Vladimir Matveev
Vladimir Matveev

Reputation: 128111

Dataset.groupByKey + untyped aggregation functions

Suppose I have types like these:

case class SomeType(id: String, x: Int, y: Int, payload: String)
case class Key(x: Int, y: Int)

Then suppose I did groupByKey on a Dataset[SomeType] like this:

val input: Dataset[SomeType] = ...

val grouped: KeyValueGroupedDataset[Key, SomeType] =
  input.groupByKey(s => Key(s.x, s.y))

Then suppose I have a function which determines which field I want to use in an aggregation:

val chooseDistinguisher: SomeType => String = _.id

And now I would like to run an aggregation function over the grouped dataset, for example, functions.countDistinct, using the field obtained by the function:

grouped.agg(
  countDistinct(<something which depends on chooseDistinguisher>).as[Long]
)

The problem is, I cannot create a UDF from chooseDistinguisher, because countDistinct accepts a Column, and to turn a UDF into a Column you need to specify the input column names, which I cannot do - I do not know which name to use for the "values" of a KeyValueGroupedDataset.

I think it should be possible, because KeyValueGroupedDataset itself does something similar:

def count(): Dataset[(K, Long)] = agg(functions.count("*").as(ExpressionEncoder[Long]()))

However, this method cheats a bit because it uses "*" as the column name, but I need to specify a particular column (i.e. the column of the "value" in a key-value grouped dataset). Also, when you use typed functions from the typed object, you also do not need to specify the column name, and it works somehow.

So, is it possible to do this, and if it is, how to do it?

Upvotes: 0

Views: 1489

Answers (2)

Sim
Sim

Reputation: 13548

Currently, this use case is better handled with DataFrame, which you can later convert back into a Dataset[A].

// Code assumes SQLContext implicits are present
import org.apache.spark.sql.{functions => f}

val colName = "id"
ds.toDF
  .withColumn("key", f.concat('x, f.lit(":"), 'y))
  .groupBy('key)
  .agg(countDistinct(f.col(colName)).as("cntd"))

Upvotes: 0

Piotr Kalański
Piotr Kalański

Reputation: 689

As I know it's not possible with agg transformation, which expects TypedColumn type which is constructed based on Column type using as method, so you need to start from not type-safe expression. If somebody knows solution I would be interested to see it...

If you need to use type-safe aggregation you can use one of below approaches:

  • mapGroups - where you can implement Scala function responsible for aggregating Iterator
  • implement your custom Aggregator as suggested above

First approach needs less code, so below I'm showing quick example:

def countDistinct[T](values: Iterator[T])(chooseDistinguisher: T => String): Long =
     values.map(chooseDistinguisher).toSeq.distinct.size

ds
   .groupByKey(s => Key(s.x, s.y))
   .mapGroups((k,vs) => (k, countDistinct(vs)(_.name)))

In my opinion Spark Dataset type-safe API is still much less mature than not type safe DataFrame API. Some time ago I was thinking that it could be good idea to implement simple to use type-safe aggregation API for Spark Dataset.

Upvotes: 1

Related Questions