Reputation: 1899
How can I add or replace fields to a struct on any nested level?
This input:
val rdd = sc.parallelize(Seq(
"""{"a": {"xX": 1,"XX": 2},"b": {"z": 0}}""",
"""{"a": {"xX": 3},"b": {"z": 0}}""",
"""{"a": {"XX": 3},"b": {"z": 0}}""",
"""{"a": {"xx": 4},"b": {"z": 0}}"""))
var df = sqlContext.read.json(rdd)
Yields the following schema:
root
|-- a: struct (nullable = true)
| |-- XX: long (nullable = true)
| |-- xX: long (nullable = true)
| |-- xx: long (nullable = true)
|-- b: struct (nullable = true)
| |-- z: long (nullable = true)
Then I can do this:
import org.apache.spark.sql.functions._
val overlappingNames = Seq(col("a.xx"), col("a.xX"), col("a.XX"))
df = df
.withColumn("a_xx",
coalesce(overlappingNames:_*))
.dropNestedColumn("a.xX")
.dropNestedColumn("a.XX")
.dropNestedColumn("a.xx")
(dropNestedColumn
is borrowed from this answer:
https://stackoverflow.com/a/39943812/1068385. I'm basically looking for the inverse operation of that.)
And the schema becomes:
root
|-- a: struct (nullable = false)
|-- b: struct (nullable = true)
| |-- z: long (nullable = true)
|-- a_xx: long (nullable = true)
Obviously it doesn't replace (or add) a.xx
but instead it adds the new field a_xx
on root level.
I'd like to be able to do this instead:
val overlappingNames = Seq(col("a.xx"), col("a.xX"), col("a.XX"))
df = df
.withNestedColumn("a.xx",
coalesce(overlappingNames:_*))
.dropNestedColumn("a.xX")
.dropNestedColumn("a.XX")
So that it would result in this schema:
root
|-- a: struct (nullable = false)
| |-- xx: long (nullable = true)
|-- b: struct (nullable = true)
| |-- z: long (nullable = true)
How can I achieve that?
The practical goal here is to be case-insensitive with column names in the input JSON. The final step would be simple: collect all overlapping column names and apply the coalesce on each.
Upvotes: 12
Views: 9274
Reputation: 2094
It might not be as elegant or as efficient as it could be but here is what I came up with:
object DataFrameUtils {
private def nullableCol(parentCol: Column, c: Column): Column = {
when(parentCol.isNotNull, c)
}
private def nullableCol(c: Column): Column = {
nullableCol(c, c)
}
private def createNestedStructs(splitted: Seq[String], newCol: Column): Column = {
splitted
.foldRight(newCol) {
case (colName, nestedStruct) => nullableCol(struct(nestedStruct as colName))
}
}
private def recursiveAddNestedColumn(splitted: Seq[String], col: Column, colType: DataType, nullable: Boolean, newCol: Column): Column = {
colType match {
case colType: StructType if splitted.nonEmpty => {
var modifiedFields: Seq[(String, Column)] = colType.fields
.map(f => {
var curCol = col.getField(f.name)
if (f.name == splitted.head) {
curCol = recursiveAddNestedColumn(splitted.tail, curCol, f.dataType, f.nullable, newCol)
}
(f.name, curCol as f.name)
})
if (!modifiedFields.exists(_._1 == splitted.head)) {
modifiedFields :+= (splitted.head, nullableCol(col, createNestedStructs(splitted.tail, newCol)) as splitted.head)
}
var modifiedStruct: Column = struct(modifiedFields.map(_._2): _*)
if (nullable) {
modifiedStruct = nullableCol(col, modifiedStruct)
}
modifiedStruct
}
case _ => createNestedStructs(splitted, newCol)
}
}
private def addNestedColumn(df: DataFrame, newColName: String, newCol: Column): DataFrame = {
if (newColName.contains('.')) {
var splitted = newColName.split('.')
val modifiedOrAdded: (String, Column) = df.schema.fields
.find(_.name == splitted.head)
.map(f => (f.name, recursiveAddNestedColumn(splitted.tail, col(f.name), f.dataType, f.nullable, newCol)))
.getOrElse {
(splitted.head, createNestedStructs(splitted.tail, newCol) as splitted.head)
}
df.withColumn(modifiedOrAdded._1, modifiedOrAdded._2)
} else {
// Top level addition, use spark method as-is
df.withColumn(newColName, newCol)
}
}
implicit class ExtendedDataFrame(df: DataFrame) extends Serializable {
/**
* Add nested field to DataFrame
*
* @param newColName Dot-separated nested field name
* @param newCol New column value
*/
def withNestedColumn(newColName: String, newCol: Column): DataFrame = {
DataFrameUtils.addNestedColumn(df, newColName, newCol)
}
}
}
Feel free to improve on it.
val data = spark.sparkContext.parallelize(List("""{ "a1": 1, "a3": { "b1": 3, "b2": { "c1": 5, "c2": 6 } } }"""))
val df: DataFrame = spark.read.json(data)
val df2 = df.withNestedColumn("a3.b2.c3.d1", $"a3.b2")
should produce:
assertResult("struct<a1:bigint,a3:struct<b1:bigint,b2:struct<c1:bigint,c2:bigint,c3:struct<d1:struct<c1:bigint,c2:bigint>>>>>")(df2.shema.simpleString)
Upvotes: 10