Reputation: 329
I'm trying to do a very specific action on my dataframe but i can't find a way to do it nicely .
I have a dataframe that look like that :
+------------------+----------------+--------+
|CIVILITY_PREDICTED|COUNTRY_CODE_PRE| name|
+------------------+----------------+--------+
| M| CA|A HANNAN|
| M| CA| A JAY|
| M| GB| A JAY|
| M| CA| A K I L|
| F| CA| A LAH|
| ?| CN| A LIAN|
| ?| CN| A MEI|
| ?| CN| A MIN|
| F| CA| A RIN|
| M| CA| A S M|
| ?| CN| A YING|
| F| CA|AA ISHAH|
| M| CA| AABAN|
| M| GB| AABAN|
| M| US| AABAN|
| M| GB| AABAS|
| F| CA| AABEER|
| M| CA| AABEL|
| F| US| AABHA|
| F| GB| AABIA|
+------------------+----------------+--------+
As you can see in CIVILITY_PREDICTED i have some "?" . Every "name" has one row per country , sometimes the CIVILITY_PREDICTED is "?" for a country but not for another one for the same name.
So basically i want for each "?" to add the most common CIVILITY_PREDICTED based on the other countries for the name.
I tried to do it by doing this (e is the dataframe and to_predict is another one with only the name i want to get):
e.filter($"CIVILITY_PREDICTED" === "?" && $"name".isNotNull)
.select("COUNTRY_CODE_PRE","CIVILITY_PREDICTED","name").
collect().map(a => {
to_predict
.filter($"name" === a.get(3))
.filter( $"CIVILITY_PREDICTED" !== "?")
.groupBy("CIVILITY_PREDICTED")
.count()
.agg(org.apache.spark.sql.functions.max("CIVILITY_PREDICTED")).show()
With this i get the CIVLITY_PREDICTED with most occurence for each name , but i guess it's not very optimal and i don't know how to then replace the corresponding "?" in the dataframe with this one.
Do someone know please ? Thank you very much
Upvotes: 0
Views: 48
Reputation: 1386
Window functions are the key here. The following solution uses first_value to pick the first gender value based on row count.
spark.sql("""select distinct name, first_value(CIVILITY_PREDICTED) over (partition by name order by count(*) desc) civility
from civ
group by name, CIVILITY_PREDICTED
""").show
Based on the data recreated as seen below, this returns:
+-----+--------+
| name|civility|
+-----+--------+
|AABAN| M|
+-----+--------+
To see the original value as well as the most common one:
spark.sql("""select name, CIVILITY_PREDICTED,
first(CIVILITY_PREDICTED)
over (partition by name order by count(*) desc) civility
from civ
group by 1,2
order by 1,2
""").show
which returns
+-----+------------------+--------+
| name|CIVILITY_PREDICTED|civility|
+-----+------------------+--------+
|AABAN| ?| M|
|AABAN| M| M|
+-----+------------------+--------+
I recreated only one name with the issue you're trying to solve. AABAN is ? for one row, and 'M' for two others.
val civ = """+------------------+----------------+--------+
|CIVILITY_PREDICTED|COUNTRY_CODE_PRE| name|
+------------------+----------------+--------+
| ?| CA| AABAN|
| M| GB| AABAN|
| M| US| AABAN|""".stripMargin.replaceAll("\\+", "").replaceAll("\\-", "").split("\n").filter(_.size>10)
val df = spark.read
.option("ignoreTrailingWhiteSpace", "true")
.option("ignoreLeadingWhiteSpace", "true")
.option("delimiter", "|")
.option("header", "true")
.csv(spark.sparkContext.parallelize(civ).toDS)
.drop("_c3")
df.createOrReplaceTempView("civ")
df.orderBy("name").show(99)
+------------------+----------------+-----+
|CIVILITY_PREDICTED|COUNTRY_CODE_PRE| name|
+------------------+----------------+-----+
| ?| CA|AABAN|
| M| GB|AABAN|
| M| US|AABAN|
+------------------+----------------+-----+
Upvotes: 0