R/helper.R

Defines functions plot_cv_field plot_sequentially detrend checkInputData setCores

Documented in checkInputData detrend plot_cv_field plot_sequentially setCores

#' Internal function: Set the number of cores for parallel computing
#'
#' @keywords internal
#' @param ncores Number of cores for parallel computing. Default is NULL.
#' @return Logical
#'
setCores <- function(ncores = NULL) {
  if (!is.null(ncores)) {
    if (!is.numeric(ncores)) {
      stop(paste0("Please enter valid type - but got ", class(ncores)))
    }

    defaultNumber <- tryCatch({
      # Prefer environment override if set; otherwise use detected cores
      env_threads <- Sys.getenv("RCPP_PARALLEL_NUM_THREADS", unset = NA)
      if (!is.na(env_threads) && suppressWarnings(!is.na(as.numeric(env_threads)))) {
        as.integer(env_threads)
      } else {
        as.integer(parallel::detectCores(logical = TRUE))
      }
    }, error = function(...) 1L)

    if (ncores > defaultNumber) {
      stop(paste0("The input number of cores is invalid - default is ", defaultNumber))
    }
    if (ncores < 1) {
      stop(paste0("The number of cores is not greater than 1 - but got ", ncores))
    }
    # RcppParallel honors this environment variable for both TBB and tinythread backends
    Sys.setenv(RCPP_PARALLEL_NUM_THREADS = as.integer(ncores))
    return(TRUE)
  }
}

#'
#' Internal function: Validate input data for a spatmca object
#'
#' @keywords internal
#' @param x1 Location matrix (\eqn{p \times d}) corresponding to Y1.
#' @param x2 Location matrix (\eqn{q \times d}) corresponding to Y2.
#' @param Y1 Data matrix (\eqn{n \times p}) of the first variable stores the values at \eqn{p} locations with sample size \eqn{n}.
#' @param Y2 Data matrix (\eqn{n \times q}) of the second variable stores the values at \eqn{q} locations with sample size \eqn{n}.
#' @param M Number of folds for cross-validation
#' @return `NULL`
#'
checkInputData <- function(x1, x2, Y1, Y2, M) {
  if (nrow(x1) != ncol(Y1)) {
    stop("The number of rows of x1 should be equal to the number of columns of Y1.")
  }
  if (nrow(x2) != ncol(Y2)) {
    stop("The number of rows of x2 should be equal to the number of columns of Y2.")
  }
  if (nrow(x1) < 3 || nrow(x2) < 3) {
    stop("Number of locations must be larger than 2.")
  }
  if (ncol(x1) > 3 || ncol(x2) > 3) {
    stop("Dimension of locations must be less 4.")
  }
  if (nrow(Y1) != nrow(Y2)) {
    stop("The numbers of sample sizes of both data should be equal.")
  }
  if (M >= nrow(Y1)) {
    stop("Number of folds must be less than sample size, but got M = ", M)
  }
}

#'
#' Internal function: Detrend Y by column-wise centering
#'
#' @keywords internal
#' @param Y Data matrix
#' @return Detrended data matrix
#'
detrend <- function(Y, is_Y_detrended) {
  if (is_Y_detrended) {
    return(Y - rep(colMeans(Y), rep.int(nrow(Y), ncol(Y))))
  } else {
    return(Y)
  }
}

#' Internal function: Plot sequentially
#' @keywords internal
#' @param objs Valid ggplot2 objects
#' @return `NULL`
#' 
plot_sequentially <- function(objs) {
  original_ask <- par("ask")
  on.exit(par(ask = original_ask))
  par(ask = TRUE)
  for (obj in objs) {
    suppressWarnings(print(obj))
  }
}


#' Internal function: Plot 2D fields for cross validation results 
#' @keywords internal
#' @param cv_data A dataframe contains columns ``u``, ``v``, and ``cv`` 
#' @param variate A character represent the title
#' @return A ggplot object
plot_cv_field <- function(cv_data, variate) {
  default_theme <- theme_classic() +
    theme(
      text = element_text(size = 24),
      plot.title = element_text(hjust = 0.5)
    )

  result <- ggplot(cv_data, aes(x = u, y = v, z = cv, fill = cv)) +
    geom_tile() +
    scale_y_continuous(
      trans = log_trans(),
      breaks = trans_breaks("log", function(x) exp(x)),
      labels = trans_format("log", math_format(e^.x))
    ) +
    scale_x_continuous(
      trans = log_trans(),
      breaks = trans_breaks("log", function(x) exp(x)),
      labels = trans_format("log", math_format(e^.x))
    ) +
    ggtitle(variate) +
    default_theme
  return(result)
}

Try the SpatMCA package in your browser

Any scripts or data that you put into this service are public.

SpatMCA documentation built on Nov. 5, 2025, 5:42 p.m.