Reputation: 307
I have the following test data and must check the following statement with the help of pyspark (the data is actually very large: 700000 transactions, each transaction with 10+ products):
import pandas as pd
import datetime
data = {'date': ['2014-01-01', '2014-01-02', '2014-01-03', '2014-01-04', '2014-01-05', '2014-01-06'],
'customerid': [1, 2, 2, 3, 4, 3], 'productids': ['A;B', 'D;E', 'H;X', 'P;Q;G', 'S;T;U', 'C;G']}
data = pd.DataFrame(data)
data['date'] = pd.to_datetime(data['date'])
"The transactions that exist for a customer ID within x days are characterized by at least one identical product in the shopping cart."
So far I have the following approach (example x = 2):
spark = SparkSession.builder \
.master('local[*]') \
.config("spark.driver.memory", "500g") \
.appName('my-pandasToSparkDF-app') \
.getOrCreate()
spark.conf.set("spark.sql.execution.arrow.enabled", "true")
spark.sparkContext.setLogLevel("OFF")
df=spark.createDataFrame(data)
x = 2
win = Window().partitionBy('customerid').orderBy(F.col("date").cast("long")).rangeBetween(-(86400*x), Window.currentRow)
test = df.withColumn("productids", F.array_distinct(F.split("productids", "\;")))\
.withColumn("flat_col", F.array_distinct(F.flatten((F.collect_list("productids").over(win))))).orderBy(F.col("date"))
test = test.toPandas()
So from every transaction we look 2 days into the past, group by customerid and the corresponding products are summarized in the "flat_col" column.
But what I actually need is the intersection of the shopping baskets with the same ID. Only then can I judge whether there are common products.
So instead of ['P', 'Q', 'G', 'C'] in the fifth row of "flat_col" there should be ['G ']. Similarly, [] should appear in all other rows of "flat_col".
Thank you so much!
Upvotes: 0
Views: 870
Reputation: 8410
You can achieve this without self-join
(as joins are expensive shuffle
operations in Big Data) using higher order functions
in spark 2.4
. Functions used filter,transform,aggregate
.
df=spark.createDataFrame(data)
x = 2
win = Window().partitionBy('customerid').orderBy(F.col("date").cast("long")).rangeBetween(-(86400*x), Window.currentRow)
test = df.withColumn("productids", F.array_distinct(F.split("productids", "\;")))\
.withColumn("flat_col", F.flatten(F.collect_list("productids").over(win)))\
.withColumn("occurances", F.expr("""filter(transform(productids, x->\
IF(aggregate(flat_col, 0,(acc,t)->acc+IF(t=x,1,0))>1,x,null)),y->y!='null')"""))\
.drop("flat_col").orderBy("date").show()
+-------------------+----------+----------+----------+
| date|customerid|productids|occurances|
+-------------------+----------+----------+----------+
|2014-01-01 00:00:00| 1| [A, B]| []|
|2014-01-02 00:00:00| 2| [D, E]| []|
|2014-01-03 00:00:00| 2| [H, X]| []|
|2014-01-04 00:00:00| 3| [P, Q, G]| []|
|2014-01-05 00:00:00| 4| [S, T, U]| []|
|2014-01-06 00:00:00| 3| [C, G]| [G]|
+-------------------+----------+----------+----------+
Upvotes: 1
Reputation: 212
Self join is the best trick ever
from pyspark.sql.functions import concat_ws, collect_list
spark.createDataFrame(data).registerTempTable("df")
sql("SELECT date, customerid, explode(split(productids, ';')) productid FROM df").registerTempTable("altered")
df = sql("SELECT al.date, al.customerid, al.productid productids, altr.productid flat_col FROM altered al left join altered altr on altr.customerid = al.customerid and al.productid = altr.productid and al.date != altr.date and datediff(al.date,altr.date) <=2 and datediff(al.date,altr.date) >=-2")
df.groupBy("date", "customerid").agg(concat_ws(",", collect_list("productids")).alias('productids'), concat_ws(",", collect_list("flat_col")).alias('flat_col')).show()
Upvotes: 0