R/model_minirocket.R

Defines functions py_to_r predict.MinirocketModel score.MinirocketModel fit.MinirocketModel post_load_model.MinirocketModel pre_save_model.MinirocketModel is_model_trained.MinirocketModel is_model_trained get_trained_model_path.MinirocketModel get_trained_model_path get_trained_model_ref.MinirocketModel get_trained_model_ref MinirocketModel

Documented in fit.MinirocketModel MinirocketModel predict.MinirocketModel score.MinirocketModel

#' MINIROCKET model
#'
#' "MINIROCKET allows state-of-the-art accuracy with a minimum training computation cost."
#' We are using the Python implementation of the model available through the Python library sktime
#'
#' @return MinirocketModel. MINIROCKET model
#'
#' @export
MinirocketModel <- function() {
  obj <- Model("MINIROCKET")
  class(obj) <- c("MinirocketModel", class(obj))
  obj
}

# -------
# METHODS
# -------

get_trained_model_ref <- function(obj, ...) UseMethod("get_trained_model_ref")
get_trained_model_ref.MinirocketModel <- function(obj) {
  paste(obj$name, "trained", sep = "-")
}

get_trained_model_path <- function(obj, ...) UseMethod("get_trained_model_path")
get_trained_model_path.MinirocketModel <- function(obj) {
  paste(system.file("extdata", package = "bonski.predict"), paste(get_trained_model_ref(obj), "pickle", sep = "."),
        sep = "/")
}

save_trained_model <- function (obj, ...) UseMethod("save_trained_model")
save_trained_model.MinirocketModel <- function (obj) {
  if (!is.null(obj$trained_model) && !reticulate::py_is_null_xptr(obj$trained_model)) {
    reticulate::py_save_object(obj = obj$trained_model, filename = get_trained_model_path(obj))
  }
}

load_trained_model <- function (obj, ...) UseMethod("load_trained_model")
load_trained_model.MinirocketModel <- function (obj) {
  path <- get_trained_model_path(obj)
  if (file.exists(path)) {
    obj$trained_model <- reticulate::py_load_object(filename = path)
  }
}

is_model_trained <- function(obj, ...) UseMethod("is_model_trained")
is_model_trained.MinirocketModel <- function(obj) {
  if (is.null(obj$trained_model) || reticulate::py_is_null_xptr(obj$trained_model)) {
    load_trained_model(obj)
  }
  if (is.null(obj$trained_model)) {
    fit(obj)
  }
}


pre_save_model.MinirocketModel <- function(obj) {
  save_trained_model(obj)
  NextMethod()
}

post_load_model.MinirocketModel <- function(obj) {
  NextMethod()
  load_trained_model(obj)
}

#' MINIROCKET model fitting
#'
#' @param obj MinirocketModel
#' @param classify logical. use classification or regression?
#' @param scoring character. Scoring method used by the classifier
#'
#' @export
fit.MinirocketModel <- function(obj, classify = TRUE, scoring = NULL) {
  if (is.null(obj$data$X_train) || is.null(obj$data$y_train)) {
    split_Xy(obj)
  }

  obj$classify <- classify
  obj$scoring <- scoring

  np <- reticulate::import("numpy", convert = FALSE, delay_load = FALSE)
  sklearn <- reticulate::import("sklearn", convert = FALSE, delay_load = FALSE)
  sktime <- reticulate::import("sktime", convert = FALSE, delay_load = FALSE)
  # -----
  # Python calls start
  # -----
  make_pipeline = sklearn$pipeline$make_pipeline
  MiniRocketMultivariate = sktime$transformations$panel$rocket$MiniRocketMultivariate

  if (classify) {
    RidgeModel = sklearn$linear_model$RidgeClassifierCV
  } else {
    RidgeModel = sklearn$linear_model$RidgeCV
  }

  minirocket_pipeline = make_pipeline(MiniRocketMultivariate(),
                                      RidgeModel(alphas = np$logspace(-3, 3, 10L), normalize = TRUE, scoring = scoring))
  minirocket_pipeline$fit(reticulate::r_to_py(obj$data$X_train)$applymap(np$array),
                          reticulate::np_array(obj$data$y_train)$astype("int"))
  # -----
  # Python calls end
  # -----

  obj$trained_model <- minirocket_pipeline
}

#' MINIROCKET model score
#'
#' If ord_score is TRUE instead of performing an accuracy score (is the predicted class equals to the true one) we take
#' advantage of the fact that the ski evaluation score is an ordinal variable
#'
#' @param obj MinirocketModel
#' @param ord_score logical. Whether the score should be computed as ordinal
#'
#' @return numeric. score
#'
#' @export
score.MinirocketModel <- function(obj, ord_score = FALSE) {
  if (is.null(obj$data$X_test) || is.null(obj$data$y_test)) {
    split_Xy(obj)
  }
  is_model_trained(obj)

  obj$ord_score <- ord_score

  if (ord_score) {
    y_pred <- predict(obj, obj$data$X_test)
    score <- ordinal_score(obj$data$y_test, y_pred)
  } else {
    # -----
    # Python calls start
    # -----
    score = obj$trained_model$score(reticulate::r_to_py(obj$data$X_test)$applymap(np$array),
                                    reticulate::np_array(obj$data$y_test)$astype("int"))
    # -----
    # Python calls end
    # -----

    score <- py_to_r(score)
  }

  obj$score <- score
  score
}

#' MINIROCKET model predict
#'
#' @param obj MinirocketModel
#' @param X_pred data.frame. Model input
#'
#' @return list. prediction
#'
#' @export
predict.MinirocketModel <- function(obj, X_pred) {
  is_model_trained(obj)

  np <- reticulate::import("numpy", convert = FALSE, delay_load = FALSE)
  # -----
  # Python calls start
  # -----
  y_pred = obj$trained_model$predict(reticulate::r_to_py(X_pred)$applymap(np$array))
  # -----
  # Python calls end
  # -----

  py_to_r(y_pred)
}

py_to_r <- function(obj) {
  tryCatch(
  {
    return(reticulate::py_to_r(obj))
  },
    error = function(cond) {
      return(obj)
    }
  )
}
vadmbertr/bonski.predict documentation built on Dec. 23, 2021, 2:06 p.m.