R/strategy_eval.R

Defines functions expand_if_necessary compute_values compute_counts.eval_matrix compute_counts get_n_indiv.eval_strategy get_n_indiv get_eval_inflow.eval_strategy get_eval_inflow get_eval_init.eval_strategy get_eval_init eval_strategy

Documented in compute_counts compute_values eval_strategy expand_if_necessary

#' Evaluate Strategy
#' 
#' Given an unevaluated strategy, an initial number of 
#' individual and a number of cycle to compute, returns the 
#' evaluated version of the objects and the count of 
#' individual per state per model cycle.
#' 
#' `init` need not be integer. E.g. `c(A = 1, B = 0.5, C =
#' 0.1, ...)`.
#' 
#' @param strategy An `uneval_strategy` object.
#' @param parameters Optional. An object generated by 
#'   [define_parameters()].
#' @param cycles positive integer. Number of Markov Cycles 
#'   to compute.
#' @param init numeric vector, same length as number of 
#'   model states. Number of individuals in each model state
#'   at the beginning.
#' @param method Counting method.
#' @param expand_limit A named vector of state expansion 
#'   limits.
#' @param inflow Numeric vector, similar to `init`. Number
#'   of new individuals in each state per cycle.
#' @param strategy_name Name of the strategy.
#'   
#' @return An `eval_strategy` object (actually a list of 
#'   evaluated parameters, matrix, states and cycles 
#'   counts).
#'   
#' @example inst/examples/example_eval_strategy.R
#'   
#' @keywords internal
eval_strategy <- function(strategy, parameters, cycles, 
                          init, method, expand_limit,
                          inflow, strategy_name) {
  
  stopifnot(
    cycles > 0,
    length(cycles) == 1
  )
  
  ## expand states if necessary, and retrieve values.   
  ##  If no expansion, then it returns the same values
  expanded <- expand_if_necessary(
    strategy      = strategy,
    parameters    = parameters, 
    cycles        = cycles,
    init          = init,
    method        = method,
    expand_limit  = expand_limit,
    inflow        = inflow,
    strategy_name = strategy_name)
  
  uneval_states <- expanded$uneval_states
  uneval_transition <- expanded$uneval_transition
  init <- expanded$init
  inflow <- expanded$inflow
  strategy_starting_values <- expanded$starting_values
  n_indiv <- expanded$n_indiv
  parameters <- expanded$parameters
  actually_expanded_something <- expanded$actually_expanded_something
  
  states <- eval_state_list(uneval_states, parameters)
  
  transition <- eval_transition(uneval_transition,
                                parameters)
  
  count_list <- compute_counts(
    x = transition,
    init = init,
    inflow = inflow
  ) %>% 
    correct_counts(method = method)
  
  values <- compute_values(states, count_list, strategy_starting_values)
  
  count_table <- count_list$counts
  
  if (actually_expanded_something) {
    for (st in expanded$expanded_states) {
      count_table[[st]] <- rowSums(count_table[expanded$expansion_cols[[st]]])
      count_table <- count_table[- which(names(count_table) %in% expanded$expansion_cols[[st]])]
    }
  }
  
  structure(
    list(
      parameters = parameters,
      complete_parameters = expanded$complete_parameters,
      transition = transition,
      states = states,
      counts = count_table,
      values = values,
      e_init = init,
      e_inflow = inflow,
      n_indiv = n_indiv,
      cycles = cycles,
      expand_limit = expand_limit
    ),
    class = c("eval_strategy")
  )
}

get_eval_init <- function(x) {
  UseMethod("get_eval_init")
}

get_eval_init.eval_strategy <- function(x) {
  x$e_init
}

get_eval_inflow <- function(x) {
  UseMethod("get_eval_inflow")
}

get_eval_inflow.eval_strategy <- function(x) {
  x$e_inflow
}

get_n_indiv <- function(x) {
  UseMethod("get_n_indiv")
}

get_n_indiv.eval_strategy <- function(x) {
  x$n_indiv
}

#' Compute Count of Individual in Each State per Cycle
#' 
#' Given an initial number of individual and an evaluated 
#' transition matrix, returns the number of individual per 
#' state per cycle.
#' 
#' Use the `method` argument to specify if transitions 
#' are supposed to happen at the beginning or the end of 
#' each cycle. Alternatively linear interpolation between 
#' cycles can be performed.
#' 
#' @param x An `eval_matrix` or
#'   `eval_part_surv` object.
#' @param init numeric vector, same length as number of 
#'   model states. Number of individuals in each model state
#'   at the beginning.
#' @param inflow numeric vector, similar to `init`.
#'   Number of new individuals in each state per cycle.
#'   
#' @return A `cycle_counts` object.
#'   
#' @keywords internal
compute_counts <- function(x, ...) {
  UseMethod("compute_counts")
}

#' @export
compute_counts.eval_matrix <- function(x, init, inflow, ...) {
  
  if (! length(init) == get_matrix_order(x)) {
    stop(sprintf(
      "Length of 'init' vector (%i) differs from the number of states (%i).",
      length(init),
      get_matrix_order(x)
    ))
  }
  
  if (! length(inflow) == get_matrix_order(x)) {
    stop(sprintf(
      "Length of 'inflow' vector (%i) differs from the number of states (%i).",
      length(inflow),
      get_matrix_order(x)
    ))
  }
  
  counts_and_diff <- get_counts_diff(x, init, inflow)

  list_counts <- lapply(counts_and_diff, `[[`, 1) 

  res <- dplyr::as_tibble(
    as.data.frame(
      matrix(
        unlist(list_counts),
        byrow = TRUE,
        ncol = get_matrix_order(x)
      )
    )
  )
  
  colnames(res) <- get_state_names(x)
  
  structure(res, class = c("cycle_counts", class(res)))
  list(counts = res, diff = lapply(counts_and_diff[-(length(x) + 1)], `[[`, 2))
  
}

#' Compute State Values per Cycle
#' 
#' Given states and counts, computes the total state values 
#' per cycle.
#' 
#' @param states An object of class `eval_state_list`.
#' @param counts An object of class `cycle_counts`.
#'   
#' @return A data.frame of state values, one column per 
#'   state value and one row per cycle.
#'   
#' @keywords internal
## slightly harder to read than the original version, but much faster
## identical results to within a little bit of numerical noise
compute_values <- function(states, count_list, strategy_starting_values) {
  
  counts <- count_list$counts
  diff <- count_list$diff
  
  method <- attr(count_list, "method")
  
  states_values <- structure(
    states$.dots, class = class(states)
  )
  states_starting <- states$starting_values
  states_names <- get_state_names(states)
  state_values_names <- get_state_value_names(states)
  num_cycles <- nrow(counts)
  num_states <- length(states_names)
  num_state_values <-length(state_values_names)

  ## combine the list of states into a single large array
  dims_array_1 <- c(
    num_cycles,
    num_state_values,
    num_states)
  
  dims_array_2 <- dims_array_1 + c(0, 1, 0)
  
  state_val_array <- array(unlist(states_values), dim = dims_array_2)
  start_val_array <- array(unlist(states_starting), dim = dims_array_2)

  ## get rid of markov_cycle
  mc_col <- match("markov_cycle", names(states_values[[1]]))
  state_val_array <- state_val_array[, -mc_col, , drop = FALSE]
  start_val_array <- start_val_array[, -mc_col, , drop = FALSE]

  ## put counts into a similar large array
  counts_mat <- array(unlist(counts[, states_names]),
                      dim = dims_array_1[c(1, 3, 2)])
  counts_mat <- aperm(counts_mat, c(1, 3, 2))

  
  
  starting_fill_zero <- c(strategy_starting_values, 
                          rep(0, num_state_values * (num_cycles - 1))) %>%
    matrix(nrow = num_cycles, byrow = TRUE) %>%
    array(dim = dims_array_1)
  
  if(is.null(diff) & sum(start_val_array) > 0){
    warning("Partitioned survival models cannot take into account the starting 
    values in define_state()")
  }
  
  new_starting_states <- lapply(diff, function(x){
    diag(x) <- 0
    colSums(x)
  }) %>% 
    unlist() 
  
  
  new_starting_states <- if (length(new_starting_states)) {
    m <- matrix(new_starting_states, nrow = num_cycles, byrow = TRUE)
    if (method == "beginning"){
      m <- rbind(as.numeric(counts[1,]), m[seq_len(num_cycles -1), ])
    } else {
      m[1,] <- as.numeric(counts[1,])
    }
    m
  } else {
    matrix(rep(0, num_states), ncol = num_states)
  }
  
  start_x_counts <- lapply(seq_len(num_states), function(i){
    new_starting_states[, i] * start_val_array[, , i]
  }) %>% 
    unlist() %>%
    array(dim = dims_array_1)
  
  # multiply, sum, add starting values and add markov_cycle back in
  vals_x_counts <- (state_val_array + starting_fill_zero) * counts_mat 
  wtd_sums <- rowSums(vals_x_counts, dims = 2) + rowSums(start_x_counts, dims = 2)
  res <- data.frame(markov_cycle = states_values[[1]]$markov_cycle, wtd_sums)

  names(res)[-1] <- state_values_names
  
  res
  
}

#' Expand States and Transition
#' 
#' @inherit eval_strategy
#' @keywords internal
#'   
#' @return Expanded states, transitions, input and inflow 
#'   (if they require expansion; otherwise return inputs
#'   unchanged).
#'   
expand_if_necessary <- function(strategy, parameters, 
                                cycles, init, method,
                                expand_limit, inflow,
                                strategy_name) {
  uneval_transition <- get_transition(strategy)
  uneval_states <- get_states(strategy)
  to_expand <- NULL
  
  i_parameters <- interpolate(parameters)
  
  i_uneval_transition <- interpolate(uneval_transition,
                                     more = as_expr_list(i_parameters))
  
  i_uneval_states <- interpolate(uneval_states,
                                 more = as_expr_list(i_parameters))
  
  
  td_tm <- has_state_time(i_uneval_transition)
  
  td_st <- has_state_time(i_uneval_states)
  
  # no expansion if
  expand <- any(c(td_tm, td_st))
  
  # because parameters are deleted if expand
  old_parameters <- parameters
  
  if (expand) {
    if (inherits(uneval_transition, "part_surv")) {
      stop("Cannot use 'state_time' with partitionned survival.")
    }
    
    uneval_transition <- i_uneval_transition
    uneval_states <- i_uneval_states
    
    # parameters not needed anymore because of interp
    parameters <- define_parameters()
    
    # from cells to cols
    td_tm <- td_tm %>%
      matrix(nrow = get_matrix_order(uneval_transition),
             byrow = TRUE) %>%
      apply(1, any)
    
    to_expand <- sort(unique(c(
      get_state_names(uneval_transition)[td_tm],
      get_state_names(uneval_states)[td_st]
    )))
    
    message(
      sprintf(
        "%s: detected use of 'state_time', expanding state%s: %s.",
        strategy_name,
        plur(length(to_expand)),
        paste(to_expand, collapse = ", ")
      )
    )
    
    for (st in to_expand) {
      init <- expand_state(init, state_name = st, cycles = expand_limit[st])
      inflow <- expand_state(inflow, state_name = st, cycles = expand_limit[st])
    }
    
    for (st in to_expand) {
      uneval_transition <- expand_state(
        x = uneval_transition,
        state_pos = which(get_state_names(uneval_transition) == st),
        state_name = st,
        cycles = expand_limit[st]
      )
      
      uneval_states <- expand_state(x = uneval_states,
                                    state_name = st,
                                    cycles = expand_limit[st])
    }
  }
  
  parameters <- eval_parameters(parameters,
                                cycles = cycles,
                                strategy_name = strategy_name)
  
  # to retain values in case of expansion
  if (expand) {
    complete_parameters <- eval_parameters(structure(
      c(lazyeval::lazy_dots(state_time = 1),
        old_parameters),
      class = class(old_parameters)
    ),
    cycles = 1,
    strategy_name = strategy_name)
  } else {
    complete_parameters <- parameters[1,]
  }
  
  e_init <- unlist(eval_init(x = init, parameters[1,]))
  e_inflow <- eval_inflow(x = inflow, parameters)
  
  e_starting_values <- unlist(
    eval_starting_values(
      x = strategy$starting_values,
      parameters[1, ])
  )
  # e_starting_values <- 
  #   list(starting_strategy = e_starting_values_strat,
  #        starting_state = lapply(i_uneval_states, function(x){
  #     unlist(eval_starting_values(
  #       x = x$starting_values,
  #       parameters[1, ]
  #     ))
  #   }))

  n_indiv <- sum(e_init, unlist(e_inflow))
  
  if (any(is.na(e_init)) || any(is.na(e_inflow)) || any(is.na(e_starting_values))) {
    stop("Missing values not allowed in 'init', 'inflow' or 'starting values'.")
  }
  
  if (!any(e_init > 0)) {
    stop("At least one init count must be > 0.")
  }
  
  exp_cols <- list()
  for (st in to_expand) {
    exp_cols[[st]] <- sprintf(".%s_%i", st, seq_len(expand_limit[st] + 1))
  }
  
  list(
    uneval_transition = uneval_transition,
    uneval_states = uneval_states,
    init = e_init,
    inflow = e_inflow,
    starting_values = e_starting_values,
    n_indiv = n_indiv,
    parameters = parameters,
    complete_parameters = complete_parameters,
    actually_expanded_something = expand,
    expanded_states = to_expand,
    expansion_cols = exp_cols)
}
pierucci/heemod documentation built on July 17, 2022, 9:27 p.m.