R/predict.R

Defines functions predict.reduced

#' @export
#' @keywords internal
#' @importFrom dplyr select
#' @importFrom stringr str_extract
#' @importFrom dplyr ungroup
#' @importFrom stats predict
#'
#'
predict.reduced <- function(object, newdata, ..., reduced = TRUE){
newdata <- as.data.frame(newdata)
  class(object) <- class(object)[-1]
  ind <- any(names(newdata) %in% paste(object$group))
  if(ind){
    newdatagroup <- select(newdata, eval(object$group))
    newdata <- select(ungroup(newdata), -eval(object$group))
  }
  reducednewdata <- if(reduced == TRUE){
    t(object$projection %*% t(newdata))
  }else{
      newdata
  }
  pred <- predict(object, reducednewdata)
  if(ind){
    check <- all.equal(as.factor(newdatagroup[[1]]), pred$class)
    cer <- if(is.logical(check)){
      as.numeric(check)
    }else{
      as.numeric(str_extract(check, "[0-9]+")) / length(newdatagroup[[1]])
    }
    pred <- c(pred, cer = cer)
  }
  pred
}
BenBarnard/slidR documentation built on Jan. 2, 2018, 4:32 p.m.