R/weigh.R

Defines functions weigh_population weigh_named weigh_grouped weigh_direct weigh

Documented in weigh

#' Weigh survey participants
#'
#' @description
#' Applies weights to participants in a `contact_survey` object. Weights are
#' always multiplied into an existing `weight` column (or one is created with
#' value 1), making multiple calls composable.
#'
#' The behaviour depends on the combination of arguments:
#' \describe{
#'   \item{`target = NULL`}{Numeric column: multiply `weight` by column values
#'     directly.}
#'   \item{Unnamed `target` + `groups`}{Map column values to groups, assign
#'     `target[g] / n_in_group` per participant.}
#'   \item{Named `target`}{Names match column values, assign
#'     `target[val] / n_with_val` per participant.}
#'   \item{Data frame `target`}{Post-stratify against population data (expanded
#'     to single-year ages via [pop_age()]).}
#' }
#'
#' @param survey a [survey()] object (must have been processed by
#'   [assign_age_groups()] if using data frame target)
#' @param by column name in the participant data to weigh by
#' @param target target weights: `NULL` for direct numeric weighting, an
#'   unnamed numeric vector (with `groups`), a named numeric vector, or a
#'   data frame with columns `lower.age.limit` and `population`
#' @param groups a list of value sets mapping column values to groups (used
#'   with unnamed `target` vector); must be the same length as `target`
#' @param ... further arguments passed to [pop_age()] when `target` is a data
#'   frame
#' @returns the survey object with updated participant weights
#'
#' @examples
#' data(polymod)
#' # Direct numeric weighting
#' if ("survey_weight" %in% names(polymod$participants)) {
#'   polymod |> weigh("survey_weight")
#' }
#'
#' # Dayofweek weighting with groups (POLYMOD uses 0 = Sunday, 6 = Saturday)
#' polymod |>
#'   weigh("dayofweek", target = c(5, 2), groups = list(1:5, c(0, 6)))
#'
#' @export
#' @autoglobal
weigh <- function(survey, by, target = NULL, groups = NULL, ...) {
  check_if_contact_survey(survey)
  survey <- copy_survey(survey)
  participants <- survey$participants

  if (!by %in% colnames(participants)) {
    cli::cli_abort(
      "Column {.val {by}} not found in participant data."
    )
  }

  if (!"weight" %in% colnames(participants)) {
    participants[, weight := 1]
  }

  if (is.null(target) && is.null(groups)) {
    participants <- weigh_direct(participants, by)
  } else if (is.data.frame(target)) {
    participants <- weigh_population(participants, target, ...)
  } else if (!is.null(names(target))) {
    if (!is.null(groups)) {
      cli::cli_warn(
        "{.arg groups} is ignored when {.arg target} is a named vector."
      )
    }
    participants <- weigh_named(participants, by, target)
  } else if (!is.null(groups)) {
    participants <- weigh_grouped(participants, by, target, groups)
  } else {
    cli::cli_abort(
      "Cannot determine weighting method. Provide {.arg groups} with an \\
       unnamed {.arg target}, use a named {.arg target}, or pass a data frame."
    )
  }

  survey$participants <- participants
  survey
}

#' @autoglobal
weigh_direct <- function(participants, by) {
  if (!is.numeric(participants[[by]])) {
    cli::cli_abort(
      "Column {.val {by}} must be numeric for direct weighting \\
       (without {.arg target}). Got {.cls {class(participants[[by]])}}."
    )
  }
  participants[, weight := weight * get(by)]
  participants
}

#' @autoglobal
weigh_grouped <- function(participants, by, target, groups) {
  if (length(target) != length(groups)) {
    cli::cli_abort(
      "{.arg target} (length {length(target)}) and {.arg groups} \\
       (length {length(groups)}) must have the same length."
    )
  }

  col_vals <- participants[[by]]
  group_idx <- rep(NA_integer_, length(col_vals))
  for (g in seq_along(groups)) {
    group_idx[col_vals %in% groups[[g]]] <- g
  }

  group_counts <- tabulate(group_idx, nbins = length(groups))
  empty <- which(group_counts == 0L)
  if (length(empty) > 0L) {
    cli::cli_warn(
      "Group{?s} {.val {empty}} ha{?s/ve} no matching participants; \\
       their target weights will be ignored."
    )
    target[empty] <- 0
    group_counts[empty] <- 1L
  }
  unmatched_non_na <- is.na(group_idx) & !is.na(col_vals)
  n_unmatched <- sum(unmatched_non_na)
  if (n_unmatched > 0) {
    cli::cli_warn(
      "{n_unmatched} participant{?s} ha{?s/ve} values in {.val {by}} \\
       that do not match any group; an average weight will be used."
    )
  }

  n_total <- length(col_vals)
  weight_factor <- ifelse(
    is.na(group_idx),
    sum(target) / n_total,
    target[group_idx] / group_counts[group_idx]
  )

  participants[, weight := weight * weight_factor]
  participants
}

#' @autoglobal
weigh_named <- function(participants, by, target) {
  col_vals <- as.character(participants[[by]])
  val_counts <- table(col_vals)
  matched_target <- target[col_vals]
  matched_counts <- as.numeric(val_counts[col_vals])

  unmatched <- setdiff(unique(col_vals[!is.na(col_vals)]), names(target))
  n_unmatched <- length(unmatched)
  if (n_unmatched > 0) {
    cli::cli_warn(
      "{n_unmatched} value{?s} in column {.val {by}} not found in \\
       {.arg target} names ({.val {unmatched}}); \\
       {?its/their} weight{?s} will be set to {.val NA}."
    )
  }

  weight_factor <- ifelse(
    is.na(matched_target),
    NA_real_,
    matched_target / matched_counts
  )

  participants[, weight := weight * weight_factor]
  participants
}

#' @autoglobal
weigh_population <- function(participants, target, ...) {
  if (!all(c("lower.age.limit", "population") %in% colnames(target))) {
    cli::cli_abort(
      "Data frame {.arg target} must have columns {.val lower.age.limit} \\
       and {.val population}."
    )
  }

  if (!"part_age" %in% colnames(participants)) {
    cli::cli_abort(
      "Column {.val part_age} not found in participant data. \\
       Run {.fn assign_age_groups} first."
    )
  }

  survey_pop_full <- data.table(target)
  if (!"upper.age.limit" %in% colnames(survey_pop_full)) {
    age_breaks <- agegroups_to_limits(participants$age.group)
    survey_pop_full <- add_survey_upper_age_limit(
      survey = survey_pop_full,
      age_breaks = age_breaks
    )
  }
  survey_pop_full <- survey_pop_reference(survey_pop_full, ...)

  participants <- weight_by_age(participants, survey_pop_full)
  participants
}

Try the socialmixr package in your browser

Any scripts or data that you put into this service are public.

socialmixr documentation built on April 29, 2026, 9:07 a.m.