R/as_precrec.R

Defines functions as_precrec.BenchmarkResult as_precrec.ResampleResult as_precrec.PredictionClassif roc_data as_precrec

Documented in as_precrec as_precrec.BenchmarkResult as_precrec.PredictionClassif as_precrec.ResampleResult

#' @title Convert to 'precrec' Format
#'
#' @description
#' Converts to a format which is understood by [precrec::evalmod()] of package \CRANpkg{precrec}.
#'
#' @param object (`any`)\cr
#'   Object to convert.
#' @return Object as created by [precrec::mmdata()].
#'
#' @references
#' `r format_bib("precrec")`
#' @export
as_precrec = function(object) {
  UseMethod("as_precrec")
}

roc_data = function(prediction) {
  prediction = mlr3::as_prediction(prediction)
  if (nlevels(prediction$truth) != 2L) {
    stopf("Need a binary classification problem to plot a ROC curve")
  }

  if ("prob" %nin% prediction$predict_types) {
    stopf("Need predicted probabilities to plot a ROC curve")
  }

  data.table(
    scores = prediction$prob[, 1L, drop = TRUE],
    labels = prediction$truth
  )
}

#' @rdname as_precrec
#' @export
as_precrec.PredictionClassif = function(object) { # nolint
  require_namespaces("precrec")
  data = roc_data(object)
  precrec::mmdata(
    scores = data$scores,
    labels = data$labels,
    dsids = 1L,
    posclass = levels(object$truth)[1L]
  )
}


#' @rdname as_precrec
#' @export
as_precrec.ResampleResult = function(object) { # nolint
  require_namespaces("precrec")
  predictions = object$predictions()
  data = transpose_list(map(predictions, roc_data))

  precrec::mmdata(
    scores = data$scores,
    labels = data$labels,
    dsids = seq_along(predictions),
    posclass = object$task$positive
  )
}


#' @rdname as_precrec
#' @export
as_precrec.BenchmarkResult = function(object) { # nolint
  require_namespaces("precrec")
  scores = object$score(measures = list())

  if (uniqueN(scores$task_id) > 1L) {
    stopf("Unable to convert benchmark results with multiple tasks.")
  }
  if (uniqueN(scores$resampling_id) > 1L) {
    stopf("Unable to convert benchmark results with multiple resamplings.")
  }

  predictions = scores$prediction
  data = transpose_list(map(predictions, roc_data))
  data$labels = split(data$labels, scores$iteration)
  data$scores = split(data$scores, scores$iteration)

  lrns = unique(scores$learner_id)
  iters = unique(scores$iteration)
  precrec::mmdata(
    scores = data$scores,
    labels = data$labels,
    dsids = iters,
    modnames = lrns,
    posclass = object$tasks$task[[1L]]$positive
  )
}

Try the mlr3viz package in your browser

Any scripts or data that you put into this service are public.

mlr3viz documentation built on Nov. 23, 2023, 5:07 p.m.