Reputation: 2253
I am just studying pyspark. I am got confused about the following code:
df.groupBy(['Category','Register']).agg({'NetValue':'sum',
'Units':'mean'}).show(5,truncate=False)
df.groupBy(['Category','Register']).agg({'NetValue':'sum',
'Units': lambda x: pd.Series(x).nunique()}).show(5,truncate=False)
The first line is correct. But the second line is incorrect. The error message is:
AttributeError: 'function' object has no attribute '_get_object_id'
It looks like I did not use lambda function correctly. But this is how I use lambda in a normal python environment, and it is correct.
Could anyone help me here?
Upvotes: 0
Views: 4444
Reputation: 61
If you are okay with the performance of PySpark primitives using pure Python functions, the following code gives the desired result. You can modify the logic in _map
to suit your specific need. I made some assumptions about what your data schema might look like.
from pyspark.sql import SparkSession
from pyspark.sql.types import StructType, StructField, StringType, LongType
schema = StructType([
StructField('Category', StringType(), True),
StructField('Register', LongType(), True),
StructField('NetValue', LongType(), True),
StructField('Units', LongType(), True)
])
test_records = [
{'Category': 'foo', 'Register': 1, 'NetValue': 1, 'Units': 1},
{'Category': 'foo', 'Register': 1, 'NetValue': 2, 'Units': 2},
{'Category': 'foo', 'Register': 2, 'NetValue': 3, 'Units': 3},
{'Category': 'foo', 'Register': 2, 'NetValue': 4, 'Units': 4},
{'Category': 'bar', 'Register': 1, 'NetValue': 5, 'Units': 5},
{'Category': 'bar', 'Register': 1, 'NetValue': 6, 'Units': 6},
{'Category': 'bar', 'Register': 2, 'NetValue': 7, 'Units': 7},
{'Category': 'bar', 'Register': 2, 'NetValue': 8, 'Units': 8}
]
spark = SparkSession.builder.getOrCreate()
dataframe = spark.createDataFrame(test_records, schema)
def _map(((category, register), records)):
net_value_sum = 0
uniques = set()
for record in records:
net_value_sum += record['NetValue']
uniques.add(record['Units'])
return category, register, net_value_sum, len(uniques)
new_dataframe = spark.createDataFrame(
dataframe.rdd.groupBy(lambda x: (x['Category'], x['Register'])).map(_map),
schema
)
new_dataframe.show()
Result:
+--------+--------+--------+-----+
|Category|Register|NetValue|Units|
+--------+--------+--------+-----+
| bar| 2| 15| 2|
| foo| 1| 3| 2|
| foo| 2| 7| 2|
| bar| 1| 11| 2|
+--------+--------+--------+-----+
If you need performance or to stick with the pyspark.sql framework, then see this related question and its linked questions:
Custom aggregation on PySpark dataframes
Upvotes: 2