R/step4_rpl_e.R

Defines functions rpl_e

Documented in rpl_e

#' @title 
#' Step 4: Replaying the experiment with optimal parameters
#'
#' @description
#' After completing Step 3 using \code{fit_p()} to obtain the optimal parameters
#'  for each subject and saving the resulting CSV locally, this function
#'  allows you to load that result dataset. It then applies these optimal
#'  parameters back into the reinforcement learning model, effectively
#'  simulating how the "robot" (the model) would make its choices.
#'
#' Based on this generated dataset, you can then analyze the robot's data
#'  in the same manner as you would analyze human behavioral data. If a
#'  particular model's fitted data can successfully reproduce the
#'  experimental effects observed in human subjects, it strongly suggests
#'  that this model is a good and valid representation of the process.
#'
#' @param data [data.frame] 
#' 
#' This data should include the following mandatory columns: 
#'  \itemize{
#'    \item \code{sub} "Subject"
#'    \item \code{time_line} "Block" "Trial"
#'    \item \code{L_choice} "L_choice"
#'    \item \code{R_choice} "R_choice"
#'    \item \code{L_reward} "L_reward"
#'    \item \code{R_reward} "R_reward"
#'    \item \code{sub_choose} "Sub_Choose"
#'  }
#'   
#' @param id [CharacterVector]
#' 
#'  A vector specifying the subject ID(s) for which parameters should be
#'   fitted. The function will process only the subjects provided in this
#'   vector.
#'
#'  To fit all subjects, you can either explicitly set the argument as
#'   \code{id = unique(data$Subject)} or leave it as the default
#'   (\code{id = NULL}). Both approaches will direct the function to fit
#'   parameters for every unique subject in the dataset.
#'
#'  It is strongly recommended to avoid using simple numeric sequences like
#'   \code{id = 1:4}. This practice can lead to errors if subject IDs are
#'   stored as strings (e.g., subject four is stored as "004") or are not
#'   sequentially numbered.
#'  
#'  default: \code{id = NULL}
#'  
#' @param result [data.frame]  
#' 
#' Output data generated by the \code{fit_p()} function. Each row represents 
#'  model fit results for a subject.
#'
#' @param model [Function]  
#' 
#' A model function to be applied in evaluating the experimental effect.
#'
#' @param model_name [string]  
#' 
#' A character string specifying the name of the model to extract from 
#'  the result.
#'
#' @param param_prefix [string]  
#' 
#' A prefix string used to identify parameter columns in the result data 
#' 
#' default: \code{param_prefix = "param_"}
#' 
#' @param n_trials [integer]
#' 
#' Represents the total number of trials a single subject experienced
#'  in the experiment. If this parameter is kept at its default value
#'  of \code{NULL}, the program will automatically detect how many trials
#'  a subject experienced from the provided data. This information
#'  is primarily used for calculating model fit statistics such as
#'  AIC (Akaike Information Criterion) and BIC (Bayesian Information
#'  Criterion).
#'  
#'  default: \code{n_trials = NULL}
#'
#' @return 
#' A list, where each element is a data.frame representing one subject's
#'  results. Each data.frame includes the value update history for each option,
#'  the learning rate (\code{eta}), utility function (\code{gamma}), and other 
#'  relevant information used in each update.
#'  
#' @examples
#' \dontrun{
#' list <- list()
#'
#' list[[1]] <- dplyr::bind_rows(
#'   binaryRL::rpl_e(
#'     data = binaryRL::Mason_2024_G2,
#'     result = read.csv("../OUTPUT/result_comparison.csv"),
#'     model = binaryRL::TD,
#'     model_name = "TD",
#'     param_prefix = "param_",
#'   )
#' )
#'
#' list[[2]] <- dplyr::bind_rows(
#'   binaryRL::rpl_e(
#'     data = binaryRL::Mason_2024_G2,
#'     result = read.csv("../OUTPUT/result_comparison.csv"),
#'     model = binaryRL::RSTD,
#'     model_name = "RSTD",
#'     param_prefix = "param_",
#'   )
#' )
#'
#' list[[3]] <- dplyr::bind_rows(
#'   binaryRL::rpl_e(
#'     data = binaryRL::Mason_2024_G2,
#'     result = read.csv("../OUTPUT/result_comparison.csv"),
#'     model = binaryRL::Utility,
#'     param_prefix = "param_",
#'     model_name = "Utility",
#'   )
#' )
#' }
#'  
rpl_e <- function(
  data, 
  id = NULL,
  result, 
  model,
  model_name,
  param_prefix = "param_",
  n_trials = NULL
) {
  # 事前准备. 探测信息
  info <- suppressWarnings(suppressMessages(detect_information(data = data)))
  
  # 自动获取被试id列的列名
  Subject <- info[["sub_col_name"]]
  
  # 如果没有指定重玩实验的是哪几个被试, 就会让所有被试重玩该实验.
  # <前提是有这个被试的最佳参数结果>
  if (is.null(id)) {id <- info[["all_ids"]]}
  
  # 如果没有告知实验有多少个试次, 就自动探测
  if (is.null(n_trials)) {n_trials <- info[["n_trials"]]}
  
  # 因为所有模型的结果都会在result中, 一次只导入result中某个模型的行
  fit_model <- "fit_model"
  result <- result[result[[fit_model]] == model_name, ]
  
  # 正式开始复现实验结果
  res <- list()
  
  for (i in 1:length(id)) {
    
    params <- stats::na.omit(unlist(result[i, grep(param_prefix, names(result))]))
    
    # 创建临时环境
    binaryRL.env <- new.env()
    
    # 给临时环境创建全局变量
    binaryRL.env$mode <- "replay"
    binaryRL.env$policy <- "on"
    
    binaryRL.env$estimate <- "MLE"
    binaryRL.env$priors <- NULL
    
    binaryRL.env$data <- data
    binaryRL.env$id <- id[i]
    binaryRL.env$n_params <- length(params)
    binaryRL.env$n_trials <- n_trials
    
    # 让obj_func的环境绑定在fit_env中
    obj_func <- model
    environment(obj_func) <- binaryRL.env
    
    res[[i]] <- obj_func(params = params)[[1]]
  }
  
  return(res)
}

Try the binaryRL package in your browser

Any scripts or data that you put into this service are public.

binaryRL documentation built on Aug. 21, 2025, 6:01 p.m.