Nothing
#' Weighted confusion matrix
#'
#' This function calculates the weighted confusion matrix from a caret
#' ConfusionMatrix object or a simple matrix, according to one of several
#' weighting schemas and optionally prints the weighted accuracy score.
#'
#' @param m the caret confusion matrix object or simple matrix.
#'
#' @param weight.type the weighting schema to be used. Can be one of:
#' "arithmetic" - a decreasing arithmetic progression weighting scheme,
#' "geometric" - a decreasing geometric progression weighting scheme,
#' "normal" - weights drawn from the right tail of a normal distribution,
#' "interval" - weights contained on a user-defined interval,
#' "sin" - a weighing scheme based on a sine function,
#' "tanh" - a weighing scheme based on a hyperbolic tangent function,
#' "custom" - custom weight vector defined by the user.
#'
#' @param weight.penalty determines whether the weights associated with
#' non-diagonal elements generated by the "normal", "arithmetic" and "geometric"
#' weight types are positive or negative values. By default, the value is set to
#' FALSE, which means that generated weights will be positive values.
#'
#' @param standard.deviation standard deviation of the normal distribution, if
#' the normal distribution weighting schema is used.
#'
#' @param geometric.multiplier the multiplier used to construct the geometric
#' progression series, if the geometric progression weighting scheme is used.
#'
#' @param sin.high the upper segment of the sine function to be used in the
#' weighting scheme.
#'
#' @param sin.low the lower segment of the sine function to be used in the
#' weighting scheme.
#'
#' @param tanh.decay the decay factor of the hyperbolic tangent weighing
#' function. Higher values increase the rate of decay and place less weight on
#' observations farther away from the correctly predicted category.
#'
#' @param interval.high the upper bound of the weight interval, if the interval
#' weighting scheme is used.
#'
#' @param interval.low the lower bound of the weight interval, if the interval
#' weighting scheme is used.
#'
#' @param custom.weights the vector of custom weight to be applied, if the
#' custom weighting scheme was selected. The vector should be equal to "n", but
#' can be larger, with excess values being ignored.
#'
#' @param print.weighted.accuracy print the weighted accuracy metric, which
#' represents the sum of all weighted confusion matrix cells divided by the
#' total number of observations.
#'
#' @return an nxn weighted confusion matrix
#'
#' @details The number of categories "n" should be greater or equal to 2.
#'
#' @usage wconfusionmatrix(m, weight.type = "arithmetic",
#' weight.penalty = FALSE,
#' standard.deviation = 2,
#' geometric.multiplier = 2,
#' interval.high=1, interval.low = -1,
#' sin.high=1.5*pi, sin.low = 0.5*pi,
#' tanh.decay = 3,
#' custom.weights = NA,
#' print.weighted.accuracy = FALSE)
#'
#' @keywords weighted confusion matrix accuracy score
#'
#' @seealso [weightmatrix()] for the weight matrix used in computations,
#' [balancedaccuracy()] for accuracy metrics designed for imbalanced data.
#'
#' @author Alexandru Monahov, <https://www.alexandrumonahov.eu.org/>
#'
#' @examples
#' m = matrix(c(70,0,0,10,10,0,5,3,2), ncol = 3, nrow=3)
#' wconfusionmatrix(m, weight.type="arithmetic", print.weighted.accuracy = TRUE)
#' wconfusionmatrix(m, weight.type="geometric", print.weighted.accuracy = TRUE)
#' wconfusionmatrix(m, weight.type="interval", print.weighted.accuracy = TRUE)
#' wconfusionmatrix(m, weight.type="normal", print.weighted.accuracy = TRUE)
#' wconfusionmatrix(m, weight.type="sin", print.weighted.accuracy = TRUE)
#' wconfusionmatrix(m, weight.type="tanh", print.weighted.accuracy = TRUE)
#' wconfusionmatrix(m, weight.type= "custom", custom.weights = c(1,0.1,0),
#' print.weighted.accuracy = TRUE)
#'
#' @export
wconfusionmatrix <- function(m, weight.type = "arithmetic", weight.penalty = FALSE, standard.deviation = 2, geometric.multiplier = 2, interval.high=1, interval.low = -1, sin.high=1.5*pi, sin.low = 0.5*pi, tanh.decay = 3, custom.weights = NA, print.weighted.accuracy = FALSE) {
if (is.matrix(m) == FALSE) {m = as.matrix(m)}
n = length(m[,1])
cf = 0.123456789 # correction factor used to avoid weighting error due to assignment of weight value equal to the numbers used in the find-replace algorithm
if (weight.type == "normal") {
# Normal distribution
a <- seq(from = 1, to = n, by = 1)
fmean <- seq(from = 1, to = n, by = 1)
mat <- t(mapply(function(mean,sd) dnorm(a,mean,sd)/max(dnorm(a,mean,sd)), mean=fmean, sd=standard.deviation))
if (weight.penalty == TRUE) {
mat = -(1-mat)
diag(mat) = 1
}
if (print.weighted.accuracy == TRUE) {
waccuracy = sum(m*mat)/sum(m)
cat("Weighted accuracy = ", sum(m*mat)/sum(m), "\n", "\n")
}
return(m*mat)
}
else if (weight.type == "arithmetic") {
# Arithmetic progression
mat = ((n-1)-abs(outer(seq(0, (n-1), 1), seq(0, (n-1), 1), `-`)))/(n-1)
if (weight.penalty == TRUE) {
mat = -(1-mat)
diag(mat) = 1
}
if (print.weighted.accuracy == TRUE) {
waccuracy = sum(m*mat)/sum(m)
cat("Weighted accuracy = ", sum(m*mat)/sum(m), "\n", "\n")
}
return(m*mat)
}
else if (weight.type == "geometric") {
# Geometric progression
mult = geometric.multiplier
mat = (abs(outer(seq(0, (n-1), 1), seq(0, (n-1), 1), `-`)))+1+cf
x=mult^seq(0,(n-1),by=1)
x_n = (x-min(x))/(max(x)-min(x))
if (mult > 1){
x_dict = 1-x_n
} else if (mult > 0 && mult < 1) {
x_dict = x_n
} else if (mult == 1) {
x_dict = 1-seq(0, (n-1), 1)/(n-1)
} else if (mult <= 0) {
stop("Please enter a multiplier value greater than zero.")
}
for (i in 1:n) {
mat[mat==i+cf] = x_dict[i]
}
if (weight.penalty == TRUE) {
mat = -(1-mat)
diag(mat) = 1
}
if (print.weighted.accuracy == TRUE) {
waccuracy = sum(m*mat)/sum(m)
cat("Weighted accuracy = ", sum(m*mat)/sum(m), "\n", "\n")
}
return(m*mat)
}
else if (weight.type == "sin") {
sin_hi = sin.high
sin_lo = sin.low
mat = (abs(outer(seq(0, (n-1), 1), seq(0, (n-1), 1), `-`)))+1+cf
mat_tmp = mat
x = sin(seq(sin_lo, sin_hi, length.out = n))
for (i in 1:n) {
mat[mat_tmp==i+cf] = x[i]
}
if (print.weighted.accuracy == TRUE) {
waccuracy = sum(m*mat)/sum(m)
cat("Weighted accuracy = ", sum(m*mat)/sum(m), "\n", "\n")
}
return(m*mat)
}
else if (weight.type == "tanh") {
tanh_decay = tanh.decay # higher values mean quicker decay (less weight placed on values far away from correct classification)
mat = (abs(outer(seq(0, (n-1), 1), seq(0, (n-1), 1), `-`)))+1+cf
mat_tmp = mat
x = 1-tanh(seq(0, tanh_decay, length.out = n))
if (weight.penalty == TRUE) {
x = tanh(seq(0, tanh_decay, length.out = n))
}
for (i in 1:n) {
mat[mat_tmp==i+cf] = x[i]
}
if (print.weighted.accuracy == TRUE) {
waccuracy = sum(m*mat)/sum(m)
cat("Weighted accuracy = ", sum(m*mat)/sum(m), "\n", "\n")
}
return(m*mat)
}
else if (weight.type == "interval") {
# Interval weight
hi = interval.high
lo = interval.low
mat = (abs(outer(seq(0, (n-1), 1), seq(0, (n-1), 1), `-`)))+1+cf
mat_tmp = mat
x=seq(hi, lo, length.out = n)
for (i in 1:n) {
mat[mat_tmp==i+cf] = x[i]
}
if (print.weighted.accuracy == TRUE) {
waccuracy = sum(m*mat)/sum(m)
cat("Weighted accuracy = ", sum(m*mat)/sum(m), "\n", "\n")
}
return(m*mat)
}
else if (weight.type == "custom") {
# Custom weights
wt = custom.weights
mat = (abs(outer(seq(0, (n-1), 1), seq(0, (n-1), 1), `-`)))+1+cf
for (i in 1:n) {
mat[mat==i+cf] = wt[i]
}
if (print.weighted.accuracy == TRUE) {
waccuracy = sum(m*mat)/sum(m)
cat("Weighted accuracy = ", sum(m*mat)/sum(m), "\n", "\n")
}
return(m*mat)
}
}
Any scripts or data that you put into this service are public.
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.