R/S4-TSClusters-methods.R

Defines functions n_of_objects.TSClusters n_of_classes.TSClusters is.cl_partition.TSClusters is.cl_hierarchy.TSClusters is.cl_hard_partition.TSClusters is.cl_dendrogram.TSClusters cl_membership.TSClusters cl_class_ids.TSClusters as.cl_membership.TSClusters cvi_TSClusters plot.TSClusters predict.TSClusters update.TSClusters

Documented in plot.TSClusters predict.TSClusters update.TSClusters

#' Methods for `TSClusters`
#'
#' Methods associated with [TSClusters-class] and derived objects.
#'
#' @name tsclusters-methods
#' @rdname tsclusters-methods
#' @aliases TSClusters-methods
#' @include S4-TSClusters-classes.R
#' @importFrom methods setMethod
#'
NULL

# ==================================================================================================
# Custom initialize
# ==================================================================================================

#' @rdname tsclusters-methods
#' @aliases initialize,TSClusters
#' @importFrom methods callNextMethod
#' @importFrom methods initialize
#' @importFrom methods new
#'
#' @param .Object A `TSClusters` prototype. You *shouldn't* use this, see Initialize section and the
#'   examples.
#' @param ... For `initialize`, any valid slots. For `plot`, passed to [ggplot2::geom_line()] for
#'   the plotting of the *cluster centroids*, or to [stats::plot.hclust()]; see Plotting section and
#'   the examples. For `update`, any supported argument. Otherwise ignored.
#' @param override.family Logical. Attempt to substitute the default family with one that conforms
#'   to the provided elements? See Initialize section.
#'
#' @section Initialize:
#'
#'   The initialize method is used when calling [methods::new()]. The `family` slot can be
#'   substituted with an appropriate one if certain elements are provided by the user. The
#'   initialize methods of derived classes also inherit the family and can use it to calculate other
#'   slots. In order to get a fully functional object, at least the following slots should be
#'   provided:
#'
#'   - `type`: "partitional", "hierarchical", "fuzzy" or "tadpole".
#'   - `datalist`: The data in one of the supported formats.
#'   - `centroids`: The time series centroids in one of the supported formats.
#'   - `cluster`: The cluster indices for each series in the `datalist`.
#'   - `control*`: A [tsclust-controls] object with the desired parameters.
#'   - `distance*`: A string indicating the distance that should be used.
#'   - `centroid*`: A string indicating the centroid to use (only necessary for partitional
#'   clustering).
#'
#'   *Necessary when overriding the default family for the calculation of other slots, CVIs or
#'   prediction. Maybe not always needed, e.g. for plotting.
#'
#' @examples
#'
#' data(uciCT)
#'
#' # Assuming this was generated by some clustering procedure
#' centroids <- CharTraj[seq(1L, 100L, 5L)]
#' cluster <- unclass(CharTrajLabels)
#'
#' pc_obj <- new("PartitionalTSClusters",
#'               type = "partitional", datalist = CharTraj,
#'               centroids = centroids, cluster = cluster,
#'               distance = "sbd", centroid = "dba",
#'               control = partitional_control(),
#'               args = tsclust_args(cent = list(window.size = 8L, norm = "L2")))
#'
#' fc_obj <- new("FuzzyTSClusters",
#'               type = "fuzzy", datalist = CharTraj,
#'               centroids = centroids, cluster = cluster,
#'               distance = "sbd", centroid = "fcm",
#'               control = fuzzy_control())
#'
#' show(fc_obj)
#'
setMethod("initialize", "TSClusters", function(.Object, ..., override.family = TRUE) {
    tic <- proc.time()
    dots <- list(...)
    # some minor checks
    if (!is.null(dots$datalist)) dots$datalist <- tslist(dots$datalist)
    if (!is.null(dots$centroids)) dots$centroids <- tslist(dots$centroids)
    # avoid infinite recursion (see https://bugs.r-project.org/bugzilla/show_bug.cgi?id=16629)
    if (is.null(dots$call)) {
        call <- match.call()
    }
    else {
        call <- dots$call
        dots$call <- NULL
    }
    # apparently a non-NULL value is needed if proc_time class is virtual?
    if (is.null(dots$proctime)) {
        dots$proctime <- tic
        fill_proctime <- TRUE
    }
    else {
        fill_proctime <- FALSE # nocov
    }

    # no quoted_call here, apparently do.call evaluates as parent and callNextMethod needs that
    .Object <- do.call(methods::callNextMethod, enlist(.Object = .Object, dots = dots), TRUE)
    .Object@call <- call

    # some "defaults"
    if (is.null(dots$preproc)) .Object@preproc <- "none"
    if (is.null(dots$k)) .Object@k <- length(.Object@centroids)

    if (is.null(dots$args))
        .Object@args <- tsclust_args(preproc = .Object@dots,
                                     dist = .Object@dots,
                                     cent = .Object@dots)
    else
        .Object@args <- adjust_args(.Object@args, .Object@dots) # UTILS-utils.R

    # more helpful for hierarchical/tadpole
    if (override.family) {
        if (length(.Object@type) == 0L)
            warning("Could not override family, 'type' slot is missing.")
        else if (length(.Object@distance) == 0L)
            warning("Could not override family, 'distance' slot is missing.")
        else {
            if (.Object@preproc == "none")
                preproc <- .Object@family@preproc
            else
                preproc <- get_from_callers(.Object@preproc, "function")

            if (.Object@type == "partitional") {
                if (length(.Object@centroid) > 0L) {
                    allcent <- all_cent2(.Object@centroid, .Object@control)

                    if (.Object@centroid == "shape" && .Object@preproc == "none") {
                        preproc <- zscore
                        .Object@preproc <- "zscore"
                    }
                }
                else {
                    allcent <- .Object@family@allcent
                    warning("Could not override allcent in family, 'centroid' slot is missing.")
                }
            }
            else if (.Object@type == "hierarchical") {
                centfun <- try(match.fun(.Object@centroid), silent = TRUE)
                if (!inherits(centfun, "try-error")) {
                    allcent <- function(...) { list(centfun(...)) }
                    allcent_env <- new.env(parent = .GlobalEnv)
                    allcent_env$centfun <- centfun
                    environment(allcent) <- allcent_env
                }
                else if (length(.Object@datalist) > 0L && !is.null(.Object@distmat)) {
                    allcent <- function(...) {
                        datalist[which.min(apply(distmat, 1L, sum))] # for CVI's global_cent
                    }

                    datalist <- distmat <- NULL # avoid check NOTE
                    allcent_env <- new.env(parent = .GlobalEnv)
                    allcent_env$datalist <- .Object@datalist
                    allcent_env$distmat <- .Object@distmat
                    environment(allcent) <- allcent_env
                }
                else {
                    allcent <- .Object@family@allcent
                    warning("Could not override allcent in family, 'datalist' or 'distmat' slots are missing.")
                }
            }
            else if (.Object@type == "fuzzy") {
                if (length(.Object@centroid) > 0L && .Object@centroid %in% centroids_fuzzy)
                    allcent <- .Object@centroid
                else {
                    allcent <- .Object@family@allcent
                    warning("Could not override allcent in family, 'centroid' slot is missing.")
                }
            }
            else if (.Object@type == "tadpole") {
                centfun <- try(match.fun(.Object@centroid), silent = TRUE)
                if (!inherits(centfun, "try-error")) {
                    allcent <- function(...) { list(centfun(...)) }
                    allcent_env <- new.env(parent = .GlobalEnv)
                    allcent_env$centfun <- centfun
                    environment(allcent) <- allcent_env
                }
                else if (length(.Object@centroids) > 0L) {
                    centroids <- NULL # avoid check NOTE
                    allcent_env <- new.env(parent = .GlobalEnv)
                    allcent_env$centroids <- .Object@centroids
                    allcent <- function(...) { centroids[1L] } # for CVI's global_cent
                    environment(allcent) <- allcent_env
                }
                else {
                    allcent <- .Object@family@allcent
                    warning("Could not override allcent in family, 'centroids' slot is missing.")
                }
            }
            else { # nocov start
                allcent <- .Object@family@allcent
                warning("Could not override allcent in family, 'type' slot is not recognized.")
            } # nocov end

            .Object@family <- methods::new("tsclustFamily",
                                           dist = .Object@distance,
                                           allcent = allcent,
                                           preproc = preproc,
                                           control = .Object@control,
                                           fuzzy = isTRUE(.Object@type == "fuzzy"))
        }
    }
    # just a filler
    if (fill_proctime) .Object@proctime <- proc.time() - tic
    # return
    .Object
})

# for derived classes
#' @importFrom methods callNextMethod
#' @importFrom methods initialize
#'
setMethod("initialize", "PartitionalTSClusters", function(.Object, ...) {
    .Object <- methods::callNextMethod()
    # some "defaults"
    if (length(.Object@iter) == 0L) .Object@iter <- 1L
    if (length(.Object@converged) == 0L) .Object@converged <- TRUE
    if (!nrow(.Object@cldist) &&
        length(formals(.Object@family@dist)) && length(.Object@cluster) &&
        length(.Object@datalist) && length(.Object@centroids))
        {
        # no cldist available, but dist and cluster can be used to calculate it
        dm <- quoted_call(
            .Object@family@dist,
            .Object@datalist,
            .Object@centroids,
            dots = .Object@args$dist
        )
        .Object@cldist <- base::as.matrix(dm[cbind(1L:length(.Object@datalist), .Object@cluster)])
        dimnames(.Object@cldist) <- NULL
    }
    if (!nrow(.Object@clusinfo) && length(.Object@cluster) && nrow(.Object@cldist)) {
        # no clusinfo available, but cluster and cldist can be used to calculate it (see UTILS-utils.R)
        .Object@clusinfo <- compute_clusinfo(.Object@k, .Object@cluster, .Object@cldist)
    }
    # return
    .Object
})

#' @importFrom methods callNextMethod
#' @importFrom methods initialize
#'
setMethod("initialize", "HierarchicalTSClusters", function(.Object, ...) {
    .Object <- methods::callNextMethod()
    # replace distmat with NULL so that, if the distance function is called again, it won't subset it
    if (length(formals(.Object@family@dist)) > 0L)
        environment(.Object@family@dist)$control$distmat <- NULL
    if (!nrow(.Object@cldist) && length(formals(.Object@family@dist)) && length(.Object@cluster)) {
        # no cldist available, but dist and cluster can be used to calculate it
        dm <- quoted_call(
            .Object@family@dist,
            .Object@datalist,
            .Object@centroids,
            dots = .Object@args$dist
        )
        .Object@cldist <- base::as.matrix(dm[cbind(1L:length(.Object@datalist),
                                                   .Object@cluster)])
        dimnames(.Object@cldist) <- NULL
    }
    if (!nrow(.Object@clusinfo) && length(.Object@cluster) && nrow(.Object@cldist)) {
        # no clusinfo available, but cluster and cldist can be used to calculate it (see UTILS-utils.R)
        .Object@clusinfo <- compute_clusinfo(.Object@k, .Object@cluster, .Object@cldist)
    }
    # return
    .Object
})

#' @importFrom methods callNextMethod
#' @importFrom methods initialize
#'
setMethod("initialize", "FuzzyTSClusters", function(.Object, ...) {
    .Object <- methods::callNextMethod()
    # some "defaults"
    if (length(.Object@iter) == 0L) .Object@iter <- 1L
    if (length(.Object@converged) == 0L) .Object@converged <- TRUE
    if (!nrow(.Object@fcluster)) {
        if (length(formals(.Object@family@dist))) {
            # no fcluster available, but dist and cluster function can be used to calculate it
            dm <- quoted_call(
                .Object@family@dist,
                .Object@datalist,
                .Object@centroids,
                dots = .Object@args$dist
            )
            .Object@fcluster <- .Object@family@cluster(dm, m = .Object@control$fuzziness)
            colnames(.Object@fcluster) <- paste0("cluster_", 1L:.Object@k)
            .Object@cluster <- max.col(.Object@fcluster, "first")
        }
        else {
            .Object@fcluster <- matrix(NA_real_) # nocov
        }
    }
    # return
    .Object
})

# ==================================================================================================
# Show
# ==================================================================================================

#' @rdname tsclusters-methods
#' @aliases show,TSClusters
#' @exportMethod show
#' @importFrom utils head
#'
#' @param object,x An object that inherits from [TSClusters-class] as returned by [tsclust()].
#'
setMethod("show", "TSClusters", function(object) {
    cat(object@type, "clustering with", object@k, "clusters\n")
    cat("Using", object@distance, "distance\n")
    cat("Using", object@centroid, "centroids\n")
    if (inherits(object, "HierarchicalTSClusters"))
        cat("Using method", object@method, "\n")
    if (object@preproc != "none")
        cat("Using", object@preproc, "preprocessing\n")
    cat("\nTime required for analysis:\n")
    print(object@proctime)
    if (inherits(object, "FuzzyTSClusters")) {
        cat("\nHead of fuzzy memberships:\n\n")
        print(utils::head(object@fcluster))
    }
    else if (!is.null(attr(object, "clusinfo"))) {
        cat("\nCluster sizes with average intra-cluster distance:\n\n")
        print(object@clusinfo)
    }
    invisible(NULL)
})

# ==================================================================================================
# update from stats
# ==================================================================================================

#' @rdname tsclusters-methods
#' @method update TSClusters
#' @export
#' @importFrom stats update
#'
#' @param evaluate Logical. Defaults to `TRUE` and evaluates the updated call, which will result in
#'   a new `TSClusters` object. Otherwise, it returns the unevaluated call.
#'
#' @details
#'
#' The `update` method takes the original function call, replaces any provided argument and
#' optionally evaluates the call again. Use `evaluate = FALSE` if you want to get the unevaluated
#' call. If no arguments are provided, the object is updated to a new version if necessary (this is
#' due to changes in the internal functions of the package, here for backward compatibility).
#'
update.TSClusters <- function(object, ..., evaluate = TRUE) {
    args <- as.pairlist(list(...))
    if (length(args) == 0L) {
        if (evaluate) {
            if (object@type != "tadpole") {
                # update dist closure
                object@family@dist <- ddist2(object@distance, object@control)
                # update allcent closure
                if (object@centroid %in% centroids_included)
                    object@family@allcent <- all_cent2(object@centroid, object@control)
                # update distmat in allcent environment with internal class?
                if (object@centroid %in% c("pam", "fcmdd"))
                    environment(object@family@allcent)$control$distmat <- object@control$distmat
            }
            return(object)
        }
        else
            return(object@call) # nocov
    }
    new_call <- object@call
    new_call[names(args)] <- args
    if (evaluate)
        eval.parent(new_call, n = 2L)
    else
        new_call # nocov
}

#' @rdname tsclusters-methods
#' @aliases update,TSClusters
#' @exportMethod update
#' @importFrom methods signature
#'
setMethod("update", methods::signature(object = "TSClusters"), update.TSClusters)

# ==================================================================================================
# predict from stats
# ==================================================================================================

#' @rdname tsclusters-methods
#' @method predict TSClusters
#' @export
#' @importFrom stats predict
#'
#' @param newdata New data to be assigned to a cluster. It can take any of the supported formats of
#'   [tsclust()]. Note that for multivariate series, this means that it **must** be a list of
#'   matrices, even if the list has only one matrix.
#'
#' @section Prediction:
#'
#'   The `predict` generic can take the usual `newdata` argument. If `NULL`, the method simply
#'   returns the obtained cluster indices. Otherwise, a nearest-neighbor classification based on the
#'   centroids obtained from clustering is performed:
#'
#'   1. `newdata` is preprocessed with `object@family@preproc` using the parameters in
#'   `object@args$preproc`.
#'   2. A cross-distance matrix between the processed series and `object@centroids` is computed with
#'   `object@family@dist` using the parameters in `object@args$dist`.
#'   3. For non-fuzzy clustering, the series are assigned to their nearest centroid's cluster. For
#'   fuzzy clustering, the fuzzy membership matrix for the series is calculated. In both cases,
#'   the function in `object@family@cluster` is used.
#'
predict.TSClusters <- function(object, newdata = NULL, ...) {
    if (is.null(newdata)) {
        if (inherits(object, "FuzzyTSClusters"))
            ret <- object@fcluster
        else
            ret <- object@cluster
    }
    else {
        newdata <- tslist(newdata)
        check_consistency(newdata, "vltslist")
        nm <- names(newdata)
        newdata <- quoted_call(
            object@family@preproc,
            newdata,
            dots = subset_dots(object@args$preproc, object@family@preproc)
        )
        distmat <- quoted_call(
            object@family@dist,
            x = newdata,
            centroids = object@centroids,
            dots = object@args$dist
        )
        ret <- object@family@cluster(distmat = distmat, m = object@control$fuzziness)
        if (inherits(object, "FuzzyTSClusters"))
            dimnames(ret) <- list(nm, paste0("cluster_", 1L:ncol(ret)))
        else
            names(ret) <- nm
    }
    ret
}

#' @rdname tsclusters-methods
#' @aliases predict,TSClusters
#' @exportMethod predict
#' @importFrom methods signature
#'
setMethod("predict", methods::signature(object = "TSClusters"), predict.TSClusters)

# ==================================================================================================
# Plot
# ==================================================================================================

#' @rdname tsclusters-methods
#' @method plot TSClusters
#' @export
#' @importFrom dplyr anti_join
#' @importFrom dplyr bind_rows
#' @importFrom dplyr sample_n
#' @importFrom methods S3Part
#' @importFrom ggplot2 aes_string
#' @importFrom ggplot2 facet_wrap
#' @importFrom ggplot2 geom_line
#' @importFrom ggplot2 geom_vline
#' @importFrom ggplot2 ggplot
#' @importFrom ggplot2 guides
#' @importFrom ggplot2 labs
#' @importFrom ggplot2 theme_bw
#' @importFrom ggrepel geom_label_repel
#' @importFrom graphics plot
#' @importFrom reshape2 melt
#'
#' @param y Ignored.
#' @param clus A numeric vector indicating which clusters to plot.
#' @param labs.arg A list with arguments to change the title and/or axis labels. See the examples
#'   and [ggplot2::labs()] for more information.
#' @param series Optionally, the data in the same format as it was provided to [tsclust()].
#' @param time Optional values for the time axis. If series have different lengths, provide the time
#'   values of the longest series.
#' @param plot Logical flag. You can set this to `FALSE` in case you want to save the ggplot object
#'   without printing anything to screen
#' @param type What to plot. `NULL` means default. See details.
#' @param labels Whether to include labels in the plot (not for dendrogram plots). See details and
#'   note that this is subject to **randomness**.
#'
#' @section Plotting:
#'
#'   The plot method uses the `ggplot2` plotting system (see [ggplot2::ggplot()]).
#'
#'   The default depends on whether a hierarchical method was used or not. In those cases, the
#'   dendrogram is plotted by default; you can pass any extra parameters to [stats::plot.hclust()]
#'   via the ellipsis (`...`).
#'
#'   Otherwise, the function plots the time series of each cluster along with the obtained centroid.
#'   The default values for cluster centroids are: `linetype = "dashed"`, `size = 1.5`, `colour =
#'   "black"`, `alpha = 0.5`. You can change this by means of the ellipsis (`...`).
#'
#'   You can choose what to plot with the `type` parameter. Possible options are:
#'
#'   - `"dendrogram"`: Only available for hierarchical clustering.
#'   - `"series"`: Plot the time series divided into clusters without including centroids.
#'   - `"centroids"`: Plot the obtained centroids only.
#'   - `"sc"`: Plot both series and centroids
#'
#'   In order to enable labels on the (non-dendrogram) plot, you have to select an option that plots
#'   the series and at least provide an empty list in the `labels` argument. This list can contain
#'   arguments for [ggrepel::geom_label_repel()] and will be passed along. The following are
#'   set by the plot method if they are not provided:
#'
#'   - `"mapping"`: set to [aes_string][ggplot2::aes_string](x = "t", y = "value", label = "label")
#'   - `"data"`: a data frame with as many rows as series in the `datalist` and 4 columns:
#'     + `t`: x coordinate of the label for each series.
#'     + `value`: y coordinate of the label for each series.
#'     + `cl`: index of the cluster to which the series belongs (i.e. `x@cluster`).
#'     + `label`: the label for the given series (i.e. `names(x@datalist)`).
#'
#'   You can provide your own data frame if you want, but it must have those columns and, even if
#'   you override `mapping`, the `cl` column must have that name. The method will attempt to spread
#'   the labels across the plot, but note that this is **subject to randomness**, so be careful if
#'   you need reproducibility of any commands used after plotting (see examples).
#'
#'   If created, the function returns the `gg` object invisibly, in case you want to modify it to
#'   your liking. You might want to look at [ggplot2::ggplot_build()] if that's the case.
#'
#'   If you want to free the scale of the X axis, you can do the following:
#'
#'   `plot(x, plot = FALSE)` `+` `facet_wrap(~cl, scales = "free")`
#'
#'   For more complicated changes, you're better off looking at the source code at
#'   \url{https://github.com/asardaes/dtwclust/blob/master/R/S4-TSClusters-methods.R} and creating your
#'   own plotting function.
#'
#' @return
#'
#' The plot method returns a `gg` object (or `NULL` for dendrogram plot) invisibly.
#'
#' @examples
#'
#' \dontrun{
#' plot(pc_obj, type = "c", linetype = "solid",
#'      labs.arg = list(title = "Clusters' centroids"))
#'
#' set.seed(15L)
#' plot(pc_obj, labels = list(nudge_x = -5, nudge_y = 0.2),
#'      clus = c(1L,4L))
#' }
#'
plot.TSClusters <- function(x, y, ...,
                            clus = seq_len(x@k), labs.arg = NULL,
                            series = NULL, time = NULL,
                            plot = TRUE, type = NULL,
                            labels = NULL)
{
    # set default type if none was provided
    if (!is.null(type)) # nocov start
        type <- match.arg(type, c("dendrogram", "series", "centroids", "sc"))
    else if (x@type == "hierarchical")
        type <- "dendrogram"
    else
        type <- "sc" # nocov end

    # plot dendrogram?
    if (inherits(x, "HierarchicalTSClusters") && type == "dendrogram") {
        x <- methods::S3Part(x, strictS3 = TRUE, "hclust")
        if (plot) graphics::plot(x, ...)
        return(invisible(NULL))
    }
    else if (x@type != "hierarchical" && type == "dendrogram") {
        stop("Dendrogram plot only applies to hierarchical clustering.")
    }

    # Obtain data, the priority is: provided data > included data list
    if (!is.null(series)) {
        data <- tslist(series)
    }
    else {
        if (length(x@datalist) < 1L)
            stop("Provided object has no data. Please provide the data manually.") # nocov
        data <- x@datalist
    }

    # centroids consistency
    check_consistency(centroids <- x@centroids, "vltslist")

    # force same length for all multivariate series/centroids in the same cluster by
    # adding NAs
    if (mv <- is_multivariate(data)) {
        data <- lapply(data, base::as.matrix)
        centroids <- lapply(centroids, base::as.matrix)
        clusters <- split(data, factor(x@cluster, levels = 1L:x@k), drop = FALSE)
        for (id_clus in 1L:x@k) {
            cluster <- clusters[[id_clus]]
            if (length(cluster) < 1L) next # nocov (empty cluster)
            nc <- NCOL(cluster[[1L]])
            len <- sapply(cluster, NROW)
            L <- max(len, NROW(centroids[[id_clus]]))
            trail <- L - len
            clusters[[id_clus]] <- Map(cluster, trail, f = function(mvs, trail) {
                rbind(mvs, matrix(NA, trail, nc))
            })
            trail <- L - NROW(centroids[[id_clus]])
            centroids[[id_clus]] <- rbind(centroids[[id_clus]], matrix(NA, trail, nc))
        }
        # split returns the result in order of the factor levels,
        # but I want to keep the original order as returned from clustering
        ido <- sort(sort(x@cluster, index.return = TRUE)$ix, index.return = TRUE)$ix
        data <- unlist(clusters, recursive = FALSE)[ido]
    }

    # helper values (lengths() here, see issue #18 in GitHub)
    L1 <- lengths(data)
    L2 <- lengths(centroids)
    # timestamp consistency
    if (!is.null(time) && length(time) < max(L1, L2))
        stop("Length mismatch between values and timestamps") # nocov
    # Check if data was z-normalized
    if (x@preproc == "zscore")
        title_str <- "Clusters' members (z-normalized)"
    else
        title_str <- "Clusters' members"

    # transform to data frames
    dfm <- reshape2::melt(data)
    dfcm <- reshape2::melt(centroids)

    # time, cluster and colour indices
    color_ids <- integer(x@k)
    dfm_tcc <- mapply(x@cluster, L1, USE.NAMES = FALSE, SIMPLIFY = FALSE,
                      FUN = function(clus, len) {
                          t <- if (is.null(time)) seq_len(len) else time[1L:len]
                          cl <- rep(clus, len)
                          color <- rep(color_ids[clus], len)
                          color_ids[clus] <<- color_ids[clus] + 1L
                          data.frame(t = t, cl = cl, color = color)
                      })
    dfcm_tc <- mapply(1L:x@k, L2, USE.NAMES = FALSE, SIMPLIFY = FALSE,
                      FUN = function(clus, len) {
                          t <- if (is.null(time)) seq_len(len) else time[1L:len]
                          cl <- rep(clus, len)
                          data.frame(t = t, cl = cl)
                      })

    # bind
    dfm <- data.frame(dfm, do.call(rbind, dfm_tcc, TRUE))
    dfcm <- data.frame(dfcm, do.call(rbind, dfcm_tc, TRUE))
    # make factor
    dfm$cl <- factor(dfm$cl)
    dfcm$cl <- factor(dfcm$cl)
    dfm$color <- factor(dfm$color)

    # create gg object
    gg <- ggplot2::ggplot(data.frame(t = integer(),
                                     variable = factor(),
                                     value = numeric(),
                                     cl = factor(),
                                     color = factor()),
                          ggplot2::aes_string(x = "t",
                                              y = "value",
                                              group = "L1"))

    # add centroids first if appropriate, so that they are at the very back
    if (type %in% c("sc", "centroids")) {
        if (length(list(...)) == 0L)
            gg <- gg + ggplot2::geom_line(data = dfcm[dfcm$cl %in% clus, ],
                                          linetype = "dashed",
                                          size = 1.5,
                                          colour = "black",
                                          alpha = 0.5)
        else
            gg <- gg + ggplot2::geom_line(data = dfcm[dfcm$cl %in% clus, ], ...)
    }

    # add series next if appropriate
    if (type %in% c("sc", "series"))
        gg <- gg + ggplot2::geom_line(data = dfm[dfm$cl %in% clus, ], aes_string(colour = "color"))

    # add vertical lines to separate variables of multivariate series
    if (mv) {
        ggdata <- data.frame(cl = rep(1L:x@k, each = (nc - 1L)),
                             vbreaks = as.numeric(1L:(nc - 1L) %o% sapply(centroids, NROW)))
        gg <- gg + ggplot2::geom_vline(data = ggdata[ggdata$cl %in% clus, , drop = FALSE],
                                       colour = "black", linetype = "longdash",
                                       ggplot2::aes_string(xintercept = "vbreaks"))
    }

    # add labels
    if (type %in% c("sc", "series") && is.list(labels)) {
        if (is.null(labels$mapping))
            labels$mapping <- ggplot2::aes_string(x = "t", y = "value", label = "label")
        if (is.null(labels$data) && !is.null(names(x@datalist))) {
            label <- names(x@datalist)[x@cluster %in% clus]
            label <- split(label, x@cluster[x@cluster %in% clus])
            dfm <- dfm[dfm$cl %in% clus,]
            labels$data <- dplyr::bind_rows(lapply(split(dfm, dfm$cl), function(df_cluster) {
                # keep given order for split
                df_cluster$L1 <- factor(df_cluster$L1, levels = unique(df_cluster$L1))
                df_series <- split(df_cluster[c("t", "value", "cl")], df_cluster$L1)
                ret <- vector("list", length(df_series))

                for (i in seq_along(df_series)) {
                    this_df <- df_series[[i]]

                    for (not_this in df_series[-i]) {
                        if (nrow(this_df) == nrow(not_this) && isTRUE(all.equal(not_this, this_df, check.attributes = FALSE))) {
                            next
                        }
                        this_df <- dplyr::anti_join(this_df, not_this, by = c("t", "value"))
                    }
                    ret[[i]] <- dplyr::sample_n(this_df, 1L)
                }

                # return
                dplyr::bind_rows(ret)
            }))
            labels$data$label <- unlist(label)
        }
        labels$data <- labels$data[labels$data$cl %in% clus,]
        labels$inherit.aes <- FALSE
        gg <- gg + do.call(ggrepel::geom_label_repel, labels, TRUE)
    }

    # add facets, remove legend, apply kinda black-white theme
    gg <- gg +
        ggplot2::facet_wrap(~cl, scales = "free_y") +
        ggplot2::guides(colour = "none") +
        ggplot2::theme_bw()

    # labs
    if (!is.null(labs.arg)) # nocov start
        gg <- gg + ggplot2::labs(labs.arg)
    else
        gg <- gg + ggplot2::labs(title = title_str)

    # plot without warnings in case I added NAs for multivariate cases
    if (plot) suppressWarnings(graphics::plot(gg))
    invisible(gg) # nocov end
}

#' @rdname tsclusters-methods
#' @aliases plot,TSClusters,missing
#' @exportMethod plot
#' @importFrom methods signature
#'
setMethod("plot", methods::signature(x = "TSClusters", y = "missing"), plot.TSClusters)

# ==================================================================================================
# Cluster validity indices
# ==================================================================================================

#' @importFrom cluster silhouette
#'
cvi_TSClusters <- function(a, b = NULL, type = "valid", ...) {
    type <- match.arg(type, several.ok = TRUE,
                      c("RI", "ARI", "J", "FM", "VI",
                        "Sil", "SF", "CH", "DB", "DBstar", "D", "COP",
                        "valid", "internal", "external"))
    dots <- list(...)
    internal <- c("Sil", "SF", "CH", "DB", "DBstar", "D", "COP")
    external <- c("RI", "ARI", "J", "FM", "VI")
    if (any(type == "valid")) {
        type <- if (is.null(b)) internal else c(internal, external)
    }
    else if (any(type == "internal")) {
        type <- internal
    }
    else if (any(type == "external")) {
        type <- external
    }

    which_internal <- type %in% internal
    which_external <- type %in% external
    if (any(which_external))
        CVIs <- cvi(a@cluster, b = b, type = type[which_external], ...)
    else
        CVIs <- numeric()

    type <- type[which_internal]
    if (any(which_internal)) {
        if (length(a@datalist) == 0L && any(type %in% c("SF", "CH"))) {
            warning("Internal CVIs: the original data must be in object to calculate ",
                    "the following indices:",
                    "\n\tSF\tCH")
            type <- setdiff(type, c("SF", "CH"))
        }

        # calculate distmat if needed
        distmat <- NULL
        if (any(type %in% c("Sil", "D", "COP"))) {
            if (is.null(a@distmat)) {
                if (length(a@datalist) == 0L) {
                    warning("Internal CVIs: distmat OR original data needed for indices:",
                            "\n\tSil\tD\tCOP")
                    type <- setdiff(type, c("Sil", "D", "COP"))
                }
                else {
                    distmat <- quoted_call(
                        a@family@dist,
                        x = a@datalist,
                        centroids = NULL,
                        dots = a@args$dist
                    )
                }
            }
            else {
                distmat <- a@distmat
            }
        }
        if (!is.null(distmat)) {
            distmat <- base::as.matrix(distmat)
            if (!base::isSymmetric(distmat))
                warning("Internal CVIs: series' cross-distance matrix is NOT symmetric, ",
                        "which can be problematic for:",
                        "\n\tSil\tD\tCOP")
        }
        # are no valid indices left?
        if (length(type) == 0L) return(CVIs) # nocov

        # calculate some values that both Davies-Bouldin indices use
        if (any(type %in% c("DB", "DBstar"))) {
            S <- a@clusinfo$av_dist
            # distance between centroids
            distcent <- quoted_call(
                a@family@dist,
                x = a@centroids,
                centroids = NULL,
                dots = a@args$dist
            )
            distcent <- base::as.matrix(distcent)
            if (!base::isSymmetric(distcent))
                warning("Internal CVIs: centroids' cross-distance matrix is NOT symmetric, ",
                        "which can be problematic for:",
                        "\n\tDB\tDB*")
        }

        # calculate global centroids if needed
        if (any(type %in% c("SF", "CH"))) {
            N <- length(a@datalist)
            if (a@type == "partitional") {
                global_cent <- quoted_call(
                    a@family@allcent,
                    x = a@datalist,
                    cl_id = rep(1L, N),
                    k = 1L,
                    cent = a@datalist[sample(N, 1L)],
                    cl_old = rep(0L, N),
                    dots = a@args$cent
                )
            }
            else {
                global_cent <- quoted_call(a@family@allcent,
                                           a@datalist,
                                           dots = subset_dots(a@args$cent, a@family@allcent))
            }
            dist_global_cent <- quoted_call(a@family@dist,
                                            x = a@centroids,
                                            centroids = global_cent,
                                            dots = a@args$dist)
            dim(dist_global_cent) <- NULL
        }

        CVIs <- c(CVIs, sapply(type, function(CVI) {
            switch(EXPR = CVI,
                   # Silhouette
                   Sil = {
                       mean(cluster::silhouette(a@cluster, dmatrix = distmat)[,3L])
                   },
                   # Dunn
                   D = {
                       pairs <- call_pairs(a@k)
                       deltas <- mapply(pairs[ , 1L], pairs[ , 2L],
                                        USE.NAMES = FALSE, SIMPLIFY = TRUE,
                                        FUN = function(i, j) {
                                            min(distmat[a@cluster == i,
                                                        a@cluster == j,
                                                        drop = TRUE])
                                        })
                       Deltas <- sapply(1L:a@k, function(k) {
                           max(distmat[a@cluster == k, a@cluster == k, drop = TRUE])
                       })
                       min(deltas) / max(Deltas)
                   },
                   # Davies-Bouldin
                   DB = {
                       mean(sapply(1L:a@k, function(k) {
                           max((S[k] + S[-k]) / distcent[k, -k])
                       }))
                   },
                   # Modified DB -> DB*
                   DBstar = {
                       mean(sapply(1L:a@k, function(k) {
                           max(S[k] + S[-k]) / min(distcent[k, -k, drop = TRUE])
                       }))
                   },
                   # Calinski-Harabasz
                   CH = {
                       (N - a@k) /
                           (a@k - 1) *
                           sum(a@clusinfo$size * dist_global_cent) /
                           sum(a@cldist[, 1L])
                   },
                   # Score function
                   SF = {
                       bcd <- sum(a@clusinfo$size * dist_global_cent) / (N * a@k)
                       wcd <- sum(a@clusinfo$av_dist)
                       1 - 1 / exp(exp(bcd - wcd))
                   },
                   # COP
                   COP = {
                       1 / nrow(distmat) * sum(sapply(1L:a@k, function(k) {
                           sum(a@cldist[a@cluster == k, 1L]) / min(apply(distmat[a@cluster != k,
                                                                                 a@cluster == k,
                                                                                 drop = FALSE],
                                                                         2L,
                                                                         max))
                       }))
                   })
        }))
    }
    # return
    CVIs
}

#' @rdname cvi
#' @aliases cvi,PartitionalTSClusters
#' @exportMethod cvi
#' @importFrom methods signature
#' @include GENERICS-cvi.R
#'
setMethod("cvi", methods::signature(a = "PartitionalTSClusters"), cvi_TSClusters)

#' @rdname cvi
#' @aliases cvi,HierarchicalTSClusters
#' @exportMethod cvi
#' @importFrom methods signature
#'
setMethod("cvi", methods::signature(a = "HierarchicalTSClusters"), cvi_TSClusters)

#' @rdname cvi
#' @aliases cvi,FuzzyTSClusters
#' @exportMethod cvi
#' @importFrom methods signature
#'
setMethod(
    "cvi", methods::signature(a = "FuzzyTSClusters", b = "ANY"),
    function(a, b = NULL, type = "valid", ...) {
        type <- match.arg(type, several.ok = TRUE,
                          c("MPC", "K", "T", "SC", "PBMF",
                            "RI", "ARI", "VI", "NMIM",
                            "valid", "internal", "external"))
        dots <- list(...)
        internal <- c("MPC", "K", "T", "SC", "PBMF")
        external <- c("RI", "ARI", "VI", "NMIM")
        if (any(type == "valid")) {
            type <- if (is.null(b)) internal else c(internal, external)
        } else if (any(type == "internal")) {
            type <- internal
        } else if (any(type == "external")) {
            type <- external
        }

        which_internal <- type %in% internal
        which_external <- type %in% external
        if (any(which_external))
            CVIs <- cvi(a@fcluster, b = b, type = type[which_external], ...)
        else
            CVIs <- numeric()

        type <- type[which_internal]
        if (any(which_internal)) {
            if (length(a@datalist) == 0L && any(type %in% c("K", "T" ,"SC", "PBMF"))) {
                warning("Fuzzy CVIs: the original series must be in the object to calculate ",
                        "the following indices:\n",
                        "\tK\tT\tSC\tPBMF")
                type <- setdiff(type, c("K", "T" ,"SC", "PBMF"))
            }
            # are no valid indices left?
            if (length(type) == 0L) return(numeric(0L)) # nocov
            # calculate global centroids if needed
            if (any(type %in% c("K", "SC", "PBMF"))) {
                N <- length(a@datalist)
                global_cent <- quoted_call(
                    a@family@allcent,
                    x = a@datalist,
                    cl_id = cbind(rep(1L, N)),
                    k = 1L,
                    dots = a@args$cent
                )
                dist_global_cent <- quoted_call(
                    a@family@dist,
                    x = a@centroids,
                    centroids = global_cent,
                    dots = a@args$dist
                )
                dim(dist_global_cent) <- NULL
            }
            # distance between centroids
            if (any(type %in% c("K", "T", "PBMF"))) {
                distcent <- quoted_call(a@family@dist,
                                        x = a@centroids,
                                        centroids = NULL,
                                        dots = a@args$dist)
            }
            # distance between series and centroids
            if (any(type %in% c("K", "T", "SC", "PBMF"))) {
                dsc <- quoted_call(a@family@dist,
                                   x = a@datalist,
                                   centroids = a@centroids,
                                   dots = a@args$dist)
            }
            CVIs <- c(CVIs, sapply(type, function(CVI) {
                switch(EXPR = CVI,
                       # ---------------------------------------------------------------------------
                       "MPC" = {
                           PC <- sum(a@fcluster ^ 2) / nrow(a@fcluster)
                           1 - (a@k / (a@k - 1L)) * (1 - PC)
                       },
                       # ---------------------------------------------------------------------------
                       "K" = {
                           numerator <- sum((a@fcluster ^ 2) * (dsc ^ 2)) +
                               sum(dist_global_cent ^ 2) / a@k
                           denominator <- min(distcent[!diag(a@k)] ^ 2)
                           numerator / denominator
                       },
                       # ---------------------------------------------------------------------------
                       "T" = {
                           numerator <- sum((a@fcluster ^ 2) * (dsc ^ 2)) +
                               sum(distcent[!diag(a@k)] ^ 2) / (a@k * (a@k - 1L))
                           denominator <- min(distcent[!diag(a@k)] ^ 2) + 1 / a@k
                           numerator / denominator
                       },
                       # ---------------------------------------------------------------------------
                       "SC" = {
                           u <- a@fcluster
                           m <- a@control$fuzziness
                           SC1_numerator <- sum((dist_global_cent ^ 2) / a@k)
                           SC1_denominator <- sum(apply((u ^ m) * (dsc ^ 2), 2L, sum) /
                                                      apply(u, 2L, sum))
                           SC1 <- SC1_numerator / SC1_denominator
                           SC2_numerator <- sum(sapply(1L:(a@k - 1L), function(i) {
                               sum(sapply(1L:(a@k - i), function(r) {
                                   j <- r + i
                                   temp <- apply(u[, c(i,j)], 1L, min)
                                   sum(temp ^ 2) / sum(temp)
                               }))
                           }))
                           SC2_denominator <- apply(u, 1L, max)
                           SC2_denominator <- sum(SC2_denominator ^ 2) / sum(SC2_denominator)
                           SC2 <- SC2_numerator / SC2_denominator
                           SC1 - SC2
                       },
                       # ---------------------------------------------------------------------------
                       "PBMF" = {
                           u <- a@fcluster
                           m <- a@control$fuzziness
                           dsgc <- quoted_call(a@family@dist,
                                               x = a@datalist,
                                               centroids = global_cent,
                                               dots = a@args$dist)
                           factor1 <- 1 / a@k
                           factor2 <- sum(dsgc) / sum(dsc * (u ^ m))
                           factor3 <- max(distcent[!diag(a@k)])
                           (factor1 * factor2 * factor3) ^ 2
                       })
            }))
        }
        # return
        CVIs
    }
)

# ==================================================================================================
# Functions to support package 'clue'
# ==================================================================================================

#' @method as.cl_membership TSClusters
#' @export
#' @importFrom clue as.cl_membership
#'
as.cl_membership.TSClusters <- function(x) {
    clue::as.cl_membership(x@cluster)
}

#' @method cl_class_ids TSClusters
#' @export
#' @importFrom clue as.cl_class_ids
#' @importFrom clue cl_class_ids
#'
cl_class_ids.TSClusters <- function(x) {
    clue::as.cl_class_ids(x@cluster)
}

#' @method cl_membership TSClusters
#' @export
#' @importFrom clue as.cl_membership
#' @importFrom clue cl_membership
#'
cl_membership.TSClusters <- function(x, k = n_of_classes(x)) {
    clue::as.cl_membership(x)
}

#' @method is.cl_dendrogram TSClusters
#' @export
#' @importFrom clue is.cl_dendrogram
#'
is.cl_dendrogram.TSClusters <- function(x) {
    x@type == "hierarchical"
}

#' @method is.cl_hard_partition TSClusters
#' @export
#' @importFrom clue is.cl_hard_partition
#'
is.cl_hard_partition.TSClusters <- function(x) {
    x@type != "fuzzy"
}

#' @method is.cl_hierarchy TSClusters
#' @export
#' @importFrom clue is.cl_hierarchy
#'
is.cl_hierarchy.TSClusters <- function(x) {
    x@type == "hierarchical"
}

#' @method is.cl_partition TSClusters
#' @export
#' @importFrom clue is.cl_partition
#'
is.cl_partition.TSClusters <- function(x) {
    TRUE
}

#' @method n_of_classes TSClusters
#' @export
#' @importFrom clue n_of_classes
#'
n_of_classes.TSClusters <- function(x) {
    x@k
}

#' @method n_of_objects TSClusters
#' @export
#' @importFrom clue n_of_objects
#'
n_of_objects.TSClusters <- function(x) {
    length(x@cluster)
}
asardaes/dtwclust documentation built on March 3, 2023, 5:32 a.m.