R/dm_wavelet_reconstruct.R

Defines functions print.summary.dm_wavelet_reconstruct summary.dm_wavelet_reconstruct plot.dm_wavelet_reconstruct dm_wavelet_reconstruct

Documented in dm_wavelet_reconstruct plot.dm_wavelet_reconstruct print.summary.dm_wavelet_reconstruct summary.dm_wavelet_reconstruct

#' Reconstruct or remove selected cycle components from a dm_wavelet object
#'
#' @description
#' Reconstructs a selected oscillatory component, or a selected period band,
#' from a \code{dm_wavelet} object using \code{WaveletComp::reconstruct()}.
#'
#' Requested periods are supplied in \strong{hours}. They are internally
#' converted to the native wavelet period units used when \code{dm_wavelet()}
#' was computed, then passed to \code{WaveletComp::reconstruct()}.
#'
#' The function supports two modes:
#' \itemize{
#'   \item \code{mode = "extract"} returns the selected cycle or band itself.
#'   \item \code{mode = "remove"} returns the original series with the selected
#'   cycle or band removed.
#' }
#'
#' @param x An object of class \code{"dm_wavelet"} returned by
#'   \code{dm_wavelet()}.
#' @param series Optional character vector of series names to reconstruct. If
#'   \code{NULL}, all available series are used.
#' @param mode One of \code{"extract"} or \code{"remove"}.
#' @param period_hours Optional numeric vector of exact periods, in hours, to
#'   reconstruct. If supplied, \code{lower_hours} and \code{upper_hours} are
#'   ignored.
#' @param lower_hours Optional lower period bound in hours for band
#'   reconstruction.
#' @param upper_hours Optional upper period bound in hours for band
#'   reconstruction.
#' @param lvl Minimum wavelet power level to include in the reconstruction.
#' @param only_sig Logical. If \code{TRUE}, use wavelet power significance in
#'   reconstruction.
#' @param siglvl Significance level used when \code{only_sig = TRUE}.
#' @param only_coi Logical. If \code{TRUE}, restrict reconstruction to the cone
#'   of influence.
#' @param only_ridge Logical. If \code{TRUE}, reconstruct from the power ridge
#'   only.
#' @param rescale Logical. Passed to \code{WaveletComp::reconstruct()}.
#' @param verbose Logical. If \code{TRUE}, prints a completion message.
#'
#' @return
#' An object of class \code{"dm_wavelet_reconstruct"} with elements:
#' \describe{
#'   \item{call}{Matched function call.}
#'   \item{series}{Selected series names.}
#'   \item{mode}{Reconstruction mode, either \code{"extract"} or \code{"remove"}.}
#'   \item{selection}{A list describing the requested selection in hours and in
#'   native wavelet units.}
#'   \item{results}{Named list of \code{WaveletComp::reconstruct()} outputs.}
#'   \item{reconstructed_long}{Tidy table with columns \code{TIME},
#'   \code{series}, \code{original}, \code{reconstructed}, \code{difference},
#'   and \code{filtered}.}
#'   \item{used_periods}{Per-series table of periods actually used in the
#'   reconstruction, in native units and in hours.}
#'   \item{filtered_wide}{Wide table with \code{TIME} in the first column and
#'   one filtered series column per selected tree/series.}
#'   \item{parent}{Minimal metadata copied from the parent \code{dm_wavelet}
#'   object.}
#' }
#'
#' @examples
#' \donttest{
#' wv <- dm_wavelet(
#'   x = gf_nepa17,
#'   TreeNum = 1:2,
#'   source = "raw",
#'   make_pval = TRUE,
#'   verbose = FALSE
#' )
#'
#' # extract circadian component
#' rec_extract <- dm_wavelet_reconstruct(
#'   wv,
#'   mode = "extract",
#'   lower_hours = 20,
#'   upper_hours = 28,
#'   only_sig = TRUE
#' )
#'
#' # remove circadian component
#' rec_remove <- dm_wavelet_reconstruct(
#'   wv,
#'   mode = "remove",
#'   lower_hours = 20,
#'   upper_hours = 28,
#'   only_sig = TRUE
#' )
#'
#' head(rec_extract$filtered_wide)
#' head(rec_remove$filtered_wide)
#' }
#'
#' @importFrom dplyr bind_rows %>%
#' @importFrom tidyr pivot_wider
#' @importFrom tibble tibble
#' @export
dm_wavelet_reconstruct <- function(x,
                                   series = NULL,
                                   mode = c("extract", "remove"),
                                   period_hours = NULL,
                                   lower_hours = NULL,
                                   upper_hours = NULL,
                                   lvl = 0,
                                   only_sig = TRUE,
                                   siglvl = 0.05,
                                   only_coi = FALSE,
                                   only_ridge = FALSE,
                                   rescale = FALSE,
                                   verbose = TRUE) {
  TIME <- NULL

  if (!inherits(x, "dm_wavelet")) {
    stop("'x' must be an object of class 'dm_wavelet'.")
  }

  if (!requireNamespace("WaveletComp", quietly = TRUE)) {
    stop("Package 'WaveletComp' is required for dm_wavelet_reconstruct().")
  }

  mode <- match.arg(mode)

  if (is.null(series)) {
    series <- x$series
  }

  miss <- setdiff(series, x$series)
  if (length(miss) > 0) {
    stop("These requested series were not found: ", paste(miss, collapse = ", "))
  }

  if (!is.null(period_hours) && (!is.null(lower_hours) || !is.null(upper_hours))) {
    warning("'period_hours' was supplied, so 'lower_hours' and 'upper_hours' are ignored.")
    lower_hours <- NULL
    upper_hours <- NULL
  }

  if (is.null(period_hours) && xor(is.null(lower_hours), is.null(upper_hours))) {
    stop("Provide both 'lower_hours' and 'upper_hours' for a band reconstruction.")
  }

  if (is.null(period_hours) && is.null(lower_hours) && is.null(upper_hours)) {
    stop("Provide either 'period_hours' or both 'lower_hours' and 'upper_hours'.")
  }

  if (!is.null(period_hours)) {
    if (!is.numeric(period_hours) || any(!is.finite(period_hours)) || any(period_hours <= 0)) {
      stop("'period_hours' must contain positive finite numbers.")
    }
  }

  if (!is.null(lower_hours)) {
    if (!is.numeric(lower_hours) || length(lower_hours) != 1 || !is.finite(lower_hours) || lower_hours <= 0) {
      stop("'lower_hours' must be one positive finite number.")
    }
    if (!is.numeric(upper_hours) || length(upper_hours) != 1 || !is.finite(upper_hours) || upper_hours <= 0) {
      stop("'upper_hours' must be one positive finite number.")
    }
    if (lower_hours >= upper_hours) {
      stop("'lower_hours' must be smaller than 'upper_hours'.")
    }
  }

  if (!is.numeric(siglvl) || length(siglvl) != 1 || is.na(siglvl) || siglvl <= 0 || siglvl >= 1) {
    stop("'siglvl' must be a single number between 0 and 1.")
  }

  cl <- match.call()

  dwr_hours_to_native <- function(hours, time_unit) {
    if (time_unit == "days") return(hours / 24)
    if (time_unit == "hours") return(hours)
    if (time_unit == "mins") return(hours * 60)
    if (time_unit == "secs") return(hours * 3600)
    hours
  }

  dwr_native_to_hours <- function(native, time_unit) {
    if (time_unit == "days") return(native * 24)
    if (time_unit == "hours") return(native)
    if (time_unit == "mins") return(native / 60)
    if (time_unit == "secs") return(native / 3600)
    native
  }

  results_list <- vector("list", length(series))
  names(results_list) <- series

  tidy_list <- vector("list", length(series))
  names(tidy_list) <- series

  used_periods_list <- vector("list", length(series))
  names(used_periods_list) <- series

  period_native <- if (!is.null(period_hours)) dwr_hours_to_native(period_hours, x$time_unit) else NULL
  lower_native <- if (!is.null(lower_hours)) dwr_hours_to_native(lower_hours, x$time_unit) else NULL
  upper_native <- if (!is.null(upper_hours)) dwr_hours_to_native(upper_hours, x$time_unit) else NULL

  for (i in seq_along(series)) {
    ss <- series[i]
    wt <- x$results[[ss]]$wavelet
    tt <- x$results[[ss]]$time

    rec_args <- list(
      WT = wt,
      lvl = lvl,
      only.coi = only_coi,
      only.sig = only_sig,
      siglvl = siglvl,
      only.ridge = only_ridge,
      rescale = rescale,
      plot.waves = FALSE,
      plot.rec = FALSE,
      verbose = FALSE
    )

    if (!is.null(period_native)) {
      rec_args$sel.period <- period_native
    } else {
      rec_args$sel.lower <- lower_native
      rec_args$sel.upper <- upper_native
    }

    rec <- do.call(WaveletComp::reconstruct, rec_args)

    sdat <- rec$series
    sname <- names(sdat)[!grepl("\\.trend$|\\.r$", names(sdat))][1]
    rname <- paste0(sname, ".r")

    if (!(sname %in% names(sdat)) || !(rname %in% names(sdat))) {
      stop(
        "Could not identify original and reconstructed series columns in ",
        "WaveletComp::reconstruct() output for series '", ss, "'."
      )
    }

    original <- as.numeric(sdat[[sname]])
    reconstructed <- as.numeric(sdat[[rname]])
    difference <- original - reconstructed

    filtered <- if (mode == "extract") reconstructed else difference

    tidy_list[[i]] <- tibble::tibble(
      TIME = tt,
      series = ss,
      original = original,
      reconstructed = reconstructed,
      difference = difference,
      filtered = filtered
    )

    used_idx <- rec$rnum.used
    if (length(used_idx) > 0 && !is.null(wt$Period)) {
      used_native <- as.numeric(wt$Period[used_idx])
      used_hours <- dwr_native_to_hours(used_native, x$time_unit)

      used_periods_list[[i]] <- tibble::tibble(
        series = ss,
        period_native = used_native,
        period_hours = used_hours
      )
    } else {
      used_periods_list[[i]] <- tibble::tibble(
        series = ss,
        period_native = numeric(0),
        period_hours = numeric(0)
      )
    }

    results_list[[i]] <- rec
  }

  reconstructed_long <- dplyr::bind_rows(tidy_list)

  filtered_wide <- reconstructed_long %>%
    dplyr::select(TIME, series, filtered) %>%
    tidyr::pivot_wider(
      names_from = series,
      values_from = filtered
    ) %>%
    as.data.frame()

  out <- list(
    call = cl,
    series = series,
    mode = mode,
    selection = list(
      period_hours = period_hours,
      lower_hours = lower_hours,
      upper_hours = upper_hours,
      period_native = period_native,
      lower_native = lower_native,
      upper_native = upper_native,
      lvl = lvl,
      only_sig = only_sig,
      siglvl = siglvl,
      only_coi = only_coi,
      only_ridge = only_ridge,
      rescale = rescale
    ),
    results = results_list,
    reconstructed_long = reconstructed_long,
    used_periods = dplyr::bind_rows(used_periods_list),
    filtered_wide = filtered_wide,
    parent = list(
      input_type = x$input_type,
      source = x$source,
      time_unit = x$time_unit,
      dt = x$dt,
      dt_hours = x$dt_hours
    )
  )

  class(out) <- "dm_wavelet_reconstruct"

  if (isTRUE(verbose)) {
    msg <- if (!is.null(period_hours)) {
      paste0("selected period(s): ", paste(round(period_hours, 4), collapse = ", "), " hour(s)")
    } else {
      paste0("selected band: ", round(lower_hours, 4), " to ", round(upper_hours, 4), " hours")
    }

    action_msg <- if (mode == "extract") "extracted" else "removed"

    message(
      "dm_wavelet_reconstruct completed for ", length(series),
      " series (", msg, "; mode = '", mode, "', selected component ", action_msg, ")."
    )
  }

  out
}

############### plotting #####################################################
#' Plot method for dm_wavelet_reconstruct objects
#'
#' @description
#' Plots reconstructed or filtered cycle components extracted by
#' \code{dm_wavelet_reconstruct()}.
#'
#' @param x An object of class \code{"dm_wavelet_reconstruct"}.
#' @param y Unused.
#' @param series Optional character vector of series names to plot. If
#'   \code{NULL}, all available reconstructed series are used.
#' @param type One of:
#'   \describe{
#'     \item{`"compare"`}{Original and reconstructed series together.}
#'     \item{`"reconstructed"`}{Reconstructed component only.}
#'     \item{`"difference"`}{Original minus reconstructed.}
#'     \item{`"filtered"`}{Directly plot the returned filtered series. This is
#'       the extracted component for \code{mode = "extract"} and the
#'       component-removed series for \code{mode = "remove"}.}
#'   }
#' @param facet Logical. If \code{TRUE}, facet by series.
#' @param legend_position Legend position passed to ggplot2.
#' @param line_width Line width.
#' @param alpha Alpha transparency for original series in compare plots.
#' @param main Optional title.
#' @param ... Further arguments passed to or from other methods.
#'
#' @return A \code{ggplot2} object.
#'
#' @method plot dm_wavelet_reconstruct
#' @importFrom ggplot2 ggplot aes geom_line facet_wrap theme_bw theme
#'   element_text labs
#' @export
plot.dm_wavelet_reconstruct <- function(x,
                                        y = NULL,
                                        series = NULL,
                                        type = c("compare", "reconstructed", "difference", "filtered"),
                                        facet = TRUE,
                                        legend_position = "right",
                                        line_width = 0.8,
                                        alpha = 0.7,
                                        main = NULL,
                                        ...) {
  TIME <- value <- signal <- reconstructed <- difference <- filtered <- NULL

  if (!inherits(x, "dm_wavelet_reconstruct")) {
    stop("'x' must be an object of class 'dm_wavelet_reconstruct'.")
  }

  if (!requireNamespace("ggplot2", quietly = TRUE)) {
    stop("Package 'ggplot2' is required for plot.dm_wavelet_reconstruct().")
  }

  type <- match.arg(type)

  dat <- x$reconstructed_long

  if (!is.null(series)) {
    miss <- setdiff(series, unique(dat$series))
    if (length(miss) > 0) {
      stop("These requested series were not found: ", paste(miss, collapse = ", "))
    }
    dat <- dat[dat$series %in% series, , drop = FALSE]
  }

  if (nrow(dat) == 0) {
    stop("No rows remain after filtering.")
  }

  mode_label <- if (!is.null(x$mode) && x$mode == "remove") {
    "Filtered series (selected component removed)"
  } else {
    "Filtered series (selected component extracted)"
  }

  if (type == "compare") {
    long_dat <- tibble::tibble(
      TIME = c(dat$TIME, dat$TIME),
      series = c(dat$series, dat$series),
      signal = c(rep("original", nrow(dat)), rep("reconstructed", nrow(dat))),
      value = c(dat$original, dat$reconstructed)
    )

    p <- ggplot2::ggplot(
      long_dat,
      ggplot2::aes(x = TIME, y = value, colour = signal, group = signal)
    ) +
      ggplot2::geom_line(
        data = long_dat[long_dat$signal == "original", , drop = FALSE],
        linewidth = line_width,
        alpha = alpha
      ) +
      ggplot2::geom_line(
        data = long_dat[long_dat$signal == "reconstructed", , drop = FALSE],
        linewidth = line_width
      ) +
      ggplot2::theme_bw() +
      ggplot2::theme(
        legend.position = legend_position,
        axis.title = ggplot2::element_text(size = 14),
        axis.text = ggplot2::element_text(size = 11)
      ) +
      ggplot2::labs(
        x = "Time",
        y = "Value",
        colour = "Series type",
        title = if (is.null(main)) "Original and reconstructed wavelet component" else main
      )

    if (isTRUE(facet)) {
      p <- p + ggplot2::facet_wrap(stats::as.formula("~ series"), scales = "free_y", ncol = 1)
    }

    return(p)
  }

  if (type == "reconstructed") {
    p <- ggplot2::ggplot(
      dat,
      ggplot2::aes(x = TIME, y = reconstructed, colour = series)
    ) +
      ggplot2::geom_line(linewidth = line_width) +
      ggplot2::theme_bw() +
      ggplot2::theme(
        legend.position = if (isTRUE(facet)) "none" else legend_position,
        axis.title = ggplot2::element_text(size = 14),
        axis.text = ggplot2::element_text(size = 11)
      ) +
      ggplot2::labs(
        x = "Time",
        y = "Reconstructed value",
        colour = "Series",
        title = if (is.null(main)) "Reconstructed wavelet component" else main
      )

    if (isTRUE(facet)) {
      p <- p + ggplot2::facet_wrap(stats::as.formula("~ series"), scales = "free_y", ncol = 1)
    }

    return(p)
  }

  if (type == "difference") {
    p <- ggplot2::ggplot(
      dat,
      ggplot2::aes(x = TIME, y = difference, colour = series)
    ) +
      ggplot2::geom_line(linewidth = line_width) +
      ggplot2::theme_bw() +
      ggplot2::theme(
        legend.position = if (isTRUE(facet)) "none" else legend_position,
        axis.title = ggplot2::element_text(size = 14),
        axis.text = ggplot2::element_text(size = 11)
      ) +
      ggplot2::labs(
        x = "Time",
        y = "Original - reconstructed",
        colour = "Series",
        title = if (is.null(main)) "Difference between original and reconstructed series" else main
      )

    if (isTRUE(facet)) {
      p <- p + ggplot2::facet_wrap(stats::as.formula("~ series"), scales = "free_y", ncol = 1)
    }

    return(p)
  }

  if (type == "filtered") {
    p <- ggplot2::ggplot(
      dat,
      ggplot2::aes(x = TIME, y = filtered, colour = series)
    ) +
      ggplot2::geom_line(linewidth = line_width) +
      ggplot2::theme_bw() +
      ggplot2::theme(
        legend.position = if (isTRUE(facet)) "none" else legend_position,
        axis.title = ggplot2::element_text(size = 14),
        axis.text = ggplot2::element_text(size = 11)
      ) +
      ggplot2::labs(
        x = "Time",
        y = if (!is.null(x$mode) && x$mode == "remove") {
          "Filtered value (component removed)"
        } else {
          "Filtered value (component extracted)"
        },
        colour = "Series",
        title = if (is.null(main)) mode_label else main
      )

    if (isTRUE(facet)) {
      p <- p + ggplot2::facet_wrap(stats::as.formula("~ series"), scales = "free_y", ncol = 1)
    }

    return(p)
  }

  stop("Unknown plot type.")
}

#################### summary extaction #############################
#' Summarize a dm_wavelet_reconstruct object
#'
#' @description
#' Summarizes the output of \code{dm_wavelet_reconstruct()}.
#'
#' For each reconstructed series, the summary reports:
#' \itemize{
#'   \item number of observations,
#'   \item variance of the original series,
#'   \item variance of the reconstructed component,
#'   \item variance of the remaining component,
#'   \item variance of the returned filtered series,
#'   \item proportion of original variance represented by reconstructed and filtered series,
#'   \item correlation and \eqn{R^2} between original and reconstructed,
#'   \item number and range of periods used in reconstruction.
#' }
#'
#' @param object An object of class \code{"dm_wavelet_reconstruct"}.
#' @param ... Further arguments passed to or from other methods.
#'
#' @return
#' An object of class \code{"summary.dm_wavelet_reconstruct"} with elements:
#' \describe{
#'   \item{overview}{One-row summary of the reconstruction object.}
#'   \item{series_summary}{Per-series summary table.}
#'   \item{used_periods}{The table of periods actually used in the reconstruction.}
#'   \item{selection}{The selection settings used for reconstruction.}
#' }
#'
#' @method summary dm_wavelet_reconstruct
#' @export
summary.dm_wavelet_reconstruct <- function(object, ...) {

  if (!inherits(object, "dm_wavelet_reconstruct")) {
    stop("'object' must be an object of class 'dm_wavelet_reconstruct'.")
  }

  dat <- object$reconstructed_long
  used <- object$used_periods

  if (is.null(dat) || nrow(dat) == 0) {
    stop("The object contains no reconstructed data.")
  }

  safe_var <- function(z) {
    z <- z[is.finite(z)]
    if (length(z) < 2) return(NA_real_)
    stats::var(z)
  }

  safe_cor <- function(x, y) {
    ok <- is.finite(x) & is.finite(y)
    if (sum(ok) < 2) return(NA_real_)
    stats::cor(x[ok], y[ok])
  }

  ser_names <- unique(dat$series)

  series_rows <- lapply(ser_names, function(ss) {
    dss <- dat[dat$series == ss, , drop = FALSE]
    uss <- used[used$series == ss, , drop = FALSE]

    v_orig <- safe_var(dss$original)
    v_rec  <- safe_var(dss$reconstructed)
    v_rem  <- safe_var(dss$difference)
    v_filt <- safe_var(dss$filtered)

    cor_or_rec <- safe_cor(dss$original, dss$reconstructed)
    r2_or_rec <- if (is.finite(cor_or_rec)) cor_or_rec^2 else NA_real_

    tibble::tibble(
      series = ss,
      n_obs = sum(is.finite(dss$original)),
      variance_original = v_orig,
      variance_reconstructed = v_rec,
      variance_remaining = v_rem,
      variance_filtered = v_filt,
      prop_variance_reconstructed = if (is.finite(v_orig) && v_orig > 0) v_rec / v_orig else NA_real_,
      prop_variance_remaining = if (is.finite(v_orig) && v_orig > 0) v_rem / v_orig else NA_real_,
      prop_variance_filtered = if (is.finite(v_orig) && v_orig > 0) v_filt / v_orig else NA_real_,
      correlation_original_reconstructed = cor_or_rec,
      r_squared_original_reconstructed = r2_or_rec,
      n_used_periods = nrow(uss),
      min_used_period_hours = if (nrow(uss) > 0) min(uss$period_hours, na.rm = TRUE) else NA_real_,
      max_used_period_hours = if (nrow(uss) > 0) max(uss$period_hours, na.rm = TRUE) else NA_real_,
      mean_used_period_hours = if (nrow(uss) > 0) mean(uss$period_hours, na.rm = TRUE) else NA_real_
    )
  })

  series_summary <- dplyr::bind_rows(series_rows)

  overview <- tibble::tibble(
    mode = object$mode,
    n_series = length(unique(dat$series)),
    input_type = object$parent$input_type,
    source = object$parent$source,
    time_unit = object$parent$time_unit,
    dt = object$parent$dt,
    dt_hours = object$parent$dt_hours
  )

  out <- list(
    overview = overview,
    series_summary = series_summary,
    used_periods = used,
    selection = object$selection
  )

  class(out) <- "summary.dm_wavelet_reconstruct"
  out
}


#' Print method for summary.dm_wavelet_reconstruct
#'
#' @param x An object of class \code{"summary.dm_wavelet_reconstruct"}.
#' @param digits Number of digits for rounded numeric printing.
#' @param ... Further arguments passed to or from other methods.
#'
#' @return The input object, invisibly.
#'
#' @method print summary.dm_wavelet_reconstruct
#' @export
print.summary.dm_wavelet_reconstruct <- function(x, digits = 4, ...) {

  if (!inherits(x, "summary.dm_wavelet_reconstruct")) {
    stop("'x' must be an object of class 'summary.dm_wavelet_reconstruct'.")
  }

  cat("dm_wavelet_reconstruct summary\n")
  cat("------------------------------\n")
  print(x$overview)

  cat("\nSelection settings:\n")
  print(x$selection)

  cat("\nPer-series summary:\n")
  sx <- x$series_summary
  num_cols <- vapply(sx, is.numeric, logical(1))
  sx[num_cols] <- lapply(sx[num_cols], function(z) round(z, digits = digits))
  print(sx)

  if (!is.null(x$used_periods) && nrow(x$used_periods) > 0) {
    cat("\nUsed periods (hours):\n")
    up <- x$used_periods
    num_cols2 <- vapply(up, is.numeric, logical(1))
    up[num_cols2] <- lapply(up[num_cols2], function(z) round(z, digits = digits))
    print(up)
  }

  cat("\nNote:\n")
  cat("Variance ratios are reported relative to the original series.\n")
  cat("Wavelet reconstruction is not strictly orthogonal, so reconstructed and remaining variances do not necessarily sum to the original variance.\n")

  invisible(x)
}

Try the dendRoAnalyst package in your browser

Any scripts or data that you put into this service are public.

dendRoAnalyst documentation built on May 20, 2026, 5:07 p.m.