Kenny
Kenny

Reputation: 1982

Convert Spark Data frame to multiple list with one column as key

Consider a Spark data frame df like this

+----+-------+----+----+
|bin|median|min|end|
+----+-------+----+----+
|   1|    0.0|   0|   0.5|
|   2|    1.0|   0.8|   1.7|
|   3|    2.0|   1.6|   2.5|
|   4|    4.0|   3.7|   4.7|
|   5|    6.0|   5.7|   6.3|

I would like to pull out each attribute/column as a separate dictionary/list with bin being key, meaning

median[1] = 0.0 #df[df.bin == 1]
median[key= 1,2,3,4,5] = [0.0,1.0,2.0,4.0,6.0]
min[key= 1,2,3,4,5] = [0,0.8,1.6,3.7,5.7]

I am thinking of something like mapping into rdd, how about something more "dataframe" manipulation ? Is there a way to pull out all the lists at the same time ?

median = {}
df.rdd.map(lambda row : median[row.bin] = row.median)

What is the answer if I want to pull out list instead of dictionary, assuming the bin will be numbered continuously from 1 ? How do we make sure keeping the order ? .orderBy().collect() ?

Upvotes: 1

Views: 669

Answers (2)

abiratsis
abiratsis

Reputation: 7316

Here is another approach which provides support for both key and column filtering. The solution consists of two functions:

  • as_dict(df, cols, ids, key): returns data into a dictionary
  • extract_col_from_dict(dct, col, ids): extracts the column data from a dictionary

Initially let's extract the desired data into a dictionary from the given dataframe:

def as_dict(df, cols = [], ids = [], key = 0):
  key_idx = 0

  if isinstance(key, int):
    key_idx = key
    key = df.columns[key_idx]
  elif isinstance(key, str):
    key_idx = df.columns.index(key)
  else:
    raise Exception("Please provide a valid key e.g:{1, 'col1'}")

  df = df.select("*") if not cols else df.select(*[[key] + cols])

  if ids:
    df = df.where(df[key].isin(ids))

  return df.rdd.map(lambda x : (x[key_idx], x.asDict())).collectAsMap()

Arguments:

  • df: the dataframe
  • cols: the columns that you want to work with, default include all columns
  • ids: in order to avoid collecting all the dataset on the driver you can filter based on this. This applies for the key column. Default include all records
  • key: the key column, it can be string/int, default 0

Let's call the function with your dataset:

df = spark.createDataFrame(
[(1, 0.0, 0., 0.5),
(2, 1.0, 0.8, 1.7),
(3, 2.0, 1.6, 2.5),
(4, 4.0, 3.7, 4.7),
(5, 6.0, 5.7, 6.3)], ["bin", "median", "min", "end"])

dict_ = as_dict(df)
dict_
{1: {'bin': 1, 'min': 0.0, 'end': 0.5, 'median': 0.0},
 2: {'bin': 2, 'min': 0.8, 'end': 1.7, 'median': 1.0},
 3: {'bin': 3, 'min': 1.6, 'end': 2.5, 'median': 2.0},
 4: {'bin': 4, 'min': 3.7, 'end': 4.7, 'median': 4.0},
 5: {'bin': 5, 'min': 5.7, 'end': 6.3, 'median': 6.0}}

# or with filters applied
dict_ = as_dict(df, cols = ['min', 'end'], ids = [1, 2, 3])
dict_
{1: {'bin': 1, 'min': 0.0, 'end': 0.5},
 2: {'bin': 2, 'min': 0.8, 'end': 1.7},
 3: {'bin': 3, 'min': 1.6, 'end': 2.5}}

The function will map the records to key/value pairs where the value will be also a dictionary (calling row.asDict).

After calling as_dict function the data will be located on the driver and now you can extract the data that you need with the extract_col_from_dict:

def extract_col_from_dict(dct, col, ids = []):
  filtered = {}
  if ids:
    filtered = { key:val for key, val in dct.items() if key in ids }
  else:
    filtered = { key:val for key, val in dct.items() }

  return [d[col] for d in list(filtered.values())]

Arguments:

  • dct: the source dictionary
  • col: column to be extracted
  • ids: more filtering, default all records

And the output of the function:

min_data = extract_col_from_dict(dict_, 'min')
min_data
[0.0, 0.8, 1.6, 3.7, 5.7]

Upvotes: 1

pault
pault

Reputation: 43494

If you're trying to collect your data anyway, the easiest way IMO to get the data in your desired format is via pandas.

You can call toPandas(), set the index to bin, and then call to_dict():

output = df.toPandas().set_index("bin").to_dict()
print(output)
#{'end': {1: 0.5, 2: 1.7, 3: 2.5, 4: 4.7, 5: 6.3},
# 'median': {1: 0.0, 2: 1.0, 3: 2.0, 4: 4.0, 5: 6.0},
# 'min': {1: 0.0, 2: 0.8, 3: 1.6, 4: 3.7, 5: 5.7}}

This will create a dictionary of dictionaries, where the outer key is the column name and the inner key is the bin. If you wanted separate variables, you can just extract from output, but don't use min as a variable name since it will stomp on __builtin__.min.

median, min_, end = output['median'], output['min'], output['end']
print(median[1])
#0.0

Upvotes: 1

Related Questions