Rajesh Ravindran
Rajesh Ravindran

Reputation: 1

Dataframe column name is not updated using alias

I'm doing some kind of aggregation on the dataframe I have created. Here are the steps

val initDF = spark.read.format("csv").schema(someSchema).option("header","true").load(filePath).as[someCaseClass]

var maleFemaleDistribution = initDF.select("DISTRICT","GENDER","ENROLMENT_ACCEPTED","ENROLMENT_REJECTED").groupBy("DISTRICT").agg(
     count( lit(1).alias("OVERALL_COUNT")),
     sum(when(col("GENDER") === "M", 1).otherwise(0).alias("MALE_COUNT")),
     sum(when(col("GENDER") === "F", 1).otherwise(0).alias("FEMALE_COUNT"))
      ).orderBy("DISTRICT")

HowEver when I do a printSchema on my newly created DataFrame, I am not seeing the Column names as Alias I have provided, Instead it shows

maleFemaleDistribution.printSchema
root
 |-- DISTRICT: string (nullable = true)
 |-- count(1 AS `OVERALL_COUNT`): long (nullable = false)
 |-- sum(CASE WHEN (GENDER = M) THEN 1 ELSE 0 END AS `MALE_COUNT`): long (nullable = true)
 |-- sum(CASE WHEN (GENDER = F) THEN 1 ELSE 0 END AS `FEMALE_COUNT`): long (nullable = true)

Where as I am expecting the column names to be

maleFemaleDistribution.printSchema
root
 |-- DISTRICT: string (nullable = true)
 |-- OVERALL_COUNT: long (nullable = false)
 |-- MALE_COUNT: long (nullable = true)
 |-- FEMALE_COUNT: long (nullable = true) 

I'm seeking help to understand why Alias is not updated in the new DF. And How should I modify the code to reflect column names mentioned in Alias

Upvotes: 0

Views: 421

Answers (2)

Mann
Mann

Reputation: 307

You should add alias function after the sum operation. So, instead of this,

sum(when(col("GENDER") === "M", 1).otherwise(0).alias("MALE_COUNT"))

It should look like this :

sum(when(col("GENDER") === "M", 1).otherwise(0)).alias("MALE_COUNT")

Upvotes: 0

Gaurang Shah
Gaurang Shah

Reputation: 12910

I haven't tried running the query, However it should be.

var maleFemaleDistribution = initDF.select("DISTRICT","GENDER","ENROLMENT_ACCEPTED","ENROLMENT_REJECTED").groupBy("DISTRICT").agg(
     count(lit(1)).alias("OVERALL_COUNT"),
     sum(when(col("GENDER") === "M", 1).otherwise(0)).alias("MALE_COUNT"),
     sum(when(col("GENDER") === "F", 1).otherwise(0)).alias("FEMALE_COUNT")
      ).orderBy("DISTRICT")

Upvotes: 1

Related Questions