R/changeDecisionsDummies.R

Defines functions changeSingleRuleDummies dumColnames getDummyLevels changeDecisionsDummies

Documented in changeDecisionsDummies

#' Takes decisions and modifies them so that only one level of a multiclass variable is used in decisions
#'
#' @param rules a data frame with a column "condition".
#' @param dummy_var string vector with the names of columns to change to dummy variable.
#' @param data data on which to fit the decision ensemble.
#' @param target response variable.
#' @param classPos for classification, the positive class.
#' @param in_parallel if TRUE, the function is run in parallel.
#' @param n_cores if in_parallel = TRUE, and no cluster has been passed: number of cores to use, default is detectCores() - 1.
#' @param cluster the cluster to use to run the function in parallel (opt).
#'
#' @export
changeDecisionsDummies <- function(rules, dummy_var, data, target, classPos = NULL,
                                   in_parallel = FALSE, n_cores = detectCores() - 1, cluster = NULL) {

  # get the colnames and column numbers of dummy variables
  dum_lev <- lapply(dummy_var, getDummyLevels, colN = colnames(data))
  names(dum_lev) <- dummy_var
  dum_colN <- unlist(lapply(dummy_var, FUN = dumColnames, colN = colnames(data)))

  # replace the <=0.5 conditions in rules for all other levels with >0.5
  rules <- rules[, to_update := "no"]
  # separate levels in rules
  if (in_parallel == TRUE) {
    if (is.null(cluster) == TRUE) {
      message("Initiate parallelisation ... ")
      cluster <- makeCluster(n_cores)
      clusterEvalQ(cluster, library(endoR))
      clusterEvalQ(cluster, library(stringr))
      on.exit(stopCluster(cluster))
    }

    rules <- t(parApply(
      cl = cluster, rules, MARGIN = 1, FUN = changeSingleRuleDummies,
      dum_colN = dum_colN, dum_lev = dum_lev
    ))
  } else {
    rules <- t(apply(rules,
      MARGIN = 1, FUN = changeSingleRuleDummies,
      dum_colN = dum_colN, dum_lev = dum_lev
    ))
  }


  # get one rule per level
  tmp <- split(rules, seq(nrow(rules)))
  tmp <- lapply(tmp,
    FUN = function(x, newN) {
      names(x) <- newN
      return(x)
    },
    newN = colnames(rules)
  )
  if (in_parallel == TRUE) {
    rules <- parLapply(cl = cluster, tmp, fun = singleRulePerLevel)
  } else {
    rules <- lapply(tmp, FUN = singleRulePerLevel)
  }

  rules <- do.call(what = rbind, rules)
  rules <- as.data.table(rules)[, n := as.numeric(n)][, n := sum(n), by = condition]

  if ("err" %in% colnames(rules)) {
    rules <- rules[to_update == "yes", to_up := "yes"][, to_update := NULL][, to_update := NULL]
    rules <- unique(rules)
    # Get the new metrics
    if (nrow(rules[to_up == "yes", ]) > 0) {
      if ("imp" %in% colnames(rules)) {
        importances <- TRUE
        colN <- c("len", "err", "support", "pred", "imp", "to_up")
      } else {
        importances <- FALSE
        colN <- c("len", "err", "support", "pred", "to_up")
      }

      no_up <- unique(subset(rules, is.na(to_up), select = -to_up))
      tmp1 <- unique(subset(rules, to_up == "yes" & !(condition %in% no_up$condition))[, (colN) := NULL])

      tmp2 <- getDecisionsMetrics(tmp1,
        data = data, target = target, classPos = classPos,
        importances = importances,
        in_parallel = in_parallel, n_cores = n_cores, cluster = cluster
      )
      tmp2 <- tmp1[tmp2, on = "condition"]
      # tmp2 <- merge(tmp2, tmp1, all.x = TRUE, by = 'condition' )

      if (nrow(no_up) > 0) {
        rules <- rbind(no_up, tmp2)
      } else {
        rules <- tmp2
      }

      rules <- rules[, c("len", "support", "err", "pred", "n") := lapply(.SD, as.numeric),
        .SDcols = c("len", "support", "err", "pred", "n")
      ]
      rules <- unique(rules[, n := sum(n), by = condition])
    }
  } else {
    rules <- unique(rules[, to_update := NULL][, to_update := NULL])
  }


  return(rules)
}


##################################################
getDummyLevels <- function(var, colN) {
  sub <- str_which(colN, pattern = paste0("^", var))
  sub <- paste0("X[,", sub, "]<=0.5")
  return(sub)
}
dumColnames <- function(var, colN) {
  sub <- str_which(colN, pattern = paste0("^", var))
  dum <- rep(var, length(sub))
  names(dum) <- paste0("X[,", sub, "]<=0.5")
  return(dum)
}

#################################################
changeSingleRuleDummies <- function(rule, dum_colN, dum_lev) {
  ruleV <- unlist(strsplit(rule["condition"], " & "))
  if (length(ruleV[ruleV %in% names(dum_colN)]) > 0) {
    oriRule <- ruleV[!(ruleV %in% names(dum_colN))]
    ruleV <- ruleV[ruleV %in% names(dum_colN)]
    names(ruleV) <- dum_colN[ruleV]
    vdum <- unique(names(ruleV))
    newV <- c()
    for (v in vdum) {
      tmp <- names(dum_colN[dum_colN == v])
      tmp <- tmp[!(tmp %in% ruleV)]
      tmp <- str_replace(tmp, pattern = "<=", replacement = ">")
      if (length(tmp[tmp %in% oriRule]) > 0) next
      if (length(tmp) > 1) {
        tmp <- paste0("(", paste(tmp, collapse = " | "), ")")
      }
      newV <- c(newV, tmp)
    }
    rule["condition"] <- paste0(sort(c(oriRule, newV)), collapse = " & ")
    rule["to_update"] <- "yes"
  }

  return(rule)
}
aruaud/endoR documentation built on Jan. 25, 2025, 2:20 a.m.