Reputation: 59
I have a question about using weights in keras. I have some data as events and for each of them there is an associated weight. Therefore, when I do the training of my keras model I was using the sample_weight argument to pass that information.
I then notice that if I want to use the model.predict method there isn't an argument to pass the weights... and now I'm not sure if the type of weights I have are the ones I'm supposed to use in the fit method in the sample_weight.
My question is, what type of weights is the fit method supposed to recieve? Also, is it understood that the predict method doesn't need any weight for the data?
Thanks!
Upvotes: 3
Views: 2097
Reputation: 7103
The parameter sample_weight
is used when you do not have the same confidence in all the data in your sample. This way you may tell Keras that you are more confident about some of them more than the others. That is used only for training, as it is used to adjust (weight) the loss function that is used by the optimizer. Hence, in the fit
you should not pass it as you do not know anything about the output and you cannot say anything about your confidence.
From the Keras' docs (https://keras.io/models/model/#train_on_batch)
sample_weight: Optional Numpy array of weights for the training samples, used for weighting the loss function (during training only).
Upvotes: 2
Reputation: 2174
I think there might be some confusion with what "sample_weight" means in the context of model.fit. When you call model.fit, you are minimizing a loss function. That loss function measures the error between your model prediction and the true values. Some of the samples in your dataset might be more important to you, so you would weigh the loss function more heavily on those samples. So, the "sample_weights" are only used to weigh certain samples in your dataset during training to "better fit" those some samples relative to others. They are an optional argument to model.fit (default just weights each sample equal - what you should do if you don't have a good reason to do otherwise). And (hopefully my explanation was clear enough) do not make any sense in the context of model.predict.
Upvotes: 1