smm: Fit SMM model to the data

View source: R/smm.R

smmR Documentation

Fit SMM model to the data

Description

Function to carry out support measure machines algorithm which is appropriate for multiple instance learning. The algorithm calculates the kernel matrix of different empirical measures using kernel mean embedding. The data set should be passed in with rows corresponding to samples from a set of instances. SMM will compute a kernel on the instances and pass that to kernlab::ksvm() to train the appropriate SVM model.

Usage

## Default S3 method:
smm(
  x,
  y,
  instances,
  cost = 1,
  weights = TRUE,
  control = list(kernel = "radial", sigma = if (is.vector(x)) 1 else 1/ncol(x), scale =
    TRUE),
  ...
)

## S3 method for class 'formula'
smm(formula, data, instances = "instance_name", ...)

## S3 method for class 'mild_df'
smm(x, ...)

Arguments

x

A data.frame, matrix, or similar object of covariates, where each row represents a sample. If a mild_df object is passed, y, instances are automatically extracted, bags is ignored, and all other columns will be used as predictors.

y

A numeric, character, or factor vector of bag labels for each instance. Must satisfy length(y) == nrow(x). Suggest that one of the levels is 1, '1', or TRUE, which becomes the positive class; otherwise, a positive class is chosen and a message will be supplied.

instances

A vector specifying which samples belong to each instance. Can be a string, numeric, of factor.

cost

The cost parameter in SVM, fed to the C argument in kernlab::ksvm().

weights

named vector, or TRUE, to control the weight of the cost parameter for each possible y value. Weights multiply against the cost vector. If TRUE, weights are calculated based on inverse counts of instances with given label, where we only count one positive instance per bag. Otherwise, names must match the levels of y.

control

A list of additional parameters passed to the method that control computation with the following components:

  • kernel either a character the describes the kernel ('linear' or 'radial') or a kernel matrix at the instance level.

  • sigma argument needed for radial basis kernel.

  • scale argument used for all methods. A logical for whether to rescale the input before fitting.

...

Arguments passed to or from other methods.

formula

A formula with specification y ~ x. This argument is an alternative to the x, y arguments, but requires the data and instances argument. See examples.

data

If formula is provided, a data.frame or similar from which formula elements will be extracted.

Value

An object of class smm The object contains at least the following components:

  • ksvm_fit: A fit of class ksvm from the kernlab package.

  • call_type: A character indicating which method smm() was called with.

  • x: The training data needed for computing the kernel matrix in prediction.

  • features: The names of features used in training.

  • levels: The levels of y that are recorded for future prediction.

  • cost: The cost parameter from function inputs.

  • sigma: The radial basis function kernel parameter.

  • weights: The calculated weights on the cost parameter, if applicable.

  • x_scale: If scale = TRUE, the scaling parameters for new predictions.

Methods (by class)

  • default: Method for data.frame-like objects

  • formula: Method for passing formula

  • mild_df: Method for mild_df objects. Use the bag_label as y at the instance level, then perform smm() ignoring the MIL structure.

Author(s)

Sean Kent, Yifei Liu

References

Muandet, K., Fukumizu, K., Dinuzzo, F., & Schölkopf, B. (2012). Learning from distributions via support measure machines. Advances in neural information processing systems, 25.

See Also

predict.smm() for prediction on new data.

Examples

set.seed(8)
n_instances <- 10
n_samples <- 20
y <- rep(c(1, -1), each = n_samples * n_instances / 2)
instances <- as.character(rep(1:n_instances, each = n_samples))
x <- data.frame(x1 = rnorm(length(y), mean = 1*(y==1)),
                x2 = rnorm(length(y), mean = 2*(y==1)),
                x3 = rnorm(length(y), mean = 3*(y==1)))

df <- data.frame(instance_name = instances, y = y, x)

mdl <- smm(x, y, instances)
mdl2 <- smm(y ~ ., data = df)

# instance level predictions
suppressWarnings(library(dplyr))
df %>%
  dplyr::bind_cols(predict(mdl, type = "raw", new_data = x, new_instances = instances)) %>%
  dplyr::bind_cols(predict(mdl, type = "class", new_data = x, new_instances = instances)) %>%
  dplyr::distinct(instance_name, y, .pred, .pred_class)


mildsvm documentation built on July 14, 2022, 9:08 a.m.