
Reputation: 847

Spark Array column - Find max interval between two values

I have a Scala Spark dataframe with the schema:

     |-- passengerId: string (nullable = true)
     |-- travelHist: array (nullable = true)
     |    |-- element: integer (containsNull = true)

I want to iterate through the array elements and find the max number of occurrences of 0 values between 1 and 2.

passengerID travelHist
1 1, 0, 0, 0, 0, 2, 1, 0, 0, 0, 0, 0, 0, 0, 2, 1, 0
2 0, 0, 0, 0, 0, 0, 0, 0, 2, 1, 0, 0, 0, 2, 0, 0, 0, 0
3 0,0,0,2,1,0,2,1,0

The output for the above records should look like below:

passengerID maxStreak
1 7
2 3
3 1

What would be the most efficient way to find such an interval assuming the number of elements in the array does not exceed 50 values?

Upvotes: 1

Views: 121

Answers (2)


Reputation: 2468

Here's a solution using scala UDF in pyspark. You can find the code for the UDF and release jar used in the pyspark script in the following repository.


Code for scala UDF is as follows.

package com.help.udf

import org.apache.spark.sql.api.java.UDF1

import scala.collection.mutable
import util.control.Breaks._
import scala.reflect.runtime.currentMirror
import scala.tools.reflect.ToolBox

class CountZeros extends UDF1[Array[Int], Int] {

  override def call(given_array: Array[Int]): Int = {

    //println("Printing all element")
    //given_array.foreach(ele => print (ele + ",  "))
    //println("adding the debug printing ")
    var maxCount = -1

    var runningCount = -1
    var insideLoop = false

    for( ele <- given_array ){

        if (ele == 1) {
          // initialize count to 0
          runningCount = 0
          insideLoop = true

        if (ele == 0 && insideLoop) {
          runningCount += 1

        if (ele == 2 && insideLoop) {
          insideLoop = false
          if (maxCount == -1) {
            maxCount = runningCount
          if (runningCount > maxCount) {
            maxCount = runningCount


      //println( "ele ", ele, " maxCount  ", maxCount, "  runningCount  ", runningCount, " insideLoop flag  ", insideLoop)

    //println("maxCount" , maxCount)

Following is the pyspark code which uses the above UDF.

import sys

import pyspark.sql.functions as F
from pyspark import SparkContext, SQLContext
from pyspark.sql import SparkSession
from pyspark.sql.types import *

spark = SparkSession.builder \
    .appName("MyApp") \
    .config("spark.jars", "file:/path/to/pyspark-native-udfs/releases/pyspark-native-udfs-assembly-0.1.2.jar") \

sc = spark.sparkContext
sqlContext = SQLContext(sc)

data1 = [
    [1, [1, 0, 0, 0, 0, 2, 1, 0, 0, 0, 0, 0, 0, 0, 2, 1, 0]],
    [2, [0, 0, 0, 0, 0, 0, 0, 0, 2, 1, 0, 0, 0, 2, 0, 0, 0, 0]],
    [3, [0,0,0,2,1,0,2,1,0]],

df1Columns = ["passengerID", "travelHist"]
df1 = sqlContext.createDataFrame(data=data1, schema=df1Columns)
df1 = df1.withColumn("travelHist", F.col("travelHist").cast("array<int>"))

df1.show(n=100, truncate=False)

spark.udf.registerJavaFunction("count_zeros_udf", "com.help.udf.CountZeros", IntegerType())


df1_array = sqlContext.sql("select *, count_zeros_udf(travelHist) as maxStreak from given_table")
print("Dataframe after applying SCALA NATIVE UDF")
df1_array.show(n=100, truncate=False)

Output :

|passengerID|travelHist                                            |
|1          |[1, 0, 0, 0, 0, 2, 1, 0, 0, 0, 0, 0, 0, 0, 2, 1, 0]   |
|2          |[0, 0, 0, 0, 0, 0, 0, 0, 2, 1, 0, 0, 0, 2, 0, 0, 0, 0]|
|3          |[0, 0, 0, 2, 1, 0, 2, 1, 0]                           |

 |-- passengerID: long (nullable = true)
 |-- travelHist: array (nullable = true)
 |    |-- element: integer (containsNull = true)

Dataframe after applying SCALA NATIVE UDF
|passengerID|travelHist                                            |maxStreak|
|1          |[1, 0, 0, 0, 0, 2, 1, 0, 0, 0, 0, 0, 0, 0, 2, 1, 0]   |7        |
|2          |[0, 0, 0, 0, 0, 0, 0, 0, 2, 1, 0, 0, 0, 2, 0, 0, 0, 0]|3        |
|3          |[0, 0, 0, 2, 1, 0, 2, 1, 0]                           |1        |

Upvotes: 2

Shubham Sharma
Shubham Sharma

Reputation: 71689

Let us do some pattern matching

df1 = (
    .withColumn('matches', F.expr("array_join(travelHist, '')"))
    .withColumn('matches', F.expr("regexp_extract_all(matches, '1(0+)2', 1)"))
    .withColumn('matches', F.expr("transform(matches, x -> length(x))"))
    .withColumn('maxStreak', F.expr("array_max(matches)"))

|passengerID|          travelHist|matches|maxStreak|
|          1|[1, 0, 0, 0, 0, 2...| [4, 7]|        7|
|          2|[0, 0, 0, 0, 0, 0...|    [3]|        3|
|          3|[0, 0, 0, 2, 1, 0...|    [1]|        1|

Upvotes: 3

Related Questions