R/model_ezdm.R

Defines functions posterior_predict_ezdm_4par log_lik_ezdm_4par configure_model.ezdm_4par posterior_predict_ezdm_3par log_lik_ezdm_3par configure_model.ezdm_3par bmf2bf.ezdm_4par bmf2bf.ezdm_3par check_data.ezdm ezdm .model_ezdm

Documented in ezdm

############################################################################# !
# MODELS                                                                 ####
############################################################################# !

.ezdm_version_table <- list(
  "3par" = list(
    parameters = list(
      drift = "Drift rate = Average rate of evidence accumulation of the decision processes",
      bound = "Boundary separation = Distance between the decision boundaries that need to be reached",
      ndt = "Non-decision time = Additional time required beyond the evidence accumulation process",
      s = "The diffusion constant, that is the standard deviation of the Gaussian noise during sampling"
    ),
    links = list(
      drift = "identity", bound = "log", ndt = "log", s = "log"
    ),
    fixed_parameters = list(s = 0, mu = 0),
    priors = list(
      drift = list(main = "cauchy(0,1)", effects = "normal(0,0.5)"),
      bound = list(main = "normal(0,0.5)", effects = "normal(0,0.5)"),
      ndt = list(main = "normal(-1.5,0.5)", effects = "normal(0,0.3)"),
      s = list(main = "normal(0,1)", effects = "normal(0,0.3)")
    ),
    init_ranges = list(
      mu = c(0,1),
      drift = c(-1,1),
      bound = c(1,2),
      ndt = c(0.25, 0.5),
      s = c(0.99, 1.01)
    )
  ),
  "4par" = list(
    parameters = list(
      drift = "Drift rate = Average rate of evidence accumulation of the decision processes",
      bound = "Boundary separation = Distance between the decision boundaries that need to be reached",
      ndt = "Non-decision time = Additional time required beyond the evidence accumulation process",
      zr = "Relative starting point = Starting point between the decision thresholds relative to the upper bound.",
      s = "The diffusion constant, that is the standard deviation of the Gaussian noise during sampling"
    ),
    links = list(
      drift = "identity", bound = "log", ndt = "log", zr = "logit", s = "log"
    ),
    fixed_parameters = list(s = 0, mu = 0),
    priors = list(
      drift = list(main = "cauchy(0,1)", effects = "normal(0,0.5)"),
      bound = list(main = "normal(0,0.5)", effects = "normal(0,0.5)"),
      ndt = list(main = "normal(-1.5,0.5)", effects = "normal(0,0.3)"),
      zr = list(main = "normal(0,0.5)", effects = "normal(0,0.3)"),
      s = list(main = "normal(0,1)", effects = "normal(0,0.3)")
    ),
    init_ranges = list(
      mu = c(0,1),
      drift = c(-1,1),
      bound = c(1,2),
      ndt = c(0.25, 0.5),
      zr = c(0.45, 0.55),
      s = c(0.99, 1.01)
    )
  )
)


.model_ezdm <- function(mean_rt = NULL, var_rt = NULL, n_upper = NULL, n_trials = NULL, version = "3par", links = NULL, call = NULL, ...) {
  out <- structure(
    list(
      resp_vars = nlist(mean_rt, var_rt, n_upper),
      other_vars = nlist(n_trials),
      domain = "Decision Making / Response times",
      task = "Choice Reaction Time tasks",
      name = "EZ-Diffusion Model",
      citation = glue(
        "Wagenmakers, E.-J., Van Der Maas, H. L. J., & Grasman, R. P. P. P. (2007). An EZ-diffusion model for response time and accuracy. Psychonomic Bulletin & Review, 14(1), 3-22. https://doi.org/10/fk447c", "\n",
        "- Ch\u00e1vez De la Pe\u00f1a, A. F., & Vandekerckhove, J. (2025). An EZ Bayesian hierarchical drift diffusion model for response time and accuracy. Psychonomic Bulletin & Review. https://doi.org/10.3758/s13423-025-02729-y"
      ),
      version = version,
      requirements = glue(
        "Provide aggregated statistics for each subject and condition that model parameters should vary over:", "\n\n",
        "  - Mean reaction times (mean_rt) in seconds", "\n",
        "  - Variance of reaction times (var_rt) in seconds", "\n",
        "  - Number of responses to the upper decision threshold (n_upper)", "\n",
        "  - Total number of trials used to calculate aggregated statistics (n_trials)"
      ),
      parameters = .ezdm_version_table[[version]][["parameters"]],
      links = .ezdm_version_table[[version]][["links"]],
      fixed_parameters = .ezdm_version_table[[version]][["fixed_parameters"]],
      default_priors = .ezdm_version_table[[version]][["priors"]],
      init_ranges = .ezdm_version_table[[version]][["init_ranges"]],
      void_mu = TRUE
    ),
    class = c("bmmodel", "ezdm"),
    call = call
  )
  if (!is.null(version)) class(out) <- c(class(out), paste0("ezdm_", version))
  out$links[names(links)] <- links
  out
}
# user facing alias
# information in the title and details sections will be filled in
# automatically based on the information in the .model_ezdm()$info

#' @title `r .model_ezdm()$name`
#' @name ezdm
#' @details `r model_info(.model_ezdm(version = "4par"))`
#' @param mean_rt The names of the variable or variables (for 4par version) coding the mean reaction time in seconds in the data.
#' @param var_rt The names of the variable or variables (for 4par version) coding the variance of the reaction time in seconds in the data
#' @param n_upper The name of the variable coding the number of responses that hit the upper response threshold (typically the number of correct responses) in the data.
#' @param n_trials The name of the variable coding the number of trials that was used to calculated the aggregated statistics.
#' @param links A list of links for the parameters.
#' @param version A character label for the version of the model. There is a three-parameter version
#'   (version = "3par") of the `ezdm` that fixes the relative starting point `zr` to 0.5, and a
#'   four parameter version (version = "4par"), that allows to freely estimate the starting point.
#' @param ... used internally for testing, ignore it
#' @return An object of class `bmmodel`
#' @keywords bmmodel
#' @export
#' @examples
#' \dontrun{
#' # Minimal parameter recovery example with 3-parameter EZDM
#'
#' # Simulate data from known parameters
#' set.seed(123)
#' sim_data <- rezdm(
#'   n = 10,
#'   n_trials = 100,
#'   drift = 2,
#'   bound = 1.5,
#'   ndt = 0.3,
#'   version = "3par"
#' )
#'
#' # Add subject ID
#' sim_data$id <- 1:10
#'
#' # Specify model
#' model <- ezdm(
#'   mean_rt = "mean_rt",
#'   var_rt = "var_rt",
#'   n_upper = "n_upper",
#'   n_trials = "n_trials",
#'   version = "3par"
#' )
#'
#' # Specify formula with random effects
#' formula <- bmf(
#'   drift ~ 1 + (1 | id),
#'   bound ~ 1 + (1 | id),
#'   ndt ~ 1
#' )
#'
#' # Fit model (using cmdstanr backend)
#' fit <- bmm(
#'   formula = formula,
#'   data = sim_data,
#'   model = model,
#'   backend = "cmdstanr",
#'   cores = 4,
#'   chains = 4,
#'   iter = 2000,
#'   warmup = 1000
#' )
#'
#' # Check parameter recovery
#' summary(fit)
#'
#' # Extract population-level effects
#' # True values: drift = 2, bound = 1.5, ndt = 0.3 (on log scale for drift/bound)
#' exp(brms::fixef(fit))
#' }
ezdm <- function(mean_rt, var_rt, n_upper, n_trials, links = NULL, version = "3par", ...) {
  call <- match.call()
  stop_missing_args()
  .model_ezdm(
    mean_rt = mean_rt, var_rt = var_rt, n_upper = n_upper, n_trials = n_trials,
    links = links, version = version, call = call, ...
  )
}

############################################################################# !
# CHECK_DATA S3 methods                                                  ####
############################################################################# !

#' @export
check_data.ezdm <- function(model, data, formula) {
  # retrieve required variable names
  mean_rt <- model$resp_vars$mean_rt
  var_rt <- model$resp_vars$var_rt
  n_upper <- model$resp_vars$n_upper
  n_trials <- model$other_vars$n_trials


  # validate length of mean_rt and var_rt dependent on version
  if (model$version == "3par") {
    stopif(length(mean_rt) != 1, "mean_rt must be a single variable name.")
    stopif(length(var_rt) != 1, "var_rt must be a single variable name.")
  } else if (model$version == "4par") {
    stopif(length(mean_rt) != 2, "mean_rt must be a vector of two variable names: c(mean_rt_upper, mean_rt_lower).")
    stopif(length(var_rt) != 2, "var_rt must be a vector of two variable names: c(var_rt_upper, var_rt_lower).")
  } else {
    stop2("Unknown ezdm version: {model$version}. Supported versions are '3par' and '4par'.")
  }

  # check that all required variables exist in data
  required_vars <- c(mean_rt, var_rt, n_upper, n_trials)
  missing_vars <- setdiff(required_vars, colnames(data))
  stopif(
    length(missing_vars),
    "The following required variables are missing from the data: {collapse_comma(missing_vars)}"
  )

  # check that mean RT values are plausible (warn if likely in milliseconds)
  # typical RTs in seconds are 0.2-3s; values > 10 suggest milliseconds
  mean_rt_values <- unlist(data[mean_rt])
  warnif(
    any(mean_rt_values > 10, na.rm = TRUE),
    "Some mean RT values are greater than 10. If your reaction times are in
    milliseconds, please convert them to seconds before fitting the model.
    The model assumes reaction times are measured in seconds."
  )

  # check that mean RT values are positive
  stopif(
    any(mean_rt_values <= 0, na.rm = TRUE),
    "Mean RT values must be positive. Found non-positive values in the data."
  )

  # check that variance values are positive
  var_rt_values <- unlist(data[var_rt])
  stopif(
    any(var_rt_values <= 0, na.rm = TRUE),
    "Variance of RT must be positive. Found non-positive values in the data."
  )

  # check that n_trials is a positive integer
  n_trials_values <- data[[n_trials]]
  stopif(
    any(n_trials_values <= 2, na.rm = TRUE),
    "Number of trials (n_trials) must be larger than two."
  )
  warnif(
    any(n_trials_values != round(n_trials_values), na.rm = TRUE),
    "Number of trials (n_trials) should be whole numbers. Found non-integer values."
  )

  # check that n_upper is a non-negative integer
  n_upper_values <- data[[n_upper]]
  stopif(
    any(n_upper_values < 0, na.rm = TRUE),
    "Number of upper boundary responses (n_upper) needs to be positive."
  )
  warnif(
    any(n_upper_values != round(n_upper_values), na.rm = TRUE),
    "Number of upper boundary responses (n_upper) should be whole numbers."
  )

  # check that n_upper <= n_trials (proportion correct between 0 and 1)
  stopif(
    any(n_upper_values > n_trials_values, na.rm = TRUE),
    "Number of upper boundary responses (n_upper) cannot exceed total trials (n_trials)."
  )

  NextMethod("check_data")
}

############################################################################# !
# Convert bmmformula to brmsformla methods                               ####
############################################################################# !

#' @export
bmf2bf.ezdm_3par <- function(model, formula) {
  # retrieve required response arguments
  mean_rt <- model$resp_vars$mean_rt
  var_rt <- model$resp_vars$var_rt
  n_upper <- model$resp_vars$n_upper
  n_trials <- model$other_vars$n_trials

  brms::bf(paste0(mean_rt, " | vreal(", var_rt, ") + vint(", n_upper, ") + trials(", n_trials, ") ~ 1"))
}

#' @export
bmf2bf.ezdm_4par <- function(model, formula) {
  # retrieve required response arguments
  mean_rt <- model$resp_vars$mean_rt
  var_rt <- model$resp_vars$var_rt
  n_upper <- model$resp_vars$n_upper
  n_trials <- model$other_vars$n_trials

  brms::bf(glue::glue("{mean_rt[1]} | vreal({mean_rt[2]}, {var_rt[1]}, {var_rt[2]}) + vint({n_upper}, {n_trials}) ~ 1"))
}

############################################################################# !
# CONFIGURE_MODEL S3 METHODS                                             ####
############################################################################# !

#' @export
configure_model.ezdm_3par <- function(model, data, formula) {
  # construct brms formula from the bmm formula
  formula <- bmf2bf(model, formula)
  links <- model$links

  # construct the family & add to formula object
  formula$family <- brms::custom_family(
    "ezdm_3par",
    dpars = c("mu", "drift", "bound", "ndt", "s"),
    links = c("identity", links$drift, links$bound, links$ndt, links$s),
    lb = c(NA, NA, 0, 0, 0),
    ub = c(NA, NA, NA, NA, NA),
    type = "real",
    log_lik = log_lik_ezdm_3par,
    posterior_predict = posterior_predict_ezdm_3par,
    loop = TRUE,
    vars = c("vreal1[n]", "vint1[n]", "trials[n]")
  )

  # prepare initial stanvars to pass to brms, model formula and priors
  sc_path <- system.file("stan_chunks", package = "bmm")
  stan_functions <- read_lines2(paste0(sc_path, "/ezdm_3par_functions.stan"))
  stanvars <- brms::stanvar(scode = stan_functions, block = "functions")

  # return the list
  nlist(formula, data, stanvars)
}


log_lik_ezdm_3par <- function(i, prep) {
  dezdm(
    mean_rt = prep$data$Y[i],
    var_rt = prep$data$vreal1[i],
    n_upper = prep$data$vint1[i],
    n_trials = prep$data$trials[i],
    drift = brms::get_dpar(prep, "drift", i = i),
    bound = brms::get_dpar(prep, "bound", i = i),
    ndt = brms::get_dpar(prep, "ndt", i = i),
    s = brms::get_dpar(prep, "s", i = i),
    version = "3par",
    log = TRUE
  )
}

# posterior_predict for pp_check
# Returns predictions for EZDM 3par dependent variables
# By default returns mean_rt (the primary response Y)
# Use dv argument to select other variables: "var_rt", "n_upper"
# Usage: posterior_predict(fit, dv = "var_rt")
posterior_predict_ezdm_3par <- function(i, prep, ..., dv = c("mean_rt", "var_rt", "n_upper")) {
  dv <- match.arg(dv)

  rezdm(
    n = length(brms::get_dpar(prep, "drift", i = i)),
    n_trials = prep$data$trials[i],
    drift = brms::get_dpar(prep, "drift", i = i),
    bound = brms::get_dpar(prep, "bound", i = i),
    ndt = brms::get_dpar(prep, "ndt", i = i),
    s = brms::get_dpar(prep, "s", i = i),
    version = "3par"
  )[[dv]]
}

#' @export
configure_model.ezdm_4par <- function(model, data, formula) {
  # construct brms formula from the bmm formula
  formula <- bmf2bf(model, formula)
  links <- model$links

  # construct the family & add to formula object
  formula$family <- brms::custom_family(
    "ezdm_4par",
    dpars = c("mu", "drift", "bound", "ndt", "zr", "s"),
    links = c("identity", links$drift, links$bound, links$ndt, links$zr, links$s),
    lb = c(NA, NA, 0, 0, 0, 0), # lower bounds for parameters
    ub = c(NA, NA, NA, NA, 1, NA), # upper bounds for parameters
    type = "real", # real for continous dv, int for discrete dv
    log_lik = log_lik_ezdm_4par,
    posterior_predict = posterior_predict_ezdm_4par,
    loop = TRUE, # is the likelihood vectorized
    vars = c("vreal1[n]", "vreal2[n]", "vreal3[n]", "vint1[n]", "vint2[n]")
  )

  # prepare initial stanvars to pass to brms, model formula and priors
  sc_path <- system.file("stan_chunks", package = "bmm")
  stan_functions <- read_lines2(paste0(sc_path, "/ezdm_4par_functions.stan"))
  stanvars <- brms::stanvar(scode = stan_functions, block = "functions")

  # return the list
  nlist(formula, data, stanvars)
}


log_lik_ezdm_4par <- function(i, prep) {
  # compute log-likelihood using dezdm (vectorized over posterior samples)
  # based on bmf2bf.ezdm_4par formula:
  # Y = mean_rt_upper, vreal1 = mean_rt_lower
  # vreal2 = var_rt_upper, vreal3 = var_rt_lower
  dezdm(
    mean_rt = c(prep$data$Y[i], prep$data$vreal1[i]),
    var_rt = c(prep$data$vreal2[i], prep$data$vreal3[i]),
    n_upper = prep$data$vint1[i],
    n_trials = prep$data$vint2[i],
    drift = brms::get_dpar(prep, "drift", i = i),
    bound = brms::get_dpar(prep, "bound", i = i),
    ndt = brms::get_dpar(prep, "ndt", i = i),
    zr = brms::get_dpar(prep, "zr", i = i),
    s = brms::get_dpar(prep, "s", i = i),
    version = "4par",
    log = TRUE
  )
}

# posterior_predict for pp_check
# Returns predictions for EZDM 4par dependent variables
# By default returns mean_rt_upper (the primary response Y)
# Use dv argument to select other variables:
#   "mean_rt_upper", "mean_rt_lower", "var_rt_upper", "var_rt_lower", "n_upper"
# Usage: posterior_predict(fit, dv = "var_rt_upper")
posterior_predict_ezdm_4par <- function(i, prep, ..., dv = c(
                                          "mean_rt_upper", "mean_rt_lower",
                                          "var_rt_upper", "var_rt_lower", "n_upper"
                                        )) {
  dv <- match.arg(dv)
  rezdm(
    n = length(brms::get_dpar(prep, "drift", i = i)),
    n_trials = prep$data$vint2[i],
    drift = brms::get_dpar(prep, "drift", i = i),
    bound = brms::get_dpar(prep, "bound", i = i),
    ndt = brms::get_dpar(prep, "ndt", i = i),
    zr = brms::get_dpar(prep, "zr", i = i),
    s = brms::get_dpar(prep, "s", i = i),
    version = "4par"
  )[[dv]]
}

Try the bmm package in your browser

Any scripts or data that you put into this service are public.

bmm documentation built on March 30, 2026, 5:08 p.m.