R/llm.predict.R

Defines functions predict.llm

Documented in predict.llm

#' Create Logit Leaf Model Prediction
#'
#' This function creates a prediction for an object of class logitleafmodel. It assumes a dataframe with numeric
#' values as input and an object of class logitleafmodel, which is the result of the \code{\link{llm}} function.
#' Currently only binary classification is supported.
#'
#' @param X Dataframe containing numerical independent variables.
#' @param object An object of class logitleafmodel, as that created by the function llm.
#' @param ... further arguments passed to or from other methods.
#' @return Returns a dataframe containing a probablity for every instance based on the LLM model. Optional rownumbers can be added.
#' @export
#' @import partykit
#' @importFrom stats predict
#' @references Arno De Caigny, Kristof Coussement, Koen W. De Bock, A New Hybrid Classification Algorithm for Customer Churn Prediction Based on Logistic Regression and Decision Trees, European Journal of Operational Research (2018), doi: 10.1016/j.ejor.2018.02.009.
#' @author Arno De Caigny, \email{a.de-caigny@@ieseg.fr}, Kristof Coussement, \email{k.coussement@@ieseg.fr} and Koen W. De Bock, \email{kdebock@@audencia.com}
#' @seealso \code{\link{llm}}, \code{\link{table.llm.html}}, \code{\link{llm.cv}}
#' @examples
#' ## Load PimaIndiansDiabetes dataset from mlbench package
#' if (requireNamespace("mlbench", quietly = TRUE)) {
#'   library("mlbench")
#' }
#' data("PimaIndiansDiabetes")
#' ## Split in training and test (2/3 - 1/3)
#' idtrain <- c(sample(1:768,512))
#' PimaTrain <-PimaIndiansDiabetes[idtrain,]
#' Pimatest <-PimaIndiansDiabetes[-idtrain,]
#' ## Create the LLM
#' Pima.llm <- llm(X = PimaTrain[,-c(9)],Y = PimaTrain$diabetes,
#'  threshold_pruning = 0.25,nbr_obs_leaf = 100)
#' ## Use the model on the test dataset to make a prediction
#' PimaPrediction <- predict.llm(object = Pima.llm, X = Pimatest[,-c(9)])
#' ## Optionally add the dependent to calculate performance statistics such as AUC
#' # PimaPrediction <- cbind(PimaPrediction, "diabetes" = Pimatest[,"diabetes"])
#'@export predict.llm
#'

predict.llm <- function(object, X, ...){

  # Temp function: Preditc_llm_obs function
  # Inputs
  # == observation : a single observation
  # == llmobject: object of class logitleafmodel
  predict_llm_obs <- function(observation, llmobject){
    # Check the format of observation: If not yet in dataframe, create dataframe
    if ((!is.data.frame(observation))) {
      observation <- as.data.frame(t(observation))
    }
    # Get the segmentspecific model
    temp_dt <- partykit::as.party(llmobject[[3]])
    segment_lr <- llmobject[[2]][[which(names(llmobject[[1]])==stats::predict(temp_dt, newdata=observation, type="node"))]]

    # return the predicted response
    return(stats::predict(segment_lr, newdata=observation, type="response", verbose = FALSE))

  }

  # Apply the temporary predict_llm_obs function to every observation
  myreturn <- as.data.frame(apply(X = X,MARGIN = 1,FUN=predict_llm_obs, llmobject=object))
  names(myreturn) <- "probability"

  # Return the probabilities
  return(myreturn)
}

Try the LLM package in your browser

Any scripts or data that you put into this service are public.

LLM documentation built on July 1, 2020, 7:19 p.m.