Reputation: 169
I have a Spark DataFrame with the following structure:
root
|-- distribution: vector (nullable = true)
+--------------------+
| topicDistribution|
+--------------------+
| [0.1, 0.2] |
| [0.3, 0.2] |
| [0.5, 0.2] |
| [0.1, 0.7] |
| [0.1, 0.8] |
| [0.1, 0.9] |
+--------------------+
My question is: How to add a column with the index of the maximum value for each row?
It should be something like this:
root
|-- distribution: vector (nullable = true)
|-- max_index: integer (nullable = true)
+--------------------+-----------+
| topicDistribution| max_index |
+--------------------+-----------+
| [0.1, 0.2] | 1 |
| [0.3, 0.2] | 0 |
| [0.5, 0.2] | 0 |
| [0.1, 0.7] | 1 |
| [0.1, 0.8] | 1 |
| [0.1, 0.9] | 1 |
+--------------------+-----------+
Thanks a lot
I tried the following method but I got an error:
import org.apache.spark.sql.functions.udf
val func = udf( (x: Vector[Double]) => x.indices.maxBy(x) )
df.withColumn("max_idx",func(($"topicDistribution"))).show()
Error says:
Exception in thread "main" org.apache.spark.sql.AnalysisException:
cannot resolve 'UDF(topicDistribution)' due to data type mismatch:
argument 1 requires array<double> type, however, '`topicDistribution`'
is of vector type.;;
Upvotes: 4
Views: 2481
Reputation: 74779
NOTE: The solution may not be the best performance-wise but just shows the other approach to tackle the problem (and shows how rich Spark SQL's Dataset API is).
vector
is from Spark MLlib's VectorUDT
so let me create a sample dataset first.
val input = Seq((0.1, 0.2), (0.3, 0.2)).toDF
import org.apache.spark.ml.feature.VectorAssembler
val vecAssembler = new VectorAssembler()
.setInputCols(Array("_1", "_2"))
.setOutputCol("distribution")
val ds = vecAssembler.transform(input).select("distribution")
scala> ds.printSchema
root
|-- distribution: vector (nullable = true)
The schema looks exactly like yours.
Let's change the type from VectorUDT
to the regular Array[Double]
.
import org.apache.spark.ml.linalg.Vector
val arrays = ds
.map { r => r.getAs[Vector](0).toArray }
.withColumnRenamed("value", "distribution")
scala> arrays.printSchema
root
|-- distribution: array (nullable = true)
| |-- element: double (containsNull = false)
With arrays
you could use posexplode
to index the elements in arrays, groupBy
to max
over positions and join
for a solution.
val pos = arrays.select($"*", posexplode($"distribution"))
val max_cols = pos
.groupBy("distribution")
.agg(max("col") as "max_col")
val solution = pos
.join(max_cols, "distribution")
.filter($"col" === $"max_col")
.select("distribution", "pos")
scala> solution.show
+------------+---+
|distribution|pos|
+------------+---+
| [0.1, 0.2]| 1|
| [0.3, 0.2]| 0|
+------------+---+
Upvotes: 1
Reputation: 3110
// create some sample data:
import org.apache.spark.mllib.linalg.{Vectors,Vector}
case class myrow(topics:Vector)
val rdd = sc.parallelize(Array(myrow(Vectors.dense(0.1,0.2)),myrow(Vectors.dense(0.6,0.2))))
val mydf = sqlContext.createDataFrame(rdd)
mydf.show()
+----------+
| topics|
+----------+
|[0.1, 0.2]|
|[0.6, 0.2]|
+----------+
// build the udf
import org.apache.spark.sql.functions.udf
val func = udf( (x:Vector) => x.toDense.values.toSeq.indices.maxBy(x.toDense.values) )
mydf.withColumn("max_idx",func($"topics")).show()
+----------+-------+
| topics|max_idx|
+----------+-------+
|[0.1, 0.2]| 1|
|[0.6, 0.2]| 0|
+----------+-------+
// note: you might have to change the UDF to be Vector instead of Seq for your particular use-case //edited to use Vector instead of Seq as you original question and your comment asked
Upvotes: 3