#' MINIROCKET model
#'
#' "MINIROCKET allows state-of-the-art accuracy with a minimum training computation cost."
#' We are using the Python implementation of the model available through the Python library sktime
#'
#' @return MinirocketModel. MINIROCKET model
#'
#' @export
MinirocketModel <- function() {
obj <- Model("MINIROCKET")
class(obj) <- c("MinirocketModel", class(obj))
obj
}
# -------
# METHODS
# -------
get_trained_model_ref <- function(obj, ...) UseMethod("get_trained_model_ref")
get_trained_model_ref.MinirocketModel <- function(obj) {
paste(obj$name, "trained", sep = "-")
}
get_trained_model_path <- function(obj, ...) UseMethod("get_trained_model_path")
get_trained_model_path.MinirocketModel <- function(obj) {
paste(system.file("extdata", package = "bonski.predict"), paste(get_trained_model_ref(obj), "pickle", sep = "."),
sep = "/")
}
save_trained_model <- function (obj, ...) UseMethod("save_trained_model")
save_trained_model.MinirocketModel <- function (obj) {
if (!is.null(obj$trained_model) && !reticulate::py_is_null_xptr(obj$trained_model)) {
reticulate::py_save_object(obj = obj$trained_model, filename = get_trained_model_path(obj))
}
}
load_trained_model <- function (obj, ...) UseMethod("load_trained_model")
load_trained_model.MinirocketModel <- function (obj) {
path <- get_trained_model_path(obj)
if (file.exists(path)) {
obj$trained_model <- reticulate::py_load_object(filename = path)
}
}
is_model_trained <- function(obj, ...) UseMethod("is_model_trained")
is_model_trained.MinirocketModel <- function(obj) {
if (is.null(obj$trained_model) || reticulate::py_is_null_xptr(obj$trained_model)) {
load_trained_model(obj)
}
if (is.null(obj$trained_model)) {
fit(obj)
}
}
pre_save_model.MinirocketModel <- function(obj) {
save_trained_model(obj)
NextMethod()
}
post_load_model.MinirocketModel <- function(obj) {
NextMethod()
load_trained_model(obj)
}
#' MINIROCKET model fitting
#'
#' @param obj MinirocketModel
#' @param classify logical. use classification or regression?
#' @param scoring character. Scoring method used by the classifier
#'
#' @export
fit.MinirocketModel <- function(obj, classify = TRUE, scoring = NULL) {
if (is.null(obj$data$X_train) || is.null(obj$data$y_train)) {
split_Xy(obj)
}
obj$classify <- classify
obj$scoring <- scoring
np <- reticulate::import("numpy", convert = FALSE, delay_load = FALSE)
sklearn <- reticulate::import("sklearn", convert = FALSE, delay_load = FALSE)
sktime <- reticulate::import("sktime", convert = FALSE, delay_load = FALSE)
# -----
# Python calls start
# -----
make_pipeline = sklearn$pipeline$make_pipeline
MiniRocketMultivariate = sktime$transformations$panel$rocket$MiniRocketMultivariate
if (classify) {
RidgeModel = sklearn$linear_model$RidgeClassifierCV
} else {
RidgeModel = sklearn$linear_model$RidgeCV
}
minirocket_pipeline = make_pipeline(MiniRocketMultivariate(),
RidgeModel(alphas = np$logspace(-3, 3, 10L), normalize = TRUE, scoring = scoring))
minirocket_pipeline$fit(reticulate::r_to_py(obj$data$X_train)$applymap(np$array),
reticulate::np_array(obj$data$y_train)$astype("int"))
# -----
# Python calls end
# -----
obj$trained_model <- minirocket_pipeline
}
#' MINIROCKET model score
#'
#' If ord_score is TRUE instead of performing an accuracy score (is the predicted class equals to the true one) we take
#' advantage of the fact that the ski evaluation score is an ordinal variable
#'
#' @param obj MinirocketModel
#' @param ord_score logical. Whether the score should be computed as ordinal
#'
#' @return numeric. score
#'
#' @export
score.MinirocketModel <- function(obj, ord_score = FALSE) {
if (is.null(obj$data$X_test) || is.null(obj$data$y_test)) {
split_Xy(obj)
}
is_model_trained(obj)
obj$ord_score <- ord_score
if (ord_score) {
y_pred <- predict(obj, obj$data$X_test)
score <- ordinal_score(obj$data$y_test, y_pred)
} else {
# -----
# Python calls start
# -----
score = obj$trained_model$score(reticulate::r_to_py(obj$data$X_test)$applymap(np$array),
reticulate::np_array(obj$data$y_test)$astype("int"))
# -----
# Python calls end
# -----
score <- py_to_r(score)
}
obj$score <- score
score
}
#' MINIROCKET model predict
#'
#' @param obj MinirocketModel
#' @param X_pred data.frame. Model input
#'
#' @return list. prediction
#'
#' @export
predict.MinirocketModel <- function(obj, X_pred) {
is_model_trained(obj)
np <- reticulate::import("numpy", convert = FALSE, delay_load = FALSE)
# -----
# Python calls start
# -----
y_pred = obj$trained_model$predict(reticulate::r_to_py(X_pred)$applymap(np$array))
# -----
# Python calls end
# -----
py_to_r(y_pred)
}
py_to_r <- function(obj) {
tryCatch(
{
return(reticulate::py_to_r(obj))
},
error = function(cond) {
return(obj)
}
)
}
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.