Reputation: 1973
Suppose I have a dataframe in Spark as shown below -
val df = Seq(
(0,0,0,0.0),
(1,0,0,0.1),
(0,1,0,0.11),
(0,0,1,0.12),
(1,1,0,0.24),
(1,0,1,0.27),
(0,1,1,0.30),
(1,1,1,0.40)
).toDF("A","B","C","rate")
Here is how it looks like -
scala> df.show()
+---+---+---+----+
| A| B| C|rate|
+---+---+---+----+
| 0| 0| 0| 0.0|
| 1| 0| 0| 0.1|
| 0| 1| 0|0.11|
| 0| 0| 1|0.12|
| 1| 1| 0|0.24|
| 1| 0| 1|0.27|
| 0| 1| 1| 0.3|
| 1| 1| 1| 0.4|
+---+---+---+----+
A,B and C are the advertising channels in this case. 0 and 1 represent absence and presence of channels respectively. 2^3 shows 8 combinations in the data-frame.
I want to filter records from this data-frame that shows presence of 2 channels at a time( AB, AC, BC) . Here is how I want my output to be -
+---+---+---+----+
| A| B| C|rate|
+---+---+---+----+
| 1| 1| 0|0.24|
| 1| 0| 1|0.27|
| 0| 1| 1| 0.3|
+---+---+---+----+
I can write 3 statements to get the output by doing -
scala> df.filter($"A" === 1 && $"B" === 1 && $"C" === 0).show()
+---+---+---+----+
| A| B| C|rate|
+---+---+---+----+
| 1| 1| 0|0.24|
+---+---+---+----+
scala> df.filter($"A" === 1 && $"B" === 0 && $"C" === 1).show()
+---+---+---+----+
| A| B| C|rate|
+---+---+---+----+
| 1| 0| 1|0.27|
+---+---+---+----+
scala> df.filter($"A" === 0 && $"B" === 1 && $"C" === 1).show()
+---+---+---+----+
| A| B| C|rate|
+---+---+---+----+
| 0| 1| 1| 0.3|
+---+---+---+----+
However, I want to achieve this using either a single statement that does my job or a function that helps me get the output. I was thinking of using a case statement to match the values. However in general my dataframe might consist of more than 3 channels -
scala> df.show()
+---+---+---+---+----+
| A| B| C| D|rate|
+---+---+---+---+----+
| 0| 0| 0| 0| 0.0|
| 0| 0| 0| 1| 0.1|
| 0| 0| 1| 0| 0.1|
| 0| 0| 1| 1|0.59|
| 0| 1| 0| 0| 0.1|
| 0| 1| 0| 1|0.89|
| 0| 1| 1| 0|0.39|
| 0| 1| 1| 1| 0.4|
| 1| 0| 0| 0| 0.0|
| 1| 0| 0| 1|0.99|
| 1| 0| 1| 0|0.49|
| 1| 0| 1| 1| 0.1|
| 1| 1| 0| 0|0.79|
| 1| 1| 0| 1| 0.1|
| 1| 1| 1| 0| 0.1|
| 1| 1| 1| 1| 0.1|
+---+---+---+---+----+
In this scenario I would want my output as -
scala> df.show()
+---+---+---+---+----+
| A| B| C| D|rate|
+---+---+---+---+----+
| 0| 0| 1| 1|0.59|
| 0| 1| 0| 1|0.89|
| 0| 1| 1| 0|0.39|
| 1| 0| 0| 1|0.99|
| 1| 0| 1| 0|0.49|
| 1| 1| 0| 0|0.79|
+---+---+---+---+----+
which shows rates for paired presence of channels => (AB, AC, AD, BC, BD, CD).
Kindly help.
Upvotes: 3
Views: 6119
Reputation: 10092
One way could be to sum the columns and then filter only when the result of the sum is 2.
import org.apache.spark.sql.functions._
df.withColumn("res", $"A" + $"B" + $"C").filter($"res" === lit(2)).drop("res").show
The output is:
+---+---+---+----+
| A| B| C|rate|
+---+---+---+----+
| 1| 1| 0|0.24|
| 1| 0| 1|0.27|
| 0| 1| 1| 0.3|
+---+---+---+----+
Upvotes: 2