Leo
Leo

Reputation: 908

Pyspark addition across columns

I have a dataframe with 100+ columns. I have get a derived field with addition of sum of the columns based on condition.

For example, NEW_COLUMN_VALUE should be sum of A_2,3 & 4

df = df.withColumn('NEW_COLUMN_VALUE', 
        when(col('Id')==1, col("A_2")+col("A_3")+col("A_4"))
        .otherwise(lit(None)))

another column should be the sum of A 18 to A40. Is there an easy way to avoid doing as below. (adding 22 columns... Columns follow a patter A_1,A_2.... till A_80 ; There are other id fields also

col("A_18")+col("A_19")+col("A_20").......

Upvotes: 0

Views: 87

Answers (2)

abiratsis
abiratsis

Reputation: 7336

Here is another solution. Here all the metadata is stored in a python dictionary, namely cols which acts as the configuration of the application. Each item of the dictionary is a new column which contains information about the new created column. The key of the dictionary represents the new column, and the value is a list which keeps the metadata for this new column.

More specifically this list consists of:

  1. The id value of the condition, i.e: df["id"] == lit(v[0])
  2. The starting element of the range
  3. The last element of the range

In order to create the columns that we need to add, first we have to populate the column names. This is achieved using the mentioned range.

Next we use expr to add the populated columns together. And finally, we append the results of withColumn together.

Here is the complete code:

from pyspark.sql.functions import expr, lit, when

col_prefix = "a" # modify this to whatever your column looks like
cols = ['id', 'a1', 'a2', 'a3', 'a4']
data = [(1, 1, 2, 4, 7),
        (2, 2, 4, 5, 8)]

df = spark.createDataFrame(data, cols)

cols = {
  "col1": [1, 1, 3],
  "col2": [2, 2, 4]
}

for k,v in cols.items():
  target_cols = [f"{col_prefix}{idx}" for idx in list(range(v[1], v[2] + 1))] # produces [a0, a1, a2 ... aN]
  add_expr = " + ".join(target_cols) # produces "a0 + a1 + a2 + ... + aN"
  df = df.withColumn(k, when(df["id"] == lit(v[0]) ,expr(add_expr)))

df.show()

# +---+---+---+---+---+----+----+
# | id| a1| a2| a3| a4|col1|col2|
# +---+---+---+---+---+----+----+
# |  1|  1|  2|  4|  7|   7|null|
# |  2|  2|  4|  5|  8|null|  17|
# +---+---+---+---+---+----+----+

Upvotes: 1

Kafels
Kafels

Reputation: 4069

Writing a few lines of Python to solve your problem easily:

from functools import reduce
from operator import add

import pyspark.sql.functions as f


def filter_columns(dataframe, start, stop):
    for column in dataframe.columns:
        if column.startswith('A_'):
            number = int(column.split('_')[-1])
            if start <= number <= stop:
                yield f.col(column)


# From A_1 to A_4
new_df = df.withColumn('foo', reduce(add, filter_columns(df, start=1, stop=4)))

# From A_18 to A_40
new_df = new_df.withColumn('bar', reduce(add, filter_columns(df, start=18, stop=40)))

Upvotes: 1

Related Questions