R/estimate_2_RNN.R

Defines functions estimate_2_RNN

Documented in estimate_2_RNN

#' @title 
#' Estimation Method: Recurrent Neural Network (RNN)
#' @name estimate_2_RNN
#' @description 
#'  This function takes a large set of simulated data to train an Recurrent 
#'    Neural Network (RNN) model and then uses the trained model to 
#'    estimate optimal parameters for the target data.
#'
#' @param data 
#'  A data frame in which each row represents a single trial,
#'    see \link[multiRL]{data} 
#' @param colnames 
#'  Column names in the data frame,
#'    see \link[multiRL]{colnames}
#' @param behrule 
#'  The agent's implicitly formed internal rule,
#'    see \link[multiRL]{behrule}
#' @param ids 
#'  The Subject ID of the participant whose data needs to be fitted.
#' @param models 
#'  Reinforcement Learning Models
#' @param funcs 
#'  The functions forming the reinforcement learning model,
#'    see \link[multiRL]{funcs}
#' @param priors 
#'  Prior probability density function of the free parameters,
#'    see \link[multiRL]{priors}
#' @param settings 
#'  Other model settings, 
#'    see \link[multiRL]{settings}
#' @param lowers
#'  Lower bound of free parameters in each model.
#' @param uppers
#'  Upper bound of free parameters in each model.
#' @param control 
#'  Settings manage various aspects of the iterative process,
#'    see \link[multiRL]{control}
#' @param ... 
#'  Additional arguments passed to internal functions.
#'
#' @returns An S3 object of class \code{DataFrame} containing, for each model, 
#'  the estimated optimal parameters and associated model fit metrics.
#' 
estimate_2_RNN <- function(
    data,
    colnames,
    behrule,
    ids = NULL,
    
    models,
    funcs = NULL,
    priors,
    settings,
    
    lowers,
    uppers,
    control,
    ...
){
  
################################ [default] #####################################
  
  # 默认列名
  default <- list(
    subid = "Subject", 
    block = "Block", 
    trial = "Trial",
    object = NA_character_, 
    reward = NA_character_, 
    action = "Action"
  )
  colnames <- utils::modifyList(x = default, val = colnames)
  
  # 默认方程
  if (is.null(funcs)) {funcs <- rep(list(list()), length(models))}
  for (i in 1:length(funcs)) {
    default <- list(
      lrng_func = multiRL::func_alpha,
      prob_func = multiRL::func_beta,
      util_func = multiRL::func_gamma,
      bias_func = multiRL::func_delta,
      expl_func = multiRL::func_epsilon,
      dcay_func = multiRL::func_zeta
    )
    funcs[[i]] <- utils::modifyList(x = default, val = funcs[[i]])
  }
  
  # 默认设置
  settings <- .restructure_settings(x = settings, n = length(models))
  for (i in 1:length(settings)) {
    default <- list(
      name = paste0("Unknown_", i),
      policy = "on"
    )
    settings[[i]] <- utils::modifyList(x = default, val = settings[[i]])
  }
  
  # 强制设置
  for (i in 1:length(settings)) {
    settings[[i]]$mode <- "fitting"
    settings[[i]]$estimate <- "RNN"
    settings[[i]]$policy <- "on"
  }
  
  # 转换先验
  priors <- .convert_priors(priors = priors, to = "rfunc")
  
  # 默认控制
  default = list(
    # General
    seed = 123,
    core = 1,
    sample = 100,
    dash = 1e-5,
    # SBI
    train = 1000,
    scope = "individual",
    # RNN
    layer = "GRU",
    loss = "MSE",
    info = c(colnames$object, colnames$action),
    units = 128,
    dropout = 0,
    L = NA_character_,
    penalty = 1e-5,
    batch_size = 10,
    epochs = 100,
    keras3 = FALSE,
    backend = "tensorflow",
    check = TRUE
  )
  control <- utils::modifyList(x = default, val = control, keep.null = TRUE)
  # 解放control中的设定, 变成全局变量
  list2env(control, envir = environment())

################################# [check] ######################################
  
  if (check == TRUE) {.check_tensorflow()}

############################ [aotu-detect data] ################################
  
  # 自动探测数据
  suppressMessages({dfinfo <- .detect_data(data)})
  # 如果没有输入被试序号的列名. 则自动探测
  if ("subid" %in% names(colnames)) {
    subid <- colnames[["subid"]]
  } else {
    subid <- dfinfo$sub_col_name
  }
  # 如果没有输入要拟合的被试序号, 就拟合所有的被试
  if (is.null(ids)){ids <- dfinfo$all_ids}
  
################################## [RNN] #######################################

  result.RNN <- list()
  
  for (i in 1:length(models)) {
    
    model_name <- settings[[i]]$name
    
    message(paste0(
      "Fitting ", model_name,"\n"
    ))
    
    # 定义进度条
    progressr::handlers(progressr::handler_txtprogressbar)
    # 进度条启动
    progressr::with_progress({
      # 进度条参照
      p <- progressr::progressor(steps = length(ids))

      # RNN启动!
      suppressMessages({
        opt_params <- list()
        
        if ( scope == "shared" ) {
          
          # 只训练一个RNN
          if (keras3) {
              RNN <- engine_RNN3(
                data = data[data[, subid] == 1, ],
                behrule = behrule,
                colnames = colnames,
                funcs = funcs[[i]],
                settings = settings[[i]],
                priors = priors[[i]],
                model = models[[i]],
                control = control
              )
          } else {
              RNN <- engine_RNN(
                data = data[data[, subid] == 1, ],
                behrule = behrule,
                colnames = colnames,
                funcs = funcs[[i]],
                settings = settings[[i]],
                priors = priors[[i]],
                model = models[[i]],
                control = control
              )
          }

          for (j in ids) {
            
            sub_data <- data[data[, subid] == j, ]
            
            n_params <- length(priors[[i]])
            n_trials <- nrow(sub_data)
            
            # 预测真实数据对应的参数
            sub_matrix <- .df2matrix(df = sub_data)
            mat_cols <- colnames(sub_matrix)
            split_info <- unlist(lapply(info, function(col) {
              grep(paste0("^", col, "(_[0-9]+)?$"), mat_cols, value = TRUE)
            }))
            n_info <- length(split_info)
            
            X_sub <- array(NA, dim = c(1, n_trials, n_info))
            X_sub[1, , ] <- sub_matrix[, split_info, drop = FALSE]
            X_pred <- stats::predict(object = RNN, x = X_sub, verbose = 0)
            X_pred <- .name_rnnouts(
              X_pred = X_pred, loss = loss, param_names = names(priors[[i]])
            )
            X_pred <- .fix_params(
              params_df = X_pred, param_names = names(priors[[i]]),
              lower = lowers[[i]], upper = uppers[[i]], dash = dash
            )
            opt_params[[j]] <- X_pred
            p()
          }
          
        } else if ( scope == "individual" ) {
          
          for (j in ids) {
            
            sub_data <- data[data[, subid] == j, ]
            
            n_params <- length(priors[[i]])
            n_trials <- nrow(sub_data)
            
            # 为每个被试单独训练模型
            if (keras3) {
              RNN <- engine_RNN3(
                data = sub_data,
                behrule = behrule,
                colnames = colnames,
                funcs = funcs[[i]],
                settings = settings[[i]],
                priors = priors[[i]],
                model = models[[i]],
                control = control
              )
            } else {
              RNN <- engine_RNN(
                data = sub_data,
                behrule = behrule,
                colnames = colnames,
                funcs = funcs[[i]],
                settings = settings[[i]],
                priors = priors[[i]],
                model = models[[i]],
                control = control
              )
            }

            # 预测真实数据对应的参数
            sub_matrix <- .df2matrix(df = sub_data)
            mat_cols <- colnames(sub_matrix)
            split_info <- unlist(lapply(info, function(col) {
              grep(paste0("^", col, "(_[0-9]+)?$"), mat_cols, value = TRUE)
            }))
            n_info <- length(split_info)
            
            X_sub <- array(NA, dim = c(1, n_trials, n_info))
            X_sub[1, , ] <- sub_matrix[, split_info, drop = FALSE]
            X_pred <- stats::predict(object = RNN, x = X_sub, verbose = 0)
            X_pred <- .name_rnnouts(
              X_pred = X_pred, loss = loss, param_names = names(priors[[i]])
            )
            X_pred <- .fix_params(
              params_df = X_pred, param_names = names(priors[[i]]),
              lower = lowers[[i]], upper = uppers[[i]], dash = dash
            )
            opt_params[[j]] <- X_pred
            p()
          }
        }
        result.RNN[[i]] <- do.call(rbind, opt_params)
      })
    })
  }
  
  col_order <- c("fit_model", "Subject")
  
  for (i in 1:length(models)) {
    
    result.RNN[[i]] <- as.data.frame(result.RNN[[i]]) 
    # 新增两列作为序号
    result.RNN[[i]][["fit_model"]] <- settings[[i]]$name
    result.RNN[[i]][["Subject"]] <- ids
    # 找到原始列的名字
    remaining_cols <- setdiff(names(result.RNN[[i]]), col_order)
    # 序号列 + 数据列
    result.RNN[[i]] <- result.RNN[[i]][c(col_order, remaining_cols)]
    
  }
  
  result.RNN <- .rbind_fill(result.RNN)
  
  return(result.RNN)
}

Try the multiRL package in your browser

Any scripts or data that you put into this service are public.

multiRL documentation built on March 31, 2026, 5:06 p.m.