Reputation: 14001
I have cost function in tensorflow.
activation = tf.add(tf.mul(X, W), b)
cost = (tf.pow(Y-y_model, 2)) # use sqr error for cost function
I am trying out this example. How can I change it to rmse cost function?
Upvotes: 16
Views: 32153
Reputation: 482
for who want to implement RMSE as a metric
rmse = tf.keras.metrics.RootMeanSquaredError()
exapmle of how to use it
model.compile(optimizer=optimizer, loss='mean_squared_error',
metrics=[rmse,'mae'])
Upvotes: 6
Reputation: 123
Now we have tf.losses.mean_squared_error
Therefore,
RMSE = tf.sqrt(tf.losses.mean_squared_error(label, prediction))
Upvotes: 8
Reputation: 1905
tf.sqrt(tf.reduce_mean(tf.square(tf.subtract(targets, outputs))))
And slightly simplified (TensorFlow overloads the most important operators):
tf.sqrt(tf.reduce_mean((targets - outputs)**2))
Upvotes: 49
Reputation: 21917
(1) Are you sure you need this? Minimizing the l2 loss will give you the same result as minimizing the RMSE error. (Walk through the math: You don't need to take the square root, because minimizing x^2 still minimizes x for x>0, and you know that the sum of a bunch of squares is positive. Minimizing x*n minimizes x for constant n).
(2) If you need to know the numerical value of the RMSE error, then implement it directly from the definition of RMSE:
tf.sqrt(tf.reduce_sum(...)/n)
(You need to know or calculate n - the number of elements in the sum, and set the reduction axis appropriately in the call to reduce_sum).
Upvotes: 6
Reputation: 222521
The formula for root mean square error is:
The way to implement it in TF is tf.sqrt(tf.reduce_mean(tf.squared_difference(Y1, Y2)))
.
The important thing to remember is that there is no need to minimize RMSE loss with the optimizer. With the same result you can minimize just tf.reduce_mean(tf.squared_difference(Y1, Y2))
or even tf.reduce_sum(tf.squared_difference(Y1, Y2))
but because they have a smaller graph of operations, they will be optimized faster.
But you can use this function if you just want to tract the value of RMSE.
Upvotes: 15