Ivan Lee
Ivan Lee

Reputation: 4261

How to count the frequency of words with CountVectorizer in spark ML?

The below code gives a count vector for each row in the DataFrame:

import org.apache.spark.ml.feature.{CountVectorizer, CountVectorizerModel}

val df = spark.createDataFrame(Seq(
  (0, Array("a", "b", "c")),
  (1, Array("a", "b", "b", "c", "a"))
)).toDF("id", "words")

// fit a CountVectorizerModel from the corpus
val cvModel: CountVectorizerModel = new CountVectorizer()
  .setInputCol("words")
  .setOutputCol("features")
  .fit(df)


cvModel.transform(df).show(false)

The result is:

+---+---------------+-------------------------+
|id |words          |features                 |
+---+---------------+-------------------------+
|0  |[a, b, c]      |(3,[0,1,2],[1.0,1.0,1.0])|
|1  |[a, b, b, c, a]|(3,[0,1,2],[2.0,2.0,1.0])|
+---+---------------+-------------------------+

How to get total counts of each words, like:

+---+------+------+
|id |words |counts|
+---+------+------+
|0  |a     |  3   |
|1  |b     |  3   |
|2  |c     |  2   |
+---+------+------+

Upvotes: 3

Views: 1491

Answers (2)

hichris123
hichris123

Reputation: 10223

Shankar's answer only gives you the actual frequencies if the CountVectorizer model keeps every single word in the corpus (e.g. no minDF or VocabSize limitations). In these cases you can use Summarizer to directly sum each Vector. Note: this requires Spark 2.3+ for Summarizer.

import org.apache.spark.ml.stat.Summarizer.metrics

// You need to select normL1 and another item (like mean) because, for some reason, Spark
// won't allow one Vector to be selected at a time (at least in 2.4)
val totalCounts = cvModel.transform(df)
    .select(metrics("normL1", "mean").summary($"features").as("summary"))
    .select("summary.normL1", "summary.mean")
    .as[(Vector, Vector)]
    .first()
    ._1

You'll then have to zip totalCounts with cvModel.vocabulary to get the words themselves.

Upvotes: 3

koiralo
koiralo

Reputation: 23099

You can simply explode and groupBy to get the count of each word

cvModel.transform(df).withColumn("words", explode($"words"))
  .groupBy($"words")
  .agg(count($"words").as("counts"))
  .withColumn("id", row_number().over(Window.orderBy("words")) -1)
  .show(false)

Output:

+-----+------+---+
|words|counts|id |
+-----+------+---+
|a    |3     |1  |
|b    |3     |2  |
|c    |2     |3  |
+-----+------+---+

Upvotes: 1

Related Questions