addmeaning
addmeaning

Reputation: 1398

How to get all columns in Spark DataFrame recursively

I want to get all columns of DataFrame. If DataFrame has a flat structure (no nested StructTypes) df.columns produces correct result. I want to return all nested column names also, e. g.

Given

val schema = StructType(
  StructField("name", StringType) ::
  StructField("nameSecond", StringType) ::
  StructField("nameDouble", StringType) ::
  StructField("someStruct", StructType(
    StructField("insideS", StringType)::
    StructField("insideD", DoubleType)::
    Nil
  )) ::
  Nil
)
val rdd = spark.sparkContext.emptyRDD[Row]
val df = spark.createDataFrame(rdd, schema)

I want to get

Seq("name", "nameSecond", "nameDouble", "someStruct", "insideS", "insideD")

Upvotes: 1

Views: 1920

Answers (1)

Tzach Zohar
Tzach Zohar

Reputation: 37832

You can use this recursive function to traverse the schema:

def flattenSchema(schema: StructType): Seq[String] = {
  schema.fields.flatMap {
    case StructField(name, inner: StructType, _, _) => Seq(name) ++ flattenSchema(inner)
    case StructField(name, _, _, _) => Seq(name)
  }
}

println(flattenSchema(schema)) 
// prints: ArraySeq(name, nameSecond, nameDouble, someStruct, insideS, insideD)

Upvotes: 6

Related Questions