John
John

Reputation: 1167

Flatten Nested Spark Dataframe

Is there a way to flatten an arbitrarily nested Spark Dataframe? Most of the work I'm seeing is written for specific schema, and I'd like to be able to generically flatten a Dataframe with different nested types (e.g. StructType, ArrayType, MapType, etc).

Say I have a schema like:

StructType(List(StructField(field1,...), StructField(field2,...), ArrayType(StructType(List(StructField(nested_field1,...), StructField(nested_field2,...)),nested_array,...)))

Looking to adapt this into a flat table with a structure like:

field1
field2
nested_array.nested_field1
nested_array.nested_field2

FYI, looking for suggestions for Pyspark, but other flavors of Spark are also appreciated.

Upvotes: 16

Views: 46166

Answers (6)

Sofiia Alieva
Sofiia Alieva

Reputation: 11

I wrote it the following way:

def to_flatten(df):
   for type in df.schema:
       if type.needConversion():
           df = df.withColumn(f"{type.name}.<<your_inner_column>>", 
                             df[f"{type.name}.<<your_inner_column>>"])
   return df

The solution does not drop existing columns.

The nested columns are of type StructType and for StructType the needConversion() method returns True.

(Beware that for some other types needConversion() returns True, however they were not part of my dataframe)

For me it yields similar solution as the one that used a stack: https://stackoverflow.com/a/65256632/21404451.

Upvotes: 1

Narahari B M
Narahari B M

Reputation: 337

This flattens nested df that has both struct types and array types. Typically helps when reading data in through Json. Improved on this https://stackoverflow.com/a/56533459/7131019

from pyspark.sql.types import *
from pyspark.sql import functions as f

def flatten_structs(nested_df):
    stack = [((), nested_df)]
    columns = []

    while len(stack) > 0:
        
        parents, df = stack.pop()
        
        array_cols = [
            c[0]
            for c in df.dtypes
            if c[1][:5] == "array"
        ]
        
        flat_cols = [
            f.col(".".join(parents + (c[0],))).alias("_".join(parents + (c[0],)))
            for c in df.dtypes
            if c[1][:6] != "struct"
        ]

        nested_cols = [
            c[0]
            for c in df.dtypes
            if c[1][:6] == "struct"
        ]
        
        columns.extend(flat_cols)

        for nested_col in nested_cols:
            projected_df = df.select(nested_col + ".*")
            stack.append((parents + (nested_col,), projected_df))
        
    return nested_df.select(columns)

def flatten_array_struct_df(df):
    
    array_cols = [
            c[0]
            for c in df.dtypes
            if c[1][:5] == "array"
        ]
    
    while len(array_cols) > 0:
        
        for array_col in array_cols:
            
            cols_to_select = [x for x in df.columns if x != array_col ]
            
            df = df.withColumn(array_col, f.explode(f.col(array_col)))
            
        df = flatten_structs(df)
        
        array_cols = [
            c[0]
            for c in df.dtypes
            if c[1][:5] == "array"
        ]
    return df

flat_df = flatten_array_struct_df(df)

**

Upvotes: 6

Igor Tavares
Igor Tavares

Reputation: 969

I've developed a recursively approach to flatten any nested DataFrame.

The implementation is on the AWS Data Wrangler code base on GitHub.

P.S. The Spark support was deprecated in the package, but the code base stills useful.

Upvotes: 3

MaFF
MaFF

Reputation: 10096

This issue might be a bit old, but for anyone out there still looking for a solution you can flatten complex data types inline using select *:

first let's create the nested dataframe:

from pyspark.sql import HiveContext
hc = HiveContext(sc)
nested_df = hc.read.json(sc.parallelize(["""
{
  "field1": 1, 
  "field2": 2, 
  "nested_array":{
     "nested_field1": 3,
     "nested_field2": 4
  }
}
"""]))

now to flatten it:

flat_df = nested_df.select("field1", "field2", "nested_array.*")

You'll find useful examples here: https://docs.databricks.com/delta/data-transformation/complex-types.html

If you have too many nested arrays, you can use:

flat_cols = [c[0] for c in nested_df.dtypes if c[1][:6] != 'struct']
nested_cols = [c[0] for c in nested_df.dtypes if c[1][:6] == 'struct']
flat_df = nested_df.select(*flat_cols, *[c + ".*" for c in nested_cols])

Upvotes: 24

bhavin tandel
bhavin tandel

Reputation: 75

The following gist will flatten the structure of the nested json,

import typing as T

import cytoolz.curried as tz
import pyspark


def schema_to_columns(schema: pyspark.sql.types.StructType) -> T.List[T.List[str]]:
    """
    Produce a flat list of column specs from a possibly nested DataFrame schema
    """

    columns = list()

    def helper(schm: pyspark.sql.types.StructType, prefix: list = None):

        if prefix is None:
            prefix = list()

        for item in schm.fields:
            if isinstance(item.dataType, pyspark.sql.types.StructType):
                helper(item.dataType, prefix + [item.name])
            else:
                columns.append(prefix + [item.name])

    helper(schema)

    return columns

def flatten_frame(frame: pyspark.sql.DataFrame) -> pyspark.sql.DataFrame:

    aliased_columns = list()

    for col_spec in schema_to_columns(frame.schema):
        c = tz.get_in(col_spec, frame)
        if len(col_spec) == 1:
            aliased_columns.append(c)
        else:
            aliased_columns.append(c.alias(':'.join(col_spec)))

    return frame.select(aliased_columns)

You can then flatten the nested data as

flatten_data = flatten_frame(nested_df)

This will give you the flatten dataframe.

The gist was taken from https://gist.github.com/DGrady/b7e7ff3a80d7ee16b168eb84603f5599

Upvotes: -2

John
John

Reputation: 1167

Here's my final approach:

1) Map the rows in the dataframe to an rdd of dict. Find suitable python code online for flattening dict.

flat_rdd = nested_df.map(lambda x : flatten(x))

where

def flatten(x):
  x_dict = x.asDict()
  ...some flattening code...
  return x_dict

2) Convert the RDD[dict] back to a dataframe

flat_df = sqlContext.createDataFrame(flat_rdd)

Upvotes: 2

Related Questions