Nothing
#' Model Prediction for Classifiers Based on Association Rules
#'
#' Predicts classes for new data using a CBA classifier.
#'
#' @aliases predict
#' @name predict.CBA
#'
#' @family classifier
#'
#' @param object An object of class [CBA].
#' @param newdata A data.frame or [arules::transactions] containing rows of new entries
#' to be classified.
#' @param type Predict `"class"` labels. Some classifiers can also return
#' \code{"scores"}.
#' @param \dots Additional arguments are ignored.
#' @return A factor vector with the classification result.
#' @author Michael Hahsler
#' @examples
#' data("iris")
#'
#' train_id <- sample(seq_len(nrow(iris)), 130)
#' iris_train <- iris[train_id, ]
#' iris_test <- iris[-train_id, ]
#'
#' cl <- CBA(Species ~., iris_train)
#' pr <- predict(cl, iris_test)
#' pr
#'
#' accuracy(pr, response(Species ~., iris_test))
#' @method predict CBA
#' @export
predict.CBA <-
function(object,
newdata,
type = c("class", "score"),
...) {
type <- match.arg(type)
method <- object$method
if (is.null(method))
method <- "majority"
methods <- c("first", "majority", "weighted", "logit")
m <- pmatch(method, methods)
if (is.na(m))
stop("Unknown method")
method <- methods[m]
# no rules. Always predict the default class
### FIXME: Implement score.
### FIXME: class should return a factor.
if (length(object$rules) == 0) {
if (type == "class")
return(rep(object$default, nrow(newdata)))
# score
stop("prediction type 'score' is not yet implemented for classifier with no rules.")
}
### convert data
if (is.null(object$discretization) &&
!is(newdata, "transactions"))
stop(
"Classifier does not contain discretization information. New data needs to be in the form of transactions. Check ? discretizeDF."
)
newdata <-
prepareTransactions(
object$formula,
newdata,
disc.method = object$discretization,
match = object$rules
)
# Matrix of which rules match which transactions (sparse is only better for more
# than 150000 entries)
rulesMatchLHS <- is.subset(lhs(object$rules), newdata,
sparse = (length(newdata) * length(object$rules) > 150000))
dimnames(rulesMatchLHS) <- list(NULL, NULL)
# find class label for each rule
RHSclass <- response(object$formula, object$rules)
# classify using first match
if (method == "first") {
if (type == "score")
stop(
"prediction type 'score' is not supported for CBA classifiers using classification method 'first' (matching rule)."
)
w <-
apply(
rulesMatchLHS,
MARGIN = 2,
FUN = function(x)
which(x)[1]
)
output <- RHSclass[w]
if (any(is.na(w)) &&
is.na(object$default))
warning("Classifier has no default class when no rules matches! Producing NAs!")
output[is.na(w)] <- object$default
# preserve the levels of original data for data.frames
return(output)
}
# For each transaction, if it is matched by any rule, classify it using
# the majority, weighted majority
# weights
weights <- object$weights
if (is.character(weights))
weights <- quality(object$rules)[[weights, exact = FALSE]]
if (is.null(weights))
weights <- rep(1, length(object$rules))
if (method == "majority")
weights <- rep(1, length(object$rules))
# transform weight vector into a matrix
if (!is.matrix(weights)) {
weights <- sapply(1:length(levels(RHSclass)), function(i) {
w <- weights
w[as.integer(RHSclass) != i] <- 0
w
})
}
if (nrow(weights) != length(object$rules) ||
ncol(weights) != length(levels(RHSclass)))
stop("number of weights does not match number of rules/classes.")
if (is.null(object$best_k)) {
### score is the sum of the weights of all matching rules
# class bias
bias <- object$bias
if (!is.null(bias) && nrow(bias) != length(levels(RHSclass)))
stop("number of class bias values does not match number of rules/classes.")
# sum score and add bias
scores <- t(crossprod(weights, rulesMatchLHS))
if (!is.null(bias))
scores <- sweep(scores, 2, bias, '+')
} else{
### score is the average of the top-N matching rules (see CPAR paper by Yin and Han, 2003)
scores <- t(apply(
rulesMatchLHS,
MARGIN = 2,
FUN = function(m) {
m_weights <- weights * m
m_weights <-
apply(m_weights,
MARGIN = 2,
sort,
decreasing = TRUE)[1:min(object$best_k, nrow(m_weights)), , drop = FALSE]
m_weights[m_weights == 0] <- NA
score <- colMeans(m_weights, na.rm = TRUE)
score[is.na(score)] <- 0
score
}
))
}
colnames(scores) <- levels(RHSclass)
if (method == "logit")
scores <- exp(scores) / (1 + rowSums(exp(scores)))
if (type == "score")
return(scores)
# make sure default wins for ties
if (!is.null(object$default)) {
defaultLevel <- which(object$default == levels(RHSclass))
scores[, defaultLevel] <-
scores[, defaultLevel] + .Machine$double.eps
}
output <- factor(
apply(scores, MARGIN = 1, which.max),
levels = 1:length(levels(RHSclass)),
labels = levels(RHSclass)
)
return(output)
}
#' @rdname predict.CBA
#' @param pred,true two factors with the same level representing the predictions and the ground truth (e.g., obtained with [response()]).
#' @export
accuracy <- function(pred, true) {
if (!identical(levels(pred), levels(true)))
stop("pred and true need to be factors with matching levels!")
tbl <- table(pred, true)
sum(diag(tbl)) / sum(tbl)
}
Any scripts or data that you put into this service are public.
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.