Whitewater
Whitewater

Reputation: 317

Extract an item by name from an MapType() column in a PySpark column

I have a PySpark dataframe structured like this, where the array is formatted to start with a number: 1: "item / state / zip" most, but not all of the time. Sometimes the order will be different.

data = [
    ("Item A", "2024-12-01", {"1": "city: Palo Alto", "2": "state: CA", "3": "zip: 94301"}),
    ("Item B", "2024-12-02", {"1": "state: NY", "2": "city: New York", "3": "zip: 10001"}),
    ("Item B", "2024-12-03", {"1": "city: Austin", "2": "state: TX", "3": "zip: 73301"})
]

schema = StructType([
    StructField("item", StringType(), True),
    StructField("date", StringType(), True),
    StructField("geo_data", MapType(StringType(), StringType()), True)
])

sample_df = spark.createDataFrame(data, schema)

I want to extract the "State" value from the geo_data column into a new column using .withcolumn but I'm running into an issue because the array in the geo_data column is not ordered consistently (so state will not always appear as "1" or "2" in the list, and in some rows there is no "state" item in the array. This means I cannot use something simple like:

.withColumn("state_code", F.expr("geo_data[1]"))

I've also tried to use getItem() and getField() in the following ways:

new_data = sample_df.withColumn(
    "state_code", 
    F.col("geo_data").getField("state")
)

And the result each way are this error:

[DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE] Cannot resolve "geo_data[State]" due to data type mismatch: Parameter 2 requires the "INTEGRAL" type, however "State" has the type "STRING". SQLSTATE: 42K09

Looking for a better way to do this. Thanks.

Upvotes: 0

Views: 58

Answers (2)

lihao
lihao

Reputation: 781

You can use filter function to check map_values and find the item beginning with the string state:. this might return an array with 1 or 0 item (based on your description). take the first one and then use a string function (substring_index, right, regexp_replace, split etc) to retrieve the state.

sample_df.withColumn('state_code', F.expr("""
   substring_index(
       filter(map_values(geo_data), x-> x rlike '^state: ')[0]
     ,  ': '
     ,  -1
   )
""")).show()
+------+----------+--------------------+----------+
|  item|      date|            geo_data|state_code|
+------+----------+--------------------+----------+
|Item A|2024-12-01|{1 -> city: Palo ...|        CA|
|Item B|2024-12-02|{1 -> state: NY, ...|        NY|
|Item B|2024-12-03|{1 -> city: Austi...|        TX|
+------+----------+--------------------+----------+

Upvotes: 0

techtech
techtech

Reputation: 190

You cannot access the key 'state' because in your test data the keys in geo_data are 1, 2 and 3. You need to transform the values in geo_data first.

from pyspark.sql.functions import udf
@udf(MapType(StringType(), StringType()))
def split_list(list_):
    return {s.split(':')[0]: s.split(':')[1].strip() for s in list_}
    

display(sample_df.withColumn(
    "state_code", 
    split_list(F.map_values("geo_data")).getItem("state")
))

Upvotes: 0

Related Questions