Reputation: 179
I have a 'text' column in which arrays of tokens are stored. How to filter all these arrays so that the tokens are at least three letters long?
from pyspark.sql.functions import regexp_replace, col
from pyspark.sql.session import SparkSession
spark = SparkSession.builder.getOrCreate()
columns = ['id', 'text']
vals = [
(1, ['I', 'am', 'good']),
(2, ['You', 'are', 'ok']),
]
df = spark.createDataFrame(vals, columns)
df.show()
# Had tried this but have TypeError: Column is not iterable
# df_clean = df.select('id', regexp_replace('text', [len(word) >= 3 for word
# in col('text')], ''))
# df_clean.show()
I expect to see:
id | text
1 | [good]
2 | [You, are]
Upvotes: 2
Views: 2254
Reputation: 179
This is the solution
filter_length_udf = udf(lambda row: [x for x in row if len(x) >= 3], ArrayType(StringType()))
df_final_words = df_stemmed.withColumn('words_filtered', filter_length_udf(col('words')))
Upvotes: 0
Reputation: 18098
This does it, you can decide to exclude row or not, I added an extra column and filtered out, but options are yours:
from pyspark.sql import functions as f
columns = ['id', 'text']
vals = [
(1, ['I', 'am', 'good']),
(2, ['You', 'are', 'ok']),
(3, ['ok'])
]
df = spark.createDataFrame(vals, columns)
#df.show()
df2 = df.withColumn("text_left_over", f.expr("filter(text, x -> not(length(x) < 3))"))
df2.show()
# This is the actual piece of logic you are looking for.
df3 = df.withColumn("text_left_over", f.expr("filter(text, x -> not(length(x) < 3))")).where(f.size(f.col("text_left_over")) > 0).drop("text")
df3.show()
returns:
+---+--------------+--------------+
| id| text|text_left_over|
+---+--------------+--------------+
| 1| [I, am, good]| [good]|
| 2|[You, are, ok]| [You, are]|
| 3| [ok]| []|
+---+--------------+--------------+
+---+--------------+
| id|text_left_over|
+---+--------------+
| 1| [good]|
| 2| [You, are]|
+---+--------------+
Upvotes: 2