wu wei
wu wei

Reputation: 89

how to create new column 'count' in Spark DataFrame under some condition

I have a DataFrame about connection log with columns Id, targetIP, Time. Every record in this DataFrame is a connection event to one system. Id means this connection, targetIP means the target IP address this time, Time is the connection time. With Values:

ID Time targetIP
1 1 192.163.0.1
2 2 192.163.0.2
3 3 192.163.0.1
4 5 192.163.0.1
5 6 192.163.0.2
6 7 192.163.0.2
7 8 192.163.0.2

I want to create a new column under some condition: count of connections to this time's target IP address in the past 2 time units. So the result DataFrame should be:

ID Time targetIP count
1 1 192.163.0.1 0
2 2 192.163.0.2 0
3 3 192.163.0.1 1
4 5 192.163.0.1 1
5 6 192.163.0.2 0
6 7 192.163.0.2 1
7 8 192.163.0.2 2

For example, ID=7, the targetIP is 192.163.0.2 Connected to system in past 2 time units, which are ID=5 and ID=6, and their targetIP are also 192.163.0.2. So the count about ID=7 is 2.

Looking forward to your help.

Upvotes: 3

Views: 2357

Answers (2)

asm0dey
asm0dey

Reputation: 2931

So, what you basically need is a window function.

Let's start with your initial data

import org.apache.spark.sql.expressions.Window
import spark.implicits._

case class Event(ID: Int, Time: Int, targetIP: String)

val events = Seq(
    Event(1, 1, "192.163.0.1"),
    Event(2, 2, "192.163.0.2"),
    Event(3, 3, "192.163.0.1"),
    Event(4, 5, "192.163.0.1"),
    Event(5, 6, "192.163.0.2"),
    Event(6, 7, "192.163.0.2"),
    Event(7, 8, "192.163.0.2")
).toDS()

Now we need to define a window function itself

val timeWindow = Window.orderBy($"Time").rowsBetween(-2, -1)

And now the most interesting part: how to count something over the window? There is no simple way, so we'll do the following

  1. Aggregate all the targetIp's into the list
  2. Filter the list to find only needed ips
  3. Count size of the list
val df = events
        .withColumn("tmp", collect_list($"targetIp").over(timeWindow))
        .withColumn("count", size(expr("filter(tst, x -> x == targetIp)")))
        .drop($"tmp")

And the result will contain a new column "count" which we need!

UPD:

There is the much shorter version without aggregation, written by @blackbishop,

val timeWindow = Window.partitionBy($"targetIP").orderBy($"Time").rangeBetween(-2, Window.currentRow)
val df = events
        .withColumn("count", count("*").over(timeWindow) - lit(1))
        .explain(true)

Upvotes: 3

blackbishop
blackbishop

Reputation: 32700

You can use count over Window bounded with range between - 2 and current row, to get the count of IP in the last 2 time units.

Using Spark SQL you can do something like this:

df.createOrReplaceTempView("connection_logs")

df1 = spark.sql("""
    SELECT  *,
            COUNT(*) OVER(PARTITION BY targetIP ORDER BY Time 
                          RANGE BETWEEN 2 PRECEDING AND CURRENT ROW
                          ) -1 AS count
    FROM    connection_logs
    ORDER BY ID
""")

df1.show()

#+---+----+-----------+-----+
#| ID|Time|   targetIP|count|
#+---+----+-----------+-----+
#|  1|   1|192.163.0.1|    0|
#|  2|   2|192.163.0.2|    0|
#|  3|   3|192.163.0.1|    1|
#|  4|   5|192.163.0.1|    1|
#|  5|   6|192.163.0.2|    0|
#|  6|   7|192.163.0.2|    1|
#|  7|   8|192.163.0.2|    2|
#+---+----+-----------+-----+

Or using DataFrame API:

from pyspark.sql import Window
from pyspark.sql import functions as F

time_unit = lambda x: x

w = Window.partitionBy("targetIP").orderBy(col("Time").cast("int")).rangeBetween(-time_unit(2), 0)

df1 = df.withColumn("count", F.count("*").over(w) - 1).orderBy("ID")

df1.show()

Upvotes: 1

Related Questions