knitr::opts_chunk$set(
  warning = FALSE,
  message = FALSE,
  fig.height = 4.5,
  fig.width = 7.5,
  fig.align = "center"
)

rm(list=ls())

require(e1071)
require(semisupervisr)
require(dplyr)
require(ggplot2)

set.seed(123)
rho <- 3

pivot_data <- gen_pivot_data(n_nonpivots = 0,
                             n_pivots = 2,
                             n_domains = 2,
                             n_classes = 2,
                             n = 200,
                             sd_class_means = 3,
                             sd_np_means = 10,
                             sd_obs = 1) %>% 
  mutate(row_idx = row_number())

svm.model <- svm(Class ~ P_Feature_1 + P_Feature_2,
                 data=pivot_data,
                 type="C-classification",
                 kernel="linear", 
                 cost = 1,
                 scale=FALSE)


# get parameters of hyperplane
w <- t(svm.model$coefs) %*% svm.model$SV
b <- -svm.model$rho


# Point distances
#feature_cols <- names(pivot_data)[grepl("Feature" , names(pivot_data))]
feature_cols <- c("P_Feature_1", "P_Feature_2")
hyperplane_distance <- (w %*% t(pivot_data[, feature_cols]) + b )/norm(w, "2")
pivot_data$Distance <- as.vector(hyperplane_distance)
pivot_data$Prediction <- predict(svm.model, pivot_data)


margin_distance <- (w %*% c(0, (-b+1)/w[1,2]) + b )/norm(w, "2")
margin_distance <- margin_distance[1]

#pivot_data$InMargin <- if_else(abs(pivot_data$Distance) <= margin_distance, TRUE, FALSE)
pivot_data <- pivot_data %>% 
  mutate(InMargin = if_else(abs(pivot_data$Distance) <= margin_distance, TRUE, FALSE),
         PointHyperplanePos = if_else(Distance > 0, "Above", as.character(NA)),
         PointHyperplanePos = if_else(Distance <0, 'Below', PointHyperplanePos))


pivot_data %>% 
  arrange(abs(Distance)) %>% 
  head(length(svm.model$index))
svm.model$index

pivot_data %>% 
  ggplot(aes(x = P_Feature_1, y = P_Feature_2, shape = InMargin,
             colour = Class, label = row_idx)) + 
  geom_point() +
  #geom_text(vjust = 0, nudge_y = 0.1) +
  geom_abline(intercept=-b/w[1,2], slope=-w[1,1]/w[1,2], colour="blue") +
  geom_abline(intercept=(-b-1)/w[1,2], slope=-w[1,1]/w[1,2], colour="orange") +
  geom_abline(intercept=(-b+1)/w[1,2], slope=-w[1,1]/w[1,2], colour="orange") +
  geom_point(data = pivot_data[svm.model$index, c("P_Feature_1", "P_Feature_2")],
             inherit.aes = FALSE,
             aes(x = P_Feature_1, y = P_Feature_2),
             shape = 21, colour = "blue", size = 3) +
  labs(title = "SVM classifications",
       subtitle = "Blue circles indicate support vectors.",
       caption = "Data generated by gen_pivot_data function from semisupervisr package.")


# Get rho points furthest from hyperplane but within margin
new_training_points <- pivot_data %>% 
  filter(InMargin == TRUE) %>% 
  group_by(PointHyperplanePos) %>% 
  top_n(rho, wt = abs(Distance))

new_training_points %>% 
  ggplot(aes(x = P_Feature_1, y = P_Feature_2, shape = InMargin,
             colour = Class, label = row_idx)) + 
  geom_point() +
  geom_text(vjust = 0, nudge_y = 0.1) +
  geom_abline(intercept=-b/w[1,2], slope=-w[1,1]/w[1,2], colour="blue") +
  geom_abline(intercept=(-b-1)/w[1,2], slope=-w[1,1]/w[1,2], colour="orange") +
  geom_abline(intercept=(-b+1)/w[1,2], slope=-w[1,1]/w[1,2], colour="orange") +
  geom_point(data = pivot_data[svm.model$index, c("P_Feature_1", "P_Feature_2")],
             inherit.aes = FALSE,
             aes(x = P_Feature_1, y = P_Feature_2),
             shape = 21, colour = "blue", size = 3) +
  labs(title = "SVM classifications",
       subtitle = "Blue circles indicate support vectors.",
       caption = "Data generated by gen_pivot_data function from semisupervisr package.")


camroach87/semisupervisr documentation built on May 13, 2019, 11:04 a.m.