Tail of Godzilla
Tail of Godzilla

Reputation: 551

How add a nested column to a dataframe in pyspark?

I have a dataframe with a schema like:

root
 |-- field_a: string (nullable = true)
 |-- field_b: integer (nullable = true)

and I'd like to add a nested column to my dataframe, to have something like this:

root
 |-- field_a: string (nullable = true)
 |-- field_b: integer (nullable = true)
 |-- field_c: struct (nullable = true)
 |    |-- subfield_a: integer (nullable = true)
 |    |-- subfield_b: integer (nullable = true)

How can I achieve this in pyspark?

Upvotes: 0

Views: 2374

Answers (1)

Napoleon Borntoparty
Napoleon Borntoparty

Reputation: 1962

You have two options really, one is by declaring a new schema and nesting your pyspark.sql.types.StructField, or you use pyspark.sql.functions.struct as following:

import pyspark.sql.functions as f

df = spark._sc.parallelize([
    [0, 1.0, 0.71, 0.143],
    [1, 0.0, 0.97, 0.943],
    [0, 0.123, 0.27, 0.443],
    [1, 0.67, 0.3457, 0.243],
    [1, 0.39, 0.7777, 0.143]
]).toDF(['col1', 'col2', 'col3', 'col4'])


df_new = df.withColumn(
    'tada', 
    f.struct(*[f.col('col2').alias('subcol_1'), f.col('col3').alias('subcol_2')])
)
df_new.show()
+----+-----+------+-----+--------------+
|col1| col2|  col3| col4|          tada|
+----+-----+------+-----+--------------+
|   0|  1.0|  0.71|0.143|   [1.0, 0.71]|
|   1|  0.0|  0.97|0.943|   [0.0, 0.97]|
|   0|0.123|  0.27|0.443| [0.123, 0.27]|
|   1| 0.67|0.3457|0.243|[0.67, 0.3457]|
|   1| 0.39|0.7777|0.143|[0.39, 0.7777]|
+----+-----+------+-----+--------------+

Now, given tada is a StructType, you can access it with the [...] notation as follows:

df_new.select(f.col('tada')['subcol_1']).show()
+-------------+
|tada.subcol_1|
+-------------+
|          1.0|
|          0.0|
|        0.123|
|         0.67|
|         0.39|
+-------------+

Printing the schema also summarises:

df_new.printSchema()

root
 |-- col1: long (nullable = true)
 |-- col2: double (nullable = true)
 |-- col3: double (nullable = true)
 |-- col4: double (nullable = true)
 |-- tada: struct (nullable = false)
 |    |-- subcol_1: double (nullable = true)
 |    |-- subcol_2: double (nullable = true)

NB1: Instead of f.col(...) to take an existing column, you can use any other function that returns a pyspark.sql.functions.Column, such as f.lit(). NB2: When using f.col(...), you can see that existing column types will be carried over. Hope this helps!

Upvotes: 2

Related Questions