R/my-glm.R

Defines functions my_glm

Documented in my_glm

#' @title Solve Generalized Linear Model with Newton-Ralphson Method
#' @description This is a simple algorithm to solve a generalized linear regression model.
#' @param X design matrix
#' @param Y a vector of outcome
#' @param family an R 'family' object.
#' @param step_option a character string of "constant" or "momentum", if "momentum", then gamma should be specified
#' @param lambda a numeric number indicating the learning rate for gradient descent algorithm
#' @param gamma a fraction for momentum step size
#' @param tol a numeric number indicating the precision of the algorithm
#' @param maxit an integer indicating the maximum number iterations
#'
#' @examples
#' n <- 1000; p <- 3
#' beta <- c(0.2, 2, 1)
#' X <- cbind(1, matrix(rnorm(n * (p-1)), ncol = (p-1) ))
#' mu <- 1 - pcauchy(X %*% beta)
#' Y <- as.numeric(runif(n) > mu)
#'
#' df <- as.data.frame(cbind(Y, X))
#'
#' fit0 <- glm(Y ~ . -1, data = df, family = binomial(link = "cauchit"))
#' fit1 <- my_glm(X, Y, family = binomial(link = "cauchit"),
#'                step_option = "constant", lambda = 0.001)
#' fit2 <- my_glm(X, Y, family = binomial(link = "cauchit"),
#'                step_option = "momentum", lambda = 0.001, gamma = 0.2)
#'
#' cbind(fit0$coefficients, fit1$coefficients, fit2$coefficients)
#' @export

my_glm <- function(X,
                   Y,
                   family,
                   step_option = c("constant", "momentum"),
                   lambda = 0.001,
                   gamma = NULL,
                   maxit = 2e3,
                   tol = 1e-12){

  beta <- rep(0, ncol(X))

  for(i in seq_len(maxit)){
    beta_old <- beta
    eta <- X %*% beta
    mu <- family$linkinv(eta)
    score <- t(X) %*% (Y - mu)

    if(step_option == "constant"){
      beta <- beta + lambda * score
    }else if(step_option == "momentum"){
      beta <- beta + sum(gamma^(1:i-1)) * lambda * score
    }

    if(sqrt(crossprod(beta - beta_old)) < tol) break
  }

  rslt <- list(coefficients = beta, iter = i)
  class(rslt) <- "my_glm"

  return(rslt)
}
tqchen07/bis557 documentation built on Dec. 21, 2020, 3:06 a.m.