benchmarks/benchmark_gaussian.R

remove(list=ls())
pkgload::load_all()
library(magrittr)

data("CPS1988", package = "AER" )
df <- CPS1988

X <- model.matrix( wage~.-1, data = df )
y <- df$wage

set.seed(1071)

rmse <- function( y, pred ) {
  sqrt( mean( (y - pred)^2))
}


lambdas <- c(0,1,-2)
n_folds = 100

glm_rmse <- function( train_y, test_y, train_X, test_X ) {

  df <- data.frame( y = train_y, train_X )
  fit <- glm( y~.-1, data = df, family = gaussian )

  pred <- c(predict( fit, as.data.frame(test_X), type = "response"))
  rmse( test_y, pred )
}

rmse_list <- list()

for( i in seq_len(n_folds) ) {
  size <- nrow(X)

  train <- sample( seq_len(size), size * 0.7 )
  test <- seq_len(size)[ -train ]

  train_x <- X[ train, ]
  test_x <- X[ test, ]

  train_y <- y[train]
  test_y <- y[test]

  rmse_list[[i]] <- purrr::map( lambdas, function(lambda) {
    if( lambda < 0 ){
      lambda <- NULL
    }
    model <- gaussian_ridge_ls(
      y = train_y,
      X = train_x,
      reg_lambda = lambda,
      tol = 1e-4
    )
    pred <- predict(model, test_x)
    rmse( test_y, pred )
  })
  # verify that the error is similar to glm
  rmse_list[[i]] <- c( rmse_list[[i]],
                       glm_rmse(train_y, test_y, train_x, test_x) )
}

rmse_list %>%
  do.call(rbind,.) %>%
  data.frame %>%
  lapply( function(i) mean(unlist(i)))
JSzitas/ridges documentation built on May 3, 2022, 12:03 a.m.