R/strategy_eval.R

Defines functions eval_strategy

Documented in eval_strategy

#**************************************************************************
#* 
#* Original work Copyright (C) 2016  Antoine Pierucci
#* Modified work Copyright (C) 2017  Matt Wiener
#* Modified work Copyright (C) 2017  Jordan Amdahl
#*
#* This program is free software: you can redistribute it and/or modify
#* it under the terms of the GNU General Public License as published by
#* the Free Software Foundation, either version 3 of the License, or
#* (at your option) any later version.
#*
#* This program is distributed in the hope that it will be useful,
#* but WITHOUT ANY WARRANTY; without even the implied warranty of
#* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
#* GNU General Public License for more details.
#*
#* You should have received a copy of the GNU General Public License
#* along with this program.  If not, see <http://www.gnu.org/licenses/>.
#**************************************************************************

#' Evaluate Strategy
#' 
#' Given an unevaluated strategy, an initial number of 
#' individual and a number of cycle to compute, returns the 
#' evaluated version of the objects and the count of 
#' individual per state per model cycle.
#' 
#' `init` need not be integer. E.g. `c(A = 1, B = 0.5, C =
#' 0.1, ...)`.
#' 
#' @param strategy An `uneval_strategy` object.
#' @param parameters Optional. An object generated by 
#'   [define_parameters()].
#' @param cycles positive integer. Number of Markov Cycles 
#'   to compute.
#' @param init numeric vector, same length as number of 
#'   model states. Number of individuals in each model state
#'   at the beginning.
#' @param method Counting method.
#' @param expand_limit A named vector of state expansion 
#'   limits.
#' @param inflow Numeric vector, similar to `init`. Number
#'   of new individuals in each state per cycle.
#' @param strategy_name Name of the strategy.
#'   
#' @return An `eval_strategy` object (actually a list of 
#'   evaluated parameters, matrix, states and cycles 
#'   counts).
#'   
#' @example inst/examples/example_eval_strategy.R
#'   
#' @keywords internal
eval_strategy <- function(strategy, parameters, cycles, 
                          init, method, expand_limit,
                          inflow, strategy_name, aux_params = NULL,
                          disc_method = 'start', progress_reporter = create_null_prog_reporter(),
                          state_groups = NULL, individual_level = F) {
  
  .state <- .full_state <- .expand <- NULL
  
  stopifnot(
    cycles > 0,
    length(cycles) == 1
  )
  
  # Extract and count states
  states <- get_states(strategy)
  n_states = length(states)
  
  # Extract transitions
  transitions <- get_transition(strategy)
  
  if (!'uneval_matrix' %in% class(transitions)) {
    # No need for interpolation to figure out there is no state_time in a PSM
    to_expand <- rep(F, n_states)
  } else {
  
    # Check for any references to state time
    params_st <- has_state_time.state_transition(
      parameters[!names(parameters) %in% c('state_time', 'state_day', 'state_week', 'state_month', 'state_year')]
    )
    mat_st <- has_state_time(transitions)
    state_st <- has_state_time(states)
    
    if (params_st | any(mat_st) | any(state_st)) {
      # Interpolate to determine propagation of state_time
      to_expand <- get_states_to_expand(parameters, states, transitions)
    } else {
      # No need for interpolation to figure out there is no state_time if there are
      # no references to state_time
      to_expand <- rep(F, n_states)
    }
  }
  
  state_names <- get_state_names(strategy)
  
  # Handle state groups
  if (is.null(state_groups)) {
    state_groups <- tibble(
      name = state_names,
      group = state_names,
      share = F
    )
  } else {
    state_groups <- rbind(
      tibble(
        name = state_names,
        group = state_names,
        share = 0
      ) %>%
        filter(!(name %in% state_groups$name)),
      select(state_groups, name, group, share)
    )
  }
  
  # Build table to determine number of tunnels for each state
  if(any(to_expand)) {
     expand_table <- tibble::tibble(
      .state = attr(states, "names"),
      .expand = to_expand
    ) %>%
    left_join(
      select(state_groups, name, group, share),
      by = c('.state' = 'name')
    ) %>%
    group_by(group, share) %>%
    mutate(
      share = ifelse(is.na(share), FALSE, as.logical(share)),
      .expand = .expand | (any(share) && n() > 1)
    ) %>%
    ungroup() %>%
    mutate(
      .limit = ifelse(.expand, expand_limit, 1)
    ) %>%
    plyr::ddply(
      ".state",
      function(st) {
        if(st$.expand) full_names <- paste0(".", st$.state, "_", seq_len(st$.limit))
        else full_names <- st$.state
        tibble::tibble(
          state_time = seq_len(st$.limit),
          .limit = st$.limit,
          .full_state = full_names,
          .expand = st$.expand
        )
      }
    ) %>%
    mutate(.state = factor(.state, levels = attr(states, "names"))) %>%
    arrange(.state, state_time) %>%
    mutate(
      .state = as.character(.state),
      .full_state = as.character(.full_state)
    )
  } else {
    st_name_vec <- attr(states, "names")
    expand_table <- tibble::tibble(
      .state = st_name_vec,
      .full_state = st_name_vec,
      state_time = 1,
      .expand = to_expand,
      .limit = 1
    )
  }
  
  # Inform user about state expansion
  if(any(expand_table$.expand)){
    expanded <- expand_table %>%
      filter(.expand) %>%
      distinct(.state)
    message(
      sprintf(
        "%s: detected use of 'state_time', expanding state%s: %s.",
        strategy_name,
        plur(length(expanded$.state)),
        paste(expanded$.state, collapse = ", ")
      )
    )
  }
  try(progress_reporter$report_progress(1L))
  
  # Evaluate parameters
  e_parameters <- eval_parameters(
    parameters,
    cycles = cycles,
    strategy_name = strategy_name,
    max_state_time = max(expand_table$.limit),
    disc_method = disc_method
  )
  try(progress_reporter$report_progress(1L))
  
  # Evaluate object parameters.  Doesn't need to
  # be returned since it modifies via reference
  # environment other lazy expression already
  # have
  eval_obj_parameters(
    aux_params,
    e_parameters
  )
  try(progress_reporter$report_progress(1L))
  
  # Evaluate Initial State Values
  e_start_values <- eval_starting_values(
    strategy$starting_values,
    e_parameters
  )
  try(progress_reporter$report_progress(1L))
  
  # Evaluate Initial Counts
  e_init <- eval_init(
    init,
    e_parameters,
    expand_table,
    individual_level = individual_level
  )
  try(progress_reporter$report_progress(1L))
  
  # Inflow (now includes init)
  e_inflow <- eval_inflow(
    inflow,
    e_parameters,
    expand_table
  )
  try(progress_reporter$report_progress(1L))
  
  # Evaluate States
  e_states <- eval_state_list(
    get_states(strategy),
    e_parameters,
    expand_table,
    disc_method = disc_method
  )
  try(progress_reporter$report_progress(1L))
  
  # Evaluate Transitions
  e_transition <- eval_transition(
    get_transition(strategy),
    e_parameters,
    expand_table,
    state_groups = state_groups
  )
  try(progress_reporter$report_progress(1L))
  
  # Compute counts
  count_table_uncorrected <- compute_counts(
    x = e_transition,
    init = e_init,
    inflow = e_inflow
  )
  
  count_table <- correct_counts(count_table_uncorrected, method = method)
  
  # Compute values
  values <- compute_values(
    states = e_states,
    counts = count_table,
    init = e_init,
    inflow = e_inflow,
    starting = e_start_values
  )
  try(progress_reporter$report_progress(1L))
  
  # Get counts of individuals
  n_indiv <- sum(e_inflow) + sum(e_init)
  
  # Aggregate over states
  count_table_agg <- plyr::dlply(
    expand_table %>% mutate(.state = factor(.state, unique(.state))),
    ".state",
    function(st) unname(rowSums(count_table[st$.full_state]))
  ) %>%
    do.call(tibble::tibble, .)
  
  # Aggregate over states
  count_table_agg_uncorrected <- plyr::dlply(
    expand_table %>% mutate(.state = factor(.state, unique(.state))),
    ".state",
    function(st) rowSums(count_table_uncorrected[st$.full_state])
  ) %>%
    do.call(tibble::tibble, .)
  
  try(progress_reporter$report_progress(1L))
  structure(
    list(
      parameters = e_parameters,
      transition = e_transition,
      states = e_states,
      counts = count_table_agg,
      counts_uncorrected = count_table_agg_uncorrected,
      values = values,
      e_init = e_init,
      e_inflow = e_inflow,
      n_indiv = n_indiv,
      cycles = cycles,
      expand_limit = expand_limit
    ),
    class = c("eval_strategy")
  )
}

get_eval_init <- function(x) {
  UseMethod("get_eval_init")
}

get_eval_init.eval_strategy <- function(x) {
  x$e_init
}

get_eval_inflow <- function(x) {
  UseMethod("get_eval_inflow")
}

get_eval_inflow.eval_strategy <- function(x) {
  x$e_inflow
}

get_n_indiv <- function(x) {
  UseMethod("get_n_indiv")
}

get_n_indiv.eval_strategy <- function(x) {
  x$n_indiv
}

#' Compute Count of Individual in Each State per Cycle
#' 
#' Given an initial number of individual and an evaluated 
#' transition matrix, returns the number of individual per 
#' state per cycle.
#' 
#' Use the `method` argument to specify if transitions 
#' are supposed to happen at the beginning or the end of 
#' each cycle. Alternatively linear interpolation between 
#' cycles can be performed.
#' 
#' @param x An `eval_matrix` or
#'   `eval_part_surv` object.
#' @param init numeric vector, same length as number of 
#'   model states. Number of individuals in each model state
#'   at the beginning.
#' @param inflow numeric vector, similar to `init`.
#'   Number of new individuals in each state per cycle.
#'   
#' @return A `cycle_counts` object.
#'   
#' @keywords internal
compute_counts <- function(x, ...) {
  UseMethod("compute_counts")
}

#' @export
compute_counts.eval_sparse_matrix <- function(x, init, inflow, ...) {
  
  n_state <- get_matrix_order(x)
  n_cycle <- length(x)
  state_names <- get_state_names(x)
  
  if (! ncol(inflow) == get_matrix_order(x)) {
    stop(sprintf(
      "Number of columns of 'inflow' matrix (%i) differs from the number of states (%i).",
      ncol(inflow),
      get_matrix_order(x)
    ))
  }
  
  # Make a diagonal matrix of inital state vector
  init_mat = diag(init)
  
  # Do element-wise multiplication to get the numbers
  # undergoing each transition
  uncond_trans <- vector(mode = 'list', length =  n_cycle + 1)
  uncond_trans[[1]] <- init_mat
  trace_mat <- matrix(nrow = n_cycle + 1, ncol = ncol(x[[1]]))
  colnames(trace_mat) <- state_names
  trace_mat[1, ] <- init
  for(i in seq_len(n_cycle)) {
    mat <- (colSums(as.matrix(uncond_trans[[i]])) + diag(unlist(inflow[i, ]))) * as.matrix(x[[i]])
    uncond_trans[[i + 1]] <- as_sparse_matrix(mat)
    trace_mat[i + 1, ] <- colSums(mat)
  }
  
  trace_df <- as_tibble(as.data.frame(trace_mat))
  
  structure(
    trace_df,
    class = c("cycle_counts_sparse", "cycle_counts", class(trace_df)),
    transitions = uncond_trans[-1]
  )
}


#' @export
compute_counts.eval_matrix <- function(x, init, inflow, ...) {
  
  n_state <- get_matrix_order(x)
  n_cycle <- length(x)
  state_names <- get_state_names(x)
  
  if (! ncol(inflow) == get_matrix_order(x)) {
    stop(sprintf(
      "Number of columns of 'inflow' matrix (%i) differs from the number of states (%i).",
      ncol(inflow),
      get_matrix_order(x)
    ))
  }
  
  # Make a diagonal matrix of inital state vector
  init_mat = diag(init, ncol = n_state, nrow = n_state)
  
  # Do element-wise multiplication to get the numbers
  # undergoing each transition
  uncond_trans <- array(0, c(n_state,n_state,n_cycle+1))
  dimnames(uncond_trans) <- list(
    state_names,
    state_names,
    NULL
  )
  uncond_trans[,,1] <- init_mat
  for(i in seq_len(n_cycle)) {
    uncond_trans[,,i+1] <- (colSums(matrix(uncond_trans[,,i], ncol = n_state, nrow = n_state)) + diag(unlist(inflow[i, ]), nrow = n_state, ncol = n_state)) * x[[i]]
  }
  
  # Sum over columns to get trace
  counts_array <- colSums(uncond_trans, dims=1) %>% t
  
  # Convert counts to data_frames
  counts_df <- as_tibble(as.data.frame(counts_array))
  colnames(counts_df) <- state_names

  
  structure(
    counts_df,
    class = c("cycle_counts", class(counts_df)),
    transitions = uncond_trans[ , , -1, drop = F]
  )
}

#' Compute State Values per Cycle
#' 
#' Given states and counts, computes the total state values 
#' per cycle.
#' 
#' @param states An object of class `eval_state_list`.
#' @param counts An object of class `cycle_counts`.
#'   
#' @return A data.frame of state values, one column per 
#'   state value and one row per cycle.
#'   
#' @keywords internal
## slightly harder to read than the original version, but much faster
## identical results to within a little bit of numerical noise
compute_values <- function(states, counts, init, inflow, starting) {
  states_names <- get_state_names(states)
  state_values_names <- get_state_value_names(states)
  
  n_states <- length(states_names)
  n_state_vals <- length(state_values_names)
  num_cycles <- nrow(counts)
  
  ## combine the list of states into a single large array
  dims_array_1 <- c(
    num_cycles,
    length(state_values_names),
    length(states_names))
  
  dims_array_2 <- dims_array_1 + c(0, 1, 0)
  
  state_val_array <- array(unlist(states, TRUE, FALSE), dim = dims_array_2)
  
  ## get rid of markov_cycle
  mc_col <- match("markov_cycle", names(states[[1]]))
  state_val_array <- state_val_array[, -mc_col, , drop = FALSE]
  
  ## put counts into a similar large array
  counts_mat <- array(unlist(counts[, states_names], TRUE, FALSE),
                      dim = dims_array_1[c(1, 3, 2)])
  counts_mat <- aperm(counts_mat, c(1, 3, 2))
  
  # multiply, sum, and add markov_cycle back in
  vals_x_counts <- state_val_array * counts_mat
  wtd_sums <- rowSums(vals_x_counts, dims = 2) + starting * rowSums(inflow)
  wtd_sums[1, ] <- wtd_sums[1, ] + sum(init) * starting[1, ]
  
  # Handle transitional costs
  if(!is.null(attr(states, "transitions"))) {
    
    if ("cycle_counts_sparse" %in% class(counts)) {
    
      trans_values_df <- attr(states, "transitions") %>%
        group_by(markov_cycle) %>%
        mutate(
          .dim1 = as.numeric(.from_name_expanded),
          .dim2 = as.numeric(.to_name_expanded),
          .index = .dim1 + ((.dim2 - 1) * n_states),
          .product = value * as.numeric(attr(counts, "transitions")[[markov_cycle[1]]])[.index]
        ) %>%
        group_by(markov_cycle, variable) %>%
        summarize(value = sum(.product))
      
    } else {
      
      trans_values_df <- attr(states, "transitions") %>%
        mutate(
          .dim1 = as.numeric(.from_name_expanded),
          .dim2 = as.numeric(.to_name_expanded),
          .dim3 = as.numeric(markov_cycle),
          .index = .dim1 + ((.dim2 - 1) * n_states) + ((.dim3 - 1) * (n_states ^ 2)),
          .product = value * as.numeric(attr(counts, "transitions"))[.index]
        ) %>%
        group_by(markov_cycle, variable) %>%
        summarize(value = sum(.product))
    }
      
    trans_values <- reshape2::acast(trans_values_df, markov_cycle~variable, value.var = "value")
    
    wtd_sums <- wtd_sums + trans_values
    
  }
  
  res <- data.frame(markov_cycle = states[[1]]$markov_cycle, wtd_sums)
  names(res)[-1] <- state_values_names
  
  res
}
PolicyAnalysisInc/heRoMod documentation built on March 23, 2024, 4:29 p.m.