R/yhat_draw_bads.R

Defines functions yhat.draw.bads

Documented in yhat.draw.bads

#' @title Draw of y when accept tree movement for bads version
#' @return Draws of yhat and ypred
#' @export

yhat.draw.bads=function(btree_obj,x.test,Rj,tau,sigma2,rotate){
  t_data=btree_obj$t_data

  t_R<-lapply(t_data,function(x) Rj[x])
  mean.draw<-lapply(t_R,function(x){

    pmean=length(x)*mean(x)/sigma2/(length(x)/sigma2+1/tau^2)
    pvar=1/(length(x)/sigma2+1/tau^2)
    draw.mu=rnorm(1,pmean,sqrt(pvar))

    return(draw.mu)
  })

  if(rotate=='rr'){
    x.test = x.test%*%btree_obj$RotMat
  }
  if(rotate=='rraug'){
    x.test = cbind(x.test,x.test%*%btree_obj$RotMat)
  }

  t_idx = apply(x.test,1,function(x){find_terminal_idx(x,btree_obj)})


  yhat=c()
  ypred=c()
  for (dd in 1:length(mean.draw)) {
    yhat[t_data[[dd]]]=mean.draw[[dd]]
    ypred[which(t_idx==dd)]=mean.draw[[dd]]
  }
  return(list(yhat=yhat,ypred=ypred,t_idx=t_idx))
}
DongyueXie/bCART documentation built on Feb. 4, 2020, 12:26 a.m.