knitr::opts_chunk$set(echo = FALSE,fig.width=11,fig.height=7,cache=FALSE, 
               warning=F, message=FALSE, fig.align='center')
library(dplyr)
library(magrittr)
# library(FAM)
devtools::load_all() # at least SHA b9ea7e8 
library(reshape2)
library(whoppeR)
library(grid)
library(gridExtra)
library(ggplot2)
library(xtable)
library(tidyr)
library(optimx)
study_beta <- function(mem, nFeatures, LR, FR = NULL) {

  mxn <- prod(dim(mem))
  mem[nFeatures - mem  < 1] <- nFeatures-1.1
  toGo <- nFeatures - mem
  beta_pars <- betaParams(mean = LR, sd = sqrt((LR*(1-LR))/toGo))
  mem <- mem + (rbeta(mxn, beta_pars$a, beta_pars$b) * toGo)

  if (!is.null(FR) && FR != 0) {
    mem[mem < 1] <- 1
    beta_pars <- betaParams(mean = FR, sd = sqrt((FR*(1-FR))/mem))
    mem <- mem - (rbeta(mxn, beta_pars$a, beta_pars$b) * mem)
  }

  return(mem)
}

test_beta <- function(mem, nFeatures, thresh, acc, LR, TR, FR=NULL) {

  mxn <- prod(dim(mem))
  nCor = sum(acc)

  # memory feature updating
  mem[nFeatures - mem < 1] <- nFeatures-1.1
  toGo <-nFeatures - mem
  binomVAR <- toGo*LR*(1-LR)
  binomM <- toGo*LR
  beta_pars <- betaParams(mean = LR, sd = sqrt((LR*(1-LR))/toGo))
  mem[acc] <- mem[acc] + (rbeta(nCor, beta_pars$a, beta_pars$b) * toGo[acc])

  # threshold updating
  thresh[thresh < 1] <- 1
  beta_pars <- betaParams(mean = TR, sd = sqrt(TR*(1-TR)/thresh))
  thresh[acc] <- thresh[acc] - (rbeta(nCor, beta_pars$a[acc], beta_pars$b[acc])
                               * thresh[acc])

  if (!is.null(FR) && FR != 0) {
    mem[mem < 1] <- 1
    beta_pars <- betaParams(mean = FR, sd = sqrt((FR*(1-FR))/mem))
    mem <- mem - (rbeta(mxn, beta_pars$a, beta_pars$b) * mem)
  }

  return(list(mem = mem,thresh = thresh))
}

# Make a local copy of functions from FAM package needed during fitting.
# This way, the functions can be exported to any parallel worker threads
# without needing to export the entire installed FAM package
CFR_PCR <- FAM::CFR_PCR
RTdist <- FAM::RTdist

accuracyErrorFcn <- function(free, fixed, accuracy_observations, fcn = CFR_PCR, ...) {

  if (!paramBounds(c(free,fixed)) | anyNA(c(free,fixed))) {
    return(1000000)
  }
  preds <- fcn(free=free,fixed=fixed, summarised = TRUE)
  err <- sum((preds$acc-accuracy_observations$acc)^2)
  return(err)
}

RT_ErrorFcn <- function(free, fixed, RT_observations, fcn = CFR_PCR, ...) {

  if (!paramBounds(c(free,fixed)) | anyNA(c(free,fixed))) {
    return(1000000)
  }
  preds <- fcn(free=free, fixed=fixed, summarised = FALSE)
  dist <- preds %>% group_by(class,obsOrder) %>% RTdist()
  completes <- dist%>%
    group_by(class) %>%
    summarise(complete = identical(unique(order), 1:c(free,fixed)['nList']))

  if (!all(completes$complete) ||
      nrow(completes) != length(unique(RT_observations$class))) {
    dist <- left_join(expand.grid(class=unique(RT_observations$class),
                                  order = 1:c(free,fixed)['nList'],
                                  RT=seq(.1,c(free,fixed)['Time'],.1),
                                  stringsAsFactors = FALSE),
                      dist,
                      by = c("class","order","RT"))
    dist$y[is.na(dist$y)] <- 0
  }

  likelihoods <- inner_join(select(RT_observations,
                                   class, order, RT = RTrounded),
                            dist,
                            by = c("class", "order", "RT"))
  err <- -sum(log(likelihoods$y))
  return(err)
}

minimization_routine <- function(model, erf = RT_ErrorFcn) {

  fit <- optimx(par = model$free[[j]],
                fn = erf,
                method="Nelder-Mead",
                itnmax=1000,
                control = list(maxit=1000,kkt=FALSE,
                               parscale = c(1,1,200,1,1,100,1)),
                fcn = model$obj, # passed to erf
                fixed = model$fixed, # passed to erf
                RT_observations = model$RT_observations[model$RT_observations$subject == j,], # passed to erf
                accuracy_observations = model$accuracy_observations[model$accuracy_observations$subject ==j,]) # passed to erf
  fit$subject = j
  return(fit)  
}

testing_routine <- function(model, erf = RT_ErrorFcn) {
  err <- erf(free = model$free[[j]],
             fixed = model$fixed,
             obj = model$fcn,
             RT_observations = model$RT_observations[model$RT_observations$subject == j,],
             accuracy_observations = model$accuracy_observations[model$accuracy_observations$subject ==j,])

  ## Manually create an optimx results data frame object
  fit <- structure(data.frame(t(c(model$free[[j]], value = err,subject=j))),
                   details = NULL,
                   maximize = NULL,
                   npar = length(model$free[[j]]),
                   follow.on = NULL,
                   class = c("optimx", "data.frame"))
  return(fit)
}

paramBounds <- function(p) {

  probability_params <- c("ER","LR","TR","FR")
  strict_positive_params <- c("Tmin","Tmax","lambda")
  above_one_params <- c("Ta","Tb")

  prob_check <- any(p[names(p) %in% probability_params] < 0, na.rm = TRUE) ||
                any(p[names(p) %in% probability_params] >= 1, na.rm = TRUE)
  strict_pos_check <- any(p[names(p) %in% strict_positive_params] <= 0, na.rm = TRUE)
  above_one_check <- any(p[names(p) %in% above_one_params] <= 1, na.rm = TRUE)

  if (any(prob_check, strict_pos_check, above_one_check)) {
    return(FALSE)
  } else {
    return(TRUE)
  }
}
raw_data <- CFR_allSs %>%
  ungroup() %>%
  filter(list != 0, score== 1,
         phase=='final' | (phase=='prac' & practice =='T')) %>%
  select_(.dots = c("subject","cond_list","class","target","resp","score","RT")) %>%
  group_by(subject,class,cond_list) %>%
  mutate(order = 1:n(),
         RTrounded = round(RT,1)) %>%
  ungroup() %>%
  complete(subject, class, order, cond_list, fill=list(score=0))

prob_atLeast_x <- raw_data %>%
  group_by(subject, class, order) %>%
  summarise(acc = mean(score)) %>%
  ungroup()

RT_byOutput_nRecalled <- raw_data %>%
  group_by(subject, class, cond_list) %>%
  summarise(nRecalled = sum(score)) %>%
  ungroup() %>%
  right_join(raw_data, by = c("subject", "class", "cond_list")) %>%
  filter(!is.na(RT)) %>%
  select(subject, class, cond_list, order, nRecalled, RT,RTrounded) %>%
  arrange(subject, class, cond_list, order)
free <- list(c(ER=.62,LR=.025,Ta=150,TR=.02,Tmin=1,Tmax=45,lambda=.5), #1
             c(ER=.525,LR=.05,Ta=75,TR=.01,Tmin=1,Tmax=45,lambda=.8), #2
             c(ER=.52,LR=.06,Ta=200,TR=.01,Tmin=1,Tmax=45,lambda=.75), #3
             c(ER=.525,LR=.05,Ta=125,TR=.01,Tmin=1,Tmax=45,lambda=.75), #4
             c(ER=.53,LR=.06,Ta=150,TR=.03,Tmin=1,Tmax=45,lambda=.8), #5
             c(ER=.52,LR=.09,Ta=125,TR=.02,Tmin=1,Tmax=45,lambda=.65), #6
             c(ER=.524,LR=.09,Ta=150,TR=.025,Tmin=1,Tmax=45,lambda=.9), #7
             c(ER=.54,LR=.06,Ta=350,TR=.01,Tmin=1,Tmax=45,lambda=.8), #8
             c(ER=.5,LR=.06,Ta=80,TR=.015,Tmin=1,Tmax=45,lambda=.9), #9 
             c(ER=.55,LR=.075,Ta=100,TR=.01,Tmin=1,Tmax=45,lambda=.75), #10
             c(ER=.52,LR=.075,Ta=150,TR=.02,Tmin=1,Tmax=45,lambda=.75), #11
             c(ER=.55,LR=.07,Ta=150,TR=.04,Tmin=1,Tmax=45,lambda=.85), #12
             c(ER=.57,LR=.045,Ta=125,TR=.03,Tmin=1,Tmax=45,lambda=.75), #13
             c(ER=.54,LR=.09,Ta=100,TR=.02,Tmin=1,Tmax=45,lambda=.75), #14
             c(ER=.52,LR=.09,Ta=200,TR=.01,Tmin=1,Tmax=45,lambda=.85), #15
             c(ER=.55,LR=.09,Ta=175,TR=.035,Tmin=1,Tmax=45,lambda=.65), #16
             c(ER=.52,LR=.12,Ta=100,TR=.035,Tmin=1,Tmax=45,lambda=.7), #17
             c(ER=.56,LR=.09,Ta=200,TR=.03,Tmin=1,Tmax=45,lambda=.75), #18
             c(ER=.53,LR=.09,Ta=125,TR=.02,Tmin=1,Tmax=45,lambda=.6), #19
             c(ER=.53,LR=.09,Ta=80,TR=.05,Tmin=1,Tmax=60,lambda=.75), #20
             c(ER=.515,LR=.09,Ta=100,TR=.025,Tmin=1,Tmax=45,lambda=.85), #21
             c(ER=.53,LR=.09,Ta=80,TR=.04,Tmin=1,Tmax=45,lambda=.65), #22
             c(ER=.56,LR=.05,Ta=200,TR=.03,Tmin=1,Tmax=45,lambda=.75), #23
             c(ER=.55,LR=.06,Ta=120,TR=.06,Tmin=1,Tmax=45,lambda=.75), #24
             c(ER=.57,LR=.015,Ta=250,TR=.03,Tmin=1,Tmax=45,lambda=.75), #25
             c(ER=.53,LR=.05,Ta=80,TR=.035,Tmin=1,Tmax=45,lambda=.85), #26
             c(ER=.53,LR=.02,Ta=50,TR=.06,Tmin=1,Tmax=45,lambda=.75), #27
             c(ER=.53,LR=.05,Ta=100,TR=.03,Tmin=1,Tmax=45,lambda=.75), #28
             c(ER=.51,LR=.06,Ta=150,TR=.03,Tmin=1,Tmax=45,lambda=.9), #29
             c(ER=.56,LR=.05,Ta=120,TR=.02,Tmin=1,Tmax=45,lambda=.7), #30
             c(ER=.505,LR=.08,Ta=150,TR=.04,Tmin=1,Tmax=45,lambda=.7), #31
             c(ER=.52,LR=.08,Ta=150,TR=.02,Tmin=1,Tmax=60,lambda=.95), #32
             c(ER=.57,LR=.06,Ta=120,TR=.02,Tmin=1,Tmax=45,lambda=.7), #33
             c(ER=.54,LR=.09,Ta=120,TR=.06,Tmin=1,Tmax=45,lambda=.65)) #34

if (!all(unlist(lapply(lapply(free,names), setequal, params$pars)))) {
  stop("Model specs in header don't match free parameters. Edit header or list of free parameters")
}

mname <- strsplit(as.character(params$model),"::")[[1]]
fname <- paste(mname[length(mname)],
               paste0(params$pars,collapse="_"),
               sep="_")
fpath <- file.path("..","data",paste(fname,'rds',sep='.'))

# Filter RTs to those between 2.5th and 97.5th percentile
# So we don't fit the most extreme RT points
model <- list(free = free,
              fixed =  params$fixed,
              obj = eval(parse(text=params$model)),
              RT_observations = filter(RT_byOutput_nRecalled,
                                       RT < quantile(RT, .975) & RT > quantile(RT, .025)),
              accuracy_observations = prob_atLeast_x)

if (params$routine =="minimize") {
  routine <- minimization_routine
} else {
  routine <- testing_routine
}
subjects <- unique(raw_data$subject)

 if (!file.exists(fpath))  {

  if (params$inpar) {

    cluster <- doParallelCluster(detectCores()-1)
    message(paste("Log at", cluster$logfile))
    cat(paste('[',Sys.time(),']',"INIT parlog"),
        file=cluster$logfile, sep='\n')
    clusterExport(cluster$handle, c("accuracyErrorFcn", "RT_ErrorFcn","routine",
                                    "paramBounds","study_beta", "test_beta",
                                    "model", "CFR_PCR", "RTdist"))
    clusterCall(cluster$handle, options, warn = 1)
    results <- foreach(j = subjects, .verbose=T,
                       .packages=c("optimx","PCR","whoppeR", "magrittr",
                                   "dplyr","reshape2","tidyr"))  %dopar% {

      cat(paste('[', Sys.time(), ']', "Fitting subject", j),
          sep='\n')
      results <- try(routine(model))
      results
    }

    stopCluster(cluster$handle)
    cat(paste('[',Sys.time(),']',"Finished Fitting"),
        file=cluster$logfile, sep='\n', append = TRUE)

  } else {

    results <- vector(mode='list', length = length(subjects))
    for (j in subjects) {
      message(paste("Fitting subject", j))
      results[[j]] <-  try(routine(model))
    }
  }

  if (params$routine =="minimize") {
      saveRDS(results, file = fpath)
  }

} else {

  results <- readRDS(fpath)
}
raw_predictions <- lapply(results, function(x) {
  CFR_PCR(free = coef(x)[1,], fixed = model$fixed, summarised=FALSE)
  }) %>%
  bind_rows(.id="subject") %>%
  mutate(subject = as.numeric(subject)) %>%
  rename(order = obsOrder)
RT_preds <- raw_predictions %>%
  filter(!is.na(order)) %>%
  group_by(subject, class, nRecalled, order) %>%
  summarise(RT = median(obsRT),
            N = n()) %>%
  ungroup() %>%
  select(subject, class, order, nRecalled,  RT, N)

# Summarise the RTs by collapsing across subjects
RT_byOutput_nRecalled <- RT_byOutput_nRecalled %>%
  group_by(subject, class, nRecalled, order) %>%
  summarise(RT = median(RT, na.rm = TRUE),
            N = n())

RT_byOutput_nRecalled <- bind_rows(mutate(RT_preds, type = "model"),
                                   mutate(RT_byOutput_nRecalled, type = "real"))
rm(RT_preds)

# To find the probability of recalling at least 1, 2, 3, ... 14, 15 items from the simulated
# data, we count up how many observations we have at each possible output position (for each
# subject in each condition )and divide it by the maximum number of times an observation at
# that output position could have possibly been observed (which is equal to the number of
# simulated lists).
#
# We remove simulated list items that were not recalled, since their output position is
# not observed.
#
# For example, if we observed 7 outputs in the 14th position out of 1000 simulated lists, then
# the probability of recalling at least 14 items would be 7/1000 = .007

acc_preds <- raw_predictions %>%
  filter(!is.na(order)) %>%
  group_by(subject, class, order) %>%
  summarise(acc = n()/params$fixed['nSim']) %>%  # max(sim) should equal nSims == 1000
  ungroup() %>%
  complete(subject, class, order, fill = list(acc = 0)) %>%
  mutate(type = "model")

prob_atLeast_x %<>% mutate(type="real") %>%
  bind_rows(acc_preds)
rm(acc_preds)

Fixed Specs

print(xtable(t(as.matrix(model$fixed)), digits=3,
             caption = paste("Fixed Parameters")),
      type = "html", include.rownames=FALSE, caption.placement="top")

Model Results by Subject

class_labels <- c(np = "No Practice",
                  sp = "Study Practice",
                  tp = "Test Practice")
type_labels <- c(model = "PCR Model",
                 real = "Obs. Data")

for (i in unique(raw_predictions$subject)) {
  rtPlot <- ggplot(data = filter(RT_byOutput_nRecalled, subject==i),
                   aes(x=factor(order),
                       y=RT,
                       group = nRecalled,
                       color = factor(nRecalled))) +
    geom_point() +
    geom_line() +
    facet_grid(class ~ type, labeller = labeller(class = class_labels,
                                                 type = type_labels),
               scales = "free_y") +
    scale_y_continuous("Median Inter-Retreival Time of Item x") +
    scale_x_discrete("Output Postion (x)") +
    scale_color_discrete("Items Recalled") +
    ggtitle("Data vs. PCR: Reaction Time")

  accPlot <- ggplot(filter(prob_atLeast_x, subject==i),
                   aes(x=factor(order), y=acc, group=type, linetype=type)) +
    geom_point() +
    geom_line() +
    facet_grid(class ~ ., labeller = labeller(class = class_labels)) +
    scale_y_continuous("Probability of Recalling at Least x Items") +  
    scale_x_discrete("Output Position (x)") + 
    scale_linetype_discrete(name="", labels = c(model = "PCR Model",
                                                real = "Obs. Data")) +
    ggtitle("Data vs. PCR: Accuracy")

  cat(paste("<h4 class='subid'>","Subject ", i, "</h4>",sep=""))
  print(xtable(t(data.frame(free[[i]])), digits=3, caption = paste("Starting Parameters")),
          type = "html", include.rownames=FALSE, caption.placement="top")    
  print(xtable(results[[i]], digits=3, caption = paste("Best Parameters")),
          type = "html", include.rownames=FALSE, caption.placement="top")    
  print(accPlot)
  print(rtPlot)

}

Averaged Results

prob_atLeast_x_collapsed <- prob_atLeast_x %>%
  group_by(class,order,type) %>%
  summarise(sd = sd(acc),
            acc = mean(acc),
            nSubs = n())
RT_byOutput_nRecalled_collapsed <- RT_byOutput_nRecalled %>%
  group_by(class, order, nRecalled, type) %>%
  summarise(MAD = mad(RT, constant = 1),
            RT = median(RT),
            nSubs = n())
rtPlot %+% RT_byOutput_nRecalled_collapsed +
  geom_point() +
  geom_errorbar(aes(ymin = RT - MAD,
                    ymax = RT + MAD),
                linetype=1,
                width=.3)

accPlot %+% prob_atLeast_x_collapsed +
  geom_errorbar(aes(ymin = acc - sd/sqrt(nSubs),
                    ymax = acc + sd/sqrt(nSubs)),
                linetype=1,
                width=.3)

avg_res <- summarise_each(do.call(rbind,results),funs(mean))
print(xtable(as.matrix(avg_res[1:which(names(avg_res)=='value')]),
             digits=3, caption = paste("Average Parameters")),
        type = "html", include.rownames=FALSE, caption.placement="top")
medianRT_by_output_order <- bind_rows(
  select(raw_predictions, subject, class, order, RT = obsRT) %>% mutate(type = "model"),
  select(raw_data, subject, class, order, RT) %>% mutate(type = "real")
  ) %>%
  filter(!is.na(order)) %>%
  group_by(class, order, type) %>%
  summarise(RT = median(RT, na.rm = TRUE),
            N = n()) %>%
  ungroup()
ggplot(medianRT_by_output_order, aes(x=order, y=RT, color=class, linetype=type)) +
  geom_point() +
  geom_line() +
  facet_grid(class ~ ., labeller = labeller(class = class_labels))
binned <- RT_byOutput_nRecalled %>%
  filter(nRecalled >=3) %>% 
  group_by(class,type) %>%
  mutate(bin = cut(nRecalled, breaks = c(3,6, 10, 15),include.lowest = TRUE)) %>%
  group_by(bin,order,class,type) %>%
  summarise(RT = median(RT))

ggplot(binned, aes(x=order, y=RT, color=bin, linetype=type)) +
  geom_point() +
  geom_line() +
  facet_grid(class ~.,  labeller = labeller(class = class_labels))
x <- RT_preds %>%
  group_by(class, order, nRecalled) %>%
  summarise(count = n(),
            MAD = mad(RT,na.rm=T,constant = 1),
            median_RT = median(RT,na.rm=TRUE),
            mean_RT = mean(RT,na.rm=TRUE),
            sd_RT = sd(RT,na.rm=TRUE))
forward <- ggplot(filter(x, nRecalled > 2),
                  aes(x=order, y=mean_RT, color=factor(nRecalled))) +
  geom_point() +
  geom_line(size = .75) +
  facet_grid(class~., labeller = labeller(class_labels)) +
  scale_x_continuous(breaks = 1:15) +
  scale_color_discrete("Items Recalled") +
  ylab("Mean IRT") +
  xlab("Output Order") +
  ggtitle("Forward IRTs by Recall Total")
print(forward)


wjhopper/FAM documentation built on May 4, 2019, 7:33 a.m.