Reputation: 128111
Suppose I have types like these:
case class SomeType(id: String, x: Int, y: Int, payload: String)
case class Key(x: Int, y: Int)
Then suppose I did groupByKey
on a Dataset[SomeType]
like this:
val input: Dataset[SomeType] = ...
val grouped: KeyValueGroupedDataset[Key, SomeType] =
input.groupByKey(s => Key(s.x, s.y))
Then suppose I have a function which determines which field I want to use in an aggregation:
val chooseDistinguisher: SomeType => String = _.id
And now I would like to run an aggregation function over the grouped dataset, for example, functions.countDistinct
, using the field obtained by the function:
grouped.agg(
countDistinct(<something which depends on chooseDistinguisher>).as[Long]
)
The problem is, I cannot create a UDF from chooseDistinguisher
, because countDistinct
accepts a Column
, and to turn a UDF into a Column
you need to specify the input column names, which I cannot do - I do not know which name to use for the "values" of a KeyValueGroupedDataset
.
I think it should be possible, because KeyValueGroupedDataset
itself does something similar:
def count(): Dataset[(K, Long)] = agg(functions.count("*").as(ExpressionEncoder[Long]()))
However, this method cheats a bit because it uses "*"
as the column name, but I need to specify a particular column (i.e. the column of the "value" in a key-value grouped dataset). Also, when you use typed functions from the typed
object, you also do not need to specify the column name, and it works somehow.
So, is it possible to do this, and if it is, how to do it?
Upvotes: 0
Views: 1489
Reputation: 13548
Currently, this use case is better handled with DataFrame
, which you can later convert back into a Dataset[A]
.
// Code assumes SQLContext implicits are present
import org.apache.spark.sql.{functions => f}
val colName = "id"
ds.toDF
.withColumn("key", f.concat('x, f.lit(":"), 'y))
.groupBy('key)
.agg(countDistinct(f.col(colName)).as("cntd"))
Upvotes: 0
Reputation: 689
As I know it's not possible with agg
transformation, which expects TypedColumn
type which is constructed based on Column
type using as
method, so you need to start from not type-safe expression. If somebody knows solution I would be interested to see it...
If you need to use type-safe aggregation you can use one of below approaches:
mapGroups
- where you can implement Scala function responsible for aggregating Iterator
Aggregator
as suggested aboveFirst approach needs less code, so below I'm showing quick example:
def countDistinct[T](values: Iterator[T])(chooseDistinguisher: T => String): Long =
values.map(chooseDistinguisher).toSeq.distinct.size
ds
.groupByKey(s => Key(s.x, s.y))
.mapGroups((k,vs) => (k, countDistinct(vs)(_.name)))
In my opinion Spark Dataset type-safe API is still much less mature than not type safe DataFrame API. Some time ago I was thinking that it could be good idea to implement simple to use type-safe aggregation API for Spark Dataset.
Upvotes: 1