library(ggplot2) library(RegularizedFits)
RegularizedFits is a package I created to fit trendlines to data without overfitting. It does this by not minimizing the square error as a normal linear model would, but by minimizing a 'regularized' error that adds a term equal to the sum of the square of the coefficients:
$$ \sum_{i=1}^n \left(y_i-\left(\sum_{j=0}^d a_jx_i^j\right)\right)^2+\lambda\sum_{j=0}^da_j^2 $$
Note the first summation is just your normal squared error, and $\lambda$ is a weight that can be adjusted depending on how worried you are about overfitting.
This package consists of two functions. The first is reg.lm()
(analogous to lm()
) which fits a polynomial model by minimizing the above equation. The second function is crossval()
, which will find the 'best' value of $\lambda$ from a collection that you provide it.
reg.lm()
I'll first show how reg.lm()
works. We'll work with the following (noisy) data:
x <- 0:5 y <- (0:5)^2 + c(-0.40, 3.50, 1.80, 10.00, -9.75, 13.80) data <- data.frame(y,x)
If we use lm()
to fit a cubic model to this data, we'll see that it reacts too strongly to the noice and overfits:
fit <- lm(y ~ cbind(x, x^2,x^3)) fit <- as.numeric(fit$coefficients) fit p1 <- ggplot(data, aes(x = x, y = y)) + geom_point() + stat_function(fun = function(x) fit[1] + fit[2]*x + fit[3]*x^2 + fit[4]*x^3, aes(color = 'unregularized')) + stat_function(fun = function(x) x^2, color = 'grey') + theme(legend.position = "bottom") + theme_bw() p1
Now let's apply reg.lm()
to fit a cubic to the data (we will use the default $\lambda=1$):
regcube <- reg.lm(y, x, degree = 3) regcube
Comparing the coefficients, we can already see that this is much better at picking out the quadratic pattern in the data.
regfit <- regcube$coefficients p2 <- p1 + stat_function(fun = function(x) regfit[1] + regfit[2]*x + regfit[3]*x^2 + regfit[4]*x^3, aes(color = 'regularized')) p2
crossval()
Now we will look at how to select a value of $\lambda$ using crossval()
. crossval()
works by using a so-called 'cross-validation' procedure. There's a lot of resources out there that summarize what this means (e.g. wikipedia), but in essence, it divides the data into a number of partitions, then applies reg.lm()
to each partition and evaluates the quality of a fit by how well it predicts data for the rest of the partitions.
If we apply it to the above data:
cv.fit <- crossval(y, x, degree = 3, lambda = c(0.1, 1, 10)) cv.fit cv.fit <- cv.fit$coefficients
p2 + stat_function(fun = function(x) cv.fit[1] + cv.fit[2]*x + cv.fit[3]*x^2 + cv.fit[4]*x^3, aes(color = 'cross validated'))
We can see that the cross validated fit is decently close to the actual data within the range we're considering, but also (by eye-test) the closest to the underlying function outside the domain we're considering, which is perhaps the biggest application of these functions.
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.