R/predict.graper.R

Defines functions predict.graper

Documented in predict.graper

#' @title Predict response on new data
#' @name predict.graper
#' @description Function to predict the response on a
#' new data set using a fitted graper model.
#' @param object fitted graper model as obtained from  \code{\link{graper}}
#' @param newX Predictor matrix of size n_test
#'(number of new test samples) x p (number of predictors)
#' (same feature structure as used in \code{\link{graper}})
#' @param type type of prediction returned, either:
#' \itemize{
#'  \item{\strong{response}:}{returns the linear predictions
#'   for linear regression and class probabilities
#'   for logistic regression}
#'  \item{\strong{link}:}{returns the linear predictions}
#'  \item{\strong{inRange}:}{returns linear predictions for linear
#'   and class memberships for logistic regression}
#' }
#' @param ... other arguments
#' @importFrom methods is
#' @return A vector with predictions.
#' @export
#' @examples
#' # create data
#' dat <- makeExampleData()
#' # split data into train and test sets of equal size
#' ntrain <- dat$n / 2
#' # fit the model to the train data
#' fit <- graper(dat$X[seq_len(ntrain), ],
#'               dat$y[seq_len(ntrain)], dat$annot)
#' # make predictions on the test data
#' ypred <- predict(fit, dat$X[seq_len(ntrain) + dat$n / 2, ])
#'
#' # create data for logistic regression
#' dat <- makeExampleData(response="bernoulli")
#' # split data into train and test sets of equal size
#' ntrain <- dat$n / 2
#' # fit the graper model for a logistic model
#' fit <- graper(dat$X[seq_len(ntrain), ],
#'               dat$y[seq_len(ntrain)],
#'               dat$annot, family="binomial")
#' # make predictions on the test data
#' ypred <- predict(fit, dat$X[seq_len(ntrain) + dat$n / 2, ], type = "inRange")

predict.graper <- function(object, newX,
                        type = c("inRange","response", "link"), ...){

    # sanity check
    if(ncol(newX) != nrow(object$EW_beta)) {
        stop("Number of columns in newX need to agree with number
            of predictors in the graper object.")
    }
    # sanity check
    if(!is(object, "graper")) {
        stop("object needs to be a graper object.")
    }
    # Get type of predictions wanted
    type = match.arg(type)

    if(is.null(object$intercept)) {
        object$intercept <- 0
    }
    if(object$Options$family == "gaussian"){
        pred <- object$intercept + newX %*% object$EW_beta
    } else {
        predexp <- object$intercept + newX %*%  object$EW_beta
        probs <- exp(predexp)/(1 + exp(predexp))
        if(type == "link") {
            pred <- predexp
        } else if(type == "response"){
            pred <- probs
        } else {
            pred <- round(probs)
        }
    }
    return(pred)
}

Try the graper package in your browser

Any scripts or data that you put into this service are public.

graper documentation built on Nov. 8, 2020, 5:45 p.m.