Reputation: 2663
I've been working on this in pyspark for a while and I'm stuck. I'm trying to get the median of the column numbers for its respective window. I need to do this without the use of other libraries such as numpy etc.
So far (as depicted below), I've grouped the dataset into windows by the column id. This is depicted by the column row_numbers which shows you what each window looks like. There are three windows in this dataframe example.
This is what I would like:
I want each row to also contain the median of the window of the column id without taking its own row into consideration. The location of the median I require is in my function below called median_loc
Example: For the row_number = 5, I need to find the median for the rows 1 to 4 above it (i.e. not including row_number 5). The median therefore (per my requirement) is the average of the column id in the same window where row_number = 1 and row_number = 2 i.e
Date id numbers row_number med_loc
2017-03-02 group 1 98 1 [1]
2017-04-01 group 1 50 2 [1]
2018-03-02 group 1 5 3 [1, 2]
2016-03-01 group 2 49 1 [1]
2016-12-22 group 2 81 2 [1]
2017-12-31 group 2 91 3 [1, 2]
2018-08-08 group 2 19 4 [2]
2018-09-25 group 2 52 5 [1, 2]
2017-01-01 group 3 75 1 [1]
2018-12-12 group 3 17 2 [1]
The code I used to get the last column med_loc is as follow
def median_loc(sz):
if sz == 1 or sz == 0:
kth = [1]
return kth
elif sz % 2 == 0 and sz > 1:
szh = sz // 2
kth = [szh - 1, szh] if szh != 1 else [1, 2]
return kth
elif sz % 2 != 0 and sz > 1:
kth = [(sz + 1) // 2]
return kth
sqlContext.udf.register("median_location", median_loc)
median_loc = F.udf(median_loc)
df = df.withColumn("med_loc", median_loc(df.row_number)-1)
Note : I only made them look like a list for easier understanding. It is just to show where the median is located in the respective window. It is just for the easier understanding of folks reading this on Stack Overflow
The output that I desire is as follows:
Date id numbers row_number med_loc median
2017-03-02 group 1 98 1 [1] 98
2017-04-01 group 1 50 2 [1] 98
2018-03-02 group 1 5 3 [1, 2] 74
2016-03-01 group 2 49 1 [1] 49
2016-12-22 group 2 81 2 [1] 49
2017-12-31 group 2 91 3 [1, 2] 65
2018-08-08 group 2 19 4 [2] 81
2018-09-25 group 2 52 5 [1, 2] 65
2017-01-01 group 3 75 1 [1] 75
2018-12-12 group 3 17 2 [1] 75
Basically, the way to get the median so far is something like this:
If med_loc is one digit (i.e if the list has just one digit such [1] or [3] et.) then median = df.numbers where df.row_number = df.med_loc
If med_loc is two digits (i.e if the list has two digits such as [1,2] or [2, 3] etc.) then median = average(df.numbers) where df.row_number in df.med_loc
I cannot stress enough how important it is for me not to use other libraries such as numpy etc. to get the output. There are other solutions that I looked at that used np.median and they work, however, that is not my requirement at this time.
I'm sorry if this explanation is so winded and if I'm complicating it. I've been looking at this for days and can't seem to figure it out. I also tried to use the percent_rank function but I'm not able to figure it out because not all windows contain 0.5 percentile.
Any help will be appreciated.
Upvotes: 0
Views: 618
Reputation: 43504
Suppose you start with the following DataFrame, df
:
+----------+-------+-------+
| Date| id|numbers|
+----------+-------+-------+
|2017-03-02|group 1| 98|
|2017-04-01|group 1| 50|
|2018-03-02|group 1| 5|
|2016-03-01|group 2| 49|
|2016-12-22|group 2| 81|
|2017-12-31|group 2| 91|
|2018-08-08|group 2| 19|
|2018-09-25|group 2| 52|
|2017-01-01|group 3| 75|
|2018-12-12|group 3| 17|
+----------+-------+-------+
First add the row_number
as you did in your example and assign the output to a new DataFrame df2
:
import pyspark.sql.functions as f
from pyspark.sql import Window
df2 = df.select(
"*", f.row_number().over(Window.partitionBy("id").orderBy("Date")).alias("row_number")
)
df2.show()
#+----------+-------+-------+----------+
#| Date| id|numbers|row_number|
#+----------+-------+-------+----------+
#|2017-03-02|group 1| 98| 1|
#|2017-04-01|group 1| 50| 2|
#|2018-03-02|group 1| 5| 3|
#|2016-03-01|group 2| 49| 1|
#|2016-12-22|group 2| 81| 2|
#|2017-12-31|group 2| 91| 3|
#|2018-08-08|group 2| 19| 4|
#|2018-09-25|group 2| 52| 5|
#|2017-01-01|group 3| 75| 1|
#|2018-12-12|group 3| 17| 2|
#+----------+-------+-------+----------+
Now you can join df2
to itself on the id
column with the condition that the left side's row number
is 1
or it is greater than the right side's row_number
. Then group by the left DataFrame's ("id", "Date", "row_number")
and collect the numbers
from the right DataFrame into a list.
For the case when the row_number
is equal to 1, we only want to keep the first element of this collected list. Otherwise keep all the numbers, but sort them because we need to have them ordered to calculate the median.
Call this intermediate DataFrame df3
:
df3 = df2.alias("l").join(df2.alias("r"), on="id", how="left")\
.where("l.row_number = 1 OR (r.row_number < l.row_number)")\
.groupBy("l.id", "l.Date", "l.row_number")\
.agg(f.collect_list("r.numbers").alias("numbers"))\
.select(
"id",
"Date",
"row_number",
f.when(
f.col("row_number") == 1,
f.array([f.col("numbers").getItem(0)])
).otherwise(f.sort_array("numbers")).alias("numbers")
)
df3.show()
#+-------+----------+----------+----------------+
#| id| Date|row_number| numbers|
#+-------+----------+----------+----------------+
#|group 1|2017-03-02| 1| [98]|
#|group 1|2017-04-01| 2| [98]|
#|group 1|2018-03-02| 3| [50, 98]|
#|group 2|2016-03-01| 1| [49]|
#|group 2|2016-12-22| 2| [49]|
#|group 2|2017-12-31| 3| [49, 81]|
#|group 2|2018-08-08| 4| [49, 81, 91]|
#|group 2|2018-09-25| 5|[19, 49, 81, 91]|
#|group 3|2017-01-01| 1| [75]|
#|group 3|2018-12-12| 2| [75]|
#+-------+----------+----------+----------------+
Notice that the numbers
column of df3
has a list of the appropriate values for which we want to find the median.
Since your version of Spark is greater than 2.1, you can use pyspark.sql.functions.posexplode()
to compute the median from this list of values. For lower versions of spark, you would need to use a udf
.
First create 2 helper columns in df3
:
isEven
: A Boolean to indicate if the numbers
array has an even number of elementsmiddle
: The index of the middle of the array, which is the floor of the length / 2After these columns are created, explode the array using posexplode()
, which will return two new columns: pos
and col
. We then filter out the resultant DataFrame to only keep the positions that we need to compute the median.
The logic on which positions to keep is as follows:
isEven
is False
, we only keep the middle positionisEven
is True
, we keep the middle position and the middle position - 1.Finally group by the id
and Date
and average the remaining numbers
.
df3.select(
"*",
f.when(
(f.size("numbers") % 2) == 0,
f.lit(True)
).otherwise(f.lit(False)).alias("isEven"),
f.floor(f.size("numbers")/2).alias("middle")
).select(
"id",
"Date",
"middle",
f.posexplode("numbers")
).where(
"(isEven=False AND middle=pos) OR (isEven=True AND pos BETWEEN middle-1 AND middle)"
).groupby("id", "Date").agg(f.avg("col").alias("median")).show()
#+-------+----------+------+
#| id| Date|median|
#+-------+----------+------+
#|group 1|2017-03-02| 98.0|
#|group 1|2017-04-01| 98.0|
#|group 1|2018-03-02| 74.0|
#|group 2|2016-03-01| 49.0|
#|group 2|2016-12-22| 49.0|
#|group 2|2017-12-31| 65.0|
#|group 2|2018-08-08| 81.0|
#|group 2|2018-09-25| 65.0|
#|group 3|2017-01-01| 75.0|
#|group 3|2018-12-12| 75.0|
#+-------+----------+------+
Upvotes: 2