alexanoid
alexanoid

Reputation: 25790

Spark SQL add column/update-accumulate value

I have the following DataFrame:

name,email,phone,country
------------------------------------------------
[Mike,[email protected],+91-9999999999,Italy]
[Alex,[email protected],+91-9999999998,France]
[John,[email protected],+1-1111111111,United States]
[Donald,[email protected],+1-2222222222,United States]
[Dan,[email protected],+91-9999444999,Poland]
[Scott,[email protected],+91-9111999998,Spain]
[Rob,[email protected],+91-9114444998,Italy]

exposed as temp table tagged_users:

resultDf.createOrReplaceTempView("tagged_users")

I need to add additional column tag to this DataFrame and assign calculated tags by different SQL conditions, which are described in the following map(key - tag name, value - condition for WHERE clause)

val tags = Map(
  "big" -> "country IN (SELECT * FROM big_countries)",
  "medium" -> "country IN (SELECT * FROM medium_countries)",
  //2000 other different tags and conditions
  "sometag" -> "name = 'Donald' AND email = '[email protected]' AND phone = '+1-2222222222'"
  )

I have the following DataFrames(as data dictionaries) in order to be able to use them in SQL query:

Seq("Italy", "France", "United States", "Spain").toDF("country").createOrReplaceTempView("big_countries")
Seq("Poland", "Hungary", "Spain").toDF("country").createOrReplaceTempView("medium_countries")

I want to test each line in my tagged_users table and assign it appropriate tags. I tried to implement the following logic in order to achieve it:

tags.foreach {
  case (tag, tagCondition) => {
    resultDf = spark.sql(buildTagQuery(tag, tagCondition, "tagged_users"))
       .withColumn("tag", lit(tag).cast(StringType))
  }
}

def buildTagQuery(tag: String, tagCondition: String, table: String): String = {
    f"SELECT * FROM $table WHERE $tagCondition"
}

but right now I don't know how to accumulate tags and not override them. Right now as the result I have the following DataFrame:

name,email,phone,country,tag
Dan,[email protected],+91-9999444999,Poland,medium
Scott,[email protected],+91-9111999998,Spain,medium

but I need something like:

name,email,phone,country,tag
Mike,[email protected],+91-9999999999,Italy,big
Alex,[email protected],+91-9999999998,France,big
John,[email protected],+1-1111111111,United States,big
Donald,[email protected],+1-2222222222,United States,(big|sometag)
Dan,[email protected],+91-9999444999,Poland,medium
Scott,[email protected],+91-9111999998,Spain,(big|medium)
Rob,[email protected],+91-9114444998,Italy,big

Please note that Donal should have 2 tags (big|sometag) and Scott should have 2 tags (big|medium).

Please show how to implement it.

UPDATED

val spark = SparkSession
  .builder()
  .appName("Java Spark SQL basic example")
  .config("spark.master", "local")
  .getOrCreate();

import spark.implicits._
import spark.sql

Seq("Italy", "France", "United States", "Spain").toDF("country").createOrReplaceTempView("big_countries")
Seq("Poland", "Hungary", "Spain").toDF("country").createOrReplaceTempView("medium_countries")

val df = Seq(
  ("Mike", "[email protected]", "+91-9999999999", "Italy"),
  ("Alex", "[email protected]", "+91-9999999998", "France"),
  ("John", "[email protected]", "+1-1111111111", "United States"),
  ("Donald", "[email protected]", "+1-2222222222", "United States"),
  ("Dan", "[email protected]", "+91-9999444999", "Poland"),
  ("Scott", "[email protected]", "+91-9111999998", "Spain"),
  ("Rob", "[email protected]", "+91-9114444998", "Italy")).toDF("name", "email", "phone", "country")

df.collect.foreach(println)

df.createOrReplaceTempView("tagged_users")

val tags = Map(
  "big" -> "country IN (SELECT * FROM big_countries)",
  "medium" -> "country IN (SELECT * FROM medium_countries)",
  "sometag" -> "name = 'Donald' AND email = '[email protected]' AND phone = '+1-2222222222'")

val sep_tag = tags.map((x) => { s"when array_contains(" + x._1 + ", country) then '" + x._1 + "' " }).mkString

val combine_sel_tag1 = tags.map((x) => { s" array_contains(" + x._1 + ",country) " }).mkString(" and ")

val combine_sel_tag2 = tags.map((x) => x._1).mkString(" '(", "|", ")' ")

val combine_sel_all = " case when " + combine_sel_tag1 + " then " + combine_sel_tag2 + sep_tag + " end as tags "

val crosqry = tags.map((x) => { s" cross join ( select collect_list(country) as " + x._1 + " from " + x._1 + "_countries) " + x._1 + "  " }).mkString

val qry = " select name,email,phone,country, " + combine_sel_all + " from tagged_users " + crosqry

spark.sql(qry).show

spark.stop()

fails with the following exception:

Caused by: org.apache.spark.sql.catalyst.analysis.NoSuchTableException: Table or view 'sometag_countries' not found in database 'default';
    at org.apache.spark.sql.catalyst.catalog.ExternalCatalog$class.requireTableExists(ExternalCatalog.scala:48)
    at org.apache.spark.sql.catalyst.catalog.InMemoryCatalog.requireTableExists(InMemoryCatalog.scala:45)
    at org.apache.spark.sql.catalyst.catalog.InMemoryCatalog.getTable(InMemoryCatalog.scala:326)
    at org.apache.spark.sql.catalyst.catalog.ExternalCatalogWithListener.getTable(ExternalCatalogWithListener.scala:138)
    at org.apache.spark.sql.catalyst.catalog.SessionCatalog.lookupRelation(SessionCatalog.scala:701)
    at org.apache.spark.sql.catalyst.analysis.Analyzer$ResolveRelations$.org$apache$spark$sql$catalyst$analysis$Analyzer$ResolveRelations$$lookupTableFromCatalog(Analyzer.scala:730)
    ... 74 more

Upvotes: 1

Views: 905

Answers (3)

stack0114106
stack0114106

Reputation: 8711

Check out this DF solution:

scala> val df = Seq(("Mike","[email protected]","+91-9999999999","Italy"),
     | ("Alex","[email protected]","+91-9999999998","France"),
     | ("John","[email protected]","+1-1111111111","United States"),
     | ("Donald","[email protected]","+1-2222222222","United States"),
     | ("Dan","[email protected]","+91-9999444999","Poland"),
     | ("Scott","[email protected]","+91-9111999998","Spain"),
     | ("Rob","[email protected]","+91-9114444998","Italy")
     | ).toDF("name","email","phone","country")
df: org.apache.spark.sql.DataFrame = [name: string, email: string ... 2 more fields]

scala> val dfbc=Seq("Italy", "France", "United States", "Spain").toDF("country")
dfbc: org.apache.spark.sql.DataFrame = [country: string]

scala> val dfmc=Seq("Poland", "Hungary", "Spain").toDF("country")
dfmc: org.apache.spark.sql.DataFrame = [country: string]

scala> val dfbc2=dfbc.agg(collect_list('country).as("bcountry"))
dfbc2: org.apache.spark.sql.DataFrame = [bcountry: array<string>]

scala> val dfmc2=dfmc.agg(collect_list('country).as("mcountry"))
dfmc2: org.apache.spark.sql.DataFrame = [mcountry: array<string>]

scala> val df2=df.crossJoin(dfbc2).crossJoin(dfmc2)
df2: org.apache.spark.sql.DataFrame = [name: string, email: string ... 4 more fields]

scala> df2.selectExpr("*","case when array_contains(bcountry,country) and array_contains(mcountry,country) then '(big|medium)' when array_contains(bcountry,country) then 'big' when array_contains(mcountry,country) then 'medium' else 'none' end as `tags`").select("name","email","phone","country","tags").show(false)
+------+------------------+--------------+-------------+------------+
|name  |email             |phone         |country      |tags        |
+------+------------------+--------------+-------------+------------+
|Mike  |[email protected]  |+91-9999999999|Italy        |big         |
|Alex  |[email protected]  |+91-9999999998|France       |big         |
|John  |[email protected]  |+1-1111111111 |United States|big         |
|Donald|[email protected]|+1-2222222222 |United States|big         |
|Dan   |[email protected]   |+91-9999444999|Poland       |medium      |
|Scott |[email protected] |+91-9111999998|Spain        |(big|medium)|
|Rob   |[email protected]   |+91-9114444998|Italy        |big         |
+------+------------------+--------------+-------------+------------+


scala>

SQL approach

scala> Seq(("Mike","[email protected]","+91-9999999999","Italy"),
     |       ("Alex","[email protected]","+91-9999999998","France"),
     |       ("John","[email protected]","+1-1111111111","United States"),
     |       ("Donald","[email protected]","+1-2222222222","United States"),
     |       ("Dan","[email protected]","+91-9999444999","Poland"),
     |       ("Scott","[email protected]","+91-9111999998","Spain"),
     |       ("Rob","[email protected]","+91-9114444998","Italy")
     |       ).toDF("name","email","phone","country").createOrReplaceTempView("tagged_users")

scala> Seq("Italy", "France", "United States", "Spain").toDF("country").createOrReplaceTempView("big_countries")

scala> Seq("Poland", "Hungary", "Spain").toDF("country").createOrReplaceTempView("medium_countries")

scala> spark.sql(""" select name,email,phone,country,case when array_contains(bc,country) and array_contains(mc,country) then '(big|medium)' when array_contains(bc,country) then 'big' when array_contains(mc,country) then 'medium' else 'none' end as tags from tagged_users cross join ( select collect_list(country) as bc from big_countries ) b cross join ( select collect_list(country) as mc from medium_countries ) c """).show(false)
+------+------------------+--------------+-------------+------------+
|name  |email             |phone         |country      |tags        |
+------+------------------+--------------+-------------+------------+
|Mike  |[email protected]  |+91-9999999999|Italy        |big         |
|Alex  |[email protected]  |+91-9999999998|France       |big         |
|John  |[email protected]  |+1-1111111111 |United States|big         |
|Donald|[email protected]|+1-2222222222 |United States|big         |
|Dan   |[email protected]   |+91-9999444999|Poland       |medium      |
|Scott |[email protected] |+91-9111999998|Spain        |(big|medium)|
|Rob   |[email protected]   |+91-9114444998|Italy        |big         |
+------+------------------+--------------+-------------+------------+


scala>

Iterating through the tags

scala> val tags = Map(
     |   "big" -> "country IN (SELECT * FROM big_countries)",
     |   "medium" -> "country IN (SELECT * FROM medium_countries)")
tags: scala.collection.immutable.Map[String,String] = Map(big -> country IN (SELECT * FROM big_countries), medium -> country IN (SELECT * FROM medium_countries))

scala> val sep_tag = tags.map( (x) => { s"when array_contains("+x._1+", country) then '" + x._1 + "' " } ).mkString
sep_tag: String = "when array_contains(big, country) then 'big' when array_contains(medium, country) then 'medium' "

scala> val combine_sel_tag1 = tags.map( (x) => { s" array_contains("+x._1+",country) " } ).mkString(" and ")
combine_sel_tag1: String = " array_contains(big,country)  and  array_contains(medium,country) "

scala> val combine_sel_tag2 = tags.map( (x) => x._1 ).mkString(" '(","|", ")' ")
combine_sel_tag2: String = " '(big|medium)' "

scala> val combine_sel_all = " case when " + combine_sel_tag1 + " then " + combine_sel_tag2 +  sep_tag + " end as tags "
combine_sel_all: String = " case when  array_contains(big,country)  and  array_contains(medium,country)  then  '(big|medium)' when array_contains(big, country) then 'big' when array_contains(medium, country) then 'medium'  end as tags "

scala> val crosqry = tags.map( (x) => { s" cross join ( select collect_list(country) as "+x._1+" from "+x._1+"_countries) "+ x._1 + "  " } ).mkString
crosqry: String = " cross join ( select collect_list(country) as big from big_countries) big   cross join ( select collect_list(country) as medium from medium_countries) medium  "

scala> val qry = " select name,email,phone,country, " + combine_sel_all + " from tagged_users " + crosqry
qry: String = " select name,email,phone,country,  case when  array_contains(big,country)  and  array_contains(medium,country)  then  '(big|medium)' when array_contains(big, country) then 'big' when array_contains(medium, country) then 'medium'  end as tags  from tagged_users  cross join ( select collect_list(country) as big from big_countries) big   cross join ( select collect_list(country) as medium from medium_countries) medium  "

scala> spark.sql(qry).show
+------+------------------+--------------+-------------+------------+
|  name|             email|         phone|      country|        tags|
+------+------------------+--------------+-------------+------------+
|  Mike|  [email protected]|+91-9999999999|        Italy|         big|
|  Alex|  [email protected]|+91-9999999998|       France|         big|
|  John|  [email protected]| +1-1111111111|United States|         big|
|Donald|[email protected]| +1-2222222222|United States|         big|
|   Dan|   [email protected]|+91-9999444999|       Poland|      medium|
| Scott| [email protected]|+91-9111999998|        Spain|(big|medium)|
|   Rob|   [email protected]|+91-9114444998|        Italy|         big|
+------+------------------+--------------+-------------+------------+


scala>

UPDATE2:

scala> Seq(("Mike","[email protected]","+91-9999999999","Italy"),
     | ("Alex","[email protected]","+91-9999999998","France"),
     | ("John","[email protected]","+1-1111111111","United States"),
     | ("Donald","[email protected]","+1-2222222222","United States"),
     | ("Dan","[email protected]","+91-9999444999","Poland"),
     | ("Scott","[email protected]","+91-9111999998","Spain"),
     | ("Rob","[email protected]","+91-9114444998","Italy")
     | ).toDF("name","email","phone","country").createOrReplaceTempView("tagged_users")

scala> Seq("Italy", "France", "United States", "Spain").toDF("country").createOrReplaceTempView("big_countries")

scala> Seq("Poland", "Hungary", "Spain").toDF("country").createOrReplaceTempView("medium_countries")

scala> val tags = Map(
     |   "big" -> "country IN (SELECT * FROM big_countries)",
     |   "medium" -> "country IN (SELECT * FROM medium_countries)",
     |   "sometag" -> "name = 'Donald' AND email = '[email protected]' AND phone = '+1-2222222222'")
tags: scala.collection.immutable.Map[String,String] = Map(big -> country IN (SELECT * FROM big_countries), medium -> country IN (SELECT * FROM medium_countries), sometag -> name = 'Donald' AND email = '[email protected]' AND phone = '+1-2222222222')

scala> val sql_tags = tags.map( x => { val p = x._2.trim.toUpperCase.split(" ");
     | val qry = if(p.contains("IN") && p.contains("FROM"))
     | s" case when array_contains((select collect_list("+p.head +") from " + p.last.replaceAll("[)]","")+ " ), " +p.head + " ) then '" + x._1 + " ' else '' end " + x._1 + " "
     | else
     | " case when " + x._2 + " then '" + x._1 + " ' else '' end " + x._1 + " ";
     | qry } ).mkString(",")
sql_tags: String = " case when array_contains((select collect_list(COUNTRY) from BIG_COUNTRIES ), COUNTRY ) then 'big ' else '' end big , case when array_contains((select collect_list(COUNTRY) from MEDIUM_COUNTRIES ), COUNTRY ) then 'medium ' else '' end medium , case when name = 'Donald' AND email = '[email protected]' AND phone = '+1-2222222222' then 'sometag ' else '' end sometag "

scala> val outer_query = tags.map( x=> x._1).mkString(" regexp_replace(trim(concat(", ",", " )),' ','|') tags ")
outer_query: String = " regexp_replace(trim(concat(big,medium,sometag )),' ','|') tags "

scala> spark.sql(" select name,email, country, " + outer_query + " from ( select name,email, country ," + sql_tags + "   from tagged_users ) " ).show
+------+------------------+-------------+-----------+
|  name|             email|      country|       tags|
+------+------------------+-------------+-----------+
|  Mike|  [email protected]|        Italy|        big|
|  Alex|  [email protected]|       France|        big|
|  John|  [email protected]|United States|        big|
|Donald|[email protected]|United States|big|sometag|
|   Dan|   [email protected]|       Poland|     medium|
| Scott| [email protected]|        Spain| big|medium|
|   Rob|   [email protected]|        Italy|        big|
+------+------------------+-------------+-----------+


scala>

Upvotes: 1

Louis Thompson
Louis Thompson

Reputation: 21

If you need to aggregate the results and not just execute each query perhaps use map instead of foreach then union the results

 val o = tags.map {
  case (tag, tagCondition) => {
    val resultDf = spark.sql(buildTagQuery(tag, tagCondition, "tagged_users"))
      .withColumn("tag", new Column("blah"))
    resultDf
  }
}

o.foldLeft(o.head) {
  case (acc, df) => acc.union(df)
}

Upvotes: 1

Terry Dactyl
Terry Dactyl

Reputation: 1868

I would define multiple tags tables with columns value, tag.

Then your tags definition would be a collection say Seq[(String, String] where the first tuple element is the column on which the tag is calculated.

Lets say

Seq(
  "country" -> "bigCountries", // Columns [country, bigCountry]
  "country" -> "mediumCountries", // Columns [country, mediumCountry]
  "email" -> "hotmailLosers" // [country, hotmailLoser]
)

Then iterate through this list, left join each table on the relevant column with the associated column.

After joining each table simply select your tags column to be the current value + the joined column if it is not null.

Upvotes: 0

Related Questions