nir
nir

Reputation: 3868

Spark - How to avoid duplicate columns after join?

Extending upon use case given here: How to avoid duplicate columns after join?

I have two dataframes with the 100s of columns. Following are some samples with join columns:

df1.columns
//  Array(ts, id, X1, X2, ...)

and

df2.columns
//  Array(ts, id, X1, Y2, ...)

After I do:

val df_combined = df1.join(df2, df1.X1===df2.X1 and df1.X2==df2.Y2)

I end up with the following columns: Array(ts, id, X1, X2, ts, id, X1, Y2). X1 is duplicated.

I can't use join(right: Dataset[_], usingColumns: Seq[String]) api as to use this api all columns must be there in both dataframe which is not the case here (X2 and Y2). Only option I see is to rename a column and drop column later or to alias dataframe and drop column later from 2nd dataframe. Isn't there a simple api to achieve this? E.g. automatically drop one of the join column in case of equality join.

Upvotes: 2

Views: 9831

Answers (1)

Shaido
Shaido

Reputation: 28322

As you noted, the best way to avoid duplicate columns is using a Seq[String] as input to the join. However, since the columns have different names in the dataframes there are only two options:

  1. Rename the Y2 column to X2 and perform the join as df1.join(df2, Seq("X1", "X2")). If you want to keep both the Y2 and X2 column afterwards, simply copy X2 to a new column Y2.

  2. Perform the join as before and drop the unwanted duplicated column(s) afterwards:

    df1.join(df2, df1.col("X1") === df2.col("X1") and df1.col("X2") === df2.col("Y2"))
      .drop(df1.col("X1"))
    

Unfortunately, currently there is no automatic way to achieve this.


When joining dataframes, it's better to make sure they do not have the same column names (with the exception of the columns used in the join). For example, the ts and id columns above. If there are a lot of columns it can be hard to rename them all manually. To do it automatically, the below code can be used:

val prefix "df1_"
val cols = df1.columns.map(c => col(c).as(s"$prefix$c"))
df1.select(cols:_*)

Upvotes: 4

Related Questions