M80
M80

Reputation: 994

Get the first not-null value in a group

In Spark SQL how to get the first not-null ( or matching text like not 'N/A' ) in a group. In the below example user is watching tv-channel, first 3 records are channel 100, the SIGNAL_STRENGHT is N/A, where as the next record has the value of Good, so I want to use it.

I tried the windows function, but i has methods like MAX, MIN, etc

IF i use lead I get only next row, if I use unbounded I don;t see a method something like fistNotNull. Please advise

Input ?

CUSTOMER_ID || TV_CHANNEL_ID || TIME || SIGNAL_STRENGHT
1 || 100 || 0|| N/A
1 || 100 || 1|| Good
1 || 100 || 2 || Meduim
1 || 100 || 3|| N/A
1 || 100 || 4|| Poor
1 || 100 || 5 || Meduim
1 || 200 || 6 || N/A
1 || 200 || 7 || N/A
1 || 200 || 8 || Poor
1 || 300 || 9 || Good
1 || 300 || 10 || Good
1 || 300 || 11 || Good

Expected Output ?

CUSTOMER_ID || TV_CHANNEL_ID || TIME || SIGNAL_STRENGHT
1 || 100 || 0|| Good
1 || 100 || 1|| Good
1 || 100 || 2 || Meduim
1 || 100 || 3|| Poor
1 || 100 || 4|| Poor
1 || 100 || 5 || Meduim
1 || 200 || 6 || Poor
1 || 200 || 7 || Poor
1 || 200 || 8 || Poor
1 || 300 || 9 || Good
1 || 300 || 10 || Good
1 || 300 || 11 || Good

Actual code

    package com.ganesh.test;

    import org.apache.spark.SparkContext;
    import org.apache.spark.sql.*;
    import org.apache.spark.sql.expressions.Window;
    import org.apache.spark.sql.expressions.WindowSpec;
    import org.slf4j.Logger;
    import org.slf4j.LoggerFactory;

    public class ChannelLoader {

        private static final Logger LOGGER = LoggerFactory.getLogger(ChannelLoader.class);

        public static void main(String[] args) throws AnalysisException {
            String master = "local[*]";
            //region
            SparkSession sparkSession = SparkSession
                    .builder()
                    .appName(ChannelLoader.class.getName())
                    .master(master).getOrCreate();
            SparkContext context = sparkSession.sparkContext();
            context.setLogLevel("ERROR");

            SQLContext sqlCtx = sparkSession.sqlContext();

            Dataset<Row> rawDataset = sparkSession.read()
                    .format("com.databricks.spark.csv")
                    .option("delimiter", ",")
                    .option("header", "true")
                    .load("sample_channel.csv");

            rawDataset.printSchema();

            rawDataset.createOrReplaceTempView("channelView");
            //endregion

            WindowSpec windowSpec = Window.partitionBy("CUSTOMER_ID").orderBy("TV_CHANNEL_ID");

            rawDataset = sqlCtx.sql("select * ," +
                    " ( isNan(SIGNAL_STRENGHT) over ( partition by CUSTOMER_ID, TV_CHANNEL_ID order by TIME ROWS BETWEEN CURRENT ROW AND UNBOUNDED FOLLOWING )  ) as updatedStren " +
                    " from channelView " +
                    " order by CUSTOMER_ID, TV_CHANNEL_ID, TIME "
            );

            rawDataset.show();

            sparkSession.close();

        }
    }

UPDATE

I looked at many possible ways but no luck. So I used brute force and got the desired result, I compute several columns and derive the result. I decided to convert N/A to null, so that when I use collect_list it does not appear.

    rawDataset = sqlCtx.sql("select * " +
            " , ( collect_list(SIGNAL_STRENGTH) " +
            " over ( partition by CUSTOMER_ID, TV_CHANNEL_ID order by TIME ROWS BETWEEN CURRENT ROW AND UNBOUNDED FOLLOWING )  )" +
            " as fwdValues " +
            " , ( collect_list(SIGNAL_STRENGTH) " +
            " over ( partition by CUSTOMER_ID, TV_CHANNEL_ID order by TIME ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW )  )" +
            " as bkwdValues " +
            " , ( row_number() over ( partition by CUSTOMER_ID, TV_CHANNEL_ID order by TIME ) ) as rank_fwd " +
            " , ( row_number() over ( partition by CUSTOMER_ID, TV_CHANNEL_ID order by TIME DESC ) ) as rank_bkwd " +
            " from channelView " +
            " order by CUSTOMER_ID, TV_CHANNEL_ID, TIME "
    );
    rawDataset.show();
    rawDataset.createOrReplaceTempView("updatedChannelView");
    sqlCtx.sql("select * " +
            " , SIGNAL_STRENGTH " +
            ", ( case " +
            "   when (SIGNAL_STRENGTH IS NULL AND rank_bkwd = 1) then bkwdValues[size(bkwdValues)-1] " +
            "   when (SIGNAL_STRENGTH IS NULL ) then fwdValues[0] " +
            "   else SIGNAL_STRENGTH " +
            "  end ) as NEW_SIGNAL_STRENGTH" +
            " from updatedChannelView " +
            ""
    ).show();

Output from code

     +-----------+-------------+----+---------------+--------------------+--------------------+--------+---------+---------------+-------------------+
    |CUSTOMER_ID|TV_CHANNEL_ID|TIME|SIGNAL_STRENGTH|           fwdValues|          bkwdValues|rank_fwd|rank_bkwd|SIGNAL_STRENGTH|NEW_SIGNAL_STRENGTH|
    +-----------+-------------+----+---------------+--------------------+--------------------+--------+---------+---------------+-------------------+
    |          1|          100|   0|           null|[Good, Meduim, Poor]|                  []|       1|        6|           null|               Good|
    |          1|          100|   1|           Good|[Good, Meduim, Poor]|              [Good]|       2|        5|           Good|               Good|
    |          1|          100|   2|         Meduim|      [Meduim, Poor]|      [Good, Meduim]|       3|        4|         Meduim|             Meduim|
    |          1|          100|   3|           null|              [Poor]|      [Good, Meduim]|       4|        3|           null|               Poor|
    |          1|          100|   4|           Poor|              [Poor]|[Good, Meduim, Poor]|       5|        2|           Poor|               Poor|
    |          1|          100|   5|           null|                  []|[Good, Meduim, Poor]|       6|        1|           null|               Poor|
    |          1|          200|   6|           null|              [Poor]|                  []|       1|        3|           null|               Poor|
    |          1|          200|   7|           null|              [Poor]|                  []|       2|        2|           null|               Poor|
    |          1|          200|   8|           Poor|              [Poor]|              [Poor]|       3|        1|           Poor|               Poor|
    |          1|          300|  10|           null|              [Good]|                  []|       1|        3|           null|               Good|
    |          1|          300|  11|           null|              [Good]|                  []|       2|        2|           null|               Good|
    |          1|          300|   9|           Good|              [Good]|              [Good]|       3|        1|           Good|               Good|
    +-----------+-------------+----+---------------+--------------------+--------------------+--------+---------+---------------+-------------------+

Upvotes: 2

Views: 2307

Answers (1)

Prem
Prem

Reputation: 11955

Hope this helps!

[Edit note - Solution approach updated after original question was modified]

import pyspark.sql.functions as f

df = sc.parallelize([
    [1, 100, 0, None],
    [1, 100, 1, 'Good'],
    [1, 100, 2, 'Medium'],
    [1, 100, 3, None],
    [1, 100, 4, 'Poor'],
    [1, 100, 5, 'Medium'],
    [1, 200, 6, None],
    [1, 200, 7, None],
    [1, 200, 8, 'Poor'],
    [1, 300, 9, 'Good'],
    [1, 300,10, 'Good'],
    [1, 300,11, 'Good']
]).toDF(('customer_id', 'tv_channel_id', 'time', 'signal_strength'))
df.show()

#convert to pandas dataframe and fill NA as per the requirement then convert it back to spark dataframe
df1 = df.sort('customer_id', 'tv_channel_id','time').select('customer_id', 'tv_channel_id', 'signal_strength')
p_df = df1.toPandas()
p_df["signal_strength"] = p_df.groupby(["customer_id","tv_channel_id"]).transform(lambda x: x.fillna(method='bfill'))
df2= sqlContext.createDataFrame(p_df).withColumnRenamed("signal_strength","signal_strength_new")

#replace 'signal_strength' column of original dataframe with the column of above pandas dataframe
df=df.withColumn('row_index', f.monotonically_increasing_id())
df2=df2.withColumn('row_index', f.monotonically_increasing_id())
final_df = df.join(df2, on=['customer_id', 'tv_channel_id','row_index']).drop("row_index","signal_strength").\
    withColumnRenamed("signal_strength_new","signal_strength").\
    sort('customer_id', 'tv_channel_id','time')
final_df.show()

Output is :

+-----------+-------------+----+---------------+
|customer_id|tv_channel_id|time|signal_strength|
+-----------+-------------+----+---------------+
|          1|          100|   0|           Good|
|          1|          100|   1|           Good|
|          1|          100|   2|         Medium|
|          1|          100|   3|           Poor|
|          1|          100|   4|           Poor|
|          1|          100|   5|         Medium|
|          1|          200|   6|           Poor|
|          1|          200|   7|           Poor|
|          1|          200|   8|           Poor|
|          1|          300|   9|           Good|
|          1|          300|  10|           Good|
|          1|          300|  11|           Good|
+-----------+-------------+----+---------------+

Upvotes: 1

Related Questions