Reputation: 8628
I have this data in the parquet file:
ID=222, ORDER=None, PARENT=101
ID=111, ORDER=None, PARENT=001
ID=333, ORDER=None, PARENT=111
ID=444, ORDER=None, PARENT=111
ID=101, ORDER=None, PARENT=0
ID=001, ORDER=None, PARENT=0
I want to create a Map like this id -> (parent, level, order)
. In the above-given example there are 2 levels - 0, 1, 2. However, I don't want this number to be hardcoded.
The output should be the following:
222 -> 101,1,None
101 -> 101,0,None
111 -> 001,1,None
001 -> 001,0,None
333 -> 111,2,None
444 -> 111,2,None
The level 0
means a root level without a parent.
I have written the following function (see below), but I guess that there is an easier way to create a map. Maybe even without collecting data from RDD's into a dictionary. It seems to me that I do not use the power of distributed computing by applying rdd.collect()
.
def get_map(sqlContext, pathtoparquetfile):
f = sqlContext.read.parquet(pathtoparquetfile).rdd.collect()
f = dict([ (r.ID, r.asDict()) for r in f ] )
# Fix root vertices without parent pointers
for (k, t) in f.iteritems():
p = t['PARENT']
if p == k or not f.has_key(p):
t['PARENT'] = 0
parent = { r['ID']:r['PARENT'] for r in f.values() }
level = {}
def find_level(id):
if not level.has_key(id):
if not f.has_key(parent[id]): parent[id] = 0
level[id] = 0 if parent[id]==0 else find_level(parent[id]) + 1
return level[id]
for (k, t) in f.iteritems():
t.update(level = find_level(k))
for (k, t) in f.iteritems():
t['children'] = []
for (k, t) in f.iteritems():
p = t['PARENT']
if p != 0 and f.has_key(p):
f[p]['children'].append(k)
for (k, t) in f.iteritems():
t['children'].sort(key=lambda c:(f[c]['ORDER'], c))
pos = 0
for c in t['children']:
f[c]['order'] = pos
pos = pos + 1
for (k, t) in f.iteritems():
if not t.has_key('order'): t['order'] = 0
return {k:(t['PARENT'] if t['level'] == 2 else k, t['level'], t['order']) for (k, t) in f.iteritems() }
Upvotes: 0
Views: 677
Reputation: 4719
In general, we 'chain' the data to find out their level.
from pyspark.sql import functions as f
cfg = SparkConf().setAppName('s')
spark = SparkSession.builder.enableHiveSupport().config(conf=cfg).getOrCreate()
spark.sparkContext.setLogLevel('WARN')
# matching 'parent' row, saving current 'join result' to global result(df_result)
# and return the row which maybe have 'child' row
def join_again(i, x):
global df_atom, df_result
tmp = df_atom.join(x, on=[x['id'] == df_atom['parent_atom']], how='right').cache()
# df.union is added since spark2.0, you can also use df.unionAll which is added in spark1.3
df_result = df_result.unionAll(tmp.select('id', 'parent', f.lit(i), 'order'))
# they maybe have 'child' row and should participate in 'join_again' next time
res = tmp.filter(tmp['parent_atom'].isNotNull()) \
.select(tmp['id'].alias('parent'), tmp['id_atom'].alias('id'), tmp['order'])
tmp.unpersist()
return res
def join_cycle(y):
# 'n' means how many times we execute func:'join_again'
# and 'n' is also equal to data levels (e.g. 'there are 2 levels - 0, 1, 2')
n = 1
while 1:
if y.rdd.isEmpty():
break
y = join_again(n, y)
n += 1
if __name__ == '__main__':
df = spark.createDataFrame([('222', None, '101'), ('111', None, '001'),
('333', None, '111'), ('444', None, '111'),
('555', None, '444'), ('666', None, '444')],
schema=StructType([StructField('id', StringType()),
StructField('order', StringType()),
StructField('parent', StringType())]))
df_atom = df.select(df['id'].alias('id_atom'), df['parent'].alias('parent_atom')).cache()
df_result = spark.createDataFrame([], schema=StructType([StructField('id', StringType()),
StructField('parent', StringType()),
StructField('lv', StringType()),
StructField('order', StringType())]))
# find out the row which do not have 'child' row and they are level 0
df_init = df.join(df_atom, on=[df['parent'] == df_atom['id_atom']], how='left') \
.filter(df_atom['id_atom'].isNull()).cache()
# we need to specify data level manually through func: pyspark.sql.functions.lit()
df_result = df_result.unionAll(df_init.select('parent', 'parent', f.lit(0), 'order'))
df = df_init.select('order', 'parent', 'id')
df_init.unpersist()
join_cycle(df)
df_result.distinct().show(truncate=False)
Upvotes: 1