R/fit-family.R

#==================================================================================================#
# Family fitting function
#==================================================================================================#

#' @title Fit family
#'
#' @author Ryan Thompson <ryan.thompson1@unsw.edu.au>
#'
#' @description Fits a family of centers.
#'
#' @param x a numeric vector of data
#' @param w a numeric vector of weights
#' @param family the location family; currently only allows 'huber' for Huber family
#' @param spread.fun a function used for the spread of \code{x}; must accept data \code{x} and
#' weights \code{w} (in that order), and return a numeric
#' @param eps a numerical tolerance parameter
#'
#' @return An object of class \code{fit.family}; a data frame with the following columns:
#' \item{mu}{the fitted values}
#' \item{lambda}{the thresholding parameter}
#'
#' @example R/examples/example-fit-family.R
#'
#' @export

fit.family <- \(x, w = rep(1, length(x)), family = 'huber', spread.fun = weighted.mad,
                eps = .Machine$double.eps) {

  # Check arguments are valid
  family <- match.arg(family)
  if (anyNA(x)) stop('x must not contain NA values')
  if (any(w < 0)) stop('w must not contain negative values')
  if (all(w == 0)) stop('w must contain at least one positive value')
  if (length(x) != length(w)) stop('x and w must have the same length')

  # Calculate mean and median
  mean. <- stats::weighted.mean(x, w)
  median. <- weighted.median(x, w)

  # Compute family
  if (family == 'huber') {
    result <- huber.family(x, w, median., eps)
    spread.x <- spread.fun(x, w)
    if (is.nan(spread.x) | spread.x == 0) spread.x <- 1 # Handles constant x
    result$lambda <- result$lambda / spread.x
  }

  # Return result
  class(result) <- c('fit.family', 'data.frame')
  attributes(result)$family <- family
  attributes(result)$mean <- mean.
  attributes(result)$median <- median.
  return(result)

}

# Huber family
huber.family <- \(x, w, med, eps) {

  # Compute important quantities upfront
  n <- length(x)
  w <- w / sum(w)
  s <- sign(x - med)

  # Initialize variables
  lambda <- mu <- numeric(n)
  mu[1] <- sum(x * w)
  lambda[1] <- max(abs(x - mu[1]))

  for (m in 2:n) {

    # Update active set
    d <- lambda[m - 1] - pmin(s * (x - mu[m - 1]), lambda[m - 1])
    A <- d ^ 2 < eps

    # Stop if finished
    if (sum(A) == n) {m <- m - 1; break}

    # Compute gradient and step size
    eta <- sum(w[A] * s[A]) / sum(w[!A])
    gamma <- min(d[!A] / (1 + s[!A] * eta))

    # Update mu and lambda
    mu[m] <- mu[m - 1] - gamma * eta
    lambda[m] <- lambda[m - 1] - gamma

  }

  # Return result
  data.frame(mu = mu[1:m], lambda = lambda[1:m])

}

#==================================================================================================#
# Plot function for fit.family object
#==================================================================================================#

globalVariables(c('mu', 'center'))

#' @title Plot function for \code{fit.family} object
#'
#' @author Ryan Thompson <ryan.thompson@monash.edu>
#'
#' @description Plot a fitted family.
#'
#' @param x an object of class \code{fit.family}
#' @param y an object of class \code{fit.family}
#' @param ... any other arguments
#'
#' @return A plot of the fitted family.
#'
#' @method plot fit.family
#'
#' @export
#'
#' @importFrom graphics "plot"

plot.fit.family <- \(x, y = NULL, ...) {

  # Interpolate x and y at same lambda sequence
  if (!is.null(y)) {
    # if (attributes(x)$family != attributes(y)$family) {
    #   stop('x and y must belong to the same family')
    # }
    family <- attributes(x)$family
    mean.diff <- attributes(x)$mean - attributes(y)$mean
    median.diff <- attributes(x)$median - attributes(y)$median
    lambda <- union(x$lambda, y$lambda)
    x <- stats::approx(x$lambda, x$mu, lambda, yleft = x$mu[which.min(x$lambda)],
                       yright = x$mu[which.max(x$lambda)],
                       method = ifelse(length(x$mu) > 1, 'linear', 'constant'))
    y <- stats::approx(y$lambda, y$mu, lambda, yleft = y$mu[which.min(y$lambda)],
                       yright = y$mu[which.max(y$lambda)],
                       method = ifelse(length(y$mu) > 1, 'linear', 'constant'))
    x <- data.frame(mu = x$y - y$y, lambda = lambda)
    attributes(x)$family <- family
    attributes(x)$mean <- mean.diff
    attributes(x)$median <- median.diff
  }

  # Plot family
  family.name <- tools::toTitleCase(attributes(x)$family)
  x$center <- family.name
  x.mean <- data.frame(mu = attributes(x)$mean, lambda = c(0, Inf), center = 'Mean')
  x.median <- data.frame(mu = attributes(x)$median, lambda = c(0, Inf), center = 'Median')
  x <- rbind(x, x.mean, x.median)
  x$center <- factor(x$center, c(family.name, 'Mean', 'Median'))
  ggplot2::ggplot(x, ggplot2::aes(lambda, mu, linetype = center)) +
    ggplot2::geom_line() +
    ggplot2::xlab(expression(lambda)) +
    ggplot2::ylab(ifelse(is.null(y), expression(mu(lambda)),
                         expression(mu[X](lambda)-mu[Y](lambda)))) +
    ggplot2::labs(linetype = 'Center')

}

Try the familial package in your browser

Any scripts or data that you put into this service are public.

familial documentation built on July 9, 2023, 7:27 p.m.