echo55
echo55

Reputation: 329

Spark dataframe operation

I'm trying to do a very specific action on my dataframe but i can't find a way to do it nicely .

I have a dataframe that look like that :

+------------------+----------------+--------+
|CIVILITY_PREDICTED|COUNTRY_CODE_PRE|    name|
+------------------+----------------+--------+
|                 M|              CA|A HANNAN|
|                 M|              CA|   A JAY|
|                 M|              GB|   A JAY|
|                 M|              CA| A K I L|
|                 F|              CA|   A LAH|
|                 ?|              CN|  A LIAN|
|                 ?|              CN|   A MEI|
|                 ?|              CN|   A MIN|
|                 F|              CA|   A RIN|
|                 M|              CA|   A S M|
|                 ?|              CN|  A YING|
|                 F|              CA|AA ISHAH|
|                 M|              CA|   AABAN|
|                 M|              GB|   AABAN|
|                 M|              US|   AABAN|
|                 M|              GB|   AABAS|
|                 F|              CA|  AABEER|
|                 M|              CA|   AABEL|
|                 F|              US|   AABHA|
|                 F|              GB|   AABIA|
+------------------+----------------+--------+

As you can see in CIVILITY_PREDICTED i have some "?" . Every "name" has one row per country , sometimes the CIVILITY_PREDICTED is "?" for a country but not for another one for the same name.

So basically i want for each "?" to add the most common CIVILITY_PREDICTED based on the other countries for the name.

I tried to do it by doing this (e is the dataframe and to_predict is another one with only the name i want to get):

e.filter($"CIVILITY_PREDICTED" === "?" && $"name".isNotNull)
.select("COUNTRY_CODE_PRE","CIVILITY_PREDICTED","name").
collect().map(a => {
    to_predict
    .filter($"name" === a.get(3))
    .filter( $"CIVILITY_PREDICTED" !== "?")
    .groupBy("CIVILITY_PREDICTED")
    .count()
    .agg(org.apache.spark.sql.functions.max("CIVILITY_PREDICTED")).show()

With this i get the CIVLITY_PREDICTED with most occurence for each name , but i guess it's not very optimal and i don't know how to then replace the corresponding "?" in the dataframe with this one.

Do someone know please ? Thank you very much

Upvotes: 0

Views: 48

Answers (1)

Lars Skaug
Lars Skaug

Reputation: 1386

Window functions are the key here. The following solution uses first_value to pick the first gender value based on row count.

spark.sql("""select distinct name, first_value(CIVILITY_PREDICTED) over (partition by name order by count(*) desc) civility
             from civ
             group by name, CIVILITY_PREDICTED
             """).show

Based on the data recreated as seen below, this returns:

+-----+--------+
| name|civility|
+-----+--------+
|AABAN|       M|
+-----+--------+

To see the original value as well as the most common one:

spark.sql("""select name, CIVILITY_PREDICTED, 
             first(CIVILITY_PREDICTED) 
              over (partition by name order by count(*) desc) civility
             from civ 
             group by 1,2 
             order by 1,2
             """).show

which returns

+-----+------------------+--------+
| name|CIVILITY_PREDICTED|civility|
+-----+------------------+--------+
|AABAN|                 ?|       M|
|AABAN|                 M|       M|
+-----+------------------+--------+

I recreated only one name with the issue you're trying to solve. AABAN is ? for one row, and 'M' for two others.

val civ = """+------------------+----------------+--------+
|CIVILITY_PREDICTED|COUNTRY_CODE_PRE|    name|
+------------------+----------------+--------+
|                 ?|              CA|   AABAN|
|                 M|              GB|   AABAN|
|                 M|              US|   AABAN|""".stripMargin.replaceAll("\\+", "").replaceAll("\\-", "").split("\n").filter(_.size>10)

val df = spark.read
  .option("ignoreTrailingWhiteSpace", "true")
  .option("ignoreLeadingWhiteSpace", "true")
  .option("delimiter", "|")
  .option("header", "true")
  .csv(spark.sparkContext.parallelize(civ).toDS)
  .drop("_c3")

df.createOrReplaceTempView("civ")

df.orderBy("name").show(99)

+------------------+----------------+-----+
|CIVILITY_PREDICTED|COUNTRY_CODE_PRE| name|
+------------------+----------------+-----+
|                 ?|              CA|AABAN|
|                 M|              GB|AABAN|
|                 M|              US|AABAN|
+------------------+----------------+-----+

Upvotes: 0

Related Questions