gokyori
gokyori

Reputation: 367

Returning multiple columns from a single pyspark dataframe

I am trying to parse a single single column of pyspark dataframe and get dataframe with multiple columns.My dataframe is as follows:

   a  b               dic
0  1  2  {'d': 1, 'e': 2}
1  3  4  {'d': 7, 'e': 0}
2  5  6  {'d': 5, 'e': 4}

I want to parse the dic column and get dataframe as follows. I am looking forward to use pandas UDF if possible. My intended output is as follows:

   a  b  c  d
0  1  2  1  2
1  3  4  7  0
2  5  6  5  4

Here is my attempt to solution:

schema = StructType([
    StructField("c", IntegerType()),
    StructField("d", IntegerType())])

@pandas_udf(schema,PandasUDFType.GROUPED_MAP)
def do_someting(dic_col):
    return (pd.DataFrame(dic_col))

df.apply(add_json).show(10)

But this gives error 'DataFrame' object has no attribute 'apply'

Upvotes: 1

Views: 1515

Answers (2)

blackbishop
blackbishop

Reputation: 32670

You can transform first to JSON string by replacing simple quotes by double quotes, then use from_json to convert it into a struct or map column.

If you know the schema of the dict you can do it like this:

data = [
    (1,   2,  "{'c': 1, 'd': 2}"),
    (3,   4,  "{'c': 7, 'd': 0}"),
    (5,   6,  "{'c': 5, 'd': 4}")
]

df = spark.createDataFrame(data, ["a", "b", "dic"])

schema = StructType([
    StructField("c", StringType(), True),
    StructField("d", StringType(), True)
])

df = df.withColumn("dic", from_json(regexp_replace(col("dic"), "'", "\""), schema))

df.select("a", "b", "dic.*").show(truncate=False)

#+---+---+---+---+
#|a  |b  |c  |d  |
#+---+---+---+---+
#|1  |2  |1  |2  |
#|3  |4  |7  |0  |
#|5  |6  |5  |4  |
#+---+---+---+---+

If you don't know the all the keys, you can convert it to a map instead of struct then explode it and pivot to get keys as columns:

df = df.withColumn("dic", from_json(regexp_replace(col("dic"), "'", "\""), MapType(StringType(), StringType())))\
       .select("a", "b", explode("dic"))\
       .groupBy("a", "b")\
       .pivot("key")\
       .agg(first("value"))

Upvotes: 2

Georgina Skibinski
Georgina Skibinski

Reputation: 13387

Try:

#to convert pyspark df into pandas:
df=df.toPandas()

df["d"]=df["dic"].str.get("d")
df["e"]=df["dic"].str.get("e")
df=df.drop(columns=["dic"])

Returns:

   a  b  d  e
0  1  2  1  2
1  3  4  7  0
2  5  6  5  4

Upvotes: 0

Related Questions