Reputation: 116
I wish to collect the names of all the fields in a nested schema. The data were imported from a json file.
The schema looks like:
root
|-- column_a: string (nullable = true)
|-- column_b: string (nullable = true)
|-- column_c: struct (nullable = true)
| |-- nested_a: struct (nullable = true)
| | |-- double_nested_a: string (nullable = true)
| | |-- double_nested_b: string (nullable = true)
| | |-- double_nested_c: string (nullable = true)
| |-- nested_b: string (nullable = true)
|-- column_d: string (nullable = true)
If I use df.schema.fields
or df.schema.names
it just prints the names of the column layer - none of the nested columns.
The desired output I want is a python list, which contains all the column names such as:
['column_a', 'columb_b', 'column_c.nested_a.double_nested.a', 'column_c.nested_a.double_nested.b', etc...]
The information exists there if I want to write a custom function - but am I missing a beat? Does there exist a method that achieves what I need?
Upvotes: 2
Views: 3842
Reputation: 31540
By default in Spark doesn't have any method to give us flatten the schema names.
Use the code from this post:
def flatten(schema, prefix=None):
fields = []
for field in schema.fields:
name = prefix + '.' + field.name if prefix else field.name
dtype = field.dataType
if isinstance(dtype, ArrayType):
dtype = dtype.elementType
if isinstance(dtype, StructType):
fields += flatten(dtype, prefix=name)
else:
fields.append(name)
return fields
df.printSchema()
#root
# |-- column_a: string (nullable = true)
# |-- column_c: struct (nullable = true)
# | |-- nested_a: struct (nullable = true)
# | | |-- double_nested_a: string (nullable = true)
# | |-- nested_b: string (nullable = true)
# |-- column_d: string (nullable = true)
sch=df.schema
print(flatten(sch))
#['column_a', 'column_c.nested_a.double_nested_a', 'column_c.nested_b', 'column_d']
Upvotes: 5