Monika
Monika

Reputation: 143

function to each row of Spark Dataframe

I have a spark Dataframe (df) with 2 column's (Report_id and Cluster_number).

I want to apply a function (getClusterInfo) to df which will return the name for each cluster i.e. if cluster number is '3' then for a specific report_id, the 3 below mentioned rows will be written:

{"cluster_id":"1","influencers":[{"screenName":"A"},{"screenName":"B"},{"screenName":"C"},...]}
{"cluster_id":"2","influencers":[{"screenName":"D"},{"screenName":"E"},{"screenName":"F"},...]}
{"cluster_id":"3","influencers":[{"screenName":"G"},{"screenName":"H"},{"screenName":"E"},...]}

I am using foreach on df to apply getClusterInfo function, but can't figure out how to convert o/p to a Dataframe (Report_id, Array[cluster_info]).

Here is the code snippet:

  df.foreach(row => {
    val report_id = row(0)
    val cluster_no = row(1).toString
    val cluster_numbers = new Range(0, cluster_no.toInt - 1, 1)
    for (cluster <- cluster_numbers.by(1)) {
      val cluster_id = report_id + "_" + cluster
      //get cluster influencers
      val result = getClusterInfo(cluster_id)
      println(result.get)
      val res : String = result.get.toString()
      // TODO ?
    }
    .. //TODO ?
  })

Upvotes: 4

Views: 13653

Answers (1)

Tzach Zohar
Tzach Zohar

Reputation: 37822

Geenrally speaking, you shouldn't use foreach when you want to map something into something else; foreach is good for applying functions that only have side-effects and return nothing.

In this case, if I got the details right (probably not), you can use a User-Defined Function (UDF) and explode the result:

import org.apache.spark.sql.functions._
import spark.implicits._

// I'm assuming we have these case classes (or similar)
case class Influencer(screenName: String)
case class ClusterInfo(cluster_id: String, influencers: Array[Influencer])

// I'm assuming this method is supplied - with your own implementation
def getClusterInfo(clusterId: String): ClusterInfo =
  ClusterInfo(clusterId, Array(Influencer(clusterId)))

// some sample data - assuming both columns are integers:
val df = Seq((222, 3), (333, 4)).toDF("Report_id", "Cluster_number")

// actual solution:

// UDF that returns an array of ClusterInfo;
// Array size is 'clusterNo', creates cluster id for each element and maps it to info
val clusterInfoUdf = udf { (clusterNo: Int, reportId: Int) =>
  (1 to clusterNo).map(v => s"${reportId}_$v").map(getClusterInfo)
}

// apply UDF to each record and explode - to create one record per array item
val result = df.select(explode(clusterInfoUdf($"Cluster_number", $"Report_id")))

result.printSchema()
// root
// |-- col: struct (nullable = true)
// |    |-- cluster_id: string (nullable = true)
// |    |-- influencers: array (nullable = true)
// |    |    |-- element: struct (containsNull = true)
// |    |    |    |-- screenName: string (nullable = true)

result.show(truncate = false)
// +-----------------------------+
// |col                          |
// +-----------------------------+
// |[222_1,WrappedArray([222_1])]|
// |[222_2,WrappedArray([222_2])]|
// |[222_3,WrappedArray([222_3])]|
// |[333_1,WrappedArray([333_1])]|
// |[333_2,WrappedArray([333_2])]|
// |[333_3,WrappedArray([333_3])]|
// |[333_4,WrappedArray([333_4])]|
// +-----------------------------+

Upvotes: 7

Related Questions