R/data_interface.R

Defines functions .validate_create_usrdata create_usrdata .validate_create_testset_args .create_curvetest .create_benchtest create_testset

Documented in create_testset create_usrdata

#' Create a list of test datasets
#'
#' The \code{create_testset} function creates test datasets either for
#' benchmarking or curve evaluation.
#'
#' @param test_type A single string to specify the type of dataset generated by
#'   this function.
#'
#'   \describe{
#'     \item{"bench"}{Create test datasets for benchmarking}
#'     \item{"curve"}{Create test datasets for curve evaluation}
#'   }
#'
#' @param set_names A character vector to specify the names of test
#'   datasets.
#'
#' \enumerate{
#'
#'   \item For benchmarking (\code{test_type = "bench"})
#'
#'   This function uses a naming convention for randomly generated data for
#'   benchmarking. The format is a prefix ('i' or 'b') followed by the number of
#'   dataset. The  prefix 'i' indicates a balanced dataset, whereas 'b'
#'   indicates an imbalanced dataset. The number can be used with a suffix 'k'
#'   or 'm', indicating respectively 1000 or 1 million.
#'
#'   Below are some examples.
#'   \describe{
#'     \item{"b100"}{A balanced data set with 50 positives and 50
#'         negatives.}
#'     \item{"b10k"}{A balanced data set with 5000 positives and 5000
#'         negatives.}
#'     \item{"b1m"}{A balanced data set with 500,000 positives and 500,000
#'         negatives.}
#'     \item{"i100"}{An imbalanced data set with 25 positives and 75
#'         negatives.}
#'   }
#'
#'   The function returns a list of \code{\link{TestDataB}} objects.
#'
#'   \item For curve evaluation (\code{test_type = "curve"})
#'
#'   The following three predefined datasets can be specified for curve
#'   evaluation.
#'
#'   \tabular{lll}{
#'     \strong{set name}
#'     \tab \strong{\code{S3} object}
#'     \tab \strong{data source} \cr
#'
#'     c1 or C1 \tab \code{\link{TestDataC}} \tab \code{\link{C1DATA}}   \cr
#'     c2 or C2 \tab \code{\link{TestDataC}} \tab \code{\link{C2DATA}}   \cr
#'     c3 or C3 \tab \code{\link{TestDataC}} \tab \code{\link{C3DATA}}   \cr
#'     c4 or C4 \tab \code{\link{TestDataC}} \tab \code{\link{C4DATA}}
#'   }
#'
#'   The function returns a list of \code{\link{TestDataC}} objects.
#' }
#'
#' @return A list of \code{R6} test dataset objects.
#'
#' @seealso \code{\link{run_benchmark}} and \code{\link{run_evalcurve}} require
#'  the list of the datasets generated by this function.
#'  \code{\link{TestDataB}} for benchmarking test data.
#'  \code{\link{TestDataC}}, \code{\link{C1DATA}}, \code{\link{C2DATA}},
#'  \code{\link{C3DATA}}, and \code{\link{C4DATA}} for curve evaluation
#'    test data.
#'  \code{\link{create_usrdata}} for creating a user-defined test set.
#'
#' @examples
#' ## Create a balanced data set with 50 positives and 50 negatives
#' tset1 <- create_testset("bench", "b100")
#' tset1
#'
#' ## Create an imbalanced data set with 25 positives and 75 negatives
#' tset2 <- create_testset("bench", "i100")
#' tset2
#'
#' ## Create P1 dataset
#' tset3 <- create_testset("curve", "c1")
#' tset3
#'
#' ## Create P1 dataset
#' tset4 <- create_testset("curve", c("c1", "c2"))
#' tset4
#'
#' @export
create_testset <- function(test_type, set_names = NULL) {
  # Validate arguments
  new_args <- .validate_create_testset_args(test_type, set_names)

  # Create a test dataset
  if (new_args$test_type == "bench") {
    dsets <- lapply(
      new_args$set_names,
      function(sname) {
        .create_benchtest(sname)
      }
    )
  } else if (new_args$test_type == "curve") {
    dsets <- lapply(
      new_args$set_names,
      function(sname) {
        .create_curvetest(sname)
      }
    )
  }

  names(dsets) <- new_args$set_names
  dsets
}

#
# Create a random sample dataset
#
.create_benchtest <- function(sname = NULL, np = 10, pfunc = NULL, nn = 10,
                              nfunc = NULL) {
  # Calculate np and nn when sname is specified
  if (!is.null(sname)) {
    tot <- as.numeric(gsub("[i|b|r|k|m]", "", tolower(sname)))
    if (grepl("k$", tolower(sname))) {
      tot <- tot * 1000
    } else if (grepl("m$", tolower(sname))) {
      tot <- tot * 1000 * 1000
    }
    if (tot < 2) {
      stop("Invalid set_names. Data set size must be >1.",
        call. = FALSE
      )
    }

    if (grepl("^i", tolower(sname))) {
      posratio <- 0.25
    } else if (grepl("^b", tolower(sname))) {
      posratio <- 0.5
    } else if (grepl("^r", tolower(sname))) {
      posratio <- stats::runif(1)
    } else {
      stop("Invalid set_names. Check the naming convetion.", call. = FALSE)
    }

    np <- round(tot * posratio)
    nn <- tot - np
  }

  # Sample positive scores
  if (is.null(pfunc)) {
    pfunc <- function(n) stats::rbeta(n, shape1 = 1, shape2 = 1)
  }

  # Sample negative scores
  if (is.null(nfunc)) {
    nfunc <- function(n) stats::rbeta(n, shape1 = 1, shape2 = 4)
  }

  # Create scores and labels
  scores <- c(pfunc(np), nfunc(nn))
  labels <- c(rep(1, np), rep(0, nn))

  # Create a TestDataB object
  TestDataB$new(scores, labels, as.character(sname))
}

#
# Get a test dataset with pre-calculated values
#
.create_curvetest <- function(sname) {
  if (tolower(sname) == "c1") {
    pdata <- prcbench::C1DATA
  } else if (tolower(sname) == "c2") {
    pdata <- prcbench::C2DATA
  } else if (tolower(sname) == "c3") {
    pdata <- prcbench::C3DATA
  } else if (tolower(sname) == "c4") {
    pdata <- prcbench::C4DATA
  } else {
    stop("Invalid dataset name", call. = FALSE)
  }

  # Create a TestDataC object
  ds <- TestDataC$new(pdata$scores, pdata$labels, sname)
  ds$set_basepoints_x(pdata$bp_x)
  ds$set_basepoints_y(pdata$bp_y)
  ds$set_textpos_x(pdata$tp_x)
  ds$set_textpos_y(pdata$tp_y)
  ds$set_textpos_x2(pdata$tp_x2)
  ds$set_textpos_y2(pdata$tp_y2)

  ds
}

#
# Validate arguments and return updated arguments
#
.validate_create_testset_args <- function(test_type, set_names) {
  assertthat::assert_that(assertthat::is.string(test_type))
  if (!is.na(pmatch(test_type, "bench"))) {
    test_type <- "bench"
  } else if (!is.na(pmatch(test_type, "curve"))) {
    test_type <- "curve"
  } else {
    stop("Invalid test_type. It must be either 'bench' or 'curve'.",
      call. = FALSE
    )
  }

  if (!is.null(set_names)) {
    set_names <- tolower(set_names)
    if (test_type == "bench") {
      for (sname in set_names) {
        assertthat::assert_that(assertthat::is.string(sname))
        cnum <- gsub("[i|b|k|m]", "", sname)
        if (suppressWarnings(is.na(as.numeric(cnum)))) {
          stop("Invalid set_names. Check the naming convetion",
            call. = FALSE
          )
        }
      }
    } else if (test_type == "curve") {
      c_set_names <- c("c1", "c2", "c3", "c4")
      if (length(setdiff(set_names, c_set_names)) != 0) {
        stop("Invalid set_names. Valid set_names are 'c1', 'c2', 'c3' or 'c4'.",
          call. = FALSE
        )
      }
    }
  }

  list(test_type = test_type, set_names = set_names)
}

#' Create a user-defined test dataset
#'
#' The \code{create_usrdata} function creates various types of test datasets.
#'
#' @param test_type A single string to specify the type of dataset generated by
#'   this function.
#'
#'   \describe{
#'     \item{"bench"}{Create a test dataset for benchmarking}
#'     \item{"curve"}{Create a test dataset for curve evaluation}
#'   }
#'
#' @param scores A numeric vector to set scores.
#'
#' @param labels A numeric vector to set labels.
#'
#' @param tsname A single string to specify the name of the dataset.
#'
#' @param base_x A numeric vector to set pre-calculated recall values for
#'    curve evaluation.
#'
#' @param base_y A numeric vector to set pre-calculated precision values for
#'    curve evaluation.
#'
#' @param text_x A single numeric value to set the x position for displaying
#'    the test result in a plot
#'
#' @param text_y A single numeric value to set the y position for displaying
#'    the test result in a plot
#'
#' @param text_x2 A single numeric value to set the x position for displaying
#'    the test result (group into categories) in a plot
#'
#' @param text_y2 A single numeric value to set the y position for displaying
#'    the test result (group into categories) in a plot
#'
#' @return A list of \code{R6} test dataset objects.
#'
#' @seealso \code{\link{create_testset}} for creating a predefined test set.
#'  \code{\link{TestDataB}} for benchmarking test data.
#'  \code{\link{TestDataC}} for curve evaluation test data.
#'
#' @examples
#' ## Create a test dataset for benchmarking
#' testset2 <- create_usrdata("bench",
#'   scores = c(0.1, 0.2), labels = c(1, 0),
#'   tsname = "m1"
#' )
#' testset2
#'
#' ## Create a test dataset for curve evaluation
#' testset <- create_usrdata("curve",
#'   scores = c(0.1, 0.2), labels = c(1, 0),
#'   base_x = c(0, 1.0), base_y = c(0, 0.5)
#' )
#' testset
#'
#' @export
create_usrdata <- function(test_type, scores = NULL, labels = NULL,
                           tsname = NULL, base_x = NULL, base_y = NULL,
                           text_x = NULL, text_y = NULL,
                           text_x2 = text_x, text_y2 = text_y) {
  # Validate arguments
  new_args <- .validate_create_usrdata(
    test_type, scores, labels, tsname,
    base_x, base_y, text_x, text_y,
    text_x2, text_y2
  )

  if (new_args$test_type == "bench") {
    dsets <- list(TestDataB$new(
      new_args$scores, new_args$labels,
      new_args$tsname
    ))
  } else if (new_args$test_type == "curve") {
    ds <- TestDataC$new(new_args$scores, new_args$labels, new_args$tsname)
    ds$set_basepoints_x(new_args$base_x)
    ds$set_basepoints_y(new_args$base_y)
    if (!is.null(new_args$text_x)) {
      ds$set_textpos_x(new_args$text_x)
    }
    if (!is.null(new_args$text_y)) {
      ds$set_textpos_y(new_args$text_y)
    }
    if (!is.null(new_args$text_x2)) {
      ds$set_textpos_x2(new_args$text_x2)
    }
    if (!is.null(new_args$text_y2)) {
      ds$set_textpos_y2(new_args$text_y2)
    }
    dsets <- list(ds)
  }

  names(dsets) <- new_args$tsname

  dsets
}

#
# Validate arguments and return updated arguments
#
.validate_create_usrdata <- function(test_type, scores, labels, tsname, base_x,
                                     base_y, text_x, text_y, text_x2, text_y2) {
  assertthat::assert_that(assertthat::is.string(test_type))
  if (!is.na(pmatch(test_type, "bench"))) {
    test_type <- "bench"
  } else if (!is.na(pmatch(test_type, "curve"))) {
    test_type <- "curve"
  } else {
    stop("Invalid test_type. It must be either 'bench' or 'curve'.",
      call. = FALSE
    )
  }

  assertthat::assert_that(is.numeric(scores))
  assertthat::assert_that(length(scores) > 1)

  assertthat::assert_that(is.numeric(labels) || is.factor(labels))
  assertthat::assert_that(length(labels) > 1)
  assertthat::assert_that(length(unique(labels)) == 2)

  assertthat::assert_that(length(scores) == length(labels))

  if (is.null(tsname)) {
    tsname <- "usr"
  }
  assertthat::assert_that(assertthat::is.string(tsname))

  if (test_type == "curve") {
    assertthat::assert_that(is.numeric(base_x))
    assertthat::assert_that(all(base_x >= 0.0) && all(base_x <= 1.0))
    assertthat::assert_that(is.numeric(base_y))
    assertthat::assert_that(all(base_y >= 0.0) && all(base_y <= 1.0))
    assertthat::assert_that(length(base_x) == length(base_y))

    lapply(c(text_x, text_y, text_x2, text_y2), function(p) {
      if (!is.null(p)) {
        assertthat::assert_that(assertthat::is.number(p))
        assertthat::assert_that(p >= 0.0 && p <= 1.0)
      }
    })
  }

  list(
    test_type = test_type, scores = scores, labels = labels, tsname = tsname,
    base_x = base_x, base_y = base_y, text_x = text_x, text_y = text_y,
    text_x2 = text_x2, text_y2 = text_y2
  )
}

Try the prcbench package in your browser

Any scripts or data that you put into this service are public.

prcbench documentation built on March 31, 2023, 5:27 p.m.