Concatenating lists in PySpark

In my Spark Dataframe, one of the columns is of strings

Activities
"1 1 1 1 0 0 0 0 0"
"0 0 0 1 1 1 0 0 0"
"1 1 1 1 0 0 0 0 0"
"0 0 0 1 1 1 0 0 0"
"1 1 1 1 0 0 0 0 0"
"0 0 0 1 1 1 0 0 0"

I wish to collect strings from each row of this column and make a single list by concatenation. Then, split this huge string to make a huge single integer array like

[1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 0, 0, 0,...]

(Of course, one can split the strings first, into lists, then append all the lists to form a big list, but the issue of How to concatenation RDD based lists remains)

Using pythons local data structures I can do:

import pyspark.sql.functions as F

allActivities = []
activitiesListColumn = df.agg(F.collect_list("Activities").alias("Activities")).collect()[0]
for rowActivity in activitiesListColumn["Activities"]:
    activities = rowActivity.split()
    allActivities += activities
print(allActivities)

How to get this done with RDD based (ie parallel-ized) data structures?

Upvotes: 1

Views: 2093

Answers (1)

lvnt
lvnt

Reputation: 497

This possible with GROUP_CONCAT method but spark-sql doesnt contains this method. We can identify an UDF that runs like GROUP_CONCAT. About detail this UDF you can see in this link: SPARK SQL replacement for mysql GROUP_CONCAT aggregate function.. But we must change the seperator character (',' to ' ').. And after that, you can try this line:

df.agg(GroupConcat(new ColumnName("your_string_array"))).show

The GroupConcat object is:

object GroupConcat extends UserDefinedAggregateFunction {
  def inputSchema = new StructType().add("x", StringType)
  def bufferSchema = new StructType().add("buff", ArrayType(StringType))
  def dataType = StringType
  def deterministic = true

  def initialize(buffer: MutableAggregationBuffer) = {
    buffer.update(0, ArrayBuffer.empty[String])
  }

  def update(buffer: MutableAggregationBuffer, input: Row) = {
    if (!input.isNullAt(0))
      buffer.update(0, buffer.getSeq[String](0) :+ input.getString(0))
  }

  def merge(buffer1: MutableAggregationBuffer, buffer2: Row) = {
    buffer1.update(0, buffer1.getSeq[String](0) ++ buffer2.getSeq[String](0))
  }

  def evaluate(buffer: Row) = UTF8String.fromString(
    buffer.getSeq[String](0).mkString(" "))
}

Upvotes: 1

Related Questions