R/predict.logforest.R

Defines functions predict.logforest

Documented in predict.logforest

#' Predict Outcomes Using a Logic Forest Model
#'
#' Computes predicted values for new observations or the out-of-bag (OOB) predictions 
#' for a logic forest model fitted using \code{logforest}.
#'
#' @param object An object of class \code{"logforest"}.
#' @param newdata A matrix or data frame of new predictor values. If omitted, predictions 
#'   are made for the original data used to fit the model (OOB predictions).
#' @param cutoff A numeric value between 0 and 1 specifying the minimum proportion of trees 
#'   that must predict a class of 1 for the overall prediction to be 1. Ignored for non-classification models.
#' @param ... Additional arguments (currently ignored).
#'
#' @return An object of class \code{"LFprediction"} containing:
#' \itemize{
#'   \item \code{LFprediction}: numeric vector of predicted responses.
#'   \item \code{proportion_one}: numeric vector of the proportion of trees predicting class 1 (classification only).
#'   \item \code{AllTrees}: matrix or data frame with predicted values from each tree, 
#'     the proportion of trees predicting 1, and the overall predicted class (classification), 
#'     or predicted values for regression/time-to-event models.
#' }
#'
#' @details
#' For classification models, predictions are determined based on the \code{cutoff} proportion. 
#' For regression or time-to-event models, the function returns predicted values and OOB statistics if \code{newdata} is not provided.
#'
#' @author Bethany Wolf \email{wolfb@@musc.edu}
#'
#' @seealso \code{\link{logforest}}
#' @export
predict.logforest<-function(object, newdata, cutoff,...)
{
  if (inherits(object)!= "logforest")
    stop("object not of class logforest")
  nBS<-length(object$AllFits)
  trees<- object$AllFits
  mtype<-object$model.type
  if(mtype=="Classification")
  {
    if(missing(cutoff)) cutoff<-0.5

    if (missing(newdata))
    {
      LFprediction<-object$OOBprediction[,1]
      proportion_one<-object$OOBprediction[,2]
      ans<-list(model.type=mtype, LFprediction=LFprediction, proportion_one=proportion_one)
    }
    if (!missing(newdata))
    {
      pred<-ncol(newdata)
      if (pred!=object$predictors)
        stop("the predictors in newdata do not match the original predictors")
      size<-nrow(newdata)
      predict.new<-matrix(0, nrow=size, ncol=nBS)
      for (i in 1:nBS)
      {
        newX<-newdata[,1:pred]
        newpredict<-predict.logreg2(object=trees[[i]], newbin=as.matrix(newX))
        predict.new[,i]<- newpredict
      }
      predictions<-proportion.positive(predictmatrix=predict.new, cutoff=cutoff)
      predmatrix<-cbind(predict.new, predictions$predmat)
      predframe<-as.data.frame(predmatrix)
      names(predframe)[1:nBS]<-paste("tree", 1:nBS, sep="")
      names(predframe)[nBS+1]<-paste("proportion_one")
      names(predframe)[nBS+2]<-paste("PredictedValue")
      ans<-list(model.type=mtype, LFprediction=predframe$PredictedValue, proportion_one=predframe$proportion_one,
                AllTrees=predframe)
    }
  }

  if(mtype=="Linear Regression")
  {
    if (missing(newdata))
    {
      LFprediction<-object$OOBprediction[,2]
      OOBmse<-object$OOBmiss
      ans<-list(model.type=mtype, LFprediction=LFprediction, OOBmse=OOBmse, ptype="OOBprediction")
    }
    if (!missing(newdata))
    {
      pred<-ncol(newdata)
      if (pred!=object$predictors)
        stop("the predictors in newdata do not match the original predictors")
      size<-nrow(newdata)
      predict.new<-matrix(0, nrow=size, ncol=nBS)
      for (i in 1:nBS)
      {
        newX<-newdata[,1:pred]
        newpredict<-predict.logreg2(object=trees[[i]], newbin=as.matrix(newX))
        predict.new[,i]<- newpredict
      }
      predictions<-rowMeans(predict.new)
      predmatrix<-cbind(predict.new, predictions)
      predframe<-as.data.frame(predmatrix)
      names(predframe)[1:nBS]<-paste("tree", 1:nBS, sep="")
      names(predframe)[nBS+1]<-paste("PredictedValue")
      ans<-list(model.type=mtype, LFprediction=predframe$PredictedValue, AllTrees=predframe)
    }
  }
  if(mtype=="Cox-PH Time-to-Event")
  {
    if (missing(newdata))
    {
      LFprediction<-object$OOBprediction[,2]
      OOBmse<-object$OOBmiss
      ans<-list(model.type=mtype, LFprediction=LFprediction, OOBmse=OOBmse, ptype="OOBprediction")
    }
    if (!missing(newdata))
    {
      pred<-ncol(newdata)
      if (pred!=object$predictors)
        stop("the predictors in newdata do not match the original predictors")
      size<-nrow(newdata)
      predict.new<-matrix(0, nrow=size, ncol=nBS)
      for (i in 1:nBS)
      {
        newX<-newdata[,1:pred]
        newpredict<-predict.logreg2(object=trees[[i]], newbin=as.matrix(newX))
        predict.new[,i]<- newpredict
      }
      predictions<-rowMeans(predict.new)
      predmatrix<-cbind(predict.new, predictions)
      predframe<-as.data.frame(predmatrix)
      names(predframe)[1:nBS]<-paste("tree", 1:nBS, sep="")
      names(predframe)[nBS+1]<-paste("PredictedValue")
      ans<-list(model.type=mtype, LFprediction=predframe$PredictedValue, AllTrees=predframe)
    }
  }
  if(mtype=="Exp. Time-to-Event")
  {
    if (missing(newdata))
    {
      LFprediction<-object$OOBprediction[,2]
      OOBmse<-object$OOBmiss
      ans<-list(model.type=mtype, LFprediction=LFprediction, OOBmse=OOBmse, ptype="OOBprediction")
    }
    if (!missing(newdata))
    {
      pred<-ncol(newdata)
      if (pred!=object$predictors)
        stop("the predictors in newdata do not match the original predictors")
      size<-nrow(newdata)
      predict.new<-matrix(0, nrow=size, ncol=nBS)
      for (i in 1:nBS)
      {
        newX<-newdata[,1:pred]
        newpredict<-predict.logreg2(object=trees[[i]], newbin=as.matrix(newX))
        predict.new[,i]<- newpredict
      }
      predictions<-rowMeans(predict.new)
      predmatrix<-cbind(predict.new, predictions)
      predframe<-as.data.frame(predmatrix)
      names(predframe)[1:nBS]<-paste("tree", 1:nBS, sep="")
      names(predframe)[nBS+1]<-paste("PredictedValue")
      ans<-list(model.type=mtype, LFprediction=predframe$PredictedValue, AllTrees=predframe)
    }
  }
  class(ans)<-"LFprediction"
  return(ans)
}

Try the LogicForest package in your browser

Any scripts or data that you put into this service are public.

LogicForest documentation built on Feb. 14, 2026, 1:08 a.m.