R/btrm.R

Defines functions btrm

Documented in btrm

btrm=function(y, x, z,ynew=NULL, xnew=NULL,znew=NULL,
             sparse=TRUE, nwarm=1000, niter=1000,
             minsample=20, base = 0.95, power = 0.8){

  ###
  #1. setup
  ###
  #1.1. outcome type
  type="c"
  if(is.factor(y))
    type="b" #binary

  if(type=="b"){
    compLTP=match.fun(compLTP_binary)
    Grow=match.fun(Grow_binary)
    Change=match.fun(Change_binary)
    btrm_predict=match.fun(btrm_binary_predict)
  }else if(type=="c"){
    compLTP=match.fun(compLTP_continuous)
    Grow=match.fun(Grow_continuous)
    Change=match.fun(Change_continuous)
    btrm_predict=match.fun(btrm_continuous_predict)
  }

  posterior.improved="no" #warm-up or posterior improved
  ###################################################
  ## A. Warm Up
  ###################################################
  #A1. Initial tree; uniform priors for x & z; compute LTP (likelihood, tree probability, posterior)
  ET=InitialTree(y,x,z)
  ET$dir.predictor=rep(1/ET$q,ET$q) #priors
  ET$dir.marker=rep(1/ET$p,ET$p)

  ET$sparse=FALSE
  ET=compLTP(ET,base,power)

  #A2. Warm up tree
  m.predictor=NA #selected predictors
  m.marker=NA    #selected biomarkers

  for(warm in 1:nwarm){
    if(warm %% 100==0)
      message(paste0("Number of warm-up: ", warm, " of total ", nwarm))
    ET1=UpdateTree(ET,minsample, Grow,Change) #stochastic search: growth, prune, change, swap, assign
    if(ET1$size.cond){ #n for each terminal node is large enough
      ET1=compLTP(ET1,base,power)   #compute LTP
      ET=MH(ET1,ET,base,power)
      if(ET$MH=="accepted"){
        posterior.improved="yes"
        #Store the marker and predictor information
        m.predictor=append(m.predictor, ET$splitVariable[ET$internal])
        m.marker=append(m.marker, ET$marker[ET$terminal])
      }
    }
  }

  #A3. Dirchlet priors
  m.predictor=m.predictor[!is.na(m.predictor)]
  m.marker=m.marker[!is.na(m.marker)]

  if(sparse==TRUE){
    for(j in 1:ET$q)
      ET$dir.predictor[j]=ET$dir.predictor[j]+sum(m.predictor==j)
    ET$dir.predictor=ET$dir.predictor/sum(ET$dir.predictor)

    for(j in 1:ET$p)
      ET$dir.marker[j]=ET$dir.marker[j]+sum(m.marker==j)
    ET$dir.marker=ET$dir.marker/sum(ET$dir.marker)
  }

  ###################################################
  ## B. Update Tree
  ###################################################
  #B1. initial value for evaluation
  ET2=ET
  #B2. update tree
  for(iter in 1:niter){
    if(iter %% 100==0)
      message(paste0("Number of update: ", iter, " of total ", niter))
    ET1=UpdateTree(ET,minsample, Grow,Change) #stochastic search: growth, prune, change, swap, assign
    if(ET1$size.cond){ #n for each terminal node is large enough
      ET1=compLTP(ET1,base,power)   #compute LTP
      ET=MH(ET1,ET,base,power)      #Metropolis-Hastings
      if(ET$MH=="accepted"){        #choose the one with the highest posterior
        if(ET$logPosterior>ET2$logPosterior){
          posterior.improved="yes"
          ET2=ET
        }
      }
    }
  }
  if(posterior.improved=="no"){
    warning("Increase nwarm and/or niter\n")
  }else{
    ET=ET2 #choose the one with the highest posterior

    #B3. summarize result
    fit_pred=btrm_predict(ET, ynew=y, xnew=x, znew=z)
    ET$y.hat=fit_pred$yhat #fitted (should be same as ET$yhat)
    ET$node.hat=fit_pred$node.hat

    if(type=="b"){
      y_num=as.numeric(y==1)
      ET$bs=mean((y_num-ET$y.hat)^2)
      ET$roc=ET$auc=NA
      try_roc=try(ET$roc<-pROC::roc(y~ET$y.hat,direction="<",levels=c(0,1)),silent=TRUE)
      if(!inherits(try_roc,'try-error'))
        ET$auc=ET$roc$auc
    }else if(type=="c"){
      ET$mse=mean((y-ET$y.hat)^2)
    }

    if(is.null(ynew)==FALSE & is.null(xnew)==FALSE & is.null(znew)==FALSE){
      fit_new=btrm_predict(ET, ynew=ynew, xnew=xnew, znew=znew)
      ET$y.hat.new=fit_new$yhat
      ET$node.hat.new=fit_new$node.hat
      ET$marker.hat.new=fit_new$marker.hat

      if(type=="b"){
        y_num_new=as.numeric(ynew==1)
        ET$bs.new=mean((y_num_new-ET$y.hat.new)^2)
        ET$roc.new=ET$auc.new=NA
        try_roc=try(ET$roc.new<-pROC::roc(ynew~ET$y.hat.new,direction="<",levels=c(0,1)),silent=TRUE)
        if(!inherits(try_roc,'try-error'))
          ET$auc.new=ET$roc.new$auc
      }else if(type=="c"){
        ET$mse.new=mean((ynew-ET$y.hat.new)^2)
      }
    }

    ###
    #C. further summary
    ###
    #C1. remove unnecessary info during MH search
    if(sum(ET$terminal!=1)){
      #terminal
      nt=max(ET$terminal,na.rm=TRUE)
      ET$marker=ET$marker[1:nt]
      ntt=1:nt
      for(s in setdiff(ntt,ET$terminal)){
        ET$marker[s]=NA
        if(s<=nt){
          ET$bhat[[s]]=NA
        }else{
          ET$bhat[[s]]=NULL
        }
      }

      #internal
      ni=max(ET$internal,na.rm=TRUE)
      ET$splitVariable=ET$splitVariable[1:ni]
      ET$cutoff=ET$cutoff[1:ni]
    }

    #C2. rename
    names(ET)[names(ET)=="bhat"]="beta.hat"

    #C3. print
    btrm_print(ET)

    #C4. remove some variables
    ET$n=ET$q=ET$p=NULL
    ET$y=ET$x=ET$z=NULL
    ET$eta=NULL
    ET$numNodes=NULL

    ET$sparse=NULL
    ET$dir.predictor=ET$dir.marker=NULL
    ET$loglik=ET$logTreeProb=ET$logPosterior=NULL
    ET$MH=ET$method=ET$size.cond=ET$STN=ET$SIN=NULL

    ET$yhat=NULL #should be same as ET$y.hat

    return(ET)
  }
}

Try the btrm package in your browser

Any scripts or data that you put into this service are public.

btrm documentation built on June 8, 2025, 12:45 p.m.