R/pl6_calc_average.R

Defines functions .validate.avgpoints .validate.avgcurves .validate_avg_common .calc_avg_common calc_avg_basic calc_avg_rocprc

#
# Calculate the average curve for a model
#
calc_avg_rocprc <- function(curves, modnames, uniq_modnames, cb_alpha,
                            x_bins) {
  .calc_avg_common(
    curves, "curve", "avgcurves", modnames, uniq_modnames,
    cb_alpha, x_bins
  )
}

#
# Calculate the average points for a model
#
calc_avg_basic <- function(epoints, modnames, uniq_modnames, cb_alpha) {
  .calc_avg_common(
    epoints, "point", "avgpoints", modnames, uniq_modnames,
    cb_alpha, NULL
  )
}

#
# Calculate averages
#
.calc_avg_common <- function(obj, mode, class_name, modnames, uniq_modnames,
                             cb_alpha, x_bins) {
  # === Validate input arguments ===
  if (is.null(x_bins) || any(is.na(x_bins))) {
    x_bins <- 1
  }
  .validate_cb_alpha(cb_alpha)
  .validate_x_bins(x_bins, allow_zero = TRUE)

  # === Summarize curves by by models ===
  # Z value of confidence bounds
  cb_zval <- stats::qnorm((1.0 - cb_alpha) + (cb_alpha * 0.5))

  # Filter curves by model
  ffunc <- function(mname) {
    obj[modnames == mname]
  }
  obj_by_model <- lapply(uniq_modnames, ffunc)

  # Calculate averages and confidence bounds
  vfunc <- function(i) {
    if (mode == "curve") {
      avgs <- calc_avg_curve(obj_by_model[[i]], x_bins, cb_zval)
      .check_cpp_func_error(avgs, "calc_avg_curve")
    } else if (mode == "point") {
      avgs <- calc_avg_points(obj_by_model[[i]], cb_zval)
      .check_cpp_func_error(avgs, "calc_avg_basic")
    }
    avgs[["avg"]]
  }
  lavgs <- lapply(seq_along(obj_by_model), vfunc)

  # === Create an S3 object ===
  s3obj <- structure(lavgs, class = class_name)

  # Set attributes
  attr(s3obj, "uniq_modnames") <- uniq_modnames
  attr(s3obj, "cb_zval") <- cb_zval
  attr(s3obj, "pauc") <- NA
  attr(s3obj, "spauc") <- NA
  attr(s3obj, "args") <- list(
    cb_alpha = cb_alpha,
    x_bins = x_bins
  )
  attr(s3obj, "validated") <- FALSE

  # Call .validate()
  .validate(s3obj)
}

#
# Validate avg object
#
.validate_avg_common <- function(avgobj, class_name, func_name) {
  # Need to validate only once
  if (methods::is(avgobj, class_name) && attr(avgobj, "validated")) {
    return(avgobj)
  }

  # Validate class items and attributes
  item_names <- NULL
  attr_names <- c("uniq_modnames", "cb_zval", "args", "validated")
  arg_names <- c("cb_alpha", "x_bins")
  .validate_basic(
    avgobj, class_name, func_name, item_names, attr_names,
    arg_names
  )

  attr(avgobj, "validated") <- TRUE
  avgobj
}

#
# Validate 'avgcurves' object generated by calc_avg_rocprc()
#
.validate.avgcurves <- function(x) {
  .validate_avg_common(x, "avgcurves", "calc_avg_rocprc")
}

#
# Validate 'avgpoints' object generated by calc_avg_basic()
#
.validate.avgpoints <- function(x) {
  .validate_avg_common(x, "avgpoints", "calc_avg_basic")
}
takayasaito/precrec documentation built on Oct. 19, 2023, 7:28 p.m.