library(ggplot2)
library(RegularizedFits)

RegularizedFits is a package I created to fit trendlines to data without overfitting. It does this by not minimizing the square error as a normal linear model would, but by minimizing a 'regularized' error that adds a term equal to the sum of the square of the coefficients:

$$ \sum_{i=1}^n \left(y_i-\left(\sum_{j=0}^d a_jx_i^j\right)\right)^2+\lambda\sum_{j=0}^da_j^2 $$

Note the first summation is just your normal squared error, and $\lambda$ is a weight that can be adjusted depending on how worried you are about overfitting.

This package consists of two functions. The first is reg.lm() (analogous to lm()) which fits a polynomial model by minimizing the above equation. The second function is crossval(), which will find the 'best' value of $\lambda$ from a collection that you provide it.

reg.lm()

I'll first show how reg.lm() works. We'll work with the following (noisy) data:

x <- 0:5
y <- (0:5)^2 + c(-0.40, 3.50, 1.80, 10.00, -9.75, 13.80)
data <- data.frame(y,x) 

If we use lm() to fit a cubic model to this data, we'll see that it reacts too strongly to the noice and overfits:

fit <- lm(y ~ cbind(x, x^2,x^3))

fit <- as.numeric(fit$coefficients)

fit

p1 <- ggplot(data, aes(x = x, y = y)) + geom_point() + 
    stat_function(fun = function(x) fit[1] + fit[2]*x + fit[3]*x^2 + fit[4]*x^3, aes(color = 'unregularized')) + stat_function(fun = function(x) x^2, color = 'grey') +
    theme(legend.position = "bottom") + theme_bw()

p1

Now let's apply reg.lm() to fit a cubic to the data (we will use the default $\lambda=1$):

regcube <- reg.lm(y, x, degree = 3)
regcube

Comparing the coefficients, we can already see that this is much better at picking out the quadratic pattern in the data.

regfit <- regcube$coefficients

p2 <- p1 + 
    stat_function(fun = function(x) regfit[1] + regfit[2]*x + regfit[3]*x^2 + regfit[4]*x^3, aes(color = 'regularized'))

p2

crossval()

Now we will look at how to select a value of $\lambda$ using crossval(). crossval() works by using a so-called 'cross-validation' procedure. There's a lot of resources out there that summarize what this means (e.g. wikipedia), but in essence, it divides the data into a number of partitions, then applies reg.lm() to each partition and evaluates the quality of a fit by how well it predicts data for the rest of the partitions.

If we apply it to the above data:

cv.fit <- crossval(y, x, degree = 3, lambda = c(0.1, 1, 10))

cv.fit

cv.fit <- cv.fit$coefficients
p2 + 
    stat_function(fun = function(x) cv.fit[1] + cv.fit[2]*x + cv.fit[3]*x^2 + cv.fit[4]*x^3, aes(color = 'cross validated'))

We can see that the cross validated fit is decently close to the actual data within the range we're considering, but also (by eye-test) the closest to the underlying function outside the domain we're considering, which is perhaps the biggest application of these functions.



arsbar24/reg-fit documentation built on May 14, 2019, 2 a.m.