#' Predict using the twdtw_knn1 model
#'
#' This function predicts the classes of new data using the Time Warped Dynamic Time Warping (TWDTW)
#' method with a 1-nearest neighbor approach. The prediction is based on the minimum TWDTW distance
#' to the known patterns stored in the `twdtw_knn1` model.
#'
#' @param object A `twdtw_knn1` model object generated by the `twdtw_knn1` function.
#' @param newdata A data frame or similar object containing the new observations
#' (time series data) to be predicted.
#' @param ... Additional arguments passed to the \link[twdtw]{twdtw} function.
#' If provided, they will overwrite twdtw arguments previously passed to \link[dtwSat]{twdtw_knn1}.
#'
#' @return A vector of predicted classes for the `newdata`.
#'
#' @seealso twdtw_knn1
#'
#' @inherit twdtw_knn1 examples
#'
#' @export
predict.twdtw_knn1 <- function(object, newdata, ...){
# Update twdtw_args with new arguments passed via ...
new_twdtw_args <- list(...)
matching_twdtw_args <- intersect(names(new_twdtw_args), names(object$twdtw_args))
object$twdtw_args[matching_twdtw_args] <- new_twdtw_args[matching_twdtw_args]
# Convert newdata to time series
newdata_ts <- prepare_time_series(newdata)
# Compute TWDTW distances
distances <- sapply(seq_along(object$data$observations), function(i){
pattern <- object$data$observations[[i]]
sapply(seq_along(newdata_ts$observations), function(j) {
ts <- newdata_ts$observations[[j]]
do.call(proxy::dist, c(list(x = as.data.frame(ts), y = as.data.frame(pattern), method = 'twdtw'), object$twdtw_args))
})
})
# Find the nearest neighbor for each observation in newdata
nearest_neighbor <- apply(distances, 1, which.min)
# Return the predicted label for each observation based on the nearest neighbor
return(factor(object$data$label[nearest_neighbor]))
}
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.