R/predict-my-grad-descent.R

Defines functions predict.my_grad_descent

Documented in predict.my_grad_descent

#' @title Predict Method For Linear Model with Gradient Descent Method Fits
#' @description Predicted values based on the object from my_grad_descent().
#' @param object Object of class inheriting from "my_grad_descent"
#' @param ... further arguments passed to or from other methods
#'
#' @examples
#' data(iris)
#' fit <- my_grad_descent(Sepal.Length ~ ., data = iris)
#' predict(fit, iris)
#' @export

# predict.my_grad_descent <- function(object, newdata = NULL, ...){
#   dots <- list(...)
#   data <- dots[[1]]
#   if(is.null(newdata)){
#     if(!inherits(data, "data.frame")){
#       stop("Third argument must be a data frame.")
#     }
#     m <- model.matrix(object$form, data)
#     m %*% object$coefficients
#   }else{
#     newform <- paste0(as.character(form)[1], as.character(form)[3])
#     m <- model.matrix(formula(newform), newdata)
#     m %*% object$coefficients
#   }
# }

predict.my_grad_descent <- function(object, ...){
  dots <- list(...)
  data <- dots[[1]]
  if(!inherits(data, "data.frame")){
    stop("Third argument must be a data frame.")
  }
  m <- model.matrix(object$form, data)
  m %*% object$coefficients
}
tqchen07/bis557 documentation built on Dec. 21, 2020, 3:06 a.m.