Michael
Michael

Reputation: 2556

How to properly use reduce with a dictionary

I am using a custom function as part of a reduce operation. For the following example I am getting the following message TypeError: reduce() takes no keyword arguments - I believe this is due to the way I am using the dictionary mapping in the function exposed_colum - Could you please help me fix this function?

from pyspark.sql import DataFrame, Row
from pyspark.sql.functions import col
from pyspark.sql import SparkSession
from functools import reduce


def process_data(df: DataFrame):
    col_mapping = dict(zip(["name", "age"], ["a", "b"]))

    # Do other things...

    def exposed_column(df: DataFrame, mapping: dict):
        return df.select([col(c).alias(mapping.get(c, c)) for c in df.columns])

    return reduce(exposed_column, sequence=col_mapping, initial=df)


spark = SparkSession.builder.appName("app").getOrCreate()
l = [
    ("Bob", 25, "Spain"),
    ("Marc", 22, "France"),
    ("Steve", 20, "Belgium"),
    ("Donald", 26, "USA"),
]
rdd = spark.sparkContext.parallelize(l)
people = rdd.map(lambda x: Row(name=x[0], age=int(x[1]), country=x[2])).toDF()

people.show()
process_data(people).show()

people.show() is looking like this

+---+-------+------+
|age|country|  name|
+---+-------+------+
| 25|  Spain|   Bob|
| 22| France|  Marc|
| 20|Belgium| Steve|
| 26|    USA|Donald|
+---+-------+------+

And this is the expected output

+------+---+
|     a|  b|
+------+---+
|   Bob| 25|
|  Marc| 22|
| Steve| 20|
|Donald| 26|
+------+---+

Upvotes: 0

Views: 520

Answers (1)

Oliver W.
Oliver W.

Reputation: 13459

reduce does not take keywords, that’s true. Once you remove the keywords, you’ll notice a more serious issue though: when you iterate over a dictionary, you’re iterating over its keys only. So the function in which you're trying to batch rename the columns won’t do what you had in mind.

One way to do a batch column rename, would be to iterate over the dictionary’s items:

from typing import Mapping
from pyspark.sql import DataFrame

def rename_columns(frame: DataFrame, mapping: Mapping[str, str]) -> DataFrame:
    return reduce(lambda f, old_new: f.withColumnRenamed(old_new[0], old_new[1]),
                  mapping.items(), frame)

This allows you to pass in a dictionary (note that the recommendation for adding type hints to arguments is to use Mapping, not dict) that maps column names to other names. Fortunately, withColumnRenamed won’t complain if you try to rename a column that isn’t in the DataFrame, so this is equivalent to your mapping.get(c, c).

One thing I’m not noticing in your code is that it is dropping the country column. So that’ll still be in your output.

Upvotes: 2

Related Questions