Reputation: 25790
I have the following DataFrame:
name,email,phone,country
------------------------------------------------
[Mike,[email protected],+91-9999999999,Italy]
[Alex,[email protected],+91-9999999998,France]
[John,[email protected],+1-1111111111,United States]
[Donald,[email protected],+1-2222222222,United States]
[Dan,[email protected],+91-9999444999,Poland]
[Scott,[email protected],+91-9111999998,Spain]
[Rob,[email protected],+91-9114444998,Italy]
exposed as temp table tagged_users
:
resultDf.createOrReplaceTempView("tagged_users")
I need to add additional column tag
to this DataFrame and assign calculated tags by different SQL conditions, which are described in the following map(key - tag name, value - condition for WHERE
clause)
val tags = Map(
"big" -> "country IN (SELECT * FROM big_countries)",
"medium" -> "country IN (SELECT * FROM medium_countries)",
//2000 other different tags and conditions
"sometag" -> "name = 'Donald' AND email = '[email protected]' AND phone = '+1-2222222222'"
)
I have the following DataFrames(as data dictionaries) in order to be able to use them in SQL query:
Seq("Italy", "France", "United States", "Spain").toDF("country").createOrReplaceTempView("big_countries")
Seq("Poland", "Hungary", "Spain").toDF("country").createOrReplaceTempView("medium_countries")
I want to test each line in my tagged_users
table and assign it appropriate tags. I tried to implement the following logic in order to achieve it:
tags.foreach {
case (tag, tagCondition) => {
resultDf = spark.sql(buildTagQuery(tag, tagCondition, "tagged_users"))
.withColumn("tag", lit(tag).cast(StringType))
}
}
def buildTagQuery(tag: String, tagCondition: String, table: String): String = {
f"SELECT * FROM $table WHERE $tagCondition"
}
but right now I don't know how to accumulate tags and not override them. Right now as the result I have the following DataFrame:
name,email,phone,country,tag
Dan,[email protected],+91-9999444999,Poland,medium
Scott,[email protected],+91-9111999998,Spain,medium
but I need something like:
name,email,phone,country,tag
Mike,[email protected],+91-9999999999,Italy,big
Alex,[email protected],+91-9999999998,France,big
John,[email protected],+1-1111111111,United States,big
Donald,[email protected],+1-2222222222,United States,(big|sometag)
Dan,[email protected],+91-9999444999,Poland,medium
Scott,[email protected],+91-9111999998,Spain,(big|medium)
Rob,[email protected],+91-9114444998,Italy,big
Please note that Donal
should have 2 tags (big|sometag)
and Scott
should have 2 tags (big|medium)
.
Please show how to implement it.
UPDATED
val spark = SparkSession
.builder()
.appName("Java Spark SQL basic example")
.config("spark.master", "local")
.getOrCreate();
import spark.implicits._
import spark.sql
Seq("Italy", "France", "United States", "Spain").toDF("country").createOrReplaceTempView("big_countries")
Seq("Poland", "Hungary", "Spain").toDF("country").createOrReplaceTempView("medium_countries")
val df = Seq(
("Mike", "[email protected]", "+91-9999999999", "Italy"),
("Alex", "[email protected]", "+91-9999999998", "France"),
("John", "[email protected]", "+1-1111111111", "United States"),
("Donald", "[email protected]", "+1-2222222222", "United States"),
("Dan", "[email protected]", "+91-9999444999", "Poland"),
("Scott", "[email protected]", "+91-9111999998", "Spain"),
("Rob", "[email protected]", "+91-9114444998", "Italy")).toDF("name", "email", "phone", "country")
df.collect.foreach(println)
df.createOrReplaceTempView("tagged_users")
val tags = Map(
"big" -> "country IN (SELECT * FROM big_countries)",
"medium" -> "country IN (SELECT * FROM medium_countries)",
"sometag" -> "name = 'Donald' AND email = '[email protected]' AND phone = '+1-2222222222'")
val sep_tag = tags.map((x) => { s"when array_contains(" + x._1 + ", country) then '" + x._1 + "' " }).mkString
val combine_sel_tag1 = tags.map((x) => { s" array_contains(" + x._1 + ",country) " }).mkString(" and ")
val combine_sel_tag2 = tags.map((x) => x._1).mkString(" '(", "|", ")' ")
val combine_sel_all = " case when " + combine_sel_tag1 + " then " + combine_sel_tag2 + sep_tag + " end as tags "
val crosqry = tags.map((x) => { s" cross join ( select collect_list(country) as " + x._1 + " from " + x._1 + "_countries) " + x._1 + " " }).mkString
val qry = " select name,email,phone,country, " + combine_sel_all + " from tagged_users " + crosqry
spark.sql(qry).show
spark.stop()
fails with the following exception:
Caused by: org.apache.spark.sql.catalyst.analysis.NoSuchTableException: Table or view 'sometag_countries' not found in database 'default';
at org.apache.spark.sql.catalyst.catalog.ExternalCatalog$class.requireTableExists(ExternalCatalog.scala:48)
at org.apache.spark.sql.catalyst.catalog.InMemoryCatalog.requireTableExists(InMemoryCatalog.scala:45)
at org.apache.spark.sql.catalyst.catalog.InMemoryCatalog.getTable(InMemoryCatalog.scala:326)
at org.apache.spark.sql.catalyst.catalog.ExternalCatalogWithListener.getTable(ExternalCatalogWithListener.scala:138)
at org.apache.spark.sql.catalyst.catalog.SessionCatalog.lookupRelation(SessionCatalog.scala:701)
at org.apache.spark.sql.catalyst.analysis.Analyzer$ResolveRelations$.org$apache$spark$sql$catalyst$analysis$Analyzer$ResolveRelations$$lookupTableFromCatalog(Analyzer.scala:730)
... 74 more
Upvotes: 1
Views: 905
Reputation: 8711
Check out this DF solution:
scala> val df = Seq(("Mike","[email protected]","+91-9999999999","Italy"),
| ("Alex","[email protected]","+91-9999999998","France"),
| ("John","[email protected]","+1-1111111111","United States"),
| ("Donald","[email protected]","+1-2222222222","United States"),
| ("Dan","[email protected]","+91-9999444999","Poland"),
| ("Scott","[email protected]","+91-9111999998","Spain"),
| ("Rob","[email protected]","+91-9114444998","Italy")
| ).toDF("name","email","phone","country")
df: org.apache.spark.sql.DataFrame = [name: string, email: string ... 2 more fields]
scala> val dfbc=Seq("Italy", "France", "United States", "Spain").toDF("country")
dfbc: org.apache.spark.sql.DataFrame = [country: string]
scala> val dfmc=Seq("Poland", "Hungary", "Spain").toDF("country")
dfmc: org.apache.spark.sql.DataFrame = [country: string]
scala> val dfbc2=dfbc.agg(collect_list('country).as("bcountry"))
dfbc2: org.apache.spark.sql.DataFrame = [bcountry: array<string>]
scala> val dfmc2=dfmc.agg(collect_list('country).as("mcountry"))
dfmc2: org.apache.spark.sql.DataFrame = [mcountry: array<string>]
scala> val df2=df.crossJoin(dfbc2).crossJoin(dfmc2)
df2: org.apache.spark.sql.DataFrame = [name: string, email: string ... 4 more fields]
scala> df2.selectExpr("*","case when array_contains(bcountry,country) and array_contains(mcountry,country) then '(big|medium)' when array_contains(bcountry,country) then 'big' when array_contains(mcountry,country) then 'medium' else 'none' end as `tags`").select("name","email","phone","country","tags").show(false)
+------+------------------+--------------+-------------+------------+
|name |email |phone |country |tags |
+------+------------------+--------------+-------------+------------+
|Mike |[email protected] |+91-9999999999|Italy |big |
|Alex |[email protected] |+91-9999999998|France |big |
|John |[email protected] |+1-1111111111 |United States|big |
|Donald|[email protected]|+1-2222222222 |United States|big |
|Dan |[email protected] |+91-9999444999|Poland |medium |
|Scott |[email protected] |+91-9111999998|Spain |(big|medium)|
|Rob |[email protected] |+91-9114444998|Italy |big |
+------+------------------+--------------+-------------+------------+
scala>
SQL approach
scala> Seq(("Mike","[email protected]","+91-9999999999","Italy"),
| ("Alex","[email protected]","+91-9999999998","France"),
| ("John","[email protected]","+1-1111111111","United States"),
| ("Donald","[email protected]","+1-2222222222","United States"),
| ("Dan","[email protected]","+91-9999444999","Poland"),
| ("Scott","[email protected]","+91-9111999998","Spain"),
| ("Rob","[email protected]","+91-9114444998","Italy")
| ).toDF("name","email","phone","country").createOrReplaceTempView("tagged_users")
scala> Seq("Italy", "France", "United States", "Spain").toDF("country").createOrReplaceTempView("big_countries")
scala> Seq("Poland", "Hungary", "Spain").toDF("country").createOrReplaceTempView("medium_countries")
scala> spark.sql(""" select name,email,phone,country,case when array_contains(bc,country) and array_contains(mc,country) then '(big|medium)' when array_contains(bc,country) then 'big' when array_contains(mc,country) then 'medium' else 'none' end as tags from tagged_users cross join ( select collect_list(country) as bc from big_countries ) b cross join ( select collect_list(country) as mc from medium_countries ) c """).show(false)
+------+------------------+--------------+-------------+------------+
|name |email |phone |country |tags |
+------+------------------+--------------+-------------+------------+
|Mike |[email protected] |+91-9999999999|Italy |big |
|Alex |[email protected] |+91-9999999998|France |big |
|John |[email protected] |+1-1111111111 |United States|big |
|Donald|[email protected]|+1-2222222222 |United States|big |
|Dan |[email protected] |+91-9999444999|Poland |medium |
|Scott |[email protected] |+91-9111999998|Spain |(big|medium)|
|Rob |[email protected] |+91-9114444998|Italy |big |
+------+------------------+--------------+-------------+------------+
scala>
Iterating through the tags
scala> val tags = Map(
| "big" -> "country IN (SELECT * FROM big_countries)",
| "medium" -> "country IN (SELECT * FROM medium_countries)")
tags: scala.collection.immutable.Map[String,String] = Map(big -> country IN (SELECT * FROM big_countries), medium -> country IN (SELECT * FROM medium_countries))
scala> val sep_tag = tags.map( (x) => { s"when array_contains("+x._1+", country) then '" + x._1 + "' " } ).mkString
sep_tag: String = "when array_contains(big, country) then 'big' when array_contains(medium, country) then 'medium' "
scala> val combine_sel_tag1 = tags.map( (x) => { s" array_contains("+x._1+",country) " } ).mkString(" and ")
combine_sel_tag1: String = " array_contains(big,country) and array_contains(medium,country) "
scala> val combine_sel_tag2 = tags.map( (x) => x._1 ).mkString(" '(","|", ")' ")
combine_sel_tag2: String = " '(big|medium)' "
scala> val combine_sel_all = " case when " + combine_sel_tag1 + " then " + combine_sel_tag2 + sep_tag + " end as tags "
combine_sel_all: String = " case when array_contains(big,country) and array_contains(medium,country) then '(big|medium)' when array_contains(big, country) then 'big' when array_contains(medium, country) then 'medium' end as tags "
scala> val crosqry = tags.map( (x) => { s" cross join ( select collect_list(country) as "+x._1+" from "+x._1+"_countries) "+ x._1 + " " } ).mkString
crosqry: String = " cross join ( select collect_list(country) as big from big_countries) big cross join ( select collect_list(country) as medium from medium_countries) medium "
scala> val qry = " select name,email,phone,country, " + combine_sel_all + " from tagged_users " + crosqry
qry: String = " select name,email,phone,country, case when array_contains(big,country) and array_contains(medium,country) then '(big|medium)' when array_contains(big, country) then 'big' when array_contains(medium, country) then 'medium' end as tags from tagged_users cross join ( select collect_list(country) as big from big_countries) big cross join ( select collect_list(country) as medium from medium_countries) medium "
scala> spark.sql(qry).show
+------+------------------+--------------+-------------+------------+
| name| email| phone| country| tags|
+------+------------------+--------------+-------------+------------+
| Mike| [email protected]|+91-9999999999| Italy| big|
| Alex| [email protected]|+91-9999999998| France| big|
| John| [email protected]| +1-1111111111|United States| big|
|Donald|[email protected]| +1-2222222222|United States| big|
| Dan| [email protected]|+91-9999444999| Poland| medium|
| Scott| [email protected]|+91-9111999998| Spain|(big|medium)|
| Rob| [email protected]|+91-9114444998| Italy| big|
+------+------------------+--------------+-------------+------------+
scala>
UPDATE2:
scala> Seq(("Mike","[email protected]","+91-9999999999","Italy"),
| ("Alex","[email protected]","+91-9999999998","France"),
| ("John","[email protected]","+1-1111111111","United States"),
| ("Donald","[email protected]","+1-2222222222","United States"),
| ("Dan","[email protected]","+91-9999444999","Poland"),
| ("Scott","[email protected]","+91-9111999998","Spain"),
| ("Rob","[email protected]","+91-9114444998","Italy")
| ).toDF("name","email","phone","country").createOrReplaceTempView("tagged_users")
scala> Seq("Italy", "France", "United States", "Spain").toDF("country").createOrReplaceTempView("big_countries")
scala> Seq("Poland", "Hungary", "Spain").toDF("country").createOrReplaceTempView("medium_countries")
scala> val tags = Map(
| "big" -> "country IN (SELECT * FROM big_countries)",
| "medium" -> "country IN (SELECT * FROM medium_countries)",
| "sometag" -> "name = 'Donald' AND email = '[email protected]' AND phone = '+1-2222222222'")
tags: scala.collection.immutable.Map[String,String] = Map(big -> country IN (SELECT * FROM big_countries), medium -> country IN (SELECT * FROM medium_countries), sometag -> name = 'Donald' AND email = '[email protected]' AND phone = '+1-2222222222')
scala> val sql_tags = tags.map( x => { val p = x._2.trim.toUpperCase.split(" ");
| val qry = if(p.contains("IN") && p.contains("FROM"))
| s" case when array_contains((select collect_list("+p.head +") from " + p.last.replaceAll("[)]","")+ " ), " +p.head + " ) then '" + x._1 + " ' else '' end " + x._1 + " "
| else
| " case when " + x._2 + " then '" + x._1 + " ' else '' end " + x._1 + " ";
| qry } ).mkString(",")
sql_tags: String = " case when array_contains((select collect_list(COUNTRY) from BIG_COUNTRIES ), COUNTRY ) then 'big ' else '' end big , case when array_contains((select collect_list(COUNTRY) from MEDIUM_COUNTRIES ), COUNTRY ) then 'medium ' else '' end medium , case when name = 'Donald' AND email = '[email protected]' AND phone = '+1-2222222222' then 'sometag ' else '' end sometag "
scala> val outer_query = tags.map( x=> x._1).mkString(" regexp_replace(trim(concat(", ",", " )),' ','|') tags ")
outer_query: String = " regexp_replace(trim(concat(big,medium,sometag )),' ','|') tags "
scala> spark.sql(" select name,email, country, " + outer_query + " from ( select name,email, country ," + sql_tags + " from tagged_users ) " ).show
+------+------------------+-------------+-----------+
| name| email| country| tags|
+------+------------------+-------------+-----------+
| Mike| [email protected]| Italy| big|
| Alex| [email protected]| France| big|
| John| [email protected]|United States| big|
|Donald|[email protected]|United States|big|sometag|
| Dan| [email protected]| Poland| medium|
| Scott| [email protected]| Spain| big|medium|
| Rob| [email protected]| Italy| big|
+------+------------------+-------------+-----------+
scala>
Upvotes: 1
Reputation: 21
If you need to aggregate the results and not just execute each query perhaps use map instead of foreach then union the results
val o = tags.map {
case (tag, tagCondition) => {
val resultDf = spark.sql(buildTagQuery(tag, tagCondition, "tagged_users"))
.withColumn("tag", new Column("blah"))
resultDf
}
}
o.foldLeft(o.head) {
case (acc, df) => acc.union(df)
}
Upvotes: 1
Reputation: 1868
I would define multiple tags tables with columns value, tag.
Then your tags definition would be a collection say Seq[(String, String] where the first tuple element is the column on which the tag is calculated.
Lets say
Seq(
"country" -> "bigCountries", // Columns [country, bigCountry]
"country" -> "mediumCountries", // Columns [country, mediumCountry]
"email" -> "hotmailLosers" // [country, hotmailLoser]
)
Then iterate through this list, left join each table on the relevant column with the associated column.
After joining each table simply select your tags column to be the current value + the joined column if it is not null.
Upvotes: 0