knitr::opts_chunk$set( collapse = TRUE, comment = "#>" )
sin_func = function(n){ x = seq(-0.5,0.5,length.out=n) sin(2*pi*x) } library(smashrgen) library(smashr) library(torch) simu_func = function(n_list,signal_func,snr=3,sigma2=1,n_rep=10,m_list=c(10,30,50,100,150),seed=12345){ set.seed(seed) res = list() for(n in n_list){ res_n = list() for(rep in 1:n_rep){ print(paste('running n=',n,', rep=',rep, sep='')) # generate data mu = signal_func(n) mu = mu*sqrt(snr/var(mu)) y = mu + rnorm(n,0,sqrt(sigma2)) # fit smashr t0_smash = Sys.time() fit_smash = smash.gaus(y,sqrt(sigma2)) t_smash = Sys.time() - t0_smash res_smash = list(posterior=list(mean=fit_smash),run_time = t_smash) # fit gp res_sgp = list() res_sgp_torch=list() for(m in m_list){ fit_sgp = sgp(y,m=m,sigma2=sigma2,fix_sigma2 = T,fix_X_ind = T) fit_sgp_torch = sgp_torch(y,m=m,sigma2=sigma2,fix_sigma2 = T,fix_X_ind = T) res_sgp[[paste0('m',m,sep='')]] = fit_sgp res_sgp_torch[[paste0('m',m,sep='')]] = fit_sgp_torch } res_n[[paste0('rep_',rep,sep='')]] = list(res_smash=res_smash,res_sgp=res_sgp,res_sgp_torch=res_sgp_torch,data=list(mu=mu,y=y)) } res[[paste0('n_',n,sep='')]] = res_n } res } # test = simu_func(c(200,210),sin_func,n_rep=2,m_list=c(10,20)) # output = simu_func(c(300,500,1000,3000,5000,10000),sin_func)
output = readRDS('~/myRpackages/smashrgen/output/test_speed_gaussian_sin.rds') times = lapply(output,function(res){ temp = lapply(res,function(rep){ return(c(rep$res_smash$run_time,unlist(lapply(rep$res_sgp,function(z){z$run_time})),unlist(lapply(rep$res_sgp_torch,function(z){z$run_time})))) }) temp = do.call(rbind,temp) temp = colMeans(temp) }) times = do.call(rbind,times) ns = c(300,500,1000,3000,5000,10000) rownames(times) = ns colnames(times) = c('smashr','sgp_m10','sgp_m30','sgp_m50','sgp_m100','sgp_m150','sgp_torch_m10','sgp_torch_m30','sgp_torch_m50','sgp_torch_m100','sgp_torch_m150') round(times,3)
mse = function(x,y){mean((x-y)^2)} errors = lapply(output,function(res){ temp = lapply(res,function(rep){ return(c(mse(rep$res_smash$posterior$mean,rep$data$mu), unlist(lapply(rep$res_sgp,function(z){mse(z$posterior$mean,rep$data$mu)})), unlist(lapply(rep$res_sgp_torch,function(z){mse(z$posterior$mean,rep$data$mu)})))) }) temp = do.call(rbind,temp) temp = colMeans(temp) }) errors = do.call(rbind,errors) ns = c(300,500,1000,3000,5000,10000) rownames(errors) = ns colnames(errors) = c('smashr','sgp_m10','sgp_m30','sgp_m50','sgp_m100','sgp_m150','sgp_torch_m10','sgp_torch_m30','sgp_torch_m50','sgp_torch_m100','sgp_torch_m150') round(sqrt(errors),3)
plot(output$n_10000$rep_1$data$y,col='grey80',pch=20) lines(output$n_10000$rep_1$data$mu,col='grey70') lines(output$n_10000$rep_1$res_sgp$m10$posterior$mean,col=1) lines(output$n_10000$rep_1$res_sgp$m150$posterior$mean,col=2) lines(output$n_10000$rep_1$res_smash$posterior$mean,col=4)
output = readRDS('~/myRpackages/smashrgen/output/test_speed_gaussian_step.rds') times = lapply(output,function(res){ temp = lapply(res,function(rep){ return(c(rep$res_smash$run_time,unlist(lapply(rep$res_sgp,function(z){z$run_time})),unlist(lapply(rep$res_sgp_torch,function(z){z$run_time})))) }) temp = do.call(rbind,temp) temp = colMeans(temp) }) times = do.call(rbind,times) ns = c(300,500,1000,3000,5000,10000) rownames(times) = ns colnames(times) = c('smashr','sgp_m10','sgp_m30','sgp_m50','sgp_m100','sgp_m150','sgp_torch_m10','sgp_torch_m30','sgp_torch_m50','sgp_torch_m100','sgp_torch_m150') round(times,3)
mse = function(x,y){mean((x-y)^2)} errors = lapply(output,function(res){ temp = lapply(res,function(rep){ return(c(mse(rep$res_smash$posterior$mean,rep$data$mu), unlist(lapply(rep$res_sgp,function(z){mse(z$posterior$mean,rep$data$mu)})), unlist(lapply(rep$res_sgp_torch,function(z){mse(z$posterior$mean,rep$data$mu)})))) }) temp = do.call(rbind,temp) temp = colMeans(temp) }) errors = do.call(rbind,errors) ns = c(300,500,1000,3000,5000,10000) rownames(errors) = ns colnames(errors) = c('smashr','sgp_m10','sgp_m30','sgp_m50','sgp_m100','sgp_m150','sgp_torch_m10','sgp_torch_m30','sgp_torch_m50','sgp_torch_m100','sgp_torch_m150') round(sqrt(errors),3)
plot(output$n_10000$rep_1$data$y,col='grey80',pch=20) lines(output$n_10000$rep_1$data$mu,col='grey70') lines(output$n_10000$rep_1$res_sgp$m10$posterior$mean,col=1) lines(output$n_10000$rep_1$res_smash$posterior$mean,col=4) lines(output$n_10000$rep_1$res_sgp$m150$posterior$mean,col=2)
sessionInfo()
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.