R/principalCurves.R

Defines functions .numplots

#' Principal Curves Analysis
#'
#' @description This is an alternative to the function contained in the princurv package. It is
#' somewhat simpler to use by comparison. This implementation uses a robust smoother from
#' the loess.smooth function by default.
#'
#' @param x A matrix or data frame of numeric variables.
#' @param tol Convergence threshold on shortest distances to the curve.
#' @param maxit Maximum number of iterations.
#' @param stretch A stretch factor for the endpoints of the curve, allowing the curve to grow to avoid bunching at the end. Must be a numeric value between 0 and 2. Defaults to 1.25.
#' @param family one of "symmetric" or "gaussian". defaults to "symmetric" which gives nearly identical
#' results as the "gaussian" option, unless there are outliers, in which case "symmetric" gives better results.
#' therefore the "symmetric" is typically the best option.
#' @param plot Should the iterations be plotted? Defaults to FALSE.
#'
#' @return a list with class "principalCurves"
#' @export
#'
#' @examples pcurves(iris[,-5])
#' @references Hastie, T. and Stuetzle, W., Principal Curves, JASA, Vol. 84, No. 406 (Jun., 1989), pp. 502-516, DOI: 10.2307/2289936
#'
pcurves <- function (x, tol = 0.00001, maxit = 1000, stretch = 1.25, family= c("symmetric", "gaussian"), plot = FALSE) {
    numcheck <- sapply(x, is.numeric)
    if (any(isFALSE(numcheck))) {
      return(cat(crayon::red(("'x' must not contain character or factor variables!")
      )))
    }
    x <- as.matrix(x)
    family <- match.arg(family)
    smoother_function <- function(x, y) {
        stats:::loess.smooth(x, y, family = family, evaluation = NROW(y))$y
    }
    function_call <- match.call()
    xbar <- L1median(x)$center
    xscale <- sapply(1:ncol(x), function(i) sqrt(mean((x[,i]-xbar[i])^2)))
    dist_old <- sum(xscale^2)
    xstar <- sweep(x, 2, xbar, "-")
    xstar <- scale(xstar, F, T)
    xstar <- sweep(xstar, 2, xscale, "*")
    svd_xstar <- wsvd(xstar)
    dd <- svd_xstar$d
    lambda <- svd_xstar$u[,1] * dd[1]
    ord <- order(lambda)
    s <- scale(outer(lambda, svd_xstar$v[,1]), center = -xbar, scale = FALSE)
    dimnames(s) <- dimnames(x)
    dist <- sum((dd^2)[-1]) * nrow(x)
    start <- list(
      s = s,
      ord = ord,
      lambda = lambda,
      dist = dist
    )

    pcurve <- start
    it <- 0
    s <- matrix(
      as.double(NA),
      nrow = nrow(x),
      ncol = ncol(x),
      dimnames = list(NULL, colnames(x))
    )
    has_converged <- abs(dist_old - pcurve$dist) <= tol * dist_old

    while (!has_converged && it < maxit) {
      it <- it + 1
      for (jj in seq_len(ncol(x))) {
        yjj <- smoother_function(pcurve$lambda, x[, jj])
        s[, jj] <- yjj
      }
      dist_old <- pcurve$dist
      pcurve <- projectToCurve(x = x, s = s, stretch = stretch)
      has_converged <- abs(dist_old - pcurve$dist) <= tol * dist_old
    }

    if (plot) {
      par(oma = c(0, 0, 3, 0), mar = c(4, 5, 1, 1) + 0.1)
      resp.n <- .numplots(ncol(x))
      par(mfrow = resp.n, cex = 1)
      for (i in 1:ncol(x)) {
        plot(
          pcurve$lambda,
          x[, i],
          xlab = " ",
          ylab = dimnames(x)[[2]][i],
          pch = 21,
          col = "#382222CC",
          bg = "#a3a3a326"
        )
        lines(pcurve$lambda[pcurve$ord],
              pcurve$scores[pcurve$ord, i],
              col = "#2e76f2D6",
              lwd = 2)
        if (floor(i / 6) == (i - 1) / 6)
          mtext(
            "Marginal plots of fitted curves",
            line = 1,
            cex = 1.25,
            outer = TRUE
          )
      }
    }

    out <- structure(
        list(
          scores = pcurve$scores,
          ord = pcurve$ord,
          lambda = pcurve$lambda,
          dist = pcurve$dist,
          converged = has_converged,
          num_iterations = as.integer(it),
          call = function_call
      ), class=c("principalCurves", "principal_curve")
    )

    out
}


.numplots <- function(n){
  if (n == 1){
    m <- c(1, 1)
  }
  else if (n == 2){
    m <- c(1, 2)
  }
  else if (n == 3){
    m <- c(1, 3)
  }
  else if (n == 4){
    m <- c(2, 2)
  }
  else if (n == 6){
    m <- c(3, 2)
  }
  else if (n==8 || n==7){
    m <- c(4, 2)
  }
  else if (n==9){
    m <- c(3, 3)
  }
  else{
    m <- c(2, 3)
  }
}
abnormally-distributed/cvreg documentation built on May 3, 2020, 3:45 p.m.