knitr::opts_chunk$set(
  warning = FALSE,
  message = FALSE,
  fig.width = 7,
  fig.height = 4.5,
  fig.align = "center"
)

rm(list=ls())

require(semisupervisr)
require(ggplot2)
theme_set(theme_bw())
require(dplyr)
require(tidyr)
require(e1071)

Introduction

This document gives an example of semi-supervised learning using pivot features. An iterative learning process is carried out with the iter_pivot function and compared against a standard random forest learner.

# set seed for reproducibility (this seems to produce a good data set)
#set.seed(3234)

# simulate pivot feature data
pivot_data <- gen_pivot_data(n_nonpivots = 1,
                             n_pivots = 1,
                             n_domains = 2,
                             n_classes = 2,
                             n = 200,
                             sd_class_means = 2,
                             sd_np_means = 10,
                             sd_obs = 1)
plot(pivot_data)

ggplot(pivot_data, aes(x = P_Feature_1, y = NP_Feature_1, 
                       colour = Class, shape = Domain)) +
  geom_point() +
  labs(title = "Simulated data.",
       caption = "Data simulated using semisupervisr package's gen_pivot_data function.")

train <- pivot_data[pivot_data$Domain == "Domain_1",]
test <- pivot_data[pivot_data$Domain == "Domain_2",]

Two SVM classifiers are trained. One uses only the pivot features and the other uses all features.

# Trains standard random forest using both features
fit_svm_all <- svm(Class ~ P_Feature_1 + NP_Feature_1, train)
fit_svm_pivot <- svm(Class ~ P_Feature_1, train)

test$Prediction_SVM_All <- predict(fit_svm_all, test)
test$Prediction_SVM_Pivot <- predict(fit_svm_pivot, test)

ggplot(test, aes(x = P_Feature_1, y = NP_Feature_1,
                       colour = Prediction_SVM_All, shape = Class)) +
  geom_point() +
  facet_wrap(~Domain) +
  labs(title = "Random forest predictions",
       subtitle = "Fit using all features")
ggplot(test, aes(x = P_Feature_1, y = NP_Feature_1,
                       colour = Prediction_SVM_Pivot, shape = Class)) +
  geom_point() +
  facet_wrap(~Domain) +
  labs(title = "Random forest predictions",
       subtitle = "Fit using pivot features only")

Now we can predict results using the iter_pivot function which is an interative algorithm. We set the probability threshold to 0.75. This seems to perform better than a higher cutoff when dealing with more than 2 output classes for some reason. TODO - investigate why and run a repeated simulation with different p_threshold values to check.

iter_model <- iter_pivot(source = train,
                         target = test,
                         x_pivot = "P_Feature_1",
                         x_nonpivot = "NP_Feature_1",
                         rho = 3)

plot(iter_model, aes(x = P_Feature_1, y = NP_Feature_1), type = "Learning") +
  geom_point(data = train, aes(x = P_Feature_1, y = NP_Feature_1, colour = Class))

test$Prediction_Iter <- predict(iter_model, test[c("P_Feature_1", "NP_Feature_1")])
test %>% 
  ggplot(aes(x = P_Feature_1, y = NP_Feature_1, colour = Prediction_Iter, 
             shape = Class)) + 
  geom_point() + 
  facet_wrap(~Domain) +
  labs(title = "Iterative learner predictions")

Results

In general the iterative learning process outperforms the SVM classifiers. We do see a drop in performance on the source domain but the target domain has improved classifications. We do not care about the source domain drop in performance as it is already labelled (so we wouldn't be doing any predictions on it). Effectively, this algorithm seems to be grabbing the useful bits from the source domains for the purpose of classifying observations in the target.

#with(pivot_data, table(Domain, Class == Predictionsvm_All))
#with(pivot_data, table(Domain, Class == Predictionsvm_Pivot))
test %>% 
  select(starts_with("Prediction"), Class) %>% 
  gather(Model, Prediction, -Class) %>% 
  count(Model, Class, Prediction) %>% 
  spread(Prediction, n, fill = 0) %>% 
  ungroup()

TODO

Run replicate with different parameter values to better determine how well this is performing.



camroach87/semisupervisr documentation built on May 13, 2019, 11:04 a.m.