R/PredictionDataClassif.R

Defines functions filter_prediction_data.PredictionDataClassif c.PredictionDataClassif is_missing_prediction_data.PredictionDataClassif check_prediction_data.PredictionDataClassif

Documented in check_prediction_data.PredictionDataClassif c.PredictionDataClassif is_missing_prediction_data.PredictionDataClassif

#' @rdname PredictionData
#' @param train_task ([Task])\cr
#'   Task used for training the learner.
#' @export
check_prediction_data.PredictionDataClassif = function(pdata, train_task, ...) { # nolint
  pdata$row_ids = assert_row_ids(pdata$row_ids)
  n = length(pdata$row_ids)
  assert_factor(pdata$truth, len = n, null.ok = TRUE)
  # unsupervised task
  if (is.null(pdata$truth)) {
    lvls = train_task$col_info[train_task$target_names, get("levels"), on = "id"][[1]]
    pdata$truth = if (length(pdata$row_ids)) factor(NA, lvls) else factor(levels = lvls)
  }
  lvls = levels(pdata$truth)

  if (!is.null(pdata$response)) {
    pdata$response = assert_factor(as_factor(unname(pdata$response), levels = lvls))
    assert_prediction_count(length(pdata$response), n, "response")
  }

  if (!is.null(pdata$prob)) {
    prob = assert_matrix(pdata$prob)
    assert_prediction_count(nrow(pdata$prob), n, "prob")
    assert_numeric(prob, lower = 0, upper = 1)
    assert_row_sums(prob)

    if (!identical(colnames(prob), lvls)) {
      assert_subset(colnames(prob), lvls)

      # add missing columns with prob == 0
      miss = setdiff(lvls, colnames(prob))
      if (length(miss)) {
        prob = cbind(prob, matrix(0, nrow = n, ncol = length(miss), dimnames = list(NULL, miss)))
      }

      # reorder columns to match the level order of `truth`
      prob = prob[, reorder_vector(colnames(prob), lvls), drop = FALSE]
    }

    if (!is.null(rownames(prob))) {
      rownames(prob) = NULL
    }

    pdata$prob = prob

    if (is.null(pdata$response)) {
      # calculate response from prob
      i = max.col(prob, ties.method = "random")
      pdata$response = factor(colnames(prob)[i], levels = lvls)
    }
  }

  pdata
}



#' @rdname PredictionData
#' @export
is_missing_prediction_data.PredictionDataClassif = function(pdata, ...) { # nolint
  miss = logical(length(pdata$row_ids))
  if (!is.null(pdata$response)) {
    miss = is.na(pdata$response)
  }
  if (!is.null(pdata$prob)) {
    miss = miss | apply(pdata$prob, 1L, anyMissing)
  }

  pdata$row_ids[miss]
}

#' @rdname PredictionData
#'
#' @param keep_duplicates (`logical(1)`)
#'   If `TRUE`, the combined [PredictionData] object is filtered for duplicated
#'   row ids (starting from last).
#' @param ... (one or more [PredictionData] objects).
#' @export
c.PredictionDataClassif = function(..., keep_duplicates = TRUE) {
  dots = list(...)
  assert_list(dots, "PredictionDataClassif")
  assert_flag(keep_duplicates)
  if (length(dots) == 1L) {
    return(dots[[1L]])
  }

  predict_types = names(mlr_reflections$learner_predict_types$classif)
  predict_types = map(dots, function(x) intersect(names(x), predict_types))
  if (!every(predict_types[-1L], setequal, y = predict_types[[1L]])) {
    stopf("Cannot rbind predictions: Different predict types")
  }

  elems = c("row_ids", "truth", intersect(predict_types[[1L]], "response"))
  tab = map_dtr(dots, function(x) x[elems], .fill = FALSE)
  prob = do.call(rbind, map(dots, "prob"))

  if (!keep_duplicates) {
    keep = !duplicated(tab, by = "row_ids", fromLast = TRUE)
    tab = tab[keep]
    prob = prob[keep, , drop = FALSE]
  }

  result = as.list(tab)
  result$prob = prob
  new_prediction_data(result, "classif")
}

#' @export
filter_prediction_data.PredictionDataClassif = function(pdata, row_ids, ...) {
  keep = pdata$row_ids %in% row_ids
  pdata$row_ids = pdata$row_ids[keep]
  pdata$truth = pdata$truth[keep]

  if (!is.null(pdata$response)) {
    pdata$response = pdata$response[keep]
  }

  if (!is.null(pdata$prob)) {
    pdata$prob = pdata$prob[keep,, drop = FALSE]
  }

  pdata
}
mlr-org/mlr3 documentation built on July 10, 2024, 10:53 a.m.