Adam Merckx
Adam Merckx

Reputation: 1214

Where can I find the algorithm behind model.predict?

I would like to implement the code for model.predict (https://keras.io/models/model/) in C++. But I am unable to find the exact logic (equations, formula) used in prediction?

For C++, I implemented the source code here: https://github.com/Dobiasd/frugally-deep but unfortunately could not find the equation behind the predict function. (Frugally deep exports the model as a .json file and does the prediction using the predict function).

Would there be any resources that I could refer to find the equations for model.predict?

Upvotes: 0

Views: 189

Answers (2)

Dr. Snoopy
Dr. Snoopy

Reputation: 56367

model.predict implements a forward pass of the model, so there is no direct equation, the computation is inferred from the computation graph of the model.

So in order to implement the same behavior, you have to do a forward pass through the layers of the model, where each layer implements its own computation, so its not a simple recommendation of use equation X, because its a large set of computational formulas that you have to implement, one for each kind of layer.

Upvotes: 1

Michael Kolber
Michael Kolber

Reputation: 1429

Looking at the repo, it appears you're looking for this.

Upvotes: 1

Related Questions