Aviral Srivastava
Aviral Srivastava

Reputation: 4582

How to validate (and drop) a column based on a regex condition in Pyspark without multiple scanning and shuffling?

I want to validate columns on the basis of whether they have even a single invalid entry. My constraint is to avoid shuffling and multiple scanning in order for it to scale to petabytes.

I tried validating columns using a normal string comparison and it worked but I am unable to try using a regex. The problem statement that I have is as following:


| Column 1      | Column 2      | Column 3      | Column 4      | Column 5      |
| --------------| --------------| --------------| --------------| --------------|
|(123)-456-7890 | 123-456-7890  |(123)-456-789  |               |(123)-456-7890 |
|(123)-456-7890 | 123-4567890   |(123)-456-7890 |(123)-456-7890 | null          |
|(123)-456-7890 | 1234567890    |(123)-456-7890 |(123)-456-7890 | null          |

The valid formats are:

(xxx)-xxx-xxxx, xxx-xxx-xxxx, xxx-xxxxxxx and xxxxxxxxxx

And so, the o/p of the above input should be:

| Column 1      | Column 2      |
| --------------| --------------| 
|(123)-456-7890 | 123-456-7890  |
|(123)-456-7890 | 123-4567890   |
|(123)-456-7890 | 1234567890    |

My current code is as follows:

import regex as re
from pyspark.sql.functions import col, lit
from pyspark.sql.functions import sum as _sum
from pyspark.sql.functions import when
from pyspark.sql import Row

formats = [r'^(?:\(\d{3}\)-)\d{3}-\d{4}$',
           r'^(?:\d{3}-)\d{3}-\d{4}$', r'^(?:\d{3}-)\d{7}$', r'^\d{10}$']


def validate_format(number):
    length = len(number)
    if length == 14:
        if (re.match(formats[0], number)):
            return True
        return False
    if length == 12:
        if (re.match(formats[1], number)):
            return True
        return False
    if length == 11:
        if (re.match(formats[2], number)):
            return True
        return False
    if length == 10:
        if (re.match(formats[3], number)):
            return True
        return False
    return False


def create_dataframe(spark):
    my_cols = Row("Column1", "Column2", "Column3", "Column4")
    row_1 = my_cols('(617)-283-3811', 'Salah', 'Messi', None)
    row_2 = my_cols('617-2833811', 'Messi', 'Virgil', 'Messi')
    row_3 = my_cols('617-283-3811', 'Ronaldo', 'Messi', 'Ronaldo')
    row_seq = [row_1, row_2, row_3]
    df = spark.createDataFrame(row_seq)
    invalid_counts = invalid_counts_in_df(df)
    print(invalid_counts)


def invalid_counts_in_df(df):
    invalid_counts = df.select(
        *[_sum(when(validate_format(col(c)), lit(0)).otherwise(lit(1))).alias(c) for c in df.columns]).collect()
    return invalid_counts

When I was dealing with normal strings as in here, I was successful. However, now my function returns an error message:

>>> create_dataframe(spark)
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "<stdin>", line 8, in create_dataframe
  File "<stdin>", line 3, in invalid_counts_in_df
  File "<stdin>", line 3, in <listcomp>
  File "<stdin>", line 2, in validate_format
TypeError: object of type 'Column' has no len()

I am falling short of the appropriate method that I should use in order to invalidate or validate columns in the most efficient way. I understand multiple scanning and loads of shuffling is definitely not the way to go.

I expect to find a way to get the columns which have all the entries as valid formats.

Upvotes: 2

Views: 1480

Answers (1)

cronoik
cronoik

Reputation: 19385

In terms of performance you should always try to use the pyspark functions over python functions. Pyspark functions are optimized to utilize the ressource of your cluster and the data doesn't need to be converted to python objects.

The appropriate pyspark functions for your use case is rlike. Have a look at the example below:

from pyspark.sql import Row

my_cols = Row("Column1", "Column2", "Column3", "Column4")
row_1 = my_cols('(617)-283-3811', 'Salah', 'Messi', None)
row_2 = my_cols('617-2833811', 'Messi', 'Virgil', 'Messi')
row_3 = my_cols('617-283-3811', 'Ronaldo', 'Messi', 'Ronaldo')
row_seq = [row_1, row_2, row_3]

df = spark.createDataFrame(row_seq)

numberOfRows = df.count()

#I have simplified your regexes a bit because I don't see a reason 
#why you need non capturing groups 
expr = "^(\(\d{3}\)-\d{3}-\d{4})|(\d{3}-\d{3}-\d{4})|(\d{3}-\d{7})|(\d{10})$"

#you can also set it to df.columns
columnsToCheck = ['Column1']
columnsToRemove = []

for col in columnsToCheck:
    numberOfMatchingRows = df.filter(df[col].rlike(expr)).count()
    if numberOfMatchingRows < numberOfRows:
        columnsToRemove.append(col)

df = df.select(*[c for c in df.columns if c not in columnsToRemove])
df.show()

Output:

+--------------+-------+-------+-------+
|       Column1|Column2|Column3|Column4|
+--------------+-------+-------+-------+ 
|(617)-283-3811|  Salah|  Messi|   null| 
|   617-2833811|  Messi| Virgil|  Messi| 
|  617-283-3811|Ronaldo|  Messi|Ronaldo| 
+--------------+-------+-------+-------+

Upvotes: 3

Related Questions