R/wconfusionmatrix.R

Defines functions wconfusionmatrix

Documented in wconfusionmatrix

#' Weighted confusion matrix
#'
#' This function calculates the weighted confusion matrix from a caret
#' ConfusionMatrix object or a simple matrix, according to one of several
#' weighting schemas and optionally prints the weighted accuracy score.
#'
#' @param m the caret confusion matrix object or simple matrix.
#'
#' @param weight.type the weighting schema to be used. Can be one of:
#' "arithmetic" - a decreasing arithmetic progression weighting scheme,
#' "geometric" - a decreasing geometric progression weighting scheme,
#' "normal" - weights drawn from the right tail of a normal distribution,
#' "interval" - weights contained on a user-defined interval,
#' "sin" - a weighing scheme based on a sine function,
#' "tanh" - a weighing scheme based on a hyperbolic tangent function,
#' "custom" - custom weight vector defined by the user.
#'
#' @param weight.penalty determines whether the weights associated with
#' non-diagonal elements generated by the "normal", "arithmetic" and "geometric"
#' weight types are positive or negative values. By default, the value is set to
#' FALSE, which means that generated weights will be positive values.
#'
#' @param standard.deviation standard deviation of the normal distribution, if
#' the normal distribution weighting schema is used.
#'
#' @param geometric.multiplier the multiplier used to construct the geometric
#' progression series, if the geometric progression weighting scheme is used.
#'
#' @param sin.high the upper segment of the sine function to be used in the
#' weighting scheme.
#'
#' @param sin.low the lower segment of the sine function to be used in the
#' weighting scheme.
#'
#' @param tanh.decay the decay factor of the hyperbolic tangent weighing
#' function. Higher values increase the rate of decay and place less weight on
#' observations farther away from the correctly predicted category.
#'
#' @param interval.high the upper bound of the weight interval, if the interval
#' weighting scheme is used.
#'
#' @param interval.low the lower bound of the weight interval, if the interval
#' weighting scheme is used.
#'
#' @param custom.weights the vector of custom weight to be applied, if the
#' custom weighting scheme was selected. The vector should be equal to "n", but
#' can be larger, with excess values being ignored.
#'
#' @param print.weighted.accuracy print the weighted accuracy metric, which
#' represents the sum of all weighted confusion matrix cells divided by the
#' total number of observations.
#'
#' @return an nxn weighted confusion matrix
#'
#' @details The number of categories "n" should be greater or equal to 2.
#'
#' @usage wconfusionmatrix(m, weight.type = "arithmetic",
#'                         weight.penalty = FALSE,
#'                         standard.deviation = 2,
#'                         geometric.multiplier = 2,
#'                         interval.high=1, interval.low = -1,
#'                         sin.high=1.5*pi, sin.low = 0.5*pi,
#'                         tanh.decay = 3,
#'                         custom.weights = NA,
#'                         print.weighted.accuracy = FALSE)
#'
#' @keywords weighted confusion matrix accuracy score
#'
#' @seealso [weightmatrix()] for the weight matrix used in computations,
#'   [balancedaccuracy()] for accuracy metrics designed for imbalanced data.
#'
#' @author Alexandru Monahov, <https://www.alexandrumonahov.eu.org/>
#'
#' @examples
#' m = matrix(c(70,0,0,10,10,0,5,3,2), ncol = 3, nrow=3)
#' wconfusionmatrix(m, weight.type="arithmetic", print.weighted.accuracy = TRUE)
#' wconfusionmatrix(m, weight.type="geometric", print.weighted.accuracy = TRUE)
#' wconfusionmatrix(m, weight.type="interval", print.weighted.accuracy = TRUE)
#' wconfusionmatrix(m, weight.type="normal", print.weighted.accuracy = TRUE)
#' wconfusionmatrix(m, weight.type="sin", print.weighted.accuracy = TRUE)
#' wconfusionmatrix(m, weight.type="tanh", print.weighted.accuracy = TRUE)
#' wconfusionmatrix(m, weight.type= "custom", custom.weights = c(1,0.1,0),
#'                  print.weighted.accuracy = TRUE)
#'
#' @export

wconfusionmatrix <- function(m, weight.type = "arithmetic", weight.penalty = FALSE, standard.deviation = 2, geometric.multiplier = 2, interval.high=1, interval.low = -1, sin.high=1.5*pi, sin.low = 0.5*pi, tanh.decay = 3, custom.weights = NA, print.weighted.accuracy = FALSE) {

  if (is.matrix(m) == FALSE) {m = as.matrix(m)}
  n = length(m[,1])
  cf = 0.123456789 # correction factor used to avoid weighting error due to assignment of weight value equal to the numbers used in the find-replace algorithm

  if (weight.type == "normal") {
  # Normal distribution
  a <- seq(from = 1, to = n, by = 1)
  fmean <- seq(from = 1, to = n, by = 1)
  mat <- t(mapply(function(mean,sd) dnorm(a,mean,sd)/max(dnorm(a,mean,sd)), mean=fmean, sd=standard.deviation))
  if (weight.penalty == TRUE) {
    mat = -(1-mat)
    diag(mat) = 1
  }
  if (print.weighted.accuracy == TRUE) {
    waccuracy = sum(m*mat)/sum(m)
    cat("Weighted accuracy = ", sum(m*mat)/sum(m), "\n", "\n")
  }
  return(m*mat)
  }

  else if (weight.type == "arithmetic") {
  # Arithmetic progression
  mat = ((n-1)-abs(outer(seq(0, (n-1), 1), seq(0, (n-1), 1), `-`)))/(n-1)
  if (weight.penalty == TRUE) {
    mat = -(1-mat)
    diag(mat) = 1
  }
  if (print.weighted.accuracy == TRUE) {
    waccuracy = sum(m*mat)/sum(m)
    cat("Weighted accuracy = ", sum(m*mat)/sum(m), "\n", "\n")
  }
  return(m*mat)
  }

  else if (weight.type == "geometric") {
  # Geometric progression
  mult = geometric.multiplier
  mat = (abs(outer(seq(0, (n-1), 1), seq(0, (n-1), 1), `-`)))+1+cf
  x=mult^seq(0,(n-1),by=1)
  x_n = (x-min(x))/(max(x)-min(x))
  if (mult > 1){
    x_dict = 1-x_n
  } else if (mult > 0 && mult < 1) {
    x_dict = x_n
  } else if (mult == 1) {
    x_dict = 1-seq(0, (n-1), 1)/(n-1)
  } else if (mult <= 0) {
    stop("Please enter a multiplier value greater than zero.")
  }
  for (i in 1:n) {
    mat[mat==i+cf] = x_dict[i]
  }
  if (weight.penalty == TRUE) {
    mat = -(1-mat)
    diag(mat) = 1
  }
  if (print.weighted.accuracy == TRUE) {
    waccuracy = sum(m*mat)/sum(m)
    cat("Weighted accuracy = ", sum(m*mat)/sum(m), "\n", "\n")
  }
  return(m*mat)
  }

  else if (weight.type == "sin") {
    sin_hi = sin.high
    sin_lo = sin.low
    mat = (abs(outer(seq(0, (n-1), 1), seq(0, (n-1), 1), `-`)))+1+cf
    mat_tmp = mat
    x = sin(seq(sin_lo, sin_hi, length.out = n))
    for (i in 1:n) {
      mat[mat_tmp==i+cf] = x[i]
    }
    if (print.weighted.accuracy == TRUE) {
      waccuracy = sum(m*mat)/sum(m)
      cat("Weighted accuracy = ", sum(m*mat)/sum(m), "\n", "\n")
    }
    return(m*mat)
  }

  else if (weight.type == "tanh") {
    tanh_decay = tanh.decay # higher values mean quicker decay (less weight placed on values far away from correct classification)
    mat = (abs(outer(seq(0, (n-1), 1), seq(0, (n-1), 1), `-`)))+1+cf
    mat_tmp = mat
    x = 1-tanh(seq(0, tanh_decay, length.out = n))
    if (weight.penalty == TRUE) {
      x = tanh(seq(0, tanh_decay, length.out = n))
    }
    for (i in 1:n) {
      mat[mat_tmp==i+cf] = x[i]
    }
    if (print.weighted.accuracy == TRUE) {
      waccuracy = sum(m*mat)/sum(m)
      cat("Weighted accuracy = ", sum(m*mat)/sum(m), "\n", "\n")
    }
    return(m*mat)
  }

  else if (weight.type == "interval") {
  # Interval weight
  hi = interval.high
  lo = interval.low
  mat = (abs(outer(seq(0, (n-1), 1), seq(0, (n-1), 1), `-`)))+1+cf
  mat_tmp = mat
  x=seq(hi, lo, length.out = n)
  for (i in 1:n) {
    mat[mat_tmp==i+cf] = x[i]
  }
  if (print.weighted.accuracy == TRUE) {
    waccuracy = sum(m*mat)/sum(m)
    cat("Weighted accuracy = ", sum(m*mat)/sum(m), "\n", "\n")
  }
  return(m*mat)
  }

  else if (weight.type == "custom") {
  # Custom weights
  wt = custom.weights
  mat = (abs(outer(seq(0, (n-1), 1), seq(0, (n-1), 1), `-`)))+1+cf
  for (i in 1:n) {
    mat[mat==i+cf] = wt[i]
  }
  if (print.weighted.accuracy == TRUE) {
    waccuracy = sum(m*mat)/sum(m)
    cat("Weighted accuracy = ", sum(m*mat)/sum(m), "\n", "\n")
  }
  return(m*mat)
  }

}

Try the wconf package in your browser

Any scripts or data that you put into this service are public.

wconf documentation built on Sept. 11, 2024, 6:22 p.m.