
Reputation: 25

Pyspark SQL: Keep entries with only null values in pivot table

I am trying to create a pivot table on a PySpark SQL dataframe, which doesn't drop the null values. My input table has the following structure:


I am running everything in the IBM Data Science Experience cloud under Python 2 with spark 2.1.

When doing it on a pandas dataframe the "dropna=false" parameter gives me the result I want.

table= pd.pivot_table(ratings,columns=['movieId'],index=[ 'monthyear','userId'], values='rating', dropna=False)

As an output I get the following:

Pandas pivot result

In PySpark SQL I am using at the moment the following command:

ratings_pivot = spark_df.groupBy('monthyear','userId').pivot('movieId').sum("rating").show()

As an output I get the following:

PySpark pivot result

As you can see, all the entries with only null values are not shown. Is there a possibility to use something similar like dropna=false in SQL? Since this is very specific, I can´t find anything about that in the internet.

I just extracted a small dataset for reproduction:

df = spark.createDataFrame([("1", 30, 2.5,200912), ("1", 32, 3.0,200912), ("2", 40, 4.0,201002), ("3", 45, 2.5,200002)], ("userID", "movieID", "rating", "monthyear"))

|     1|     30|   2.5|   200912|
|     1|     32|   3.0|   200912|
|     2|     40|   4.0|   201002|
|     3|     45|   2.5|   200002|

If I now run the pivot query, I get the following result:


|monthyear|UserID|  30|  32|  40|  45|
|   201002|     2|null|null| 4.0|null|
|   200912|     1| 2.5| 3.0|null|null|
|   200002|     3|null|null|null| 2.5|

What I want now, is that in the results looks like the following:

|monthyear|UserID|  30|  32|  40|  45|
|   201002|     2|null|null| 4.0|null|
|   200912|     2|null|null|null|null|
|   200002|     2|null|null|null|null|
|   200912|     1| 2.5| 3.0|null|null|
|   200002|     1|null|null|null|null|
|   201002|     1|null|null|null|null|
|   200002|     3|null|null|null| 2.5|
|   200912|     3|null|null|null|null|
|   201002|     3|null|null|null|null|

Upvotes: 2

Views: 3067

Answers (2)

Alper t. Turker
Alper t. Turker

Reputation: 35249

Spark provide anything like this, because it just won't scale. pivot alone is expensive enough. It can be done manually with outer join:

n = 20 # Adjust value depending on the data

wide = (df
    # Get unique months
    .coalesce(n)  # Coalesce to avoid partition number "explosion"
    # Same as above for UserID and get Cartesian product
    # Join with pivoted data
        df.groupBy("monthyear", "UserID")
        ["monthyear", "UserID"], 

# +---------+------+----+----+----+----+
# |monthyear|UserID|  30|  32|  40|  45|
# +---------+------+----+----+----+----+
# |   201002|     3|null|null|null|null|
# |   201002|     2|null|null| 4.0|null|
# |   200002|     1|null|null|null|null|
# |   200912|     1| 2.5| 3.0|null|null|
# |   200002|     3|null|null|null| 2.5|
# |   200912|     2|null|null|null|null|
# |   200912|     3|null|null|null|null|
# |   201002|     1|null|null|null|null|
# |   200002|     2|null|null|null|null|
# +---------+------+----+----+----+----+

Upvotes: 2


Reputation: 1

Spark does keep entries with all null values, for both rows and columns:

Spark 2.1:

Welcome to
      ____              __
     / __/__  ___ _____/ /__
    _\ \/ _ \/ _ `/ __/  '_/
   /__ / .__/\_,_/_/ /_/\_\   version 2.1.1

Using Python version 3.6.4 (default, Dec 21 2017 21:42:08)
SparkSession available as 'spark'.

In [1]: df = spark.createDataFrame([("a", 1, 4), ("a", 2, 2), ("b", 3, None), (None, 4, None)], ("x", "y", "z"))

In [2]: df.groupBy("x").pivot("y").sum("z").show()
|   x|   1|   2|   3|   4|
|   b|null|null|null|null|
|   a|   4|   2|null|null|

Spark 2.2:

Welcome to
      ____              __
     / __/__  ___ _____/ /__
    _\ \/ _ \/ _ `/ __/  '_/
   /__ / .__/\_,_/_/ /_/\_\   version 2.2.1

Using Python version 3.6.4 (default, Dec 21 2017 21:42:08)
SparkSession available as 'spark'.

In [1]: df = spark.createDataFrame([("a", 1, 4), ("a", 2, 2), ("b", 3, None), (None, 4, None)], ("x", "y", "z"))

In [2]: df.groupBy("x").pivot("y").sum("z").show()
|   x|   1|   2|   3|   4|
|   b|null|null|null|null|
|   a|   4|   2|null|null|

Spark 2.3:

Welcome to
      ____              __
     / __/__  ___ _____/ /__
    _\ \/ _ \/ _ `/ __/  '_/
   /__ / .__/\_,_/_/ /_/\_\   version 2.3.0

Using Python version 3.6.4 (default, Dec 21 2017 21:42:08)
SparkSession available as 'spark'.

In [1]: df = spark.createDataFrame([("a", 1, 4), ("a", 2, 2), ("b", 3, None), (None, 4, None)], ("x", "y", "z"))

In [2]: df.groupBy("x").pivot("y").sum("z").show()
|   x|   1|   2|   3|   4|
|   b|null|null|null|null|
|   a|   4|   2|null|null|

Upvotes: 2

Related Questions