when I use DataFrame groupby like this:
I will only get a DataFrame with columns "age" and "count(id)",but in df,there are many other columns like "name".
In all,I want to get the result as in MySQL,
"select name,age,count(id) from df group by age"
What should I do when use groupby in Spark?
If you use pySpark version 1.6.0 and after that, you can use collect_set()
or collect_list()
For example, in the case of your code, you can use:
df = df.groupBy('age').agg(F.count('id').alias('idCount'), F.collect_set('name')\
.alias('userName'), F.collect_set('age').alias('userAge')
Please note that collect_list()
includes duplicates in the result.
Using alias
is to avoid creating columns with same names, which will result in more problems and might happen in some versions of Spark or on some platforms.
Problem : in spark scala using dataframe, when using groupby and max, it is returning a dataframe with the columns used in groupby and max only. How to get all the columns ? or can say how to get not groupby columns ?
solution : Please go through the full example to get all the columns with groupby and max
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.functions._ //{col, lit, when, to_timestamp}
import org.apache.spark.sql.types._
import org.apache.spark.sql.Column
val spark = SparkSession
import spark.implicits._
val simpleData = Seq(("James","Sales","NY",90000,34,10000),
val df = simpleData.toDF("employee_name","department","state","salary","age","bonus")
gives below output as df is generated.
output :
| James| Sales| NY| 90000| 34|10000|
| Michael| Sales| NY| 86000| 56|20000|
| Robert| Sales| CA| 81000| 30|23000|
| Maria| Finance| CA| 90000| 24|23000|
| Raman| Finance| CA| 99000| 40|24000|
| Scott| Finance| NY| 83000| 36|19000|
| Jen| Finance| NY| 79000| 53|15000|
| Jeff| Marketing| CA| 80000| 25|18000|
| Kumar| Marketing| NY| 91000| 50|21000|
below code gives the output with not appropriate column names but still can be used :
val dfwithmax = df.groupBy("department").agg(max("salary"), first("employee_name"), first("state"), first("age"), first("bonus"))
|department|max(salary)|first(employee_name, false)|first(state, false)|first(age, false)|first(bonus, false)|
| Sales| 90000| James| NY| 34| 10000|
| Finance| 99000| Maria| CA| 24| 23000|
| Marketing| 91000| Jeff| CA| 25| 18000|
to make the column names appropriate, you can use as column name as given below
val dfwithmax1 = df.groupBy("department").agg(max("salary") as "salary", first("employee_name") as "employee_name", first("state") as "state", first("age") as "age",first("bonus") as "bonus")
| Sales| 90000| James| NY| 34|10000|
| Finance| 99000| Maria| CA| 24|23000|
| Marketing| 91000| Jeff| CA| 25|18000|
if you still want to change the order of dataframe columns it can be done as below
val reOrderedColumnName : Array[String] = Array("employee_name", "department", "state", "salary", "age", "bonus")
val orderedDf =, reOrderedColumnName.tail: _*)
full code together :
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types._
import org.apache.spark.sql.Column
object test {
def main(args: Array[String]): Unit = {
/** spark session object */
val spark = SparkSession.builder().appName("app-name").master("local[*]")
import spark.implicits._
val simpleData = Seq(("James","Sales","NY",90000,34,10000),
val df = simpleData.toDF("employee_name","department","state","salary","age","bonus")
val dfwithmax = df.groupBy("department").agg(max("salary"), first("employee_name"), first("state"), first("age"), first("bonus"))
val dfwithmax1 = df.groupBy("department").agg(max("salary") as "salary", first("employee_name") as "employee_name", first("state") as "state", first("age") as "age",first("bonus") as "bonus")
val reOrderedColumnName : Array[String] = Array("employee_name", "department", "state", "salary", "age", "bonus")
val orderedDf =, reOrderedColumnName.tail: _*)
full output :
| James| Sales| NY| 90000| 34|10000|
| Michael| Sales| NY| 86000| 56|20000|
| Robert| Sales| CA| 81000| 30|23000|
| Maria| Finance| CA| 90000| 24|23000|
| Raman| Finance| CA| 99000| 40|24000|
| Scott| Finance| NY| 83000| 36|19000|
| Jen| Finance| NY| 79000| 53|15000|
| Jeff| Marketing| CA| 80000| 25|18000|
| Kumar| Marketing| NY| 91000| 50|21000|
|department|max(salary)|first(employee_name, false)|first(department, false)|first(state, false)|first(age, false)|first(bonus, false)|
| Sales| 90000| James| Sales| NY| 34| 10000|
| Finance| 99000| Maria| Finance| CA| 24| 23000|
| Marketing| 91000| Jeff| Marketing| CA| 25| 18000|
| Sales| 90000| James| Sales| NY| 34|10000|
| Finance| 99000| Maria| Finance| CA| 24|23000|
| Marketing| 91000| Jeff| Marketing| CA| 25|18000|
Exceptions :
Exception in thread "main" org.apache.spark.sql.AnalysisException: Reference 'department' is ambiguous, could be: department, department.;
it means you have department column twice. It is used in groupby or max and it is mentioned by you in the first("department") as "department" also.
for example(please check at last) :
val dfwithmax1 = df.groupBy("department").agg(max("salary") as "salary", first("employee_name") as "employee_name", first("department") as "department", first("state") as "state", first("age") as "age",first("bonus") as "bonus")
This pyspark code selects the B
value of the max([A
, B
]-combination) of each A
-group (if several maxima exist in a group, a random one is picked).
would be age
in your case and B
any of the columns you did not group by but nevertheless want to select.
df = spark.createDataFrame([
[1, 1, 0.2],
[1, 1, 0.9],
[1, 2, 0.6],
[1, 2, 0.5],
[1, 2, 0.6],
[2, 1, 0.2],
[2, 2, 0.1],
], ["group", "A", "B"])
out = (
.withColumn("AB", F.struct("A", "B"))
# F.max(AB) selects AB-combinations with max `A`. If more
# than one combination remains the one with max `B` is selected. If
# after this identical combinations remain, a single one of them is picked
# randomly.
.select("group", F.expr("max_AB.B"))
|group| B|
| 1|0.6|
| 2|0.1|
Here an example that I came across in spark-workshop
val populationDF =
.option("infer-schema", "true")
.option("header", "true")
.select('name, regexp_replace(col("population"), "\\s", "").cast("integer").as("population"))
val maxPopulationDF = populationDF.agg(max('population).as("populationmax"))
To get other columns, I do a simple join between the original DF and the aggregated one
populationDF.join(maxPopulationDF,populationDF.col("population") === maxPopulationDF.col("populationmax")).select('name, 'populationmax).show()
Aggregate functions reduce values of rows for specified columns within the group. If you wish to retain other row values you need to implement reduction logic that specifies a row from which each value comes from. For instance keep all values of the first row with the maximum value of age. To this end you can use a UDAF (user defined aggregate function) to reduce rows within the group.
import org.apache.spark.sql._
import org.apache.spark.sql.functions._
object AggregateKeepingRowJob {
def main (args: Array[String]): Unit = {
val sparkSession = SparkSession
.appName(this.getClass.getName.replace("$", ""))
val sc = sparkSession.sparkContext
import sparkSession.sqlContext.implicits._
val rawDf = Seq(
(1L, "Moe", "Slap", 2.0, 18),
(2L, "Larry", "Spank", 3.0, 15),
(3L, "Curly", "Twist", 5.0, 15),
(4L, "Laurel", "Whimper", 3.0, 15),
(5L, "Hardy", "Laugh", 6.0, 15),
(6L, "Charley", "Ignore", 5.0, 5)
).toDF("id", "name", "requisite", "money", "age")
val maxAgeUdaf = new KeepRowWithMaxAge
val aggDf = rawDf
import org.apache.spark.sql.Row
import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}
import org.apache.spark.sql.types._
class KeepRowWithMaxAmt extends UserDefinedAggregateFunction {
// This is the input fields for your aggregate function.
override def inputSchema: org.apache.spark.sql.types.StructType =
StructField("store", StringType) ::
StructField("prod", StringType) ::
StructField("amt", DoubleType) ::
StructField("units", IntegerType) :: Nil
// This is the internal fields you keep for computing your aggregate.
override def bufferSchema: StructType = StructType(
StructField("store", StringType) ::
StructField("prod", StringType) ::
StructField("amt", DoubleType) ::
StructField("units", IntegerType) :: Nil
// This is the output type of your aggregation function.
override def dataType: DataType =
StructField("store", StringType),
StructField("prod", StringType),
StructField("amt", DoubleType),
StructField("units", IntegerType)
override def deterministic: Boolean = true
// This is the initial value for your buffer schema.
override def initialize(buffer: MutableAggregationBuffer): Unit = {
buffer(0) = ""
buffer(1) = ""
buffer(2) = 0.0
buffer(3) = 0
// This is how to update your buffer schema given an input.
override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
val amt = buffer.getAs[Double](2)
val candidateAmt = input.getAs[Double](2)
amt match {
case a if a < candidateAmt =>
buffer(0) = input.getAs[String](0)
buffer(1) = input.getAs[String](1)
buffer(2) = input.getAs[Double](2)
buffer(3) = input.getAs[Int](3)
case _ =>
// This is how to merge two objects with the bufferSchema type.
override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
buffer1(0) = buffer2.getAs[String](0)
buffer1(1) = buffer2.getAs[String](1)
buffer1(2) = buffer2.getAs[Double](2)
buffer1(3) = buffer2.getAs[Int](3)
// This is where you output the final value, given the final value of your bufferSchema.
override def evaluate(buffer: Row): Any = {
You need to remember that aggregate functions reduce the rows and therefore you need to specify which of the rows name you want with a reducing function. If you want to retain all rows of a group (warning! this can cause explosions or skewed partitions) you can collect them as a list. You can then use a UDF (user defined function) to reduce them by your criteria, in my example money. And then expand columns from the single reduced row with another UDF . For the purpose of this answer I assume you wish to retain the name of the person who has the most money.
import org.apache.spark.sql._
import org.apache.spark.sql.catalyst.expressions.GenericRowWithSchema
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types.StringType
import scala.collection.mutable
object TestJob3 {
def main (args: Array[String]): Unit = {
val sparkSession = SparkSession
.appName(this.getClass.getName.replace("$", ""))
val sc = sparkSession.sparkContext
import sparkSession.sqlContext.implicits._
val rawDf = Seq(
(1, "Moe", "Slap", 2.0, 18),
(2, "Larry", "Spank", 3.0, 15),
(3, "Curly", "Twist", 5.0, 15),
(4, "Laurel", "Whimper", 3.0, 9),
(5, "Hardy", "Laugh", 6.0, 18),
(6, "Charley", "Ignore", 5.0, 5)
).toDF("id", "name", "requisite", "money", "age")
val rawSchema = rawDf.schema
val fUdf = udf(reduceByMoney, rawSchema)
val nameUdf = udf(extractName, StringType)
val aggDf = rawDf
.withColumn("short", fUdf($"horizontal"))
.withColumn("name", nameUdf($"short"))
def reduceByMoney= (x: Any) => {
val d = x.asInstanceOf[mutable.WrappedArray[GenericRowWithSchema]]
val red = d.reduce((r1, r2) => {
val money1 = r1.getAs[Double]("money")
val money2 = r2.getAs[Double]("money")
val r3 = money1 match {
case a if a >= money2 =>
case _ =>
def extractName = (x: Any) => {
val d = x.asInstanceOf[GenericRowWithSchema]
here is the output
|age|count|max(money)|short |name |
|5 |1 |5.0 |[6, Charley, Ignore, 5.0, 5]|Charley|
|15 |2 |5.0 |[3, Curly, Twist, 5.0, 15] |Curly |
|9 |1 |3.0 |[4, Laurel, Whimper, 3.0, 9]|Laurel |
|18 |2 |6.0 |[5, Hardy, Laugh, 6.0, 18] |Hardy |
May be this solution will helpfull.
from pyspark.sql import SQLContext
from pyspark import SparkContext, SparkConf
from pyspark.sql import functions as F
from pyspark.sql import Window
name_list = [(101, 'abc', 24), (102, 'cde', 24), (103, 'efg', 22), (104, 'ghi', 21),
(105, 'ijk', 20), (106, 'klm', 19), (107, 'mno', 18), (108, 'pqr', 18),
(109, 'rst', 26), (110, 'tuv', 27), (111, 'pqr', 18), (112, 'rst', 28), (113, 'tuv', 29)]
age_w = Window.partitionBy("age")
name_age_df = sqlContext.createDataFrame(name_list, ['id', 'name', 'age'])
name_age_count_df = name_age_df.withColumn("count", F.count("id").over(age_w)).orderBy("count")
| id|name|age|count|
|109| rst| 26| 1|
|113| tuv| 29| 1|
|110| tuv| 27| 1|
|106| klm| 19| 1|
|103| efg| 22| 1|
|104| ghi| 21| 1|
|105| ijk| 20| 1|
|112| rst| 28| 1|
|101| abc| 24| 2|
|102| cde| 24| 2|
|107| mno| 18| 3|
|111| pqr| 18| 3|
|108| pqr| 18| 3|
One way to get all columns after doing a groupBy is to use join function.
feature_group = ['name', 'age']
data_counts = df.groupBy(feature_group).count().alias("counts")
data_joined = df.join(data_counts, feature_group)
data_joined will now have all columns including the count values.
You can do like this :
Sample data:
name age id
abc 24 1001
cde 24 1002
efg 22 1003
ghi 21 1004
ijk 20 1005
klm 19 1006
mno 18 1007
pqr 18 1008
rst 26 1009
tuv 27 1010
pqr 18 1012
rst 28 1013
tuv 29 1011"name","age","id").groupBy("name","age").count().show();
| efg| 22| 1|
| tuv| 29| 1|
| rst| 28| 1|
| klm| 19| 1|
| pqr| 18| 2|
| cde| 24| 1|
| tuv| 27| 1|
| ijk| 20| 1|
| abc| 24| 1|
| mno| 18| 1|
| ghi| 21| 1|
| rst| 26| 1|
Long story short in general you have to join aggregated results with the original table. Spark SQL follows the same pre-SQL:1999 convention as most of the major databases (PostgreSQL, Oracle, MS SQL Server) which doesn't allow additional columns in aggregation queries.
Since for aggregations like count results are not well defined and behavior tends to vary in systems which supports this type of queries you can just include additional columns using arbitrary aggregate like first
or last
In some cases you can replace agg
using select
with window functions and subsequent where
but depending on the context it can be quite expensive.
