l3rn1ngData
l3rn1ngData

Reputation: 59

Spark scala selecting multiple columns from a list and single columns

I'm attempting to do a select on a dataframe but I'm having a little bit of trouble.

I have this initial dataframe

+----------+-------+-------+-------+
|id|value_a|value_b|value_c|value_d|
+----------+-------+-------+-------+

And what I have to do is sum value_a with value_b and keep the others the same. So I have this list

val select_list = List(id, value_c, value_d)

and after this I do the select

df.select(select_list.map(col):_*, (col(value_a) + col(value_b)).as("value_b"))

And I'm expecting to get this:

+----------+-------+-------+
|id|value_c|value_d|value_b|  --- that value_b is the sum of value_a and value_b (original)
+----------+-------+-------+

But i'm getting "a no _* annotation allowed here". Keep in mind that in reality I have a lot of columns so I need to use a list, I can't simply select each column. I'm running into this trouble because the new column that is the result of the sum has the same name of an existing column, so I can't just select(column("*"), sum....).drop(value_b) or I'd be dropping the old column and the new one with the sum.

What is the correct syntax to add multiple and single columns in a single select, or how else can I solve this? for now I decided to do this:

df.select(col("*"), (col(value_a) + col(value_b)).as("value_b_tmp")).
drop("value_a", "value_b").withColumnRenamed("value_b_tmp", "value_b")

Which works fine but I understand the withColumn and withColumnRenamed is expensive because I'm creating pretty much a new dataframe with a new or renamed column and I'm looking for the less expensive operation possible.

Thanks in advance!

Upvotes: 1

Views: 1389

Answers (2)

Emiliano Martinez
Emiliano Martinez

Reputation: 4133

You can create a new sum field and collect the result of the operation for the sum of the n columns as:

 val df: DataFrame = 
 spark.createDataFrame(
    spark.sparkContext.parallelize(Seq(Row(1,2,3),Row(1,2,3))),
       StructType(List(
        StructField("field1", IntegerType), 
        StructField("field2", IntegerType), 
        StructField("field3", IntegerType))))

val columnsToSum = df.schema.fieldNames

columnsToSum.filter(name =>  name != "field1")
  .foldLeft(df.withColumn("sum", lit(0)))((df, column) =>
   df.withColumn("sum", col("sum") + col(column)))

Gives:

+------+------+------+---+
|field1|field2|field3|sum|
+------+------+------+---+
|     1|     2|     3|  5|
|     1|     2|     3|  5|
+------+------+------+---+

Upvotes: 0

falcon-le0
falcon-le0

Reputation: 609

Simply use .withColumn function, it will replace the column if it exists:

df
  .withColumn("value_b", col("value_a") + col("value_b"))
  .select(select_list.map(col):_*)

Upvotes: 3

Related Questions