Reputation: 551
I have a dataframe with a schema like:
root
|-- field_a: string (nullable = true)
|-- field_b: integer (nullable = true)
and I'd like to add a nested column to my dataframe, to have something like this:
root
|-- field_a: string (nullable = true)
|-- field_b: integer (nullable = true)
|-- field_c: struct (nullable = true)
| |-- subfield_a: integer (nullable = true)
| |-- subfield_b: integer (nullable = true)
How can I achieve this in pyspark?
Upvotes: 0
Views: 2374
Reputation: 1962
You have two options really, one is by declaring a new schema and nesting your pyspark.sql.types.StructField
, or you use pyspark.sql.functions.struct
as following:
import pyspark.sql.functions as f
df = spark._sc.parallelize([
[0, 1.0, 0.71, 0.143],
[1, 0.0, 0.97, 0.943],
[0, 0.123, 0.27, 0.443],
[1, 0.67, 0.3457, 0.243],
[1, 0.39, 0.7777, 0.143]
]).toDF(['col1', 'col2', 'col3', 'col4'])
df_new = df.withColumn(
'tada',
f.struct(*[f.col('col2').alias('subcol_1'), f.col('col3').alias('subcol_2')])
)
df_new.show()
+----+-----+------+-----+--------------+
|col1| col2| col3| col4| tada|
+----+-----+------+-----+--------------+
| 0| 1.0| 0.71|0.143| [1.0, 0.71]|
| 1| 0.0| 0.97|0.943| [0.0, 0.97]|
| 0|0.123| 0.27|0.443| [0.123, 0.27]|
| 1| 0.67|0.3457|0.243|[0.67, 0.3457]|
| 1| 0.39|0.7777|0.143|[0.39, 0.7777]|
+----+-----+------+-----+--------------+
Now, given tada
is a StructType
, you can access it with the [...]
notation as follows:
df_new.select(f.col('tada')['subcol_1']).show()
+-------------+
|tada.subcol_1|
+-------------+
| 1.0|
| 0.0|
| 0.123|
| 0.67|
| 0.39|
+-------------+
Printing the schema also summarises:
df_new.printSchema()
root
|-- col1: long (nullable = true)
|-- col2: double (nullable = true)
|-- col3: double (nullable = true)
|-- col4: double (nullable = true)
|-- tada: struct (nullable = false)
| |-- subcol_1: double (nullable = true)
| |-- subcol_2: double (nullable = true)
NB1: Instead of f.col(...)
to take an existing column, you can use any other function that returns a pyspark.sql.functions.Column
, such as f.lit()
.
NB2: When using f.col(...)
, you can see that existing column types will be carried over.
Hope this helps!
Upvotes: 2