Programmer
Programmer

Reputation: 418

spark dataframe aggregation of column based on condition in scala

I have csv data a following in following format.

I need to find top 2 vendor whose turnover is greater than 100 in year 2017.

Turnover= Sum(Invoices whose status is Paid-in-Full ) - Sum(Invoices whose status is Exception or Rejected)

I have loaded the data from csv in datebricks scala notebook as follow:

val invoices_data = spark.read.format(file_type)
                  .option("header", "true")
                  .option("dateFormat", "M/d/yy")
                  .option("inferSchema", "true")
                 .load("invoice.csv")

Then I tried to make group by vendor name

val avg_invoice_by_vendor = invoices_data.groupBy("VendorName")

But Now I don't know how to proceed further.

Here is sample csv data.

Id     InvoiceDate      Status         Invoice   VendorName
    2   2/23/17         Exception       23        V1
    3   11/23/17        Paid-in-Full    56        V1
    1   12/20/17        Paid-in-Full    12        V1
    5   8/4/19          Paid-in-Full    123       V2
    6   2/6/17          Paid-in-Full    237       V2
    9   3/9/17          Rejected        234       V2
    7   4/23/17         Paid-in-Full    78        V3
    8   5/23/17         Exception       345       V4

Upvotes: 0

Views: 115

Answers (2)

Programmer
Programmer

Reputation: 418

I have used pivot method to solve above issue.

invoices_data
              .filter(invoices_data("InvoiceStatusDesc") === "Paid-in-Full" || 
                invoices_data("InvoiceStatusDesc") === "Exception" ||
                invoices_data("InvoiceStatusDesc") === "Rejected")
              .filter(year(to_date(invoices_data("InvoiceDate"), "M/d/yy")) === 2017)
              .groupBy("InvoiceVendorName").pivot("InvoiceStatusDesc").sum("InvoiceTotal")

Upvotes: 1

Boris Azanov
Boris Azanov

Reputation: 4481

You can use udf for sign invoice depends on status and after grouping aggregate df using sum function:

import org.apache.spark.sql.functions._
import org.apache.spark.sql.types.DateType
def signInvoice: (String, Int) => Int = (status: String, invoice: Int) => {
  status match {
    case "Exception" | "Rejected" => -invoice
    case "Paid-in-Full" => invoice
    case _ => throw new IllegalStateException("wrong status")
  }
}

val signInvoiceUdf = spark.udf.register("signInvoice", signInvoice)
val top2_vendorsDF = invoices_data
  .withColumn("InvoiceDate", col("InvoiceDate").cast(DateType))
  .filter(year(col("InvoiceDate")) === lit(2017))
  .withColumn("Invoice", col("Invoice").as[Int])
  .groupBy("VendorName")
  .agg(sum(signInvoiceUdf('Status, 'Invoice)).as("sum_invoice"))
  .filter(col("sum_invoice") > 100)
  .orderBy(col("sum_invoice").desc)
  .take(2)

Upvotes: 1

Related Questions