cv_misvm: Fit MI-SVM model to the data using cross-validation

View source: R/cv_misvm.R

cv_misvmR Documentation

Fit MI-SVM model to the data using cross-validation

Description

Cross-validation wrapper on the misvm() function to fit the MI-SVM model over a variety of specified cost parameters. The optimal cost parameter is chosen by the best AUC of the cross-fit models. See ?misvm for more details on the fitting function.

Usage

## Default S3 method:
cv_misvm(
  x,
  y,
  bags,
  cost_seq,
  n_fold,
  fold_id,
  method = c("heuristic", "mip", "qp-heuristic"),
  weights = TRUE,
  control = list(kernel = "linear", sigma = 1, nystrom_args = list(m = nrow(x), r =
    nrow(x), sampling = "random"), max_step = 500, type = "C-classification", scale =
    TRUE, verbose = FALSE, time_limit = 60, start = FALSE),
  ...
)

## S3 method for class 'formula'
cv_misvm(formula, data, cost_seq, n_fold, fold_id, ...)

## S3 method for class 'mi_df'
cv_misvm(x, ...)

Arguments

x

A data.frame, matrix, or similar object of covariates, where each row represents a sample.

y

A numeric, character, or factor vector of bag labels for each instance. Must satisfy length(y) == nrow(x). Suggest that one of the levels is 1, '1', or TRUE, which becomes the positive class; otherwise, a positive class is chosen and a message will be supplied.

bags

A vector specifying which instance belongs to each bag. Can be a string, numeric, of factor.

cost_seq

A sequence of cost arguments (default 2^(-2:2)) in misvm().

n_fold

The number of folds (default 5). If this is specified, fold_id need not be specified.

fold_id

The ids for the specific the fold for each instance. Care must be taken to ensure that ids respect the bag structure to avoid information leakage. If n_fold is specified, fold_id will be computed automatically.

method

The algorithm to use in fitting (default 'heuristic'). When method = 'heuristic', which employs an algorithm similar to Andrews et al. (2003). When method = 'mip', the novel MIP method will be used. When method = 'qp-heuristic, the heuristic algorithm is computed using the dual SVM. See details.

weights

named vector, or TRUE, to control the weight of the cost parameter for each possible y value. Weights multiply against the cost vector. If TRUE, weights are calculated based on inverse counts of instances with given label, where we only count one positive instance per bag. Otherwise, names must match the levels of y.

control

list of additional parameters passed to the method that control computation with the following components:

  • kernel either a character the describes the kernel ('linear' or 'radial') or a kernel matrix at the instance level.

  • sigma argument needed for radial basis kernel.

  • nystrom_args a list of parameters to pass to kfm_nystrom(). This is used when method = 'mip' and kernel = 'radial' to generate a Nystrom approximation of the kernel features.

  • max_step argument used when method = 'heuristic'. Maximum steps of iteration for the heuristic algorithm.

  • type: argument used when method = 'heuristic'. The type argument is passed to e1071::svm().

  • scale argument used for all methods. A logical for whether to rescale the input before fitting.

  • verbose argument used when method = 'mip'. Whether to message output to the console.

  • time_limit argument used when method = 'mip'. FALSE, or a time limit (in seconds) passed to gurobi() parameters. If FALSE, no time limit is given.

  • start argument used when method = 'mip'. If TRUE, the mip program will be warm_started with the solution from method = 'qp-heuristic' to potentially improve speed.

...

Arguments passed to or from other methods.

formula

a formula with specification mi(y, bags) ~ x which uses the mi function to create the bag-instance structure. This argument is an alternative to the x, y, bags arguments, but requires the data argument. See examples.

data

If formula is provided, a data.frame or similar from which formula elements will be extracted.

Value

An object of class cv_misvm. The object contains the following components:

  • misvm_fit: A fit object of class misvm trained on the full data with the cross-validated choice of cost parameter. See misvm() for details.

  • cost_seq: the input sequence of cost arguments

  • cost_aucs: estimated AUC for the models trained for each cost_seq parameter. These are the average of the fold models for that cost, excluding any folds that don't have both levels of y in the validation set.

  • best_cost: The optimal choice of cost parameter, chosen as that which has the maximum AUC. If there are ties, this will pick the smallest cost with maximum AUC.

Methods (by class)

  • default: Method for data.frame-like objects

  • formula: Method for passing formula

  • mi_df: Method for mi_df objects, automatically handling bag names, labels, and all covariates.

Author(s)

Sean Kent, Yifei Liu

See Also

misvm() for fitting without cross-validation.

Examples

set.seed(8)
mil_data <- generate_mild_df(nbag = 20,
                             positive_prob = 0.15,
                             dist = rep("mvnormal", 3),
                             mean = list(rep(1, 10), rep(2, 10)),
                             sd_of_mean = rep(0.1, 3))
df <- build_instance_feature(mil_data, seq(0.05, 0.95, length.out = 10))
cost_seq <- 2^seq(-5, 7, length.out = 3)

# Heuristic method
mdl1 <- cv_misvm(x = df[, 4:123], y = df$bag_label,
                 bags = df$bag_name, cost_seq = cost_seq,
                 n_fold = 3, method = "heuristic")
mdl2 <- cv_misvm(mi(bag_label, bag_name) ~ X1_mean + X2_mean + X3_mean, data = df,
                 cost_seq = cost_seq, n_fold = 3)

if (require(gurobi)) {
  # solve using the MIP method
  mdl3 <- cv_misvm(x = df[, 4:123], y = df$bag_label,
                   bags = df$bag_name, cost_seq = cost_seq,
                   n_fold = 3, method = "mip")
}

predict(mdl1, new_data = df, type = "raw", layer = "bag")

# summarize predictions at the bag layer
suppressWarnings(library(dplyr))
df %>%
  bind_cols(predict(mdl2, df, type = "class")) %>%
  bind_cols(predict(mdl2, df, type = "raw")) %>%
  distinct(bag_name, bag_label, .pred_class, .pred)


mildsvm documentation built on July 14, 2022, 9:08 a.m.