Marsellus Wallace
Marsellus Wallace

Reputation: 18601

Spark MLlib: Should I call .cache before fitting a model?

Imagine that I am training a Spark MLlib model as follows:

val traingData = loadTrainingData(...)
val logisticRegression = new LogisticRegression()

traingData.cache
val logisticRegressionModel = logisticRegression.fit(trainingData)

Does the call traingData.cache improve performances at training time or is it not needed?

Does the .fit(...) method for a ML algorithm call cache/unpersist internally?

Upvotes: 3

Views: 1060

Answers (1)

Marsellus Wallace
Marsellus Wallace

Reputation: 18601

There is no need to call .cache for Spark LogisticRegression (and some other models). The train method (called by Predictor.fit(...)) in LogisticRegression is implemented as follows:

override protected[spark] def train(dataset: Dataset[_]): LogisticRegressionModel = {
  val handlePersistence = dataset.rdd.getStorageLevel == StorageLevel.NONE // true if not cached-persisted
  train(dataset, handlePersistence)
}

And later...

if (handlePersistence) instances.persist(StorageLevel.MEMORY_AND_DISK)

This will generally even be more efficient than a custom call to .cache as instances in the line above only contains (label, weight, features) and not the rest of the data.

Upvotes: 2

Related Questions