predict.misvm: Predict method for 'misvm' object

View source: R/misvm.R

predict.misvmR Documentation

Predict method for misvm object

Description

Predict method for misvm object

Usage

## S3 method for class 'misvm'
predict(
  object,
  new_data,
  type = c("class", "raw"),
  layer = c("bag", "instance"),
  new_bags = "bag_name",
  ...
)

Arguments

object

An object of class misvm.

new_data

A data frame to predict from. This needs to have all of the features that the data was originally fitted with.

type

If 'class', return predicted values with threshold of 0 as -1 or +1. If 'raw', return the raw predicted scores.

layer

If 'bag', return predictions at the bag level. If 'instance', return predictions at the instance level.

new_bags

A character or character vector. Can specify a singular character that provides the column name for the bag names in new_data (default 'bag_name'). Can also specify a vector of length nrow(new_data) that has bag name for each row.

...

Arguments passed to or from other methods.

Details

When the object was fitted using the formula method, then the parameters new_bags and new_instances are not necessary, as long as the names match the original function call.

Value

A tibble with nrow(new_data) rows. If type = 'class', the tibble will have a column .pred_class. If type = 'raw', the tibble will have a column .pred.

Author(s)

Sean Kent

See Also

  • misvm() for fitting the misvm object.

  • cv_misvm() for fitting the misvm object with cross-validation.

Examples

mil_data <- generate_mild_df(nbag = 20,
                             positive_prob = 0.15,
                             sd_of_mean = rep(0.1, 3))
df1 <- build_instance_feature(mil_data, seq(0.05, 0.95, length.out = 10))
mdl1 <- misvm(x = df1[, 4:63], y = df1$bag_label,
              bags = df1$bag_name, method = "heuristic")

predict(mdl1, new_data = df1, type = "raw", layer = "bag")

# summarize predictions at the bag layer
library(dplyr)
df1 %>%
  bind_cols(predict(mdl1, df1, type = "class")) %>%
  bind_cols(predict(mdl1, df1, type = "raw")) %>%
  distinct(bag_name, bag_label, .pred_class, .pred)


mildsvm documentation built on July 14, 2022, 9:08 a.m.