R/sits_combine_predictions.R

Defines functions sits_combine_predictions.default sits_combine_predictions.uncertainty sits_combine_predictions.average sits_combine_predictions

Documented in sits_combine_predictions sits_combine_predictions.average sits_combine_predictions.default sits_combine_predictions.uncertainty

#' @title Estimate ensemble prediction based on list of probs cubes
#'
#' @name  sits_combine_predictions
#'
#' @author Gilberto Camara, \email{gilberto.camara@@inpe.br}
#' @author Rolf Simoes, \email{rolf.simoes@@inpe.br}
#'
#' @param  cubes         List of probability data cubes (class "probs_cube")
#' @param  type          Method to measure uncertainty. One of "average" or
#'                       "uncertainty"
#' @param  ...           Parameters for specific functions.
#' @param  weights       Weights for averaging (numeric vector).
#' @param  uncert_cubes  Uncertainty cubes to be used as local weights when
#'                       type = "uncertainty" is selected
#'                       (list of tibbles with class "uncertainty_cube")
#' @param  memsize       Memory available for classification in GB
#'                       (integer, min = 1, max = 16384).
#' @param  multicores    Number of cores to be used for classification
#'                       (integer, min = 1, max = 2048).
#' @param  output_dir    Valid directory for output file.
#'                       (character vector of length 1).
#' @param  version       Version of the output
#'                      (character vector of length 1).
#' @return A combined probability cube (tibble of class "probs_cube").
#'
#' @description Calculate an ensemble predictor based a list of probability
#' cubes. The function combines the output of two or more classifier
#' to derive a value which is based on weights assigned to each model.
#' The supported types of ensemble predictors are 'average' and
#' 'uncertainty'.
#'
#' @examples
#' if (sits_run_examples()) {
#'     # create a data cube from local files
#'     data_dir <- system.file("extdata/raster/mod13q1", package = "sits")
#'     cube <- sits_cube(
#'         source = "BDC",
#'         collection = "MOD13Q1-6",
#'         data_dir = data_dir
#'     )
#'     # create a random forest model
#'     rfor_model <- sits_train(samples_modis_ndvi, sits_rfor())
#'     # classify a data cube using rfor model
#'     probs_rfor_cube <- sits_classify(
#'         data = cube, ml_model = rfor_model, output_dir = tempdir(),
#'         version = "rfor"
#'     )
#'     # create an XGBoost model
#'     svm_model <- sits_train(samples_modis_ndvi, sits_svm())
#'     # classify a data cube using SVM model
#'     probs_svm_cube <- sits_classify(
#'         data = cube, ml_model = svm_model, output_dir = tempdir(),
#'         version = "svm"
#'     )
#'     # create a list of predictions to be combined
#'     pred_cubes <- list(probs_rfor_cube, probs_svm_cube)
#'     # combine predictions
#'     comb_probs_cube <- sits_combine_predictions(
#'         pred_cubes,
#'         output_dir = tempdir()
#'     )
#'     # plot the resulting combined prediction cube
#'     plot(comb_probs_cube)
#' }
#' @export
sits_combine_predictions <- function(cubes,
                                     type = "average", ...,
                                     memsize = 8L,
                                     multicores = 2L,
                                     output_dir,
                                     version = "v1") {
    # check if list of probs cubes have the same organization
    .check_probs_cube_lst(cubes)
    class(type) <- type
    UseMethod("sits_combine_predictions", type)
}

#' @rdname sits_combine_predictions
#' @export
sits_combine_predictions.average <- function(cubes,
                                             type = "average", ...,
                                             weights = NULL,
                                             memsize = 8L,
                                             multicores = 2L,
                                             output_dir,
                                             version = "v1") {
    # Check memsize
    .check_memsize(memsize, min = 1, max = 16384)
    # Check multicores
    .check_multicores(multicores, min = 1, max = 2048)
    # Check output dir
    .check_output_dir(output_dir)
    # Check version
    version <- .check_version(version)
    # version is case-insensitive in sits
    version <- tolower(version)
    # Get weights
    n_inputs <- length(cubes)
    weights <- .default(weights, rep(1 / n_inputs, n_inputs))
    .check_that(
        length(weights) == n_inputs,
        msg = "number of weights does not match number of inputs",
    )
    .check_that(
        sum(weights) == 1,
        msg = "weigths should add up to 1.0"
    )
    # Get combine function
    comb_fn <- .comb_fn_average(cubes, weights = weights)
    # Call combine predictions
    probs_cube <- .comb(
        probs_cubes = cubes,
        uncert_cubes = NULL,
        comb_fn = comb_fn,
        band = "probs",
        memsize = memsize,
        multicores = multicores,
        output_dir = output_dir,
        version = version,
        progress = FALSE, ...
    )
    return(probs_cube)
}

#' @rdname sits_combine_predictions
#' @export
sits_combine_predictions.uncertainty <- function(cubes,
                                                 type = "uncertainty", ...,
                                                 uncert_cubes,
                                                 memsize = 8L,
                                                 multicores = 2L,
                                                 output_dir,
                                                 version = "v1") {
    # Check memsize
    .check_memsize(memsize, min = 1, max = 16384)
    # Check multicores
    .check_multicores(multicores, min = 1, max = 2048)
    # Check output dir
    .check_output_dir(output_dir)
    # Check version
    version <- .check_version(version)
    # version is case-insensitive in sits
    version <- tolower(version)
    # Check if list of probs cubes and uncert_cubes have the same organization
    .check_that(
        length(cubes) == length(uncert_cubes),
        local_msg = "uncert_cubes must have same length of cubes",
        msg = "invalid uncert_cubes parameter"
    )
    .check_uncert_cube_lst(uncert_cubes)
    .check_cubes_match(cubes[[1]], uncert_cubes[[1]])
    # Get combine function
    comb_fn <- .comb_fn_uncertainty(cubes)
    # Call combine predictions
    probs_cube <- .comb(
        probs_cubes = cubes,
        uncert_cubes = uncert_cubes,
        comb_fn = comb_fn,
        band = "probs",
        memsize = memsize,
        multicores = multicores,
        output_dir = output_dir,
        version = version,
        progress = FALSE, ...
    )
    return(probs_cube)
}
#' @rdname sits_combine_predictions
#' @export
sits_combine_predictions.default <- function(cubes, type, ...) {
    stop("Invalid method for combining predictions")
}
e-sensing/sits documentation built on Jan. 28, 2024, 6:05 a.m.