futurenext110
futurenext110

Reputation: 2109

Creating a dictionary type column in dataframe

Consider the following dataframe:

------------+--------------------+
|id|          values
+------------+--------------------+
|          39|a,a,b,b,c,c,c,c,d
|         520|a,b,c
|         832|a,a

I want to convert it into the following DataFrame:

------------+--------------------+
|id|          values
+------------+--------------------+
|          39|{"a":2, "b": 2,"c": 4,"d": 1}
|         520|{"a": 1,"b": 1,"c": 1}
|         832|{"a": 2}

I tried two approaches:

  1. Converting the dataframe to rdd. Then I mapped the value column to a frequancy counter function. But I get errors on converting the rdd back to the dataframe

  2. Using a udf to essentially do the same thing as above.

The reason I want to have a dictionary column is to load it as a json in one of my python application.

Upvotes: 13

Views: 22216

Answers (3)

dfernig
dfernig

Reputation: 626

You can do this with a udf that returns a MapType column.

from pyspark.sql.types import MapType, StringType, IntegerType
from collections import Counter

my_udf = udf(lambda s: dict(Counter(s.split(','))), MapType(StringType(), IntegerType()))
df = df.withColumn('values', my_udf('values'))
df.collect()

[Row(id=39, values={u'a': 2, u'c': 4, u'b': 2, u'd': 1}),
 Row(id=520, values={u'a': 1, u'c': 1, u'b': 1}),
 Row(id=832, values={u'a': 2})]

Upvotes: 14

futurenext110
futurenext110

Reputation: 2109

I ended up using this; if you feel there is a better approach do let me know.

def split_test(str_in):
  a = str_in.split(',')
  b = {}
  for i in a:
    if i not in b:
      b[i] = 1
    else:
      b[i] += 1

  return str(b)

udf_value_count = udf(split_test, StringType() )

value_count_df = value_df.withColumn('value_count', udf_value_count(value_df.values)).drop('values')

Upvotes: 0

Josemy
Josemy

Reputation: 838

I could not get exactly the output you need, but I was really close. This is what I could get:

from pyspark.sql.functions import explode, split
counts = (df.select("id", explode(split("values", ",")).alias("value")).groupby("id", "value").count())
counts.show()

Output:

+---+-----+-----+
| id|value|count|
+---+-----+-----+
|520|    a|    1|
|520|    b|    1|
|520|    c|    1|
| 39|    a|    2|
| 39|    b|    2|
| 39|    c|    4|
| 39|    d|    1|
|832|    a|    2|
+---+-----+-----+

Probably someone can add what it need to get the output you require. Hope it helps.

Upvotes: 0

Related Questions