ChaseHardin
ChaseHardin

Reputation: 2269

Parallelize Spark Collect Function

I've noticed that spark's function, collect is extremely slow on large sets of data so I'm trying to fix this using parallelize.

My main method creates the spark session and passes that to the get_data func.

def main():
    spark = SparkSession.builder.appName('app_name').getOrCreate()
    return get_data(spark)

Here is where I try to parallelize my collect function

def get_data(spark):
    df = all_data(spark)
    data = spark.sparkContext.parallelize(df.select('my_column').distinct().collect())
    return map(lambda row: row['my_column'], data)

This does not work and returns this error:

TypeError: 'RDD' object is not iterable

Does anyone have any ideas on how to parallelize or increase performance on the get_data function.

Upvotes: 0

Views: 1050

Answers (1)

thePurplePython
thePurplePython

Reputation: 2767

Here are examples of static and dynamic approaches using a broadcast variable (read-only variable persisted in each executor memory; avoids transferring a copy of the list on the driver machine for every distributed task) to retrieve the distinct values of a column. Also, if you don't provide a hard-coded value during the pivot it will trigger an extra job (wide transformation shuffle) to get the distinct values for that column.

Disclaimer => there may be a better alternative out there performance wise for the dynamic approach

print(spark.version)
2.4.3

import pyspark.sql.functions as F

# sample data
rawData = [(1, "a"),
           (1, "b"),
           (1, "c"),
           (2, "a"),
           (2, "b"),
           (2, "c"),
           (3, "a"),
           (3, "b"),
           (3, "c")]

df = spark.createDataFrame(rawData).toDF("id","value")

# static list example
l = ["a", "b", "c"]
l = spark.sparkContext.broadcast(l)

pivot_static_df = df\
  .groupby("id")\
  .pivot("value", l.value)\
  .agg(F.expr("first(value)"))

pivot_static_df.show()
+---+---+---+---+
| id|  a|  b|  c|
+---+---+---+---+
|  1|  a|  b|  c|
|  3|  a|  b|  c|
|  2|  a|  b|  c|
+---+---+---+---+

# dynamic list example
v = df.select("value").distinct().rdd.flatMap(lambda x: x).collect()
v = spark.sparkContext.broadcast(v)

print(v.value)

pivot_dynamic_df = df\
  .groupby("id")\
  .pivot("value", l.value)\
  .agg(F.expr("first(value)"))

pivot_dynamic_df.show()
['c', 'b', 'a']
+---+---+---+---+
| id|  a|  b|  c|
+---+---+---+---+
|  1|  a|  b|  c|
|  3|  a|  b|  c|
|  2|  a|  b|  c|
+---+---+---+---+

Upvotes: 2

Related Questions