knitr::opts_chunk$set(echo = TRUE)
require(tidyverse)
require(ggplot2)
require(tie)
require(janitor)

compute_cutoff <- function(rs, lhats, t=0) {
  rs <- rs[complete.cases(lhats) & complete.cases(rs)]; lhats <- lhats[complete.cases(lhats) & complete.cases(rs)]
  sr.ix <- sort(rs, decreasing=FALSE, index.return=TRUE)$ix
  # compute minimum value
  min.lhat <- min(lhats)
  # compute minimum value + 5%
  lhat.thresh <- (1 + t)*min.lhat
  # find which indices are all below this
  lhat.below <- which(lhats <= lhat.thresh)
  rs.below <- rs[lhat.below]; lhats.below <- lhats[lhat.below]
  tmin.ix <- min(rs.below, index.return=TRUE)
  return(list(r.star=rs.below[tmin.ix], Er.Rt.Star=lhats.below[tmin.ix]))
}


g_legend<-function(a.gplot){
  tmp <- ggplot_gtable(ggplot_build(a.gplot))
  leg <- which(sapply(tmp$grobs, function(x) x$name) == "guide-box")
  legend <- tmp$grobs[[leg]]
  return(legend)
}

Data Loading

lol.dat <- readRDS('../data/real_data/lda_results.rds') %>%
  dplyr::select(-c(xv, ntrain, repo, K, d)) %>%
  dplyr::rename(Dataset=exp, Algorithm=alg, Fold=fold, Er.Rt=lhat) %>%
  dplyr::mutate(Classifier="LDA")
lol.rc <- lol.dat %>%
  dplyr::filter(Algorithm == "RandomGuess") %>%
  dplyr::select(-Classifier) %>%
  dplyr::mutate(Algorithm=recode_factor(Algorithm, "RandomGuess"="RC")) %>%
  dplyr::rename(Classifier=Algorithm) %>%
  dplyr::select(-c(r))

lol.dat <- lol.dat %>%
  dplyr::filter(Algorithm != "RandomGuess")
# compute rstar and Lhatstar as the minimal dimension within
# 5% of the minimum misclassification rate
lol.dat.star <- lol.dat %>%
  dplyr::group_by(Algorithm, Classifier, Fold, Dataset, n) %>%
  bow(tie(r.star, Er.Rt.star) := compute_cutoff(r, Er.Rt, t=.05))

# horizontally merge the rstar, Lhatstar from LOL
# and the misclassification rate from random chance
lol.dat.prep <- lol.dat.star %>%
  dplyr::inner_join(lol.dat.star %>%
                      dplyr::filter(Algorithm == "LOL") %>%
                      ungroup() %>%
                      dplyr::rename(LOL.r.star=r.star, LOL.Er.Rt.star=Er.Rt.star) %>%
                      dplyr::select(-c(Algorithm,Classifier)),
                    by=c("Fold", "Dataset", "n")) %>%
  dplyr::inner_join(lol.rc %>%
                      dplyr::rename(RC.Er.Rt=Er.Rt) %>%
                      dplyr::select(-c(Classifier, n)),
                    by=c("Fold", "Dataset")) %>%
  # normalize
  mutate(r.star.norm=(LOL.r.star-r.star)/pmin(n, 100), 
         Lhat.star.norm=(LOL.Er.Rt.star-Er.Rt.star)/RC.Er.Rt) %>%
  # for each (Algorithm, Dataset) compute the mean normalized rstar, error rate
  # over folds
  dplyr::group_by(Algorithm, Dataset, n) %>%
  dplyr::summarize(r.star=mean(r.star.norm), Er.Rt.star=mean(Lhat.star.norm))
algs <-  c("LOL", "PLS", "CCA", "LDA", "PCA", "RP")
acols <- c("#e41a1c", "#377eb8", "#4daf4a", "#984ea3", "#ff7f00", "#a65628")
linestyle <- c("solid", "solid", "solid","solid", "solid", "solid")
names(linestyle) <- algs
names(algs) <- acols
names(acols) <- algs
#shapes <- c(21, 24, 21, 24, 23, 23, 21, 24, 23)
shapes <- c(21, 24, 21, 22, 21, 23)
names(shapes) <- algs

lol.dat.prep <- lol.dat.prep %>%
  mutate(Algorithm=recode_factor(Algorithm, "LRLDA"="LDA"))

realdat.panels <- function(data, xlims=c(-1, 1), ylims=c(-.35, .35), plot.title="(A) Performance on UCI/PMLB Benchmarks", 
                             xl="Relative # Embedding Dimensions", yl="Relative Error", leg.title="Algorithm",
                             legend.style=guide_legend(ncol=2, byrow=TRUE)) {
  data <- data %>% mutate(Dataset=factor(Dataset), Algorithm=factor(Algorithm))
  box <- data.frame(x=c(min(xlims), mean(xlims), mean(xlims), min(xlims)),
                    y=c(min(ylims), min(ylims), mean(ylims), mean(ylims)))
  data.medians <- data %>%
    dplyr::group_by(Algorithm) %>%
    dplyr::summarize(r.star=median(r.star), Er.Rt.star=median(Er.Rt.star))
  # table results
  tab <- data %>%
    dplyr::filter(Algorithm != "LOL") %>%
    dplyr::group_by(Algorithm) %>%
    dplyr::summarize(Q1=sum(r.star > 0 & Er.Rt.star > 0, na.rm=TRUE) + 
                       0.5*sum((r.star > 0 & Er.Rt.star == 0 | Er.Rt.star > 0 & r.star == 0), na.rm=TRUE) + 
                       .25*sum(r.star == 0 & Er.Rt.star == 0, na.rm=TRUE),
              Q2=sum(r.star < 0 & Er.Rt.star > 0, na.rm=TRUE) + 
                0.5*sum((r.star < 0 & Er.Rt.star == 0 | Er.Rt.star > 0 & r.star == 0), na.rm=TRUE) + 
                .25*sum(r.star == 0 & Er.Rt.star == 0, na.rm=TRUE), 
              Q3=sum(r.star < 0 & Er.Rt.star < 0, na.rm=TRUE) + 
                0.5*sum((r.star < 0 & Er.Rt.star == 0 | Er.Rt.star < 0 & r.star == 0), na.rm=TRUE) + 
                .25*sum(r.star == 0 & Er.Rt.star == 0, na.rm=TRUE),
              Q4=sum(r.star > 0 & Er.Rt.star < 0, na.rm=TRUE) + 
                0.5*sum((r.star > 0 & Er.Rt.star == 0 | Er.Rt.star < 0 & r.star == 0), na.rm=TRUE)+ 
                .25*sum(r.star == 0 & Er.Rt.star == 0, na.rm=TRUE),
              `Q3+Q4`=Q3+Q4,
              `n`=sum(Q1 + Q2 + Q3 + Q4)) %>%
    adorn_totals("row")
  print(tab)
  per.tab <- tab %>%
    dplyr::mutate_if(is.numeric, funs(./n)) %>%
    dplyr::select(-n)
  print(per.tab)

  center <- ggplot(data, aes(x=r.star, y=Er.Rt.star)) +
    geom_polygon(data=box, aes(x=x, y=y), fill='green', alpha=0.15) +
    geom_polygon(data=box, aes(x=-x, y=-y), fill='red', alpha=0.15) +
    geom_point(aes(x=r.star, y=Er.Rt.star, shape=Algorithm, color=Algorithm, fill=Algorithm), alpha=0.5, size=1.2) +
    geom_point(data=data.medians, aes(x=r.star, y=Er.Rt.star, 
                                      shape=Algorithm, color=Algorithm, fill=Algorithm), alpha=1.0, size=2.5) +
    scale_fill_manual(values=acols, guide=legend.style, name=leg.title) +
    scale_color_manual(values=acols, guide=legend.style, name=leg.title) +
    scale_shape_manual(values=shapes, guide=legend.style, name=leg.title) +
    ylab(yl) +
    xlab(xl) +
    labs(shape="Simulation", color="Algorithm") +
    ggtitle("") +
    scale_y_continuous(limits=ylims) +
    scale_x_continuous(limits=xlims) +
    theme_bw() +
    annotate("text", size=4, label="LOL better", color="darkgreen", x=-.5, y=.35) +
    annotate("text", size=4, label="LOL better", color="darkgreen", x=1.0, y=-.2, angle=-90) +
    annotate("text", size=4, label="LOL worse", color="red", x=.5, y=.35) +
    annotate("text", size=4, label="LOL worse", color="red", x=1.0, y=.2, angle=-90)
  leg <- g_legend(center)
  center <- center + theme(legend.position=NaN)
  right <- ggplot(data %>% filter(Algorithm != "LOL"), aes(x=Er.Rt.star, y=..scaled.., color=Algorithm)) +
    geom_density(size=1.1) +
    scale_color_manual(values=acols, guide=legend.style, name=leg.title) +
    geom_vline(xintercept=0, ymin=0, ymax=1, color=as.character(acols["LOL"]), size=1.1) +
    scale_fill_manual(values=acols, guide=legend.style, name=leg.title) +
    #scale_linetype_manual(values=linestyle, guide=legend.style, name=leg.title) +
    scale_x_continuous(limits=ylims) +
    ylab("Likelihood") +
    xlab("") +
    ggtitle("") +
    theme_bw() +
    theme(legend.position=NaN,
        axis.text.y=element_blank(),
        axis.text.x=element_blank(),
        axis.ticks.x=element_blank()) +
    coord_flip()
  top <- ggplot(data %>% filter(Algorithm != "LOL"), aes(x=r.star, y=..scaled.., color=Algorithm)) +
    geom_density(size=1.1) +
    scale_color_manual(values=acols, guide=legend.style, name=leg.title) +
    #scale_linetype_manual(values=linestyle, guide=legend.style, name=leg.title) +
    geom_vline(xintercept=0, ymin=0, ymax=1, color=as.character(acols["LOL"]), size=1.1) +
    scale_x_continuous(limits=xlims) +
    ylab("Likelihood") +
    xlab("") +
    ggtitle(plot.title) +
    theme_bw() + 
    theme(legend.position=NaN,
        axis.text.y=element_blank(),
        axis.ticks.y=element_blank(),
        axis.text.x=element_blank(),
        axis.ticks.x=element_blank())
  return(grid.arrange(top, leg, center + theme(legend.position=NaN), right, ncol=2, nrow=2, widths=c(4,1.5), heights=c(2,4)))
}

realdat.panels(lol.dat.prep)
test=lol.dat.prep %>%
    dplyr::mutate(LOL.better.Er=Er.Rt.star < 0, LOL.worse.Er=Er.Rt.star > 0, LOL.tied.Er= Er.Rt.star == 0) %>%
    dplyr::group_by(Algorithm) %>%
    dplyr::summarize(LOL.better=sum(LOL.better.Er) + 0.5*sum(LOL.tied.Er), LOL.worse=sum(LOL.worse.Er) + 0.5*sum(LOL.tied.Er))


neurodata/fselect documentation built on March 6, 2021, 12:54 p.m.