Bagel912
Bagel912

Reputation: 331

pyspark: filtering and extract struct through ArrayType column

I'm using pyspark 2.2 and has the following schema

root
 |-- col1: string (nullable = true)
 |-- col2: array (nullable = true)
 |    |-- element: struct (containsNull = true)
 |    |    |-- id: string (nullable = true)
 |    |    |-- metadata: map (nullable = true)
 |    |    |    |-- key: string
 |    |    |    |-- value: string (valueContainsNull = true)

and data

+----+----------------------------------------------+
|col1|col2                                          |
+----+----------------------------------------------+
|A   |[[id1, [k -> v1]], [id2, [k2 -> v5, k -> v2]]]|
|B   |[[id3, [k -> v3]], [id4, [k3 -> v6, k -> v4]]]|
+----+----------------------------------------------+

col2 is a complex structure. It's an array of struct and every struct has two elements, an id string and a metadata map. (that's a simplified dataset, the real dataset has 10+ elements within struct and 10+ key-value pairs in the metadata field).

I want to form a query that returns a dataframe matching my filtering logic (say col1 == 'A' and col2.id == 'id2' and col2.metadata.k == 'v2').

The result would look like this, the filtering logic can match at most one struct within the array so in the second column it's just one struct instead of an array of one struct

+----+--------------------------+
|col1|col2_filtered             |
+----+--------------------------+
|A   |[id2, [k2 -> v5, k -> v2]]|
+----+--------------------------+

I know how to achieve this through explode, but the issue is col2 normally has over 100+ structs and there will be at most one matching my filtering logic so I don't think explode is a scalable solution.

Can someone tells me how to do that, thanks in advance!

Below is the code block for setting things up.

from pyspark.sql.types import StructType, StructField, StringType, ArrayType, MapType

schema = StructType([
    StructField('col1', StringType(), True),
    StructField('col2', ArrayType(
        StructType([
            StructField('id', StringType(), True),
            StructField('metadata', MapType(StringType(), StringType()), True)
        ])
    ))
])

data = [
    ('A', [('id1', {'k': 'v1'}), ('id2', {'k': 'v2', 'k2': 'v5'})]),
    ('B', [('id3', {'k': 'v3'}), ('id4', {'k': 'v4', 'k3': 'v6'})])
]

df = spark.createDataFrame(data=data, schema=schema)

Upvotes: 1

Views: 4887

Answers (2)

Bagel912
Bagel912

Reputation: 331

Besides the solution from @mck, I tried another three ways after searching that all get the desired result.

  1. Filtering using udf and returns the matching struct
df.filter(df.col1 == 'A') \
  .select(df.col1, udf(lambda a: [s for s in a if s.id == 'id2' and s.metadata['k'] == 'v2'], df.schema['col2'].dataType)('col2')[0].alias('col2_filtered')) \
  .na.drop('any')
  1. Filtering using udf and get the index of the matching struct
df.filter(df.col1 == 'A') \
  .select(df.col1, df.col2.getItem(udf(lambda a: [i for i, s in enumerate(a) if s.id == 'id2' and s.metadata['k'] == 'v2'], ArrayType(IntegerType(), True))(df.col2)[0]).alias('col2_filtered')) \
  .na.drop('any')
  1. Filtering using expr, this is a feature in Spark 2.4 so can be a candidate for future upgrades
df.filter(df.col1 == 'A') \
  .select(df.col1, expr("filter(col2, s -> s.id == 'id2' AND s.metadata['k'] == 'v2')").getItem(0).alias('col2_filtered')) \
  .na.drop('any')

Upvotes: 0

mck
mck

Reputation: 42352

EDIT: you can try a UDF:

import pyspark.sql.functions as F

df2 = df.filter(
    F.udf(lambda x: any([y.id == 'id2' and 'k' in y.metadata.keys() for y in x]), 'boolean')('col2')
).withColumn(
    'col2',
    F.udf(lambda x: [y for y in x if y.id == 'id2' and 'k' in y.metadata.keys()][0], 'struct<id:string,metadata:map<string,string>>')('col2')
)

df2.show(truncate=False)
+----+--------------------------+
|col1|col2                      |
+----+--------------------------+
|A   |[id2, [k2 -> v5, k -> v2]]|
+----+--------------------------+

You can cast the columns to JSON and check if col2 contains the desired JSON:

import pyspark.sql.functions as F

df2 = df.filter(
    (F.col('col1') == 'A') &
    F.to_json('col2').contains(
        F.to_json(
            F.struct(
                F.lit('id2').alias('id'),
                F.create_map(F.lit('k'), F.lit('v2')).alias('metadata')
            )
        )
    )
)

df2.show(truncate=False)
+----+------------------------------------+
|col1|col2                                |
+----+------------------------------------+
|A   |[[id1, [k -> v1]], [id2, [k -> v2]]]|
+----+------------------------------------+

If you just want to keep the matching struct in col2, you can replace it using withColumn:

df3 = df2.withColumn(
    'col2', 
    F.struct(
        F.lit('id2').alias('id'),
        F.create_map(F.lit('k'), F.lit('v2')).alias('metadata')
    )
)

df3.show()
+----+----------------+
|col1|            col2|
+----+----------------+
|   A|[id2, [k -> v2]]|
+----+----------------+

Upvotes: 2

Related Questions