ParagM
ParagM

Reputation: 73

processing group of rows from spark dataset in scala

I am looking for a way to divide my large spark dataset into groups/batches and process that group of rows in some function. So basically group of rows should be input to my function and output is Unit for me as I dont want to aggregate or update input records but just perform some calculation.

Just to understand, lets say I have following input.

Col1 Col2 Col3
1 A 1
1 B 2
1 C 3
1 A 4
1 A 5
2 C 6
2 X 7
2 X 8

and lets say I need to group by col1 and col2, which will give me following groups

(1,A,1), (1,A,4), (1,A,5) ---> first group

(1, B, 2) ---> Second group

(1, C, 3), (1, C, 6)---> 3rd group

(2, X, 7), (2, X, 8) ---> 4th group

So I want to pass these groups to my function to perform some logic. For now, lets just say I am summing Col3 in that method.(this is not my requirement but lets just assume that I want to do that summation in my separate method). To generate following o/p.

Col1 Col2 Col3
1 A 10
1 B 2
1 C 9
2 X 15

How can I achieve this, based on some suggestions, I tried to look at UDAF but couldnt find a way how to use it. Pls Note that my real input dataset is having more than 500 million records. Thanks.

Upvotes: 0

Views: 756

Answers (1)

Gaarv
Gaarv

Reputation: 824

Here a simple example based on your input to get you started:

    from pyspark.sql.types import IntegerType
    from pyspark.sql import SparkSession
    from pyspark.sql import functions as F
    
    spark = SparkSession.builder.getOrCreate()
    
    data = [
        (1, "A",  1),
        (1, "B",  2),
        (1, "C",  3),
        (1, "A",  4),
        (1, "A",  5),
        (1, "C",  6),
        (2, "X",  7),
        (2, "X",  8),
    ]
    
    df = spark.createDataFrame(data, ["col1", "col2", "col3"])
    df.show()
    
    +----+----+----+
    |col1|col2|col3|
    +----+----+----+
    |   1|   A|   1|
    |   1|   B|   2|
    |   1|   C|   3|
    |   1|   A|   4|
    |   1|   A|   5|
    |   1|   C|   6|
    |   2|   X|   7|
    |   2|   X|   8|
    +----+----+----+
    
    # define your function - pure Python here, no Spark needed
    def dummy_f(xs):
      return sum(xs)
    
    
    # apply your function as UDF - needs input function and return type (integer here)
    (
      df
      .groupBy(F.col("col1"), F.col("col2"))
      .agg(F.collect_list(F.col("col3").cast("int")).alias("col3"))
      .withColumn("col3sum", F.udf(dummy_f, IntegerType())(F.col("col3")))
    ).show()
    
    +----+----+---------+-------+
    |col1|col2|     col3|col3sum|
    +----+----+---------+-------+
    |   1|   A|[1, 4, 5]|     10|
    |   1|   B|      [2]|      2|
    |   1|   C|   [3, 6]|      9|
    |   2|   X|   [7, 8]|     15|
    +----+----+---------+-------+

Aggregating columns as needed for your input function is the key. You can use create_map to create a dict or collect_list as shown here.

Upvotes: 1

Related Questions