R/gibbSampler.R

Defines functions gibbSampler

Documented in gibbSampler

#' Gibbs sampler for mixture model.
#'
#' Gibbs sampler for mixture model.
#' @param theta parameters to be estimated
#' @param y data
#' @param p.sites the composition of each site.
#' @param season.sites season data for each site.
#' @param ID ID vector of the sites.
#' @param N.run Number of run of the sampler.
#' @keywords Gibbs sampler
#' @export
#' @examples
#' data("wq_analysis_week2")
#' SPTMData(wq.raw.obs, frequency = "quarter")
gibbSampler = function(priors = list(theta = list(m0 = 0, s0 = 1), sigma2 = list(a0=2, b0=1)), y, p.sites, season.sites, basis, ID,
                       N.run = 10000,
                       mixture = TRUE,
                       meanFunc = NULL, debug = FALSE, print.res = FALSE, ...){



  theta0 = c(priors$theta$m0, priors$sigma2$b0 / (priors$sigma2$a0-1))

  n.sites = nrow(p.sites)

  n.lu = ncol(p.sites)

  if(!mixture){theta0 = theta0[1:(length(theta0)-n.lu + 1)]}

  n.param = length(theta0)

  n.season = dim(season.sites)[2]

  if(length(dim(season.sites))==2){n.season = 1}

  theta.hist = matrix(0, nrow = N.run, ncol = n.param)

  p.hist = array(0, dim = c(n.sites, n.lu, N.run))

  if(mixture){

  z.hist = array(0, c(n.sites, n.lu, N.run))

  p.hist[,,1] =as.matrix(p.sites)

  theta.hist[1,] = theta0

  idx.max = apply(p.sites, 1,which.max)

  for(li in 1:nrow(p.sites)){
    z.hist[li,idx.max[li],1] = 1
  }

  n.colors = ncol(p.sites)

  Ym = array(rep(y$obs, n.colors), dim = c(n.sites, nrow(season.sites), n.colors))

  ProbM = array(NA, dim =  c(n.sites, nrow(season.sites), n.colors))

  for(t in 2:N.run){

    sigma = theta.hist[t-1, (length(theta0) - n.lu + 1):length(theta0)]

    # Updating z

    m =  meanFunc(theta.hist[t-1,], n.mixt = n.lu, z.hist[,,t-1], season.sites)
    idx.nan = which(is.nan(colSums(m)))

    Yp = array(rep(m, nrow(season.sites)), dim = c( nrow(season.sites),n.colors, n.sites))

    Prob = NULL
    for(j in 1:n.lu){
      Prob = cbind(Prob,rowSums(dnorm(Ym[,,j], t(Yp[,j,]) , sigma[j], log = T), na.rm=T)* sqrt(2*pi*sigma[j]^2))
    }

    logProb = Prob + log(p.hist[,,t-1])

    mlogProb = t(apply(logProb ,1, function(x) x - max(x, na.rm=T)))

    fProb = t(apply(mlogProb ,1, function(x) exp(x)/ sum(exp(x))))

    z.hist[,,t] <- 0

    z.Num = apply(fProb,1, function(x) sample(1:n.lu,1,prob = x))

    idx.z = cbind(1:n.sites, z.Num, rep(t,n.sites))

    z.hist[idx.z] <- 1

    n.t = colSums(z.hist[,,t])



    # Updating p_k

    for(i in 1:n.sites){
      alpha = priors$z$a[i,] + z.hist[i,,t] / n.colors
      p.hist[i,,t] = rdirichlet(c(as.matrix(alpha)))
    }

    # Updating mu_k


    season.lu = updateSeason(season.sites, p.sites = z.hist[,,t])

    l.j = NULL
    for(j in 1:n.lu){

      index.j = ID[which(z.hist[,j,t]==1)]

      idx.Xlu = j + (0:((length(theta0) - n.lu)/n.lu-1))*n.lu
      idx.jlu = 0:(n.season-1)*n.colors + j
      Xlu = season.lu[,idx.jlu]

      if(length(index.j)>0){

        y.p = y[y$ID %in% index.j,]

        l.j = c(l.j,length(y.p$obs))

        Xs = kronecker(Xlu,matrix(1,ncol = 1, nrow = length(index.j)))

        idx.na = which(!is.na(y.p$obs))
        Xs = Xs[idx.na,]
        est.b = ginv(t(Xs) %*% Xs) %*% t(Xs) %*% y.p$obs[idx.na]

        theta.hist[t,idx.Xlu] = est.b


      }else{

        l.j = c(l.j, 0)

        idx.Xlu = j + (0:((length(theta0) - n.lu)/n.lu-1))*n.lu

        theta.hist[t,idx.Xlu] = theta.hist[t-1,idx.Xlu]
      }


    }

    if(debug){browser()}

    n.j = colSums(z.hist[,,t])

    var.post = (1 / priors$theta$s0 + rep(n.j, n.season) / theta.hist[t-1,(length(theta0) - n.lu + 1):length(theta0)])^(-1)

    mean.post = var.post * (priors$theta$m0 / priors$theta$s0 + n.j * theta.hist[t,1:(n.season*n.colors)] / theta.hist[t-1,(length(theta0) - n.lu + 1):length(theta0)])

    theta.hist[t,1:(n.season*n.colors)] = rnorm(n.season*n.colors,mean.post, sd=sqrt(var.post))


    # Updating sigma_k

    m =  meanFunc(theta.hist[t,], n.mixt = n.lu, z.hist[,,t], season.sites)

    y.s = NULL
    for(i in ID){
      idx.i = which(ID == i)
      idx.j = which(z.hist[idx.i,,t] == 1)
      y.p = y[which(y$ID == i),]
      y.s = c(y.s,m[,idx.j])
    }
    y.s.t = matrix(y.s, nrow = length(ID), byrow = T)
    y.s = matrix(y.s.t, ncol = 1)

    for(j in 1:n.lu){

      idx.obs = which(y$ID %in% ID[which(z.hist[,j,t] == 1)])

      mean.se = 0

      if(length(idx.obs)>0){mean.se = mean((y.s[idx.obs] - y$obs[idx.obs])^2, na.rm=T)}

      n.obs= length(which(z.hist[,j,t] == 1)) #length(idx.obs) - sum(is.na(y$obs[idx.obs]))

      a = priors$sigma2$a0[j] +n.obs/2
      b = n.obs*mean.se /2 + priors$sigma2$b0[j]

#       mean(rigamma(10000,a,b));sd(rigamma(10000,a,b));
#       mean(rigamma(10000,priors$sigma2$a0[j],priors$sigma2$b0[j]));sd(rigamma(10000,priors$sigma2$a0[j],priors$sigma2$b0[j]));

      theta.hist[t,(length(theta0) - n.lu + 1):length(theta0)][j] = rigamma(1,a,b)
    }

    #theta.hist[t,ncol(theta.hist)] = sd.t
    if((round(t/100) == (t/100)) & print.res){print(t)}
  }


  RET = list(theta = theta.hist, z= z.hist, p = p.hist, priors = priors)
  }

  return(RET)

}
ick003/SpTMixture documentation built on May 18, 2019, 2:32 a.m.