Reputation: 7605
I am trying to solve a data cleaning step in a Machine Learning problem where I should group all the elements in the long tail in a common category named "Others". For example, I have a dataframe like this:
val df = sc.parallelize(Seq(
(1, "ABC"),
(2, "ABC"),
(3, "123"),
(4, "FPK"),
(5, "FPK"),
(6, "ABC"),
(7, "ABC"),
(8, "980"),
(9, "abc"),
(10, "FPK")
)).toDF("n", "s")
I want to keep the categories "ABC"
and "FPK"
since they appear several times, but I don't want to have one different category for: 123,980,abc
Since they appear just once. So What I would like to have instead is:
+---+------+
| n| s|
+---+------+
| 1| ABC|
| 2| ABC|
| 3|Others|
| 4| FPK|
| 5| FPK|
| 6| ABC|
| 7| ABC|
| 8|Others|
| 9|Others|
| 10| FPK|
+---+------+
To achieve this what I tried is this:
val newDF = df.withColumn("s",when($"s".isin("123","980","abc"),"Others").otherwise('s)
This works fine.
But I would like to programatically decide what categories belong to the long tail, in my case appear just once in the originall dataframe. So I wrote this to create a dataframe with those categories that only appear once:
val longTail = df.groupBy("s").agg(count("*").alias("cnt")).orderBy($"cnt".desc).filter($"cnt"<2)
+---+---+
| s|cnt|
+---+---+
|980| 1|
|abc| 1|
|123| 1|
+---+---+
Now I was trying to convert the values of the column "s" in this longTail dataset into a List to exchange it by the one I hardcoded before. So I tried with:
val ar = longTail.select("s").collect().map(_(0)).toList
ar: List[Any] = List(123, 980, abc)
But when I try to add the ar
val newDF = df.withColumn("s",when($"s".isin(ar),"Others").otherwise('s))
I get the following error:
java.lang.RuntimeException: Unsupported literal type class scala.collection.immutable.$colon$colon List(123, 980, abc)
What am I missing?
Upvotes: 1
Views: 250
Reputation: 40370
This is the correct syntax :
scala> df.withColumn("s", when($"s".isin(ar : _*), "Others").otherwise('s)).show
+---+------+
| n| s|
+---+------+
| 1| ABC|
| 2| ABC|
| 3|Others|
| 4| FPK|
| 5| FPK|
| 6| ABC|
| 7| ABC|
| 8|Others|
| 9|Others|
| 10| FPK|
+---+------+
This is called a repeated parameter. cf here.
Upvotes: 3
Reputation: 41957
You don't have to go through all the hassles you've been going through, you can use window
function to get the counts
of each groups and check using when/otherwise
function to populate Others
or not as below
val df = sc.parallelize(Seq(
(1, "ABC"),
(2, "ABC"),
(3, "123"),
(4, "FPK"),
(5, "FPK"),
(6, "ABC"),
(7, "ABC"),
(8, "980"),
(9, "abc"),
(10, "FPK")
)).toDF("n", "s")
import org.apache.spark.sql.functions._
import org.apache.spark.sql.expressions._
df.withColumn("s", when(count("s").over(Window.partitionBy("s").orderBy("n").rowsBetween(Long.MinValue, Long.MaxValue)) > 1, col("s")).otherwise("Others")).show(false)
which should give you
+---+------+
|n |s |
+---+------+
|4 |FPK |
|5 |FPK |
|10 |FPK |
|8 |Others|
|9 |Others|
|1 |ABC |
|2 |ABC |
|6 |ABC |
|7 |ABC |
|3 |Others|
+---+------+
I hope the answer is helpful
Upvotes: 3