R/F_run_modern.R

Defines functions internal_get_core_input InternalRunOneChain run_modern

Documented in run_modern

#-----------------------------------------------------
#' Run the modern (calibration) model
#'
#' @param modern_elevation.csv A .csv file location for modern elevations
#' @param modern_species.csv A .csv file location for modern counts (to be sorted with \code{\link{sort_modern}})
#' @param dx The elevation interval for spacing the spline knots. Defaults to 0.2.
#' @param ChainNums The number of MCMC chains to run
#' @param n.iter The number of iterations
#' @param n.burnin The number of burnin samples
#' @param n.thin The number of thinning
#' @param run.on.server Set to TRUE if you wish to run the chains in parallel on a server
#'
#' @return Nothing is returned, the relevant data will be saved.
#' @export
#' @import R2jags rjags readr dplyr
#'
#' @examples
#' run_modern()

run_modern <- function(modern_elevation = NULL,
                       modern_species = NULL,
                       dx = 0.4,
                       ChainNums = seq(1, 3),
                       n.iter = 40000,
                       n.burnin = 10000,
                       n.thin = 15,
                       validation.run = FALSE,
                       sigma_z_priors = NULL,
                       fold = 1) {

    dir.create(file.path(getwd(), "temp.JAGSobjects"), showWarnings = FALSE)

    # read in the modern data
    if (!is.null(modern_species)) {
        modern_dat <- modern_species
    } else modern_dat <- BTF::NJ_modern_species

    # Get the sorted (by species counts) modern data
    modern_data_sorted <- sort_modern(modern_dat)

    # read in the elevation data
    if (!is.null(modern_elevation)) {
        elevation_dat <- modern_elevation
    } else elevation_dat <- BTF::NJ_modern_elevation

    modern_elevation <- scale(elevation_dat$SWLI)
    scale_att <- attributes(modern_elevation)


    if (validation.run) {
        set.seed(3847)
        K <- 10
        folds <- rep(1:K, ceiling(nrow(modern_data_sorted$moderndat_sorted)/K))
        folds <- folds[sample(1:length(modern_elevation))]
        test_samps <- which(folds == fold)
        test_samps

        y <- modern_data_sorted$moderndat_sorted[-test_samps, ]
        x <- modern_elevation[-test_samps,1]
        y_test <- as_tibble(modern_data_sorted$moderndat_sorted[test_samps, ])
        x_test <- as_tibble(modern_elevation[test_samps,1])
    }

    if (!validation.run) {
        y <- modern_data_sorted$moderndat_sorted
        x <- modern_elevation[,1]
        y_test <- NULL
        x_test <- NULL
    }

    species_names <- modern_data_sorted$species_names

    # Get min/max elevations (will be used with priors)
    elevation_min <- floor(min(modern_elevation))
    elevation_max <- ceiling(max(modern_elevation))

    # Get index for the first species (if any) that has all zero counts
    begin0 <- modern_data_sorted$begin0

    # Total species counts
    N_count <- apply(y, 1, sum)

    ###### Regular B Splines Create some basis functions
    res =  BTF:::bbase(x, xl = elevation_min, xr = elevation_max, dx = dx)  # This creates the basis function matrix
    B.ik <- res$B.ik
    K <- dim(B.ik)[2]

    D = 1
    Delta.hk <- diff(diag(K), diff = D)
    Deltacomb.kh <- t(Delta.hk) %*% solve(Delta.hk %*% t(Delta.hk))
    Z.ih <- B.ik %*% Deltacomb.kh
    H <- dim(Z.ih)[2]

    if(is.null(sigma_z_priors))
    {
      mean_sigma_z <- rep(0,ncol(y))
      sd_sigma_z <- rep(1,ncol(y))
    }
    
    if(!is.null(sigma_z_priors))
    {
      species_prior <- sigma_z_priors$species
      mean_sigma_z <- rep(0,ncol(y))
      sd_sigma_z <- rep(2,ncol(y))
      
      match_index <- match(species_names,species_prior)[1:length(species_prior)]
      
      mean_sigma_z[1:length(species_prior)] <- sigma_z_priors$mean_sigma_overall[match_index]
      sd_sigma_z[1:length(species_prior)] <- sigma_z_priors$sd_sigma_overall[match_index]
    }
    
      
    # Jags model data
    pars = c("p", "beta.j", "sigma.z", "sigma.delta", "delta.hj", "spline")

    data = list(y = y, 
                n = nrow(y), 
                m = ncol(y), 
                N_count = N_count, 
                H = H,
                Z.ih = Z.ih, 
                begin0 = begin0,
                mean_sigma_z = mean_sigma_z, 
                sd_sigma_z = sd_sigma_z)


        for (chainNum in ChainNums) {
            cat(paste("Start chain ID ", chainNum), "\n")

            BTF:::InternalRunOneChain(chainNum = chainNum, jags_data = data,
                jags_pars = pars, n.burnin = n.burnin, n.iter = n.iter,
                n.thin = n.thin)

        }


    # Get model output needed for the core run
    data[["x"]] <- x
    data[["y_test"]] <- y_test
    data[["x_test"]] <- x_test
    
    jags_data <- list(data = data, 
                      pars = pars, 
                      elevation_max = elevation_max,
                      elevation_min = elevation_min, 
                      dx = dx, 
                      species_names = species_names,
                      x_center = scale_att$`scaled:center`,
                      x_scale = scale_att$`scaled:scale`)

    core_input <- internal_get_core_input(ChainNums = ChainNums, jags_data = jags_data)

    # Update jags_data list
    modern_out <- list(data = data, 
                       pars = pars, 
                       elevation_max = elevation_max,
                       elevation_min = elevation_min, 
                       dx = dx, 
                       species_names = species_names,
                       delta0.hj = core_input$delta0.hj, 
                       delta0_sd = core_input$delta0_sd, 
                       beta0.j = core_input$beta0.j, 
                       beta0_sd = core_input$beta0_sd,
                       sig0_z = core_input$sig0_z,
                       tau.z0 = core_input$tau.z0, 
                       src_dat = core_input$src_dat,
                       x_center = scale_att$`scaled:center`,
                       x_scale = scale_att$`scaled:scale`)

    class(modern_out) = 'BTF'

    invisible(modern_out)

}

#-----------------------------------------------------
InternalRunOneChain <- function(chainNum, jags_data, jags_pars, n.burnin,
    n.iter, n.thin) {
    set.seed.chain <- chainNum * 209846
    jags.dir <- file.path(getwd(), "temp.JAGSobjects/")
    set.seed(set.seed.chain)
    temp <- rnorm(1)

    # The model for the modern data
    modernmodel = "
  model
  {

  for(i in 1:n)
  {
  for(j in begin0:m){
  lambda[i,j] <- 1
  }
  for(j in 1:(begin0-1)){
  spline[i,j] <- beta.j[j] + inprod(Z.ih[i,],delta.hj[,j])
  z[i,j] ~ dnorm(spline[i,j],tau.z[j])
  lambda[i,j] <- exp(z[i,j])
  }#End j loop

  y[i,] ~ dmulti(p[i,],N_count[i])
  lambdaplus[i] <- sum(lambda[i,])
  }#End i loop

  ###Get p's for multinomial
  for(i in 1:n){
  for(j in 1:m){
  p[i,j] <- lambda[i,j]/lambdaplus[i]
  }#End j loop
  }#End i loop


  #####Spline parameters#####
  #Coefficients
  for(j in 1:(begin0-1)){
  for (h in 1:H)
  {
  delta.hj[h,j] ~ dnorm(0, tau.delta)
  }
  }
  #Smoothness
  tau.delta<-pow(sigma.delta,-2)
  sigma.delta~dt(0, 2^-2, 1)T(0,)
  ###Variance parameter###
  for(j in 1:(begin0-1)){
  tau.z[j] <- pow(sigma.z[j],-2)
  sigma.z[j] ~ dt(mean_sigma_z[j], sd_sigma_z[j]^-2, 1)T(0,)
  ###Intercept (species specific)
  beta.j[j] ~ dt(0,100^-2,1)
  }
  
  }##End model
  "

    mod <- jags(data = jags_data, parameters.to.save = jags_pars, model.file = textConnection(modernmodel),
        n.chains = 1, n.iter = n.iter, n.burnin = n.burnin, n.thin = n.thin,
        DIC = FALSE, jags.seed = set.seed.chain)

    mod_upd <- mod
    save(mod_upd, file = file.path(getwd(), "temp.JAGSobjects", paste0("jags_mod",
        chainNum, ".Rdata")))

    cat(paste("Hooraah, Chain", chainNum, "has finished!"), "\n")

    return(invisible())
}


internal_get_core_input <- function(ChainNums, jags_data)
{

  mcmc.array <- ConstructMCMCArray(ChainIDs = ChainNums)

  n_samps<-dim(mcmc.array)[1]

  #########For Splines
  # This creates the components for the basis function matrix
  xl<-jags_data$elevation_min
  xr<-jags_data$elevation_max
  begin0 <- jags_data$data$begin0
  deg=3
  dx<-jags_data$dx
  knots <- seq(xl - deg * dx, xr + deg * dx, by = dx)
  n_knots<-length(knots)
  D <- diff(diag(n_knots), diff = deg + 1) / (gamma(deg + 1) * dx ^ deg)
  K <- dim(D)[1]
  Dmat=1
  Delta.hk <- diff(diag(K), diff = Dmat) # difference matrix
  Deltacomb.kh <- t(Delta.hk)%*%solve(Delta.hk%*%t(Delta.hk))

  # Get model estimates
  # Data
  y <- jags_data$data$y
  n <- nrow(y)
  m <- ncol(y)
  x <- jags_data$data$x
  species_names <- jags_data$species_names
  
  
  ##########Get parameter estimates
  delta.hj_samps <- array(NA,c(n_samps,jags_data$data$H,(begin0-1)))
  beta.j_samps <- sigma.z_samps <- array(NA,c(n_samps,(begin0-1)))

      for(j in 1:(begin0-1))
    {
      for(h in 1:jags_data$data$H)
      {
        parname<-paste0("delta.hj[",h,",",j,"]")
        delta.hj_samps[,h,j]<-mcmc.array[1:n_samps,sample(1,ChainNums),parname]
      }
      parname<-paste0("beta.j[",j,"]")
      beta.j_samps[,j]<-mcmc.array[1:n_samps,sample(1,ChainNums),parname]
    }

  for(j in 1:(begin0-1))
  {

    parname<-paste0("sigma.z[",j,"]")
    sigma.z_samps[,j]<-mcmc.array[1:n_samps,sample(1,ChainNums),parname]
  }

  delta0.hj<-apply(delta.hj_samps,2:3,mean)
  delta0_sd<-apply((apply(delta.hj_samps,2:3,sd)),2,median)

  beta0.j<-apply(beta.j_samps,2,mean)
  beta0_sd <- apply(beta.j_samps,2,sd) %>% median
  
  sig0_z <- apply(sigma.z_samps,2,mean)
  
  ## August 2020
  # sigma.z0 <- rep(NA, m)
  # for(i in 1:m)
  #   sigma.z0[i] <- (beta0_sd[i] + delta0_sd[i]*SWLI_grid^2) %>% median
  # 
  # if(any(is.na(sigma.z0)))
  # {
  #   sigma.z0[which(is.na(sigma.z0))] = min(sigma.z0,na.rm = TRUE)
  # }
  # tau.z0 <- 1/((sigma.z0*0.5)^2)
  
  ## September 2020
  sigma.z0 <- rep(NA, (begin0-1))
  for(i in 1:(begin0-1))
  {
  sigma.z0[i] <- delta0_sd[i] + sig0_z[i]
  }
  tau.z0 <- 1/(sigma.z0^2)
 
     # ---------------------------------------------------- results objects
  p_star <- p_star_all <- spline_star <- z_star <- spline_star_all <- array(NA,
                                                                  c(n_samps, length(x), m))

  for (i in 1:n_samps) {
    for (j in begin0:m) {
      spline_star_all[i, , j] <- 0
    }
    for (j in 1:(begin0 - 1)) {
      for (k in 1:length(x)) x.index <- seq(1:length(x))
      spline_star_all[i, , j] <- exp(mcmc.array[i, sample(seq(1,
                                                              3), 1), paste0("spline[", x.index, ",", j, "]")])
 
    }
  }
  

  for (i in 1:n_samps) {
    for (j in 1:m) {
      p_star_all[i, , j] = spline_star_all[i, , j]/apply(spline_star_all[i,
                                                                         , ], 1, sum)
      
    }
  }

  # Get predicted values
  # ----------------------------------------------------
  pred_pi_mean <- apply(p_star_all, 2:3, mean)
  pred_pi_high <- apply(p_star_all, 2:3, "quantile", 0.975)
  pred_pi_low <- apply(p_star_all, 2:3, "quantile", 0.025)
  

  # Plot of predicted output
  # ------------------------------------------------
  df = data.frame((x*jags_data$x_scale)+jags_data$x_center, pred_pi_mean)
  df_low = data.frame((x*jags_data$x_scale)+jags_data$x_center, pred_pi_low)
  df_high = data.frame((x*jags_data$x_scale)+jags_data$x_center, pred_pi_high)

  colnames(df)  = c("SWLI", species_names)
  colnames(df_low) = c("SWLI", species_names)
  colnames(df_high) = c("SWLI", species_names)


  df_long = df %>% pivot_longer(names_to = "species", values_to = "proportion",
                                -SWLI)
  df_low_long = df_low %>% pivot_longer(names_to = "species", values_to = "proportion_lwr",
                                        -SWLI)
  df_high_long = df_high %>% pivot_longer(names_to = "species", values_to = "proportion_upr",
                                          -SWLI)


  src_dat = df_long %>% dplyr::mutate(proportion_lwr = df_low_long %>%
                                         dplyr::pull(proportion_lwr), proportion_upr = df_high_long %>%
                                         dplyr::pull(proportion_upr)) %>% 
    dplyr::arrange(SWLI)


 return(list(delta0.hj = delta0.hj, 
             delta0_sd = delta0_sd,
             beta0.j = beta0.j, 
             beta0_sd = beta0_sd,
             sig0_z = sig0_z,
             tau.z0 = tau.z0, 
             src_dat = src_dat))
}
ncahill89/BTF documentation built on March 29, 2021, 12:04 p.m.