Reputation: 1274
Good afternoon !
Under R , i developed the following script :
library("rootSolve")
Sphere_1_pdim = function (x) return(sum(x^2))
RMSprop<-function(objFun ,iter = 50000, alpha = 0.00001,lambda=0.2 ,start_init ){
init = start_init
gradient_1 <- function(init , objFun ) {
p=ceiling(runif(1,min=0,max =length(init) ))
result <- gradient(objFun, init,pert = 1e-8) # vector of gradient / partial dérivatives
# print("gradient")
# print(replace(result, sample(length(init),2), 0))
# return(replace(result,p, 0))
return(result)
}
x <- init
# create a vector to contain all xs for all steps
x.All = matrix(c(paste("X",1:length(start_init)),"objfun"),ncol=length(start_init)+1)
tmp<-rep(0,length(init))
V=(1-lambda)*(gradient_1(x,objFun))^2
# gradient descent method to find the minimum
for(i in seq_len(iter)){
V=lambda*V+(1-lambda)*(gradient_1(x,objFun))^2
tmp = x - alpha*gradient_1(x,objFun)/sqrt(V)
if ( !is.nan(suppressWarnings(objFun(tmp))) ) {
x <- tmp
}
x.All = rbind(x.All,c(x,objFun(x)))
# print(c(i,x,objFun(x)))
}
# print result and plot all xs for every iteration
print(paste("The minimum of f(x) is ", objFun(x), " at position x = ", x, sep = ""))
# Here i should plot the x.All dataframe columns :
plot(x.All, type = "l")
x.All=as.data.frame(x.All[-1,])
print("x.All")
print(x.All)
dput(x.All)
return(list(x,x.All))
}
Example of output :
p=RMSprop(objFun=Sphere_1_pdim ,iter = 20, alpha = 0.01,lambda=0.2 ,rep(0.5,4) )
print(p[[2]])
[1] "x.All"
V1 V2 V3 V4
1 0.489793792738403 0.489793792738403 0.489793792738403 0.489793792738403
2 0.479794218064014 0.479794218064014 0.479794218064014 0.479794218064014
3 0.469836158529361 0.469836158529361 0.469836158529361 0.469836158529361
4 0.459887403385885 0.459887403385885 0.459887403385885 0.459887403385885
5 0.449941475647717 0.449941475647717 0.449941475647717 0.449941475647717
6 0.439997096792552 0.439997096792552 0.439997096792552 0.439997096792552
7 0.430054049875739 0.430054049875739 0.430054049875739 0.430054049875739
8 0.420112337792478 0.420112337792478 0.420112337792478 0.420112337792478
9 0.41017201167119 0.41017201167119 0.41017201167119 0.41017201167119
10 0.400233136231218 0.400233136231218 0.400233136231218 0.400233136231218
11 0.390295782982069 0.390295782982069 0.390295782982069 0.390295782982069
12 0.3803600292637 0.3803600292637 0.3803600292637 0.3803600292637
13 0.370425958630921 0.370425958630921 0.370425958630921 0.370425958630921
14 0.360493661256826 0.360493661256826 0.360493661256826 0.360493661256826
15 0.350563234797495 0.350563234797495 0.350563234797495 0.350563234797495
16 0.340634785238398 0.340634785238398 0.340634785238398 0.340634785238398
17 0.330708427713843 0.330708427713843 0.330708427713843 0.330708427713843
18 0.320784287730144 0.320784287730144 0.320784287730144 0.320784287730144
19 0.310862502424419 0.310862502424419 0.310862502424419 0.310862502424419
20 0.300943221952355 0.300943221952355 0.300943221952355 0.300943221952355
V5
1 0.95959183762028
2 0.920809966750635
3 0.882984063446507
4 0.845985695172046
5 0.809789326032179
6 0.774389780743497
7 0.739785943258098
8 0.705977505461846
9 0.672964316633562
10 0.640746253349907
11 0.609323192854345
12 0.578695007445931
13 0.548861563310546
14 0.519822719225405
15 0.491578326366735
16 0.464128227657637
17 0.43747225664385
18 0.411610237018144
19 0.386541981654287
20 0.362267291356257
Output of dput() :
structure(list(V1 = c("0.489793792738403", "0.479794218064014",
"0.469836158529361", "0.459887403385885", "0.449941475647717",
"0.439997096792552", "0.430054049875739", "0.420112337792478",
"0.41017201167119", "0.400233136231218", "0.390295782982069",
"0.3803600292637", "0.370425958630921", "0.360493661256826",
"0.350563234797495", "0.340634785238398", "0.330708427713843",
"0.320784287730144", "0.310862502424419", "0.300943221952355"
), V2 = c("0.489793792738403", "0.479794218064014", "0.469836158529361",
"0.459887403385885", "0.449941475647717", "0.439997096792552",
"0.430054049875739", "0.420112337792478", "0.41017201167119",
"0.400233136231218", "0.390295782982069", "0.3803600292637",
"0.370425958630921", "0.360493661256826", "0.350563234797495",
"0.340634785238398", "0.330708427713843", "0.320784287730144",
"0.310862502424419", "0.300943221952355"), V3 = c("0.489793792738403",
"0.479794218064014", "0.469836158529361", "0.459887403385885",
"0.449941475647717", "0.439997096792552", "0.430054049875739",
"0.420112337792478", "0.41017201167119", "0.400233136231218",
"0.390295782982069", "0.3803600292637", "0.370425958630921",
"0.360493661256826", "0.350563234797495", "0.340634785238398",
"0.330708427713843", "0.320784287730144", "0.310862502424419",
"0.300943221952355"), V4 = c("0.489793792738403", "0.479794218064014",
"0.469836158529361", "0.459887403385885", "0.449941475647717",
"0.439997096792552", "0.430054049875739", "0.420112337792478",
"0.41017201167119", "0.400233136231218", "0.390295782982069",
"0.3803600292637", "0.370425958630921", "0.360493661256826",
"0.350563234797495", "0.340634785238398", "0.330708427713843",
"0.320784287730144", "0.310862502424419", "0.300943221952355"
), V5 = c("0.95959183762028", "0.920809966750635", "0.882984063446507",
"0.845985695172046", "0.809789326032179", "0.774389780743497",
"0.739785943258098", "0.705977505461846", "0.672964316633562",
"0.640746253349907", "0.609323192854345", "0.578695007445931",
"0.548861563310546", "0.519822719225405", "0.491578326366735",
"0.464128227657637", "0.43747225664385", "0.411610237018144",
"0.386541981654287", "0.362267291356257")), class = "data.frame", row.names = c(NA,
-20L))
Here i used a number of 16 iterations (iter=16
) . Each row of the dataframe x.All
represents one iteration of the 16.
For each row , the row index-1 represents the iteration index. The last entrie is the value of the objective function.
I'm searching a way ( I prefer ggplot2
) in which i can plot the evolutions of columns ( from column X1
to the objfun
column ).
In this case , we will get 5 curves (in the same plot ) where the last curve is the evolution of the objective function Sphere_1_pdim
.
I had tried to add ( before return(x)
in RMSprop
function ) :
library(ggplot2)
library(reshape2)
x.All=as.data.frame(cbind(index=1:nrow(x.All[-1,]),x.All[-1,]))
colnames(x.All)=c("index",paste("X",1:length(start_init)),"objfun")
print(x.All)
df <- melt(x.All , id.vars = 'index', variable.name = 'series')
print(df)
#create line plot for each column in data frame
ggplot(df, aes(index, value)) + geom_line(aes(colour = series))
However , this had failed to plot the wanted curves .
I hope my question is clear for you !
Thank you a lot for your help !
Upvotes: 0
Views: 421
Reputation: 2485
I hope I have understood correctly.
This is your dataframe:
df <- structure(list(X1 = c(0.489793792738403, 0.479794218064014, 0.469836158529361,
0.459887403385885, 0.449941475647717, 0.439997096792552, 0.430054049875739,
0.420112337792478, 0.41017201167119, 0.400233136231218, 0.390295782982069,
0.3803600292637, 0.370425958630921, 0.360493661256826, 0.350563234797495,
0.340634785238398), X2 = c(0.489793792738403, 0.479794218064014,
0.469836158529361, 0.459887403385885, 0.449941475647717, 0.439997096792552,
0.430054049875739, 0.420112337792478, 0.41017201167119, 0.400233136231218,
0.390295782982069, 0.3803600292637, 0.370425958630921, 0.360493661256826,
0.350563234797495, 0.340634785238398), X3 = c(0.489793792738403,
0.479794218064014, 0.469836158529361, 0.459887403385885, 0.449941475647717,
0.439997096792552, 0.430054049875739, 0.420112337792478, 0.41017201167119,
0.400233136231218, 0.390295782982069, 0.3803600292637, 0.370425958630921,
0.360493661256826, 0.350563234797495, 0.340634785238398), X4 = c(0.489793792738403,
0.479794218064014, 0.469836158529361, 0.459887403385885, 0.449941475647717,
0.439997096792552, 0.430054049875739, 0.420112337792478, 0.41017201167119,
0.400233136231218, 0.390295782982069, 0.3803600292637, 0.370425958630921,
0.360493661256826, 0.350563234797495, 0.340634785238398), objfun = c(0.95959183762028,
0.920809966750635, 0.882984063446507, 0.845985695172046, 0.809789326032179,
0.774389780743497, 0.739785943258098, 0.705977505461846, 0.672964316633562,
0.640746253349907, 0.609323192854345, 0.578695007445931, 0.548861563310546,
0.519822719225405, 0.491578326366735, 0.464128227657637)), row.names = c(NA,
-16L), class = c("tbl_df", "tbl", "data.frame"))
First of all you need to reshape df
:
library(dplyr)
library(tidyr)
df2 <- df %>%
mutate(x = 1:16) %>%
pivot_longer(X1:objfun, "variable", "value")
Now you can plot the data:
library(ggplot2)
ggplot(df2) +
geom_line(aes(x, value, color = variable))
If your data are character
use:
x.All <- as.data.frame(sapply(x.All, as.numeric))
Upvotes: 1
Reputation: 886938
We can use matplot
from base R
matplot(as.matrix(df), type = 'l')
df <- structure(list(X1 = c(0.489793792738403, 0.479794218064014, 0.469836158529361,
0.459887403385885, 0.449941475647717, 0.439997096792552, 0.430054049875739,
0.420112337792478, 0.41017201167119, 0.400233136231218, 0.390295782982069,
0.3803600292637, 0.370425958630921, 0.360493661256826, 0.350563234797495,
0.340634785238398), X2 = c(0.489793792738403, 0.479794218064014,
0.469836158529361, 0.459887403385885, 0.449941475647717, 0.439997096792552,
0.430054049875739, 0.420112337792478, 0.41017201167119, 0.400233136231218,
0.390295782982069, 0.3803600292637, 0.370425958630921, 0.360493661256826,
0.350563234797495, 0.340634785238398), X3 = c(0.489793792738403,
0.479794218064014, 0.469836158529361, 0.459887403385885, 0.449941475647717,
0.439997096792552, 0.430054049875739, 0.420112337792478, 0.41017201167119,
0.400233136231218, 0.390295782982069, 0.3803600292637, 0.370425958630921,
0.360493661256826, 0.350563234797495, 0.340634785238398), X4 = c(0.489793792738403,
0.479794218064014, 0.469836158529361, 0.459887403385885, 0.449941475647717,
0.439997096792552, 0.430054049875739, 0.420112337792478, 0.41017201167119,
0.400233136231218, 0.390295782982069, 0.3803600292637, 0.370425958630921,
0.360493661256826, 0.350563234797495, 0.340634785238398), objfun = c(0.95959183762028,
0.920809966750635, 0.882984063446507, 0.845985695172046, 0.809789326032179,
0.774389780743497, 0.739785943258098, 0.705977505461846, 0.672964316633562,
0.640746253349907, 0.609323192854345, 0.578695007445931, 0.548861563310546,
0.519822719225405, 0.491578326366735, 0.464128227657637)), row.names = c(NA,
-16L), class = c("tbl_df", "tbl", "data.frame"))
Upvotes: 1
Reputation: 23727
For this type of exploratory plotting, maybe consider pairs
pairs(mydat, panel = lines)
Thanks to user Leonardo for the data +1
Upvotes: 1