knitr::opts_chunk$set(echo = TRUE)
D <- DAG.empty()

D <- D +
  node("W1", distr = "runif", min = -1, max = 1) +
  node("W2", distr = "runif", min = -1, max = 1) +
  node("A", distr = "rbinom", size = 1, prob = plogis(0.7*(W1 + W2 ) ))+
  node("mu", distr = "rconst", const =  A*0.5 + 0.5*plogis(W1 + W2)) +
  node("Y", distr = "rbinom", size = 1, prob = mu) 


setD <- set.DAG(D )
n <- 1500
data <- sim(setD, n = n)

family <- binomial()

family$mu.eta(2)
passes <- c()
for(i in 1:100) {
  print(i)
  D <- DAG.empty()

D <- D +
  node("W1", distr = "runif", min = -1, max = 1) +
  node("W2", distr = "runif", min = -1, max = 1) +
  node("A", distr = "rbinom", size = 1, prob = plogis(0.7*(W1 + W2 ) ))+
  node("mu0", distr = "rconst", const =   0.3*plogis(1 + W1 + W2)) +
  node("RR", distr = "rconst", const =   2) +
  node("Y", distr = "rbinom", size = 1, prob = (1-A)*mu0 + A*mu0*RR) 


setD <- set.DAG(D )
n <- 1500
data <- sim(setD, n = n)




A <- data$A
Y <- data$Y
W <- as.matrix(data[,c("W1", "W2")])

out <- spRR(formula_logRR = ~ 1, W, A, Y, family_RR = gaussian(), sl3_Lrnr_A = Lrnr_gam$new(), sl3_Lrnr_Y = Lrnr_gam$new(), smoothness_order = 1, max_degree = 2, num_knots = c(5,3), fit_control = list())
print(out)
passes <- cbind(passes, out[,5] <= log(2)  & out[,6] >= log(2) )
print(rowMeans(passes))
}
passes <- c()
for(i in 1:100) {
  print(i)
  D <- DAG.empty()

D <- D +
  node("W1", distr = "runif", min = -1, max = 1) +
  node("W2", distr = "runif", min = -1, max = 1) +
  node("A", distr = "rbinom", size = 1, prob = plogis(0.7*(W1 + W2 ) ))+
  node("mu", distr = "rconst", const =   (1+W1)*A*0.25 + (W1 + W2)) +
  node("Y", distr = "rnorm", mean = mu, sd = 0.2) 


setD <- set.DAG(D )
n <- 2500
data <- sim(setD, n = n)




A <- data$A
Y <- data$Y
W <- as.matrix(data[,c("W1", "W2")])

out <- spCATE(formula_CATE = ~ W1+1, W, A, Y, family_CATE = gaussian(), sl3_Lrnr_A = Lrnr_glm$new(), sl3_Lrnr_Y_optional = NULL, sl3_Lrnr_sigma = Lrnr_glm$new(),  smoothness_order_Y0W = 1, max_degree_Y0W = 1, num_knots_Y0W = 10, fit_control = list())

passes <- cbind(passes, out[,5] <= 0.25  & out[,6] >= 0.25 )
print(rowMeans(passes))
}
smoothness_order_Y0W <- 1
max_degree_Y0W <- 1
num_knots_Y0W <- 10
fit_control <- list()
sl3_Lrnr_A <- Lrnr_glm$new()
sl3_Lrnr_sigma <- Lrnr_glm$new()

formula_CATE <- ~ 1
family_CATE <- gaussian()

V <- model.matrix(formula_CATE , data = as.data.frame(W))

# Estimate g
data_A <- data.frame(W, A = A)
task_A <- sl3_Task$new(data_A, covariates = colnames(W), outcome = "A")
sl3_Lrnr_A <- sl3_Lrnr_A$train(task_A)
g1 <- sl3_Lrnr_A$predict(task_A)
g0 <- 1- g1


sl3_Lrnr_Y_optional <- Lrnr_gam$new()
fit_separate <- !is.null(sl3_Lrnr_Y_optional)

# Estimate part lin Q
if(!fit_separate){
  fit_Y <- fit_hal(X = as.matrix(W), X_unpenalized = as.matrix(A*V), Y = as.vector(Y), family = family_CATE, fit_control = fit_control, smoothness_orders = smoothness_order_Y0W, max_degree = max_degree_Y0W, num_knots = num_knots_Y0W)
  Q <- predict(fit_Y, new_data = as.matrix(W), new_X_unpenalized = (A*V))
  Q0 <- predict(fit_Y, new_data = as.matrix(W), new_X_unpenalized = (0*V))
  Q1 <- predict(fit_Y, new_data = as.matrix(W), new_X_unpenalized = (1*V))
} else {

  data_Y <- data.frame(W, A = A, Y=Y)
  task_Y <- sl3_Task$new(data_Y, covariates = c(colnames(W), "A"), outcome = "Y")
  data_Y1 <- data.frame(W, A = 1, Y=Y)
  task_Y1 <- sl3_Task$new(data_Y1, covariates = c(colnames(W), "A"), outcome = "Y")
  data_Y0 <- data.frame(W, A = 0, Y=Y)
  task_Y0 <- sl3_Task$new(data_Y0, covariates = c(colnames(W), "A"), outcome = "Y")

  sl3_Lrnr_Y_optional <- sl3_Lrnr_Y_optional$train(task_Y)
  Q <-  sl3_Lrnr_Y_optional$predict(task_Y)
  Q1 <-  sl3_Lrnr_Y_optional$predict(task_Y1)
  Q0 <-  sl3_Lrnr_Y_optional$predict(task_Y0)

}
print(quantile(Q1-Q0))
beta <- coef(glm.fit(V, Q1-Q0, family = family_CATE, intercept = F))
link <- V %*% beta
CATE <- family_CATE$linkinv(link)
Q <- A*CATE + Q0
Q1 <- CATE + Q0

print(quantile(beta))


# Estimate var
binary <- T
if(binary) {
  sigma2 <- Q*(1-Q)
  sigma21 <- Q1*(1-Q1)
  sigma20 <- Q0*(1-Q0)
} else {
  data_sigma <- data.frame(W, A = A, Y=Y^2)
  task_sigma <- sl3_Task$new(data_sigma, covariates = c(colnames(W), "A"), outcome = "Y")
  data_sigma1 <- data.frame(W, A = 1, Y=Y^2)
  task_sigma1 <- sl3_Task$new(data_sigma1, covariates = c(colnames(W), "A"), outcome = "Y")
  data_sigma0 <- data.frame(W, A = 0, Y=Y^2)
  task_sigma0 <- sl3_Task$new(data_sigma0, covariates = c(colnames(W), "A"), outcome = "Y")

  sl3_Lrnr_sigma <- sl3_Lrnr_sigma$train(task_sigma)
  EY2 <- sl3_Lrnr_sigma$predict(task_sigma)
  EY2_1 <- sl3_Lrnr_sigma$predict(task_sigma1)
  EY2_0 <- sl3_Lrnr_sigma$predict(task_sigma0)
  sigma2 <- EY2 - Q^2
  sigma20 <- EY2_0 - Q0^2
  sigma21 <- EY2_1 - Q1^2
}



for(i in 1:100) {
  if(binary) {
    sigma2 <- Q*(1-Q)
    sigma21 <- Q1*(1-Q1)
    sigma20 <- Q0*(1-Q0)
  } else {
    sigma2 <- EY2 - Q^2
    sigma20 <- EY2_0 - Q0^2
    sigma21 <- EY2_1 - Q1^2
  }
  gradM <- family_CATE$mu.eta(V%*%beta)*V
  num <- gradM * ( g1/sigma21)
  denom <- (g0/ sigma20 + g1/sigma21)
  hstar <- - num/denom 
  H <- (A*gradM  + hstar) /sigma2
  EIF <- as.matrix(H * (Y-Q))
  sd(EIF)
  scale <- as.matrix(apply(gradM, 2, function(v) {colMeans_safe((A*gradM  + hstar) *  A*gradM * v /sigma2  )}) )
  scaleinv <- solve(scale)
  EIF <-  EIF %*%   scaleinv
  sd(EIF)

  scores <- colMeans(EIF)
  direction_beta <- scores/sqrt(mean(scores^2))
  print(scores)
  if(max(abs(scores)) <= 1/n) {
    break
  }
  linpred <- family_CATE$linkfun(Q1-Q0)
  risk_function <- function(eps) {

    loss <- (Y - family_CATE$linkinv(A*linpred +  eps * A*V %*%direction_beta) - Q0 - eps*hstar %*% direction_beta)^2 / sigma2
    mean(loss)
  }

  optim_fit <- optim(
    par = list(epsilon = 0.01), fn = risk_function,
    lower = 0, upper = 0.01,
    method = "Brent"
  )
  eps <-  direction_beta * optim_fit$par
  Q0 <- Q0 + hstar %*% eps
  CATE <- family_CATE$linkinv(linpred +  V %*% eps)
  beta <- coef(glm.fit(V, CATE, family = family_CATE, intercept = F))
  link <- V %*% beta
  CATE <- family_CATE$linkinv(link)
  Q <- A*CATE + Q0
  Q1 <- CATE + Q0




}

se <- sd(EIF)
print(beta)
ci <- c(beta - 1.96*se/sqrt(n),beta +1.96*se/sqrt(n) )
print(ci)
passes <- c(passes, ci[1] <= 0.5 & ci[2] >= 0.5)
print(mean(passes))

}


Larsvanderlaan/npOddsRatio documentation built on May 3, 2022, 12:05 p.m.