Haterind
Haterind

Reputation: 1445

How to index every element in an array of arrays?

I have an ArrayType column where every element is also an array, of exactly 2 elements.

from pyspark.sql import SparkSession

data = [
  {"u": ["apple", 23]},
  {"u": ["banana", 12]}
]

spark = SparkSession.builder.getOrCreate()
df = spark.createDataFrame(data)

df.show()
+------------+
|           u|
+------------+
| [apple, 23]|
|[banana, 12]|
+------------+

I want to replace each inner array with its first element. Had I been writing vanilla Python, it would be:

result = [ar[0] for ar in array_of_arrays]

With Spark, I can use a UDF:

from pyspark.sql import functions as f, types as t

fn = f.udf(lambda u: u[0], t.StringType())
new_df = df.select(fn(f.col("u")))

new_df.show()
+-----------+
|<lambda>(u)|
+-----------+
|      apple|
|     banana|
+-----------+

Which is the output I want. But how can I do this with PySpark, without:

Upvotes: 0

Views: 114

Answers (1)

wwnde
wwnde

Reputation: 26676

df.withColumn('u_1', col('u')[0]).show()

+------------+------+
|           u|   u_1|
+------------+------+
| [apple, 23]| apple|
|[banana, 12]|banana|
+------------+------+

Upvotes: 1

Related Questions