Dinosaurius
Dinosaurius

Reputation: 8628

Creation of a Map structure based on PySpark DataFrame

I have this data in the parquet file:

ID=222, ORDER=None, PARENT=101
ID=111, ORDER=None, PARENT=001
ID=333, ORDER=None, PARENT=111
ID=444, ORDER=None, PARENT=111
ID=101, ORDER=None, PARENT=0
ID=001, ORDER=None, PARENT=0

I want to create a Map like this id -> (parent, level, order). In the above-given example there are 2 levels - 0, 1, 2. However, I don't want this number to be hardcoded.

The output should be the following:

222 -> 101,1,None
101 -> 101,0,None
111 -> 001,1,None
001 -> 001,0,None
333 -> 111,2,None
444 -> 111,2,None

The level 0 means a root level without a parent.

I have written the following function (see below), but I guess that there is an easier way to create a map. Maybe even without collecting data from RDD's into a dictionary. It seems to me that I do not use the power of distributed computing by applying rdd.collect().

def get_map(sqlContext, pathtoparquetfile):
    f = sqlContext.read.parquet(pathtoparquetfile).rdd.collect() 
    f = dict([ (r.ID, r.asDict()) for r in f ] )

    # Fix root vertices without parent pointers
    for (k, t) in f.iteritems():
        p = t['PARENT'] 
        if p == k or not f.has_key(p):
            t['PARENT'] = 0

    parent = { r['ID']:r['PARENT'] for r in f.values() }
    level = {}

    def find_level(id):
        if not level.has_key(id): 
            if not f.has_key(parent[id]): parent[id] = 0
            level[id] = 0 if parent[id]==0 else find_level(parent[id]) + 1
        return level[id]

    for (k, t) in f.iteritems():
        t.update(level = find_level(k))

    for (k, t) in f.iteritems():
        t['children'] = []
    for (k, t) in f.iteritems():
        p = t['PARENT']
        if p != 0 and f.has_key(p):
            f[p]['children'].append(k)
    for (k, t) in f.iteritems():
        t['children'].sort(key=lambda c:(f[c]['ORDER'], c))
        pos = 0
        for c in t['children']:
            f[c]['order'] = pos
            pos = pos + 1
    for (k, t) in f.iteritems():
        if not t.has_key('order'): t['order'] = 0

    return {k:(t['PARENT'] if t['level'] == 2 else k, t['level'], t['order']) for (k, t) in f.iteritems() }

Upvotes: 0

Views: 677

Answers (1)

Zhang Tong
Zhang Tong

Reputation: 4719

In general, we 'chain' the data to find out their level.

from pyspark.sql import functions as f

cfg = SparkConf().setAppName('s')
spark = SparkSession.builder.enableHiveSupport().config(conf=cfg).getOrCreate()
spark.sparkContext.setLogLevel('WARN')


# matching 'parent' row, saving current 'join result' to global result(df_result)
# and return the row which maybe have 'child' row
def join_again(i, x):
    global df_atom, df_result
    tmp = df_atom.join(x, on=[x['id'] == df_atom['parent_atom']], how='right').cache()
    # df.union is added since spark2.0, you can also use df.unionAll which is added in spark1.3
    df_result = df_result.unionAll(tmp.select('id', 'parent', f.lit(i), 'order'))

    # they maybe have 'child' row and should participate in 'join_again' next time
    res = tmp.filter(tmp['parent_atom'].isNotNull()) \
        .select(tmp['id'].alias('parent'), tmp['id_atom'].alias('id'), tmp['order'])
    tmp.unpersist()
    return res


def join_cycle(y):
    # 'n' means how many times we execute func:'join_again'
    # and 'n' is also equal to data levels (e.g. 'there are 2 levels - 0, 1, 2')
    n = 1
    while 1:
        if y.rdd.isEmpty():
            break
        y = join_again(n, y)
        n += 1


if __name__ == '__main__':
    df = spark.createDataFrame([('222', None, '101'), ('111', None, '001'),
                                ('333', None, '111'), ('444', None, '111'),
                                ('555', None, '444'), ('666', None, '444')],
                               schema=StructType([StructField('id', StringType()),
                                                  StructField('order', StringType()),
                                                  StructField('parent', StringType())]))

    df_atom = df.select(df['id'].alias('id_atom'), df['parent'].alias('parent_atom')).cache()
    df_result = spark.createDataFrame([], schema=StructType([StructField('id', StringType()),
                                                             StructField('parent', StringType()),
                                                             StructField('lv', StringType()),
                                                             StructField('order', StringType())]))

    # find out the row which do not have 'child' row and they are level 0
    df_init = df.join(df_atom, on=[df['parent'] == df_atom['id_atom']], how='left') \
        .filter(df_atom['id_atom'].isNull()).cache()
    # we need to specify data level manually through func: pyspark.sql.functions.lit()
    df_result = df_result.unionAll(df_init.select('parent', 'parent', f.lit(0), 'order'))

    df = df_init.select('order', 'parent', 'id')
    df_init.unpersist()

    join_cycle(df)

    df_result.distinct().show(truncate=False)

Upvotes: 1

Related Questions