Hans
Hans

Reputation: 165

R keras Beginner Question predict_classes

I have no prior experience with Tensorflow or Keras. I am trying to follow the tutorial https://tensorflow.rstudio.com/tutorials/beginners/

library(keras)


mnist <- dataset_mnist()
mnist$train$x <- mnist$train$x/255
mnist$test$x <- mnist$test$x/255

model <- keras_model_sequential() %>% 
  layer_flatten(input_shape = c(28, 28)) %>% 
  layer_dense(units = 128, activation = "relu") %>% 
  layer_dropout(0.2) %>% 
  layer_dense(10, activation = "softmax")

summary(model)

model %>% 
  compile(
    loss = "sparse_categorical_crossentropy",
    optimizer = "adam",
    metrics = "accuracy"
  )

#Note that compile and fit (which we are going to see next) modify the model object in place, unlike most R functions.

model %>% 
  fit(
    x = mnist$train$x, y = mnist$train$y,
    epochs = 5,
    validation_split = 0.3,
    verbose = 2
  )

predictions <- predict(model, mnist$test$x)
head(predictions, 2)

class_predictions <- predict(model, mnist$test$x) %>% k_argmax()
class_predictions

predict_classes is deprecated. k_armax() was advertised as the alterative in the error. However I have no idea how to get to predicted classes (the digits 0-9 in this case) as a vector to use it in a confusionMatrix, like with other R models. Any help would be appreciated.

Upvotes: 1

Views: 503

Answers (2)

J MG
J MG

Reputation: 3

predict() %>% k_argmax() returns a Tensor object. To replicate what would've been the result of predict_classes(), what you want to do is convert that Tensor object to a vector. You can do so as such:

class_predictions <- predict(model, mnist$test$x) %>% k_argmax() %>% as.vector()

Additionally, this page may be useful.

Upvotes: 0

Hans
Hans

Reputation: 165

For this problem the following code works

predictions <- predict(model, mnist$test$x)
pred_digits <- apply(predictions, 1, which.max) -1
confusionMatrix(as.factor(pred_digits), as.factor(mnist$test$y))

However I still think it odd that predict_classes has been deprecated without replacement. All the tutorials I have looked at so far use it.

Upvotes: 1

Related Questions