Reputation: 4582
I want to validate columns on the basis of whether they have even a single invalid entry. My constraint is to avoid shuffling and multiple scanning in order for it to scale to petabytes.
I tried validating columns using a normal string comparison and it worked but I am unable to try using a regex. The problem statement that I have is as following:
| Column 1 | Column 2 | Column 3 | Column 4 | Column 5 |
| --------------| --------------| --------------| --------------| --------------|
|(123)-456-7890 | 123-456-7890 |(123)-456-789 | |(123)-456-7890 |
|(123)-456-7890 | 123-4567890 |(123)-456-7890 |(123)-456-7890 | null |
|(123)-456-7890 | 1234567890 |(123)-456-7890 |(123)-456-7890 | null |
The valid formats are:
(xxx)-xxx-xxxx, xxx-xxx-xxxx, xxx-xxxxxxx and xxxxxxxxxx
And so, the o/p of the above input should be:
| Column 1 | Column 2 |
| --------------| --------------|
|(123)-456-7890 | 123-456-7890 |
|(123)-456-7890 | 123-4567890 |
|(123)-456-7890 | 1234567890 |
My current code is as follows:
import regex as re
from pyspark.sql.functions import col, lit
from pyspark.sql.functions import sum as _sum
from pyspark.sql.functions import when
from pyspark.sql import Row
formats = [r'^(?:\(\d{3}\)-)\d{3}-\d{4}$',
r'^(?:\d{3}-)\d{3}-\d{4}$', r'^(?:\d{3}-)\d{7}$', r'^\d{10}$']
def validate_format(number):
length = len(number)
if length == 14:
if (re.match(formats[0], number)):
return True
return False
if length == 12:
if (re.match(formats[1], number)):
return True
return False
if length == 11:
if (re.match(formats[2], number)):
return True
return False
if length == 10:
if (re.match(formats[3], number)):
return True
return False
return False
def create_dataframe(spark):
my_cols = Row("Column1", "Column2", "Column3", "Column4")
row_1 = my_cols('(617)-283-3811', 'Salah', 'Messi', None)
row_2 = my_cols('617-2833811', 'Messi', 'Virgil', 'Messi')
row_3 = my_cols('617-283-3811', 'Ronaldo', 'Messi', 'Ronaldo')
row_seq = [row_1, row_2, row_3]
df = spark.createDataFrame(row_seq)
invalid_counts = invalid_counts_in_df(df)
print(invalid_counts)
def invalid_counts_in_df(df):
invalid_counts = df.select(
*[_sum(when(validate_format(col(c)), lit(0)).otherwise(lit(1))).alias(c) for c in df.columns]).collect()
return invalid_counts
When I was dealing with normal strings as in here, I was successful. However, now my function returns an error message:
>>> create_dataframe(spark)
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
File "<stdin>", line 8, in create_dataframe
File "<stdin>", line 3, in invalid_counts_in_df
File "<stdin>", line 3, in <listcomp>
File "<stdin>", line 2, in validate_format
TypeError: object of type 'Column' has no len()
I am falling short of the appropriate method that I should use in order to invalidate or validate columns in the most efficient way. I understand multiple scanning and loads of shuffling is definitely not the way to go.
I expect to find a way to get the columns which have all the entries as valid formats.
Upvotes: 2
Views: 1480
Reputation: 19385
In terms of performance you should always try to use the pyspark functions over python functions. Pyspark functions are optimized to utilize the ressource of your cluster and the data doesn't need to be converted to python objects.
The appropriate pyspark functions for your use case is rlike. Have a look at the example below:
from pyspark.sql import Row
my_cols = Row("Column1", "Column2", "Column3", "Column4")
row_1 = my_cols('(617)-283-3811', 'Salah', 'Messi', None)
row_2 = my_cols('617-2833811', 'Messi', 'Virgil', 'Messi')
row_3 = my_cols('617-283-3811', 'Ronaldo', 'Messi', 'Ronaldo')
row_seq = [row_1, row_2, row_3]
df = spark.createDataFrame(row_seq)
numberOfRows = df.count()
#I have simplified your regexes a bit because I don't see a reason
#why you need non capturing groups
expr = "^(\(\d{3}\)-\d{3}-\d{4})|(\d{3}-\d{3}-\d{4})|(\d{3}-\d{7})|(\d{10})$"
#you can also set it to df.columns
columnsToCheck = ['Column1']
columnsToRemove = []
for col in columnsToCheck:
numberOfMatchingRows = df.filter(df[col].rlike(expr)).count()
if numberOfMatchingRows < numberOfRows:
columnsToRemove.append(col)
df = df.select(*[c for c in df.columns if c not in columnsToRemove])
df.show()
Output:
+--------------+-------+-------+-------+
| Column1|Column2|Column3|Column4|
+--------------+-------+-------+-------+
|(617)-283-3811| Salah| Messi| null|
| 617-2833811| Messi| Virgil| Messi|
| 617-283-3811|Ronaldo| Messi|Ronaldo|
+--------------+-------+-------+-------+
Upvotes: 3