R/itr_run_bartmachine.R

Defines functions test_bart train_bart run_bartmachine

#'
run_bartmachine <- function(
  dat_train,
  dat_test,
  dat_total,
  params,
  indcv,
  iter,
  budget
) {

  ## train
  fit_train <- train_bart(dat_train)


  ## test
  fit_test <- test_bart(
    fit_train, dat_test, dat_total, params$n_df, params$n_tb,
    indcv, iter, budget
  )

  return(list(test = fit_test, train = fit_train))
}


#' @import haven
train_bart <- function(dat_train) {

  ## format training data
  training_data_elements_bart = create_ml_args_bart(dat_train)

  ## format binary outcome
  outcome_bart = training_data_elements_bart[["Y"]]

  if(length(unique(outcome_bart)) == 2){
    outcome_bart = factor(outcome_bart, levels = c(1,0))
  }

  ## fit
  fit <- bartMachine::bartMachine(
            X=training_data_elements_bart[["X_and_T"]],
            y=outcome_bart,
            num_trees = 30,
            run_in_sample = TRUE,
            serialize = TRUE)

  return(fit)
}

#'@importFrom stats predict runif
test_bart <- function(
  fit_train, dat_test, dat_total, n_df, n_tb, indcv, iter, budget
) {

  ## format data
  testing_data_elements_bart = create_ml_args_bart(dat_test)
  total_data_elements_bart   = create_ml_args_bart(dat_total)

  ## predict
  Y0t_total=predict(fit_train,total_data_elements_bart[["X0t"]])
  Y1t_total=predict(fit_train,total_data_elements_bart[["X1t"]])

  tau_total=Y1t_total - Y0t_total + runif(n_df,-1e-6,1e-6)


  ## compute quantities of interest
  tau_test <-  tau_total[indcv == iter]
  That     <-  as.numeric(tau_total > 0)
  That_p   <- as.numeric(tau_total >= sort(tau_test, decreasing = TRUE)[floor(budget*length(tau_test))+1])


  ## output
  cf_output <- list(
    tau      = c(tau_test, rep(NA, length(tau_total) - length(tau_test))),
    tau_cv   = tau_total,
    That_cv  = That,
    That_pcv = That_p
  )

  return(cf_output)
}

Try the evalITR package in your browser

Any scripts or data that you put into this service are public.

evalITR documentation built on Aug. 26, 2023, 1:08 a.m.