
Reputation: 5838

Extract weights from fitted regr.nnet object in mlr3

This question is related to the solution provided by @Sebastian for a previous question. It showed how to do repeated training for a regr.nnet learner using a custom (=fixed) resampling strategy and cloned learners.


x <- 1:20
obs <- data.frame(
  x = rep(x, 3),
  f = factor(rep(c("a", "b", "c"), each = 20)),
  y = c(3 * dnorm(x, 10, 3), 5 * dlnorm(x, 2, 0.5), dexp(20 - x, .5)) + 
        rnorm(60, sd = 0.02)

x_test <- seq(0, 20, length.out = 100)
test <- expand.grid(
  x = x_test,
  f = c("a", "b", "c"),
  y = c(3 * dnorm(x_test, 10, 3), 5 * dlnorm(x_test, 2, 0.5), 
        dexp(20 - x_test, .5)) + rnorm(60, sd = 0.02)

dat <- rbind(obs, test)
task <- as_task_regr(dat, target = "y")
resampling <- rsmp("custom")
resampling$instantiate(task, list(train = 1:60), test = list(61:90060))
learner = lrn("regr.nnet", size=5, trace=FALSE)

learners <- replicate(10, learner$clone())
design <- benchmark_grid(
  tasks = task,
  learners = learners,
bmr <- benchmark(design)

The next part now is to evaluate the benchmark further and to use the model for a further evaluation within and outside of mlr3. In the following, I tried to evaluate model performance and to plot predictions for the test data:

## evaluate quality criteria
bmr$aggregate()[learner_id == "regr.nnet"] # ok
bmr$aggregate(msr("time_train")) # works
# bmr$aggregate(msr("regr.rmse"), msr("regr.rsq"), msr("regr.bias")) # not possible

## select the best fit
i_best  <- which.min(bmr$aggregate()$regr.mse)
best    <- bmr$resample_result(i_best)

## do prediction
pr      <-$predictions()[[1]])$response

## visualization
pred_test <-  test |>  mutate(y = pr)
ggplot(obs, aes(x, y)) + geom_point() +
  geom_line(data = pred_test, mapping = aes(x, y)) +

The R6 style has of course its advantages, I had been involved myself in the development of the pre-R6 proto package, but it is sometimes not so easy to find the best way to access internal data. The mlr3 book is very helpful, but questions remain:

  1. Is it easily possible to extract additional measures, e.g. msr("regr.rmse") from the benchmark object?
  2. I am not satisfied with my code line pr <- ....., but found no better way yet.
  3. Finally I want to get access internal data structure of the fitted nnet, to extract the raw weights for an "offline" use of the neural network outside of R.

Upvotes: 1

Views: 136

Answers (2)


Reputation: 5838

Based on the answer of @Sebastian, I have been able to solve a few minor remaining details, especially how to combine it with the previous thread. I took the iris example similar to @Sebastian's answer, but modified it to a regression task with Species as input and Petal.width as target. Then, I integrated the "brute force" method to train several networks with the same settings and apply the same subdivision into training and test data.

#> Loading required package: mlr3

data(iris) # to show that we use this data set

## half of indices for training and other for test subset
id_train  <- sample(1:nrow(iris), nrow(iris) %/% 2)
id_test   <- which(!((1:nrow(iris)) %in% id_train))

## create task, learner and custom resampling strategy
task <- as_task_regr(iris, target = "Petal.Width")
learner <- lrn("regr.nnet", size = 5, trace = FALSE)
resampling <- rsmp("custom")
resampling$instantiate(task, train = list(id_train), test = list(id_test))

## replicate the learners, create and run a benchmark design
learners <- replicate(10, learner$clone())
design <- benchmark_grid(
  tasks = task,
  learners = learners,
  resamplings = resampling

bmr <- benchmark(design, store_models = TRUE)

## summary of results and "best" result according to mse
bmr$aggregate(msrs(c("regr.mse", "regr.rmse", "regr.rsq", "time_train")))

(i_best <- which.min(bmr$aggregate()$regr.mse))
#> [1] 3
best <- bmr$resample_result(i_best)

## prediction for the test set

## or for both subsets
best$learners[[1]]$predict(task, row_ids = id_train)$response
best$learners[[1]]$predict(task, row_ids = id_test)$response

## extract raw model and the weights with coef() method from package "nnet"
#> a 5-5-1 network with 36 weights
#> inputs: Petal.Length Sepal.Length Sepal.Width Speciesversicolor Speciesvirginica 
#> output(s): Petal.Width 
#> options were - linear output units

#>        b->h1       i1->h1       i2->h1       i3->h1       i4->h1       i5->h1 
#>   0.33696598   5.87087353   5.86170711   2.01521854  -0.40507201   0.59342370 
#>        b->h2       i1->h2       i2->h2       i3->h2       i4->h2       i5->h2 
#>  -0.28521282  -0.66194804  -2.53390441  -1.83519953  -0.36281388   0.08545735 
#>        b->h3       i1->h3       i2->h3       i3->h3       i4->h3       i5->h3 
#>  -0.73454761  -0.57450360   0.15459231  -0.92357099  -1.44787867   3.87019807 
#>        b->h4       i1->h4       i2->h4       i3->h4       i4->h4       i5->h4 
#>   4.15247439  -0.31732640  -0.07613404  -0.33243948  -1.59663476  -5.46270442 
#>        b->h5       i1->h5       i2->h5       i3->h5       i4->h5       i5->h5 
#>   1.31088094  -0.51583287   3.79240044   3.43102418  -0.56412793  -0.46786541 
#>         b->o        h1->o        h2->o        h3->o        h4->o        h5->o 
#>   0.30750849  12.55643667   3.18489708  -2.57517235  -2.36851810 -10.44863295

Created on 2023-03-08 with reprex v2.0.2

Upvotes: 0


Reputation: 939

  1. (added by myself) If you create teaching material using mlr3 and would like to share it you can create an issue or PR in the mlr-org/mlr3website repository. On the mlr-org website we have a resources tab where we can link stuff like that :)

  2. Is it easily possible to extract additional measures, e.g. msr("regr.rmse") from the benchmark object?

The $aggregate() method takes in a list of measures (e.g. constructable by msrs()). (See code below)

  1. I am not satisfied with my code line pr <- ....., but found no better way yet.

You can do best$predictions()[[1]]$response without the conversion.

  1. Finally I want to get access internal data structure of the fitted nnet, to extract the raw weights for an "offline" use of the neural network outside of R.

We do not meddle with the internal structures of the fitted objects. They can be accessed through the $model slot of a trained learner (see code below).


learner = lrns(c("classif.rpart", "classif.nnet"))
task = tsk("iris")
resampling = rsmp("holdout")

design = benchmark_grid(
  tasks = task,
  learners = learner,
  resamplings = resampling

bmr = benchmark(design, store_models = TRUE)
#> INFO  [07:13:52.802] [mlr3] Running benchmark with 2 resampling iterations
#> INFO  [07:13:52.889] [mlr3] Applying learner 'classif.rpart' on task 'iris' (iter 1/1)
#> INFO  [07:13:52.921] [mlr3] Applying learner 'classif.nnet' on task 'iris' (iter 1/1)
#> # weights:  27
#> initial  value 118.756408 
#> iter  10 value 58.639749
#> iter  20 value 45.676852
#> iter  30 value 21.336083
#> iter  40 value 8.646964
#> iter  50 value 6.041813
#> iter  60 value 5.906140
#> iter  70 value 5.902865
#> iter  80 value 5.898339
#> final  value 5.898161 
#> converged
#> INFO  [07:13:52.946] [mlr3] Finished benchmark

bmr$aggregate(msrs(c("classif.acc", "time_train")))
#>    nr      resample_result task_id    learner_id resampling_id iters
#> 1:  1 <ResampleResult[21]>    iris classif.rpart       holdout     1
#> 2:  2 <ResampleResult[21]>    iris  classif.nnet       holdout     1
#>    classif.acc time_train
#> 1:        0.98      0.007
#> 2:        1.00      0.005

# get the first resample result
rr1 = bmr$resample_result(1)

# Get the model from the first resampling iteration of this ResampleResult
#> n= 100 
#> node), split, n, loss, yval, (yprob)
#>       * denotes terminal node
#> 1) root 100 66 versicolor (0.32000000 0.34000000 0.34000000)  
#>   2) Petal.Length< 2.45 32  0 setosa (1.00000000 0.00000000 0.00000000) *
#>   3) Petal.Length>=2.45 68 34 versicolor (0.00000000 0.50000000 0.50000000)  
#>     6) Petal.Width< 1.75 37  4 versicolor (0.00000000 0.89189189 0.10810811) *
#>     7) Petal.Width>=1.75 31  1 virginica (0.00000000 0.03225806 0.96774194) *

Created on 2023-03-08 with reprex v2.0.2

Upvotes: 2

Related Questions