R/get_loss.R

Defines functions get.loss

Documented in get.loss

#' A function that calculates the loss/cost
#'
#' @param y The response vector.
#' @param pred The predicted values for the response.
#' @param type A string indicating the type of regression model (linear or binomial).
#'
#' @returns The loss of the input vectors.
#'
get.loss <- function(y, pred, type) {

  n <- length(y)
  if (type == "linear") {
    loss <- (y-pred)^2
  } else {
    pred[pred < 0.00001] <- 0.00001
    pred[pred > 0.99999] <- 0.99999

    if (is.matrix(pred)) {
      loss <- matrix(NA, nrow = nrow(pred), ncol = ncol(pred))
      loss[y == 1, ] <- -2 * log(pred[y == 1, , drop = FALSE])
      loss[y == 0, ] <- -2 * log(1-pred[y == 0, , drop = FALSE])
    } else {
      loss <- double(length(y))
      loss[y == 1] <- -2 * log(pred[y == 1])
      loss[y == 0] <- -2 * log(1 - pred[y == 0])
    }
  }
  loss
}

Try the SGPR package in your browser

Any scripts or data that you put into this service are public.

SGPR documentation built on May 29, 2024, 5:27 a.m.