R/Surrogate.R

#' @title Surrogate
#' @usage NULL
#'
#' @format [R6::R6Class] object.
#' @include load_functions.R
#'
#' @description
#' Allows the construction of surrogates from a given meta-data dataset
#' of hyperparameters and a given performance.

#' @section Construction:
#' ```
#'   surr = Surrogate$new(oml_task_id = 31, base_learner = "regr.glmnet",
#'                        eval_measure = "auc", param_names = "lambda",
#'                        surrogate_learner = "regr.ranger",
#'                        data_source = "inst/extdata/glmnet_sample.csv", load_fun = load_from_csv)
#' ```
#'
#' @section Fields:
#' * `use_cache` :: `logical(1)`\cr
#'
#' * `oml_task_id` :: `integer(1)`\cr
#' OpenML Task id
#'
#' * `eval_measure` :: `character()`\cr
#' c('auc', 'acc', 'brier')
#'
#' * `base_learner` :: `character(1)` \cr Does not need to be an mlr learner.
#'
#' * `surrogate_learner` :: `character(1)`\cr An mlr learner.
#'
#' * `param_names` :: `character()`\cr
#'
#' * `param_set` :: [ParamSet]
#'
#' * `rtask` ::
#'
#' * `model` ::
#'
#' * `resample` ::
#'
#' * `scaler` ::
#'
#' * `save_path` :: `character(1)`\cr
#' Root directory to save the results
#'
#' * `data_source` :: `character(1)`\cr
#' Directory where the data is stored
#'
#' * `load_fun` :: `FUN`\cr
#' Function to load the data
#'
#' * `data` :: `data.frame()`\cr
#' Loaded data stored in a data.frame
#'
#' @section Methods:
#' TODO: define missing return types
#' * `print()`\cr
#' `()` -> `NULL`\cr
#' Description of the method
#'
#' * `predict(newdata, rescale = FALSE)`\cr
#' (`data.frame()`, `logical()`) -> `Return Type`\cr
#' Description of the method
#'
#' * `file_rtask_to_disk()`\cr
#' `()` -> \cr
#' Description of the method
#'
#' * `file_model_to_disk()`\cr
#' `()` -> \cr
#' Description of the method
#'
#' * `file_resample_to_disk()`\cr
#' `()` -> \cr
#' Description of the method
#'
#' * `acquire_object()`\cr
#' `()` -> \cr
#' Description of the method
#'
#' * `acquire_rtask()`\cr
#' `()` -> \cr
#' Description of the method
#'
#' * `acquire_model()`\cr
#' `()` -> \cr
#' Description of the method
#'
#' * `acquire_resample()`\cr
#' `()` -> \cr
#' Description of the method
#'
#' * `fail_path(handle_prefix)`\cr
#' `` -> \cr
#' Description of the method
#'
#' * `save(keep.model = FALSE, keep.task = FALSE)`\cr
#' (`logical(1)`, `logical()`) -> \cr
#' Description of the method
#'
#' @family Surrogate
#' @export
#' @examples
#'   ps = get_param_set("glmnet")
#'   ds = system.file("extdata", "glmnet_sample.csv", package = "surrogates")
#'   surr = Surrogate$new(oml_task_id = 31, base_learner = "regr.glmnet",
#'                        eval_measure = "auc", param_set = ps,
#'                        surrogate_learner = "regr.ranger",
#'                        save_path = tempdir(),
#'                        data_source = ds, load_fun = load_from_csv)
Surrogate = R6Class("Surrogate",
  public = list(

    oml_task_id = NULL,
    eval_measure = NULL,
    base_learner = NULL,
    surrogate_learner = NULL,
    param_set = NULL,
    use_cache = TRUE,

    rtask = NULL,
    model = NULL,
    resample = NULL,
    scaler = NULL,
    load_fun = NULL,

    save_path = ".",
    data_source = NULL,
    data = NULL,

    param_names = NULL,
    cst_performance_ = 0,

    initialize = function(oml_task_id, base_learner, eval_measure, surrogate_learner,
          param_set, use_cache = TRUE, save_path, data_source, load_fun, scaler = Scaler$new()) {

      # Info used for sub-setting the data:
      self$oml_task_id = assert_int(oml_task_id)
      self$base_learner = assert_string(base_learner)
      self$eval_measure = assert_string(eval_measure)
      self$scaler = assert_class(scaler, "Scaler")
      self$surrogate_learner = mlr::checkLearner(surrogate_learner)

      self$use_cache = checkmate::assert_flag(use_cache)
      if (!missing(save_path))
        self$save_path = save_path
      self$save_path = fail::fail(self$fail_path)
      self$data_source = assert_string(data_source, null.ok = TRUE)
      self$load_fun = assert_function(load_fun)

      if (missing(param_set)) stop("Please provide a valid param_set of class ParamSet!")
      self$param_set = assert_class(param_set, "ParamSet")
      invisible(self)
    },

    train = function() {
      self$acquire_model()
      invisible(self)
    },

    predict = function(newdata, rescale = FALSE) {
      newdata = private$convert_data_types_for_ps(newdata)
      if (is.null(self$model)) self$train()
      prd = predict(self$model, newdata = data.frame(newdata))$data$response
      if (rescale)
        prd = self$scaler$rescale(prd)
      return(prd)
    },

    print = function(...) {
      catf("Surrogate for OML task <%i> for measure <%s> for BL <%s>",
        self$oml_task_id, self$measure_name, self$baselearner_name)
      catf("RTask: %s", ifelse(is.null(self$rtask), "no", "yes"))
      catf("Model: %s", ifelse(is.null(self$model), "no", "yes"))
      catf("Performance: %s", ifelse(is.null(self$resample), "N/A", self$resample$aggr))
      invisible(self)
    },

    # Object Getters ---------------------------------------------------------------------
    acquire_rtask = function() self$acquire_object("rtask"),
    acquire_model = function() self$acquire_object("model"),
    acquire_resample = function() self$acquire_object("resample"),

    acquire_object = function(id) {
      if (!is.null(self[[id]]))
        return()
      # If not in cache, create the object and write it to disk
      if (!self[[sprintf("in_cache_%s", id)]] | is.null(self$save_path)) {
        catf('Object not in cache')
        self[[sprintf("file_%s_to_disk", id)]]()
      }
      else {
        self[[id]] = self$save_path$get(self[[sprintf("key_%s", id)]])
      }
    },

    file_model_to_disk = function() {
      self$acquire_rtask()
      catf("<Obtaining Model>")
      private$fixup_learner()
      self$model = train(self$surrogate_learner, self$rtask)
      catf("<Writing model to disk>")
      if (self$use_cache) self$save_path$put(keys = self$key_model, self$model)
    },

    file_resample_to_disk = function(cv = cv3, measures = list(rmse, spearmanrho, kendalltau, expvar, timepredict)) {
      self$acquire_rtask()
      catf("<Obtaining Resampling>")
      private$fixup_learner()
      self$resample = resample(self$surrogate_learner, self$rtask, cv, measures)
      catf("<Writing resample to disk>")
      if (self$use_cache) self$save_path$put(keys = self$key_resample, self$resample)
    },

    file_rtask_to_disk = function() {
      # Load the data
      if (file.exists(self$data_source) | stringi::stri_startswith_fixed(self$data_source, "http"))
        self$data = self$load_fun(self)
      else stop("Input file doesn't exist!")

      # In case no data exists, we always sample 0 performance
      if (nrow(self$data) == 0L) {
        warning("No rows found in data")
        self$fit_constant_model()
      }
      catf("<Obtaining Task>")
      self$data = private$convert_data_types_for_ps(self$data)
      tsk = makeRegrTask(id = as.character(self$oml_task_id), data = as.data.frame(self$data), target = "performance")
      self$rtask = removeConstantFeatures(tsk)

      if (self$use_cache) {
        catf("<Writing task to disk>")
        self$save_path$put(keys = self$key_rtask, self$rtask)
      }
    },

    fit_constant_model = function() {
      d = ParamHelpers::generateGridDesign(res = 3L, self$param_set)
      d = d[, self$param_names, drop = FALSE]

      # Convert logicals/character to factors for learner
      which.logical = sapply(d, function(x) is.logical(x) | is.character(x))
      if (sum(which.logical > 0))
        d = do.call("cbind", list(d[, !which.logical], lapply(d[, which.logical],
          function(x) as.factor(as.character(x)))))
      d$performance = self$cst_performance_
      return(d)
    },

    save = function(keep.model = FALSE, keep.task = FALSE) {
      if (!keep.model) self$model = NULL
      if (!keep.task) self$rtask = NULL
      self$save_path$put(keys = self$key_class, self)
    }
  ),

  active = list(
    key_base = function() sprintf("%i_%s_%s",
      self$oml_task_id, self$eval_measure, self$base_learner),
    key_model = function() paste0("surr_model_", self$key_base),
    key_rtask = function() paste0("surr_rtask_", self$key_base),
    key_resample = function() paste0("surr_resample_", self$key_base),
    key_class = function() paste0("surrogate_", self$key_base),

    in_cache_rtask = function() self$key_rtask %in% self$save_path$ls(),
    in_cache_model = function() self$key_model %in% self$save_path$ls(),
    in_cache_resample = function() self$key_resample %in% self$save_path$ls(),

    fail_path = function() {
      paste(self$save_path, "surrogates", self$base_learner,
        paste0(self$surrogate_learner$short.name, "_surrogate"),
        self$eval_measure, self$scaler$scaler_name, sep = "/"
      )
    },
    cst_performance = function(val) {
      if(missing(val)) self$cst_performance_ = val else self$cst_performance_
    }
  ),

  private = list(
    convert_data_types_for_ps = function(data) {
      setDT(data)
      typedf = ParamHelpers:::getParSetPrintData(self$param_set)
      to_int = rownames(typedf[typedf$Type == "integer", ])
      if (length(to_int) > 0L)
        data[, to_int] = data[, lapply(.SD, as.integer), .SDcols = to_int]

      to_factor =  names(Filter(is.character, data))
      if (length(to_factor) > 0L)
        data[, to_factor] = data[, lapply(.SD, as.factor), .SDcols = to_factor]
      return(data)
    },
    fixup_learner = function() {
      if (self$rtask$task.desc$has.missings)
        self$surrogate_learner =
          makeImputeWrapper(self$surrogate_learner, list(
            numeric = imputeConstant(-9999),
            integer = imputeConstant(-9999L),
            factor = imputeConstant("_NA_")
        ))
    }
  )
)
compstat-lmu/surrogates documentation built on Jan. 29, 2020, 1:23 p.m.