R/train.R

Defines functions get_time_series_freq pretty_arguments print.twdtw_knn1 twdtw_knn1

Documented in print.twdtw_knn1 twdtw_knn1

#'
#' Train a KNN-1 TWDTW model
#'
#' This function prepares a KNN-1 model with the Time Warp Dynamic Time Warping (TWDTW) algorithm.
#'
#' @param x A three-dimensional stars object (x, y, time) with bands as attributes.
#' @param y An sf object with the coordinates of the training points.
#' @param time_weight A numeric vector with length two (steepness and midpoint of logistic weight) or a function.
#' See details in \link[twdtw]{twdtw}.
#' @param cycle_length The length of the cycle, e.g. phenological cycles. Details in \link[twdtw]{twdtw}.
#' @param time_scale Specifies the time scale for the observations. Details in \link[twdtw]{twdtw}.
#' @param resampling_fun a function specifying how to create temporal patterns using the samples.
#' If not defined, it will keep all samples. Note that reducing the samples to patterns can significantly
#' improve computational time of predictions. The resampling function must receive a single data frame as 
#' argument an return a model. See details.
#' @param start_column Name of the column in y that indicates the start date. Default is 'start_date'.
#' @param end_column Name of the column in y that indicates the end date. Default is 'end_date'.
#' @param label_colum Name of the column in y containing land use labels. Default is 'label'.
#' @param resampling_freq The time for sampling the time series if `resampling_fun` is given.
#' If NULL, the function will infer the frequency of observations in `x`.
#' @param ... Additional arguments passed to \link[twdtw]{twdtw}.
#'
#' @details If \code{resampling_fun} not informed, the KNN-1 model will retain all training samples.
#'
#' If a custom smoothing function is passed to `resampling_fun`, the function will be used to
#' resample values of samples sharing the same label (land cover class). 
#'
#' The custom smoothing function takes a single data frame and returns a model.
#' The data frame has two named columns:
#' \itemize{
#'   \item `x` is the first column representing the independent variable (typically time).
#'   \item `y` is the second column representing the dependent variable (e.g., band values) corresponding to each coordinate in `x`.
#' }
#'
#' Smooting the samples can significantly reduce the processing time for prediction using `twdtw_knn1` model.
#' 
#' See the examples section for further clarity.
#'
#' @return A 'twdtw_knn1' model containing the trained model information and the data used.
#'
#' @examples
#' \dontrun{
#'
#' # Read training samples
#' samples_path <-
# '   system.file("mato_grosso_brazil/samples.gpkg", package = "dtwSat")
#'
#' samples <- st_read(samples_path, quiet = TRUE)
#'
#' # Get satellite image time sereis files
#' tif_path <- system.file("mato_grosso_brazil", package = "dtwSat")
#' tif_files <- dir(tif_path, pattern = "\\.tif$", full.names = TRUE)
#'
#' # Get acquisition dates
#' acquisition_date <- regmatches(tif_files, regexpr("[0-9]{8}", tif_files))
#' acquisition_date <- as.Date(acquisition_date, format = "%Y%m%d")
#'
#' # Create a 3D datacube
#' dc <- read_stars(tif_files,
#'                  proxy = FALSE,
#'                  along = list(time = acquisition_date),
#'                  RasterIO = list(bands = 1:6))
#' dc <- st_set_dimensions(dc, 3, c("EVI", "NDVI", "RED", "BLUE", "NIR", "MIR"))
#' dc <- split(dc, c("band"))
#'
#' # Create a knn1-twdtw model
#' m <- twdtw_knn1(
#'  x = dc,
#'  y = samples,
#'  resampling_fun = function(data) mgcv::gam(y ~ s(x), data = data),
#'  cycle_length = 'year',
#'  time_scale = 'day',
#'  time_weight = c(steepness = 0.1, midpoint = 50))
#'
#' print(m)
#'
#' # Visualize model patterns
#' plot(m)
#'
#' # Classify satellite images
#' system.time(lu <- predict(dc, model = m))
#'
#' # Visualise land use classification
#' ggplot() +
#'   geom_stars(data = lu) +
#'   theme_minimal()
#'
#'
#' # Create a knn1-twdtw model with custom smoothing function
#'
#' m <- twdtw_knn1(x = dc,
#'  y = samples,
#'  resampling_fun = function(data) lm(y ~ factor(x), data = data),
#'  cycle_length = 'year',
#'  time_scale = 'day',
#'  time_weight = c(steepness = 0.1, midpoint = 50))
#' 
#' plot(m)
#'
#' }
#' @export
twdtw_knn1 <- function(x, y, resampling_fun = NULL, resampling_freq = NULL,
                       time_weight, cycle_length, time_scale,
                       start_column = 'start_date', end_column = 'end_date',
                       label_colum = 'label', ...){

  # Check if x is a stars object with a time dimension
  if (!inherits(x, "stars") || dim(x)['time'] < 1 || length(dim(x)) != 3) {
    stop("x must be a three-dimensional stars object with a 'time' dimension")
  }

  x <- split(x, c("time"))

  # Check if y is an sf object with point geometry
  if (!inherits(y, "sf") || !all(st_is(y, "POINT"))) {
    stop("y must be an sf object with point geometry")
  }

  # check for minimum set of twdtw arguments
  if (!(is.function(time_weight) || (is.numeric(time_weight) && length(time_weight) == 2))) stop("'time_weight' should be either a function or a numeric vector with length two")
  if (is.null(cycle_length)) stop("The 'cycle_length' argument is missing.")
  if (is.null(time_scale)) stop("The 'time_scale' argument is missing for 'cycle_length' type character.")

  # Check for required columns in y
  required_columns <- c(start_column, end_column, label_colum)
  missing_columns <- setdiff(required_columns, names(y))
  if (length(missing_columns) > 0) {
    stop(paste("Missing required columns in y:", paste(missing_columns, collapse = ", ")))
  }

  # adjust y column names
  st_geometry(y) <- 'geom'
  y <- y[, c(start_column, end_column, label_colum, 'geom')]

  # Convert columns to date-time
  y[, start_column] <- to_date_time(y[[start_column]])
  y[, end_column] <- to_date_time(y[[end_column]])

  # Prepare time series samples from stars object
  ts_data <- st_extract(x, y)
  ts_data$label <- y[[label_colum]]
  ts_data <- prepare_time_series(as.data.frame(ts_data))
  ts_data$ts_id <- NULL

  smooth_models <- NULL
  if(!is.null(resampling_fun)) {

    # Shift dates
    ts_data$observations <- lapply(ts_data$observations, shift_ts_dates)

    # Split data frame by label
    ts_data <- unnest(ts_data, cols = 'observations')
    ts_data <- nest(ts_data, .by = 'label', .key = "observations")

    # Apply smooth function
    smooth_models <- lapply(ts_data$observations, function(ts) {

      # Get timeline
      y_time <- ts$time
      ts$time <- NULL

      # Fit smooth model to each band
      smooth_models <- lapply(as.list(ts), function(band) {
        resampling_fun(data.frame(x = as.numeric(y_time), y = band))
      })

      return(smooth_models)

    })

    names(smooth_models) <- ts_data$label

    ts_data$observations <- lapply(seq_along(ts_data$observations), function(l) {

      # Get timeline
      ts <- ts_data$observations[[l]]
      y_time <- ts$time
      ts$time <- NULL

      # Determine time for resampling time sereis
      if (is.null(resampling_freq)) {
        pred_time <- unique(y_time)
      } else {
        pred_time <- seq(min(y_time), max(y_time), by = resampling_freq)
      }

      smoothed_data <- sapply(smooth_models[[l]], function(m) {
        # Determine target class
        target_class <- class(model.frame(m)[, 2])
        # Convert pred_time based on target class
        pred_points <- if (target_class == "factor") {
          factor(as.numeric(pred_time), levels = levels(model.frame(m)[, 2]))
        } else {
          as(as.numeric(pred_time), target_class)
        }
        predict(m, newdata = data.frame(x = pred_points))
      })

      # Bind time and smoothed data into a data frame
      result_df <- data.frame(time = pred_time, smoothed_data)

      return(result_df)
    })

  }

  model <- list()
  model$call <- match.call()
  model$smooth_models <- smooth_models
  model$data <- ts_data
  # add twdtw arguments to model
  model$twdtw_args <- list(time_weight = time_weight,
                           cycle_length = cycle_length,
                           time_scale = time_scale,
                           origin = NULL,
                           max_elapsed = Inf,
                           version = "f90")
  new_twdtw_args <- list(...)
  matching_twdtw_args <- intersect(names(model$twdtw_args), names(new_twdtw_args))
  model$twdtw_args[matching_twdtw_args] <- new_twdtw_args[matching_twdtw_args]

  class(model) <- "twdtw_knn1"

  return(model)

}

#' Print method for objects of class twdtw_knn1
#'
#' This method provides a structured printout of the important components
#' of a `twdtw_knn1` object.
#'
#' @param x An object of class `twdtw_knn1`.
#' @param ... ignored
#'
#' @return Invisible `twdtw_knn1` object.
#'
#' @export
print.twdtw_knn1 <- function(x, ...) {
  cat("\nModel of class 'twdtw_knn1'\n")
  cat("-----------------------------\n")

  # Printing the call
  cat("Call:\n")
  print(x$call)

  # Printing the smooth_fun, if available
  cat("\nRoot Mean Squared Error (RMSE) of smooth models:\n")
  if(is.null(x$smooth_models)){
    print(NULL)
  } else {
    print(sapply(x$smooth_models, function(m1) sapply(m1, function(m2){
      residuals <- try(resid(m2))
      ifelse(is.numeric(residuals), sqrt(mean(residuals^2)), NA)
    })))
  }

  # Printing the data summary
  cat("\nData:\n")
  print(x$data)

  # Printing twdtw arguments
  cat("\nTWDTW Arguments:\n")
  pretty_arguments(x$twdtw_args)

  invisible(x) # Returns the object invisibly, so it doesn't print twice
}

#' Print Pretty Arguments
#'
#' Display a list of arguments of a given function in a human-readable format.
#'
#' @param args A list of named arguments to display.
#'
#' @return Invisible NULL. The function is mainly used for its side effect of printing.
#'
#' @examples
#' \dontrun{
#' pretty_arguments(formals(twdtw_knn1))
#' }
#'
#' @noRd
#' @keywords internal
pretty_arguments <- function(args) {

  if (is.null(args)) {
    cat("Arguments are missing.\n")
    return(invisible(NULL))
  }

  for (name in names(args)) {
    default_value <- args[[name]]

    if (is.symbol(default_value)) {
      default_value <- as.character(default_value)

    } else if (is.null(default_value)) {
      default_value <- "NULL"

    } else if (is.vector(default_value) && !is.null(names(default_value))) {
      # Handle named vectors
      values <- paste(names(default_value), default_value, sep = "=", collapse = ", ")
      default_value <- paste0("c(", values, ")")
    }
    cat(paste0(" - ", name, ": ", default_value, "\n"))
  }
}



#' Compute the Most Common Sampling Frequency across all observations
#'
#' This function calculates the most common difference between consecutive time points.
#' This can be useful for determining the aproximate sampling frequency of the time series data.
#'
#' @param x A data frame including a column called `observations`` with the time series 
#'
#' @return A difftime object representing the most common time difference between consecutive samples.
#'
#' @noRd
#' @keywords internal
get_time_series_freq <- function(x) {

  # Extract the time dimension
  time_values <- unlist(lapply(x$observations, function(ts) ts$time))

  # Compute the differences between consecutive time points
  time_diffs <- unlist(lapply(x$observations, function(ts) diff(ts$time)))

  # Convert differences to days (while retaining the difftime class)
  time_diffs <- as.difftime(time_diffs, units = "days")

  # Identify the mode
  mode_val_index <- which.max(tabulate(match(time_diffs, unique(time_diffs))))
  freq <- diff(time_values[mode_val_index:(mode_val_index+1)])

  return(freq)

}
vwmaus/dtwSat documentation built on Jan. 28, 2024, 9 a.m.