Intro

La ultima clase derivamos la forma mas simple de un problema de maquinas de soporte vectorial. Por ejemplo, supongamos que tenemos los siguientes datos

X <- data.frame(
  x1 = c(-1, 0, -0.25, 1.5, 0.5),
  x2 = c(-0.5, 0, 1, 2, 1)
)
y = c(-1, -1, -1, 1, 1)
library(ggplot2)
plot_data <- data.frame(X, y = factor(y))
p <- ggplot(plot_data, aes(x = x1, y = x2, colour = y)) +
  geom_point(size = 6, alpha = 0.3)
p

El problema de clasificacion de svm para separar los puntos y maximizar el margen entre ellos es $$ \begin{aligned} \min_{w,b} \; &\; \frac{1}{2}\lVert w \rVert^2 \ s.a. \; & \; y_i(w^\top x_i + b) \geq 1 \quad \forall i \end{aligned} $$ donde $y_i$ es el signo de clasificacion del individuos $i$ y $x_i$ es su conjunto de datos, i.e., la $i$-esima fila de $X$. Encontrar estos vectores $w$ y $b$ equivale a maximizar el margen de separacion entre puntos.

Solucion problemas sencillos

Para problemas sencillos podemos solucionar usando la libreria de optimizacion nloptr y optimizando directemante

library(nloptr)
eval_f <- function(x) {
  w <- x[-length(x)]
  0.5*sum(w^2)
}
eval_grad_f <- function(x) c(x[-length(x)], 0) # the gradient w,0
eval_g_ineq <-  function(x) {
  w <- x[-length(x)]
  b <- x[length(x)]
  y*(as.matrix(X)%*%w + b) - 1# default g(x) >= 0
}
eval_jac_g_ineq <- function(x) {
  do.call("rbind", lapply(1:nrow(X), function(i) y[i]*c(as.numeric(X[i, ]), 1)))
}
x0 = rep(0, ncol(X) + 1) # initial guess
res = slsqp(
  x0 = x0,
  fn = eval_f,
  gr = eval_grad_f,
  hin = eval_g_ineq,
  hinjac = eval_jac_g_ineq
)
w = res$par[1:ncol(X)]
b = res$par[ncol(X) + 1]

Active constrains

active = eval_g_ineq(res$par) < 10e-6
active
ablines = data.frame(
  slope = rep(-w[1]/w[2], 3),
  intercept = c(-b, 1-b, -1-b) / w[2],
  linetype = c("yi(w'xi + b)=0", "yi(w'xi + b)=1", "yi(w'xi + b)=1")
)
plot_data <- data.frame(X, y = factor(y), active = active)
p <- ggplot(plot_data, aes(x = x1, y = x2, colour = y, shape = active)) +
  geom_point(size = 6, alpha = 0.3) +
  geom_abline(data = ablines, aes(slope = slope, intercept = intercept, linetype = linetype))
p


mauriciogtec/metodosMultivariados2017 documentation built on May 21, 2019, 1:37 p.m.