Reputation: 2269
I've noticed that spark's function, collect
is extremely slow on large sets of data so I'm trying to fix this using parallelize.
My main method creates the spark session and passes that to the get_data
func.
def main():
spark = SparkSession.builder.appName('app_name').getOrCreate()
return get_data(spark)
Here is where I try to parallelize my collect function
def get_data(spark):
df = all_data(spark)
data = spark.sparkContext.parallelize(df.select('my_column').distinct().collect())
return map(lambda row: row['my_column'], data)
This does not work and returns this error:
TypeError: 'RDD' object is not iterable
Does anyone have any ideas on how to parallelize or increase performance on the get_data
function.
Upvotes: 0
Views: 1050
Reputation: 2767
Here are examples of static and dynamic approaches using a broadcast
variable (read-only variable persisted in each executor memory; avoids transferring a copy of the list on the driver machine for every distributed task) to retrieve the distinct values of a column. Also, if you don't provide a hard-coded value during the pivot
it will trigger an extra job (wide transformation shuffle) to get the distinct values for that column.
Disclaimer => there may be a better alternative out there performance wise for the dynamic approach
print(spark.version)
2.4.3
import pyspark.sql.functions as F
# sample data
rawData = [(1, "a"),
(1, "b"),
(1, "c"),
(2, "a"),
(2, "b"),
(2, "c"),
(3, "a"),
(3, "b"),
(3, "c")]
df = spark.createDataFrame(rawData).toDF("id","value")
# static list example
l = ["a", "b", "c"]
l = spark.sparkContext.broadcast(l)
pivot_static_df = df\
.groupby("id")\
.pivot("value", l.value)\
.agg(F.expr("first(value)"))
pivot_static_df.show()
+---+---+---+---+
| id| a| b| c|
+---+---+---+---+
| 1| a| b| c|
| 3| a| b| c|
| 2| a| b| c|
+---+---+---+---+
# dynamic list example
v = df.select("value").distinct().rdd.flatMap(lambda x: x).collect()
v = spark.sparkContext.broadcast(v)
print(v.value)
pivot_dynamic_df = df\
.groupby("id")\
.pivot("value", l.value)\
.agg(F.expr("first(value)"))
pivot_dynamic_df.show()
['c', 'b', 'a']
+---+---+---+---+
| id| a| b| c|
+---+---+---+---+
| 1| a| b| c|
| 3| a| b| c|
| 2| a| b| c|
+---+---+---+---+
Upvotes: 2