R/threshold_strategies.R

Defines functions threshold_strategies

Documented in threshold_strategies

#' Subset an rtrack_strategies object.
#'
#' Subsets strategy calls based on a threshold.
#'
#' For strategy-calling algorithms yielding a confidence score (such as
#' \code{\link{call_strategy}}), a value between 0 and 1 will return a new
#' \code{rtrack_strategies} object only including calls with a confidence score
#' above the given threshold.
#'
#' @param strategies An \code{rtrack_strategies} object as generated by
#'   \code{\link{call_strategy}}.
#' @param threshold A numeric value between 0 and 1.
#'
#' @return An \code{rtrack_strategies} object including only above-threshold
#'   calls. In addition, the component \code{thresholded} is set to \code{TRUE}
#'   if thresholding was performed.
#'
#' @examples
#' require(Rtrack)
#' track_file <- system.file("extdata", "Track_1.tab", package = "Rtrack")
#' arena_description <- system.file("extdata", "Arena.txt", package = "Rtrack")
#' arena <- read_arena(arena_description)
#' path <- read_path(track_file, arena, track.format = "raw.tab")
#' metrics <- calculate_metrics(path, arena)
#' strategies <- call_strategy(metrics)
#' # Inspect the strategy call (minimal experiment only has one track)
#' strategies$calls
#' # Thresholding at 0.7 will retain the track (confidence = 0.72)
#' strategies = threshold_strategies(strategies, threshold = 0.7)
#' strategies$calls
#' # Thresholding at 0.8 will discard the track, still returning an (empty) rtrack_strategies object
#' strategies = threshold_strategies(strategies, threshold = 0.8)
#' strategies$calls
#'
#' @importFrom methods is
#' @importFrom stats predict sd
#' @importFrom utils data
#' @import randomForest
#'
#' @export
threshold_strategies = function(strategies, threshold = NULL) {
	if(methods::is(strategies, "rtrack_strategies")){
		if(strategies$method == "rtrack"){
			threshold = as.numeric(threshold)
			if(threshold > 1){threshold = 1; warning("The parameter 'threshold' should be in the range 0 - 1.")}
			if(threshold < 0){threshold = 0; warning("The parameter 'threshold' should be in the range 0 - 1.")}
			if(is.na(threshold)){threshold = 0; warning("The parameter 'threshold' should be in the range 0 - 1.")}
			if(is.null(threshold)){threshold = 0; warning("The parameter 'threshold' should be in the range 0 - 1.")}
			if(threshold > 0){
				strategies$tracks = strategies$tracks[strategies$calls$confidence >= threshold]
				strategies$calls = strategies$calls[strategies$calls$confidence >= threshold, ]
				strategies$thresholded = TRUE
			}
		}else{
			stop("Invalid 'rtrack_strategies' object.")
		}
	}else{
		stop("Supplied parameter 'strategies' must be a 'rtrack_strategies' object.")
	}

	class(strategies) = "rtrack_strategies"
	return(strategies)
}

Try the Rtrack package in your browser

Any scripts or data that you put into this service are public.

Rtrack documentation built on Aug. 10, 2023, 9:10 a.m.