predict.misvm_orova: Predict method for 'misvm_orova' object

View source: R/misvm_orova.R

predict.misvm_orovaR Documentation

Predict method for misvm_orova object

Description

Predict method for misvm_orova object. Predictions use the K fitted MI-SVM models. For class predictions, we return the class whose MI-SVM model has the highest raw predicted score. For raw predictions, a full matrix of predictions is returned, with one column for each model.

Usage

## S3 method for class 'misvm_orova'
predict(
  object,
  new_data,
  type = c("class", "raw"),
  layer = c("bag", "instance"),
  new_bags = "bag_name",
  ...
)

Arguments

object

An object of class misvm_orova

new_data

A data frame to predict from. This needs to have all of the features that the data was originally fitted with.

type

If 'class', return predicted values based on the highest output of an individual model. If 'raw', return the raw predicted scores for each model.

layer

If 'bag', return predictions at the bag level. If 'instance', return predictions at the instance level.

new_bags

A character or character vector. Can specify a singular character that provides the column name for the bag names in new_data (default 'bag_name'). Can also specify a vector of length nrow(new_data) that has bag name for each row.

...

Arguments passed to or from other methods.

Details

When the object was fitted using the formula method, then the parameters new_bags and new_instances are not necessary, as long as the names match the original function call.

Value

A tibble with nrow(new_data) rows. If type = 'class', the tibble will have a column .pred_class. If type = 'raw', the tibble will have K columns .pred_{class_name} corresponding to the raw predictions of the K models.

Author(s)

Sean Kent

See Also

misvm_orova() for fitting the misvm_orova object.

Examples

data("ordmvnorm")
x <- ordmvnorm[, 3:7]
y <- ordmvnorm$bag_label
bags <- ordmvnorm$bag_name

mdl1 <- misvm_orova(x, y, bags)

# summarize predictions at the bag layer
library(dplyr)
df1 <- bind_cols(y = y, bags = bags, as.data.frame(x))
df1 %>%
  bind_cols(predict(mdl1, df1, new_bags = bags, type = "class")) %>%
  bind_cols(predict(mdl1, df1, new_bags = bags, type = "raw")) %>%
  select(-starts_with("V")) %>%
  distinct()


mildsvm documentation built on July 14, 2022, 9:08 a.m.