sandbox/VB.R

require(mvtnorm)
PurifyVI <- function(chain_index,
                     current_pars,
                     current_weight,
                     LogPostLike,
                     n_chains,
                     control_params,
                     S, ...) {

  # get statistics about chain
  weight_use <- current_weight[chain_index]
  params_use <- current_pars[chain_index, ]
  len_param_use <- length(params_use)

  params_use <- matrix(params_use, 1, len_param_use)

  # get log weight
  weight_proposal <- ELBO(
    params_use,
    LogPostLike,
    control_params,
    S, ...
  )

  if (is.na(weight_proposal)) weight_proposal <- -Inf

  # greedy acceptance rule
  if (is.finite(weight_proposal)) {
    current_pars[chain_index, ] <- params_use
    current_weight[chain_index] <- weight_proposal
  }

  return(c(current_weight[chain_index], current_pars[chain_index, ]))
}


CrossoverVI <- function(chain_index,
                        current_pars,
                        current_weight,
                        LogPostLike,
                        step_size = .8,
                        jitter_size = 1e-6,
                        n_chains,
                        crossover_rate = 1,
                        control_params,
                        S, ...) {

  # get statistics about chain
  weight_use <- current_weight[chain_index]
  params_use <- current_pars[chain_index, ]
  len_param_use <- length(params_use)

  # use binomial to sample which pars to update matching crossover rate frequency
  param_idices_bool <- stats::rbinom(len_param_use, prob = crossover_rate, size = 1)

  # if no pars selected, randomly sample 1 parameter
  if (all(param_idices_bool == 0)) {
    param_idices_bool[sample(x = 1:len_param_use, size = 1)] <- 1
  }

  # indices of parameters to be updated
  param_indices <- seq(1, len_param_use, by = 1)[as.logical(param_idices_bool)]

  # sample parent chains
  parent_chain_indices <- sample(c(1:n_chains)[-chain_index], 3, replace = F)

  # mate parents for proposal
  params_use[param_indices] <- current_pars[parent_chain_indices[3], param_indices] +
    runif(1, .5, 1) * (current_pars[parent_chain_indices[1], param_indices] -
      current_pars[parent_chain_indices[2], param_indices]) +
    stats::runif(1, -jitter_size, jitter_size)

  params_use <- matrix(params_use, 1, len_param_use)

  # get log weight
  weight_proposal <- ELBO(
    params_use,
    LogPostLike,
    control_params,
    S, ...
  )
  if (is.na(weight_proposal)) weight_proposal <- -Inf

  # greedy acceptance rule
  if (weight_proposal > weight_use) {
    current_pars[chain_index, ] <- params_use
    current_weight[chain_index] <- weight_proposal
  }

  return(c(current_weight[chain_index], current_pars[chain_index, ]))
}

QLog <- function(use_theta, use_lambda, control_params, S = 1) {
  # returns log density Q(theta|lambda)
  out <- 0
  out <- stats::dnorm(
    x = use_theta,
    mean = rep(use_lambda[1:(control_params$n_params_model)], S),
    sd = rep(exp(use_lambda[((control_params$n_params_model) + 1):(control_params$n_params_dist)]), S),
    log = T
  )
  out[(out == -Inf) | is.na(out)] <- control_params$neg_inf
  return(sum(out) / S)
}

QSample <- function(use_lambda, control_params, S) {
  # returns a collapsed vector for a S by n_params_model matrix sampled from Q(theta|lambda)
  if (control_params$use_QMC == T) {
    if (control_params$quasi_rand_seq == "sobol") quantileMat <- c(t(randtoolbox::sobol(S, control_params$n_params_model)))
    if (control_params$quasi_rand_seq == "halton") quantileMat <- c(t(randtoolbox::halton(S, control_params$n_params_model)))
  } else {
    quantileMat <- stats::runif(S * control_params$n_params_model, 0, 1)
  }

  out <- stats::qnorm(quantileMat, rep(use_lambda[1:control_params$n_params_model], S), rep(exp(use_lambda[(control_params$n_params_model + 1):(control_params$n_params_dist)]), S))
  return(out)
}


KLHat <- function(lambda, LogPostLike, control_params, S, ...) {
  # Monte Carlo approximation KL divergence up to a constant

  out <- 0 # initalize output vector

  # sample from q
  theta_mat <- QSample(use_lambda = lambda, control_params, S)

  # calc mean differences in log densities for theta_mat
  q_log_density <- QLog(theta_mat, use_lambda = lambda, control_params, S)
  post_log_density <- mean(apply(matrix(theta_mat, ncol = control_params$n_params_model, byrow = T), MARGIN = 1, FUN = LogPostLike, ...))
  out <- q_log_density - post_log_density


  return(out)
}



ELBO <- function(lambda, LogPostLike, control_params, S, ...) {
  # Monte Carlo approximation ELBO
  out <- KLHat(lambda, LogPostLike, control_params, S, ...) * -1
  return(out)
}

AlgoParamsDEVI <- function(n_pars,
                         param_names = NULL,
                         n_chains = NULL,
                         n_iter = 1000,
                         init_sd = 0.01,
                         init_center = 0,
                         n_cores_use = 1,
                         step_size = NULL,
                         jitter_size = 1e-6,
                         parallel_type = "none",
                         use_QMC = T,
                         quasi_rand_seq = "halton",
                         n_samples_ELBO = 10,
                         LRVB_correction = TRUE,
                         n_samples_LRVB = 25,
                         neg_inf = -750,
                         thin = 1,
                         burnin = 0,
                         return_trace = FALSE,
                         crossover_rate = 1) {
  # n_pars
  ### catch errors
  n_pars <- as.integer(n_pars)
  if (any(!is.finite(n_pars))) {
    stop("ERROR: n_pars is not finite")
  } else if (n_pars < 1 | length(n_pars) > 1) {
    stop("ERROR: n_pars must be a postitive integer scalar")
  }

  # param_names
  ### catch errors
  if (is.null(param_names)) {
    param_names <- paste0("par", 1:n_pars)
  } else if (!(length(param_names) == n_pars)) {
    stop("ERROR: param_names does not match size of n_pars")
  }

  dist_param_names <- c(paste0(param_names, "_MEAN"), paste0(param_names, "_VAR"))

  # n_chains
  ### if null assign default value
  if (is.null(n_chains)) {
    n_chains <- max(3 * n_pars, 4)
  }
  ### catch errors
  n_chains <- as.integer(n_chains)
  if (any(!is.finite(n_chains))) {
    stop("ERROR: n_chains is not finite")
  } else if (n_chains < 4 | length(n_chains) > 1) {
    stop("ERROR: n_chains must be a postitive integer scalar, and atleast 4")
  }

  # n_iter
  ### if null assign default value
  if (is.null(n_iter)) {
    n_iter <- 1000
  }
  ### catch errors
  n_iter <- as.integer(n_iter)
  if (any(!is.finite(n_iter))) {
    stop("ERROR: n_iter is not finite")
  } else if (n_iter < 4 | length(n_iter) > 1) {
    stop("ERROR: n_iter must be a postitive integer scalar, and atleast 4")
  }

  # init_sd
  init_sd <- as.numeric(init_sd)
  if (any(!is.finite(init_sd))) {
    stop("ERROR: init_sd is not finite")
  } else if (any(init_sd < 0 | is.complex(init_sd))) {
    stop("ERROR: init_sd must be positive and real-valued")
  } else if (!(length(init_sd) == 1 | length(init_sd) == n_pars)) {
    stop("ERROR: init_sd vector length must be 1 or n_pars")
  }
  if (any(init_sd == 0)) {
    warning("WARNING an init_sd value is 0")
  }

  # init_center
  init_center <- as.numeric(init_center)
  if (any(!is.finite(init_center))) {
    stop("ERROR: init_center is not finite")
  } else if (any(is.complex(init_center))) {
    stop("ERROR: init_center must be real valued")
  } else if (!(length(init_center) == 1 | length(init_center) == n_pars)) {
    stop("ERROR: init_center vector length must be 1 or n_pars")
  }

  # n_cores_use
  ### assign NULL value default
  if (is.null(n_cores_use)) {
    n_cores_use <- 1
  }
  ### catch any errors
  n_cores_use <- as.integer(n_cores_use)
  if (any(!is.finite(n_cores_use))) {
    stop("ERROR: n_cores_use is not finite")
  } else if (n_cores_use < 1 | length(n_cores_use) > 1) {
    stop("ERROR: n_cores_use must be a postitive integer scalar, and atleast 1")
  }


  # step_size
  ### assign NULL value default
  if (is.null(step_size)) {
    step_size <- 2.38 / sqrt(2 * n_pars) # step size recommend in ter braak's 2006 paper
  }
  ### catch any errors
  if (any(!is.finite(step_size))) {
    stop("ERROR: step_size is not finite")
  } else if (any(step_size <= 0 | is.complex(step_size))) {
    stop("ERROR: step_size must be positive and real-valued")
  } else if (!(length(step_size) == 1)) {
    stop("ERROR: step_size vector length must be 1 ")
  }

  # jitter_size
  ### assign NULL value default
  if (is.null(jitter_size)) {
    jitter_size <- 1e-6
  }
  ### catch any errors
  if (any(!is.finite(jitter_size))) {
    stop("ERROR: jitter_size is not finite")
  } else if (any(jitter_size <= 0 | is.complex(jitter_size))) {
    stop("ERROR: jitter_size must be positive and real-valued")
  } else if (!(length(jitter_size) == 1)) {
    stop("ERROR: jitter_size vector length must be 1 ")
  }

  # crossover_rate
  ### if null assign default value
  if (any(is.null(crossover_rate))) {
    crossover_rate <- 1
  }
  ### catch errors
  crossover_rate <- as.numeric(crossover_rate)
  if (any(!is.finite(crossover_rate))) {
    stop("ERROR: crossover_rate is not finite")
  } else if (any(crossover_rate > 1) | any(crossover_rate <= 0) | length(crossover_rate) > 1) {
    stop("ERROR: crossover_rate must be a numeric scalar on the interval (0,1]")
  } else if (is.complex(crossover_rate)) {
    stop("ERROR: crossover_rate cannot be complex")
  }

  # parallel_type
  validParType <- c("none", "FORK", "PSOCK")
  ### assign NULL value default
  if (is.null(parallel_type)) {
    parallel_type <- "none"
  }
  ### catch any errors
  if (!parallel_type %in% validParType) {
    stop("ERROR: invalid parallel_type")
  }

  # thin
  ### if null assign default value
  if (is.null(thin)) {
    thin <- 0
  }
  ### catch errors
  thin <- as.integer(thin)
  if (any(!is.finite(thin))) {
    stop("ERROR: thin is not finite")
  } else if (any(thin < 1) | length(thin) > 1) {
    stop("ERROR: thin must be a scalar postive integer")
  }


  # use_QMC
  if (any(is.null(use_QMC))) {
    use_QMC <- TRUE
  }
  if (length(use_QMC) > 1) {
    stop("length(use_QMC)>1, please use a scalar logical")
  }
  if (!((use_QMC == 0) | (use_QMC == 1))) {
    stop("ERROR: use_QMC must be a scalar logical")
  }
  use_QMC <- as.logical(use_QMC)


  # LRVB_correction
  ### assign value if null
  if (any(is.null(LRVB_correction))) {
    LRVB_correction <- TRUE
  }
  ### catch errors
  if (length(LRVB_correction) > 1) {
    stop("length(LRVB_correction)>1, please use a scalar logical")
  }
  if (!((LRVB_correction == 0) | (LRVB_correction == 1))) {
    stop("ERROR: LRVB_correction must be a scalar logical")
  }
  LRVB_correction <- as.logical(LRVB_correction)

  if (LRVB_correction) {
    # if using LRVB correction check for valid samples count
    # n_samples_LRVB
    ### assign value if null
    n_samples_LRVB <- as.integer(n_samples_LRVB)
    if (any(is.null(n_samples_LRVB))) {
      n_samples_LRVB <- 25
    }
    ### catch errors
    if (any(!is.finite(n_samples_LRVB))) {
      stop("ERROR: n_samples_LRVB is not finite")
    } else if (any(n_samples_LRVB < 1) | length(n_samples_LRVB) > 1) {
      stop("ERROR: n_samples_LRVB must be a scalar postive integer")
    }
  }

  # n_samples_ELBO
  ### assign value if null
  n_samples_ELBO <- as.integer(n_samples_ELBO)
  if (any(is.null(n_samples_ELBO))) {
    n_samples_ELBO <- 10
  }
  ### catch errors
  if (any(!is.finite(n_samples_ELBO))) {
    stop("ERROR: n_samples_ELBO is not finite")
  } else if (any(n_samples_ELBO < 1) | length(n_samples_ELBO) > 1) {
    stop("ERROR: n_samples_ELBO must be a scalar postive integer")
  }

  # quasi_rand_seq
  quasi_rand_seq <- tolower(as.character(quasi_rand_seq))
  valid_quasi_rand_seqs <- c("sobol", "halton")
  ### assign NULL value default
  if (is.null(quasi_rand_seq)) {
    quasi_rand_seq <- "sobol"
  }
  ### catch any errors
  if (!quasi_rand_seq %in% valid_quasi_rand_seqs) {
    stop(paste0("ERROR: invalid quasi_rand_seq; must be one of: ", paste(valid_quasi_rand_seqs, sep = ",")))
  }

  # neg_inf
  ### assign NULL value default
  if (any(is.null(neg_inf))) {
    neg_inf <- -750
  }
  ### catch any errors
  if (length(neg_inf) > 1) {
    stop("length(neg_inf)>1, please use a scalar numeric")
  }
  if (!is.numeric(neg_inf)) {
    stop("neg_inf must be a numeric")
  }

  # burnin
  ### if null assign default value
  if (is.null(burnin)) {
    burnin <- 0
  }
  ### catch errors
  burnin <- as.integer(burnin)
  if (any(!is.finite(burnin))) {
    stop("ERROR: burnin is not finite")
  } else if (any(burnin < 0) | any(burnin >= n_iter) | length(burnin) > 1) {
    stop("ERROR: burnin must be a scalar integer from the interval [0,n_iter)")
  }

  # nSamples Per Chains
  n_iters_per_chain <- floor((n_iter - burnin) / thin)
  ### catch errors
  if (n_iters_per_chain < 1 | (!is.finite(n_iters_per_chain))) {
    stop("ERROR: number of iters per chain is negative or non finite.
           n_iters_per_chain=floor((n_iter-burnin)/thin)")
  }

  out <- list(
    "n_params_model" = n_pars,
    "param_names" = param_names,
    "n_chains" = n_chains,
    "n_iter" = n_iter,
    "init_sd" = init_sd,
    "init_center" = init_center,
    "n_cores_use" = n_cores_use,
    "step_size" = step_size,
    "jitter_size" = jitter_size,
    "crossover_rate" = crossover_rate,
    "parallel_type" = parallel_type,
    "return_trace" = return_trace,
    "purify" = Inf,
    "use_QMC" = use_QMC,
    "quasi_rand_seq" = quasi_rand_seq,
    "n_samples_ELBO" = n_samples_ELBO,
    "n_samples_LRVB" = n_samples_LRVB,
    "LRVB_correction" = LRVB_correction,
    "thin" = thin,
    "neg_inf" = neg_inf,
    "n_params_dist" = 2 * n_pars,
    "n_iters_per_chain" = n_iters_per_chain
  )

  return(out)
}

dataExample <- rmvnorm(200, c(-1, 1), sigma = matrix(c(
  1, .5,
  .5, 1
), nrow = 2, ncol = 2, byrow = T))
## list parameter names
param_names_example <- c("mu_1", "mu_2")

# log posterior likelihood function = log likelihood + log prior | returns a scalar
LogPostLikeExample <- function(x, data, param_names) {
  out <- 0

  names(x) <- param_names

  # log prior
  out <- out + sum(dnorm(x["mu_1"], 0, sd = 1, log = TRUE))
  out <- out + sum(dnorm(x["mu_2"], 0, sd = 1, log = TRUE))

  # log likelihoods
  # out=out+sum(dnorm(data[,1],x["mu_1"],sd=1,log=TRUE))
  # out=out+sum(dnorm(data[,2],x["mu_2"],sd=1,log=TRUE))
  out <- out + dmvnorm(data, mean = x, sigma = matrix(c(
    1, .5,
    .5, 1
  ), nrow = 2, ncol = 2, byrow = T))
  return(out)
}

DEVI <- function(LogPostLike, control_params = AlgoParamsDEVI(), ...) {

  # import values we will reuse throughout process
  # create memory structures for storing posterior samples
  lambda <- array(NA, dim = c(
    control_params$n_iters_per_chain,
    control_params$n_chains,
    control_params$n_params_dist
  ))
  ELBO_values <- matrix(-Inf,
    nrow = control_params$n_iters_per_chain,
    ncol = control_params$n_chains
  )


  # chain initialization
  message("initalizing chains...")
  for (chain_idx in 1:control_params$n_chains) {
    count <- 0
    while (ELBO_values[1, chain_idx] == -Inf) {
      lambda[1, chain_idx, ] <- stats::rnorm(
        control_params$n_params_dist,
        control_params$init_center,
        control_params$init_sd
      )

      ELBO_values[1, chain_idx] <- ELBO(lambda[1, chain_idx, ], LogPostLike, control_params, S = control_params$n_samples_ELBO, ...)
      count <- count + 1
      if (count > 100) {
        stop("chain initialization failed.
        inspect likelihood and prior or change init_center/init_sd to sample more
             likely parameter values")
      }
    }
    message(paste0(chain_idx, " / ", control_params$n_chains))
  }
  message("chain initialization complete  :)")

  # cluster initialization
  if (!control_params$parallel_type == "none") {
    message(paste0(
      "initalizing ",
      control_params$parallel_type, " cluser with ",
      control_params$n_cores_use, " cores"
    ))

    doParallel::registerDoParallel(control_params$n_cores_use)
    cl_use <- parallel::makeCluster(control_params$n_cores_use,
      type = control_params$parallel_type
    )
  }

  message("running DE to find best variational approximation")
  lambdaIdx <- 1
  for (iter in 1:control_params$n_iter) {

    #####################
    ####### crossover
    #####################
    if (control_params$parallel_type == "none") {
      temp <- matrix(unlist(lapply(1:control_params$n_chains, CrossoverVI,
        current_pars = lambda[lambdaIdx, , ], # current parameter values for chain (numeric vector)
        current_weight = ELBO_values[lambdaIdx, ], # corresponding log like for (numeric vector)
        LogPostLike = LogPostLike, # log likelihood function (returns scalar)
        step_size = control_params$step_size,
        jitter_size = control_params$jitter_size,
        n_chains = control_params$n_chains,
        crossover_rate = control_params$crossover_rate,
        control_params = control_params,
        S = control_params$n_samples_ELBO, ...
      )),
      nrow = control_params$n_chains,
      ncol = control_params$n_params_dist + 1, byrow = T
      )
    } else {
      temp <- matrix(unlist(parallel::parLapply(cl_use, 1:control_params$n_chains, CrossoverVI,
        current_pars = lambda[lambdaIdx, , ], # current parameter values for chain (numeric vector)
        current_weight = ELBO_values[lambdaIdx, ], # corresponding log like for (numeric vector)
        LogPostLike = LogPostLike, # log likelihood function (returns scalar)
        step_size = control_params$step_size,
        jitter_size = control_params$jitter_size,
        n_chains = control_params$n_chains,
        crossover_rate = control_params$crossover_rate,
        control_params,
        S = control_params$n_samples_ELBO, ...
      )),
      control_params$n_chains,
      control_params$n_params_dist + 1,
      byrow = T
      )
    }
    # update particle chains
    ELBO_values[lambdaIdx, ] <- temp[, 1]
    lambda[lambdaIdx, , ] <- temp[, 2:(control_params$n_params_dist + 1)]
    if (iter < control_params$n_iter) {
      ELBO_values[lambdaIdx + 1, ] <- temp[, 1]
      lambda[lambdaIdx + 1, , ] <- temp[, 2:(control_params$n_params_dist + 1)]
    }

    #####################
    ####### purify
    #####################
    if (iter %% control_params$purify == 0) {
      if (control_params$parallel_type == "none") {
        temp <- matrix(unlist(lapply(1:control_params$n_chains, PurifyVI,
          current_pars = lambda[lambdaIdx, , ], # current parameter values for chain (numeric vector)
          current_weight = ELBO_values[lambdaIdx, ], # corresponding log like for (numeric vector)
          LogPostLike = LogPostLike, # log likelihood function (returns scalar)
          n_chains = control_params$n_chains,
          control_params = control_params,
          S = control_params$n_samples_ELBO, ...
        )),
        nrow = control_params$n_chains,
        ncol = control_params$n_params_dist + 1, byrow = T
        )
      } else {
        temp <- matrix(unlist(parallel::parLapply(cl_use, 1:control_params$n_chains, PurifyVI,
          current_pars = lambda[lambdaIdx, , ], # current parameter values for chain (numeric vector)
          current_weight = ELBO_values[lambdaIdx, ], # corresponding log like for (numeric vector)
          LogPostLike = LogPostLike, # log likelihood function (returns scalar)
          step_size = control_params$step_size,
          jitter_size = control_params$jitter_size,
          n_chains = control_params$n_chains,
          crossover_rate = control_params$crossover_rate,
          control_params,
          S = control_params$n_samples_ELBO, ...
        )),
        control_params$n_chains,
        control_params$n_params_dist + 1,
        byrow = T
        )
      }

      # update particle chains
      ELBO_values[lambdaIdx, ] <- temp[, 1]
      lambda[lambdaIdx, , ] <- temp[, 2:(control_params$n_params_dist + 1)]
      if (iter < control_params$n_iter) {
        ELBO_values[lambdaIdx + 1, ] <- temp[, 1]
        lambda[lambdaIdx + 1, , ] <- temp[, 2:(control_params$n_params_dist + 1)]
      }
    }




    if (iter %% 100 == 0) message(paste0("iter ", iter, "/", control_params$n_iter))
    if (iter %% control_params$thin == 0) {
      lambdaIdx <- lambdaIdx + 1
    }
  }
  # cluster stop
  if (!control_params$parallel_type == "none") {
    parallel::stopCluster(cl = cl_use)
  }
  maxIdx <- which.max(ELBO_values[control_params$n_iters_per_chain, ])

  means <- lambda[
    control_params$n_iters_per_chain,
    maxIdx, 1:control_params$n_params_model
  ]
  names(means) <- paste0(control_params$param_names, "_mean")

  covariance <- diag(exp(2 * lambda[
    control_params$n_iters_per_chain,
    maxIdx, (control_params$n_params_model + 1):control_params$n_params_dist
  ]))


  if (control_params$return_trace == T) {
    return(list(
      "means" = means,
      "covariance" = covariance,
      "ELBO" = ELBO_values[control_params$n_iters_per_chain, maxIdx],
      "lambda_trace" = lambda,
      "ELBO_trace" = ELBO_values,
      "control_params" = control_params
    ))
  } else {
    return(list(
      "means" = means,
      "covariance" = covariance,
      "ELBO" = ELBO_values[control_params$n_iters_per_chain, maxIdx],
      "control_params" = control_params
    ))
  }
}

out <- DEVI(
  LogPostLike = LogPostLikeExample,
  control_params = AlgoParamsDEVI(
    n_pars = length(param_names_example),
    n_iter = 200,
    n_samples_ELBO = 5,
    n_chains = 25, return_trace = T, crossover_rate = .8, use_QMC = T
  ),
  data = dataExample,
  param_names = param_names_example
)



baseout1 <- optim(
  par = rep(0, 2 * length(param_names_example)),
  fn = KLHat,
  method = "Nelder-Mead",
  LogPostLike = LogPostLikeExample,
  control_params = AlgoParamsDEVI(n_pars = length(param_names_example)),
  S = 50, param_names = param_names_example, data =
    dataExample, hessian = TRUE
)

baseout2 <- optim(
  par = rep(0, 2 * length(param_names_example)),
  fn = KLHat,
  method = "CG",
  LogPostLike = LogPostLikeExample,
  control_params = AlgoParamsDEVI(n_pars = length(param_names_example)),
  S = 50, param_names = param_names_example, data =
    dataExample, hessian = TRUE
)

baseout3 <- optim(
  par = rep(0, 2 * length(param_names_example)),
  fn = KLHat,
  method = "",
  LogPostLike = LogPostLikeExample,
  control_params = AlgoParamsDEVI(n_pars = length(param_names_example)),
  S = 50, param_names = param_names_example, data =
    dataExample, hessian = TRUE
)


par(mfrow = c(2, 2))
matplot(out$lambda_trace[, , 2], type = "l")
matplot(out$lambda_trace[, , 1], type = "l")
matplot(out$lambda_trace[, , 3], type = "l")
matplot(out$lambda_trace[, , 4], type = "l")
bmgaldo/DEBBI documentation built on May 22, 2022, 7:37 p.m.