ArnoXf
ArnoXf

Reputation: 135

Is it possible to filter columns by the sum of their values in Spark?

I'm loading a sparse table using PySpark where I want to remove all columns where the sum of all values in the column is above a threshold.

For example, the sum of column values of the following table:

+---+---+---+---+---+---+
|  a|  b|  c|  d|  e|  f|
+---+---+---+---+---+---+
|  1|  0|  1|  1|  0|  0|
|  1|  1|  0|  0|  0|  0|
|  1|  0|  0|  1|  1|  1|
|  1|  0|  0|  1|  1|  1|
|  1|  1|  0|  0|  1|  0|
|  0|  0|  1|  0|  1|  0|
+---+---+---+---+---+---+

Is 5, 2, 2, 3, 4 and 2. Filtering for all columns with sum >= 3 should output this table:

+---+---+---+
|  a|  d|  e|
+---+---+---+
|  1|  1|  0|
|  1|  0|  0|
|  1|  1|  1|
|  1|  1|  1|
|  1|  0|  1|
|  0|  0|  1|
+---+---+---+

I tried many different solutions without success. df.groupBy().sum() is giving me the sum of column values, so I'm searching how I can then filter those with threshold and get only the remaining columns from the original dataframe.

As there are not only 6 but a couple of thousand columns, I'm searching for a scalable solution, where I don't have to type in every column name. Thanks for help!

Upvotes: 0

Views: 1154

Answers (1)

Steven
Steven

Reputation: 15258

You can do this with a collect (or a first) step.

from pyspark.sql import functions as F

sum_result = df.groupBy().agg(*(F.sum(col).alias(col) for col in df.columns)).first()

filtered_df = df.select(
    *(col for col, value in sum_result.asDict().items() if value >= 3)
)

filtered_df.show()
+---+---+---+
|  a|  d|  e|
+---+---+---+
|  1|  1|  0|
|  1|  0|  0|
|  1|  1|  1|
|  1|  1|  1|
|  1|  0|  1|
|  0|  0|  1|
+---+---+---+

Upvotes: 1

Related Questions