R/simulator.R

#' Class providing the simulation functionality.
#'
#' @description
#' Provides methods for initializing the synthetic population, parallel simulation, calibration and monitoring.
#'
#' @docType class
#' @export
#' @import data.table
#' @import dqrng
#' @importFrom R6 R6Class
Simulator <- R6::R6Class("Simulator",

    public = list(

        ##' @description
        ##' Create a new Simulator object.
        ##' @param initializer A function that generates the initial population.
        ##' @param acc_events A list of AccumulationEvent objects.
        ##' @param man_events A list of ManipulationEvent objects.
        ##' @param seeds Seed values for random number generation.
        ##' @param keep_init A logical value indicating whether the initial status should be stored.
        ##' If \code{TRUE} the initial population can be retrieved with the method get_init.
        ##' @param ... Additional arguments to \code{initializer},
        ##' @return A new \code{Simulator} object.
        initialize = function(initializer = NULL, acc_events = list(), man_events = list(), seeds = 123, keep_init = FALSE, ...) {
            if (!is.function(initializer)) {
                stop("Argument 'initializer' is not a function")
            }
            if (!is.list(acc_events)) {
                acc_events <- list(acc_events)  
            } 
            if (!is.list(man_events)) {
                man_events <- list(man_events)
            }
            n_acc <- length(acc_events)
            n_man <- length(man_events)
            if (!n_acc && !n_man) {
                stop("There must be at least one accumulation event or manipulation event")
            } 
            if (n_acc) {
                not_acc <- which(!is_event(acc_events, "AccumulationEvent"))
                if (length(not_acc)) {
                    stop("Elements at ", not_acc, " are not of R6Class 'AccumulationEvent'")
                }
            }
            if (n_man) {
                not_man <- which(!is_event(man_events, "ManipulationEvent"))
                if (length(not_man)) {
                    stop("Elements at ", not_man, " are not of R6Class 'ManipulationEvent'")
                }
            }
            if (!is.numeric(seeds)) {
                stop("Seed values must be numeric")
            }
            private$status <- initializer(...)
            if (keep_init) {
                private$status_init <- data.table::copy(private$status)
            }
            private$accumulation_events <- acc_events
            private$manipulation_events <- man_events
            private$settings <- list(cluster_enabled = FALSE, monitor_enabled = FALSE)
            private$unlisted_events <- unlist(list(private$manipulation_events, private$accumulation_events))
            mapply(function(e, id) e$set_id(id), private$unlisted_events, 1:length(private$unlisted_events))
            private$settings$event_indices <- setNames(lapply(private$unlisted_events, function(e) e$get_id()), sapply(private$unlisted_events, function(e) e$get_name()))
            sapply(private$unlisted_events, function(e) e$allocate(private$status))
            private$runtime <- 0
            private$seeds <- as.integer(seeds)
            private$reset_seed <- TRUE
            private$history <- list()
        },

        ##' @description
        ##' Start a parallel computation cluster.
        ##' @param cl A cluster object.
        ##' @param nc Size of the cluster (available computing cores).
        ##' @param export A character vector giving the names of objects from the global environment that should be exported to the workers.
        ##' @param packages A character vector giving the names of packages that should be loaded on the workers.
        ##' @param interface Interface to "foreach", must match the cluster type ("doParallel" and "doMPI" are currently supported).
        start_cluster = function(cl, nc, export, packages, interface = c("doParallel", "doMPI")) {
            interface <- match.arg(interface)
            if (private$parallel_available(interface)) {
                if (missing(cl)) {
                    stop("No cluster object 'cl' provided")
                }
                if (missing(nc)) {
                    warning("Number of cores 'nc' not specified, defaulting to available cores")
                    if (identical(interface, "doParallel")) nc <- parallel::detectCores()
                    else if (identical(interface, "doMPI")) nc <- doMPI::clusterSize(cl)
                }
                if (missing(export)) {
                    export <- character(0)
                }
                if (missing(packages)) {
                    private$settings$packages <- character(0)
                } else {
                    private$settings$packages <- packages
                }
                private$settings$cluster_enabled <- TRUE
                private$settings$cluster_interface <- interface
                private$settings$cluster <- cl
                private$settings$cluster_ix <- 1:nc
                private$settings$n_cores <- nc
                private$export(vars = c("accumulation_events", "manipulation_events"), envir = private)
                # These functions need to be copied and their evaluation environment set to .GlobalEnv.
                # Otherwise the 'private' environment (which is the original evaluation environment)
                # will be serialized along with the functions to the workers, resulting in excessive use of memory.
                event_funs <- c("step_accumulation_recursive_parallel", "step_manipulation_recursive_parallel", 
                                "process_accumulation_events_parallel", "process_manipulation_events_parallel")
                export_env <- private$copy_local(vars = event_funs, envir = private)
                private$export(vars = event_funs, envir = export_env)
                if (length(export)) {
                    export_env <- private$copy_local(vars = export)
                    private$export(vars = export, envir = export_env)
                }
                private$split_status()
                message("Cluster ready: using ", nc, " cores")
            }
        },

        ##' @description
        ##' Stop a parallel computation cluster.
        stop_cluster = function() {
            if (private$settings$cluster_enabled) {
                if (identical(private$settings$cluster_interface, "doParallel")) {
                    parallel::stopCluster(private$settings$cluster)
                } else if (identical(private$settings$cluster_interface, "doMPI")) {
                    doMPI::closeCluster(private$settings$cluster)
                }
                private$combine_status()
                private$settings$cluster_enabled <- FALSE
                private$settings$n_cores <- NULL
                private$settings$cluster <- NULL
                private$settings$cluster_ix <- NULL
                message("Cluster stopped")
            } else {
                message("There is no active cluster")
            }
        },

        ##' @description
        ##' Progress the simulation.
        ##' @param N An integer indicating the length of the simulation.
        ##' @param unit A character string giving the time unit that \code{t_sim} corresponds to.
        ##' @param show_progress A logical value. 
        ##' If \code{TRUE} shows a progress bar that tracks the completion of the simulation 
        ##' (for non-parallel computations only)
        run = function(N, show_progress = TRUE) {
            N <- as.integer(N)
            # Parallel computation
            if (private$settings$cluster_enabled) {
                tryCatch({
                    # create local copies to avoid exporting 'self' or 'private'.
                    mon <- private$settings$monitor_enabled
                    seeds <- private$seeds
                    reset_seed <- private$reset_seed
                    private$reset_seed <- FALSE
                    runtime <- private$runtime
                    # NULL initialization necessary here due to foreach export
                    mon_ag <- NULL
                    mon_iv <- NULL
                    monitor <- NULL
                    if (mon) {
                        mon_ag <- private$settings$monitor_aggregators
                        mon_iv <- private$settings$monitor_intervals
                        monitor <- private$monitor_factory(mon_ag, mon_iv)
                    }
                    result <- foreach(
                        dt_sub = private$status_split, 
                        # These values may change between runs and need to be exported again.
                        .export = c(
                            "seeds", "reset_seed", "runtime", "N", "mon", "mon_ag", "mon_iv", "monitor"
                        ),
                        # Do not export anything extra to avoid possible conflicts.
                        .noexport = c(ls(globalenv()), ls(environment()), ls(self), ls(private)), 
                        # These packages have to be installed on the workers.
                        .packages = c("data.table", "dqrng", private$settings$packages),
                        .combine = list, 
                        .verbose = FALSE, 
                        .multicombine = TRUE) %dopar% {
                        # Parallel block
                        key <- as.numeric(dt_sub$key)
                        dt <- dt_sub$value
                        new_hist <- NULL
                        if (reset_seed) {
                            n_seeds <- length(seeds)
                            if (n_seeds > 1) dqrng::dqset.seed(seeds[key %% n_seeds + 1], key)
                            else dqrng::dqset.seed(seeds, key)
                        }
                        if (mon) {
                            mon_n <- sort(unique(unlist(lapply(mon_iv, function(x) seq(1, N, by = x)))))
                            mon_ix <- 0
                            new_hist <- vector(mode = "list", length = length(mon_n))
                        }
                        for (n in (runtime + 1):(runtime + N)) {
                            dt <- process_accumulation_events_parallel(dt, accumulation_events, parallel = key)
                            process_manipulation_events_parallel(dt, manipulation_events)
                            if (mon) {
                                temp_hist <- monitor(dt, n)
                                if (!is.null(temp_hist)) {
                                    mon_ix <- mon_ix + 1
                                    new_hist[[mon_ix]] <- temp_hist
                                }
                            }
                        }
                        return(list(value = dt, key = dt_sub$key, history = new_hist))
                    }
                    private$status_split <- lapply(result, "[" , c("value", "key"))
                    if (mon) {
                        new_hist_len <- length(result[[1]]$history)
                        if (new_hist_len) {
                            ag_len <- length(mon_ag)
                            new_history <- vector(mode = "list", length = new_hist_len)
                            for (t in 1:new_hist_len) {
                                new_history[[t]] <- vector(mode = "list", length = ag_len + 1)
                                new_history[[t]][[1]] <- result[[1]]$history[[t]][[1]]
                                for (i in 1:ag_len) {
                                    new_history[[t]][[i+1]] <- private$settings$monitor_combiners[[i]](lapply(lapply(result, function(x) x$history[[t]]), function(y) y[[i+1]]))
                                }
                            }
                            hist_len <- length(private$history)
                            private$history[(hist_len + 1):(hist_len + new_hist_len)] <- new_history
                        }
                    }
                }, warning = function(w) {
                    message(w)
                }, error = function(e) {
                    message(e)
                })
            # Sequential computation
            } else {
                if (show_progress) {
                    pb <- utils::txtProgressBar(min = 0, max = N, style = 3, width = 80)
                }
                if (private$reset_seed) {
                    dqrng::dqset.seed(private$seeds[1])
                    private$reset_seed <- FALSE
                }
                if (private$settings$monitor_enabled) {
                    private$monitor <- private$monitor_factory(private$settings$monitor_aggregators, private$settings$monitor_intervals)
                    mon_n <- sort(unique(unlist(lapply(private$settings$monitor_intervals, function(x) seq(1, N, x)))))
                    mon_ix <- 0
                    new_history <- vector(mode = "list", length = length(mon_n))
                    for (n in (private$runtime + 1):(private$runtime + N)) {
                        private$status <- private$process_accumulation_events(private$status, private$accumulation_events)
                        private$process_manipulation_events(private$status, private$manipulation_events)
                        temp_hist <- private$monitor(private$status, n)
                        if (!is.null(temp_hist)) {
                            mon_ix <- mon_ix + 1
                            new_history[[mon_ix]] <- temp_hist
                        }
                        if (show_progress) {
                            setTxtProgressBar(pb, n)
                        }
                    }
                    new_hist_len <- length(new_history)
                    if (new_hist_len) {
                        hist_len <- length(private$history)
                        private$history[(hist_len + 1):(hist_len + new_hist_len)] <- new_history
                    }
                } else {
                    for (n in (private$runtime + 1):(private$runtime + N)) {
                        private$status <- private$process_accumulation_events(private$status, private$accumulation_events)
                        private$process_manipulation_events(private$status, private$manipulation_events)
                        if (show_progress) {
                            setTxtProgressBar(pb, n)
                        }
                    }
                }
                if (show_progress) {
                    close(pb)
                }
            }
        },
        
        ##' @description 
        ##' Run various scenarios with the same initial population using different interventions with replications.
        ##' @param interventions A list of interventions corresponding to the scenarios to run, see 'details'.
        ##' @param replications An integer describing how many times each scenario should be replicated.
        ##' @param N A single integer or an integer vector describing the simulation runtime for each scenario.
        ##' @param seeds_override If \code{NULL}, a different seed is automatically used for each scenario. 
        ##' If provided, should be an integer matrix with \code{replications} rows and \code{length(interventions)} columns.
        ##' @param output A function to apply to the population after each scenario.
        ##' @param ... Additional arguments passed to \code{output}.
        ##' @details Each element of \code{interventions} should be a named list with the following elements: 
        ##' \code{acc_events}, \code{acc_pars}, \code{man_events}, \code{man_pars} and \code{pars}.
        ##' The element \code{acc_events} should be a list of objects of class \code{AccumulationEvent} and
        ##' \code{acc_pars} should be a list of the same length of numeric vectors giving the parameter values for these events.
        ##' Similarly, \code{man_events} should be a list of object of class \code{ManipulationEvent} with \code{man_pars}
        ##' being a list of the same length of numeric vectors giving the parameter values.
        ##' Finally, \code{pars} should be a list of parameter values (see \code{configure}).
        ##' For additional details, see \code{intervene}.
        run_scenarios = function(interventions, replications, N, seeds_override = NULL, output, ...) {
            M <- replications
            I <- length(interventions)
            result <- vector(mode = "list", length = M)
            if (!is.null(seeds_override)) {
                seeds <- seeds_override
            } else {
                seeds <- outer(I * 0:(M - 1), 1:I, "+")
            }
            N_mat <- matrix(N, nrow(seeds), ncol(seeds))
            for (m in 1:M) {
                result[[m]] <- vector(mode = "list", length = I)
                for (i in 1:I) {
                    self$reset(seeds = seeds[m,i])
                    intv <- interventions[[i]]
                    acc_events <- intv$acc_events
                    acc_pars<- intv$acc_pars
                    man_events <- intv$man_events
                    man_pars <- intv$man_pars
                    pars <- intv$pars
                    if (n_acc <- length(acc_events)) {
                        for (j in 1:n_acc) {
                            acc_events[[j]]$set_parameters(acc_pars[[j]])
                        }
                    }
                    if (n_man <- length(man_events)) {
                        for (j in 1:n_man) {
                            man_events[[j]]$set_parameters(man_pars[[j]])
                        }
                    }
                    self$intervene(acc_events, man_events, pars)
                    self$run(N[m,i])
                    result[[m]][[i]] <- output(self$get_status(), ...)
                }
            }
        },

        ##' @description
        ##' Configure the parameters of Event objects.
        ##' @param pars A named list of parameters, see 'details'.
        ##' @details 
        ##' The parameter list \code{pars} should be structured as follows:
        ##' Each element should have a name of an Event object and each
        ##' element should be a named list where the names correspond
        ##' to the names of the parameters of whose values are to be configured.
        reconfigure = function(pars) {
            for (name in names(pars)) {
                private$unlisted_events[[private$settings$event_indices[[name]]]]$set_parameters(pars[[name]])
            }
            if (private$settings$cluster_enabled) {
                private$export(c("manipulation_events", "accumulation_events"), private)
            }
        },

        ##' @description
        ##' Configure events, run a simulation and compute a summary statistic for calibration. 
        ##' The original status is restored after completion.
        ##' @param N A numeric value indicating the length of the simulation
        ##' @param pars A list of parameter values, see 'details'.
        ##' @param output A function that computes a summary statistic from the simulation.
        ##' @param seeds Seed values for random number generation.
        ##' @param ... Additional arguments passed to \code{output}.
        ##' @return The output of \code{output_function} evaluated for the resulting population.
        ##' @details 
        ##' The parameter list \code{pars} should be structured as follows:
        ##' Each element should have a name of an Event object and each
        ##' element should be a named list where the names correspond
        ##' to the names of the parameters of whose values are to be configured.
        configure = function(N, pars, output, seeds = 123, ...) {
            if (length(private$history)) {
                warning("Previous simulation history was overwritten by calibration")
            }
            if (!is.numeric(seeds)) {
                stop("Seed values must be numeric")
            }
            private$history <- list()
            self$reconfigure(pars)
            orig_seeds <- private$seeds
            private$seeds <- as.integer(seeds)
            private$reset_seed <- TRUE
            if (private$settings$cluster_enabled) {
                orig_status <- lapply(private$status_split, function(x) {
                    return(list(value = data.table::copy(x$value), key = x$key))
                })
                self$run(N = N)
                status_temp <- data.table::rbindlist(lapply(private$status_split, "[[", "value"))
                out <- output(status_temp, ...)
                status_temp <- NULL
                private$status_split <- orig_status
                private$seeds <- orig_seeds
                private$reset_seed <- TRUE
                private$history <- list()
            } else {
                orig_status <- data.table::copy(private$status)
                self$run(N = N, show_progress = FALSE)
                out <- output(private$status, ...)
                private$status <- orig_status
                private$seeds <- orig_seeds
                private$reset_seed <- TRUE
                private$history <- list()
            }
            private$runtime <- 0
            return(out)
        },

        ##' @description 
        ##' Carry out an intervention that applies a set of accumulation events and manipulation events, 
        ##' and changes the values of specific parameters.
        ##' By default, the events are applied first, then the parameters are changed.
        ##' @param acc_events A list of accumulation events to apply.
        ##' @param man_events A list of manipulation events to apply.
        ##' @param pars A list of parameter values to set, see 'details'.
        ##' @param events_first A logical value. If \code{TRUE} (the default) then the events are applied before any changes to the parameters. If \code{FALSE} the parameter values are changed first.
        ##' @details 
        ##' The parameter list \code{pars} should be structured as follows:
        ##' Each element should have a name of an Event object and each
        ##' element should be a named list where the names correspond
        ##' to the names of the parameters of whose values are to be configured.
        intervene = function(acc_events = list(), man_events = list(), pars = list(), events_first = TRUE) {
            if (!is.list(acc_events)) {
                acc_events <- list(acc_events)
            }
            if (!is.list(man_events)) {
                man_events <- list(man_events)
            }
            if (!is.list(pars)) {
                pars <- list(pars)
            }
            n_acc <- length(acc_events)
            n_man <- length(man_events)
            n_p <- length(pars)
            if (n_acc) {
                not_acc <- which(!is_event(acc_events, "AccumulationEvent"))
                if (length(not_acc)) {
                    stop("Elements at ", not_acc, " of 'acc_events' are not of R6Class 'AccumulationEvent'")
                }
            }
            if (n_man) {
                not_man <- which(!is_event(man_events, "ManipulationEvent"))
                if (length(not_man)) {
                    stop("Elements at ", not_man, " of 'man_events' are not of R6Class 'ManipulationEvent'")
                }
            }
            private$accumulation_events_temp <- acc_events
            private$manipulation_events_temp <- man_events
            if (n_acc || n_man) {
                if (private$settings$cluster_enabled) {
                    private$export(c("accumulation_events_temp", "manipulation_events_temp"), envir = private)
                    result <- foreach(
                        dt_sub = private$status_split, 
                        .export = c("n_acc", "n_man"),
                        # Do not export anything extra to avoid possible conflicts.
                        .noexport = c(ls(globalenv()), ls(environment()), ls(self), ls(private)),
                        # These packages have to be installed on the workers.
                        .packages = c("data.table", "dqrng", private$settings$packages),
                        .combine = list, 
                        .verbose = FALSE, 
                        .multicombine = TRUE) %dopar% {
                        # Parallel block
                        key <- as.numeric(dt_sub$key)
                        dt <- dt_sub$value
                        if (n_acc) {
                            dt <- process_accumulation_events_parallel(dt, accumulation_events_temp, parallel = key)
                        }
                        if (n_man) {
                            process_manipulation_events_parallel(dt, manipulation_events_temp)
                        }
                        return(list(value = dt, key = dt_sub$key))
                    }
                    private$status_split <- lapply(result, "[" , c("value", "key"))
                } else {
                    if (n_acc) {
                        private$status <- private$process_accumulation_events(private$status, private$accumulation_events_temp)
                    }
                    if (n_man) {
                        private$process_manipulation_events(private$status, private$manipulation_events_temp)
                    }
                }
            }
            if (n_p) {
                self$reconfigure(pars)
            }
            private$accumulation_events_temp <- list()
            private$manipulation_events_temp <- list()
        },

        ##' @description
        ##' Get the status of the current population.
        ##' @return A data.table of the current status.
        get_status = function() {
            if (private$settings$cluster_enabled) {
                temp_status <- data.table::rbindlist(lapply(private$status_split, "[[", "value"))
                if ("parallelization_index" %in% names(temp_status)) {
                    temp_status[ ,parallelization_index := NULL]
                }
                return(temp_status)
            } else {
                return(private$status)
            }
        },
        
        ##' @description
        ##' Get the initial status of the population, if available.
        ##' @return A data.table of the initial status if it was recorded, otherwise NULL.
        get_init = function() {
            if (!is.null(private$status_init)) {
                return(private$status_init)
            } else {
                message("Unable to get the initial status: the initial status was not recorded")
                return(NULL)
            }
        },

        ##' @description.
        ##' Get a sample of the current population
        ##' @param sampler A function that constructs the output sample.
        ##' @param ... Additional arguments passed to \code{sampler}.
        ##' @return The sample generated by \code{sampler}.
        get_sample = function(sampler, ...) {
            return(sampler(status = private$status, ...))
        },

        ##' @description 
        ##' Get the simulation history as recorded by the monitor.
        ##' @return If the monitor is enabled, A list with an element for each evaluation of the monitor aggregators at their specified intervals
        ##' These elements are lists as well, where the first element is the time index, and the remaining elements
        ##' correspond to the values returned by the evaluated aggregators of the monitor. If no monitor is specified, returns NULL (invisibly).
        get_history = function() {
            if (is.null(private$history)) {
                message("No history is available")
                invisible(NULL)
            } else {
                return(private$history)
            }
        },

        ##' @description
        ##' Initialize monitors to track specific status variables throughout the simulation at certain intervals.
        ##' @param aggregators A list of functions to compute aggregate summary statistics from the simulated population.
        ##' @param intervals A numeric vector giving the time interval after which each aggregator should be evaluated.
        ##' @param combiners If parallel computation is used, this is a list of functions that combine the aggregator results from each parallel worker.
        initialize_monitor = function(aggregators, intervals, combiners = NULL) {
            if (length(aggregators) != length(intervals)) {
                stop("Invalid number of intervals")
            }
            private$settings$monitor_enabled <- TRUE
            private$settings$monitor_aggregators <- aggregators
            private$settings$monitor_intervals <- intervals
            private$settings$monitor_combiners <- combiners
            message("Monitor initialized") 
            # Monitor does not actually exist at this point.
            # It is generated when 'run' is called (and exported to the workers if cluster in enabled).
        },
        
        ##' @description
        ##' Reset the simulated population to the initial state (if initial status was kept).
        ##' Note that any changes to event parameters are kept as is and have to be configured manually, if so desired.
        ##' @param force If \code{TRUE}, reset the simulator even if the initial status is not available. 
        ##' @param seeds An integer vector. If given, sets the pseudo-RNG seeds.
        ##' This will then reset the seed numbers, and the simulation time.
        reset = function(force = FALSE, seeds = NULL) {
            if (!is.null(private$status_init)) {
                private$status <- data.table::copy(private$status_init)
                if (private$settings$cluster_enabled) {
                    private$split_status()
                }
            } else {
                if (!force) {
                    stop("Unable to reset the simulator: the initial status was not recorded\nUse reset(force = TRUE) to reset anyway")
                    return()
                }
            }
            if (!is.null(seeds)) {
                private$seeds <- as.integer(seeds)
            }
            private$history <- list()
            private$reset_seed <- TRUE
            private$runtime <- 0
        }

    ),

    private = list(

        ## @field status The status of the simulated population.
        status = data.table::data.table(),
        status_init = NULL,

        ## @field status_split Chunks of the population to be sent to parallel workers.
        status_split = list(),

        ## @field accumulation_events A list of AccumulationEvent objects.
        accumulation_events = list(),
        accumulation_events_temp = list(),

        ## @field manipulation_events A list of ManipulationEvent objects.
        manipulation_events = list(),
        manipulation_events_temp = list(),

        ## @field unlisted_events An unordered vector of both types of events.
        unlisted_events = list(),

        ## @field settings A list containing various settings of the simulator.
        settings = list(),

        ## @field history A list of past statuses.
        history = list(),

        ## @field monitor A monitor object for recording specific status varibles throughout the simulation.
        monitor = NULL,

        ## @field runtime A numeric value indicating how long the simulation has been run for.
        runtime = 0,

        ## @field seeds Seed values used for random number generation in the simulation. Only the first value is used if parallel computation is not used.
        seeds = 0,

        ## @field reset_seed A logical value indicating whether seeds should be reset
        reset_seed = FALSE,

        ## @description
        ## Split the population into chunks for parallel processing.
        split_status = function() {
            if (is.null(private$status$parallelization_index)) {
                n <- private$status[,.N]
                private$status[ ,parallelization_index := factor(rep(1:private$settings$n_cores, ceiling(n / private$settings$n_cores))[1:n])]
                setkey(private$status, parallelization_index)
                private$status_split <- as.list(data.table_isplit(private$status, levels(private$status$parallelization_index)))
            }
            private$status <- data.table()
        },

        ## @description
        ## Reintegrate population chunks when parallel computation finishes.
        combine_status = function() {
            private$status <- data.table::rbindlist(lapply(private$status_split, "[[", "value"))
            if ("parallelization_index" %in% names(private$status)) {
                private$status[ ,parallelization_index := NULL]
            }
            private$status_split <- list()
        },
        
        ## @description
        ## Create local copies of 'vars' from 'envir'
        ## @param vars A character vector of object names to copy
        ## @param envir An environment from which to look up \code{vars} from.
        ## @return An environment containing the objects corresponding to 'vars'
        copy_local = function(vars, envir = .GlobalEnv) {
            temp_env <- new.env(hash = TRUE)
            for (v in vars) {
                v_obj <- get(v, envir = envir)
                if (is.function(v_obj)) {
                    environment(v_obj) <- .GlobalEnv
                }
                assign(v, v_obj, envir = temp_env)
            }
            return(temp_env)
        },
        
        ## @description
        ## Export data to cluster.
        ## @param vars A character vector of object names to export.
        ## @param envir An environment from which to look up \code{vars} from.
        export = function(vars, envir) {
            if (identical(private$settings$cluster_interface, "doParallel")) {
                parallel::clusterExport(private$settings$cluster, varlist = vars, envir = envir)
            } else if (identical(private$settings$cluster_interface, "doMPI")) {
                doMPI::exportDoMPI(private$settings$cluster, varlist = vars, envir = envir)
            }
        },
        
        ## @description
        ## Check for availability of packages related to parallel computation
        ## @param interface Interface to "foreach", must match the cluster type 
        ## ("doParallel" and "doMPI" are currently supported).
        parallel_available = function(interface) {
            if (!requireNamespace("foreach", quietly = TRUE)) {
                stop("Please install 'foreach' to use parallel functionality")
            }
            if (identical(interface, "doMPI")) {
                if (!requireNamespace("doMPI", quietly = TRUE)) {
                    stop("Please install 'doMPI' to use the MPI interface")
                }
                if (!requireNamespace("Rmpi", quitely = TRUE)) {
                    stop("Please install 'Rmpi' to use the MPI interface")
                }
            }
            return(TRUE)
        },

        ## @description
        ## A function to handle recursive event structures for accumulation events (nested & ordered/unordered subgroups of events)
        ## @param dt The data.table to update
        ## @param event_group A list of events or a single event to apply to 'dt'
        step_accumulation_recursive = function(dt, event_group, parallel) {
            if (inherits(event_group, "AccumulationEvent")) {
                dt <- event_group$apply(dt, parallel)
            } else {
                e_size <- length(event_group)
                if (e_size) {
                    if (isTRUE(attr(event_group, "ordered"))) group_order <- 1:e_size # isTRUE takes NULL into account
                    else group_order <- dqrng::dqsample(e_size)
                    for (e in group_order) {
                        dt <- private$step_accumulation_recursive(dt, event_group[[e]], parallel)
                    }
                }
            }
            return(dt)
        },

        step_accumulation_recursive_parallel = function(dt, event_group, parallel) {
            if (inherits(event_group, "AccumulationEvent")) {
                dt <- event_group$apply(dt, parallel)
            } else {
                e_size <- length(event_group)
                if (e_size) {
                    if (isTRUE(attr(event_group, "ordered"))) group_order <- 1:e_size # isTRUE takes NULL into account
                    else group_order <- dqrng::dqsample(e_size)
                    for (e in group_order) {
                        dt <- step_accumulation_recursive_parallel(dt, event_group[[e]], parallel)
                    }
                }
            }
            return(dt)
        },

        ## @description
        ## A function to handle recursive event structures for manipulation events (nested & ordered/unordered subgroups of events)
        ## @param dt The data.table to update
        ## @param event_group A list of events or a single event to apply to 'dt'
        step_manipulation_recursive = function(dt, event_group) {
            if (inherits(event_group, "ManipulationEvent")) {
                event_group$apply(dt)
            } else {
                e_size <- length(event_group)
                if (e_size) {
                    if (isTRUE(attr(event_group, "ordered"))) group_order <- 1:e_size # isTRUE takes NULL into account
                    else group_order <- dqrng::dqsample(e_size)
                    for (e in group_order) {
                        private$step_manipulation_recursive(dt, event_group[[e]])
                    }
                }
            }
        },

        step_manipulation_recursive_parallel = function(dt, event_group) {
            if (inherits(event_group, "ManipulationEvent")) {
                event_group$apply(dt)
            } else {
                e_size <- length(event_group)
                if (e_size) {
                    if (isTRUE(attr(event_group, "ordered"))) group_order <- 1:e_size # isTRUE takes NULL into account
                    else group_order <- dqrng::dqsample(e_size)
                    for (e in group_order) {
                        step_manipulation_recursive_parallel(dt, event_group[[e]])
                    }
                }
            }
        },

        ## @description
        ## Function to process events that introduce new individuals into the population
        ## @param dt The data.table to update
        ## @param events A list of accumulation events of the simulator
        ## @param parallel A logical value indicating if parallel computation is used
        process_accumulation_events = function(dt, events, parallel) {
            dt <- private$step_accumulation_recursive(dt, events, parallel)
            return(dt)
        },

        process_accumulation_events_parallel = function(dt, events, parallel) {
            dt <- step_accumulation_recursive_parallel(dt, events, parallel)
            return(dt)
        },

        ## @description
        ## Function to process events that modify the status variables
        ## @param dt The data.table to update
        ## @param event The list of manipulation events of the simulator
        process_manipulation_events = function(dt, events) {
            private$step_manipulation_recursive(dt, events)
        },

        process_manipulation_events_parallel = function(dt, events) {
            step_manipulation_recursive_parallel(dt, events)
        },

        ## @description
        ## A generator function used to create monitors based on the input aggregators and intervals
        ## @param aggregators A list of functions with a single argument that is the status data.table
        ## @param intervals An integer vector describing the time interval to apply each aggregator
        monitor_factory = function(aggregators, intervals) {
            force(aggregators)
            force(intervals)
            function(dt, n) {
                ix <- which(n %% intervals == 0)
                ix_len <- length(ix)
                if (ix_len) {
                    new_hist <- vector(mode = "list", length = ix_len + 1)
                    new_hist[[1]] <- n
                    for (i in 1:ix_len) {
                        new_hist[[i+1]] <- aggregators[[ix[i]]](dt)
                    }
                    return(new_hist)
                }
                return(NULL)
            }
        }

    )

)
santikka/Sima documentation built on Dec. 22, 2021, 10:15 p.m.