Psychevic
Psychevic

Reputation: 709

How to get other columns when using Spark DataFrame groupby?

when I use DataFrame groupby like this:

df.groupBy(df("age")).agg(Map("id"->"count"))

I will only get a DataFrame with columns "age" and "count(id)",but in df,there are many other columns like "name".

In all,I want to get the result as in MySQL,

"select name,age,count(id) from df group by age"

What should I do when use groupby in Spark?

Upvotes: 57

Views: 113664

Answers (10)

RFAI
RFAI

Reputation: 469

If you use pySpark version 1.6.0 and after that, you can use collect_set() or collect_list().

For example, in the case of your code, you can use:

df = df.groupBy('age').agg(F.count('id').alias('idCount'), F.collect_set('name')\
.alias('userName'), F.collect_set('age').alias('userAge')
  • Please note that collect_list() includes duplicates in the result.

  • Using alias is to avoid creating columns with same names, which will result in more problems and might happen in some versions of Spark or on some platforms.

Upvotes: 0

NickyPatel
NickyPatel

Reputation: 565

#solved #working solution

generated this solution with help of the comment in this thread by @Azmisov and code sample took from https://sparkbyexamples.com/spark/using-groupby-on-dataframe/

Problem : in spark scala using dataframe, when using groupby and max, it is returning a dataframe with the columns used in groupby and max only. How to get all the columns ? or can say how to get not groupby columns ?

solution : Please go through the full example to get all the columns with groupby and max

import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.functions._ //{col, lit, when, to_timestamp}
import org.apache.spark.sql.types._
import org.apache.spark.sql.Column

val spark = SparkSession
              .builder()
              .appName("app-name")
              .master("local[*]")
              .getOrCreate()
              
    spark.sparkContext.setLogLevel("ERROR")
    import spark.implicits._
    
    val simpleData = Seq(("James","Sales","NY",90000,34,10000),
      ("Michael","Sales","NY",86000,56,20000),
      ("Robert","Sales","CA",81000,30,23000),
      ("Maria","Finance","CA",90000,24,23000),
      ("Raman","Finance","CA",99000,40,24000),
      ("Scott","Finance","NY",83000,36,19000),
      ("Jen","Finance","NY",79000,53,15000),
      ("Jeff","Marketing","CA",80000,25,18000),
      ("Kumar","Marketing","NY",91000,50,21000)
    )
    val df = simpleData.toDF("employee_name","department","state","salary","age","bonus")
    df.show()

gives below output as df is generated.

output :
+-------------+----------+-----+------+---+-----+
|employee_name|department|state|salary|age|bonus|
+-------------+----------+-----+------+---+-----+
|        James|     Sales|   NY| 90000| 34|10000|
|      Michael|     Sales|   NY| 86000| 56|20000|
|       Robert|     Sales|   CA| 81000| 30|23000|
|        Maria|   Finance|   CA| 90000| 24|23000|
|        Raman|   Finance|   CA| 99000| 40|24000|
|        Scott|   Finance|   NY| 83000| 36|19000|
|          Jen|   Finance|   NY| 79000| 53|15000|
|         Jeff| Marketing|   CA| 80000| 25|18000|
|        Kumar| Marketing|   NY| 91000| 50|21000|
+-------------+----------+-----+------+---+-----+

below code gives the output with not appropriate column names but still can be used :

val dfwithmax = df.groupBy("department").agg(max("salary"), first("employee_name"), first("state"), first("age"), first("bonus"))
dfwithmax.show()

+----------+-----------+---------------------------+-------------------+-----------------+-------------------+
|department|max(salary)|first(employee_name, false)|first(state, false)|first(age, false)|first(bonus, false)|
+----------+-----------+---------------------------+-------------------+-----------------+-------------------+
|     Sales|      90000|                      James|                 NY|               34|              10000|
|   Finance|      99000|                      Maria|                 CA|               24|              23000|
| Marketing|      91000|                       Jeff|                 CA|               25|              18000|
+----------+-----------+---------------------------+-------------------+-----------------+-------------------+

to make the column names appropriate, you can use as column name as given below

val dfwithmax1 = df.groupBy("department").agg(max("salary") as "salary", first("employee_name") as "employee_name", first("state") as "state", first("age") as "age",first("bonus") as "bonus")
dfwithmax1.show()

output:
+----------+------+-------------+-----+---+-----+
|department|salary|employee_name|state|age|bonus|
+----------+------+-------------+-----+---+-----+
|     Sales| 90000|        James|   NY| 34|10000|
|   Finance| 99000|        Maria|   CA| 24|23000|
| Marketing| 91000|         Jeff|   CA| 25|18000|
+----------+------+-------------+-----+---+-----+

if you still want to change the order of dataframe columns it can be done as below

val reOrderedColumnName : Array[String] = Array("employee_name", "department", "state", "salary", "age", "bonus")
val orderedDf = dfwithmax1.select(reOrderedColumnName.head, reOrderedColumnName.tail: _*)
orderedDf.show()

full code together :

import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types._
import org.apache.spark.sql.Column

object test {
    def main(args: Array[String]): Unit = {
        /** spark session object */
        val spark = SparkSession.builder().appName("app-name").master("local[*]")
                  .getOrCreate()
                  
        spark.sparkContext.setLogLevel("ERROR")
        import spark.implicits._
        
        val simpleData = Seq(("James","Sales","NY",90000,34,10000),
          ("Michael","Sales","NY",86000,56,20000),
          ("Robert","Sales","CA",81000,30,23000),
          ("Maria","Finance","CA",90000,24,23000),
          ("Raman","Finance","CA",99000,40,24000),
          ("Scott","Finance","NY",83000,36,19000),
          ("Jen","Finance","NY",79000,53,15000),
          ("Jeff","Marketing","CA",80000,25,18000),
          ("Kumar","Marketing","NY",91000,50,21000)
        )
        val df = simpleData.toDF("employee_name","department","state","salary","age","bonus")
        df.show()

        val dfwithmax = df.groupBy("department").agg(max("salary"), first("employee_name"), first("state"), first("age"), first("bonus"))
        dfwithmax.show()
        
        val dfwithmax1 = df.groupBy("department").agg(max("salary") as "salary", first("employee_name") as "employee_name", first("state") as "state", first("age") as "age",first("bonus") as "bonus")
        dfwithmax1.show()
        
        val reOrderedColumnName : Array[String] = Array("employee_name", "department", "state", "salary", "age", "bonus")
        val orderedDf = dfwithmax1.select(reOrderedColumnName.head, reOrderedColumnName.tail: _*)
        orderedDf.show()
    }
}

full output :
+-------------+----------+-----+------+---+-----+
|employee_name|department|state|salary|age|bonus|
+-------------+----------+-----+------+---+-----+
|        James|     Sales|   NY| 90000| 34|10000|
|      Michael|     Sales|   NY| 86000| 56|20000|
|       Robert|     Sales|   CA| 81000| 30|23000|
|        Maria|   Finance|   CA| 90000| 24|23000|
|        Raman|   Finance|   CA| 99000| 40|24000|
|        Scott|   Finance|   NY| 83000| 36|19000|
|          Jen|   Finance|   NY| 79000| 53|15000|
|         Jeff| Marketing|   CA| 80000| 25|18000|
|        Kumar| Marketing|   NY| 91000| 50|21000|
+-------------+----------+-----+------+---+-----+

+----------+-----------+---------------------------+------------------------+-------------------+-----------------+-------------------+
|department|max(salary)|first(employee_name, false)|first(department, false)|first(state, false)|first(age, false)|first(bonus, false)|
+----------+-----------+---------------------------+------------------------+-------------------+-----------------+-------------------+
|     Sales|      90000|                      James|                   Sales|                 NY|               34|              10000|
|   Finance|      99000|                      Maria|                 Finance|                 CA|               24|              23000|
| Marketing|      91000|                       Jeff|               Marketing|                 CA|               25|              18000|
+----------+-----------+---------------------------+------------------------+-------------------+-----------------+-------------------+

+----------+------+-------------+----------+-----+---+-----+
|department|salary|employee_name|department|state|age|bonus|
+----------+------+-------------+----------+-----+---+-----+
|     Sales| 90000|        James|     Sales|   NY| 34|10000|
|   Finance| 99000|        Maria|   Finance|   CA| 24|23000|
| Marketing| 91000|         Jeff| Marketing|   CA| 25|18000|
+----------+------+-------------+----------+-----+---+-----+

Exceptions :

Exception in thread "main" org.apache.spark.sql.AnalysisException: Reference 'department' is ambiguous, could be: department, department.;

it means you have department column twice. It is used in groupby or max and it is mentioned by you in the first("department") as "department" also.

for example(please check at last) :

val dfwithmax1 = df.groupBy("department").agg(max("salary") as "salary", first("employee_name") as "employee_name", first("department") as "department", first("state") as "state", first("age") as "age",first("bonus") as "bonus")

thanks ! please give upvote if it is helpful.

Upvotes: 4

zwithouta
zwithouta

Reputation: 1591

This pyspark code selects the B value of the max([A, B]-combination) of each A-group (if several maxima exist in a group, a random one is picked).

A would be age in your case and B any of the columns you did not group by but nevertheless want to select.

df = spark.createDataFrame([
    [1, 1, 0.2],
    [1, 1, 0.9],
    [1, 2, 0.6],
    [1, 2, 0.5],
    [1, 2, 0.6],
    [2, 1, 0.2],
    [2, 2, 0.1],
], ["group", "A", "B"])

out = (
    df
    .withColumn("AB", F.struct("A", "B"))
    .groupby("group")
    # F.max(AB) selects AB-combinations with max `A`. If more
    # than one combination remains the one with max `B` is selected. If
    # after this identical combinations remain, a single one of them is picked
    # randomly.
    .agg(F.max("AB").alias("max_AB"))
    .select("group", F.expr("max_AB.B"))
)
out.show()

Output

+-----+---+
|group|  B|
+-----+---+
|    1|0.6|
|    2|0.1|
+-----+---+

Upvotes: 1

Mohamed Hosni
Mohamed Hosni

Reputation: 3

Here an example that I came across in spark-workshop

val populationDF = spark.read
                .option("infer-schema", "true")
                .option("header", "true")
                .format("csv").load("file:///databricks/driver/population.csv")
                .select('name, regexp_replace(col("population"), "\\s", "").cast("integer").as("population"))

val maxPopulationDF = populationDF.agg(max('population).as("populationmax"))

To get other columns, I do a simple join between the original DF and the aggregated one

populationDF.join(maxPopulationDF,populationDF.col("population") === maxPopulationDF.col("populationmax")).select('name, 'populationmax).show()

Upvotes: 0

Rubber Duck
Rubber Duck

Reputation: 3743

Aggregate functions reduce values of rows for specified columns within the group. If you wish to retain other row values you need to implement reduction logic that specifies a row from which each value comes from. For instance keep all values of the first row with the maximum value of age. To this end you can use a UDAF (user defined aggregate function) to reduce rows within the group.

import org.apache.spark.sql._
import org.apache.spark.sql.functions._


object AggregateKeepingRowJob {

  def main (args: Array[String]): Unit = {

    val sparkSession = SparkSession
      .builder()
      .appName(this.getClass.getName.replace("$", ""))
      .master("local")
      .getOrCreate()

    val sc = sparkSession.sparkContext
    sc.setLogLevel("ERROR")

    import sparkSession.sqlContext.implicits._

    val rawDf = Seq(
      (1L, "Moe",  "Slap",  2.0, 18),
      (2L, "Larry",  "Spank",  3.0, 15),
      (3L, "Curly",  "Twist", 5.0, 15),
      (4L, "Laurel", "Whimper", 3.0, 15),
      (5L, "Hardy", "Laugh", 6.0, 15),
      (6L, "Charley",  "Ignore",   5.0, 5)
    ).toDF("id", "name", "requisite", "money", "age")

    rawDf.show(false)
    rawDf.printSchema

    val maxAgeUdaf = new KeepRowWithMaxAge

    val aggDf = rawDf
      .groupBy("age")
      .agg(
        count("id"),
        max(col("money")),
        maxAgeUdaf(
          col("id"),
          col("name"),
          col("requisite"),
          col("money"),
          col("age")).as("KeepRowWithMaxAge")
      )

    aggDf.printSchema
    aggDf.show(false)

  }


}

The UDAF:

import org.apache.spark.sql.Row
import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}
import org.apache.spark.sql.types._

class KeepRowWithMaxAmt extends UserDefinedAggregateFunction {
// This is the input fields for your aggregate function.
override def inputSchema: org.apache.spark.sql.types.StructType =
  StructType(
    StructField("store", StringType) ::
    StructField("prod", StringType) ::
    StructField("amt", DoubleType) ::
    StructField("units", IntegerType) :: Nil
  )

// This is the internal fields you keep for computing your aggregate.
override def bufferSchema: StructType = StructType(
  StructField("store", StringType) ::
  StructField("prod", StringType) ::
  StructField("amt", DoubleType) ::
  StructField("units", IntegerType) :: Nil
)


// This is the output type of your aggregation function.
override def dataType: DataType =
  StructType((Array(
    StructField("store", StringType),
    StructField("prod", StringType),
    StructField("amt", DoubleType),
    StructField("units", IntegerType)
  )))

override def deterministic: Boolean = true

// This is the initial value for your buffer schema.
override def initialize(buffer: MutableAggregationBuffer): Unit = {
  buffer(0) = ""
  buffer(1) = ""
  buffer(2) = 0.0
  buffer(3) = 0
}

// This is how to update your buffer schema given an input.
override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {

  val amt = buffer.getAs[Double](2)
  val candidateAmt = input.getAs[Double](2)

  amt match {
    case a if a < candidateAmt =>
      buffer(0) = input.getAs[String](0)
      buffer(1) = input.getAs[String](1)
      buffer(2) = input.getAs[Double](2)
      buffer(3) = input.getAs[Int](3)
    case _ =>
  }
}

// This is how to merge two objects with the bufferSchema type.
override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {

  buffer1(0) = buffer2.getAs[String](0)
  buffer1(1) = buffer2.getAs[String](1)
  buffer1(2) = buffer2.getAs[Double](2)
  buffer1(3) = buffer2.getAs[Int](3)
}

// This is where you output the final value, given the final value of your bufferSchema.
override def evaluate(buffer: Row): Any = {
  buffer
}
}

Upvotes: 1

Rubber Duck
Rubber Duck

Reputation: 3743

You need to remember that aggregate functions reduce the rows and therefore you need to specify which of the rows name you want with a reducing function. If you want to retain all rows of a group (warning! this can cause explosions or skewed partitions) you can collect them as a list. You can then use a UDF (user defined function) to reduce them by your criteria, in my example money. And then expand columns from the single reduced row with another UDF . For the purpose of this answer I assume you wish to retain the name of the person who has the most money.

import org.apache.spark.sql._
import org.apache.spark.sql.catalyst.expressions.GenericRowWithSchema
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types.StringType

import scala.collection.mutable


object TestJob3 {

def main (args: Array[String]): Unit = {

val sparkSession = SparkSession
  .builder()
  .appName(this.getClass.getName.replace("$", ""))
  .master("local")
  .getOrCreate()

val sc = sparkSession.sparkContext

import sparkSession.sqlContext.implicits._

val rawDf = Seq(
  (1, "Moe",  "Slap",  2.0, 18),
  (2, "Larry",  "Spank",  3.0, 15),
  (3, "Curly",  "Twist", 5.0, 15),
  (4, "Laurel", "Whimper", 3.0, 9),
  (5, "Hardy", "Laugh", 6.0, 18),
  (6, "Charley",  "Ignore",   5.0, 5)
).toDF("id", "name", "requisite", "money", "age")

rawDf.show(false)
rawDf.printSchema

val rawSchema = rawDf.schema

val fUdf = udf(reduceByMoney, rawSchema)

val nameUdf = udf(extractName, StringType)

val aggDf = rawDf
  .groupBy("age")
  .agg(
    count(struct("*")).as("count"),
    max(col("money")),
    collect_list(struct("*")).as("horizontal")
  )
  .withColumn("short", fUdf($"horizontal"))
  .withColumn("name", nameUdf($"short"))
  .drop("horizontal")

aggDf.printSchema

aggDf.show(false)

}

def reduceByMoney= (x: Any) => {

val d = x.asInstanceOf[mutable.WrappedArray[GenericRowWithSchema]]

val red = d.reduce((r1, r2) => {

  val money1 = r1.getAs[Double]("money")
  val money2 = r2.getAs[Double]("money")

  val r3 = money1 match {
    case a if a >= money2 =>
      r1
    case _ =>
      r2
  }

  r3
})

red
}

def extractName = (x: Any) => {

  val d = x.asInstanceOf[GenericRowWithSchema]

  d.getAs[String]("name")
}
}

here is the output

+---+-----+----------+----------------------------+-------+
|age|count|max(money)|short                       |name   |
+---+-----+----------+----------------------------+-------+
|5  |1    |5.0       |[6, Charley, Ignore, 5.0, 5]|Charley|
|15 |2    |5.0       |[3, Curly, Twist, 5.0, 15]  |Curly  |
|9  |1    |3.0       |[4, Laurel, Whimper, 3.0, 9]|Laurel |
|18 |2    |6.0       |[5, Hardy, Laugh, 6.0, 18]  |Hardy  |
+---+-----+----------+----------------------------+-------+

Upvotes: -1

Thirupathi Chavati
Thirupathi Chavati

Reputation: 1871

May be this solution will helpfull.

from pyspark.sql import SQLContext
from pyspark import SparkContext, SparkConf
from pyspark.sql import functions as F
from pyspark.sql import Window

    name_list = [(101, 'abc', 24), (102, 'cde', 24), (103, 'efg', 22), (104, 'ghi', 21),
                 (105, 'ijk', 20), (106, 'klm', 19), (107, 'mno', 18), (108, 'pqr', 18),
                 (109, 'rst', 26), (110, 'tuv', 27), (111, 'pqr', 18), (112, 'rst', 28), (113, 'tuv', 29)]

age_w = Window.partitionBy("age")
name_age_df = sqlContext.createDataFrame(name_list, ['id', 'name', 'age'])

name_age_count_df = name_age_df.withColumn("count", F.count("id").over(age_w)).orderBy("count")
name_age_count_df.show()

Output:

+---+----+---+-----+
| id|name|age|count|
+---+----+---+-----+
|109| rst| 26|    1|
|113| tuv| 29|    1|
|110| tuv| 27|    1|
|106| klm| 19|    1|
|103| efg| 22|    1|
|104| ghi| 21|    1|
|105| ijk| 20|    1|
|112| rst| 28|    1|
|101| abc| 24|    2|
|102| cde| 24|    2|
|107| mno| 18|    3|
|111| pqr| 18|    3|
|108| pqr| 18|    3|
+---+----+---+-----+

Upvotes: 12

Swetha Kannan
Swetha Kannan

Reputation: 236

One way to get all columns after doing a groupBy is to use join function.

feature_group = ['name', 'age']
data_counts = df.groupBy(feature_group).count().alias("counts")
data_joined = df.join(data_counts, feature_group)

data_joined will now have all columns including the count values.

Upvotes: 18

Sree Eedupuganti
Sree Eedupuganti

Reputation: 440

You can do like this :

Sample data:

name    age id
abc     24  1001
cde     24  1002
efg     22  1003
ghi     21  1004
ijk     20  1005
klm     19  1006
mno     18  1007
pqr     18  1008
rst     26  1009
tuv     27  1010
pqr     18  1012
rst     28  1013
tuv     29  1011
df.select("name","age","id").groupBy("name","age").count().show();

Output:

    +----+---+-----+
    |name|age|count|
    +----+---+-----+
    | efg| 22|    1|
    | tuv| 29|    1|
    | rst| 28|    1|
    | klm| 19|    1|
    | pqr| 18|    2|
    | cde| 24|    1|
    | tuv| 27|    1|
    | ijk| 20|    1|
    | abc| 24|    1|
    | mno| 18|    1|
    | ghi| 21|    1|
    | rst| 26|    1|
    +----+---+-----+

Upvotes: -6

zero323
zero323

Reputation: 330393

Long story short in general you have to join aggregated results with the original table. Spark SQL follows the same pre-SQL:1999 convention as most of the major databases (PostgreSQL, Oracle, MS SQL Server) which doesn't allow additional columns in aggregation queries.

Since for aggregations like count results are not well defined and behavior tends to vary in systems which supports this type of queries you can just include additional columns using arbitrary aggregate like first or last.

In some cases you can replace agg using select with window functions and subsequent where but depending on the context it can be quite expensive.

Upvotes: 56

Related Questions