Reputation: 25535
Although I am having an answer for what I want to achieve, the problem is that it's way to slow. The data set is not very large. It's 50GB in total but the affected part is probably just between 5 to 10GB of data. However, the following is what I require, but it's way to slow And by slow I mean it was running for an hour and it didn't terminate.
df_ = spark.createDataFrame([
('1', 'hello how are are you today'),
('1', 'hello how are you'),
('2', 'hello are you here'),
('2', 'how is it'),
('3', 'hello how are you'),
('3', 'hello how are you'),
('4', 'hello how is it you today')
], schema=['label', 'text'])
tokenizer = Tokenizer(inputCol='text', outputCol='tokens')
tokens = tokenizer.transform(df_)
token_counts.groupby('label')\
.agg(F.collect_list(F.struct(F.col('token'), F.col('count'))).alias('text'))\
.show(truncate=False)
Which gives me the token count for each label:
+-----+----------------------------------------------------------------+
|label|text |
+-----+----------------------------------------------------------------+
|3 |[[are,2], [how,2], [hello,2], [you,2]] |
|1 |[[today,1], [how,2], [are,3], [you,2], [hello,2]] |
|4 |[[hello,1], [how,1], [is,1], [today,1], [you,1], [it,1]] |
|2 |[[hello,1], [are,1], [you,1], [here,1], [is,1], [how,1], [it,1]]|
+-----+----------------------------------------------------------------+
However, I think the call to explode()
is way too expensive for this.
I don't know but it might be faster to count the tokens in each "dokument" and later merge it in a groupBy()
:
df_.select(['label'] + [udf_get_tokens(F.col('text')).alias('text')])\
.rdd.map(lambda x: (x[0], list(Counter(x[1]).items()))) \
.toDF(schema=['label', 'text'])\
.show()
Gives the counts:
+-----+--------------------+
|label| text|
+-----+--------------------+
| 1|[[are,2], [hello,...|
| 1|[[are,1], [hello,...|
| 2|[[are,1], [hello,...|
| 2|[[how,1], [it,1],...|
| 3|[[are,1], [hello,...|
| 3|[[are,1], [hello,...|
| 4|[[you,1], [today,...|
+-----+--------------------+
Is there a way to merge those token counts in a more efficient way?
Upvotes: 1
Views: 93
Reputation: 35249
If groups defined by id
are largish the obvious target for improvement is shuffle size. Instead of shuffling text, shuffle labels. First vectorize input
from pyspark.ml.feature import CountVectorizer
from pyspark.ml import Pipeline
pipeline_model = Pipeline(stages=[
Tokenizer(inputCol='text', outputCol='tokens'),
CountVectorizer(inputCol='tokens', outputCol='vectors')
]).fit(df_)
df_vec = pipeline_model.transform(df_).select("label", "vectors")
Then aggregate:
from pyspark.ml.linalg import SparseVector, DenseVector
from collections import defaultdict
def seq_func(acc, v):
if isinstance(v, SparseVector):
for i in v.indices:
acc[int(i)] += v[int(i)]
if isinstance(v, DenseVector):
for i in len(v):
acc[int(i)] += v[int(i)]
return acc
def comb_func(acc1, acc2):
for k, v in acc2.items():
acc1[k] += v
return acc1
aggregated = rdd.aggregateByKey(defaultdict(int), seq_func, comb_func)
And map back to the required output:
vocabulary = pipeline_model.stages[-1].vocabulary
def f(x, vocabulary=vocabulary):
# For list of tuples use [(vocabulary[i], float(v)) for i, v in x.items()]
return {vocabulary[i]: float(v) for i, v in x.items()}
aggregated.mapValues(f).toDF(["id", "text"]).show(truncate=False)
# +---+-------------------------------------------------------------------------------------+
# |id |text |
# +---+-------------------------------------------------------------------------------------+
# |4 |[how -> 1.0, today -> 1.0, is -> 1.0, it -> 1.0, hello -> 1.0, you -> 1.0] |
# |3 |[how -> 2.0, hello -> 2.0, are -> 2.0, you -> 2.0] |
# |1 |[how -> 2.0, hello -> 2.0, are -> 3.0, you -> 2.0, today -> 1.0] |
# |2 |[here -> 1.0, how -> 1.0, are -> 1.0, is -> 1.0, it -> 1.0, hello -> 1.0, you -> 1.0]|
# +---+-------------------------------------------------------------------------------------+
This worth trying only if text part is considerably large - otherwise all required transformations between DataFrame
and Python objects might be more expensive than collecting_list
.
Upvotes: 2