Reputation: 847
I have a Scala Spark dataframe with the schema:
root
|-- passengerId: string (nullable = true)
|-- travelHist: array (nullable = true)
| |-- element: integer (containsNull = true)
I want to iterate through the array elements and find the max number of occurrences of 0 values between 1 and 2.
passengerID | travelHist |
---|---|
1 | 1, 0, 0, 0, 0, 2, 1, 0, 0, 0, 0, 0, 0, 0, 2, 1, 0 |
2 | 0, 0, 0, 0, 0, 0, 0, 0, 2, 1, 0, 0, 0, 2, 0, 0, 0, 0 |
3 | 0,0,0,2,1,0,2,1,0 |
The output for the above records should look like below:
passengerID | maxStreak |
---|---|
1 | 7 |
2 | 3 |
3 | 1 |
What would be the most efficient way to find such an interval assuming the number of elements in the array does not exceed 50 values?
Upvotes: 1
Views: 121
Reputation: 2468
Here's a solution using scala UDF in pyspark. You can find the code for the UDF and release jar used in the pyspark script in the following repository.
https://github.com/dineshdharme/pyspark-native-udfs
Code for scala UDF is as follows.
package com.help.udf
import org.apache.spark.sql.api.java.UDF1
import scala.collection.mutable
import util.control.Breaks._
import scala.reflect.runtime.currentMirror
import scala.tools.reflect.ToolBox
class CountZeros extends UDF1[Array[Int], Int] {
override def call(given_array: Array[Int]): Int = {
//println("Printing all element")
//given_array.foreach(ele => print (ele + ", "))
//println("adding the debug printing ")
var maxCount = -1
var runningCount = -1
var insideLoop = false
for( ele <- given_array ){
if (ele == 1) {
// initialize count to 0
runningCount = 0
insideLoop = true
}
if (ele == 0 && insideLoop) {
runningCount += 1
}
if (ele == 2 && insideLoop) {
insideLoop = false
if (maxCount == -1) {
maxCount = runningCount
}
if (runningCount > maxCount) {
maxCount = runningCount
}
}
//println( "ele ", ele, " maxCount ", maxCount, " runningCount ", runningCount, " insideLoop flag ", insideLoop)
}
//println("maxCount" , maxCount)
maxCount
}
}
Following is the pyspark code which uses the above UDF.
import sys
import pyspark.sql.functions as F
from pyspark import SparkContext, SQLContext
from pyspark.sql import SparkSession
from pyspark.sql.types import *
spark = SparkSession.builder \
.appName("MyApp") \
.config("spark.jars", "file:/path/to/pyspark-native-udfs/releases/pyspark-native-udfs-assembly-0.1.2.jar") \
.getOrCreate()
sc = spark.sparkContext
sqlContext = SQLContext(sc)
data1 = [
[1, [1, 0, 0, 0, 0, 2, 1, 0, 0, 0, 0, 0, 0, 0, 2, 1, 0]],
[2, [0, 0, 0, 0, 0, 0, 0, 0, 2, 1, 0, 0, 0, 2, 0, 0, 0, 0]],
[3, [0,0,0,2,1,0,2,1,0]],
]
df1Columns = ["passengerID", "travelHist"]
df1 = sqlContext.createDataFrame(data=data1, schema=df1Columns)
df1 = df1.withColumn("travelHist", F.col("travelHist").cast("array<int>"))
df1.show(n=100, truncate=False)
df1.printSchema()
spark.udf.registerJavaFunction("count_zeros_udf", "com.help.udf.CountZeros", IntegerType())
df1.createOrReplaceTempView("given_table")
df1_array = sqlContext.sql("select *, count_zeros_udf(travelHist) as maxStreak from given_table")
print("Dataframe after applying SCALA NATIVE UDF")
df1_array.show(n=100, truncate=False)
Output :
+-----------+------------------------------------------------------+
|passengerID|travelHist |
+-----------+------------------------------------------------------+
|1 |[1, 0, 0, 0, 0, 2, 1, 0, 0, 0, 0, 0, 0, 0, 2, 1, 0] |
|2 |[0, 0, 0, 0, 0, 0, 0, 0, 2, 1, 0, 0, 0, 2, 0, 0, 0, 0]|
|3 |[0, 0, 0, 2, 1, 0, 2, 1, 0] |
+-----------+------------------------------------------------------+
root
|-- passengerID: long (nullable = true)
|-- travelHist: array (nullable = true)
| |-- element: integer (containsNull = true)
Dataframe after applying SCALA NATIVE UDF
+-----------+------------------------------------------------------+---------+
|passengerID|travelHist |maxStreak|
+-----------+------------------------------------------------------+---------+
|1 |[1, 0, 0, 0, 0, 2, 1, 0, 0, 0, 0, 0, 0, 0, 2, 1, 0] |7 |
|2 |[0, 0, 0, 0, 0, 0, 0, 0, 2, 1, 0, 0, 0, 2, 0, 0, 0, 0]|3 |
|3 |[0, 0, 0, 2, 1, 0, 2, 1, 0] |1 |
+-----------+------------------------------------------------------+---------+
Upvotes: 2
Reputation: 71689
Let us do some pattern matching
df1 = (
df
.withColumn('matches', F.expr("array_join(travelHist, '')"))
.withColumn('matches', F.expr("regexp_extract_all(matches, '1(0+)2', 1)"))
.withColumn('matches', F.expr("transform(matches, x -> length(x))"))
.withColumn('maxStreak', F.expr("array_max(matches)"))
)
df1.show()
+-----------+--------------------+-------+---------+
|passengerID| travelHist|matches|maxStreak|
+-----------+--------------------+-------+---------+
| 1|[1, 0, 0, 0, 0, 2...| [4, 7]| 7|
| 2|[0, 0, 0, 0, 0, 0...| [3]| 3|
| 3|[0, 0, 0, 2, 1, 0...| [1]| 1|
+-----------+--------------------+-------+---------+
Upvotes: 3