Reputation: 95
I want to calculate the number of distinct rows according to one column. I see that the following works :
long countDistinctAtt = Math.toIntExact(dataset.select(att).distinct().count());
But this doesn't :
long countDistinctAtt = dataset.agg(countDistinct(att)).agg(count("*")).collectAsList().get(0).getLong(0);
Why the second solution does not calculate the distinct rows number ?
Upvotes: 0
Views: 47
Reputation: 1387
The second command needs to have a grouping of rows with a groupBy
method before any aggregation agg
occurs. This particular command doesn't specify based on what rows the aggregation(s) will take place, so in that case of course it won't work.
The main problem with the second command, though, is that even with grouping the rows and aggregating their values based on a column, the results are going to be based per row (aka, with that kind of logic you tell the machine that you want to count the occurrences of a value for each (now grouped and aggregated) row) than based on the entire DataFrame/DataSet. This means that the result is going to be a column/list of values instead of just one value of the total count, because each element will correspond to each aggregated row. Getting the first (get(0)
) of those values doesn't really make any sense here, because even if the command would run, you would only get a value count of just one row.
The first command bypasses the hassles by specifying that we only want the distinct values of the selected column, so you can count these values up and find the total number of them. This will result in just one value (which is long
and you correctly cast it to int
).
As a rule of thumb, 9 times out of 10 you should use groupBy
/agg
when you want to do row-based computations. In case you do not really care about rows and just want a total result for the whole DataFrame/DataSet, you can use the built-in SQL functions of Spark (you can find all of them here, and you can study their implementations for Java/Scala/Python on each of their documentations too) like in the first command.
To illustrate this, let's say we have a DataFrame (or DataSet, doesn't matter at this point) named dfTest
with the following data:
+------+------+
|letter|number|
+------+------+
| a| 5|
| b| 8|
| c| 14|
| d| 20|
| e| 8|
| f| 8|
| g| 20|
+------+------+
If we use the basic built-in SQL functions to select the number
column values, filter out the duplicates, and count the remaining rows, the command we correctly put out 4
because there are indeed 4 unique values in number
:
// In Scala:
println(dfTest.select("number").distinct().count().toInt)
// In Java:
System.out.println(Math.toIntExact(dfTest.select("number").distinct().count()))
// Output:
4
In contrary, if we group the DataFrame rows together and count the values for each row on its own (no need to use agg
here, since count
takes a column's value as argument by default), this will result in the following DataFrame where the count will be calculated strictly for each distinct value of the number
column:
// In Scala & Java:
dfTest.groupBy("number")
.count()
.show()
// Output:
+------+-----+
|number|count|
+------+-----+
| 20| 2|
| 5| 1|
| 8| 3|
| 14| 1|
+------+-----+
Upvotes: 1