Nick Resnick
Nick Resnick

Reputation: 621

How to flatten columns of type array of structs (as returned by Spark ML API)?

Maybe it's just because I'm relatively new to the API, but I feel like Spark ML methods often return DFs that are unnecessarily difficult to work with.

This time, it's the ALS model that's tripping me up. Specifically, the recommendForAllUsers method. Let's reconstruct the type of DF it would return:

scala> val arrayType = ArrayType(new StructType().add("itemId", IntegerType).add("rating", FloatType))

scala> val recs = Seq((1, Array((1, .7), (2, .5))), (2, Array((0, .9), (4, .1)))).
  toDF("userId", "recommendations").
  select($"userId", $"recommendations".cast(arrayType))

scala> recs.show()
+------+------------------+
|userId|   recommendations|
+------+------------------+
|     1|[[1,0.7], [2,0.5]]|
|     2|[[0,0.9], [4,0.1]]|
+------+------------------+
scala> recs.printSchema
root
 |-- userId: integer (nullable = false)
 |-- recommendations: array (nullable = true)
 |    |-- element: struct (containsNull = true)
 |    |    |-- itemId: integer (nullable = true)
 |    |    |-- rating: float (nullable = true)

Now, I only care about the itemId in the recommendations column. After all, the method is recommendForAllUsers not recommendAndScoreForAllUsers (ok ok I'll stop being sassy...)

How do I do this??

I thought I had it when I created a UDF:

scala> val itemIds = udf((arr: Array[(Int, Float)]) => arr.map(_._1))

but that produces an error:

scala> recs.withColumn("items", items($"recommendations"))
org.apache.spark.sql.AnalysisException: cannot resolve 'UDF(recommendations)' due to data type mismatch: argument 1 requires array<struct<_1:int,_2:float>> type, however, '`recommendations`' is of array<struct<itemId:int,rating:float>> type.;;
'Project [userId#87, recommendations#92, UDF(recommendations#92) AS items#238]
+- Project [userId#87, cast(recommendations#88 as array<struct<itemId:int,rating:float>>) AS recommendations#92]
   +- Project [_1#84 AS userId#87, _2#85 AS recommendations#88]
      +- LocalRelation [_1#84, _2#85]

Any ideas? thanks!

Upvotes: 7

Views: 14268

Answers (2)

Nick Resnick
Nick Resnick

Reputation: 621

wow, my coworker came up with an extremely elegant solution:

scala> recs.select($"userId", $"recommendations.itemId").show
+------+------+
|userId|itemId|
+------+------+
|     1|[1, 2]|
|     2|[0, 4]|
+------+------+

So maybe the Spark ML API isn't that difficult after all :)

Upvotes: 9

Jacek Laskowski
Jacek Laskowski

Reputation: 74669

With an array as the type of a column, e.g. recommendations, you'd be quite productive using explode function (or the more advanced flatMap operator).

explode(e: Column): Column Creates a new row for each element in the given array or map column.

That gives you bare structs to work with.

import org.apache.spark.sql.types._
val structType = new StructType().
  add($"itemId".int).
  add($"rating".float)
val arrayType = ArrayType(structType)
val recs = Seq((1, Array((1, .7), (2, .5))), (2, Array((0, .9), (4, .1)))).
  toDF("userId", "recommendations").
  select($"userId", $"recommendations" cast arrayType)

val exploded = recs.withColumn("recs", explode($"recommendations"))
scala> exploded.show
+------+------------------+-------+
|userId|   recommendations|   recs|
+------+------------------+-------+
|     1|[[1,0.7], [2,0.5]]|[1,0.7]|
|     1|[[1,0.7], [2,0.5]]|[2,0.5]|
|     2|[[0,0.9], [4,0.1]]|[0,0.9]|
|     2|[[0,0.9], [4,0.1]]|[4,0.1]|
+------+------------------+-------+

structs are nice in select operator with * (star) to flatten them to columns per the struct fields.

You could do select($"element.*").

scala> exploded.select("userId", "recs.*").show
+------+------+------+
|userId|itemId|rating|
+------+------+------+
|     1|     1|   0.7|
|     1|     2|   0.5|
|     2|     0|   0.9|
|     2|     4|   0.1|
+------+------+------+

I think that could do what you're after.


p.s. Stay away from UDFs as long as possible since they "trigger" row conversion from the internal format (InternalRow) to JVM objects that can lead to excessive GCs.

Upvotes: 5

Related Questions