qwertz
qwertz

Reputation: 373

Python / Pyspark - How to replace some cells with averages?

I have a huge problem and i hope someone can help me please. I want to replace cells in a column with another value.

The dataframe looks like:

----------------------------------------
|Timestamp           | Item_ID | Price |
----------------------------------------
|2017-05-01 11:05:00 | 12345   | 70    |
|2017-05-01 17:20:00 | 98765   | 10    |
|2017-05-01 11:50:00 | 12345   | 20    |
|2017-05-01 19:50:00 | 12345   | 0     |
|2017-05-01 20:17:00 | 12345   | 0     |
|2017-05-01 22:01:00 | 98765   | 0     |
----------------------------------------

As you can see there are different prices for same items over time. For example the Item "12345" has three prices: 70,20 and 0 Now i want to replace all "0" with the average of the other prices. Is something like this possible?

Result should be: For item 12345: (70+20)/2= 45 For item 98765: There is only one price so take this.

----------------------------------------
|Timestamp           | Item_ID | Price |
----------------------------------------
|2017-05-01 11:05:00 | 12345   | 70    |
|2017-05-01 17:20:00 | 98765   | 10    |
|2017-05-01 11:50:00 | 12345   | 20    |
|2017-05-01 19:50:00 | 12345   | 45    |
|2017-05-01 20:17:00 | 12345   | 45    |
|2017-05-01 22:01:00 | 98765   | 10    |
----------------------------------------

Thank you very much and have a nice day! qwertz

Upvotes: 2

Views: 79

Answers (1)

pault
pault

Reputation: 43504

Here is a way to do it using sparkSQL:

from StringIO import StringIO
import pandas as pd

# create dummy data
df = pd.DataFrame.from_csv(StringIO("""Timestamp|Item_ID|Price
2017-05-01 11:05:00|12345|70    
2017-05-01 17:20:00|98765|10    
2017-05-01 11:50:00|12345|20    
2017-05-01 19:50:00|12345|0     
2017-05-01 20:17:00|12345|0     
2017-05-01 22:01:00|98765|0""".replace("\s+", '')), sep="|").reset_index()

df['Timestamp'] = df['Timestamp'].astype(str)
spark_df = sqlCtx.createDataFrame(df)

spark_df.registerTempTable('table')
sqlCtx.sql("""SELECT Timestamp,
    l.Item_ID,
    CASE WHEN l.Price > 0 THEN l.Price ELSE r.Price END AS Price
    FROM table l 
    LEFT JOIN (
        SELECT Item_ID,
        AVG(Price) AS Price
        FROM table
        WHERE Price > 0
        GROUP BY Item_ID
    ) r ON l.Item_ID = r.Item_ID""".replace("\n", ' ')
).show()

The output:

+-------------------+-------+-----+
|Timestamp          |Item_ID|Price|
+-------------------+-------+-----+
|2017-05-01 19:50:00|12345  |45.0 |
|2017-05-01 20:17:00|12345  |45.0 |
|2017-05-01 11:05:00|12345  |70.0 |
|2017-05-01 11:50:00|12345  |20.0 |
|2017-05-01 17:20:00|98765  |10.0 |
|2017-05-01 22:01:00|98765  |10.0 |
+-------------------+-------+-----+

Explanation:

By calling spark_df.registerTempTable('table'), I am registering the spark DataFrame as a temporary table in the SQLContext (which I have named table). The query I am running is to join the table to itself using Item_ID, but one side will have the aggregated (average) values. Then I use the CASE statement to select either the given value, or the aggregate value if the Price is 0.

I called .replace("\n", " ") because newline characters are not supported (I believe they are treated as EOF). This is an easy to way to write a readable query, without having to put it all on one line.

Notes

The technique you are describing is mean imputation. As this is pretty common in the field, I have to believe there's a another (possibly better) way to do this using only spark DataFrame functions (avoiding SQL).

Upvotes: 1

Related Questions