R/inference_direct_fitc.R

# inference_direct_fitc

# FITC inference for Gaussian likelihood and sparse GP

# y is response data
# data is a dataframe containing columns on which the kernel acts,
#   giving observations on which to train the model
# inducing_data ia a dataframe giving the locations of inducing points
# kernel is a gpe kernel object
# mean_function is a gpe mean function (for now just an R function)
inference_direct_fitc <- function(y,
                                  data,
                                  kernel,
                                  likelihood,
                                  mean_function,
                                  inducing_data,
                                  weights,
                                  verbose = verbose) {
  
  # NB likelihood is ignored
  
  # throw a warning (for now) if weights aren't specified
  if (!all(weights == rep(1, length(y)))) {
    warning('weights argument currently ignored for full inference')
  }
  
  # create a diagonal matrix with a small amount of noise for inducing points
  Ixx_noise <- diag(nrow(data)) * 10 ^ -6
  
  # evaluate prior mean function
  mn_prior <- mean_function(data)
  
  # get self kernels for old and new data
  # want self-variance on the old data (so one arg)
  Kxx <- kernel(data)
  
  # FITC components
  Kzz <- kernel(inducing_data)
  Kzzi <- solve(Kzz)
  Kxz <- kernel(data, inducing_data)
  Kzx <- t(Kxz)
  
  # (from Quinonero-candela & Rasmussen)
  Qxx <- Kxz %*% Kzzi %*% Kzx
  
  # calculate $\Lambda$
  Lambda <- diag(diag(Kxx - Qxx + Ixx_noise))
  # invert $\Lambda$ (it's diagonal, so just that bit)
  iLambda <- diag(1 / diag(Lambda))
  
  # calculate $\Sigma$
  Sigma <- solve(Kzz + Kzx %*% iLambda %*% Kxz)
  
  # log marginal likelihood
  # look into speeding up determinant and inversion
  # I think it's in Vanhatalo et al. sparse ... disease ...
  lZ <- - (log(det(Qxx + Lambda))) / 2 + 
    (t(y - mn_prior) %*% solve(Qxx + Lambda) %*% (y - mn_prior)) / 2 -
    (nrow(data) * log(2 * pi)) / 2
  
  # return posterior object
  posterior <- createPosterior(inference_name = 'inference_direct_fitc',
                               lZ = lZ[1, 1],
                               data = data,
                               kernel = kernel,
                               likelihood = likelihood,
                               mean_function = mean_function,
                               inducing_data = inducing_data,
                               weights,
                               Kzzi = Kzzi,
                               Kzx = Kzx,
                               Sigma = Sigma,
                               iLambda = iLambda,
                               y = y,
                               mn_prior = mn_prior)
  
  return (posterior)
  
}

# projection for FITC models
project_direct_fitc <- function(posterior,
                                new_data,
                                variance = c('none', 'diag', 'matrix')) {
  
  # get the required variance argument
  variance <- match.arg(variance)
  
  # prior mean over the test locations
  mn_prior_xp <- posterior$mean_function(new_data)
  
  # prior covariance over the test locations
  Kxpxp <- posterior$kernel(new_data)
  
  # FITC components
  Kxpz <- posterior$kernel(new_data,
                           posterior$inducing_data)
  
  mu <- Kxpz %*% posterior$components$Sigma %*%
    posterior$components$Kzx %*%
    posterior$components$iLambda %*%
    (posterior$components$y - posterior$components$mn_prior) +
    mn_prior_xp
  
  if (variance == 'none') {
    
    # if mean only
    var <- NULL
    
  } else {
    
    # compute common variance components
    Kzxp <- t(Kxpz)
    Qxpxp <- Kxpz %*% posterior$components$Kzzi %*% Kzxp
    K <- Kxpxp - Qxpxp + Kxpz %*% posterior$components$Sigma %*% Kzxp    
    
    if (variance == 'diag') {
      
      # if diagonal (elementwise) variance only
      
      # make this more efficient!
      var <- diag(K)
      
    } else {
      
      # if full variance
      var <- K
      
    }
  }
  
  # return both
  ans <- list(mu = mu,
              var = var)
  
  return (ans)
  
}
goldingn/gpe documentation built on May 17, 2019, 7:41 a.m.