Ébe Isaac
Ébe Isaac

Reputation: 12331

How to traverse through a schema in Spark?

I would like to iterate over a schema in Spark. Using df.schema gives a list of nested StructType and StructFields.

The root elements can be indexed like so.

IN: val temp = df.schema

IN: temp(0)
OUT: StructField(A,StringType,true)

IN: temp(3)
OUT: StructField(D,StructType(StructField(D1,StructType(StructField(D11,StringType,true), StructField(D12,StringType,true), StructField(D13,StringType,true)),true), StructField(D2,StringType,true), StructField(D3,StringType,true)),true)

When I try to access the nested StructType, the following occurs

IN: val temp1 = temp(3).dataType

IN: temp1(0)
OUT:
Name: Unknown Error
Message: <console>:38: error: org.apache.spark.sql.types.DataType does not take parameters
       temp1(0)
            ^
StackTrace: 

What I don't understand is that both temp and temp1 are of the StructType class, but temp is iterable but temp1 isn't.

IN: temp.getClass
OUT: class org.apache.spark.sql.types.StructType

IN: temp1.getClass
OUT: class org.apache.spark.sql.types.StructType

I also tried dtypes but ended up with the similar problem when trying to access nested elements.

IN: df.dtypes(3)(0)
OUT:
Name: Unknown Error
Message: <console>:36: error: (String, String) does not take parameters
       df.dtypes(3)(0)
                   ^
StackTrace: 

So, how can you traverse a schema prior to knowing the sub-fields?

Upvotes: 8

Views: 7969

Answers (3)

semicolon
semicolon

Reputation: 338

Adding to the answer given by @rich_morton

This would work if we have Struct inside array or Array inside struct nested schema

    def flattenSchema(schema: StructType): Seq[String] = {
  schema.fields.flatMap {
    case StructField(name, inner: StructType, _, _) => Seq(name) ++ flattenSchema(inner)
    case StructField(name, ArrayType(structType: StructType, _), _, _) => Seq(name) ++flattenSchema(structType)
    case StructField(name, _, _, _) => Seq(name)
   
  }
}

flattenSchema(yourDf.schema)

Upvotes: 0

richardjmorton
richardjmorton

Reputation: 321

In Spark SQL type schemas there are a few complex datatypes to worry about when recursing through it, e.g., StructType, ArrayType and MapType. To write a function that fully traverses a schema with maps of structs and arrays of maps is quite complex.

To recurse down most schemas I have come across, I have only needed to account for StructType and ArrayType.

Given a schema like:

    root
     |-- name: string (nullable = true)
     |-- nameSecond: long (nullable = true)
     |-- acctRep: string (nullable = true)
     |-- nameDouble: array (nullable = true)
     |    |-- element: struct (containsNull = true)
     |    |    |-- insideK: string (nullable = true)
     |    |    |-- insideS: string (nullable = true)
     |    |    |-- insideD: long (nullable = true)
     |-- inside1: long (nullable = true)

I would use a recursive function like this:

    def collectAllFieldNames(schema: StructType): List[String] = 
        schema.fields.flatMap {
            case StructField(name, structType: StructType, _, _) => name :: collectAllFieldNames(structType)
            case StructField(name, ArrayType(structType: StructType, _), _, _) => name :: collectAllFieldNames(structType)
            case StructField(name, _, _, _) => name :: Nil
        }

Giving the result:

    List(name, nameSecond, acctRep, nameDouble, insideK, insideS, insideK, inside1)

Upvotes: 3

addmeaning
addmeaning

Reputation: 1398

Well, if you want the list of all nested column columns you can write a recursive function like that

Given:

  val schema = StructType(
    StructField("name", StringType) ::
      StructField("nameSecond", StringType) ::
      StructField("nameDouble", StringType) ::
      StructField("someStruct", StructType(
        StructField("insideS", StringType) ::
          StructField("insideD", StructType(
            StructField("inside1", StringType) :: Nil
          )) ::
          Nil
      )) ::
      Nil
  )
  val rdd = session.sparkContext.emptyRDD[Row]
  val df = session.createDataFrame(rdd, schema)

 df.printSchema()

Which will produce:

root
 |-- name: string (nullable = true)
 |-- nameSecond: string (nullable = true)
 |-- nameDouble: string (nullable = true)
 |-- someStruct: struct (nullable = true)
 |    |-- insideS: string (nullable = true)
 |    |-- insideD: struct (nullable = true)
 |    |    |-- inside1: string (nullable = true)

If you want the list of full names of the columns you can write something like this:

def fullFlattenSchema(schema: StructType): Seq[String] = {
  def helper(schema: StructType, prefix: String): Seq[String] = {
    val fullName: String => String = name => if (prefix.isEmpty) name else s"$prefix.$name"
    schema.fields.flatMap {
      case StructField(name, inner: StructType, _, _) =>
        fullName(name) +: helper(inner, fullName(name))
      case StructField(name, _, _, _) => Seq(fullName(name))
    }
  }

  helper(schema, "")
}

Which will return:

ArraySeq(name, nameSecond, nameDouble, someStruct, someStruct.insideS, someStruct.insideD, someStruct.insideD.inside1)

Upvotes: 5

Related Questions