Reputation: 65853
I have a Spark DataFrame that contains multiple columns with free text. Separately, I have a dictionary of regular expressions where each regex maps to a key.
For instance:
df = spark.sparkContext.parallelize([Row(**{'primary_loc': 'USA', 'description': 'PyCon happens annually in the United States, with satellite events in India, Brazil and Tokyo'}),
Row(**{'primary_loc': 'Canada', 'description': 'The annual hockey championship has some events occurring in the US'})]).toDF()
keywords = {'united states': re.compile(r'\b(usa|us|united states|texas|washington|new york)\b', re.I),
'india': re.compile(r'\b(india|bangalore|mumbai|delhi)\b', re.I),
'canada': re.compile(r'\b(canada|winnipeg|toronto|ontario|vancouver)\b', re.I),
'japan': re.compile(r'\b(japan|tokyo|kyoto)\b', re.I}
I want to be able to extract countries from the dataframe, such that I extract all countries from a set of columns (primary_loc
and description
in this case). So in this case, I'd get an output somewhat like
primary_loc | description | country
--------------------------------------------
USA | PyCon... | united states
USA | PyCon... | india
USA | PyCon... | brazil
USA | PyCon... | japan
Canada | The ann... | canada
Canada | The ann... | united states
To get an idea of the scale of the problem, I have around 12-15k regexes and a dataframe with around 90 million rows.
I've tried using a Python UDF that looks somewhat like:
def get_countries(row):
rd = row.asDict()
rows_out = []
for p, k in keywords.items():
if k.search(rd['PRIMARY_LOC']) or k.search(rd['DESCRIPTION']):
rows_out.append(Row(**{'product': p, **rd}))
return rows_out
newDF = df.rdd.flatMap(lambda row: get_countries(row)).toDF()
but this is excruciatingly slow, even when operating on a subset of 10k or so rows.
If it matters, I'm using PySpark via DataBricks on Azure.
Upvotes: 1
Views: 1982
Reputation: 65853
For reference, I ended up solving the problem with a variant of Paul's answer. I built an Aho-Corasick automaton using pyahocorasick and pre-created the dictionary of keywords and a reverse lookup data structure. Since the Aho-Corasick algorithm doesn't deal with word boundaries etc., I still apply the corresponding regexes on any matches - but at least with my dataset, only a few (single-digit, typically) of the 10k regexes will result in a match, and this approach allows me to restrict myself to only those. My run-time for this problem went from 360,000 core-minutes (so 6000 hours on a single core) to around ~100 core-minutes with this approach.
So:
import ahocorasick
import re
def build_ahoacorasick_from_keywords(kw_dict):
'''Build an automaton for searching for keywords in a haystack - also build an inverted dictionary of keyword -> locations and return both'''
automaton = ahocorasick.Automaton()
inverted = {}
cnt = 0
for location, keyword_string in kw_dict.items():
keywords = [_.lower() for _ in keyword_string.split(',') if _.strip()]
for kw in keywords:
automaton.add_word(kw, (cnt, kw))
cnt += 1
if kw in inverted:
inverted[kw].append(location)
else:
inverted[kw] = [location]
automaton.make_automaton()
return automaton, inverted
def get_locations(description, automaton, inverted_dict):
description = description or ''
haystack = description.lower().strip()
locations = set()
for _, (__, word) in automaton.iter(haystack):
temp_re = r'\b{}\b'.format(re.escape(word))
if re.search(temp_re, haystack):
locations.update(inverted_dict[word])
return list(locations) if locations else None
# kw_dict looks like {'united states': "usa,us,texas,washington,new york,united states", ...}
automaton, inverted = build_ahoacorasick_from_keywords(kw_dict)
my_udf = F.udf(lambda title, description: get_locations(description, automaton, inverted), ArrayType(StringType()))
new_df = df.withColumn('locations', my_udf(df.description))
# save new_df etc
Upvotes: 1
Reputation: 3824
As suggested by @mck, you can perform the regexp matching using the native API with the join strategy. I use UDF only as a last resource. The trick uses regexp_replace
from the Scala API which allows input patterns from Column
s. The function replaces the matched characters with an asterisk (it could be any char not present in your description
column!) then contains
checks for the asterisk and transforms the match to a boolean as a join condition.
Here is the example:
val df_data = Seq(
("USA", "PyCon happens annually in the United States, with satellite events in India, Brazil and Tokyo"),
("Canada", "The annual hockey championship has some events occurring in the US")
).toDF("primary_loc", "description")
val df_keywords = Seq(
("united states", "(?i)\\b(usa|us|united states|texas|washington|new york)\\b"),
("india", "(?i)\\b(india|bangalore|mumbai|delhi)\\b"),
("canada", "(?i)\\b(canada|winnipeg|toronto|ontario|vancouver)\\b"),
("japan", "(?i)\\b(japan|tokyo|kyoto)\\b"),
("brazil", "(?i)\\b(brazil)\\b"),
("spain", "(?i)\\b(spain|es|barcelona)\\b")
).toDF("country", "pattern")
df_data.join(df_keywords,
regexp_replace(df_data("description"), df_keywords("pattern"), lit("*")).contains("*"), "inner")
.show(truncate=false)
Result:
+-----------+---------------------------------------------------------------------------------------------+-------------+--------------------------------------------------------+
|primary_loc|description |country |pattern |
+-----------+---------------------------------------------------------------------------------------------+-------------+--------------------------------------------------------+
|USA |PyCon happens annually in the United States, with satellite events in India, Brazil and Tokyo|united states|(?i)\b(usa|us|united states|texas|washington|new york)\b|
|Canada |The annual hockey championship has some events occurring in the US |united states|(?i)\b(usa|us|united states|texas|washington|new york)\b|
|USA |PyCon happens annually in the United States, with satellite events in India, Brazil and Tokyo|india |(?i)\b(india|bangalore|mumbai|delhi)\b |
|USA |PyCon happens annually in the United States, with satellite events in India, Brazil and Tokyo|japan |(?i)\b(japan|tokyo|kyoto)\b |
|USA |PyCon happens annually in the United States, with satellite events in India, Brazil and Tokyo|brazil |(?i)\b(brazil)\b |
+-----------+---------------------------------------------------------------------------------------------+-------------+--------------------------------------------------------+
Unfortunatelly, I cound not make it work using the Python API. It returns a TypeError: Column is not iterable
. Looks like the input patterns can only be strings. The patterns were also prefixed with (?i)
to make them case insensitive. Also make sure the df_keywords
is broadcasted to all workers. The explain
output is:
== Physical Plan ==
BroadcastNestedLoopJoin BuildLeft, Inner, Contains(regexp_replace(description#307, pattern#400, *), *)
:- BroadcastExchange IdentityBroadcastMode
: +- LocalTableScan [primary_loc#306, description#307]
+- LocalTableScan [country#399, pattern#400]
Upvotes: 1
Reputation: 1174
Since you seem to only want to match exact words regex is way more expensive then just looking the words up. Assuming you only need to match whole words and not a complicated regular expression (e.g. numbers etc.) you can split the description into words and perform a lookup. If the words are saved in sets lookup will be O(1)
Code would look something like this
single_keywords = {'united states': {"usa", "us", "texas", "washington", "new york"},
'india': {"india", "bangalore", "mumbai", "delhi"},
'canada': {"canada", "winnipeg", "toronto", "ontario", "vancouver"},
'japan': {"japan", "tokyo", "kyoto"},
}
multiword_keywords = {"united states": {("united", "states")}}
def get_countries(row):
rd = row.asDict()
rows_out = []
words = rd['PRIMARY_LOC'].split(" ") + rd['DESCRIPTION'].split(" ")
for p, k in single_keywords.items():
if any((word in k for word in words)):
rows_out.append(Row(**{'product': p, **rd}))
for p, k in multiword_keywords.items():
if any((all([word in t for word in words]) for t in k)):
rows_out.append(Row(**{'product': p, **rd}))
return rows_out
Upvotes: 2