b-ryce
b-ryce

Reputation: 5828

Pyspark calculate a field on a grouped table

I've got a data frame that looks like this:

+-------+-----+-------------+------------+
|startID|endID|trip_distance|total_amount|
+-------+-----+-------------+------------+
|      1|    3|            5|          12|
|      1|    3|            0|           4|
+-------+-----+-------------+------------+

I need to create a new table that groups the trips by the start and end IDs, and then figures out what the average trip rate was.

The trip rate is figured by taking all the trips with the same start and end IDs, in my case startID 1, and endID 3, had a total of 2 trips, and for those 2 trips the avg trip_distance was 2.5, and avg total_amount was 8. So the trip_rate should be 8/2.5=3.2

So the end result should look like this:

+-------+-----+-----+----------+
|startID|endID|count| trip_rate|
+-------+-----+-----+----------+
|      1|    3|    2|       3.2|
+-------+-----+-----+----------+

Here is what I'm trying to do:

from pyspark.shell import spark
from pyspark.sql.functions import avg

df = spark.createDataFrame(
    [
        (1, 3, 5, 12),
        (1, 3, 0, 4)
    ],
    ['startID', 'endID', 'trip_distance', 'total_amount'] # add your columns label here
)
df.show()
grouped_table = df.groupBy('startID', 'endID').count().alias('count')
grouped_table.show()

grouped_table = df.withColumn('trip_rate', (avg('total_amount') / avg('trip_distance')))
grouped_table.show()

But I'm getting the following error:

pyspark.sql.utils.AnalysisException: "grouping expressions sequence is empty, and '`startID`' is not an aggregate function. Wrap '((avg(`total_amount`) / avg(`trip_distance`)) AS `trip_rate`)' in windowing function(s) or wrap '`startID`' in first() (or first_value) if you don't care which value you get.;;\nAggregate [startID#0L, endID#1L, trip_distance#2L, total_amount#3L, (avg(total_amount#3L) / avg(trip_distance#2L)) AS trip_rate#44]\n+- LogicalRDD [startID#0L, endID#1L, trip_distance#2L, total_amount#3L], false\n"

I tried wrapping the calculation in an AS function, but I kept getting syntax errors.

Upvotes: 1

Views: 262

Answers (1)

Cena
Cena

Reputation: 3419

Group by, sum and divide. count and sum can be used inside agg()

from pyspark.sql import functions as F

df.groupBy('startID', 'endID').agg(F.count(F.lit(1)).alias("count"), \
    (F.sum("total_amount")/F.sum("trip_distance")).alias('trip_rate')).show()

Upvotes: 1

Related Questions