R/07b.btrm_continuous_predict.R

Defines functions btrm_continuous_predict

btrm_continuous_predict=function(ET,ynew,xnew,znew){

  #1 prediction function for new data
  n=nrow(xnew)               # number of subjects
  node.hat=rep(1,n)          # node number
  marker.hat=rep(NA,n)       # marker for subj in test data


  if(ET$numNodes==1){         #There is only 1 terminal node
    marker.hat=rep(ET$marker[1],n)
    node.hat=rep(1,n)
  }else{
    for(i in ET$internal){
      # define index number and eta number
      idx=(node.hat==i)
      #xnew.sel=unique(xnew[which(idx),ET$splitVariable[i]])

      # Split node i into left & right
      left=which(idx & xnew[,ET$splitVariable[i]]<=ET$cutoff[i])
      right=which(idx & xnew[,ET$splitVariable[i]]>ET$cutoff[i])
      node.hat[left]=2*i
      node.hat[right]=2*i+1
      marker.hat[left]=ET$marker[2*i]                # selected marker for each subj
      marker.hat[right]=ET$marker[2*i+1]
    }
  }

  #2 prediction function
  yhat=rep(NA, n) #linear predictor
  for(i in 1:n){
    znew.i=c(1,znew[i,marker.hat[i]])
    yhat[i]=sum(ET$bhat[[node.hat[i]]]*znew.i)
  }

  return(list(node.hat=node.hat,marker.hat=marker.hat,yhat=yhat))
}

Try the btrm package in your browser

Any scripts or data that you put into this service are public.

btrm documentation built on June 8, 2025, 12:45 p.m.