jakko
jakko

Reputation: 835

pyspark RDD - add list of tuples at some index

I have a RDD that looks like this

[( 3,6,7), (2,5,7), (4,3,7)]

I would like to get the average of first elements , as well as sum of the second elements and sum of third elements. This is how the output would look:

(3,14,21)

Is it possible to do this using pyspark?

Upvotes: 0

Views: 2202

Answers (3)

Manu Gupta
Manu Gupta

Reputation: 830

Yes, it is possible in pyspark. You can use dataframe functionality to get all these values. Please try below.

from pyspark.sql.functions import *

my_rdd=sc.parallelize([( 3,6,7), (2,5,7), (4,3,7)])
df = sqlContext.createDataFrame(my_rdd,("fld1", "fld2", "fld3"))
df.groupBy().agg(avg(col("fld1")),sum(col("fld2")),sum(col("fld3"))).rdd.collect()

Another way of doing it:

df.registerTempTable('mytable')
df1=sqlContext.sql("select avg(fld1), sum(fld2), sum(fld3) from mytable")
df1.rdd.collect()

Thanks, Manu

Upvotes: -1

Alper t. Turker
Alper t. Turker

Reputation: 35249

With RDD you can use NumPy array and stats:

import numpy as np 

stats = sc.parallelize([( 3,6,7), (2,5,7), (4,3,7)]).map(np.array).stats()
stats.mean()[0], stats.sum()[1], stats.sum()[2]

# (3.0, 14.0, 21.0)

Upvotes: 2

desertnaut
desertnaut

Reputation: 60369

You can convert to a dataframe and use groupBy:

spark.version
# u'2.2.0'

# toy data
rdd = sc.parallelize([(3,6,7), (2,5,7), (4,3,7)])
df = spark.createDataFrame(rdd,("x1", "x2", "x3"))

(df.groupBy().avg("x1").collect()[0][0],
 df.groupBy().sum('x2').collect()[0][0],
 df.groupBy().sum('x3').collect()[0][0])
# (3.0, 14, 21)

Or you could group the 2 sum operations:

ave = df.groupBy().avg("x1").collect()
sums = df.groupBy().sum("x2","x3").collect()
(ave[0][0], sums[0][0], sums[0][1])
# (3.0, 14, 21)

UPDATE (after comment): user8371915's proposal leads to an even more elegant solution:

from pyspark.sql.functions import avg, sum

num_cols = len(df.columns) # number of columns
res = df.groupBy().agg(avg("x1"), sum("x2"), sum("x3")).first()
[res[i] for i in range(num_cols)]
# [3.0, 14, 21]

Upvotes: 2

Related Questions