Raphael Roth
Raphael Roth

Reputation: 27373

Extract columns in nested Spark DataFrame as scala Arrays

I have a DataFrame myDf which contains an array of pairs of points (i.e. x and y coordinates), it has the following schema:

myDf.printSchema

root
 |-- pts: array (nullable = true)
 |    |-- element: struct (containsNull = true)
 |    |    |-- x: float (nullable = true)
 |    |    |-- y: float (nullable = true)

I want to get x and y as individual plain Scala Array's. I think I need to apply the explode-function, but I cannot figure out how. I tried to apply this solution but I cant get it to work.

I'm using Spark 1.6.1 with Scala 2.10

EDIT: I realize that I had a misunderstanding how Spark works, getting the actual arrays is only possible if the data is collected (or using UDFs)

Upvotes: 2

Views: 4026

Answers (2)

Raphael Roth
Raphael Roth

Reputation: 27373

There are two ways to get the points as plan scala Arrays:

collecting to the driver:

val localRows = myDf.take(10)
val xs: Array[Float] = localRows.map(_.getAs[Float]("x"))
val ys: Array[Float] = localRows.map(_.getAs[Float]("y"))

or inside an UDF:

val processArr = udf((pts:WrappedArray[Row]) => {

  val xs: Array[Float] = pts.map(_.getAs[Float]("x")).array
  val ys: Array[Float] = pts.map(_.getAs[Float]("y")).array
  //...do something with it
})

}

Upvotes: 0

Yuan JI
Yuan JI

Reputation: 2995

Assuming myDf is DataFrame read from a json file:

{
 "pts":[
    {
     "x":0.0,
     "y":0.1
    },
    {
     "x":1.0,
     "y":1.1
    },
    {
     "x":2.0,
     "y":2.1
    }
  ]
}

You can do explode like this:

Java:

DataFrame pts = myDf.select(org.apache.spark.sql.functions.explode(df.col("pts")).as("pts"))
                    .select("pts.x", "pts.y");
pts.printSchema();
pts.show();

Scala:

// Sorry I don't know Scala
// I just interpreted from the above Java code
// Code here may have some mistakes
val pts = myDf.select(explode($"pts").as("pts"))
              .select($"pts.x", $"pts.y")
pts.printSchema()
pts.show()

Here is the printed schema:

root
 |-- x: double (nullable = true)
 |-- y: double (nullable = true)

And here is the pts.show() result:

+---+---+
|  x|  y|
+---+---+
|0.0|0.1|
|1.0|1.1|
|2.0|2.1|
+---+---+

Upvotes: 4

Related Questions