Harelz
Harelz

Reputation: 162

Clip a Spark DataFrame columns by it's 95% and 5% values

I am trying to make a custom transformer for my model using PySpark & Spark 2.2.

I want to take a DataFrame and make a saturate value (a max value and a min value) based on the column top 95% and 5%. (like median, but 0.95 & 0.05)

For example, I want this DataFrame:

    col_0  col_1
0      1     11
1      2     12
2      3     13
3      4     14
4      5     15
...............
...............
95    96    106
96    97    107
97    98    108
98    99    109
99   100    110

to become this DataFrame, all the other values in the other rows remain the same:

    col_0  col_1
0      5     15
1      5     15
2      5     15
3      5     15
4      5     15
...............
...............
95    96    106
96    96    106
97    96    106
98    96    106
99    96    106

Pandas DataFrame has this kind of function clip. However, I want do this on a DataFrame containing possible hundreds of columns and millions of rows - and do it for each column as efficient as possible.

Thank you very much!

Upvotes: 0

Views: 5570

Answers (1)

user10938362
user10938362

Reputation: 4151

You can easily implement your ow version using approxQuantile method:

from pyspark.sql.functions import col, when

def clip(df, cols, lower=0.05, upper=0.95, relativeError=0.001):
    if not isinstance(cols, (list, tuple)):
        cols = [cols]
    # Create dictionary {column-name: [lower-quantile, upper-quantile]}
    quantiles = {
        c: (when(col(c) < lower, lower)        # Below lower quantile
                .when(col(c) > upper, upper)   # Above upper quantile
                .otherwise(col(c))             # Between quantiles
                .alias(c))   
        for c, (lower, upper) in 
        # Compute array of quantiles
        zip(cols, df.stat.approxQuantile(cols, [lower, upper], relativeError))
    }

    return df.select([quantiles.get(c, col(c)) for c in df.columns])

With example data:

from pyspark.sql.functions import randn

df = spark.range(10000).select(
    "id", 
    *[(randn(i) + i).alias(f"x{i}") for i in range(3)]
)

df.describe().show()
+-------+------------------+--------------------+------------------+------------------+
|summary|                id|                  x0|                x1|                x2|
+-------+------------------+--------------------+------------------+------------------+
|  count|             10000|               10000|             10000|             10000|
|   mean|            4999.5|-0.01196896108307...| 0.987731544718039| 1.990886090228083|
| stddev|2886.8956799071675|  1.0034484456898938|1.0008778389552515|1.0040412784708024|
|    min|                 0|  -3.788858342200328|-2.788858342200328|-1.788858342200328|
|    max|              9999|   4.298596405374866| 5.298596405374866| 6.298596405374866|
+-------+------------------+--------------------+------------------+------------------+

it can be used as shown below

clipped_df = clip(df, df.columns[1:])

clipped_df.describe().show()
+-------+------------------+--------------------+-------------------+-------------------+
|summary|                id|                  x0|                 x1|                 x2|
+-------+------------------+--------------------+-------------------+-------------------+
|  count|             10000|               10000|              10000|              10000|
|   mean|            4999.5|-0.01482931519040...| 0.9857932036239195| 1.9886389327313188|
| stddev|2886.8956799071675|  0.9116409420899081| 0.9093740705137618| 0.9138394144215538|
|    min|                 0| -1.6570053933784237|-0.6408919025527111|0.35174956348817354|
|    max|              9999|  1.6268109900572905| 2.6333091349111575| 3.6369017007810296|
+-------+------------------+--------------------+-------------------+-------------------+

Upvotes: 2

Related Questions