R/inference_runs.R

Defines functions find_runs list_predictive_names list_parameter_names .get_run list_runs

Documented in find_runs list_parameter_names list_predictive_names list_runs

#' List runs from the Generable API
#'
#' List runs from the Generable API for a specific project.
#'
#' A run is generated by a model and a dataset. This function retrieves the
#' attributes about all runs within a project version. The returned `data.frame`
#' contains information about what draws and quantiles are available and the model
#' and dataset ids.
#'
#' The `run_draws` and `run_quantiles` columns contain two named lists of `parameter_names`
#' and `predictive_names`. The `parameter_names` is a list of the parameters of the model, e.g.
#' the sld parameters f, kg, and ks. The `predictive_names` is a list of the predicted quantities,
#' e.g. the predicted survival for each trial arm `predicted_survival_per_trial_arm`.
#'
#' Authentication (see \code{\link{login}}) is required prior to using this function
#' and this pulls the metadata from the Generable API.
#'
#' @note
#' A project can be specified by using the project name or a specific project version.
#' \enumerate{
#'   \item If a project is specified using the name, data is fetched for the latest version of the project.
#'   \item If a project is specified using the project version, the project name is not required.
#'   \item If neither a project nor a project version is provided, the default project or project version is used. These are set by the environment variables GECO_API_PROJECT and GECO_API_PROJECT_VERSION
#' }
#'
#' @param project Project name. If NULL, defaults to value of environment variable GECO_API_PROJECT
#' @param project_version_id Project version. If NULL, defaults to the most recent version of the project if provided, or the value of environment variable GECO_API_PROJECT_VERSION
#' @return data.frame of run attributes for the project specified
#' @seealso \code{\link{list_models}}, \code{\link{list_datasets}},
#'          \code{\link{fetch_quantiles}}, \code{\link{fetch_draws}}
#'
#' @importFrom magrittr %>%
#' @importFrom lubridate ymd_hms
#' @importFrom dplyr arrange
#' @importFrom dplyr desc
#' @export
list_runs <- function(project = NULL, project_version_id = NULL) {
  pv_id <- .process_project_inputs(project = project, project_version_id = project_version_id)
  ret <- geco_api(IRUNS, project_version_id = pv_id)
  if (length(ret$content) > 0) {
    d <- ret$content %>%
      purrr::map(purrr::map_if, ~ is.list(.x) & length(.x) > 1, ~ list(.x)) %>%
      purrr::map_dfr(tibble::as_tibble_row)
    # convert run_started_at into date-time field
    if ('run_started_on' %in% names(d)) {
      d <- d %>%
        dplyr::mutate(run_start_datetime = lubridate::ymd_hms(.data$run_started_on))
    }
    d <- d %>%
      dplyr::rename_at(.vars = dplyr::vars(-dplyr::starts_with('run_'),
                                           -.data$dataset_id, -.data$model_id),
                       .funs = ~ stringr::str_c('run_', .x))
    d <- d %>% dplyr::arrange(desc(.data$run_started_on))
  } else {
    d <- tibble::tibble(run_id = character(0), model_id = character(0), dataset_id = character(0))
    futile.logger::flog.info('No runs returned.')
  }
  d
}

.get_run <- function(project_version_id, run_id) {
  ret <- geco_api(IRUNS, project_version_id = project_version_id,
                  url_query_parameters = list(run_id = run_id))
  if (length(ret$content) == 1) {
    d <- ret$content %>%
      purrr::map(purrr::map_if, ~ is.list(.x) & length(.x) > 1, ~ list(.x)) %>%
      purrr::map_dfr(tibble::as_tibble_row)
    # convert run_started_at into date-time field
    if ('run_started_on' %in% names(d)) {
      d <- d %>%
        dplyr::mutate(run_start_datetime = lubridate::ymd_hms(.data$run_started_on))
    }
    d <- d %>%
      dplyr::rename_at(.vars = dplyr::vars(-dplyr::starts_with('run_'),
                                           -.data$dataset_id,
                                           -.data$model_id),
                       .funs = ~ stringr::str_c('run_', .x))
  } else {
    d <- tibble::tibble(run_id = character(0))
    futile.logger::flog.info('No runs returned.')
  }
  d
}

#' List the parameter names for a run
#'
#' List the parameter names from the Generable API for a specific run.
#'
#' A run is generated by a model and a dataset. This function retrieves the
#' names of all parameters for a specific run as a vector. This can be used in
#' \code{\link{fetch_quantiles}}.
#'
#' Authentication (see \code{\link{login}}) is required prior to using this function
#' and this pulls the list of parameter names from the Generable API.
#'
#' A project can be specified by using the project name or a specific project version.
#' \enumerate{
#'   \item If a project is specified using the name, data is fetched for the latest version of the project.
#'   \item If a project is specified using the project version, the project name is not required.
#'   \item If neither a project nor a project version is provided, the default project or project version is used. These are set by the environment variables GECO_API_PROJECT and GECO_API_PROJECT_VERSION
#' }
#'
#' @param run_id Run id; required
#' @param project Project name. If NULL, defaults to value of environment variable GECO_API_PROJECT
#' @param project_version_id Project version. If NULL, defaults to the most recent version of the project if provided, or the value of environment variable GECO_API_PROJECT_VERSION
#' @param include_raw (bool) if TRUE, include the raw parameters (on unconstrained scale) in the listing. Default FALSE
#' @return data.frame with name, description, and submodel for each parameter exposed from the specified run
#' @seealso \code{\link{list_models}}, \code{\link{list_datasets}},
#'          \code{\link{fetch_quantiles}}, \code{\link{fetch_draws}}
#'
#' @importFrom magrittr %>%
#' @importFrom dplyr filter
#' @export
list_parameter_names <- function(run_id, project = NULL, project_version_id = NULL, include_raw = FALSE) {

  pv_id <- .process_project_inputs(project = project, project_version_id = project_version_id)
  run <- .get_run(project_version_id = pv_id, run_id = run_id)
  if (nrow(run) == 0) {
    msg <- glue::glue('Provided run id: {run_id} was not found in project version {pv_id}.')
    futile.logger::flog.error(msg)
    stop(msg)
  }
  parameters <- sort(unlist((run %>% dplyr::pull(.data$run_quantiles))[[1]]$parameter_names))
  parameter_data <- tibble::tibble(name = parameters) %>%
    dplyr::left_join(tibble::tibble(name = names(.PAR_DESCRIPTIONS), description = .PAR_DESCRIPTIONS),
                     by = 'name') %>%
    dplyr::mutate(raw_scale = dplyr::if_else(stringr::str_detect(.data$name, pattern = '.*_raw$')
                                             | stringr::str_detect(.data$name, pattern = '.*_unif$'), TRUE, FALSE),
                  submodel = dplyr::case_when(stringr::str_detect(.data$name, pattern = 'kg')
                                              | stringr::str_detect(.data$name, pattern = 'ks')
                                              | stringr::str_detect(.data$name, pattern = 'f')
                                              | stringr::str_detect(.data$name, pattern = '^d')
                                              | stringr::str_detect(.data$name, pattern = 'sld') ~ 'biomarker',
                                              stringr::str_detect(.data$name, pattern = 'association') ~ 'association',
                                              stringr::str_detect(.data$name, pattern = 'hazard')
                                              | stringr::str_detect(.data$name, pattern = 'lambda')
                                              | stringr::str_detect(.data$name, pattern = 'betas') ~ 'hazard',
                                              TRUE ~ NA_character_))
  if (isFALSE(include_raw)) {
    parameter_data <- parameter_data %>%
      dplyr::filter(.data$raw_scale == FALSE)
  }
  return(parameter_data)
}

.PAR_DESCRIPTIONS <- c(
  predicted_relative_hazard="Predicted difference in log(hazard) per subject, by arm",
  smoking_exposure_betas="Relative hazard [as log(HR)] by smoking history",
  hazard_multiplier="Subject-specific relative hazard",
  lambdas="baseline hazard terms (RBF coefficients)",
  lambdas_per_study="study-specific baseline hazard terms (RBF coefficients)",
  survival_log_lik="subject-specific survival sub-model log-likelihood",
  survival_event="observed event status (0: censored, 1: observed) per subject",
  survival_time="observed event time per subject, in days",
  association_betas="relative hazard [as log(HR)] for each unit change in association state",
  kg="kg per subject (kg = growth rate among resistant cells)",
  ks="ks per subject (ks = shrinkage rate among susceptible cells",
  f="f per subject (f = portion of cells that are drug-susceptible)",
  d="d per subject (d = delay time to start of regrowth via kg)",
  bas_sld="subject-specific estimated SLD at time 0",
  association_states="derived quantities for association structure",
  sld_trial_arm_betas="SLD parameter offsets per trial arm",
  sld_trial_arm_betas_tau="SLD parameter variance per trial arm",
  sld_trial_arm_betas_tau_unif="helper term for SLD parameter estimation per trial arm",
  sld_trial_arm_betas_L_Omega="cholesky factor for correlation matrix among SLD parameters per trial arm",
  sigma_sld="measurement noise for biomarker measurements, on log scale",
  bas_sld_raw="bas_sld terms, before transformation",
  sigma_f_raw="variance in f term, before transformation",
  f_raw="f per subject, before transformation",
  sigma_ks_raw="variance in ks term, before transformation",
  ks_raw="ks per subject, before transformation",
  kg_raw="kg per subject, before transformation",
  sigma_kg_raw="variance in kg term among subjects, before transformation",
  log_sld_hat="expected log_sld for each observed measurement occasion",
  biomarker_log_lik="log_likelihood for observed biomarker values",
  predicted_biomarker="predicted biomarker values at prediction times",
  predicted_biomarker_hat_overall="population-level predicted values for expected biomarker, excluding measurement noise",
  predicted_biomarker_hat_per_trial_arm="trial-arm-level predicted values for expected biomarker, excluding measurement noise",
  predicted_biomarker_overall="population-level predicted values for biomarker measurements, including measurement noise",
  predicted_biomarker_per_trial_arm="trial-arm-level predicted values for biomarker measurements, including measurement noise",
  predicted_hazard="subject-specific predicted hazard rate (per day) at each follow-up time",
  predicted_hazard_overall="predicted hazard rate (per day) at each follow-up time, overall",
  predicted_hazard_per_study="predicted hazard rate (per day) at each follow-up time, by study",
  predicted_hazard_per_trial_arm="predicted hazard rate (per day) at each follow-up time, by trial arm",
  predicted_median_survival="predicted median survival time (in days) per subject",
  predicted_median_survival_overall="predicted median survival time (in days) overall",
  predicted_median_survival_per_study="predicted median survival time (in days) per study",
  predicted_median_survival_per_trial_arm="predicted median survival time (in days) per trial arm",
  predicted_survival="subject-specific predicted survival over time",
  predicted_survival_overall="predicted survival probability at each follow-up time, overall",
  predicted_survival_per_study="predicted survival probability at each follow-up time, by study",
  predicted_survival_per_trial_arm="predicted survival probability at each follow-up time, by trial arm",
  ppc_biomarker="predicted biomarker values at observed measurement times",
  betas_kg_trial_arm="log(kg) parameter offset per trial arm",
  betas_ks_trial_arm="log(ks) parameter offset per trial arm",
  betas_f_trial_arm="logit(f) parameter offset per trial arm",
  log_ks_trial_arm="log(ks) estimate per trial arm",
  log_kg_trial_arm="log(kg) estimate per trial arm",
  logit_f_trial_arm="logit(f) estimate per trial arm",
  log_ks_overall="log(ks) at population-level, excluding covariate effects",
  log_kg_overall="log(kg) at population-level, excluding covariate effects",
  logit_f_overall="logit(f) at population-level, excluding covariate effects",
  betas_f_trial_arm_raw='Estimate of how f varies by trial-arm, before transformation',
  betas_kg_trial_arm_raw='Estimate of how kg varies by trial-arm, before transformation',
  betas_ks_trial_arm_raw='Estimate of how ks varies by trial-arm, before transformation',
  f_overall_raw='f per subject, before transformation',
  kg_overall_raw='kg per subject, before transformation',
  ks_overall_raw='ks per subject, before transformation',
  log_bas_sld='baseline sld (sld at time 0) per subject, on log-scale',
  log_hazard_multiplier='relative hazard per subject (xB), including covariate and association effects',
  raw_lambdas_per_study='study-specific baseline hazard terms (RBF coefficients), before transformation',
  smoking_exposure_betas_raw='Relative hazard [as log(HR)] by smoking history, before transformation',
  L_Omega_betas_f_trial_arm='cholesky factor for correlation matrix for shared variance structure in betas_f_trial_arm',
  L_Omega_betas_kg_trial_arm='cholesky factor for correlation matrix for shared variance structure in betas_kg_trial_arm',
  L_Omega_betas_ks_trial_arm='cholesky factor for correlation matrix for shared variance structure in betas_ks_trial_arm',
  predicted_log_hazard_ratio='predicted log(HR) for treatment vs control arms',
  predicted_trial_arm_state='predicted trial-arm states, from trial-arm-level parameters'
)

#' List the predictive names for a run
#'
#' List the predictive names from the Generable API for a specific run.
#'
#' A run is generated by a model and a dataset. This function retrieves the
#' names of all predictive quantities for a specific run as a vector. This can be used in
#' \code{\link{fetch_quantiles}}.
#'
#' Authentication (see \code{\link{login}}) is required prior to using this function
#' and this pulls the list of predictive names names from the Generable API.
#'
#' A project can be specified by using the project name or a specific project version.
#' \enumerate{
#'   \item If a project is specified using the name, data is fetched for the latest version of the project.
#'   \item If a project is specified using the project version, the project name is not required.
#'   \item If neither a project nor a project version is provided, the default project or project version is used. These are set by the environment variables GECO_API_PROJECT and GECO_API_PROJECT_VERSION
#' }
#'
#' @param run_id Run id; required
#' @param project Project name. If NULL, defaults to value of environment variable GECO_API_PROJECT
#' @param project_version_id Project version. If NULL, defaults to the most recent version of the project if provided, or the value of environment variable GECO_API_PROJECT_VERSION
#' @return data.frame with name, description, and submodel for each predictive quantity exposed from the specified run
#' @seealso \code{\link{list_models}}, \code{\link{list_datasets}},
#'          \code{\link{fetch_quantiles}}, \code{\link{fetch_draws}}
#'
#' @importFrom magrittr %>%
#' @importFrom dplyr filter
#' @export
list_predictive_names <- function(run_id, project = NULL, project_version_id = NULL) {
  pv_id <- .process_project_inputs(project = project, project_version_id = project_version_id)
  run <- .get_run(project_version_id = pv_id, run_id = run_id)
  if (nrow(run) == 0) {
    msg <- glue::glue('Provided run id: {run_id} was not found in project version {pv_id}.')
    futile.logger::flog.error(msg)
    stop(msg)
  }
  parameters <- sort(unlist((run %>% dplyr::pull(.data$run_quantiles))[[1]]$predictive_names))
  parameter_data <- tibble::tibble(name = parameters) %>%
    dplyr::left_join(tibble::tibble(name = names(.PAR_DESCRIPTIONS), description = .PAR_DESCRIPTIONS),
                     by = 'name') %>%
    dplyr::mutate(raw_scale = FALSE,
                  submodel = dplyr::case_when(stringr::str_detect(.data$name, pattern = 'biomarker') ~ 'biomarker',
                                              stringr::str_detect(.data$name, pattern = 'hazard')
                                              | stringr::str_detect(.data$name, pattern = 'survival') ~ 'hazard',
                                              stringr::str_detect(.data$name, pattern = 'state') ~ 'association',
                                              TRUE ~ NA_character_))
}


#' Find and filter runs on key features
#'
#' This function is a higher-level wrapper around \code{\link{list_runs}} to aid in run discovery
#'
#' A run is generated by a model and a dataset. This function retrieves the
#' key features for all runs meeting certain criteria within a project version.
#'
#' The returned `data.frame` combines information about the run, model, and datasets
#' for each run.
#'
#' Authentication (see \code{\link{login}}) is required prior to using this function
#' and this pulls the list of parameter names from the Generable API.
#'
#' A project can be specified by using the project name or a specific project version.
#' \enumerate{
#'   \item If a project is specified using the name, data is fetched for the latest version of the project.
#'   \item If a project is specified using the project version, the project name is not required.
#'   \item If neither a project nor a project version is provided, the default project or project version is used. These are set by the environment variables GECO_API_PROJECT and GECO_API_PROJECT_VERSION
#' }
#'
#' @param project Project name. If NULL, defaults to value of environment variable GECO_API_PROJECT
#' @param project_version_id Project version. If NULL, defaults to the most recent version of the project if provided, or the value of environment variable GECO_API_PROJECT_VERSION
#' @param model_type (character vector) filter to runs with this model type, as one of: joint, survival, biomarker. NULL to disable this filter.
#' @param model_version (character vector) filter to runs with this model version string. NULL to disable this filter.
#' @param min_draws (scalar int) filter to runs with >= this many draws combined across all chains. NULL to disable this filter.
#' @param extra_fields (character vector) names of additional fields to include in the summary. See results of \code{\link{list_datasets}}, \code{\link{list_runs}}, and \code{\link{list_models}} for available fields.
#' @return a data frame with key metadata about the run.
#' @seealso \code{\link{list_runs}}, \code{\link{list_models}}, \code{\link{list_datasets}},
#'          \code{\link{fetch_quantiles}}, \code{\link{fetch_draws}}
#' @export
find_runs <- function(project = NULL, project_version_id = NULL,
                      model_type = NULL, model_version = NULL,
                      min_draws = 100, extra_fields = c()) {
  # format inputs
  checkmate::assert_character(model_type, null.ok = TRUE, unique = TRUE)
  checkmate::assert_character(model_version, null.ok = TRUE, unique = TRUE)
  checkmate::assert_int(min_draws, null.ok = TRUE)
  if (!is.null(model_type))
    model_type <- match.arg(model_type, choices = c('joint', 'biomarker', 'survival'), several.ok = TRUE)
  pv_id <- .process_project_inputs(project = project, project_version_id = project_version_id)

  # get run info
  runs <- list_runs(project_version_id = pv_id) %>%
    dplyr::left_join(list_datasets(project_version_id = pv_id),
                     by = 'dataset_id') %>%
    dplyr::left_join(list_models(project_version_id = pv_id),
                     by = 'model_id')

  if (nrow(runs) > 0) {
    run_info <- runs %>%
      extract_subsample_info() %>%
      tidyr::unnest_wider(.data$run_args) %>%
      dplyr::mutate(run_started_on = lubridate::ymd_hms(.data$run_started_on),
                    model_version = factor(.data$model_version, ordered = TRUE))

    # process filters
    if (!is.null(model_type)) {
      run_info <- run_info %>%
        dplyr::filter(.data$model_type %in% !!model_type)
    }
    if (!is.null(model_version)) {
      run_info <- run_info %>%
        dplyr::filter(.data$model_version %in% !!model_version)
    }
    if (!is.null(min_draws)) {
      run_info <- run_info %>%
        dplyr::filter(.data$num_draws >= !!min_draws)
    }

    # show key fields
    return(
      run_info %>%
        dplyr::select(.data$run_id, .data$dataset_description, .data$sample_id,
                      .data$model_type, .data$model_version, .data$run_started_on,
                      !!!rlang::syms(extra_fields)) %>%
        dplyr::arrange(.data$run_started_on)
    )
  } else {
    return(runs)
  }
}
generable/rgeco documentation built on Oct. 16, 2024, 2:45 a.m.