knitr::opts_chunk$set(
  collapse = TRUE,
  comment = "#>"
)
sin_func = function(n){
  x = seq(-0.5,0.5,length.out=n)
  sin(2*pi*x)
}
library(smashrgen)
library(smashr)
library(torch)
simu_func = function(n_list,signal_func,snr=3,sigma2=1,n_rep=10,m_list=c(10,30,50,100,150),seed=12345){
  set.seed(seed)
  res = list()
  for(n in n_list){
    res_n = list()
    for(rep in 1:n_rep){
      print(paste('running n=',n,', rep=',rep, sep=''))
      # generate data
      mu = signal_func(n)
      mu = mu*sqrt(snr/var(mu))
      y = mu + rnorm(n,0,sqrt(sigma2))
      # fit smashr
      t0_smash = Sys.time()
      fit_smash = smash.gaus(y,sqrt(sigma2))
      t_smash = Sys.time() - t0_smash
      res_smash = list(posterior=list(mean=fit_smash),run_time = t_smash)
      # fit gp
      res_sgp = list()
      res_sgp_torch=list()
      for(m in m_list){
        fit_sgp = sgp(y,m=m,sigma2=sigma2,fix_sigma2 = T,fix_X_ind = T)
        fit_sgp_torch = sgp_torch(y,m=m,sigma2=sigma2,fix_sigma2 = T,fix_X_ind = T)
        res_sgp[[paste0('m',m,sep='')]] = fit_sgp
        res_sgp_torch[[paste0('m',m,sep='')]] = fit_sgp_torch
      }
      res_n[[paste0('rep_',rep,sep='')]] = list(res_smash=res_smash,res_sgp=res_sgp,res_sgp_torch=res_sgp_torch,data=list(mu=mu,y=y))
    }
    res[[paste0('n_',n,sep='')]] = res_n
  }
  res
}

# test = simu_func(c(200,210),sin_func,n_rep=2,m_list=c(10,20))
# output = simu_func(c(300,500,1000,3000,5000,10000),sin_func)

sin function

output = readRDS('~/myRpackages/smashrgen/output/test_speed_gaussian_sin.rds')
times = lapply(output,function(res){
  temp = lapply(res,function(rep){
    return(c(rep$res_smash$run_time,unlist(lapply(rep$res_sgp,function(z){z$run_time})),unlist(lapply(rep$res_sgp_torch,function(z){z$run_time}))))
  })
  temp = do.call(rbind,temp)
  temp = colMeans(temp)
})
times = do.call(rbind,times)
ns = c(300,500,1000,3000,5000,10000)
rownames(times) = ns
colnames(times) = c('smashr','sgp_m10','sgp_m30','sgp_m50','sgp_m100','sgp_m150','sgp_torch_m10','sgp_torch_m30','sgp_torch_m50','sgp_torch_m100','sgp_torch_m150')
round(times,3)
mse = function(x,y){mean((x-y)^2)}
errors = lapply(output,function(res){
  temp = lapply(res,function(rep){
    return(c(mse(rep$res_smash$posterior$mean,rep$data$mu),
             unlist(lapply(rep$res_sgp,function(z){mse(z$posterior$mean,rep$data$mu)})),
             unlist(lapply(rep$res_sgp_torch,function(z){mse(z$posterior$mean,rep$data$mu)}))))
  })
  temp = do.call(rbind,temp)
  temp = colMeans(temp)
})
errors = do.call(rbind,errors)
ns = c(300,500,1000,3000,5000,10000)
rownames(errors) = ns
colnames(errors) = c('smashr','sgp_m10','sgp_m30','sgp_m50','sgp_m100','sgp_m150','sgp_torch_m10','sgp_torch_m30','sgp_torch_m50','sgp_torch_m100','sgp_torch_m150')
round(sqrt(errors),3)
plot(output$n_10000$rep_1$data$y,col='grey80',pch=20)
lines(output$n_10000$rep_1$data$mu,col='grey70')
lines(output$n_10000$rep_1$res_sgp$m10$posterior$mean,col=1)
lines(output$n_10000$rep_1$res_sgp$m150$posterior$mean,col=2)
lines(output$n_10000$rep_1$res_smash$posterior$mean,col=4)

step function

output = readRDS('~/myRpackages/smashrgen/output/test_speed_gaussian_step.rds')
times = lapply(output,function(res){
  temp = lapply(res,function(rep){
    return(c(rep$res_smash$run_time,unlist(lapply(rep$res_sgp,function(z){z$run_time})),unlist(lapply(rep$res_sgp_torch,function(z){z$run_time}))))
  })
  temp = do.call(rbind,temp)
  temp = colMeans(temp)
})
times = do.call(rbind,times)
ns = c(300,500,1000,3000,5000,10000)
rownames(times) = ns
colnames(times) = c('smashr','sgp_m10','sgp_m30','sgp_m50','sgp_m100','sgp_m150','sgp_torch_m10','sgp_torch_m30','sgp_torch_m50','sgp_torch_m100','sgp_torch_m150')
round(times,3)
mse = function(x,y){mean((x-y)^2)}
errors = lapply(output,function(res){
  temp = lapply(res,function(rep){
    return(c(mse(rep$res_smash$posterior$mean,rep$data$mu),
             unlist(lapply(rep$res_sgp,function(z){mse(z$posterior$mean,rep$data$mu)})),
             unlist(lapply(rep$res_sgp_torch,function(z){mse(z$posterior$mean,rep$data$mu)}))))
  })
  temp = do.call(rbind,temp)
  temp = colMeans(temp)
})
errors = do.call(rbind,errors)
ns = c(300,500,1000,3000,5000,10000)
rownames(errors) = ns
colnames(errors) = c('smashr','sgp_m10','sgp_m30','sgp_m50','sgp_m100','sgp_m150','sgp_torch_m10','sgp_torch_m30','sgp_torch_m50','sgp_torch_m100','sgp_torch_m150')
round(sqrt(errors),3)
plot(output$n_10000$rep_1$data$y,col='grey80',pch=20)
lines(output$n_10000$rep_1$data$mu,col='grey70')
lines(output$n_10000$rep_1$res_sgp$m10$posterior$mean,col=1)
lines(output$n_10000$rep_1$res_smash$posterior$mean,col=4)
lines(output$n_10000$rep_1$res_sgp$m150$posterior$mean,col=2)
sessionInfo()


DongyueXie/smashrgen documentation built on Jan. 14, 2024, 5:30 a.m.