sandbar
sandbar

Reputation: 84

Manually reorder columns to specific location Spark 3 / Scala

I have a DataFrame with over 100 columns. There are a handful of columns I'd like to move to the very left of the DataFrame. Is there an easy way to specify which columns I'd like to move to the left and then the remaining columns stay in the same order? I know I can use select to reorder the columns, but given that I have over 100 columns I want to avoid this.

Upvotes: 1

Views: 107

Answers (1)

mvasyliv
mvasyliv

Reputation: 1214

import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.{Encoders, SparkSession}
import spark.implicits._

case class D(
    C1: String,
    C2: String,
    C3: String,
    C4: String,
    C5: String,
    C6: String,
    C7: String,
    C8: String,
    C9: String,
    C10: String
)
val schema: StructType    = Encoders.product[D].schema
val fields = schema.fieldNames
// or fields = DataFrame.columns ...

val source = Seq(
  D("1", "1", "1", "1", "1", "1", "1", "1", "1", "1"),
  D("2", "2", "2", "2", "2", "2", "2", "2", "2", "2"),
  D("3", "3", "3", "3", "3", "3", "3", "3", "3", "3")
).toDF()

source.printSchema()
//    root
//    |-- C1: string (nullable = true)
//    |-- C2: string (nullable = true)
//    |-- C3: string (nullable = true)
//    |-- C4: string (nullable = true)
//    |-- C5: string (nullable = true)
//    |-- C6: string (nullable = true)
//    |-- C7: string (nullable = true)
//    |-- C8: string (nullable = true)
//    |-- C9: string (nullable = true)
//    |-- C10: string (nullable = true)
source.show()
//    +---+---+---+---+---+---+---+---+---+---+
//    | C1| C2| C3| C4| C5| C6| C7| C8| C9|C10|
//    +---+---+---+---+---+---+---+---+---+---+
//    |  1|  1|  1|  1|  1|  1|  1|  1|  1|  1|
//    |  2|  2|  2|  2|  2|  2|  2|  2|  2|  2|
//    |  3|  3|  3|  3|  3|  3|  3|  3|  3|  3|
//    +---+---+---+---+---+---+---+---+---+---+

val colFirst = Array("C1", "C2", "C10", "C7")
val tmpLast = fields.diff(colFirst)
val cols = colFirst ++ tmpLast

val res1 = source.select(cols.head, cols.tail:_*)
res1.printSchema()
//    root
//    |-- C1: string (nullable = true)
//    |-- C2: string (nullable = true)
//    |-- C10: string (nullable = true)
//    |-- C7: string (nullable = true)
//    |-- C3: string (nullable = true)
//    |-- C4: string (nullable = true)
//    |-- C5: string (nullable = true)
//    |-- C6: string (nullable = true)
//    |-- C8: string (nullable = true)
//    |-- C9: string (nullable = true)

res1.show(false)
//    +---+---+---+---+---+---+---+---+---+---+
//    |C1 |C2 |C10|C7 |C3 |C4 |C5 |C6 |C8 |C9 |
//    +---+---+---+---+---+---+---+---+---+---+
//    |1  |1  |1  |1  |1  |1  |1  |1  |1  |1  |
//    |2  |2  |2  |2  |2  |2  |2  |2  |2  |2  |
//    |3  |3  |3  |3  |3  |3  |3  |3  |3  |3  |
//    +---+---+---+---+---+---+---+---+---+---+

Upvotes: 2

Related Questions