R/predict-my-multiclass.R

Defines functions predict.my_multiclass

Documented in predict.my_multiclass

#' @title Predict Method for Multiple Class Logistic Regression Model Fits
#' @description Predicted values based on linear model object
#' @param object object of class inheriting from "my_multiclass"
#' @param X design matrix used in fitting "my_multiclass"
#'
#' @examples
#' data(iris)
#' X <- model.matrix(Species ~ ., data = iris)
#' Y <- iris$Species
#' fit_m <- my_multiclass(X, Y)
#' y_hat <- predict(fit_m)$Y_hat
#' @export

predict.my_multiclass <- function(object, X){
  K <- length(object$coefficients)
  Y_hat_tmp <- matrix(NA, ncol = K, nrow = nrow(X))
  for(i in seq_len(K)){
    eta_hat <- X %*% object$coefficients[[i]]
    Y_hat_tmp[, i] <- exp(eta_hat) / (1 + exp(eta_hat))
  }
  Y_hat <- object$y_lab[apply(Y_hat_tmp, 1, which.max)]
  return(list(Y_hat = Y_hat))
}
tqchen07/bis557 documentation built on Dec. 21, 2020, 3:06 a.m.