Ignacio Alorre
Ignacio Alorre

Reputation: 7605

Spark - Change value of records which belong to the long tail in a Dataset

I am trying to solve a data cleaning step in a Machine Learning problem where I should group all the elements in the long tail in a common category named "Others". For example, I have a dataframe like this:

val df = sc.parallelize(Seq(
(1, "ABC"),
(2, "ABC"),
(3, "123"),
(4, "FPK"),
(5, "FPK"),
(6, "ABC"),
(7, "ABC"),
(8, "980"),
(9, "abc"),
(10, "FPK")
)).toDF("n", "s")

I want to keep the categories "ABC" and "FPK" since they appear several times, but I don't want to have one different category for: 123,980,abc Since they appear just once. So What I would like to have instead is:

+---+------+
|  n|     s|
+---+------+
|  1|   ABC|
|  2|   ABC|
|  3|Others|
|  4|   FPK|
|  5|   FPK|
|  6|   ABC|
|  7|   ABC|
|  8|Others|
|  9|Others|
| 10|   FPK|
+---+------+

To achieve this what I tried is this:

val newDF = df.withColumn("s",when($"s".isin("123","980","abc"),"Others").otherwise('s)

This works fine.

But I would like to programatically decide what categories belong to the long tail, in my case appear just once in the originall dataframe. So I wrote this to create a dataframe with those categories that only appear once:

val longTail = df.groupBy("s").agg(count("*").alias("cnt")).orderBy($"cnt".desc).filter($"cnt"<2)

+---+---+
|  s|cnt|
+---+---+
|980|  1|
|abc|  1|
|123|  1|
+---+---+

Now I was trying to convert the values of the column "s" in this longTail dataset into a List to exchange it by the one I hardcoded before. So I tried with:

 val ar = longTail.select("s").collect().map(_(0)).toList

ar: List[Any] = List(123, 980, abc)

But when I try to add the ar

val newDF = df.withColumn("s",when($"s".isin(ar),"Others").otherwise('s))

I get the following error:

java.lang.RuntimeException: Unsupported literal type class scala.collection.immutable.$colon$colon List(123, 980, abc)

What am I missing?

Upvotes: 1

Views: 250

Answers (2)

eliasah
eliasah

Reputation: 40370

This is the correct syntax :

scala> df.withColumn("s", when($"s".isin(ar : _*), "Others").otherwise('s)).show
+---+------+
|  n|     s|
+---+------+
|  1|   ABC|
|  2|   ABC|
|  3|Others|
|  4|   FPK|
|  5|   FPK|
|  6|   ABC|
|  7|   ABC|
|  8|Others|
|  9|Others|
| 10|   FPK|
+---+------+

This is called a repeated parameter. cf here.

Upvotes: 3

Ramesh Maharjan
Ramesh Maharjan

Reputation: 41957

You don't have to go through all the hassles you've been going through, you can use window function to get the counts of each groups and check using when/otherwise function to populate Others or not as below

val df = sc.parallelize(Seq(
  (1, "ABC"),
  (2, "ABC"),
  (3, "123"),
  (4, "FPK"),
  (5, "FPK"),
  (6, "ABC"),
  (7, "ABC"),
  (8, "980"),
  (9, "abc"),
  (10, "FPK")
)).toDF("n", "s")

import org.apache.spark.sql.functions._
import org.apache.spark.sql.expressions._
df.withColumn("s", when(count("s").over(Window.partitionBy("s").orderBy("n").rowsBetween(Long.MinValue, Long.MaxValue)) > 1, col("s")).otherwise("Others")).show(false)

which should give you

+---+------+
|n  |s     |
+---+------+
|4  |FPK   |
|5  |FPK   |
|10 |FPK   |
|8  |Others|
|9  |Others|
|1  |ABC   |
|2  |ABC   |
|6  |ABC   |
|7  |ABC   |
|3  |Others|
+---+------+

I hope the answer is helpful

Upvotes: 3

Related Questions