R/permute.R

Defines functions permute

Documented in permute

#' Create permutations of resamples
#'
#' Create permutation of the response variable in the training set of resamples, to assess the quality of a reference model fit.
#'
#' @param object of class resamples, created by a `resample_***()` function.
#' @param resp name of the response variable.
#'
#' @returns The input object in which the training set of each row is now a data.frame where the response variable column has been shuffled. NB: This function therefore turns the `train` column from type `resample` to an actual `data.frame`, which takes more memory.
#'
#' @export
#' @examples
#' # Split train and test, fit a reference model and compute its performance stats
#' set.seed(123)
#' rs <- resample_split(mtcars, p=0.7)
#' rs_fitted <- rs %>% xgb_fit(resp="mpg", expl=c("cyl", "hp", "qsec"),
#'     eta=0.1, max_depth=6, nrounds=100)
#' plot(val_rmse_mean ~ iter, data=xgb_summarise_fit(rs_fitted))
#' pred <- xgb_predict(rs_fitted, niter=50, fns=c())
#' ref_stats <- regression_metrics(pred$pred, pred$mpg)
#' ref_stats
#'
#' # Now define permutations of the training set and compute the stats for each
#' rs_perm <- replicate(rs, n=100) %>% permute(resp="mpg")
#' rs_perm_fitted <- rs_perm %>% xgb_fit(resp="mpg", expl=c("cyl", "hp", "qsec"),
#'     eta=0.1, max_depth=6, nrounds=50)
#' pred <- group_by(rs_perm_fitted, replic) %>%
#'  # NB: group_by replicate number to keep it in the returned data.frame
#'   xgb_predict(niter=50, fns=c())
#' # compute the stats separately for each replicate
#' perm_stats <- pred %>%
#'   group_by(replic) %>%
#'   summarise(regression_metrics(pred, mpg))
#'
#' # Compare the observed RMSE with the one from the permutations
#' hist(perm_stats$RMSE, breaks=50)
#' abline(v=ref_stats$RMSE, col="red")
#' # and compute a p-value, through permutation
#' p.value <- sum(perm_stats$RMSE <= ref_stats$RMSE) / nrow(perm_stats)
#' p.value
permute <- function(object, resp) {
  for (i in 1:nrow(object)) {
    object$train[[i]] <- as.data.frame(object$train[[i]])
    object$train[[i]][[resp]] <- sample(object$train[[i]][[resp]])
    # TODO make it work with multiple responses.
  }
  return(object)
}
jiho/joml documentation built on Dec. 6, 2023, 5:50 a.m.