Georg Heiler
Georg Heiler

Reputation: 17676

Spark SQL refer to columns programmatically

I am about to develop a function which uses spark sql to perform an operation per column. In this function I need to refer to the columns name:

val input = Seq(
    (0, "A", "B", "C", "D"),
    (1, "A", "B", "C", "D"),
    (0, "d", "a", "jkl", "d"),
    (0, "d", "g", "C", "D"),
    (1, "A", "d", "t", "k"),
    (1, "d", "c", "C", "D"),
    (1, "c", "B", "C", "D")
  ).toDF("TARGET", "col1", "col2", "col3TooMany", "col4")

The following example explicitly referring to columns via 'column works fine.

val pre1_1 = input.groupBy('col1).agg(mean($"TARGET").alias("pre_col1"))
val pre2_1 = input.groupBy('col1, 'TARGET).agg(count("*") / input.filter('TARGET === 1).count alias ("pre2_col1"))

input.as('a)
    .join(pre1_1.as('b), $"a.col1" === $"b.col1").drop($"b.col1")
    .join(pre2_1.as('b), ($"a.col1" === $"b.col1") and ($"a.TARGET" === $"b.TARGET")).drop($"b.col1").drop($"b.TARGET").show

When referring to the columns programmatically they can no longer be resolved. When 2 joins are performed one after the other which worked fine for the code snippet above.

I could observe that for this code snippet the first and initial col1 of df was moved from the beginning to the end. Probably this is the reason that it can no longer be resolved. But so far I could not figure it out how to access the column when only passing a string / how to properly reference the colnames in a function.

val pre1_1 = input.groupBy("col1").agg(mean('TARGET).alias("pre_" + "col1"))
val pre2_1 = input.groupBy("col1", "TARGET").agg(count("*") / input.filter('TARGET === 1).count alias ("pre2_" + "col1"))
  input.join(pre1_1, input("col1") === pre1_1("col1")).drop(pre1_1("col1"))
    .join(pre2_1, (input("col1") === pre2_1("col1")) and (input("TARGET") === pre2_1("TARGET"))).drop(pre2_1("col1")).drop(pre2_1("TARGET"))

as well as an alternative approach like:

df.as('a)
      .join(pre1_1.as('b), $"a.${col}" === $"b.${col}").drop($"b.${col}")

did not succeed as $"a.${col}" no longer was resolved to a.Column but rather df("a.col1") which does not exist.

Upvotes: 2

Views: 710

Answers (1)

user7248695
user7248695

Reputation: 46

In complex cases always use unique aliases to reference columns with shared lineage. This is the only way to ensure correct and stable behavior.

import org.apache.spark.sql.functions.col

val pre1_1 = input.groupBy("col1").agg(mean('TARGET).alias("pre_" + "col1")).alias("pre1_1")
val pre2_1 = input.groupBy("col1", "TARGET").agg(count("*") / input.filter('TARGET === 1).count alias ("pre2_" + "col1")).alias("pre2_1")

input.alias("input")
  .join(pre1_1, col("input.col1") === col("pre1_1.col1"))
  .join(pre2_1, (col("input.col1") === col("pre2_1.col1")) and (col("input.TARGET") === col("pre2_1.TARGET")))

If you check logs you actually see warnings like:

WARN Column: Constructing trivially true equals predicate, 'col1#12 = col1#12'. Perhaps you need to use aliases

and code you use work only because there are "special cases" in Spark source.

In simple case like this just use equi-join syntax:

input.join(pre1_1, Seq("col1"))
  .join(pre2_1, Seq("col1", "TARGET"))

Upvotes: 3

Related Questions