Reputation: 2153
Hello StackOverflowers.
I have a pyspark dataframe that consists of a time_column and a column with values. E.g.
+----------+--------------------+
| snapshot| values|
+----------+--------------------+
|2005-01-31| 0.19120256617637743|
|2005-01-31| 0.7972692479278891|
|2005-02-28|0.005236883665445502|
|2005-02-28| 0.5474099672222935|
|2005-02-28| 0.13077227571485905|
+----------+--------------------+
I would like to perform a KS test of each snapshot value with the previous one.
I tried to do it with a for loop.
import numpy as np
from scipy.stats import ks_2samp
import pyspark.sql.functions as F
def KS_for_one_snapshot(temp_df, snapshots_list, j, var = "values"):
sample1=temp_df.filter(F.col("snapshot")==snapshots_list[j])
sample2=temp_df.filter(F.col("snapshot")==snapshots_list[j-1]) # pick the last snapshot as the one to compare with
if (sample1.count() == 0 or sample2.count() == 0 ):
ks_value = -1 # previously "0 observations" which gave type error
else:
ks_value, p_value = ks_2samp( np.array(sample1.select(var).collect()).reshape(-1)
, np.array(sample2.select(var).collect()).reshape(-1)
, alternative="two-sided"
, mode="auto")
return ks_value
results = []
snapshots_list = df.select('snapshot').dropDuplicates().sort('snapshot').rdd.flatMap(lambda x: x).collect()
for j in range(len(snapshots_list) - 1 ):
results.append(KS_for_one_snapshot(df, snapshots_list, j+1))
results
But the data in reality is huge so it takes forever. I am using databricks and pyspark, so I wonder what would be a more efficient way to run it by avoiding the for loop and utilizing the available workers.
I tried to do it by using a udf but in vain.
Any ideas?
PS. you can generate the data with the following code.
from random import randint
df = (spark.createDataFrame( range(1,1000), T.IntegerType())
.withColumn('snapshot' ,F.array(F.lit("2005-01-31"), F.lit("2005-02-28"),F.lit("2005-03-30") ).getItem((F.rand()*3).cast("int")))
.withColumn('values', F.rand()).drop('value')
)
Update:
I tried the following by using an UDF.
var_used = 'values'
data_input_1 = df.groupBy('snapshot').agg(F.collect_list(var_used).alias('value_list'))
data_input_2 = df.groupBy('snapshot').agg(F.collect_list(var_used).alias("value_list_2"))
windowSpec = Window.orderBy("snapshot")
data_input_2 = data_input_2.withColumn('snapshot_2', F.lag("snapshot", 1).over(Window.orderBy('snapshot'))).filter('snapshot_2 is not NULL')
data_input_final = data_input_final = data_input_1.join(data_input_2, data_input_1.snapshot == data_input_2.snapshot_2)
def KS_one_snapshot_general(sample_in_list_1, sample_in_list_2):
if (len(sample_in_list_1) == 0 or len(sample_in_list_2) == 0 ):
ks_value = -1 # previously "0 observations" which gave type error
else:
print('something')
ks_value, p_value = ks_2samp( sample_in_list_1
, sample_in_list_2
, alternative="two-sided"
, mode="auto")
return ks_value
import pyspark.sql.types as T
KS_one_snapshot_general_udf = udf(KS_one_snapshot_general, T.FloatType())
data_input_final.select( KS_one_snapshot_general_udf('value_list', 'value_list_2')).display()
Which works fine if the dataset (per snapshot) is small. But If I increase the number of rows then I end up with an error. PickleException: expected zero arguments for construction of ClassDict (for numpy.dtype)
Upvotes: 1
Views: 290