Reputation: 908
I have a dataframe with 100+ columns. I have get a derived field with addition of sum of the columns based on condition.
For example, NEW_COLUMN_VALUE should be sum of A_2,3 & 4
df = df.withColumn('NEW_COLUMN_VALUE',
when(col('Id')==1, col("A_2")+col("A_3")+col("A_4"))
.otherwise(lit(None)))
another column should be the sum of A 18 to A40. Is there an easy way to avoid doing as below. (adding 22 columns... Columns follow a patter A_1,A_2.... till A_80
; There are other id fields also
col("A_18")+col("A_19")+col("A_20").......
Upvotes: 0
Views: 87
Reputation: 7336
Here is another solution. Here all the metadata is stored in a python dictionary, namely cols
which acts as the configuration of the application. Each item of the dictionary is a new column which contains information about the new created column. The key of the dictionary represents the new column, and the value is a list which keeps the metadata for this new column.
More specifically this list consists of:
df["id"] == lit(v[0])
In order to create the columns that we need to add, first we have to populate the column names. This is achieved using the mentioned range.
Next we use expr
to add the populated columns together. And finally, we append the results of withColumn
together.
Here is the complete code:
from pyspark.sql.functions import expr, lit, when
col_prefix = "a" # modify this to whatever your column looks like
cols = ['id', 'a1', 'a2', 'a3', 'a4']
data = [(1, 1, 2, 4, 7),
(2, 2, 4, 5, 8)]
df = spark.createDataFrame(data, cols)
cols = {
"col1": [1, 1, 3],
"col2": [2, 2, 4]
}
for k,v in cols.items():
target_cols = [f"{col_prefix}{idx}" for idx in list(range(v[1], v[2] + 1))] # produces [a0, a1, a2 ... aN]
add_expr = " + ".join(target_cols) # produces "a0 + a1 + a2 + ... + aN"
df = df.withColumn(k, when(df["id"] == lit(v[0]) ,expr(add_expr)))
df.show()
# +---+---+---+---+---+----+----+
# | id| a1| a2| a3| a4|col1|col2|
# +---+---+---+---+---+----+----+
# | 1| 1| 2| 4| 7| 7|null|
# | 2| 2| 4| 5| 8|null| 17|
# +---+---+---+---+---+----+----+
Upvotes: 1
Reputation: 4069
Writing a few lines of Python to solve your problem easily:
from functools import reduce
from operator import add
import pyspark.sql.functions as f
def filter_columns(dataframe, start, stop):
for column in dataframe.columns:
if column.startswith('A_'):
number = int(column.split('_')[-1])
if start <= number <= stop:
yield f.col(column)
# From A_1 to A_4
new_df = df.withColumn('foo', reduce(add, filter_columns(df, start=1, stop=4)))
# From A_18 to A_40
new_df = new_df.withColumn('bar', reduce(add, filter_columns(df, start=18, stop=40)))
Upvotes: 1