#' @title Predict Method for Multiple Class Logistic Regression Model Fits
#' @description Predicted values based on linear model object
#' @param object object of class inheriting from "my_multiclass"
#' @param X design matrix used in fitting "my_multiclass"
#'
#' @examples
#' data(iris)
#' X <- model.matrix(Species ~ ., data = iris)
#' Y <- iris$Species
#' fit_m <- my_multiclass(X, Y)
#' y_hat <- predict(fit_m)$Y_hat
#' @export
predict.my_multiclass <- function(object, X){
K <- length(object$coefficients)
Y_hat_tmp <- matrix(NA, ncol = K, nrow = nrow(X))
for(i in seq_len(K)){
eta_hat <- X %*% object$coefficients[[i]]
Y_hat_tmp[, i] <- exp(eta_hat) / (1 + exp(eta_hat))
}
Y_hat <- object$y_lab[apply(Y_hat_tmp, 1, which.max)]
return(list(Y_hat = Y_hat))
}
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.