jbly1309
jbly1309

Reputation: 41

PySpark: Filter out all lines, which have more columns than header line

df = spark.read.csv('input.csv', header=True, inferSchema=True)

Let's say that 'input.csv' file contains following data:

id, name, age
1, John, 20
2, Mike, 33
3, Phil, 19, 180, 78
4, Sean, 40

I would like to filter out rows, which contain more columns than the header and save it to different output somewhat like this (illustratively):

df2 = df.filter(condition1) #condition1 = rows which have more columns than header
df = df.filter(condition2) #condition2 = rows which have same amount or less columns than header
df.show()
df2.show()

So I would get output as follows:

+---+------+----+
| id|  name| age|
+---+------+----+
|  1|  John|  20|
|  2|  Mike|  33|
|  4|  Sean|  40|
+---+------+----+


+---+------+----+
| id|  name| age|
+---+------+----+
|  3|  Phil|  19|
+---+------+----+

So far I've found nothing. Currently it just shrinks the row to fit the header with no way of obtaining it. What can I do? Thanks

EDIT: The schema does not necessarily need to be "id", "name, "age". It should literally take anything, so the filtering can not be dependent on a specific column. Moreover, the solution can not be exclusive to a specific type of reading, the data will be received based on what a user chooses and only thing that can be modified are parameters and options.

Upvotes: 1

Views: 696

Answers (1)

mck
mck

Reputation: 42382

You can read as a text file and find the number of columns in each row. Obtain a df with the id's of the rows having more than 3 columns. Then do a semi or anti join to the dataframe (obtained with read.csv) using the id.

text = spark.read.text('input.csv')
textdf = text.selectExpr(
    "value",
    "size(split(value, ',')) len",
    "split(value, ',')[0] id"
).filter('len > 3')

textdf.show()
+----------------+---+---+
|           value|len| id|
+----------------+---+---+
|3,Phil,19,180,78|  5|  3|
+----------------+---+---+


df = spark.read.csv('input.csv', header=True)

df1 = df.join(textdf, 'id', 'anti')
df2 = df.join(textdf, 'id', 'semi')

df1.show()
+---+----+---+
| id|name|age|
+---+----+---+
|  1|John| 20|
|  2|Mike| 33|
|  4|Sean| 40|
+---+----+---+

df2.show()
+---+----+---+
| id|name|age|
+---+----+---+
|  3|Phil| 19|
+---+----+---+

Upvotes: 2

Related Questions