R/model_setup_R.R

Defines functions get_feature_specs

get_feature_specs <- function(get_model_specs, model) {
  # Checks that get_model_specs is a proper function (R + py)
  # Extracts natively supported functions for get_model_specs if exists and not passed (R only)
  # Apply get_model_specs on model and checks that it provides the right output format (R and py)
  # Returns the feature_specs (R and py)

  model_class <- NULL # due to NSE

  model_class0 <- class(model)[1]

  # get_model_specs
  if (!is.function(get_model_specs) &&
    !is.null(get_model_specs) &&
    !is.na(get_model_specs)) {
    stop("`get_model_specs` must be NULL, NA or a function.")
    # NA is used to avoid using internally defined get_model_specs where this is
    # defined and not valid for the specified model
  }

  supported_models <- get_supported_models()


  # Get native get_model_specs if not passed and exists
  if (is.null(get_model_specs)) {
    native_func_available <- supported_models[get_model_specs == TRUE, model_class0 %in% model_class]
    if (native_func_available) {
      get_model_specs <- get(paste0("get_model_specs.", model_class0))
    } else {
      # The checks are disabled in the check_data function
    }
  }

  # Get the feature_specs from the model object by get_model_specs(model)
  if (is.function(get_model_specs)) {
    # Tests the get_model_specs function
    feature_specs <- tryCatch(get_model_specs(model), error = errorfun)
    if (class(feature_specs)[1] == "error") {
      stop(paste0(
        "The get_model_specs function of class `", model_class0, "` is invalid.\n",
        "See the 'Advanced usage' section of the vignette:\n",
        "vignette('understanding_shapr', package = 'shapr')\n",
        "for more information on running shapr with custom models.\n",
        "Note that `get_model_specs` is not required (can be set to NULL)\n",
        "unless you require consistency checks between model and data.\n",
        "A basic function test threw the following error:\n", as.character(feature_specs[[1]])
      ))
    }

    if (!(is.list(feature_specs) &&
      length(feature_specs) == 3 &&
      all(names(feature_specs) == c("labels", "classes", "factor_levels")))) {
      stop(
        paste0(
          "The `get_model_specs` function of class `", model_class0,
          "` does not return a list of length 3 with elements \"labels\",\"classes\",\"factor_levels\".\n",
          "See the 'Advanced usage' section of the vignette:\n",
          "vignette('understanding_shapr', package = 'shapr')\n",
          "for more information on running shapr with custom models and the required output format of get_model_specs."
        )
      )
    }
  } else {
    feature_specs <- NULL
  }


  return(feature_specs)
}
NorskRegnesentral/shapr documentation built on April 19, 2024, 1:19 p.m.