R/fit.R

Defines functions flpcr

flpcr <- function(Y, X, k, tol = 1e-6, maxit = 1e3, quiet = TRUE,
                relative = TRUE, maxit_upd = 50, tol_upd = 1e-6,
                quiet_upd = TRUE, step1 = NA, step2 = NA, pen = 0,
                alpha, Sigma, Psi, tau)
{
  # Define constants
  p <- ncol(X)
  n <- nrow(X)
  r <- ncol(Y)
  if(k >= min(n, p)) stop("PSR requires k < min(n, p)")

  if(missing(Psi) | missing(tau)){
    # Use probabilistic principal components estimates
    e_S <- eigen(crossprod(X) / n, symmetric = TRUE)
    tau <- mean(e_S$values[(k + 1):p])
    Psi <- e_S$vectors[, 1:k] %*%
      tcrossprod(diag((e_S$values[1:k, drop = FALSE] / tau), k),
                 e_S$vectors[, 1:k, drop = FALSE])
    rm(e_S)
  }

  if(missing(alpha) | missing(Sigma)){
    # Do one round of updates for initialization
    alpha <- update_alpha(Y = Y, Z = X %*% Psi, pen = pen, Sigma = Sigma)
    Sigma <- crossprod(Y - X %*% Psi %*% alpha) / n
  }

  # Starting objective value
  e_P <- eigen(Psi)
  e_P$values <- proj_evals(e_P$values, k = k)
  obj_old <- obj_fun(U = e_P$vectors, d = e_P$values, alpha = alpha, Sigma = Sigma,
                 tau = tau, Y = Y, X = X)
  if(obj_old == Inf) stop("Need finite starting objective value")

  for(iter in seq(maxit)){
    if(iter == maxit){
      warning("Reached maxiter \n")
    }
    #############################################################################
    # Update alpha
    #############################################################################
    Z <- X %*% Psi
    alpha <- update_alpha(Y = Y, Z = Z, pen = pen, Sigma = Sigma)
    # Print info
    if(!quiet){
      obj_after_alpha <- obj_fun(U = e_P$vectors, d = e_P$values, alpha = alpha,
                     Sigma = Sigma, tau = tau, Y = Y, X = X)
      change <- obj_after_alpha - obj_old
      if(change > tol) warning("alpha update increased objective \n")
      cat("At iteration: : ",iter, "\n")
      cat("Change from alpha update: ",change, "\n")
      cat("# ---------", "\n")
    }
    #############################################################################


    #############################################################################
    # Update Sigma
    #############################################################################
    Sigma <- update_sigma(Y = Y, Z = Z, alpha = alpha)
    # Print info
    if(!quiet){
      obj_after_sigma <- obj_fun(U = e_P$vectors, d = e_P$values, alpha = alpha,
                             Sigma = Sigma, tau = tau, Y = Y, X = X)
      change <- obj_after_sigma - obj_after_alpha
      if(change > tol) warning("Sigma update increased objective \n")
      cat("At iteration: : ",iter, "\n")
      cat("Change from Sigma update: ",change, "\n")
      cat("# ---------", "\n")
    }
    #############################################################################


    #############################################################################
    # Update tau
    #############################################################################
    tau_old <- tau
    tau <- update_tau(X = X, U = e_P$vectors, d = e_P$values)
    # Print info
    if(!quiet){
      obj_after_tau <- obj_fun(U = e_P$vectors, d = e_P$values, alpha = alpha,
                             Sigma = Sigma, tau = tau, Y = Y, X = X)
      change <- obj_after_tau - obj_after_sigma
      if(change > tol) warning("tau update increased objective \n")
      cat("At iteration: : ",iter, "\n")
      cat("Change from tau update: ",change, "\n")
      cat("Change in tau: ", tau - tau_old, "\n")
      cat("# ---------", "\n")
    }
    #############################################################################


    #############################################################################
    # Update Psi
    #############################################################################
    psi_update <- update_psi_mapg(alpha = alpha, Sigma = Sigma, Psi = Psi, tau = tau,
                              k = k, X = X, Y = Y, step1 = step1, step2 = step2,
                              maxit = maxit_upd, tol = tol_upd,
                              relative = relative, quiet = quiet_upd)
    Psi <- psi_update$Psi
    e_P <- eigen(Psi)
    if(!quiet){
      obj_after_psi <- obj_fun(U = e_P$vectors, d = e_P$values, alpha = alpha,
                           Sigma = Sigma, tau = tau, Y = Y, X = X)
      change <- obj_after_psi - obj_after_tau
      if(change > tol) warning("Psi update increased objective \n")
      cat("At iteration: : ",iter, "\n")
      cat("Change from Psi update: ",change, "\n")
      cat("# ---------", "\n")
    }
    #############################################################################


    #############################################################################
    # Wrap up iteration
    #############################################################################
    obj_new <- obj_fun(U = e_P$vectors, d = e_P$values, alpha = alpha,
                   Sigma = Sigma, tau = tau, Y = Y, X = X)
    change <- obj_new - obj_old

    if(relative) change <- change / obj_old
    obj_old <- obj_new
    if(abs(change) < tol) break
  }
  return(list(alpha = alpha, Sigma = Sigma, Psi = Psi, tau = tau, iter = iter,
              obj = obj_new, change = change))
}
koekvall/mpredcc documentation built on Nov. 4, 2019, 3:54 p.m.