Rituraj Ramteke
Rituraj Ramteke

Reputation: 35

What is the best way to avoid using collect() function in the pyspark code? Best ways to write optimize pyspark code?

Hi I am trying to write a code in pyspark to create a list from a dataframe. I am using collect() function in my code but not sure if it a correct way to get a filter list of values from a dataframe column. Since collect() brings the data into the data node so it will be a bad option in case of large size dataframe say 10 GB.

Below is my input dataframe -

[Row(parent=u'p1', child=u'c1'), Row(parent=u'p11', child=u'p1'),
Row(parent=u'p111', child=u'p11'), Row(parent=u'p2', child=u'c2'), 
Row(parent=u'p22', child=u'p2'), Row(parent=u'p222', child=u'p22'),
Row(parent=u'p2222', child=u'p222')]

I want to achieve output dataframe as below -

[Row(parent=u'p2222', child1=u'p222', child2=u'p22', child3=u'p2',
child4=u'c2'), Row(parent=u'p111', child1=u'p11', child2=u'p1',
child3=u'c1', child4=None)]

Below is the working code which I have written but not sure if it optimized as spark is known for optimized processing

from pyspark.sql import * 
from pyspark.sql.functions import * 
from pyspark.sql.types import *

customSchema = StructType([StructField('parent',StringType(),True),\
                           StructField('child',StringType(),True)]) 

#loading data from a CSV file and creating a dataframe
mydata = sqlContext.load(source='com.databricks.spark.csv',path='/FileStore/tables/34v0qouq1507635707462/parent_child_input.csv',header=True,schema=customSchema)
mydata.registerTempTable('mydata')

#creating a list of values of column "Child" from the dataframe "mydata"
childlist = [x[1] for x in mydata.collect()] 

#creating another dataframe with filter values of "Parent" column which are not present in childlist
level1 = mydata.selectExpr('parent','child as child1').where(~mydata.parent.isin(childlist)) 
i=1 

#Function to create dataframe containing desired output as mentioned above
def getChild(level1,i):
    cname = 'child'+str(i)
    tmp = [x[i] for x in level1.collect() if x[i]]
    tmp = list(set(tmp))
    if tmp.count(None)==1:
          tmp.remove(None)
    level1.registerTempTable('level1')
    if len(tmp)>0:
          i+=1
          ccname = 'child'+str(i)
          querystr='select level1.*,mydata.child as ' +ccname+\
             ' from level1 left outer join mydata on level1.'+cname+'=mydata.parent'
          level1 = sqlContext.sql(querystr)
          level1 = getChild(level1,i)
return level1

level1 = getChild(level1,i)
level1.drop('child5').show()

Upvotes: 2

Views: 2251

Answers (1)

Suresh
Suresh

Reputation: 5870

If your parent and child will have integers as you input data, we can split and group using the integer. Tried my way, hope it helps.

>>> df.show()
+-----+------+
|child|parent|
+-----+------+
|   c1|    p1|
|   p1|   p11|
|  p11|  p111|
|   c2|    p2|
|   p2|   p22|
|  p22|  p222|
| p222| p2222|
+-----+------+

>>> udf = F.udf(lambda x,y : (x,y),ArrayType(StringType()))
>>> udf1 = F.udf(lambda x : tuple(filter(str.isdigit,x))[0],StringType())

>>> df1 = df.select("*",udf1('parent').alias('group'),udf('parent','child').alias('set'))
>>> df1.show()
+-----+------+-----+-------------+
|child|parent|group|          set|
+-----+------+-----+-------------+
|   c1|    p1|    1|     [p1, c1]|
|   p1|   p11|    1|    [p11, p1]|
|  p11|  p111|    1|  [p111, p11]|
|   c2|    p2|    2|     [p2, c2]|
|   p2|   p22|    2|    [p22, p2]|
|  p22|  p222|    2|  [p222, p22]|
| p222| p2222|    2|[p2222, p222]|
+-----+------+-----+-------------+

>>> udf2 = F.udf(lambda x :sorted(set(sum(x,[])),reverse=True),ArrayType(StringType()))
>>> df2 = df1.groupby('group').agg(udf2(F.collect_set('set')).alias('column'))
>>> df2.show(truncate=False)
+-----+--------------------------+
|group|column                    |
+-----+--------------------------+
|1    |[p111, p11, p1, c1]       |
|2    |[p2222, p222, p22, p2, c2]|
+-----+--------------------------+

maxval = df2[[F.max(F.size('column'))]].first()[0]
schema = StructType([StructField("parent",StringType(),True),StructField("child1",StringType(),True),StructField("child2",StringType(),True),StructField("child3",StringType(),True),StructField("child4",StringType(),True)])

udf3 = F.udf(lambda x : x if len(x) == maxval else x+[None]*(maxval -len(x)),schema)

>>> df2.select("*",udf3('column').alias('merged')).select("merged.*").show()
+------+------+------+------+------+
|parent|child1|child2|child3|child4|
+------+------+------+------+------+
|  p111|   p11|    p1|    c1|  null|
| p2222|  p222|   p22|    p2|    c2|
+------+------+------+------+------+

Upvotes: 1

Related Questions