fierywasabi
fierywasabi

Reputation: 15

Why scikit-learn mlp training takes too much time?

I am trying to train a MLP using scikit-learn's MLPClassifier.

from sklearn.neural_network import MLPClassifier

I am training the mlp with 5400 iterations but it takes approximately 40 minutes. What do I do wrong? Here is the created mlp:

mlp= MLPClassifier(hidden_layer_sizes=(128),activation='relu',solver='adam',batch_size=500,shuffle=False,verbose=True)

here is the training part of my code:

for j in range (5400):
    mlp.partial_fit(train_X, y_train,classes=np.unique(y_train))#1 step

train_X dimensions are (27000,784) which is 27000 samples and each sample is 28*28=784 pixels.

My processor is Intel i7-9750H RAM size is 16GB.

Upvotes: 1

Views: 6904

Answers (1)

desertnaut
desertnaut

Reputation: 60370

You don't train it for 5400 iterations, but for possibly up to ~ 1M ones; this is not the way to do it.

Checking the docs, you'll see that MLPClassifier has already a parameter max_iter, with default value 200 (which is the value used in your case, since you don't specify anything different):

max_iter: int, default=200

Maximum number of iterations. The solver iterates until convergence (determined by ‘tol’) or this number of iterations. For stochastic solvers (‘sgd’, ‘adam’), note that this determines the number of epochs (how many times each data point will be used), not the number of gradient steps.

So, if every one of your 5400 iterations exhausts the 200 max_iter, you are actually doing 5400x200 ~= 1,000,000 iterations (epochs).

It's not clear why you choose to use a for loop with a partial_fit; you may want to either go for a full fit with max_iter=5400 and no loop, or stay with your existing loop + partial_fit and change the definition of your MLPClassifier to max_iter=1.

In what you have shown, I cannot see any reason for the loop approach; it would be justified if your data could not fit to memory and you used it to feed different slices of your data per iteration, but as is it does not make any sense.

Upvotes: 3

Related Questions