R/simulate_QL.R

Defines functions simulate_QL

Documented in simulate_QL

#' Simulate data from single and dual learning rate Q-learning models
#'
#' \code{simulate_QL} is a function to simulate data from 1-alpha and 2-alpha
#' Q-learning models, with an experiment structure identical to that run online.
#' The parameter values can be from (a sample of) those fitted previously to the
#' real data, or can be randomly sampled.
#'
#' @param summary_df [cmdstanr::summary()] output containing posterior mean
#' parameter values for a group of individuals. If left as NULL, a random sample
#' of size \code{sample_size} will be taken from specified distributions.
#' @param sample_size How may sets of parameters to sample; defaults to 100 or
#' the number of individuals in
#' the \code{summary_df}.
#' @param gain_loss Fit the dual learning rate model?
#' @param test Simulate test choices in addition to training choices?
#' @param affect Simulate subjective affect ratings (uses full passage-of-time
#' model).
#' @param time_pars Which time parameters to include in the affect model: either
#' \code{"none"}; or either/both \code{"overall"} (default) or \code{"block"}.
#' @param prev_sample An optional previous sample of id numbers (if you wish to
#' simulate data for the same subset of individual parameters across a number
#' of models).
#' @param raw_df Provide the raw data used to fit the data originally, so that
#' subject IDs can be labelled appropriately.
#' @param ... Other arguments which can be used to control the parameters of the
#' Beta/Gaussian distributions from which parameter values are sampled.
#'
#' @returns Simulated training data (and test data relevant) for a random or
#' previously fitted sample of parameter values.
#'
#' @importFrom stats plogis rbeta
#'
#' @examples
#' train_sim_2a <- simulate_QL(
#'   sample_size = 5,
#'   alpha_pos_dens = c(shape = 2, scale = 0.1), # default
#'   alpha_neg_dens = c(shape = 2, scale = 0.1), # default
#'   beta_dens = c(mean = 3, sd = 1) # default
#' )
#'
#' @export

simulate_QL <- function(summary_df = NULL,
                        sample_size = NULL,
                        gain_loss = TRUE,
                        test = FALSE,
                        affect = FALSE,
                        time_pars = "overall",
                        prev_sample = NULL,
                        raw_df = NULL,
                        ...) {

  # to appease R CMD check
  parameter <- subjID <- value <- trial_no <- id_no <- aff_num <-
    hidden_reward <- question_type <- missing_times <- trial_block <- adj <-
    question_response <- outc_lag <- NULL

  l <- list(...)

  if (affect) {
    time_prm <- match.arg(
      time_pars, c("none", "overall", "block"), several.ok = TRUE
    )
    sample_time <- is.null(raw_df)
    if (is.null(l$question_order)) {
      l$question_order <- c("happy", "confident", "engaged")
    }
    if (is.null(l$delta_model)) l$delta_model <- FALSE
    if (is.null(l$int_max)) l$int_max <- 5
  }

  if (is.null(summary_df)) {
    if (is.null(sample_size)) sample_size <- 100
    if (is.null(l$alpha_dens)) l$alpha_dens <- c(1.5, 3) # Beta(alpha, beta)
    if (is.null(l$alpha_pos_dens)) l$alpha_pos_dens <- l$alpha_dens
    if (is.null(l$alpha_neg_dens)) l$alpha_neg_dens <- l$alpha_dens
    if (is.null(l$beta_dens)) l$beta_dens <- c(3, 4) # Beta(alpha, beta) * 10
    ids_sample <- 1:sample_size

    if (gain_loss) {
      pars_df <- tibble::tibble(
        id_no = ids_sample,
        alpha_pos = rbeta(
          sample_size, l$alpha_pos_dens[1], l$alpha_pos_dens[2]
        ),
        alpha_neg = rbeta(
          sample_size, l$alpha_neg_dens[1], l$alpha_neg_dens[2]
        ),
        beta = rbeta(sample_size, l$beta_dens[1], l$beta_dens[2]) * 10
      )
    } else {
      pars_df <- tibble::tibble(
        id_no = ids_sample,
        alpha = rbeta(sample_size, l$alpha_dens[1], l$alpha_dens[2]),
        beta = rbeta(sample_size, l$beta_dens[1], l$beta_dens[2]) * 10
      )
    }
    if (affect) {
      ids_tb <- tibble::tibble(id_no = ids_sample)

      if (is.null(l$w0_dens)) l$w0_dens <- c(0, 0.5) # Normal(mu, sigma)
      if (is.null(l$w1_o_dens)) l$w1_o_dens <- c(-0.5, 1) # Normal(mu, sigma)
      if (is.null(l$w1_b_dens)) l$w1_b_dens <- c(-0.2, 0.5) # Normal(mu, sigma)
      if (is.null(l$w2_dens)) l$w2_dens <- c(0.2, 0.1)
      if (is.null(l$w3_dens)) l$w3_dens <- c(0.2, 0.1)
      if (is.null(l$gamma_dens)) l$gamma_dens <- c(2, 2) # Beta(alpha, beta)

      wt_df <-
        dplyr::bind_rows(
          list("happy" = ids_tb, "confident" = ids_tb, "engaged" = ids_tb),
          .id = "adj"
        ) |>
        dplyr::rowwise() |>
        dplyr::mutate(
          aff_num = grep(adj, l$question_order),
          w0 = rnorm(1, l$w0_dens[1], l$w0_dens[2]),
          w1_o = rnorm(1, l$w1_o_dens[1], l$w1_o_dens[2]),
          w1_b = rnorm(1, l$w1_b_dens[1], l$w1_b_dens[2]),
          w2 = rnorm(1, l$w2_dens[1], l$w2_dens[2]),
          w3 = rnorm(1, l$w3_dens[1], l$w3_dens[2]),
          gamma = rbeta(1, l$gamma_dens[1], l$gamma_dens[2])
        ) |>
        dplyr::ungroup()

      if (l$delta_model) {
        wt_a <- wt_df |> dplyr::select(-w0, -w1_o, -w1_b)
        # duplicate wt_b int_max times, with outc_lag = 0, 1, 2, ..., int_max
        wt_b <- lapply(1:l$int_max, function(i) wt_a) |>
          dplyr::bind_rows(.id = "outc_lag") |>
          dplyr::rowwise() |>
          dplyr::mutate(
            outc_lag = as.integer(outc_lag),
            w2 = w2 * rbeta(1, 1, outc_lag), # lower expectation w/ higher lag
            w3 = w3 * rbeta(1, 1, outc_lag)
          ) |>
          dplyr::ungroup()
        wt_df <- dplyr::bind_rows(wt_df, wt_b) |> dplyr::select(-gamma)
      }

      pars_df <- dplyr::bind_rows(pars_df, wt_df)

      if (!("block" %in% time_prm)) {
        pars_df  <- pars_df |> dplyr::select(-w1_b)
      }
      if (!("overall" %in% time_prm)) {
        pars_df <- pars_df |> dplyr::select(-w1_o)
      }
    }
  } else {
    pars_df <- clean_summary(summary_df) |>
      dplyr::mutate(
        adj = ifelse(is.na(aff_num), NA, l$question_order[aff_num])
      ) |>
      tidyr::pivot_wider(names_from = parameter, values_from = mean)

    if (!is.null(prev_sample)) {
      ids_sample <- prev_sample
    } else if (!is.null(sample_size)) { # do we want a sample < everyone
      ids_sample <- sample(1:max(pars_df$id_no), sample_size) |> sort()
    } else {
      ids_sample <- seq_len(dim(pars_df))[1]
    }

    if (!is.null(raw_df)) {
      if (test) raw_df <- raw_df[[1]]
      raw_df <- raw_df |>
        dplyr::mutate(
          id_no = as.integer(factor(subjID, levels = unique(raw_df$subjID)))
        ) |>
        dplyr::inner_join(
          tibble::as_tibble(ids_sample), by = c("id_no" = "value")
        )
    }
  }

  rewards <- function(i) {
    return(
      rbind(
        data.frame(
          "trial_block" = i, "type" = rep(12, 20),
          "hidden_reward" = sample(c(rep(0, 4), rep(1, 16)))
        ),
        data.frame(
          "trial_block" = i, "type" = rep(34, 20),
          "hidden_reward" = sample(c(rep(0, 6), rep(1, 14)))
        ),
        data.frame(
          "trial_block" = i, "type" = rep(56, 20),
          "hidden_reward" = sample(c(rep(0, 8), rep(1, 12)))
        )
      )
    )
  }
  # function that returns a random series of hidden rewards the same way as in
  # the task (i.e., in each block, there are exactly 16 rewarded "A", 14
  # rewarded "B", and 12 rewarded "C" symbols

  all_res <- data.frame()
  if (is.null(sample_size)) sample_size <- length(ids_sample)
  pb <- txtProgressBar(min = 0, max = sample_size, initial = 0, style = 3)

  for (id in seq_along(ids_sample)) {
    # make random sequence of trials, each condition balanced within blocks
    choice1 <- NULL
    for (i in 1:6) {
      choice1 <- c(choice1, sample(rep(c("A", "C", "E"), each = 20), 60))
    }
    choice2 <- ifelse(choice1 == "A", "B", ifelse(choice1 == "C", "D", "F"))
    conds <- data.frame(choice1, choice2)

    if (test) {
      choice1_test <- tibble::as_tibble(sample(rep(c("A", "C", "D", "E", "F"),
                                                   times = c(20, 16, 4, 12, 8)))
      ) |>
        dplyr::rename(choice1_test = value) |>
        dplyr::mutate(trial_no = dplyr::row_number()) |>
        dplyr::arrange(choice1_test)

      choice2_test <- c(
        sample(rep(c("B", "C", "D", "E", "F"), each = 4)),
        sample(rep(c("B", "D", "E", "F"), each = 4)),
        rep("B", times = 4),
        sample(rep(c("B", "D", "F"), each = 4)),
        sample(rep(c("B", "D"), each = 4))
      )

      conds_test <- cbind(choice1_test, choice2_test) |>
        dplyr::arrange(trial_no) |>
        dplyr::select(-trial_no)
    }

    indiv_pars <- pars_df |>
      dplyr::filter(id_no == ids_sample[id])

    if (gain_loss) {
      alpha_pos <- stats::na.omit(indiv_pars$alpha_pos)
      alpha_neg <- stats::na.omit(indiv_pars$alpha_neg)
      beta      <- stats::na.omit(indiv_pars$beta)
    } else {
      alpha <- stats::na.omit(indiv_pars$alpha)
      beta  <- stats::na.omit(indiv_pars$beta)
    }

    hidden_rewards <- dplyr::bind_rows(lapply(1:6, FUN = rewards)) |>
      # 6 blocks
      dplyr::group_by(type) |>
      dplyr::mutate(trial_no_group = dplyr::row_number()) |>
      dplyr::ungroup()

    training_results <- data.frame(
      "id_no" = ids_sample[id],
      "type" = ifelse(conds[, 1] == "A", 12, ifelse(conds[, 1] == "C", 34, 56)),
      "choice" = rep(NA, 360),
      "reward" = rep(NA, 360)
    )

    training_results <- training_results |>
      dplyr::group_by(type) |>
      dplyr::mutate(trial_no_group = dplyr::row_number()) |>
      dplyr::ungroup() |>
      dplyr::left_join(hidden_rewards, by = c("type", "trial_no_group")) |>
      dplyr::mutate(trial_no = dplyr::row_number())
    rm(hidden_rewards)

    if (test) {
      training_results <- training_results |>
        dplyr::mutate(exp_part = "training", test_type = NA)
    }

    if (affect) {
      training_results <- training_results |>
        dplyr::mutate(
          trial_time = NA,
          question_type = as.vector(
            sapply(1:120, function(i) sample(l$question_order))
          ),
          question_response = NA
        )

      if (l$delta_model) {
        training_results <- training_results |>
          dplyr::group_by(question_type) |>
          dplyr::mutate(
            int_trials = trial_no - dplyr::lag(trial_no, default = 0)
          )
      }

      if (!sample_time) {
        time <- raw_df |>
          dplyr::filter(id_no == ids_sample[id]) |>
          dplyr::select(trial_no, block_time, trial_time)
        missing_times <- grep(FALSE, 1:360 %in% time$trial_no)
      }

      ev_vec <- rep(0, 360)
      pe_vec <- rep(0, 360)
      qn_vec <- rep(0, 360)
      trial_time <- rep(0, 360)
      block_time <- rep(0, 360)

      w0  <- stats::na.omit(indiv_pars$w0)
      gamma    <- stats::na.omit(indiv_pars$gamma)
      if ("w1_o" %in% names(indiv_pars))
        w1_o <- stats::na.omit(indiv_pars$w1_o)
      else
        w1_o <- c(0, 0, 0)
      if ("w1_b" %in% names(indiv_pars))
        w1_b <- stats::na.omit(indiv_pars$w1_b)
      else
        w1_b <- c(0, 0, 0)

      if (!l$delta_model) {
        w2    <- stats::na.omit(indiv_pars$w2)
        w3    <- stats::na.omit(indiv_pars$w3)
      } else {
        w2 <- sapply(
          1:l$int_max,
          function(i) stats::na.omit(indiv_pars[indiv_pars$outc_lag == i, ]$w2)
        )
        w3 <- sapply(
          1:l$int_max,
          function(i) stats::na.omit(indiv_pars[indiv_pars$outc_lag == i, ]$w3)
        )

        int_trials <- training_results$int_trials
      }
    }

    # Initial Q values
    Q <- data.frame("A" = 0, "B" = 0, "C" = 0, "D" = 0, "E" = 0, "F" = 0)

    for (i in 1:360) {

      if (affect) {
        if (!sample_time && i %in% missing_times) next
      }

      # probability to choose "correct" stimulus
      p_t <-
        (exp(Q[conds[i, 1]] * beta)) /
        (exp(Q[conds[i, 1]] * beta) + (exp(Q[conds[i, 2]] * beta)))

      # make choice
      choice <- sample(c(1, 0), 1, prob = c(p_t, 1 - p_t))
      choice_idx <- ifelse(choice == 1, 1, 2)

      # rewarded?
      if (affect) {
        reward <- ifelse(choice == training_results$hidden_reward[i], 1, -1)
      } else {
        reward <- ifelse(choice == training_results$hidden_reward[i], 1, 0)
      }
      # i.e. if they choose "correctly" and hidden reward == 1 or
      # incorrectly but hidden reward == 0
      # recoded as -1 for affect models to allow for negative EVs

      ev <- Q[conds[i, choice_idx]]
      pe <- reward - Q[conds[i, choice_idx]]

      # update Q values
      if (gain_loss) { # i.e., were they rewarded?
        if (pe >= 0) {
          Q[conds[i, choice_idx]] <- ev + alpha_pos * pe
        } else {
          Q[conds[i, choice_idx]] <- ev + alpha_neg * pe
        }
      } else {
        Q[conds[i, choice_idx]] <- ev + alpha * pe
      }

      training_results$choice[i] <- choice
      training_results$reward[i] <- reward

      if (affect) {
        ev_vec[i] <- ev
        pe_vec[i] <- pe

        if (i > 1 && sample_time) {
          if ((i - 1) %% 60 != 0) {
            elapsed <- 5 + rgamma(1, shape = 3, scale = 2) # after trial
            trial_time[i] <- trial_time[i - 1] + (elapsed / 3600)
            block_time[i] <- 0
          } else {
            elapsed <- 5 + rgamma(1, shape = 5, scale = 5) # after block
            trial_time[i] <- trial_time[i - 1] + (elapsed / 3600)
            block_time[i] <- block_time[i - 1] + (elapsed / 3600)
          }
        } else if (!sample_time) {
          trial_time[i] <- time[time$trial_no == i, ]$trial_time / 60
          block_time[i] <- time[time$trial_no == i, ]$block_time / 60
        }

        q <- grep(training_results$question_type[i], l$question_order)
        qn_vec[i] <- q

        rating <-
          w0[q] +
          w1_o[q] * trial_time[i] +
          w1_b[q] * block_time[i]

        if (!l$delta_model) {
          rating <- rating +
            w2[q] * sum(
              sapply(1:i, function(j) gamma[q]^(i - j) * ev_vec[[j]])
            ) +
            w3[q] * sum(
              sapply(1:i, function(j) gamma[q]^(i - j) * pe_vec[[j]])
            )
        } else {
          for (j in 1:int_trials[i]) {
            t1 <- i + 1 # to include the current trial
            rating <- rating +
              w2[[qn_vec[[t1 - j]], j]] * ev_vec[[t1 - j]] +
              w3[[qn_vec[[t1 - j]], j]] * pe_vec[[t1 - j]]
          }
        }
        training_results$question_response[i] <- plogis(rating) * 100
        training_results$trial_time[i]        <- trial_time[i] * 60 ## in mins
      }
    }

    training_results <- training_results |> tidyr::drop_na(choice)
    all_res <- dplyr::bind_rows(all_res, training_results)

    if (test) {
      test_results <- data.frame(
        "id_no" = ids_sample[id],
        "type" = rep(NA, 60),
        "choice" = rep(NA, 60),
        "reward" = rep(NA, 60),
        "hidden_reward" = rep(NA, 60),
        "exp_part" = rep("test", 60),
        "test_type" = rep(NA, 60)
      )

      for (j in 1:60) {
        p_t_test <- (exp(Q[conds_test[j, 1]] * beta)) /
          (exp(Q[conds_test[j, 1]] * beta) + (exp(Q[conds_test[j, 2]] * beta)))
        # probability to choose "correct" stimulus
        choice_test <- sample(c(1, 0), 1, prob = c(p_t_test, 1 - p_t_test))
        # make choice

        type <- as.numeric(
          paste0(
            match(tolower(conds_test[j, 1]), letters[1:6]),
            match(tolower(conds_test[j, 2]), letters[1:6])
          )
        )

        if (grepl("12|34|56", type)) test_type <- "training"
        else if (type / 10 < 2) test_type <- "chooseA"
        else if (type %% 10 == 2) test_type <- "avoidB"
        else test_type <- "novel"

        test_results$type[j] <- type
        test_results$choice[j] <- choice_test
        test_results$test_type[j] <- test_type
      }

      test_results <- test_results |>
        dplyr::group_by(type) |>
        dplyr::mutate(trial_no_group = dplyr::row_number()) |>
        dplyr::ungroup()
      all_res <- dplyr::bind_rows(all_res, test_results)
    }
    setTxtProgressBar(pb, id)
  }

  all_res <- tibble::as_tibble(all_res) |>
    dplyr::select(-hidden_reward)

  if (affect) {
    if (!sample_time) {
      pars_df <- pars_df |>
        dplyr::inner_join(
          raw_df |> dplyr::distinct(subjID, id_no), by = "id_no"
        ) |>
        dplyr::mutate(
          id_no = as.integer(factor(id_no, levels = unique(all_res$id_no)))
        )
      # to match up with summary for plotting

      all_res <- all_res |>
        dplyr::left_join(raw_df |> dplyr::distinct(subjID, id_no), by = "id_no")
    } else {
      all_res <- all_res |> dplyr::mutate(subjID = id_no)
    }

    all_res <- all_res |>
      dplyr::rowwise() |>
      dplyr::mutate(trial_no_block = trial_no - (trial_block - 1) * 60) |>
      dplyr::mutate(
        question = ifelse(
          question_type == l$question_order[1], 1,
          ifelse(question_type == l$question_order[2], 2, 3)
        )
      ) |>
      dplyr::mutate(reward = ifelse(reward == 0, -1, reward)) |>
      dplyr::group_by(subjID, trial_block) |>
      dplyr::mutate(block_time = trial_time - min(trial_time)) |>
      dplyr::group_by(subjID, question_type) |>
      dplyr::mutate(
        trial_no_q = order(trial_no, decreasing = FALSE),
        qn_response_prev = dplyr::lag(question_response),
        trials_elapsed = trial_no - dplyr::lag(trial_no, default = 0)
      ) |>
      dplyr::ungroup() |>
      dplyr::arrange(id_no)
  } else if (!is.null(raw_df)) {
    all_res <- all_res |>
      dplyr::left_join(raw_df |> dplyr::distinct(subjID, id_no), by = "id_no")
    pars_df <- pars_df |>
      dplyr::inner_join(raw_df |> dplyr::distinct(subjID, id_no), by = "id_no")
  } else {
    all_res <- all_res |>
      dplyr::mutate(subjID = id_no)
    pars_df <- pars_df |>
      dplyr::inner_join(
        tibble::as_tibble(ids_sample), by = c("id_no" = "value")
      ) |>
      dplyr::mutate(subjID = id_no)
  }

  ret <- list()
  ret$sim <- data.table::as.data.table(all_res)
  ret$pars <- data.table::as.data.table(pars_df)

  return(ret)
}
qdercon/pstpipeline documentation built on June 1, 2025, 1:11 p.m.