R/main_evalcurves.R

Defines functions .validate_run_evalcurve_args .make_titles .predict_curves .get_base_points .add_cat_labels .summarize_cat_scores .summarize_scores .eval_epoint .eval_intpts .eval_fpoint .eval_y_range .eval_x_range .run_curve_tests run_evalcurve

Documented in run_evalcurve

#' Evaluate Precision-Recall curves with specified tools and test sets
#'
#' The \code{run_evalcurve} function runs several tests to evaluate
#'    the accuracy of Precision-Recall curves.
#'
#' @param testset A character vector to specify a test set generated by
#'   \code{\link{create_testset}}.
#'
#' @param toolset A character vector to specify a tool set generated by
#'   \code{\link{create_toolset}}.
#'
#' @param auto_combo A Boolean value to specify whether a combination of test
#'   and tool sets is automatically created.
#'
#' @return A data frame with validation results.
#'
#' @seealso \code{\link{create_testset}} to generate a test dataset.
#'    \code{\link{create_toolset}} to generate a tool set.
#'
#' @examples
#' ## Evaluate curves for c1, c2, c3 test sets and crv5 tool set
#' testset <- create_testset("curve", c("c1", "c2", "c3"))
#' toolset <- create_toolset(set_names = "crv5")
#' res1 <- run_evalcurve(testset, toolset)
#' res1
#'
#' @export
run_evalcurve <- function(testset, toolset, auto_combo = TRUE) {
  # Validate arguments
  new_args <- .validate_run_evalcurve_args(testset, toolset)

  # Prepare tool sets and test data sets
  if (auto_combo) {
    new_testset <- rep(new_args$testset, length(new_args$toolset))
    new_toolset <- rep(new_args$toolset, each = length(new_args$testset))
  } else {
    new_testset <- new_args$testset
    new_toolset <- new_args$toolset
  }

  # Evaluate curves
  testres <- .run_curve_tests(new_testset, new_toolset)
  summres <- .summarize_scores(testres, new_args$testset)
  catres <- .summarize_cat_scores(testres)
  summres <- .add_cat_labels(summres, catres, new_args$testset)
  bptsres <- .get_base_points(new_args$testset)
  predres <- .predict_curves(new_testset, new_toolset)
  ttitiles <- .make_titles(toolset)

  reslst <- list(
    testscores = testres, testsum = summres, catres = catres,
    basepoints = bptsres, predictions = predres,
    titles = ttitiles
  )

  # Create an S3 object
  structure(reslst, class = "evalcurve")
}

#
# Validate a Precision-Recall curve
#
.run_curve_tests <- function(testset, toolset) {
  tfunc <- function(i) {
    tool <- toolset[[i]]
    tset <- testset[[i]]
    tool$call(tset)

    vres <- .eval_x_range(tool)
    vres <- rbind(vres, .eval_y_range(tool))
    vres <- rbind(vres, .eval_fpoint(tset, tool))
    vres <- rbind(vres, .eval_intpts(tset, tool))
    vres <- rbind(vres, .eval_epoint(tset, tool))
    rownames(vres) <- NULL

    resdf <- data.frame(
      testitem = c(
        "x_range", "y_range", "fpoint", "intpts",
        "epoint"
      ),
      testcat = c("Rg", "Rg", "SE", "Ip", "SE")
    )
    scoredf <- cbind(resdf, vres)
    basedf <- data.frame(
      testset = tset$get_tsname(),
      toolset = tool$get_setname(),
      toolname = tool$get_toolname()
    )
    cbind(basedf, scoredf)
  }
  res <- do.call(rbind, lapply(seq_along(testset), tfunc))
  sres <- res[order(res$testset, res$toolset, res$toolname), ]
  rownames(sres) <- NULL
  sres
}

#
# Check the x value range of a Precision-Recall curve
#
.eval_x_range <- function(tool) {
  # Test 1
  if (all(tool$get_x() >= 0, na.rm = TRUE)) {
    test1 <- TRUE
  } else {
    test1 <- FALSE
  }

  # Test 2
  if (all(tool$get_x() <= 1, na.rm = TRUE)) {
    test2 <- TRUE
  } else {
    test2 <- FALSE
  }

  if (test1 && test2) {
    success <- 1
  } else {
    success <- 0
  }

  scores <- c(success, 1)
  names(scores) <- c("success", "total")
  scores
}

#
# Check the y value range of a Precision-Recall curve
#
.eval_y_range <- function(tool) {
  # Test 1
  if (all(tool$get_y() >= 0, na.rm = TRUE)) {
    test1 <- TRUE
  } else {
    test1 <- FALSE
  }

  # Test 2
  if (all(tool$get_y() <= 1, na.rm = TRUE)) {
    test2 <- TRUE
  } else {
    test2 <- FALSE
  }

  if (test1 && test2) {
    success <- 1
  } else {
    success <- 0
  }

  scores <- c(success, 1)
  names(scores) <- c("success", "total")
  scores
}

#
# Check the first point of a Precision-Recall curve
#
.eval_fpoint <- function(tset, tool, tolerance = 1e-4) {
  bx <- tset$get_basepoints_x()
  by <- tset$get_basepoints_y()

  if (!is.na(tool$get_x()[1]) && !is.na(tool$get_y()[1]) &&
    (abs(tool$get_x()[1] - bx[1]) < tolerance) &&
    (abs(tool$get_y()[1] - by[1]) < tolerance)) {
    success <- 1
  } else {
    success <- 0
  }

  scores <- c(success, 1)
  names(scores) <- c("success", "total")
  scores
}

#
# Check intermediate points of a Precision-Recall curve
#
.eval_intpts <- function(tset, tool, tolerance = 1e-4) {
  bx <- tset$get_basepoints_x()
  by <- tset$get_basepoints_y()

  if (length(bx) > 2) {
    fpfunc <- function(i) {
      xs <- tool$get_x()
      ys <- tool$get_y()
      xidx <- abs(xs - bx[i]) < tolerance
      if (!any(xidx)) {
        xs_left <- xs[xs <= bx[i]]
        xs_right <- xs[xs >= bx[i]]
        if (length(xs_left) > 0 && length(xs_right) > 0) {
          prev_x <- max(xs_left)
          next_x <- min(xs_right)
          prev_y <- ys[max(match(prev_x, xs))]
          next_y <- ys[min(match(next_x, xs))]
          if (any(is.na(c(prev_y, next_y)))) {
            yvals <- NA
          } else {
            yvals <- approx(c(prev_x, next_x), c(prev_y, next_y), bx[i])$y
          }
        } else {
          yvals <- NA
        }
      } else {
        yvals <- ys[xidx]
      }

      if (any(abs(yvals - by[i]) < tolerance, na.rm = TRUE)) {
        return(1)
      } else {
        return(0)
      }
    }

    fcounts <- lapply(2:(length(bx) - 1), fpfunc)
    success <- do.call(sum, fcounts)
    total <- length(bx) - 2
    success <- min(success, total)
  } else {
    success <- 0
    total <- 0
  }

  scores <- c(success, total)
  names(scores) <- c("success", "total")
  scores
}

#
# Check the end point of a Precision-Recall curve
#
.eval_epoint <- function(tset, tool, tolerance = 1e-4) {
  bx <- tset$get_basepoints_x()
  by <- tset$get_basepoints_y()

  epos1 <- length(tool$get_x())
  epos2 <- length(bx)

  if (!is.null(epos1) && !is.na(epos1) && (epos1 > 0) &&
    !is.na(tool$get_x()[epos1]) &&
    !is.na(tool$get_y()[epos1]) &&
    (abs(tool$get_x()[epos1] - bx[epos2]) < tolerance) &&
    (abs(tool$get_y()[epos1] - by[epos2]) < tolerance)) {
    success <- 1
  } else {
    success <- 0
  }

  scores <- c(success, 1)
  names(scores) <- c("success", "total")
  scores
}

#
# Summarize curve evaluation results
#
.summarize_scores <- function(testres, testset) {
  sumdf <- stats::aggregate(testres[, c("success", "total")],
    by = list(
      testres$testset, testres$toolset,
      testres$toolname
    ),
    FUN = sum, na.rm = TRUE
  )
  colnames(sumdf)[1:3] <- c("testset", "toolset", "toolname")
  sumdf$label <- factor(paste0(sumdf$success, "/", sumdf$total))
  sumdf$lbl_pos_x <- 0
  sumdf$lbl_pos_y <- 0
  for (tset in testset) {
    tsname <- tset$get_tsname()
    sumdf[sumdf$testset == tsname, "lbl_pos_x"] <- tset$get_textpos_x()
    sumdf[sumdf$testset == tsname, "lbl_pos_y"] <- tset$get_textpos_y()
  }

  sres <- sumdf[order(sumdf$testset, sumdf$toolset, sumdf$toolname), ]
  rownames(sres) <- NULL
  sres
}

#
# Summarize curve evaluation results by category
#
.summarize_cat_scores <- function(testres) {
  sumdf <- stats::aggregate(testres[, c("success", "total")],
    by = list(
      testres$testset, testres$testcat,
      testres$toolset, testres$toolname
    ),
    FUN = sum, na.rm = TRUE
  )
  colnames(sumdf)[1:4] <- c("testset", "testcat", "toolset", "toolname")
  sumdf$testcat <- factor(sumdf$testcat, levels = c("SE", "Ip", "Rg"))
  sumdf$label <- factor(paste0(sumdf$success, "/", sumdf$total))

  sres <- sumdf[order(
    sumdf$testset, sumdf$toolset, sumdf$toolname,
    sumdf$testcat
  ), ]
  rownames(sres) <- NULL
  sres
}

#
# Add label2 and pos2 to result summary
#
.add_cat_labels <- function(summres, catres, testsets) {
  summres$label2 <- NA
  for (i in seq_len(nrow(catres))) {
    testset <- catres$testset[i]
    toolset <- catres$toolset[i]
    toolname <- catres$toolname[i]
    srrow <- (summres$testset == testset) & (summres$toolset == toolset) &
      (summres$toolname == toolname)

    new_lab <- paste(catres$testcat[i], catres$label[i], sep = ": ")

    if (is.na(summres[srrow, "label2"])) {
      summres[srrow, "label2"] <- new_lab
    } else {
      summres[srrow, "label2"] <- paste(summres[srrow, "label2"], new_lab,
        sep = "\n"
      )
    }
  }

  summres$lbl_pos_x2 <- NA
  summres$lbl_pos_y2 <- NA
  for (tset in testsets) {
    tsname <- tset$get_tsname()
    if (is.na(summres[summres$testset == tsname, "lbl_pos_x2"][1])) {
      summres[summres$testset == tsname, "lbl_pos_x2"] <- tset$get_textpos_x2()
      summres[summres$testset == tsname, "lbl_pos_y2"] <- tset$get_textpos_y2()
    }
  }

  summres
}

#
# Get base points from test datasets
#
.get_base_points <- function(testset) {
  bfunc <- function(tset) {
    tsname <- tset$get_tsname()
    bpx <- tset$get_basepoints_x()
    bpy <- tset$get_basepoints_y()
    data.frame(testset = rep(tsname, length(bpx)), x = bpx, y = bpy)
  }
  bpres <- do.call(rbind, lapply(testset, bfunc))
  rownames(bpres) <- NULL
  bpres
}

#
# Predict curves by tools with test datasets
#
.predict_curves <- function(testset, toolset) {
  pfunc <- function(i) {
    tool <- toolset[[i]]
    tset <- testset[[i]]
    tool$call(tset)

    tsname <- tset$get_tsname()
    setname <- tool$get_setname()
    toolname <- tool$get_toolname()
    x <- tool$get_x()
    y <- tool$get_y()

    data.frame(
      testset = rep(tsname, length(x)),
      toolset = rep(setname, length(x)),
      toolname = rep(toolname, length(x)), x = x, y = y
    )
  }
  predres <- do.call(rbind, lapply(seq_along(testset), pfunc))
  rownames(predres) <- NULL
  predres
}

#
# Make plot titles
#
.make_titles <- function(toolsets) {
  tfunc <- function(ts) {
    if (ts$get_toolname() == ts$get_setname()) {
      tname <- ts$get_toolname()
    } else {
      tname <- paste(ts$get_setname(), ts$get_toolname(), sep = ":")
    }

    tname
  }

  titles <- unlist(lapply(toolsets, tfunc))
  unique(titles)
}

#
# Validate arguments and return updated arguments
#
.validate_run_evalcurve_args <- function(testset, toolset) {
  assertthat::assert_that(is.list(testset))
  assertthat::assert_that(length(testset) > 0)
  for (tset in testset) {
    if (!methods::is(tset, "TestDataC")) {
      stop("Invalid testset", call. = FALSE)
    }
  }

  assertthat::assert_that(is.list(toolset))
  assertthat::assert_that(length(toolset) > 0)
  for (tool in toolset) {
    if (!methods::is(tool, "ToolIFBase")) {
      stop("Invalid toolset", call. = FALSE)
    }
    if (tool$get_setname() %in% c("auc5", "auc4")) {
      stop(paste0("Invalid predifend tool set: ", tool$get_setname()),
        call. = FALSE
      )
    }
    if (!environment(tool$clone)$private$def_store_res) {
      stop(paste0("Incorrect store_res value in ", tool$get_toolname()),
        call. = FALSE
      )
    }
  }

  list(testset = testset, toolset = toolset)
}

Try the prcbench package in your browser

Any scripts or data that you put into this service are public.

prcbench documentation built on March 31, 2023, 5:27 p.m.