Angelo
Angelo

Reputation: 655

Pivot a Pyspark DataFrame to get a MultiColumn

My Pyspark Dataframe looks like this:

+--------+----------+----+----+----+
|latitude| longitude|var1|date|var2|
+--------+----------+----+----+----+
|    3.45|     -8.65|   1|   7|   2|
|   30.45|     45.65|   1|   7|   2|
|   40.45|    123.65|   1|   7|   2|
|   43.45|     13.65|   1|   7|   2|
|   44.45|    -12.65|   1|   7|   2|
|   54.45|   -128.65|   1|   7|   2|
+--------+----------+----+----+----+

but I dont know how to reshape it to get only a register for each date and a multicolumn specifying [variable, latitude, longitude] in that order, so I could treat each combination of variable, latitude and longitude in a separated column.

Making this:

df.select(
    'date',
    *[F.array(F.col(col), F.col('latitude'), F.col('longitude')) for col in var_cols]
).show()

I get:

+----+---------------------------------+---------------------------------+
|date|array(var1, latitude, longitude) |array(var2, latitude, longitude) |
+----+---------------------------------+---------------------------------+
|   7|               [1.0, 3.45, -8.65]|               [2.0, 3.45, -8.65]|
|   7|              [1.0, 30.45, 45.65]|              [2.0, 30.45, 45.65]|
|   7|             [1.0, 40.45, 123.65]|             [2.0, 40.45, 123.65]|
|   7|              [1.0, 43.45, 13.65]|              [2.0, 43.45, 13.65]|
|   7|             [1.0, 44.45, -12.65]|             [2.0, 44.45, -12.65]|
|   7|             [1.0, 54.45, -128...|             [2.0, 54.45, -128...|
+----+---------------------------------+---------------------------------+

And I would like a column with a single value (the value of the var) and a column by EACH VALUE of the latitude and longitude. Imagine making an index of [date, latitude, longitude] in pandas and then unstacking the latitude and longitude columns.

For example, in pandas I would do this:

df.set_index(["date", "latitude", "longitude"]).unstack().unstack()

Upvotes: 1

Views: 103

Answers (2)

mck
mck

Reputation: 42422

How about this:

var_cols = [col for col in df.columns if col not in ['date', 'latitude', 'longitude']]

df.withColumn('latlong',
              F.concat_ws('_', F.col('latitude'), F.col('longitude'))) \
  .groupBy('date') \
  .pivot('latlong') \
  .agg(*[F.first(col) for col in var_cols])

Upvotes: 1

Angelo
Angelo

Reputation: 655

I came across this solution:

var_cols = [col for col in df.columns if col not in ['date', 'latitude', 'longitude']]

df = df.withColumn('latlong',F.array(F.col('latitude'), F.col('longitude')))

df = df.withColumn('latlong', F.concat_ws(',', 'latlong'))
df = df.groupBy(["date"]).pivot("latlong").max(*var_cols)

Upvotes: 1

Related Questions