StormPooper
StormPooper

Reputation: 513

PySpark - How to deal with list of lists as a column of a dataframe

My source data is a JSON file, and one of the fields is a list of lists (I generated the file with another python script; the idea was to make a list of tuples, but the result was "converted" to list of lists); I have a list of values, and for each of this values I want to filter my DF in such a way to get all the rows that inside the list of lists have that value; let me make a simple example:

JSON row: {"id": "D1", "class": "WARRIOR", "archetype": "Pirate Warrior", "matches": 140000, "duration": 6.2, "turns": 7.5, "winrate": 58.0, "cards": [["DRG_024", 2], ["CS2_146", 1], ["EX1_409", 1]]}

value: "CS2_146"

expected result: all the rows containing "CS2_146" as the first element of one of the nested lists

Upvotes: 1

Views: 7515

Answers (2)

pardeep garg
pardeep garg

Reputation: 219

You can use array_contains functions but you have nested Array so first you need to use flatted to create single array.

from pyspark.sql.types import *
from pyspark.sql.functions import *
a={"id": "D1", "class": "WARRIOR", "archetype": "Pirate Warrior", "matches": 140000, 
  "duration": 6.2, "turns": 7.5, "winrate": 58.0, "cards": [["DRG_024", 2], 
  ["CS2_146", 1], ["EX1_409", 1]]}
df=spark.createDataFrame([a]) 
df.withColumn("t",array_contains(flatten("cards"),"CS2_146")).where(col("t")=="true").show()

Upvotes: 1

notNull
notNull

Reputation: 31470

As you are having nested array we need to explode the arrays then based on the index value we can filter out the records.

Example:

df.printSchema()
#root
# |-- archetype: string (nullable = true)
# |-- cards: array (nullable = true)
# |    |-- element: array (containsNull = true)
# |    |    |-- element: string (containsNull = true)
# |-- class: string (nullable = true)
# |-- duration: double (nullable = true)
# |-- id: string (nullable = true)
# |-- matches: long (nullable = true)
# |-- turns: double (nullable = true)
# |-- winrate: double (nullable = true)

df.show(truncate=False)
#+--------------+------------------------------------------+-------+--------+---+-------+-----+-------+
#|archetype     |cards                                     |class  |duration|id |matches|turns|winrate|
#+--------------+------------------------------------------+-------+--------+---+-------+-----+-------+
#|Pirate Warrior|[[DRG_024, 2], [CS2_146, 1], [EX1_409, 1]]|WARRIOR|6.2     |D1 |140000 |7.5  |58.0   |
#+--------------+------------------------------------------+-------+--------+---+-------+-----+-------+

#first explode cards array then explode the nested array with position
#finally filter on pos=0 and cards_arr="CS2_146"

df.selectExpr("*","explode(cards)").\
selectExpr("*","posexplode(col) as (pos,cards_arr)").filter((col("pos") == 0) & (col("cards_arr") == "CS2_146")).show()
#+--------------+--------------------+-------+--------+---+-------+-----+-------+------------+---+---------+
#|     archetype|               cards|  class|duration| id|matches|turns|winrate|         col|pos|cards_arr|
#+--------------+--------------------+-------+--------+---+-------+-----+-------+------------+---+---------+
#|Pirate Warrior|[[DRG_024, 2], [C...|WARRIOR|     6.2| D1| 140000|  7.5|   58.0|[CS2_146, 1]|  0|  CS2_146|
#+--------------+--------------------+-------+--------+---+-------+-----+-------+------------+---+---------+

Upvotes: 2

Related Questions