Regressor
Regressor

Reputation: 1973

Filtering on multiple columns in Spark dataframes

Suppose I have a dataframe in Spark as shown below -

val df = Seq(
(0,0,0,0.0),
(1,0,0,0.1),
(0,1,0,0.11),
(0,0,1,0.12),
(1,1,0,0.24),
(1,0,1,0.27),
(0,1,1,0.30),
(1,1,1,0.40)
).toDF("A","B","C","rate")

Here is how it looks like -

scala> df.show()
+---+---+---+----+
|  A|  B|  C|rate|
+---+---+---+----+
|  0|  0|  0| 0.0|
|  1|  0|  0| 0.1|
|  0|  1|  0|0.11|
|  0|  0|  1|0.12|
|  1|  1|  0|0.24|
|  1|  0|  1|0.27|
|  0|  1|  1| 0.3|
|  1|  1|  1| 0.4|
+---+---+---+----+

A,B and C are the advertising channels in this case. 0 and 1 represent absence and presence of channels respectively. 2^3 shows 8 combinations in the data-frame.

I want to filter records from this data-frame that shows presence of 2 channels at a time( AB, AC, BC) . Here is how I want my output to be -

+---+---+---+----+
|  A|  B|  C|rate|
+---+---+---+----+
|  1|  1|  0|0.24|
|  1|  0|  1|0.27|
|  0|  1|  1| 0.3|
+---+---+---+----+

I can write 3 statements to get the output by doing -

scala> df.filter($"A" === 1 && $"B" === 1 && $"C" === 0).show()
+---+---+---+----+
|  A|  B|  C|rate|
+---+---+---+----+
|  1|  1|  0|0.24|
+---+---+---+----+


scala> df.filter($"A" === 1 && $"B" === 0  && $"C" === 1).show()
+---+---+---+----+
|  A|  B|  C|rate|
+---+---+---+----+
|  1|  0|  1|0.27|
+---+---+---+----+


scala> df.filter($"A" === 0 && $"B" === 1 && $"C" === 1).show()
+---+---+---+----+
|  A|  B|  C|rate|
+---+---+---+----+
|  0|  1|  1| 0.3|
+---+---+---+----+

However, I want to achieve this using either a single statement that does my job or a function that helps me get the output. I was thinking of using a case statement to match the values. However in general my dataframe might consist of more than 3 channels -

scala> df.show()
+---+---+---+---+----+
|  A|  B|  C|  D|rate|
+---+---+---+---+----+
|  0|  0|  0|  0| 0.0|
|  0|  0|  0|  1| 0.1|
|  0|  0|  1|  0| 0.1|
|  0|  0|  1|  1|0.59|
|  0|  1|  0|  0| 0.1|
|  0|  1|  0|  1|0.89|
|  0|  1|  1|  0|0.39|
|  0|  1|  1|  1| 0.4|
|  1|  0|  0|  0| 0.0|
|  1|  0|  0|  1|0.99|
|  1|  0|  1|  0|0.49|
|  1|  0|  1|  1| 0.1|
|  1|  1|  0|  0|0.79|
|  1|  1|  0|  1| 0.1|
|  1|  1|  1|  0| 0.1|
|  1|  1|  1|  1| 0.1|
+---+---+---+---+----+

In this scenario I would want my output as -

scala> df.show()
+---+---+---+---+----+
|  A|  B|  C|  D|rate|
+---+---+---+---+----+
|  0|  0|  1|  1|0.59|
|  0|  1|  0|  1|0.89|
|  0|  1|  1|  0|0.39|
|  1|  0|  0|  1|0.99|
|  1|  0|  1|  0|0.49|
|  1|  1|  0|  0|0.79|
+---+---+---+---+----+

which shows rates for paired presence of channels => (AB, AC, AD, BC, BD, CD).

Kindly help.

Upvotes: 3

Views: 6119

Answers (1)

philantrovert
philantrovert

Reputation: 10092

One way could be to sum the columns and then filter only when the result of the sum is 2.

import org.apache.spark.sql.functions._

df.withColumn("res", $"A" + $"B" + $"C").filter($"res" === lit(2)).drop("res").show

The output is:

+---+---+---+----+
|  A|  B|  C|rate|
+---+---+---+----+
|  1|  1|  0|0.24|
|  1|  0|  1|0.27|
|  0|  1|  1| 0.3|
+---+---+---+----+

Upvotes: 2

Related Questions