#' process inputs to run simulation
#' @param simulation_flags named vector (or list) with the logical elements "normaliseTravel"
#'  "seasonal" and "real_data"
#' @param life_history_params named vector (or list) with the numeric elements
#' "R0", "TR" (time to recovery) and "LP" (latent period)
#' @param travel_params list of parameters relating to travel. This should have the following elements: 1) epsilon, which scales the off-diagonals of the travel matrix
#' @param seed_params named vector (or list) of parameters to do with seeding the pandemic.
#' contains the elements seedCountries: vector of country names in which to seed
# Sizes: vector of how many to seed in each country
# Ages: vector of which age group to seed in each country
# RiskGroups: vector of which risk group to seed in each country
#' @param time_params list of parameters to do with time steps in simulation.
# contains the elements tmax (Maximum time of simulation), tdiv (Number of time steps per day)
#' @param seasonality_params list of seasonality parameters.
#' contains the elements tdelay (0 <= tdelay <= 364): shifts the seasonality function - changing this effectively changes the seed time.
#' tdelay = 0 is seed at t = 0 in sinusoidal curve, roughly start of autumn in Northern hemisphere
#' division: Average seasonality into this many blocks of time
#' amp: amplitude of seasonality

#' @param vax_params named vector (or list) with the numeric elements "efficacy"
#' and "propn_vax0" (initial proportion of vaccinated individuals; assumed constant
#' across location, age and risk groups)
#' @param user_specified_cum_vax_pool_func function or character string of function to produce vaccine pool
#' @param vax_production_params named vector (or list) with named elements matching the arguments of user_specified_cum_vax_pool_func
#' @param user_specified_vax_alloc_func function or character string of function to allocate vaccines
#' @param vax_allocation_params named vector (or list) with named elements matching the arguments of user_specified_vax_alloc_func
#' @return a list of arguments to go into run_simulation
#' @export
setup_inputs <- function(simulation_flags, life_history_params, travel_params, 
                         user_specified_cum_vax_pool_func, vax_production_params,
                         user_specified_vax_alloc_func, vax_allocation_params){
        demography <- setup_demography_real_data(simulation_flags=simulation_flags, travel_params=travel_params)
    } else {
      demography <- setup_demography_sim_data(simulation_flags=simulation_flags, travel_params=travel_params)
  case_fatality_ratio_vec <- expand.grid("Location"=seq_len(demography[["n_countries"]]), 
                                         "case_fatality_ratio" = life_history_params$case_fatality_ratio,
  case_fatality_ratio_vec <- case_fatality_ratio_vec$case_fatality_ratio
  seed_vec <- double(length(demography[["popns"]]))
  labels <- demography[["labels"]]
  seed_vec[(which(labels$Location == seed_params[["Countries"]] & 
                    labels$Age == seed_params[["Ages"]] &
                    labels$RiskGroup == seed_params[["RiskGroups"]]))[1]] <- 
  if ("coverage" %in% names(vax_allocation_params)) {
    if (simulation_flags[["real_data"]]) {
        if(!("coverage_filename" %in% names(vax_allocation_params))){
            coverage_filename <- "data/coverage_data_intersect.csv"
        } else {
            coverage_filename <- vax_allocation_params$coverage_filename
        message(cat("Using coverage file: ", coverage_filename,sep="\t"))
        coverage <- read_coverage_data(coverage_filename, labels)
    } else {
      # every country has same seasonal coverage
      # the below line is incorrect: ada to fix
      # coverage <- rep(1/n_countries, n_countries)#
      stop("uniform coverage not yet implemented")
    vax_allocation_params$coverage <- coverage
  ## process the vaccine production function
  if(is.character(user_specified_cum_vax_pool_func)) {
    user_specified_cum_vax_pool_func <- eval(parse(text = user_specified_cum_vax_pool_func))
  cum_vax_pool_func <- cum_vax_pool_func_closure(user_specified_cum_vax_pool_func, vax_production_params)
  ## process the vaccine allocation function
  if(is.character(user_specified_vax_alloc_func)) {
    user_specified_vax_alloc_func <- eval(parse(text = user_specified_vax_alloc_func))
  vax_allocation_func <- vaccine_allocation_closure(user_specified_vax_alloc_func,
                                                    travelMatrix, vax_allocation_params, labels)
  return(list(popns = demography[["popns"]],
         labels = demography[["labels"]],
         contactMatrix = demography[["contactMatrix"]],
         travelMatrix = demography[["travelMatrix"]],
         latitudes = demography[["latitudes"]],
         n_countries = demography[["n_countries"]],
         n_ages = demography[["n_ages"]],
         n_riskgroups = demography[["n_riskgroups"]],
         cum_vax_pool_func = cum_vax_pool_func,
         vax_allocation_func = vax_allocation_func,
         case_fatality_ratio_vec = case_fatality_ratio_vec,
         seed_vec = seed_vec))

setup_demography_real_data <- function(simulation_flags,
    ## Read in demography data
    dem_tmp <- read.csv(demography_filename, sep=",")
    n_countries <- nrow(dem_tmp)
    n_ages <- ncol(dem_tmp)-2

    ## Risk group data
    risk_propns <- as.matrix(read.csv(risk_filename,sep=",",header=FALSE))
    n_riskgroups <- ncol(risk_propns)
    if(nrow(risk_propns) != n_ages) {
        stop("number of age groups inconsistent between data sets")
    if(ncol(risk_propns) !=n_riskgroups) {
        stop("number of risk groups inconsistent between data sets")

    risk_factors <- rep(1, n_riskgroups) ## Assume that each risk group has same modifier

    ## Enumerate out risk factors for each age group
    age_specific_riskgroup_factors <- matrix(rep(risk_factors,each=n_ages),
    ## construct demography matrix
    demography_matrix <- setup_populations_real_data(demography_filename,
                                                     risk_propns, risk_factors)

    ## construct travel matrix
    travelMatrix <- setup_travel_real_data(travel_filename, demography_matrix$pop_size, travel_params)

    ## construct latitude vector
    latitudes <- read_latitude_data(latitude_filename)
    ## Generate risk factor modifier. ie. modifier for each age/risk group pair, same dimensions as C2
    ## The risk factor modifier modifies the susceptibility of age/risk groups
    risk <- c(t(age_specific_riskgroup_factors))
    risk_matrix <- t(kronecker(risk,matrix(1,1,n_riskgroups*n_ages)))
    ## Generate a contact matrix with dimensions (n_ages*n_riskgroups) * (n_ages*n_riskgroups).
    ## ie. get age specific, then enumerate out by risk group.
    ## If we have country specific contact rates, we get a list
    ## of these matrices of length n_countries
    C1 <- read_contact_data(contact_filename)
    if(is.list(C1)) { # for country specific contact rates
        C2 <- lapply(C1, function(x) kronecker(x, matrix(1,n_riskgroups,n_riskgroups)))
        C3 <- lapply(C2, function(x) x*risk_matrix)
    } else {
        C2 <- kronecker(C1, matrix(1,n_riskgroups,n_riskgroups))
        C3 <- C2*risk_matrix

setup_demography_sim_data <- function(simulation_flags,
    ## Setup risk group data

    risk_propns <- rep(1/n_riskgroups,n_riskgroups) ## Assume risk groups are uniformly distributed
    risk_propns <- matrix(rep(risk_propns,each=n_ages),ncol=n_riskgroups) ## Assume that proportion of ages in each risk group are the same for all ages
    risk_factors <- rep(1, n_riskgroups) ## Assume that each risk group has same modifier
    ## Enumerate out risk factors for each age group
    age_specific_riskgroup_factors <- matrix(rep(risk_factors,each=n_ages),ncol=n_riskgroups)
    ## construct demography matrix
    demography_matrix <- setup_populations(popn_size,n_countries,age_propns, 
                             risk_propns, risk_factors)
    ## construct travel matrix
    travelMatrix <- matrix(1,n_countries,n_countries)+999*diag(n_countries) #Travel coupling - assumed independent of age (but can be changed)
    ## construct latitude vector
    latitudes <- matrix(seq_len(n_countries), n_countries, 1)

    ## Generate a contact matrix with dimensions (n_ages*n_riskgroups) * (n_ages*n_riskgroups).
    ## ie. get age specific, then enumerate out by risk group.
    ## If we have country specific contact rates, we get a list
    ## of these matrices of length n_countries
    ## Contact rates
    contactRates <- c(6.92,.25,.77,.45,.19,3.51,.57,.2,.42,.38,1.4,.17,.36,.44,1.03,1.83)
    contactDur <- c(3.88,.28,1.04,.49,.53,2.51,.75,.5,1.31,.8,1.14,.47,1,.85,.88,1.73)

    ## Generate risk factor modifier. ie. modifier for each age/risk group pair, same dimensions as C2
    ## The risk factor modifier modifies the susceptibility of age/risk groups
    risk <- c(t(age_specific_riskgroup_factors))
    risk_matrix <- t(kronecker(risk,matrix(1,1,n_riskgroups*n_ages)))
    ## for now, make contact matrices same for all countries
    C1 <- generate_contact_matrix(contactRates, contactDur, n_ages, simulation_flags[["ageMixing"]])
    if(simulation_flags[["country_specific_contact"]]) {
        C1 <- rep(list(C1), n_countries)
        C2 <- lapply(C1, function(x) kronecker(x, matrix(1,n_riskgroups,n_riskgroups)))
        C3 <- lapply(C2, function(x) x*risk_matrix)
    } else {
        C2 <- kronecker(C1, matrix(1,n_riskgroups,n_riskgroups))
        C3 <- C2*risk_matrix

#' function to setup synthetic population sizes for each location, age, risk group, 
#' and labels associated with these
#' @param popn_size numeric vector of length 1: population size of a country.
#' assumed to be same across countries.
#' @param n_countries numeric vector of length 1: number of countries
#' @param age_propns numeric vector: proportion of individuals in each age group.
#' assumed to be the same across countries.
#' @param risk_propns numeric vector: proportion of individuals in each age group.
#' assumed to be the same across countries and age groups.
#' @param risk_factors numeric vector: degree of increased susceptibility 
#' in each risk group.
#' assumed to be the same across countries and age groups.
#' @return list with the following elements:
#' X: numeric vector of length n_groups = n_countries * n_ages * n_riskgroups
#' containing the population sizes in each country, age, risk group
#' labels: data frame containing three columns: the location, age, risk group
#' corresponding to each element of X.
#' nrow(labels) = length(X)
#' pop_size: numeric vector of length n_countries containing the total population
#' size in each country.
setup_populations <- function(popn_size, n_countries,
                              age_propns, risk_propns, risk_factors){
    stopifnot(sum(age_propns) == 1, all(rowSums(risk_propns) == 1))
    n_ages <- length(age_propns)
    n_riskgroups <- ncol(risk_propns)
    ## Assuming the same population size for each country, the same age
    ## distribution and the same risk proportions. Once we input real data,
    ## we would input these matrices directly
    country_popns <- rep(popn_size, n_countries)
    country_popns <- t(matrix(rep(country_popns,n_ages),nrow=n_ages,byrow=TRUE)) ## Copy population size for each age group
    ## Get proportion of population in each age group
    age_propns_country <- matrix(rep(age_propns, n_countries),
    ## Use proportion to get actual population size
    age_groups <- country_popns * age_propns_country
    ## Generate risk propns for each age group
    ## Enumerate out risk groups to same dimension as countries
    ## Risk proportions are already enumerated out for ages (ie. n_riskgroups*n_ages)
    risk_matrix <- kronecker(c(risk_propns),matrix(1,n_countries,1))
    ## Enumerate out country/age propns to same dimensions as risk groups (ie. n_riskgroups*n_ages*n_countries x 1)
    age_group_risk <- c(kronecker(matrix(1,n_riskgroups,1), age_groups))

    ## Multiply together to get proportion of population in each location/age/risk group combo
    X <- round(risk_matrix*age_group_risk)
    labels <- cbind(X,expand.grid("Location"=1:n_countries, "RiskGroup"=1:n_riskgroups,"Age"=1:n_ages))
    return(list(X=X,labels=labels, "pop_size" = rep(popn_size, n_countries)))

#' function to setup population sizes for each location, age, risk group, 
#' and labels associated with these, from data
#' @param demography_filename character vector of length 1: file where demographic
#' data is located
#' @param risk_propns numeric vector: proportion of individuals in each age group.
#' assumed to be the same across countries and age groups.
#' @param risk_factors numeric vector: degree of increased susceptibility 
#' in each risk group.
#' assumed to be the same across countries and age groups.
#' @return list with the following elements:
#' X: numeric vector of length n_groups = n_countries * n_ages * n_riskgroups
#' containing the population sizes in each country, age, risk group
#' labels: data frame containing three columns: the location, age, risk group
#' corresponding to each element of X.
#' nrow(labels) = length(X)
#' pop_size: numeric vector of length n_countries containing the total population
#' size in each country.
setup_populations_real_data <- function(demography_filename,
                                        risk_propns, risk_factors){
    n_riskgroups <- ncol(risk_propns)
    ## read in demographic data
    demographic_data <- read.csv(demography_filename,sep = ",")
    n_countries <- nrow(demographic_data)
    n_ages <- ncol(demographic_data)-2
    ## ensure proportions sum to 1 -- eliminate rounding errors
    demographic_data[,ncol(demographic_data)] <- 1 - rowSums(demographic_data[,seq(3,ncol(demographic_data) - 1)])
    ## alphabetise countries for consistency across data sets
    demographic_data <- demographic_data[order(demographic_data$countryID),]
    pop_size <- demographic_data[,"N"]
    age_groups <- as.matrix(pop_size * demographic_data[,seq(3,ncol(demographic_data))])
    ## Generate risk propns for each age group
    ## Enumerate out risk groups to same dimension as countries
    ## Risk proportions are already enumerated out for ages (ie. n_riskgroups*n_ages)
    risk_matrix <- kronecker(c(t(risk_propns)),matrix(1,n_countries,1))
    ## Enumerate out country/age propns to same dimensions as risk groups (ie. n_riskgroups*n_ages*n_countries x 1)
    age_group_risk <- c(kronecker(matrix(1,n_riskgroups,1), age_groups))
    ## Multiply together to get proportion of population in each location/age/risk group combo
    X <- round(risk_matrix*age_group_risk)
    location_names <- demographic_data[,"countryID"]
    labels <- cbind(X,expand.grid("Location"=location_names, "RiskGroup"=1:n_riskgroups,"Age"=1:n_ages))
    return(list("X"=X,"labels"=labels, "pop_size" = pop_size))

#' function to construct contact matrix from data
#' @param contact_filename character vector of length 1: file where contact
#' data is located
#' @return list of length n_countries, of square contact matrices of side length
#' n_ages * n_riskgroups.
#' for each matrix, contactMatrix[i,j] denotes the amount of influence 
#' an individual of type i has on an individual of type j, 
#' where type includes age and risk groups.
read_contact_data <- function(contact_filename){
    contact_data <- read.table(contact_filename,sep = ",",stringsAsFactors = FALSE,row.names = 1,header = TRUE)
    # alphabetise
    contact_data <- contact_data[order(rownames(contact_data)),]
    contact_data <- as.data.frame(t(contact_data))
    contact_data <- lapply(contact_data, function(x) matrix(x,nrow = sqrt(nrow(contact_data))))

#' function to construct travel matrix from data
#' @param travel_filename character vector of length 1: file where travel
#' data is located
#' @param pop_size numeric vector of length n_countries: population size in each
#' country
#' @param travel_params named vector/list containing the element "epsilon":
#' a numeric vector of length 1 which scales the off-diagonal terms of the 
#' travel matrix.  Roughly, the ratio of off-diagonal to on-diagonal term size.
#' @return square travel matrix of side length n_countries (unnormalised)
setup_travel_real_data <- function(travel_filename, pop_size, travel_params) {
    travel_data <- read.table(travel_filename,sep = ",",stringsAsFactors = FALSE,header = TRUE)
    # alphabetise for consistency between data sets
    travel_data <- travel_data[order(colnames(travel_data)),order(colnames(travel_data))]
    travel_data <- as.matrix(travel_data)
    colnames(travel_data) <- NULL
    # normalise travel data to mean so that epsilon is more meaningful
    travel_data <- travel_data / mean(travel_data[travel_data != 0]) * mean(pop_size)
    travel_matrix <- diag(pop_size) + travel_params[["epsilon"]] * travel_data

#' function to construct latitude matrix from data
#' @param latitude_filename character vector of length 1: file where latitude
#' data is located
#' @return matrix with 1 column and nrow = n_countries.
#' The latitude of each country in degrees.
read_latitude_data <- function(latitude_filename){
    latitude_data <- read.table(latitude_filename, sep = ",", header = TRUE)
    latitudes <- latitude_data$latitude
    # alphabetise
    latitudes <- latitudes[order(latitude_data$Location)]
    return(matrix(latitudes, ncol = 1))

#' Uniform location age matrix
#' Given a vector of age boundaries and a maximum age, returns a matrix with a uniform age distribution
#' @param ages the vector of age boundaries
#' @param maxAge the maximum age
#' @return the matrix of proportions in each age goup
generate_age_matrix_uniform <- function(ages,maxAge){
    ageProp <- matrix(ages/maxAge, length(ages),1)

#' Age matrix proportions
#' Given a vector of proportions in each age group, converts this to a matrix
#' @param ages the vector of ages
#' @return the same vector but as an nx1 matrix
generate_age_matrix <- function(ages){
    ageProp <- matrix(ages, nrow=length(ages),ncol=1)

#' Age mixing manipulatoin
#' Converts a vector of age-mixing rates or contact durations into a matrix. If a matrix is passed, just returns the original matrix. Can be switched off to return a matrix of 1s
#' @param contactVector vector of length n_ages*n_ages giving age specific contact rates/durations
#' @param n_ages the number of age groups
#' @param ON bool, if FALSE, turns off age mixing
#' @param TRANSPOSE bool, if TRUE, returns the transpose of the created matrix
#' @return the matrix of age-specific contact rates/durations
generate_age_mixing <- function(contactVector, n_ages, ON=TRUE, TRANSPOSE=TRUE){
    if(!ON) return(matrix(1,n_ages,n_ages))
        Cnum <- matrix(contactVector, n_ages, n_ages)
        if(TRANSPOSE) Cnum <- t(Cnum)
    } else {
        Cnum <- contactVector

#' Age mixing matrix generation
#' Generates the age-mixing contact matrix taking into account contact rates and durations
#' @param contactRates the vector or matrix of age-specific contact rates
#' @param contactDur the vector or matrix of age-specific contact durations
#' @param n_ages the number of age groups
#' @param ON bool, if FALSE, turns off age mixing
#' @param TRANSPOSE bool, if TRUE, returns the transpose of the created matrix
#' @return the matrix of age-specific contact rates
generate_contact_matrix <- function(contactRates, contactDur, n_ages, ON=TRUE, TRANSPOSE=TRUE){
    Cnum <- generate_age_mixing(contactRates, n_ages, ON, TRANSPOSE)
    Cdur <- generate_age_mixing(contactDur, n_ages, ON, TRANSPOSE)
    C1 <- Cnum * Cdur

kronecker_by_group <- function(n_groups, y){
    x <- matrix(1, n_groups, n_groups)
    y <- kronecker(x, y)

generate_risk_matrix <- function(propRisk, n_riskgroups, transpose=TRUE){
    non_risk <- 1 - propRisk
    risk_mat <- matrix(c(non_risk, propRisk), length(propRisk), n_riskgroups)
    if(TRANSPOSE) risk_mat <- t(risk_mat)

#' Read and process seasonal vaccine coverage data
#' Calculates the proportion of seasonal vaccine doses distributed to each country in 2013
#' @param coverage_filename the vector or matrix of age-specific contact rates
#' @param labels a data frame containing the number of individuals in each location, age, risk group
#' @return the proportion of doses distributed to each country
read_coverage_data <- function(coverage_filename, labels) {
  sum_age_risk_func <- sum_age_risk_closure(labels)
  pop_size <- sum_age_risk_func(labels$X)
  coverage_df <- read.table(coverage_filename, sep = ",", header = TRUE, stringsAsFactors = FALSE)
  stopifnot(all(coverage_df$country == levels(labels$Location)))
  coverage <- coverage_df$dose_per_1000 * pop_size
