Code Review question: Generic “reduceBy” or “groupBy + aggregate” functionality with Spark DataFrame

Alright everyone. Maybe I totally reinvented the wheel here, or maybe I've invented something useful. Can one of you tell me if there's a better way of doing this? Here's what I'm trying to do:

I want a generic reduceBy function, that works like an RDD's reduceByKey, but will let me use any column in a Spark DataFrame. You may say that we already have that, and it's called groupBy, but as far as I can tell, groupBy only lets you aggregate using some very limited options. I want to groupBy, and then run an arbitrary function to aggregate. Has anyone already done that?

Basically, I'm taking a Spark DataFrame that looks like this...

| birthdate|favecolor| name|twitterhandle|facebookpage|           favesong|
|2000-01-01|     blue|Alice|     allyblue|        null|               null|
|1999-12-31|     null|  Bob|         null|      BobbyG| Gangsters Paradise|
|      null|     null|Alice|         null|        null|Rolling in the Deep|

...and reducing by the column 'name' to get this:

| birthdate|favecolor|           favesong| name|twitterhandle|facebookpage|
|2000-01-01|     blue|Rolling in the Deep|Alice|     allyblue|        null|
|1999-12-31|     null| Gangsters Paradise|  Bob|         null|      BobbyG|

I just noticed the change in column order. I think I can fix that pretty quickly by taking note of the schema before beginning. But anyway, I had to write a ton of code to get that to work, and this seems like such a simple operation somebody else should have done it by now.

Here's the code, written with Python 3.5.1 and Spark 1.5.2:

 def addEmptyColumns(df, colNames):

     :param df: 
     :param colNames: 
     exprs = df.columns + ["null as " + colName for colName in colNames]
     return df.selectExpr(*exprs)

 def concatTwoDfs(left, right):

     :param left: 
     :param right: 
     # append columns from right df to left df
     missingColumnsLeft = set(right.columns) - set(left.columns)
     left = addEmptyColumns(left, missingColumnsLeft)

     # append columns from left df to right df
     missingColumnsRight = set(left.columns) - set(right.columns)
     right = addEmptyColumns(right, missingColumnsRight)

     # let's set the same order of columns
     right = right[left.columns]

      # finally, union them
     return left.unionAll(right)

 def reduce(function, iterable, initializer=None):
     A copy of the rough code from Python 2's reduce function documentation.  Why did Python 3 get rid of it?

     Apply function of two arguments cumulatively to the items of iterable, from left to right, so as to reduce the
     iterable to a single value. For example, reduce(lambda x, y: x+y, [1, 2, 3, 4, 5]) calculates ((((1+2)+3)+4)+5).
     The left argument, x, is the accumulated value and the right argument, y, is the update value from the iterable.
     If the optional initializer is present, it is placed before the items of the iterable in the calculation, and
     serves as a default when the iterable is empty. If initializer is not given and iterable contains only one item,
     the first item is returned.

     :param function: use this function to reduce the elements of iterable
     :param iterable:
     :param initializer:
     it = iter(iterable)
     if initializer is None:
             initializer = next(it)
         except StopIteration:
             raise TypeError('reduce() of empty sequence with no initial value')
     accum_value = initializer
     for x in it:
         accum_value = function(accum_value, x)
     return accum_value

 def concat(dfs):
     Concatenates two Spark dataframes intelligently, adding missing columns with 'null' entry where appropriate.

     :param dfs: a list or tuple of two Spark dataframes
     :return: single dataframe consisting of dfs' columns and data
     return reduce(concatTwoDfs, dfs)

 def combine_rows(row1, row2):
     Takes two rows assumed to have the same columns, combines them, using values from row1 when available, from row2

     :param row1: pyspark.sql.Row
     :param row2: pyspark.sql.Row
     :return: pyspark.sql.Row combined from row1 and row2
     from pyspark.sql import Row
     combined = {}
     for col in row1.asDict():
         if row1.asDict()[col] is not None:
             combined[col] = row1.asDict()[col]
             combined[col] = row2.asDict()[col]
     return Row(**combined)

 def remove_nones(row):
     Takes in a row, returns that same row minus all of the columns that have a None entry.  This is required in
     order to create a new DataFrame using only this row; DataFrame will not be created if it doesn't know what kind
     of value to expect in a column.

     :param row:
     from pyspark.sql import Row
     cleaned = {}
     for col in row.asDict():
         if row.asDict()[col] is not None:
             cleaned[col] = row.asDict()[col]
     return Row(**cleaned)

 def reduce_by(df, col, func):
     Does pretty much the same thing as an RDD's reduceByKey, but much more generic.  Kind of like a Spark DataFrame's
     groupBy, but lets you aggregate by any generic function.

     :param df: the DataFrame to be reduced
     :param col: the column you want to use for grouping in df
     :param func: the function you will use to reduce df
     :return: a reduced DataFrame
     first_loop = True
     unique_entries =
     return_rdd = sc.parallelize([])
     for entry in unique_entries:
         if first_loop:
             return_df = sqlContext.createDataFrame( \
                                 sc.parallelize([remove_nones(df.filter(df[col] == entry[0]).rdd.reduce(func))]))
             first_loop = False
             return_df = concat((return_df, \
                                sqlContext.createDataFrame( \
                                 sc.parallelize([remove_nones(df.filter(df[col] == entry[0]).rdd.reduce(func))]))))
     return return_df

And you kick it all off by making a DataFrame called test_df, and running this:

reduce_by(test_df, 'name', combine_rows).show()

I think that for your specific aggregative need this will work as well:

from pyspark.sql import SQLContext

data = sc.parallelize([("2000-01-01", "blue", "Alice", "allyblue", None, None),\
                      ("1999-12-31", None, "Bob", None, "BobbyG", "Gangsters Paradise"),\
                         (None, None, "Alice", None, None, "Rolling in the Deep") ])

df = sqlContext.createDataFrame(\
data, ["birthdate", "favecolor", "name", "twitterhandle", "facebookpage", "favesong"])

df = df.groupBy({'birthdate': 'min', 'favecolor':'min', \
                        'twitterhandle':'min', 'facebookpage':'min', 'favesong':'min'})
print df.collect()

[Row(name=u'Alice', min(favesong)=u'Rolling in the Deep',
min(twitterhandle)=u'allyblue', min(favecolor)=u'blue', 
min(facebookpage)=u'null', min(birthdate)=u'2000-01-01'), Row(name=u'Bob',
min(favesong)=u'Gangsters Paradise', min(twitterhandle)=u'null', 
min(favecolor)=u'null', min(facebookpage)=u'BobbyG', min(birthdate)=u'1999-12-31')]

