MrCartoonology
MrCartoonology

Reputation: 2067

How do I add a column to a nested struct in a PySpark dataframe?

I have a dataframe with a schema like

root
 |-- state: struct (nullable = true)
 |    |-- fld: integer (nullable = true)

I'd like to add columns within the state struct, that is, create a dataframe with a schema like

root
 |-- state: struct (nullable = true)
 |    |-- fld: integer (nullable = true)
 |    |-- a: integer (nullable = true)

I tried

df.withColumn('state.a', val).printSchema()
# root
#  |-- state: struct (nullable = true)
#  |    |-- fld: integer (nullable = true)
#  |-- state.a: integer (nullable = true)

Upvotes: 28

Views: 65977

Answers (7)

Henrique Maia
Henrique Maia

Reputation: 11

You can use the struct function

import pyspark.sql.functions as f

df = df.withColumn(
    "state",
    f.struct(
        f.col("state.fld").alias("fld"),
        f.lit(1).alias("a")
    )
)

Upvotes: 1

ZygD
ZygD

Reputation: 24386

Spark 3.1+

F.col('state').withField('a', F.lit(1))

Example:

from pyspark.sql import functions as F
df = spark.createDataFrame([((1,),)], 'state:struct<fld:int>')
df.printSchema()
# root
#  |-- state: struct (nullable = true)
#  |    |-- fld: integer (nullable = true)

df = df.withColumn('state', F.col('state').withField('a', F.lit(1)))
df.printSchema()
# root
#  |-- state: struct (nullable = true)
#  |    |-- fld: integer (nullable = true)
#  |    |-- a: integer (nullable = false)

Upvotes: 6

malthe
malthe

Reputation: 1449

Use a transformation such as the following:

import pyspark.sql.functions as f

df = df.withColumn(
    "state",
    f.struct(
        f.col("state.*"),
        f.lit(123).alias("a")
    )
)

Upvotes: 24

Clay
Clay

Reputation: 2726

Here's a way to do it without a udf.

Initialize example dataframe:

nested_df1 = (spark.read.json(sc.parallelize(["""[
        { "state": {"fld": 1} },
        { "state": {"fld": 2}}
    ]"""])))

nested_df1.printSchema()
root
 |-- state: struct (nullable = true)
 |    |-- fld: long (nullable = true)

Spark .read.json imports all integers as long by default. If state.fld has to be an int, you will need to cast it.

from pyspark.sql import functions as F

nested_df1 = (nested_df1
    .select( F.struct(F.col("state.fld").alias("fld").cast('int')).alias("state") ))

nested_df1.printSchema()
root
 |-- state: struct (nullable = false)
 |    |-- col1: integer (nullable = true)
nested_df1.show()
+-----+
|state|
+-----+
|  [1]|
|  [2]|
+-----+

Finally

Use .select to get the nested columns you want from the existing struct with the "parent.child" notation, create the new column, then re-wrap the old columns together with the new columns in a struct.

val_a = 3

nested_df2 = (nested_df
    .select( 
        F.struct(
            F.col("state.fld"), 
            F.lit(val_a).alias("a")
        ).alias("state")
    )
)


nested_df2.printSchema()
root
 |-- state: struct (nullable = false)
 |    |-- fld: integer (nullable = true)
 |    |-- a: integer (nullable = false)
nested_df2.show()
+------+
| state|
+------+
|[1, 3]|
|[2, 3]|
+------+

Flatten if needed with "parent.*".

nested_df2.select("state.*").printSchema()
root
 |-- fld: integer (nullable = true)
 |-- a: integer (nullable = false)
nested_df2.select("state.*").show()
+---+---+
|fld|  a|
+---+---+
|  1|  3|
|  2|  3|
+---+---+

Upvotes: 2

Xingang Wang
Xingang Wang

Reputation: 1

from pyspark.sql.functions import *
from pyspark.sql.types import *
def add_field_in_dataframe(nfield, df, dt): 
    fields = nfield.split(".")
    print fields
    n = len(fields)
    addField = fields[0]  
    if n == 1:
        return df.withColumn(addField, lit(None).cast(dt))

    nestedField = ".".join(fields[:-1])
    sfields = df.select(nestedField).schema[fields[-2]].dataType.names
    print sfields
    ac = col(nestedField)
    if n == 2:
        nc = struct(*( [ac[c].alias(c) for c in sfields] + [lit(None).cast(dt).alias(fields[-1])]))
    else:
        nc = struct(*( [ac[c].alias(c) for c in sfields] + [lit(None).cast(dt).alias(fields[-1])])).alias(fields[-2])
    print nc
    n = n - 1

    while n > 1: 
        print "n: ",n
        fields = fields[:-1]
        print "fields: ", fields
        nestedField = ".".join(fields[:-1])
        print "nestedField: ", nestedField
        sfields = df.select(nestedField).schema[fields[-2]].dataType.names
        print fields[-1]
        print "sfields: ", sfields
        sfields = [s for s in sfields if s != fields[-1]]
        print "sfields: ", sfields
        ac = col(".".join(fields[:-1]))
        if n > 2: 
            print fields[-2]
            nc = struct(*( [ac[c].alias(c) for c in sfields] + [nc])).alias(fields[-2])
        else:
            nc = struct(*( [ac[c].alias(c) for c in sfields] + [nc]))
        n = n - 1
    return df.withColumn(addField, nc)

Upvotes: -2

desaiankitb
desaiankitb

Reputation: 1052

Although this is a too late answer, for pyspark version 2.x.x following is supported.

Assuming dfOld already contains state and fld as asked in question.

dfOld.withColumn("a","value") dfNew = dfOld.select("level1Field1", "level1Field2", struct(col("state.fld").alias("fld"), col("a")).alias("state"))

Reference: https://medium.com/@mrpowers/adding-structtype-columns-to-spark-dataframes-b44125409803

Upvotes: 5

pault
pault

Reputation: 43504

Here is a way to do it without using a udf:

# create example dataframe
import pyspark.sql.functions as f
data = [
    ({'fld': 0},)
]

schema = StructType(
    [
        StructField('state',
            StructType(
                [StructField('fld', IntegerType())]
            )
        )
    ]
)

df = sqlCtx.createDataFrame(data, schema)
df.printSchema()
#root
# |-- state: struct (nullable = true)
# |    |-- fld: integer (nullable = true)

Now use withColumn() and add the new field using lit() and alias().

val = 1
df_new = df.withColumn(
    'state', 
    f.struct(*[f.col('state')['fld'].alias('fld'), f.lit(val).alias('a')])
)
df_new.printSchema()
#root
# |-- state: struct (nullable = false)
# |    |-- fld: integer (nullable = true)
# |    |-- a: integer (nullable = false)

If you have a lot of fields in the nested struct you can use a list comprehension, using df.schema["state"].dataType.names to get the field names. For example:

val = 1
s_fields = df.schema["state"].dataType.names # ['fld']
df_new = df.withColumn(
    'state', 
    f.struct(*([f.col('state')[c].alias(c) for c in s_fields] + [f.lit(val).alias('a')]))
)
df_new.printSchema()
#root
# |-- state: struct (nullable = false)
# |    |-- fld: integer (nullable = true)
# |    |-- a: integer (nullable = false)

References

  • I found a way to get the field names from the Struct without naming them manually from this answer.

Upvotes: 35

Related Questions