R/predict.R

Defines functions predict.twdtw_knn1

Documented in predict.twdtw_knn1

#' Predict using the twdtw_knn1 model
#'
#' This function predicts the classes of new data using the Time Warped Dynamic Time Warping (TWDTW) 
#' method with a 1-nearest neighbor approach. The prediction is based on the minimum TWDTW distance 
#' to the known patterns stored in the `twdtw_knn1` model.
#'
#' @param object A `twdtw_knn1` model object generated by the `twdtw_knn1` function.
#' @param newdata A data frame or similar object containing the new observations 
#' (time series data) to be predicted.
#' @param ... Additional arguments passed to the \link[twdtw]{twdtw} function.
#' If provided, they will overwrite twdtw arguments previously passed to \link[dtwSat]{twdtw_knn1}.
#'
#' @return A vector of predicted classes for the `newdata`.
#'
#' @seealso twdtw_knn1
#'
#' @inherit twdtw_knn1 examples
#'
#' @export
predict.twdtw_knn1 <- function(object, newdata, ...){

  # Update twdtw_args with new arguments passed via ...
  new_twdtw_args <- list(...)
  matching_twdtw_args <- intersect(names(new_twdtw_args), names(object$twdtw_args))
  object$twdtw_args[matching_twdtw_args] <- new_twdtw_args[matching_twdtw_args]

  # Convert newdata to time series
  newdata_ts <- prepare_time_series(newdata)

  # Compute TWDTW distances
  distances <- sapply(seq_along(object$data$observations), function(i){
    pattern <- object$data$observations[[i]]
    sapply(seq_along(newdata_ts$observations), function(j) {
      ts <- newdata_ts$observations[[j]]
      do.call(proxy::dist, c(list(x = as.data.frame(ts), y = as.data.frame(pattern), method = 'twdtw'), object$twdtw_args))
    })
  })

  # Find the nearest neighbor for each observation in newdata
  nearest_neighbor <- apply(distances, 1, which.min)

  # Return the predicted label for each observation based on the nearest neighbor
  return(factor(object$data$label[nearest_neighbor]))
}
vwmaus/dtwSat documentation built on Jan. 28, 2024, 9 a.m.