Thomas R
Thomas R

Reputation: 1217

Filtering in arrays located in cells of a pyspark.sql.dataframe

I‘m pretty new to PySpark with some Python experience. I‘m already able to filter rows of a dataframe and have written udf's that calculate results from arrays in DataFrame cells with an int or double as result. No, I need an array as output and after hours I haven't found a useful example.

Here is the probleme:

The DataFrame has the following scheme, where number is the number of entries of the arrays of the same DataFrame row:

DataFrame[number: int, code: array<string>, d1: array<double>, d2: array<double>]

Here is an example of the DataFrame called df1:

[4 ,['correct', 'correct', 'wrong', 'correct'], [33, 42, 35, 76], [12, 35, 15, 16]] 
[2 ,['correct', 'wrong'], [47, 43], [13, 17]] 

Now only if I have a 'correct' in the i‘s position of the code-column of a DataFrame row I want to keep the i‘s position of d1 and d2. Additionally I want to have a new numberNew with the left over number of positions. The resulting structure and DataFrame „df2“ should look like this:

DataFrame[number: int, numberNew: int, code: array<string>, d1: array<double>, d2: array<double>]

[4 , 3, ['correct', 'correct', 'correct'], [33, 42, 76], [12, 35, 16]] 
[2 , 1, ['correct'], [47], [13]] 

Among several other things (and based on an in Python successful solution) I tried the following code:

def filterDF(number, code, d1, d2):
    dataFiltered = []
    numberNew = 0
    for i in range(number):
        if code[i] == 'correct':
            dataFiltered.append([d1[i],d2[i]])
            countNew += 1
    newTable = {'countNew' : countNew, 'data' : dataFiltered}
    newDf = pd.DataFrame(newTable)
    return newDf    

from pyspark.sql.types import ArrayType
filterDFudf = sqlContext.udf.register("filterDF", filterDF, "Array<double>")

df2 = df1.select(df1.number, filterDFudf(df1.number, df1.code, df1.d1, df1.d2)).alias('dataNew')

I got a pretty long and not really helpful error message. I.e. there was the following information: TypeError: 'float' object has no attribute 'getitem'

It would be fantastic if someone here could show me how to solve this.

Upvotes: 0

Views: 266

Answers (2)

Merelda
Merelda

Reputation: 1338

For an alternative solution, you can also make use of the list comprehension in python for your function:

def get_filtered_data(code, d1, d2):

    indices = [i for i, s in enumerate(code) if 'correct' in s]
    d1_ = [d1[index] for index in indices]
    d2_ = [d2[index] for index in indices]
    return [len(indices), d1_, d2_]

udf_get_filtered_data = udf(get_filtered_data, ArrayType(StringType()))

df = df.withColumn('filtered_data', udf_get_filtered_data('code', 'd1', 'd2'))

df.show() returns the following

+------+--------------------+----------------+----------------+--------------------+
|number|                code|              d1|              d2|       filtered_data|
+------+--------------------+----------------+----------------+--------------------+
|     4|[correct, correct...|[33, 42, 35, 76]|[12, 35, 15, 16]|[3, [33, 42, 76],...|
|     2|    [correct, wrong]|        [47, 43]|        [13, 17]|     [1, [47], [13]]|
+------+--------------------+----------------+----------------+--------------------+

By the way, if you use

dataFiltered.append([d1[i],d2[i]]) 

It will not give you the desired result you specified ([33, 42, 76], [12, 35, 16]). Rather, it will give you ([33,12], [42,35], [76,16])

This answer above gives you the correct results in d1 and d2 in a separate list as mentioned in the question.

Upvotes: 1

user10485905
user10485905

Reputation: 11

You cannot return Pandas data frame from udf like this (there are other variants which supports this, but these don't match your logic), and the schema doesn't match the output anyway. Redefine your function like this:

def filterDF(number, code, d1, d2):
    dataFiltered = []
    countNew = 0
    for i in range(number):
        if code[i] == 'correct':
            dataFiltered.append([d1[i],d2[i]])
            countNew += 1
    return (countNew, dataFiltered)

filterDFudf = sqlContext.udf.register(
    "filterDF", filterDF, 
    "struct<countNew: long, data: array<array<long>>>"
)

Test:

df = sqlContext.createDataFrame([
    (4 ,['correct', 'correct', 'wrong', 'correct'], [33, 42, 35, 76], [12, 35, 15, 16]),
    (2 ,['correct', 'wrong'], [47, 43], [13, 17])
]).toDF("number", "code", "d1", "d2")

df.select(filterDFudf("number", "code", "d1", "d2")).show()
# +------------------------------+                                                
# |filterDF(number, code, d1, d2)|
# +------------------------------+
# |          [3, [[33, 12], [4...|
# |               [1, [[47, 13]]]|
# +------------------------------+

Upvotes: 1

Related Questions