R/predict.R

Defines functions accuracy predict.CBA

Documented in accuracy predict.CBA

#' Model Prediction for Classifiers Based on Association Rules
#'
#' Predicts classes for new data using a CBA classifier.
#'
#' @aliases predict
#' @name predict.CBA
#'
#' @param object An object of class [CBA].
#' @param newdata A data.frame or [arules::transactions] containing rows of new entries
#' to be classified.
#' @param type Predict `"class"` labels. Some classifiers can also return
#' \code{"scores"}.
#' @param \dots Additional arguments are ignored.
#' @return A factor vector with the classification result.
#' @author Michael Hahsler
#' @examples
#' data("iris")
#'
#' train_id <- sample(seq_len(nrow(iris)), 130)
#' iris_train <- iris[train_id, ]
#' iris_test <- iris[-train_id, ]
#'
#' cl <- CBA(Species ~., iris_train)
#' pr <- predict(cl, iris_test)
#' pr
#'
#' accuracy(pr, response(Species ~., iris_test))
#' @method predict CBA
#' @export
predict.CBA <-
  function(object,
    newdata,
    type = c("class", "score"),
    ...) {
    type <- match.arg(type)

    method <- object$method
    if (is.null(method))
      method <- "majority"

    methods <- c("first", "majority", "weighted", "logit")
    m <- pmatch(method, methods)
    if (is.na(m))
      stop("Unknown method")
    method <- methods[m]

    # no rules. Always predict the default class
    ### FIXME: Implement score.
    ### FIXME: class should return a factor.
    if (length(object$rules) == 0) {
      if (type == "class")
        return(rep(object$default, nrow(newdata)))
      # score
      stop("prediction type 'score' is not yet implemented for classifier with no rules.")
    }

    ### convert data
    if (is.null(object$discretization) &&
        !is(newdata, "transactions"))
      stop(
        "Classifier does not contain discretization information. New data needs to be in the form of transactions. Check ? discretizeDF."
      )

    newdata <-
      prepareTransactions(
        object$formula,
        newdata,
        disc.method = object$discretization,
        match = object$rules
      )

    # Matrix of which rules match which transactions (sparse is only better for more
    # than 150000 entries)
    rulesMatchLHS <- is.subset(lhs(object$rules), newdata,
      sparse = (length(newdata) * length(object$rules) > 150000))
    dimnames(rulesMatchLHS) <- list(NULL, NULL)

    # find class label for each rule
    RHSclass <- response(object$formula, object$rules)

    # classify using first match
    if (method == "first") {
      if (type == "score")
        stop(
          "prediction type 'score' is not supported for CBA classifiers using classification method 'first' (matching rule)."
        )

      w <-
        apply(
          rulesMatchLHS,
          MARGIN = 2,
          FUN = function(x)
            which(x)[1]
        )
      output <- RHSclass[w]
      if (any(is.na(w)) &&
          is.na(object$default))
        warning("Classifier has no default class when no rules matches! Producing NAs!")
      output[is.na(w)] <- object$default

      # preserve the levels of original data for data.frames
      return(output)
    }

    # For each transaction, if it is matched by any rule, classify it using
    # the majority, weighted majority

    # weights
    weights <- object$weights
    if (is.character(weights))
      weights <- quality(object$rules)[[weights, exact = FALSE]]
    if (is.null(weights))
      weights <- rep(1, length(object$rules))
    if (method == "majority")
      weights <- rep(1, length(object$rules))

    # transform weight vector into a matrix
    if (!is.matrix(weights)) {
      weights <- sapply(1:length(levels(RHSclass)), function(i) {
        w <- weights
        w[as.integer(RHSclass) != i] <- 0
        w
      })
    }

    if (nrow(weights) != length(object$rules) ||
        ncol(weights) != length(levels(RHSclass)))
      stop("number of weights does not match number of rules/classes.")

    if (is.null(object$best_k)) {
      ### score is the sum of the weights of all matching rules

      # class bias
      bias <- object$bias

      if (!is.null(bias) && nrow(bias) != length(levels(RHSclass)))
        stop("number of class bias values does not match number of rules/classes.")

      # sum score and add bias
      scores <- t(crossprod(weights, rulesMatchLHS))
      if (!is.null(bias))
        scores <- sweep(scores, 2, bias, '+')
    } else{
      ### score is the average of the top-N matching rules (see CPAR paper by Yin and Han, 2003)

      scores <- t(apply(
        rulesMatchLHS,
        MARGIN = 2,
        FUN = function(m) {
          m_weights <- weights * m
          m_weights <-
            apply(m_weights,
              MARGIN = 2,
              sort,
              decreasing = TRUE)[1:min(object$best_k, nrow(m_weights)), , drop = FALSE]
          m_weights[m_weights == 0] <- NA
          score <- colMeans(m_weights, na.rm = TRUE)
          score[is.na(score)] <- 0
          score
        }
      ))
    }

    colnames(scores) <- levels(RHSclass)

    if (method == "logit")
      scores <- exp(scores) / (1 + rowSums(exp(scores)))

    if (type == "score")
      return(scores)

    # make sure default wins for ties
    if (!is.null(object$default)) {
      defaultLevel <- which(object$default == levels(RHSclass))
      scores[, defaultLevel] <-
        scores[, defaultLevel] + .Machine$double.eps
    }

    output <- factor(
      apply(scores, MARGIN = 1, which.max),
      levels = 1:length(levels(RHSclass)),
      labels = levels(RHSclass)
    )

    return(output)
  }


#' @rdname predict.CBA
#' @param pred,true two factors with the same level representing the predictions and the ground truth (e.g., obtrained with [response()]).
#' @export
accuracy <- function(pred, true) {
  if (!identical(levels(pred), levels(true)))
    stop("pred and true need to be factors with matching levels!")

  tbl <- table(pred, true)
  sum(diag(tbl)) / sum(tbl)
}

Try the arulesCBA package in your browser

Any scripts or data that you put into this service are public.

arulesCBA documentation built on Aug. 20, 2022, 1:06 a.m.