Lazloo Xp
Lazloo Xp

Reputation: 1008

PySpark UDF issues when referencing outside of function

I facing the issue that I get the error

TypeError: cannot pickle '_thread.RLock' object

when I try to apply the following code:

from pyspark.sql.types import *
from pyspark.sql.functions import *

data_1 = [('James','Smith','M',30),('Anna','Rose','F',41),
  ('Robert','Williams','M',62), 
]
data_2 = [('Junior','Smith','M',15),('Helga','Rose','F',33),
  ('Mike','Williams','M',77), 
]
columns = ["firstname","lastname","gender","age"]
df_1 = spark.createDataFrame(data=data_1, schema = columns)
df_2 = spark.createDataFrame(data=data_2, schema = columns)

def find_n_people_with_higher_age(x):
  return df_2.filter(df_2['age']>=x).count()

find_n_people_with_higher_age_udf = udf(find_n_people_with_higher_age, IntegerType())
df_1.select(find_n_people_with_higher_age_udf(col('category_id')))

Upvotes: 1

Views: 1500

Answers (2)

Matt Andruff
Matt Andruff

Reputation: 5155

add additional data:

from pyspark.sql import Window
from pyspark.sql.types import *
from pyspark.sql.functions import *
    data_1 = [('James','Smith','M',30),('Anna','Rose','F',41),
  ('Robert','Williams','M',62), 
]
    # add more data to make it more interesting.
    data_2 = [('Junior','Smith','M',15),('Helga','Rose','F',33),('Gia','Rose','F',34),
      ('Mike','Williams','M',77), ('John','Williams','M',77), ('Bill','Williams','F',79),
    ]
columns = ["firstname","lastname","gender","age"]
df_1 = spark.createDataFrame(data=data_1, schema = columns)
df_2 = spark.createDataFrame(data=data_2, schema = columns)

# dataframe to help fill in missing ages
ref = spark.range( 1, 110, 1).toDF("numbers").withColumn("count", lit(0)).withColumn("rolling_Count", lit(0))

    

countAges = df_2.groupby("age").count()
#this actually give you the short list of ages
rollingCounts = countAges.withColumn("rolling_Count", sum(col("count")).over(Window.partitionBy().orderBy(col("age").desc())))
#fill in missing ages and remove duplicates
filled = rollingCounts.union(ref).groupBy("age").agg(sum("count").alias("count"))
#add a rolling count across all ages
allAgeCounts = filled.withColumn("rolling_Count", sum(col("count")).over(Window.partitionBy().orderBy(col("age").desc())))
#do inner join because we've filled in all ages.
df_1.join(allAgeCounts, df_1.age == allAgeCounts.age, "inner").show()
+---------+--------+------+---+---+-----+-------------+                         
|firstname|lastname|gender|age|age|count|rolling_Count|
+---------+--------+------+---+---+-----+-------------+
|     Anna|    Rose|     F| 41| 41|    0|            3|
|   Robert|Williams|     M| 62| 62|    0|            3|
|    James|   Smith|     M| 30| 30|    0|            5|
+---------+--------+------+---+---+-----+-------------+

I wouldn't normally want to use a window over an entire table, but here the data it's iterating over <= 110 so this is reasonable.

Upvotes: 0

Matt Andruff
Matt Andruff

Reputation: 5155

Here's a good article on python UDF's.

I use it as a reference as I suspected that you were running into a serialization issue. I'm showing the entire paragraph to add context of the sentence but really it's the serialization that's the issue.

Performance Considerations

It’s important to understand the performance implications of Apache Spark’s UDF features. Python UDFs for example (such as our CTOF function) result in data being serialized between the executor JVM and the Python interpreter running the UDF logic – this significantly reduces performance as compared to UDF implementations in Java or Scala. Potential solutions to alleviate this serialization bottleneck include:

If you consider what you are asking maybe you'll see why this isn't working. You are asking all data from your dataframe(data_2) to be shipped(serialized) to an executor that then serializes it and ships it to python to be interpreted. Dataframes don't serialize. So that's your issue, but if they did, you are sending an entire data frame to each executor. Your sample data here isn't an issue, but for trillions of records it would blow up the JVM.

What your asking is doable I just need to figure out how do it. Likely a window or group by would be the trick.

Upvotes: 1

Related Questions