prcvl
prcvl

Reputation: 55

How does the predict method work on scikit-learn?

How does the predict() method in scikit-learn work? Does it return random values or is there a calculation under the hood?

Upvotes: 0

Views: 3407

Answers (1)

Steven
Steven

Reputation: 2133

predict() must follow fit(). fit() builds a model that tries to find a pattern that maps input data to the labels. At this stage the input data is called the training set. predict() simply asks your trained model to use those patterns to map new inputs to their labels. They are the model's best guess, given what it was previously trained on, not random. Optimizing the quality of those patterns so that the predictions are as accurate as possible is the whole art and science of machine learning.

Imagine I wanted to make you an expert at identifying cat breed. At first you might have no idea. So first I have to train you. I show you a bunch of pictures and tell you what each cat breed is (label). After a while, you start to see patterns and you take notes. Eventually you start to feel like you get even the finer distinctions and feel ready to identify the cat breed of any cat shown to you. Your new knowledge is called a model. This is what fit() does. It builds a model through training.

Now I start showing you pictures of cats without telling you the breed (label). Using your new knowledge, notes and patterns (i.e. your model), you can now "predict" the breed of a cat you haven't seen before (provided it's one of the breeds you've learned about during training). This is what predict() does. It runs data through the model to get predictions.

You can't predict cat breeds without training, and you can't have predict() without fit() because fit() builds the model that predict() uses.

I strongly recommend following this tutorial that will help you understand how it all fits together.

Upvotes: 3

Related Questions