user1111
user1111

Reputation: 65

Flatten Nested Struct in PySpark Array

Given a schema like:

root
|-- first_name: string
|-- last_name: string
|-- degrees: array
|    |-- element: struct
|    |    |-- school: string
|    |    |-- advisors: struct
|    |    |    |-- advisor1: string
|    |    |    |-- advisor2: string

How can I get a schema like:

root
|-- first_name: string
|-- last_name: string
|-- degrees: array
|    |-- element: struct
|    |    |-- school: string
|    |    |-- advisor1: string
|    |    |-- advisor2: string

Currently, I explode the array, flatten the structure by selecting advisor.* and then group by first_name, last_name and rebuild the array with collect_list. I'm hoping there's a cleaner/shorter way to do this. Currently, there's a lot of pain renaming some fields and stuff that I don't want to get into here. Thanks!

Upvotes: 3

Views: 8072

Answers (2)

pauli
pauli

Reputation: 4301

You can use udf to change the datatype of nested columns in dataframe. Suppose you have read the dataframe as df1

from pyspark.sql.functions import udf
from pyspark.sql.types import *

def foo(data):
    return
    (
        list(map(
            lambda x: (
                x["school"],
                x["advisors"]["advisor1"],
                x["advisors"]["advisor1"]
            ),
            data
        ))
    )

struct = ArrayType(
    StructType([
        StructField("school", StringType()),
        StructField("advisor1", StringType()),
        StructField("advisor2", StringType())
    ])
)
udf_foo = udf(foo, struct)

df2 = df1.withColumn("degrees", udf_foo("degrees"))
df2.printSchema()

output:

root
 |-- degrees: array (nullable = true)
 |    |-- element: struct (containsNull = true)
 |    |    |-- school: string (nullable = true)
 |    |    |-- advisor1: string (nullable = true)
 |    |    |-- advisor2: string (nullable = true)
 |-- first_name: string (nullable = true)
 |-- last_name: string (nullable = true)

Upvotes: 1

Aydin K.
Aydin K.

Reputation: 3378

Here's a more generic solution which can flatten multiple nested struct layers:

def flatten_df(nested_df, layers):
    flat_cols = []
    nested_cols = []
    flat_df = []

    flat_cols.append([c[0] for c in nested_df.dtypes if c[1][:6] != 'struct'])
    nested_cols.append([c[0] for c in nested_df.dtypes if c[1][:6] == 'struct'])

    flat_df.append(nested_df.select(flat_cols[0] +
                               [col(nc+'.'+c).alias(nc+'_'+c)
                                for nc in nested_cols[0]
                                for c in nested_df.select(nc+'.*').columns])
                  )
    for i in range(1, layers):
        print (flat_cols[i-1])
        flat_cols.append([c[0] for c in flat_df[i-1].dtypes if c[1][:6] != 'struct'])
        nested_cols.append([c[0] for c in flat_df[i-1].dtypes if c[1][:6] == 'struct'])

        flat_df.append(flat_df[i-1].select(flat_cols[i] +
                                [col(nc+'.'+c).alias(nc+'_'+c)
                                    for nc in nested_cols[i]
                                    for c in flat_df[i-1].select(nc+'.*').columns])
        )

    return flat_df[-1]

just call with:

my_flattened_df = flatten_df(my_df_having_structs, 3)

(second parameter is the level of layers to be flattened, in my case it's 3)

Upvotes: 0

Related Questions