R_MM_backup/opls.R

#' Fitting Orthogonal-Partial Least Squares Models
#' @export
#' @description This function is used to fit  Orthogonal-Partial Least Squares (O-PLS) models. It can be used to carry out regression or discriminant analysis. In the latter case the outcome can have two or more levels.
#' @param X Numeric input matrix (measurements derived by NMR spectroscopy or MS) with each row representing an observation and each column a metabolic feature.
#' @param Y Response vector or matrix with same length or number of columns than rows in X, respectively.
#' @param t_pred Parameter specifying the maximum number of predictive components (needed only for multi-factor Y)
#' @param center Logical value (TRUE or FALSE) indicating if features should be mean centered.
#' @param scale Desired scaling method (currently only no or unit variance scaling (UV) implemented).
#' @param cv.k The number of cross-validation sets. This depends on the number of observations in X but typically takes a value between 3 and 9.
#' @param cv.type Type or cross-validation: 'k-fold', 'k-fold_stratified', 'MC', 'MC_stratified' (see Details).
#' @param plotting Logical value (TRUE or FALSE) indicating if model parameters (R2X, Q2, etc) should be visualised once the model is trained.
#' @param maxPCo The maximum number of orthogonal components (in case stop criteria fail).
#' @details Models are fully statistically validated, currently only k-fold cross validation (CV) and class-balanced k-fold cross validation is implemented. Further extensions, e.g. Monte-Carlo CV, are work in progress. Although the algorithm accepts three and more levels as Y, model interpretation is more straightforward for pairwise group comparisons.
#' @references Trygg J. and Wold, S. (2002) Orthogonal projections to latent structures (O-PLS). \emph{Journal of Chemometrics}, 16.3, 119-128.
#' @references Geladi, P and Kowalski, B.R. (1986), Partial least squares and regression: a tutorial. \emph{Analytica Chimica Acta}, 185, 1-17.
#' @return This function returns an \emph{OPLS_MetaboMate} S4 object.
#' @seealso \code{\link{OPLS_MetaboMate-class}} \code{\link{dmodx}} \code{\link{plotscores}} \code{\link{plotload}} \code{\link{specload}}
#' @author Torben Kimhofer \email{tkimhofer@@gmail.com}
#' @importFrom graphics plot
#' @importFrom methods getSlots new representation setClass
#' @importFrom stats cov sd var
#' @importFrom utils methods
#' @importFrom ggplot2 ggplot aes aes_string scale_fill_manual scale_y_continuous theme_bw labs scale_x_discrete scale_alpha theme element_blank element_line element_rect geom_bar element_text
#' @importFrom pROC roc multiclass.roc
#' @importFrom reshape2 melt
#' @importFrom scales pretty_breaks

opls <- function(X,
                                     Y,
                                     t_pred = 1,
                                     center = T,
                                     scale = 'UV',
                                     cv.k = 7,
                                     cv.type = 'k-fold',
                                     plotting = T,
                                     maxPCo = 5) {
  {

    cat('Preparing data ... ', sep = '')

    # Determine if regression (R) or discriminant analysis (DA),
    # check Y for NA or non-numeric values
    Y1 <- Y
    if (is.numeric(Y)) {
      type <- 'R'
      if (is.infinite(max(abs(Y))) | anyNA(Y)) {
        stop('Y contains non-numeric values.')
      }
      Y_out <- create_dummy_Y(Y)
      Y <- Y_out[[1]]
    } else{
      type <- 'DA'
      if (anyNA(Y)) {
        stop('Y vector contains N/A\'s')
      }
      Y_out <- create_dummy_Y(Y)
      Y <- Y_out[[1]]

      levs=unique(apply(Y, 2, function(x){
        length(unique(x))
      }))
      if(length(levs)==1 & levs[1]==1){
        stop('Y vector has only a single level')
      }
    }

    # check dimensions and convert to matrix, stop if features are non-numeric
    if (nrow(X) != nrow(Y)) {
      stop('X and Y dimensions do not match.')
    }
    X <- as.matrix(X)
    if (is.infinite(max(abs(X))) |
        anyNA(X)) {
      stop('X matrix contains non-numeric values.')
    }

    # Exclude features with sd=0
    sdX <- apply(X, 2, sd)
    idx <- which(sdX == 0)
    if (length(idx) > 0) {
      cat(
        'A total of ',
        length(idx),
        ' features have a standard deviation of zero and will be excluded from analysis.\n',
        sep = ''
      )
      X <- X[, -idx]
      sdX <- sdX[-idx]
    }

    # Dats scaling and centering
    meanX <- apply(X, 2, mean)
    XcsTot <- center_scale(X,
                            idc = 'all',
                            center = meanX,
                            scale = sdX)
    YcsTot <- center_scale(Y,
                            idc = 'all',
                            center = T,
                            scale = 'UV')

    # Calculate Total Sum of Squares (TSS), required for calculation of Q2 etc
    # cat('Calculate total sum of squares...')
    tssx <- totSS(XcsTot)
    tssy <- totSS(YcsTot) # total variance that a model can explain (theoretical reference model, used for normalisation)
     cat('done.\n')
  }

  cat('Performing OPLS-', type, ' ... ', sep='')
  # k-fold cross-validation for calculation of Q2/AUROC
  # generate indices defining training set in each CV round
  cv_sets <- cv_sets_method(cv.k, Y=Y1, method = cv.type, type)

  # initialisations
  R2Y <- rss <- Press <- aucs <- rssx <- array()
  cv.res <- output <- list()
  t_cv <-  t_orth_cv <- preds <- matrix(NA, ncol = 1, nrow = nrow(X))
  nc <- 1
  enough <- F

  # cat('Fiting components...')

  # fit as many orthogoanl components until a stop criterion is TRUE
  while (enough == F) {

    # cat('Component ', nc, '...')
    if (nc>1) { # add column for nc > 1
       preds <- cbind(preds, matrix(NA, ncol = 1, nrow = nrow(X)))
       t_cv <- cbind(t_cv, matrix(NA, ncol = 1, nrow = nrow(X)))
       t_orth_cv <- cbind(t_orth_cv, matrix(NA, ncol = 1, nrow = nrow(X)))
  }
        # fit each cv training set fit component and predict
        # validation set. Prediction accuracy output of is be used in stop criterion
    for (k in 1:length(cv_sets)) {

      # scale/centre using CV trainine set samples
      idc <- cv_sets[[k]]
      Xcs <- center_scale(X, idc, center, scale)
      Ycs <- center_scale(Y, idc, center, scale)
      # filter data: calc PLS component, then orthogonalise X with t_pred (=> t_o, p_o),
      # finally remove orthogoanl variation from X: Xres=X-(t_o * p_o))
      if (nc == 1) {
        cv.res[[k]] = list()
        E_opls <- NIPALS_OPLS_component_mulitlevel(X = Xcs[idc, ], Y = cbind(Ycs[idc, ]))
        cv.res[[k]][['p_orth']] = E_opls$`Loadings X orth`
        cv.res[[k]][['w_orth']] = E_opls$`Weights X orth`
        cv.res[[k]][['Xres']] = E_opls$`Filtered X`
      } else{
        E_opls = NIPALS_OPLS_component_mulitlevel(X = cv.res[[k]][['Xres']], Y =cbind(Ycs[idc, ]))
        cv.res[[k]][['p_orth']] = rbind(cv.res[[k]][['p_orth']], E_opls$`Loadings X orth`)
        cv.res[[k]][['w_orth']] = rbind(cv.res[[k]][['w_orth']], E_opls$`Weights X orth`)
        cv.res[[k]][['Xres']] = E_opls$`Filtered X`
      }
      # calculate predictive component with filtered matrix (Xres)
      #print(cbind(Ycs[idc, ]))
      pls_comp = NIPALS_PLS_component(X = cv.res[[k]][['Xres']], Y = cbind(Ycs[idc, ]))

      # iteratively remove all orthogonal components from prediction data set
      # potentially t_orth could be saved and output as scores orthogonal in CV round
      e_new_orth = Xcs[-idc, ]
      for (i in 1:nc) {
        t_orth = e_new_orth %*% t(t(cv.res[[k]][['w_orth']][i, ])) / drop(crossprod(t(t(cv.res[[k]][['w_orth']][i, ]))))
        e_new_orth = e_new_orth - (t_orth %*% t(cv.res[[k]][['p_orth']][i, ]))
      }
      #print(pls_comp)
      pred = pls_prediction(pls_mod = pls_comp, X = e_new_orth)

      # Save predictions in prediction matrix (one column for nc)
      preds[-idc, nc] = pred$Y_hat
      t_cv[-idc, nc] = pred$Scores_pred
      t_orth_cv[-idc, nc] = t_orth
    }


    # calculate rss, press and Q2
    Press[nc] = sum(apply(YcsTot, 2, function(x) {
      sum((x - preds[, nc]) ^ 2)
    })) / ncol(YcsTot)
    #print(Press)
    Q2_1 = 1 - (Press / tssy)
    #print(Q2_1)

    if (type == 'DA') {
      if (ncol(YcsTot) == 2) {
        mod = roc(response = Y[, 1], predictor = preds[, nc])
        aucs[nc] = mod$auc
      } else{
        mod = multiclass.roc(response = Y1, predictor = preds[, nc])
        aucs[nc] = mod$auc
      }


    } else{
      aucs[nc] = 0
    }


    # calculate r2Y with model using all data (no cv)
    if (nc == 1) {
      # initialisation
      # OPLS-filter X matrix repetitively until some stop criterion, save orthogonal loadings, scores and weights
      # need to scale this again

      E_opls_all = NIPALS_OPLS_component(X = XcsTot, Y = YcsTot)

      output[['t_orth']] = E_opls_all$`Scores X orth` # don't need scores for any calculation (only cv scores, and these are further down in prediction bit)
      output[['p_orth']] = E_opls_all$`Loadings X orth`
      output[['w_orth']] = E_opls_all$`Weights X orth`
      Xres = E_opls_all$`Filtered X`
    } else{
      E_opls_all = NIPALS_OPLS_component_mulitlevel(Xres, YcsTot)

      output[['t_orth']] = cbind(output[['t_orth']], E_opls_all$`Scores X orth`)
      output[['p_orth']] = rbind(output[['p_orth']], E_opls_all$`Loadings X orth`)
      output[['w_orth']] = rbind(output[['w_orth']], E_opls_all$`Weights X orth`)
      Xres = E_opls_all$`Filtered X`
    }

    # make predictions and calc R2
    # calculate predictive component with filtered data
    pls_comp_all = NIPALS_PLS_component(Xres, YcsTot)
    pred_all = pls_prediction(pls_mod = pls_comp_all, X = E_opls_all$`Filtered X`)
    #rss = sum((pred_all$Y_hat - YcsTot[, 1]) ^ 2)
    rss = sum(apply(YcsTot, 2, function(x) {
      sum((x - pred_all$Y_hat) ^ 2)
    })) / ncol(YcsTot)
    R2Y[nc] = 1 - (rss / tssy)

    X_ex = pls_comp_all$scores %*% pls_comp_all$loadings
    rssx[nc] = totSS(X_ex)
    R2x = (rssx / tssx)


    ### define stop criteria for fitting components
    if (length(which(is.na(Q2_1))) > 0) {
      cat('Something went wrong, Q2 is NA for component', nc, '\n',  sep = '')
    }

    if((type=='R' & nc == 1 & ( Q2_1[nc] < 0.05)) | (type=='DA' & nc == 1 & ( aucs[nc] < 0.6))){
    #if (nc == 1 & ( Q2_1[nc] < 0.05 & aucs[nc] < 0.7) ) {
      cat('At first PC, Q2 < 0.03: ', round(Q2_1[nc], 3), '\n',  sep = '')
      cat('At first PC, AUROC < 0.6: ', round(aucs[nc], 3), '\n',  sep = '')
      print('No sign. orthogonal components!')
      return(NULL)
      #print(Q2_cv)
      enough = T
    }


    if (nc > 1) {
      if ((Q2_1[nc] - Q2_1[nc - 1]) < 0.05) {
        # if q2 does not rise by more than 0.03 with new component
        # cat('At PC ', nc, ': delta Q2 < 0.03: ', round((Q2_1[nc] - Q2_1[nc -                                                                  1]), 3), '\n',  sep = '')
        enough = T
      }

    }

    if (Q2_1[nc] > 0.98 | aucs[nc] > 0.97) {
      # cat('At PC ', nc, ', Q2 > 0.98: ', round(Q2_1[nc], 3), '\n', sep = '')
      enough = T
    }

    if (nc == maxPCo) {
      enough = T
      nc = nc + 1
    }

    if (enough == F) {
      nc = nc + 1
    }

  }
  cat('done.\nA model with 1 predictive and', nc-1, 'orthogonal component(s) was fitted.\n\n')
  if (type == 'DA') {
    model.summary = data.frame(
      R2X = round(R2x, 2),
      R2Y = round(R2Y, 2),
      Q2 = round(Q2_1, 2),
      AUROC = round(aucs, 2)
    )

    if(min(model.summary$Q2)<(-0.05)){
      model.summary$Q2[which(model.summary$Q2<(-0.05))]=-0.05
    }

  } else{
    model.summary = data.frame(
      R2X = round(R2x, 2),
      R2Y = round(R2Y, 2),
      Q2 = round(Q2_1, 2)
    )
  }
  rownames(model.summary) = paste('PC_o', 1:(nrow(model.summary)))

  mm = cbind('PC_o' = rownames(model.summary), model.summary)
  mm = melt(mm, id.vars = 'PC_o')
  mm$PC = factor(mm$'PC_o', levels = rownames(model.summary))

  mm$alpha1 = 1
  mm$alpha1[mm$'PC_o' == paste('PC_o', nrow(model.summary))] = 0.7

  g = ggplot(mm, aes_string('PC_o',' value', fill = 'variable')) +
    geom_bar(stat = 'identity',
             position = "dodge",
             colour = NA,
             aes(alpha= mm$'alpha1') )+
    scale_y_continuous(limits = c(min(c(0, min(-0.05) - 0.02)), 1), breaks = pretty_breaks()) +
    theme_bw() +
    labs(y = '',
      #title = paste('An  O-PLS-', type, ' model (1+', nc - 1, ') was fitted', sep = '')
      title = paste('O-PLS-', type, '  - Model Summary', sep = '')
    ) +
    scale_x_discrete(labels = paste('1+', 1:nc, sep = ''), name = 'Predictive + Orthogonal Component') +
    scale_fill_manual(
      values = c(
        'R2X' = 'lightgreen',
        'R2Y' = 'lightblue',
        'Q2' = 'red',
        'AUROC' = 'black'
      ),
      labels = c(
        expression(R ^ 2 * X),
        expression(R ^ {
          2
        } * Y),
        expression(Q ^ 2),
        expression(AUROC[cv])
      ),
      name = ''
    ) +
    scale_alpha(guide = F, limits=c(0,1)) +
    theme(
      legend.text.align = 0,
      panel.grid.major.x = element_blank(),
      panel.grid.major.y = element_line(color = 'black', size =
                                          0.15),
      panel.grid.minor = element_blank(),
      panel.border = element_blank(),
      axis.line.x = element_line(color = 'black', size = 0.55),
      axis.line.y = element_blank(),
      axis.ticks = element_blank(),
      legend.key = element_rect(colour = 'white'),
      text=element_text(family="Helvetica")
    )

  model.summary=model.summary[-nrow(model.summary),]
  #
  if (plotting == T) {
    plot(g)
  }

  if (nc > 0) {
    # re calc t_pred with number of nc-2 (one componenet remove and nc counter has already been set 1 up)
    Xorth = cbind(output[['t_orth']][, 1:(nc - 1)]) %*% rbind(output[['p_orth']][1:(nc -
                                                                                      1), ])
    Xres = XcsTot - Xorth
    pls_comp_all = NIPALS_PLS_component(Xres, YcsTot)

    E = Xres - (pls_comp_all$scores %*% pls_comp_all$loadings) # this is for calculation of DModX

    # this is for visualisation of orthogonal variation (currently not implemented)
    #E_pca=NIPALS_PCAcomponent(E+Xorth)



    dd = data.frame(
      Paramter = c(
        'Center',
        'Scale',
        'nPred',
        'nOrth',
        'CV.k',
        'CV.type',
        'tssx',
        'tssy'
      ),
      Value = c(center, scale, 1, nc - 1, cv.k, cv.type, tssx, tssy)
    )


    # define slots for OPLS_Torben object
    mod_opls = new('OPLS_MetaboMate',
      type = type,
      t_pred = pls_comp_all$scores,
      p_pred = pls_comp_all$loadings,
      w_pred = pls_comp_all$weights,
      betas_pred = pls_comp_all$betas,
      Qpc = pls_comp_all$Qpc,
      t_cv = cbind(t_cv[, (nc - 1)]),
      t_orth_cv = cbind(t_orth_cv[, (nc - 1)]) ,
      t_orth = cbind(output[['t_orth']][, 1:(nc - 1)]),
      p_orth = rbind(output[['p_orth']][1:(nc - 1), ]),
      w_orth = rbind(output[['w_orth']][1:(nc - 1), ]),
      nPC = nc - 1,
      summary = model.summary,
      X_orth = Xorth,
      Y_res = pls_comp_all$Y.res,
      Xcenter = attr(XcsTot, "scaled:center"),
      Xscale = attr(XcsTot, "scaled:scale"),
      Ycenter = attr(YcsTot, "scaled:center"),
      Yscale = attr(YcsTot, "scaled:scale"),
      Yout = Y_out[[2]],
      Parameters = dd,
      #sdFeatures=sdX,
      E = E
    )

    return(mod_opls)
  } else{
    print('No sign. orthogonal components - try PLS!')
    return(NULL)
  }

}
kimsche/MetaboMate documentation built on Aug. 8, 2020, 1:14 a.m.