Nothing
#' @autoglobal
has_names <- function(x, nm) {
all(nm %in% names(x))
}
#' Warn if survey has multiple observations per participant
#'
#' @description
#' Issues a warning when a survey contains multiple observations per
#' participant (more rows than unique part_id values).
#'
#' @param participants participant data.table
#' @param observation_key optional column name(s) identifying observations
#' @param filter_hint character; "pipeline" for pipeline-style hint or
#' "legacy" for contact_matrix-style hint
#' @returns NULL invisibly
#' @keywords internal
#' @autoglobal
warn_multiple_observations <- function(
participants,
observation_key = NULL,
filter_hint = c("pipeline", "legacy")
) {
filter_hint <- match.arg(filter_hint)
n_participants <- uniqueN(participants$part_id)
n_rows <- nrow(participants)
if (n_participants >= n_rows) {
return(invisible(NULL))
}
has_obs_key <- !is.null(observation_key) && length(observation_key) > 0
hint <- if (has_obs_key && filter_hint == "pipeline") {
cli::format_inline(
"Use {.code survey[{observation_key} == ...]} to select specific \\
observations before calling {.fn compute_matrix}."
)
} else if (has_obs_key) {
cli::format_inline(
"Use {.arg filter} to select by {.val {observation_key}}."
)
} else if (filter_hint == "pipeline") {
"Filter the survey with {.code survey[...]} to select specific \\
observations before calling {.fn compute_matrix}."
} else {
"Use the {.arg filter} argument to select specific observations."
}
cli::cli_warn(c(
"Survey contains multiple observations per participant \\
({n_rows} rows, {n_participants} unique participants).",
"*" = "Results will aggregate across all observations.",
i = hint
))
invisible(NULL)
}
#' Impute ages from ranges (generic helper)
#'
#' @description
#' Generic function to impute ages from min/max ranges. Works for both
#' participant and contact data by specifying the column prefix.
#'
#' @param data A data.table containing age data
#' @param prefix Column name prefix: "part_age" for participants, "cnt_age" for
#' contacts
#' @param estimate Imputation method: "mean", "sample", or "missing"
#' @returns The data with ages imputed according to the specified method
#' @autoglobal
#' @keywords internal
impute_ages <- function(
data,
prefix,
estimate = c("mean", "sample", "missing")
) {
if (!is.data.frame(estimate)) {
estimate <- rlang::arg_match(estimate)
}
# Build column names from prefix
age_col <- prefix
exact_col <- paste0(prefix, "_exact")
min_col <- paste0(prefix, "_est_min")
max_col <- paste0(prefix, "_est_max")
age_cols_in_data <- has_names(data, c(min_col, max_col))
if (!age_cols_in_data || identical(estimate, "missing")) {
return(data)
}
if (identical(estimate, "mean")) {
# Impute using mean of min/max range
data[
is.na(get(exact_col)) &
!is.na(get(min_col)) &
!is.na(get(max_col)),
(age_col) := as.integer(rowMeans(.SD)),
.SDcols = c(min_col, max_col)
]
} else if (identical(estimate, "sample")) {
# Impute by sampling uniformly from range; runif() excludes its upper
# bound, so use max + 1 to make the integer draw inclusive of max.
data[
is.na(get(age_col)) &
!is.na(get(min_col)) &
!is.na(get(max_col)) &
get(min_col) <= get(max_col),
(age_col) := as.integer(runif(.N, get(min_col), get(max_col) + 1L))
]
} else if (is.data.frame(estimate)) {
# Impute by sampling from a supplied age distribution within [min, max]
rows_to_impute <- data[,
which(
is.na(get(age_col)) &
!is.na(get(min_col)) &
!is.na(get(max_col)) &
get(min_col) <= get(max_col)
)
]
if (length(rows_to_impute) > 0) {
data[
rows_to_impute,
(age_col) := sample_from_age_distribution(
get(min_col),
get(max_col),
estimate
)
]
}
}
data
}
#' Sample ages from a distribution within `[min, max]` bands
#' @param mins integer vector of lower bounds
#' @param maxs integer vector of upper bounds
#' @param distribution data.frame with `age` and `proportion` columns
#' @returns integer vector of sampled ages
#' @keywords internal
sample_from_age_distribution <- function(mins, maxs, distribution) {
n <- length(mins)
result <- integer(n)
fell_back <- FALSE
for (i in seq_len(n)) {
lo <- mins[i]
hi <- maxs[i]
sub_dist <- distribution[distribution$age >= lo & distribution$age <= hi, ]
if (nrow(sub_dist) == 0 || sum(sub_dist$proportion) == 0) {
fell_back <- TRUE
# Upper bound inclusive: runif() excludes its upper bound, so use hi + 1
result[i] <- as.integer(runif(1, lo, hi + 1))
} else if (nrow(sub_dist) == 1) {
result[i] <- sub_dist$age
} else {
result[i] <- sample(sub_dist$age, size = 1, prob = sub_dist$proportion)
}
}
if (fell_back) {
cli::cli_warn(
"Age distribution has no coverage for some age bands; \\
falling back to uniform sampling for those contacts."
)
}
result
}
#' Impute participant ages
#'
#' @description
#' Imputes participant survey data, where variables are named:
#' "part_age_est_min" and "part_age_est_max". Uses mean imputation, sampling
#' (hot deck), or leaves them as missing. These are controlled by the
#' `estimate` argument.
#'
#' @param participants A survey data set of participants
#' @param estimate if set to "mean" (default), people whose ages are given as a
#' range (in columns named "..._est_min" and "..._est_max") but not exactly
#' (in a column named "..._exact") will have their age set to the mid-point of
#' the range; if set to "sample", the age will be sampled from the range; if
#' set to "missing", age ranges will be treated as missing
#'
#' @returns
#' The participant data, potentially with participant ages imputed depending on
#' the `estimate` method and whether age columns are present in the data.
#'
#' @autoglobal
#' @keywords internal
impute_participant_ages <- function(
participants,
estimate = c("mean", "sample", "missing")
) {
impute_ages(data = participants, prefix = "part_age", estimate = estimate)
}
#' @autoglobal
drop_ages_below_age_limit <- function(data, age_limits) {
data[is.na(cnt_age) | cnt_age >= min(age_limits), ]
}
#' Impute contact ages
#'
#' @description
#' Imputes contact survey data, where variables are named:
#' "cnt_age_est_min" and "cnt_age_est_max". Uses mean imputation, sampling
#' (hot deck), or leaves them as missing. These are controlled by the
#' `estimate` argument.
#'
#' @param contacts a survey data set of contacts
#' @param estimate if set to "mean" (default), contacts whose ages are given as
#' a range (in columns named "..._est_min" and "..._est_max") but not exactly
#' (in a column named "..._exact") will have their age set to the mid-point of
#' the range; if set to "sample", the age will be sampled from the range; if
#' set to "missing", age ranges will be treated as missing
#'
#' @returns
#' The contact data, potentially with contact ages imputed depending on the
#' `estimate` method and whether age columns are present in the data.
#'
#' @autoglobal
#' @keywords internal
impute_contact_ages <- function(
contacts,
estimate = c("mean", "sample", "missing")
) {
impute_ages(data = contacts, prefix = "cnt_age", estimate = estimate)
}
#' @autoglobal
drop_missing_contact_ages <- function(contacts, missing_action) {
if (missing_action == "ignore" && nrow(contacts[is.na(cnt_age)]) > 0) {
contacts <- contacts[!is.na(cnt_age), ]
}
contacts
}
#' @autoglobal
max_participant_age <- function(data) {
if (has_names(data, c("part_age_est_max", "part_age_exact"))) {
part_age_data <- c(data[, part_age_exact], data[, part_age_est_max])
max_year <- max(part_age_data, na.rm = TRUE) + 1
} else {
max_year <- max(data[, part_age], na.rm = TRUE) + 1
}
max_year
}
#' @autoglobal
create_age_breaks <- function(age_limits, max_age) {
c(age_limits, max(max_age, max(age_limits) + 1))
}
#' @autoglobal
get_age_group_lower_limits <- function(age_limits) {
age_limits
}
#' @autoglobal
get_age_limits <- function(survey) {
participants <- survey$participants
contacts <- survey$contacts
part_ages <- if ("part_age" %in% colnames(participants)) {
as.integer(participants[, part_age])
} else if ("part_age_exact" %in% colnames(participants)) {
as.integer(participants[, part_age_exact])
} else {
integer(0)
}
cnt_ages <- if ("cnt_age" %in% colnames(contacts)) {
as.integer(contacts[, cnt_age])
} else if ("cnt_age_exact" %in% colnames(contacts)) {
as.integer(contacts[, cnt_age_exact])
} else {
integer(0)
}
all_ages <- c(part_ages, cnt_ages)
unique_non_missing_ages <- unique(all_ages[!is.na(all_ages)])
union(0, sort(unique_non_missing_ages))
}
#' @autoglobal
filter_countries <- function(participants, countries) {
multiple_countries <- length(countries) > 0
country_col_in_participants <- "country" %in% colnames(participants)
if (multiple_countries && country_col_in_participants) {
countries <- flexible_countrycode(countries)
participants <- participants[country %in% countries]
if (nrow(participants) == 0) {
cli::cli_abort(
"No participants left after selecting countries: {.val {countries}}"
)
}
}
participants
}
#' Add age column from exact age (generic helper)
#'
#' @description
#' Generic function to add an age column from an exact age column. Works for
#' both participant and contact data by specifying the column prefix.
#' If `<prefix>_exact` exists, it overwrites `<prefix>` with its values.
#' Otherwise, it creates `<prefix>` with NA values if it doesn't exist.
#'
#' @param data A data.table containing age data
#' @param prefix Column name prefix: "part_age" for participants, "cnt_age" for
#' contacts
#' @returns The data with the age column set from exact ages or
#' initialised to NA
#' @autoglobal
#' @keywords internal
add_age <- function(data, prefix) {
age_col <- prefix
exact_col <- paste0(prefix, "_exact")
if (exact_col %in% colnames(data)) {
data <- data[, (age_col) := as.integer(get(exact_col))]
} else if (!(age_col %in% colnames(data))) {
data <- data[, (age_col) := NA_integer_]
}
data
}
#' @autoglobal
add_part_age <- function(participants) {
add_age(data = participants, prefix = "part_age")
}
#' @autoglobal
add_contact_age <- function(contacts) {
add_age(data = contacts, prefix = "cnt_age")
}
#' @autoglobal
drop_invalid_ages <- function(
participants,
missing_action,
age_limits
) {
if (is.null(age_limits)) {
cli::cli_abort("{.arg age_limits} must be provided")
}
ppt_no_age_info <- participants[is.na(part_age) | part_age < min(age_limits)]
no_age_info <- nrow(ppt_no_age_info) > 0
if (missing_action == "remove" && no_age_info) {
participants <- participants[!is.na(part_age) & part_age >= min(age_limits)]
}
participants
}
#' @autoglobal
drop_invalid_contact_ages <- function(
contacts,
participants,
missing_action
) {
if (missing_action == "remove" && nrow(contacts[is.na(cnt_age)]) > 0) {
missing.age.id <- contacts[is.na(cnt_age), part_id]
participants <- participants[!(part_id %in% missing.age.id)]
}
participants
}
## convert factors to integers, preserving numeric values
#' @autoglobal
convert_factor_to_integer <- function(
data,
cols
) {
which_factors <- sapply(data, is.factor)
factor_cols <- intersect(cols, names(data)[which_factors])
data[,
(factor_cols) := lapply(.SD, function(x) {
# Convert factor levels to integers - non-numeric levels become NA
num_levels <- suppressWarnings(as.integer(levels(x)))
# Warn if any level failed to parse (i.e. non-numeric)
if (any(is.na(num_levels) & !is.na(levels(x)))) {
cli::cli_warn("Non-numeric factor levels found; these will become NA")
}
num_levels[x]
}),
.SDcols = factor_cols
]
}
## check if any filters have been requested
#' @autoglobal
apply_data_filter <- function(
survey,
survey_type,
filter,
call = rlang::caller_env()
) {
if (!is.null(filter)) {
missing_columns <- list()
for (table in survey_type) {
if (nrow(survey[[table]]) > 0) {
missing_columns <- c(
missing_columns,
list(setdiff(names(filter), colnames(survey[[table]])))
)
## filter contact data
for (column in names(filter)) {
if (column %in% colnames(survey[[table]])) {
survey[[table]] <- survey[[table]][get(column) == filter[[column]]]
}
}
}
}
missing_all <- do.call(intersect, missing_columns)
if (length(missing_all) > 0) {
cli::cli_warn(
message = "Filter column{?s}: {.val {missing_all}} not found.",
call = call
)
}
}
survey
}
# converts from [0,1) [1,5) [5,15) [15,80) to [0,1) [1,5) [5,15) [15,Inf)
#' @autoglobal
final_age_group_label <- function(age_groups) {
last <- age_groups[length(age_groups)]
lower <- sub("\\[([0-9]+),.*$", "\\1", last)
age_groups[length(age_groups)] <- paste0("[", lower, ",Inf)")
age_groups
}
# adjust age.group.breaks to the lower and upper ages in the survey
#' @autoglobal
adjust_ppt_age_group_breaks <- function(
participants,
age_limits
) {
max_age <- max_participant_age(participants)
participants[,
lower.age.limit := reduce_agegroups(
x = part_age,
limits = age_limits
)
]
part_age_group_breaks <- create_age_breaks(age_limits, max_age)
participants[,
age.group := cut(
participants[, part_age],
breaks = part_age_group_breaks,
right = FALSE
)
]
age.groups <- age_group_labels(participants)
participants[,
age.group := factor(
age.group,
levels = levels(age.group),
labels = age.groups
)
]
part_age_group_present <- get_age_group_lower_limits(age_limits)
participants <- add_upper_age_limits(
participants = participants,
part_age_group_present = part_age_group_present,
part_age_group_breaks = part_age_group_breaks
)
participants
}
#' @autoglobal
age_group_labels <- function(participants) {
age_groups <- participants[, levels(age.group)]
age_groups <- final_age_group_label(age_groups)
age_groups
}
#' @autoglobal
add_upper_age_limits <- function(
participants,
part_age_group_present,
part_age_group_breaks
) {
lower_upper_age_limits <- data.table(
lower.age.limit = part_age_group_present,
upper.age.limit = part_age_group_breaks[-1]
)
participants <- merge(
participants,
lower_upper_age_limits,
by = "lower.age.limit",
all.x = TRUE
)
participants
}
#' @autoglobal
survey_pop_from_data <- function(survey_pop, part_age_group_present) {
survey_pop <- data.table(survey_pop)
# make sure max survey_pop age exceeds participant age group breaks
if (max(survey_pop$lower.age.limit) < max(part_age_group_present)) {
survey_pop <- rbind(
survey_pop,
list(max(part_age_group_present + 1), 0)
)
}
survey_pop
}
#' @autoglobal
get_survey_countries <- function(survey_pop, countries, participants) {
if (!is.null(survey_pop)) {
## survey population is given as vector of countries
survey_countries <- survey_pop
} else if (!is.null(countries)) {
## survey population not given but countries requested from
## survey - get population data from those countries
survey_countries <- countries
} else if ("country" %in% colnames(participants)) {
## neither survey population nor country names given - try to
## guess country or countries surveyed from participant data
survey_countries <- unique(participants[, country])
}
survey_countries
}
#' @autoglobal
survey_is_representative <- function(countries, participants, survey_pop) {
no_countries <- is.null(countries) && !("country" %in% colnames(participants))
survey_representative <- is.null(survey_pop) && no_countries
survey_representative
}
#' @autoglobal
survey_pop_from_countries <- function(
survey_pop,
countries,
participants,
age_limits,
call = rlang::caller_env()
) {
# no countries, and no survey_pop
survey_representative <- survey_is_representative(
countries = countries,
participants = participants,
survey_pop = survey_pop
)
warn_if_no_survey_countries(survey_representative, call = call)
# there aren't countries or survey pop, get the countries
if (!survey_representative) {
survey_countries <- get_survey_countries(
survey_pop = survey_pop,
countries = countries,
participants = participants
)
## get population data for countries from 'wpp' package
country_pop <- data.table(wpp_age(survey_countries))
country_pop$country <- normalise_country_names(country_pop$country)
## check if survey data are from a specific year - in that case
## use demographic data from that year, otherwise latest
if ("year" %in% colnames(participants)) {
survey_year <- participants[, median(year, na.rm = TRUE)]
} else {
survey_year <- country_pop[, max(year, na.rm = TRUE)]
cli::cli_warn(
"No information on {.val year} found in the data. Will use
{.val {survey_year}} population data."
)
}
## check if any survey countries are not in wpp
check_any_missing_countries(survey_countries, country_pop)
## get demographic data closest to survey year
country_pop_year <- unique(country_pop[, year])
survey_year <- min(
country_pop_year[which.min(abs(survey_year - country_pop_year))]
)
survey_pop <- country_pop[year == survey_year][,
list(population = sum(population)),
by = "lower.age.limit"
]
}
if (survey_representative) {
survey_pop <- participants[, list(population = .N), by = lower.age.limit]
survey_pop <- survey_pop[!is.na(lower.age.limit)]
if ("year" %in% colnames(participants)) {
survey_year <- participants[, median(year, na.rm = TRUE)]
} else {
survey_year <- NULL
}
}
list(
survey_year = survey_year,
survey_pop = survey_pop
)
}
#' @autoglobal
survey_pop_year <- function(
survey_pop,
countries,
participants,
age_limits
) {
if (is.null(survey_pop) || is.character(survey_pop)) {
survey_pop_info <- survey_pop_from_countries(
survey_pop = survey_pop,
countries = countries,
participants = participants,
age_limits = age_limits
)
survey_pop <- survey_pop_info$survey_pop
survey_year <- survey_pop_info$survey_year
} else {
part_age_group_present <- get_age_group_lower_limits(age_limits)
# survey_pop is a data frame with 'lower.age.limit' and 'population'
survey_pop <- survey_pop_from_data(survey_pop, part_age_group_present)
# add dummy survey_year
survey_year <- NA_integer_
}
list(
survey_pop = survey_pop,
survey_year = survey_year
)
}
#' @autoglobal
add_survey_upper_age_limit <- function(survey, age_breaks) {
# add upper.age.limit after sorting the survey ages (and add
# maximum age > given ages)
survey <- survey[order(lower.age.limit), ]
# if any lower age limits are missing remove them
survey <- survey[!is.na(population)]
survey$upper.age.limit <- unlist(c(
survey[-1, "lower.age.limit"],
1 + max(survey$lower.age.limit, age_breaks)
))
survey
}
#' @autoglobal
survey_pop_reference <- function(survey_pop, ...) {
data.table(
pop_age(
survey_pop,
seq(
min(survey_pop$lower.age.limit),
max(survey_pop$upper.age.limit)
),
...
)
)
}
#' @autoglobal
adjust_survey_age_groups <- function(survey_pop, part_age_group_present, ...) {
survey_pop_max <- max(survey_pop$upper.age.limit)
survey_pop <- data.table(pop_age(survey_pop, part_age_group_present, ...))
## use the actual lower.age.limits from survey_pop (which may be a subset
## of part_age_group_present if population data doesn't cover all ages)
actual_age_limits <- survey_pop$lower.age.limit
## set upper age limits
survey_pop[,
upper.age.limit := c(actual_age_limits[-1], survey_pop_max)
]
}
#' @autoglobal
weight_by_day_of_week <- function(
participants,
call = rlang::caller_env()
) {
found_dayofweek <- FALSE
if ("dayofweek" %in% colnames(participants)) {
## Add column sum_weight: Number of entries on weekdays / weekends
participants[,
sum_weight := nrow(.SD),
by = (dayofweek %in% 1:5),
]
## The sum of the weights on weekdays is 5
participants[dayofweek %in% 1:5, weight := 5 / sum_weight]
## The sum of the weights on weekend is 2
participants[!(dayofweek %in% 1:5), weight := 2 / sum_weight]
participants[, sum_weight := NULL]
found_dayofweek <- TRUE
# add boolean for "weekday"
participants[, is.weekday := dayofweek %in% 1:5]
}
if (!found_dayofweek) {
cli::cli_warn(
message = c(
"{.code weigh_dayofweek} is {.val TRUE}, but no {.col dayofweek} \\
column in the data.",
# nolint start
"i" = "Will ignore."
# nolint end
),
call = call
)
}
participants
}
#' @autoglobal
weight_by_age <- function(participants, survey_pop_full) {
# get number and proportion of participants by age
participants[, age.count := .N, by = part_age]
participants[, age.proportion := age.count / .N]
# get reference population by age (absolute and proportional)
part_age_all <- range(unique(participants[, part_age]))
survey_pop_detail <- data.table(pop_age(
survey_pop_full,
seq(part_age_all[1], part_age_all[2] + 1)
))
names(survey_pop_detail) <- c("part_age", "population.count")
survey_pop_detail[,
population.proportion := population.count / sum(population.count)
]
# merge reference and survey population data
participants <- merge(
participants,
survey_pop_detail,
by = "part_age"
)
# calculate age-specific weights
participants[, weight.age := population.proportion / age.proportion]
# merge 'weight.age' into 'weight'
participants[, weight := weight * weight.age]
## Remove the additional columns
participants[, age.count := NULL]
participants[, age.proportion := NULL]
participants[, population.count := NULL]
participants[, population.proportion := NULL]
participants[, weight.age := NULL]
}
#' @autoglobal
weigh_by_user_defined <- function(participants, weights) {
for (i in seq_along(weights)) {
if (weights[i] %in% colnames(participants)) {
## Compute the overall weight
participants[, weight := weight * get(weights[i])]
}
}
participants
}
#' Post-stratification weight normalisation
#'
#' @description
#' Normalises participant weights within groups so that they sum to the number
#' of participants in each group. Optionally truncates extreme weights to a
#' threshold and re-normalises.
#'
#' @param participants participant data.table with a `weight` column
#' @param by character; column name(s) to group by (default "age.group")
#' @param threshold numeric; if provided, weights above this value are capped
#' and the weights are re-normalised (default NULL)
#' @returns the participants data.table (modified by reference)
#' @keywords internal
#' @autoglobal
normalise_weights <- function(
participants,
by = "age.group",
threshold = NULL
) {
participants[, weight := weight / sum(weight) * .N, by = c(by)]
if (!is.null(threshold) && !is.na(threshold)) {
participants[weight > threshold, weight := threshold]
participants[, weight := weight / sum(weight) * .N, by = c(by)]
}
invisible(participants)
}
## merge participants and contacts into a single data table
#' @autoglobal
merge_participants_contacts <- function(participants, contacts) {
setkey(participants, part_id)
# Merge participants with contacts, allowing Cartesian products as one
# participant can have multiple contacts
contacts <- merge(
x = contacts,
y = participants,
by = "part_id",
all = FALSE,
allow.cartesian = TRUE,
suffixes = c(".cont", ".part")
)
setkey(contacts, part_id)
contacts
}
## some contacts in the age group have an age, sample from these
#' @autoglobal
sample_present_age <- function(contacts, this_age_group) {
contacts[
is.na(cnt_age) & age.group == this_age_group,
cnt_age := sample(
contacts[
!is.na(cnt_age) & age.group == this_age_group,
cnt_age
],
size = .N,
replace = TRUE
)
]
}
#' @autoglobal
sample_uniform_age <- function(contacts, this_age_group) {
min_contact_age <- contacts[, min(cnt_age, na.rm = TRUE)]
max_contact_age <- contacts[, max(cnt_age, na.rm = TRUE)]
contacts[
is.na(cnt_age) & age.group == this_age_group,
cnt_age := as.integer(floor(runif(
.N,
min = min_contact_age,
max = max_contact_age + 1
)))
]
}
#' @autoglobal
impute_age_by_sample <- function(contacts) {
for (this_age_group in unique(contacts[is.na(cnt_age), age.group])) {
## first, deal with missing age
if (nrow(contacts[!is.na(cnt_age) & age.group == this_age_group]) > 0) {
## some contacts in the age group have an age, sample from these
contacts <- sample_present_age(contacts, this_age_group)
} else if (nrow(contacts[!is.na(cnt_age), ]) > 0) {
## no contacts in the age group have an age, sample uniformly
contacts <- sample_uniform_age(contacts, this_age_group)
}
}
# make sure the final set does not contain NA's anymore
contacts <- contacts[!is.na(cnt_age), ]
contacts
}
#' @autoglobal
add_contact_age_groups <- function(
contacts,
age_breaks,
age_groups
) {
max_contact_age <- contacts[, max(cnt_age, na.rm = TRUE) + 1]
if (max_contact_age > max(age_breaks)) {
age_breaks[length(age_breaks)] <- max_contact_age
}
contacts[,
contact.age.group := cut(
cnt_age,
breaks = age_breaks,
labels = age_groups,
right = FALSE
)
]
}
#' @autoglobal
create_bootstrap_weights <- function(part_sample) {
sample_table <- data.table(id = part_sample, weight = 1)
sample_table <- sample_table[,
list(bootstrap.weight = sum(weight)),
by = id
]
setnames(sample_table, "id", "part_id")
setkey(sample_table, part_id)
sample_table
}
#' @autoglobal
sample_from_participants <- function(
participants,
contacts,
age_limits,
sample_all_age_groups,
max.tries = 1000
) {
participant_ids <- unique(participants$part_id)
## check upfront if sampling all age groups is possible
if (sample_all_age_groups) {
present_age_limits <- unique(participants$lower.age.limit)
missing_age_limits <- setdiff(age_limits, present_age_limits)
if (length(missing_age_limits) > 0) {
cli::cli_abort(
"Cannot sample all age groups: no participants in age groups
starting at {.val {missing_age_limits}}."
)
}
}
good_sample <- FALSE
tries <- 0L
while (!good_sample) {
tries <- tries + 1L
if (sample_all_age_groups && tries > max.tries) {
cli::cli_abort(
"Failed to draw a bootstrap sample covering all age groups after
{.val {max.tries}} attempts."
)
}
## take a sample from the participants
part_sample <- sample(participant_ids, replace = TRUE)
part_age_limits <- unique(
participants[part_id %in% part_sample, lower.age.limit]
)
age_limits_match_part <- setequal(age_limits, part_age_limits)
good_sample <- !sample_all_age_groups || age_limits_match_part
sample_table <- create_bootstrap_weights(part_sample)
sampled_contacts <- merge(contacts, sample_table, by = "part_id")
sampled_contacts[, sampled.weight := weight * bootstrap.weight]
sampled_participants <- merge(participants, sample_table)
sampled_participants[, sampled.weight := weight * bootstrap.weight]
}
list(
sampled_contacts = sampled_contacts,
sampled_participants = sampled_participants
)
}
#' @autoglobal
sample_contacts_participants <- function(
sample_participants,
participants,
contacts,
age_limits,
sample_all_age_groups,
max.tries = 1000
) {
if (sample_participants) {
sampled_contacts_participants <- sample_from_participants(
participants,
contacts,
age_limits,
sample_all_age_groups,
max.tries
)
} else {
## just use all participants
sampled_contacts_participants <- list(
sampled_contacts = contacts[, sampled.weight := weight],
sampled_participants = participants[, sampled.weight := weight]
)
}
sampled_contacts_participants
}
#' @autoglobal
weighted_matrix_array <- function(contacts) {
weighted_matrix <- xtabs(
data = contacts,
formula = sampled.weight ~ age.group + contact.age.group,
addNA = TRUE
)
array(
weighted_matrix,
dim = dim(weighted_matrix),
dimnames = dimnames(weighted_matrix)
)
}
#' @autoglobal
calculate_weighted_matrix <- function(
sampled_contacts = sampled_contacts,
sampled_participants = sampled_participants,
survey_pop = survey_pop,
symmetric,
counts,
symmetric_norm_threshold
) {
weighted_matrix <- weighted_matrix_array(
contacts = sampled_contacts
)
if (!counts) {
## normalise to give mean number of contacts
weighted_matrix <- normalise_weights_to_counts(
sampled_participants = sampled_participants,
weighted_matrix = weighted_matrix
)
}
warn_symmetric_counts_na(symmetric, counts, weighted_matrix)
matrix_not_scalar <- prod(dim(as.matrix(weighted_matrix))) > 1
na_in_weighted_mtx <- na_in_weighted_matrix(weighted_matrix)
if (symmetric && matrix_not_scalar && !na_in_weighted_mtx) {
weighted_matrix <- normalise_weighted_matrix(
survey_pop = survey_pop,
weighted_matrix = weighted_matrix,
symmetric_norm_threshold = symmetric_norm_threshold
)
}
weighted_matrix
}
#' @autoglobal
normalise_weights_to_counts <- function(sampled_participants, weighted_matrix) {
## normalise to give mean number of contacts
## calculate normalisation vector
norm_vector <- c(xtabs(
data = sampled_participants,
formula = sampled.weight ~ age.group,
addNA = TRUE
))
## normalise contact matrix
weighted_matrix <- weighted_matrix / norm_vector
## set non-existent data to NA
weighted_matrix[is.nan(weighted_matrix)] <- NA_real_
weighted_matrix
}
## construct a warning in case there are NAs
#' @autoglobal
build_na_warning <- function(weighted_matrix) {
na_headers <- anyNA(dimnames(weighted_matrix), recursive = TRUE)
na_content <- anyNA(weighted_matrix)
na_present <- na_headers || na_content
warning_suggestion <- NULL
if (na_present) {
warning_suggestion <- " Consider "
if (na_headers) {
warning_suggestion <- paste0(warning_suggestion, "setting ")
suggested_options <- NULL
if (anyNA(rownames(weighted_matrix))) {
suggested_options <- c(suggested_options, "'missing.participant.age'")
}
if (anyNA(colnames(weighted_matrix))) {
suggested_options <- c(suggested_options, "'missing.contact.age'")
}
warning_suggestion <-
paste0(warning_suggestion, paste(suggested_options, collapse = " and "))
if (na_content) {
warning_suggestion <- paste0(warning_suggestion, ", and ")
} else {
warning_suggestion <- paste0(warning_suggestion, ".")
}
}
if (na_content) {
warning_suggestion <- paste0(
warning_suggestion,
"adjusting the age limits."
)
}
}
warning_suggestion
}
#' @autoglobal
na_in_weighted_matrix <- function(weighted_matrix) {
na_headers <- anyNA(dimnames(weighted_matrix), recursive = TRUE)
na_content <- anyNA(weighted_matrix)
na_present <- na_headers || na_content
na_present
}
#' @autoglobal
normalisation_factors <- function(normalised_matrix, weighted_matrix) {
normalisation_fctr <- c(
normalised_matrix / weighted_matrix,
weighted_matrix / normalised_matrix
)
normalisation_fctr <- normalisation_fctr[
!is.infinite(normalisation_fctr) & !is.na(normalisation_fctr)
]
normalisation_fctr
}
#' @autoglobal
normalise_weighted_matrix <- function(
survey_pop,
weighted_matrix,
symmetric_norm_threshold,
call = rlang::caller_env()
) {
## set c_{ij} N_i and c_{ji} N_j (which should both be equal) to
## 0.5 * their sum; then c_{ij} is that sum / N_i
normalised_weighted_matrix <- survey_pop$population * weighted_matrix
normalised_weighted_matrix <- 0.5 /
survey_pop$population *
(normalised_weighted_matrix + t(normalised_weighted_matrix))
warn_norm_fct_exceed_thresh(
normalised_weighted_matrix = normalised_weighted_matrix,
weighted_matrix = weighted_matrix,
symmetric_norm_threshold = symmetric_norm_threshold
)
normalised_weighted_matrix
}
#' @autoglobal
mean_contacts_per_person <- function(population, num_contacts) {
mean_contacts <- sum(population * num_contacts) / sum(population)
mean_contacts
}
#' @autoglobal
get_spectral_radius <- function(weighted_matrix) {
spectrum_matrix <- weighted_matrix
spectrum_matrix[is.na(spectrum_matrix)] <- 0
eigen_values <- eigen(spectrum_matrix, only.values = TRUE)
# get the largest eigenvalue
spectral_radius <- as.numeric(eigen_values$values[1])
spectral_radius
}
#' @autoglobal
split_mean_norm_contacts <- function(
weighted_matrix,
population
) {
## get rid of name but preserve row and column names
weighted_matrix <- unname(weighted_matrix)
num_contacts <- rowSums(weighted_matrix)
mean_contacts <- mean_contacts_per_person(
population = population,
num_contacts = num_contacts
)
# Maximum growth rate of the infection process
# the dominant eigenvalue or the spectral radius
spectral_radius <- get_spectral_radius(weighted_matrix)
# normalise: how much more transmission potential from pop. structure
normalisation <- spectral_radius / mean_contacts
age_proportions <- population / sum(population)
weighted_matrix <- diag(1 / num_contacts) %*%
weighted_matrix %*%
diag(1 / age_proportions)
num_contacts <- num_contacts / spectral_radius
list(
weighted_matrix = weighted_matrix,
mean_contacts = mean_contacts,
normalisation = normalisation,
contacts = num_contacts
)
}
#' @autoglobal
matrix_per_capita <- function(weighted_matrix, survey_pop) {
weighted_matrix_per_capita <- weighted_matrix /
matrix(
rep(survey_pop$population, nrow(survey_pop)),
ncol = nrow(survey_pop),
byrow = TRUE
)
weighted_matrix_per_capita
}
#' @autoglobal
n_participants_per_age_group <- function(participants) {
participant_population <- data.table(table(
participants[, age.group],
useNA = "ifany"
))
setnames(participant_population, c("age.group", "participants"))
participant_population[, proportion := participants / sum(participants)]
participant_population
}
#' @autoglobal
return_participant_weights <- function(
survey_participants,
weigh_age,
weigh_dayofweek
) {
# default
part_weights <- survey_participants[, .N, by = list(age.group, weight)]
part_weights <- part_weights[order(age.group, weight), ]
# add age and/or dayofweek info
if (weigh_age && weigh_dayofweek) {
part_weights <- survey_participants[,
.N,
by = list(age.group, participant.age = part_age, is.weekday, weight)
]
}
if (weigh_age && !weigh_dayofweek) {
part_weights <- survey_participants[,
.N,
by = list(age.group, participant.age = part_age, weight)
]
}
if (weigh_dayofweek && !weigh_age) {
part_weights <- survey_participants[,
.N,
by = list(age.group, is.weekday, weight)
]
}
# order (from left to right)
part_weights <- part_weights[order(part_weights), ] # nolint
# set name of last column
names(part_weights)[ncol(part_weights)] <- "participants"
part_weights[, proportion := participants / sum(participants)]
part_weights[]
}
Any scripts or data that you put into this service are public.
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.