Dominik
Dominik

Reputation: 307

Pyspark: Intersection of multiple arrays

I have the following test data and must check the following statement with the help of pyspark (the data is actually very large: 700000 transactions, each transaction with 10+ products):

import pandas as pd
import datetime

data = {'date': ['2014-01-01', '2014-01-02', '2014-01-03', '2014-01-04', '2014-01-05', '2014-01-06'],
     'customerid': [1, 2, 2, 3, 4, 3], 'productids': ['A;B', 'D;E', 'H;X', 'P;Q;G', 'S;T;U', 'C;G']}
data = pd.DataFrame(data)
data['date'] = pd.to_datetime(data['date'])

"The transactions that exist for a customer ID within x days are characterized by at least one identical product in the shopping cart."

So far I have the following approach (example x = 2):

spark = SparkSession.builder \
    .master('local[*]') \
    .config("spark.driver.memory", "500g") \
    .appName('my-pandasToSparkDF-app') \
    .getOrCreate()
spark.conf.set("spark.sql.execution.arrow.enabled", "true")
spark.sparkContext.setLogLevel("OFF")

df=spark.createDataFrame(data)

x = 2

win = Window().partitionBy('customerid').orderBy(F.col("date").cast("long")).rangeBetween(-(86400*x), Window.currentRow)
test = df.withColumn("productids", F.array_distinct(F.split("productids", "\;")))\
    .withColumn("flat_col", F.array_distinct(F.flatten((F.collect_list("productids").over(win))))).orderBy(F.col("date"))

test = test.toPandas()

enter image description here

So from every transaction we look 2 days into the past, group by customerid and the corresponding products are summarized in the "flat_col" column.

But what I actually need is the intersection of the shopping baskets with the same ID. Only then can I judge whether there are common products.

So instead of ['P', 'Q', 'G', 'C'] in the fifth row of "flat_col" there should be ['G ']. Similarly, [] should appear in all other rows of "flat_col".

Thank you so much!

Upvotes: 0

Views: 870

Answers (2)

murtihash
murtihash

Reputation: 8410

You can achieve this without self-join (as joins are expensive shuffle operations in Big Data) using higher order functions in spark 2.4. Functions used filter,transform,aggregate.

df=spark.createDataFrame(data)

x = 2

win = Window().partitionBy('customerid').orderBy(F.col("date").cast("long")).rangeBetween(-(86400*x), Window.currentRow)
test = df.withColumn("productids", F.array_distinct(F.split("productids", "\;")))\
    .withColumn("flat_col", F.flatten(F.collect_list("productids").over(win)))\
    .withColumn("occurances", F.expr("""filter(transform(productids, x->\
     IF(aggregate(flat_col, 0,(acc,t)->acc+IF(t=x,1,0))>1,x,null)),y->y!='null')"""))\
    .drop("flat_col").orderBy("date").show()

+-------------------+----------+----------+----------+
|               date|customerid|productids|occurances|
+-------------------+----------+----------+----------+
|2014-01-01 00:00:00|         1|    [A, B]|        []|
|2014-01-02 00:00:00|         2|    [D, E]|        []|
|2014-01-03 00:00:00|         2|    [H, X]|        []|
|2014-01-04 00:00:00|         3| [P, Q, G]|        []|
|2014-01-05 00:00:00|         4| [S, T, U]|        []|
|2014-01-06 00:00:00|         3|    [C, G]|       [G]|
+-------------------+----------+----------+----------+

Upvotes: 1

Sergio Alyoshkin
Sergio Alyoshkin

Reputation: 212

Self join is the best trick ever

from pyspark.sql.functions import concat_ws, collect_list
spark.createDataFrame(data).registerTempTable("df")
sql("SELECT date, customerid, explode(split(productids, ';')) productid FROM df").registerTempTable("altered")
df = sql("SELECT al.date, al.customerid, al.productid productids, altr.productid flat_col FROM altered al left join altered altr on altr.customerid = al.customerid and al.productid = altr.productid and al.date != altr.date and datediff(al.date,altr.date) <=2 and datediff(al.date,altr.date) >=-2")
df.groupBy("date", "customerid").agg(concat_ws(",", collect_list("productids")).alias('productids'), concat_ws(",", collect_list("flat_col")).alias('flat_col')).show()

spark output

Upvotes: 0

Related Questions