R/MC_gWAPL.R

Defines functions MC_gSAM

Documented in MC_gSAM

#' wrap models to test WAPL
#' This peforms Monte Carlo simulation for a given scenario, comparing 6 models together
#' @param n.rep: number of Monte Carlo simulation
#' @param N: sample size for training data
#' @param sgm: noise
#' @param sim.seed: simulation seed
#' @param group: group information, length n vector, if NULL, then no group selection applies
#' @param s: scenario
#' @param PosX.idx: true prognostic variable index
#' @param m: numer of fold for cross validation
#' @param lambda: A user supplied lambda sequence(ordered in decreasing value). Normally set it to null so the computing algorithm calculate it automatically.
#' @param nlambda: number of lambda, default is 50
#' @param lambda.min.ratio: the ration between max lambda and minimal lambda
#' @param splineP: Candidate value of number of basis, default is c(3:10)
#' @param sgm_list: Candidate value sigma, used in rbf kernel for OWL/RWL
#' @import DTRlearn

MC_gSAM<- function(n.rep=100, N=400, sgm=0.1, sim.seed=0425, group=NULL,
                   s=1, PosX.idx=c(1,2,3), m=10, lambda=NULL, nlambda=100, lambda.min.ratio=0.2,
                   splineP=c(3:10),  sgm_list=c(0.01, 0.05, 0.1, 1) )
{
  if (s == 1 | s==3){
    group=rep(1:50,  each=2)
    p=100
  }else{
    group=rep(1:10,  each=2)
    p=20
  }

  set.seed(sim.seed);
  test.tmp = paste("MCgSAM_",N, "S",s,"test.RDS",sep="")
  agg.tmp = paste("MCgSAM",N, "S",s,"RunningTmp.RDS",sep="")

  test.data <- gSim(N=10000, sigma=sgm, scenario=s)
  saveRDS(test.data, test.tmp)
  X.test <- test.data$X
  optTr.test <- test.data$optA
  optTr.test <- factor(test.data$optA, levels=c(-1,1))
  y.test <- test.data$y
  Tr.test <- test.data$A
  y1.test <- test.data$Y1
  y0.test <- test.data$Y0

  #will store the estimated Values and proportion correct decisions (PCD).
  results.aggregated <- NULL;
  VarselgSAM <- matrix(NA, nrow=n.rep, ncol=p);
  VarselSAM <- matrix(NA, nrow=n.rep, ncol=p);
  VarselQlearn <- matrix(NA, nrow=n.rep, ncol=p);
  VarselgQlearn <- matrix(NA, nrow=n.rep, ncol=p);
  #results.aggregated=restmp[[1]]; VarselgSAM=restmp[[2]]; VarselSAM=restmp[[3]];VarselQlearn=restmp[[4]]; VarselgQlearn=restmp[[5]]
  #set the propensity score of 1's
  for(rep.number in 1:n.rep)
  {
    print(rep.number);
    gc()
    train.data <- gSim(N=N, sigma=sgm, scenario=s)
    X <- train.data$X
    Tr <- train.data$A
    y <- train.data$Y

    #no grouop info used
    Model<-cv_WAPL(H=X, A=Tr, R2=y, prop=rep(0.5,N), pentype = "gglasso", m=m, group=NULL,
                   p=splineP, lambda=lambda, thol=1e-5, mu=0.05, max.ite=1e5, lambda.min.ratio=lambda.min.ratio)
    sam.coef=as.data.frame(matrix(Model$w[-length(Model$w)], ncol=p)) #splineP by p matrix for the coef
    sam.Pcoef = apply(abs(sam.coef), 2, mean)
    sam.Xsel.idx=which(sam.Pcoef!=0)
    sam.Xnosel.idx=which(sam.Pcoef==0)
    if(length(sam.Xsel.idx)==0){sam.TP=0} else{sam.TP=(sum(sam.Xsel.idx %in% PosX.idx)) /length(PosX.idx)} #sen
    if(length(sam.Xnosel.idx)==0) {sam.TN=0} else{sam.TN=(sum(!(sam.Xnosel.idx %in% PosX.idx))) /(p-length(PosX.idx))} #spe
    VarselSAM[rep.number, sam.Xsel.idx]=1
    VarselSAM[rep.number, sam.Xnosel.idx]=0
    sam.predA=predict(Model, newdata=X.test)$labels
    rm(Model, sam.coef,sam.Pcoef,sam.Xsel.idx,sam.Xnosel.idx)

    #gWAPL fit
    print(group)
    Model<-cv_WAPL(H=X, A=Tr, R2=y, prop=rep(0.5,N), pentype = "lasso", m=m, group=group,
                   p=splineP, lambda=lambda, thol=1e-5, mu=0.05, max.ite=1e5, lambda.min.ratio=lambda.min.ratio)
    sam.coef=as.data.frame(matrix(Model$w[-length(Model$w)], ncol=p)) #splineP by p matrix for the coef
    sam.Pcoef = apply(abs(sam.coef), 2, mean)
    sam.Xsel.idx=which(sam.Pcoef!=0)
    sam.Xnosel.idx=which(sam.Pcoef==0)
    if(length(sam.Xsel.idx)==0){gsam.TP=0} else{gsam.TP=(sum(sam.Xsel.idx %in% PosX.idx)) /length(PosX.idx)} #sen
    if(length(sam.Xnosel.idx)==0) {gsam.TN=0} else{gsam.TN=(sum(!(sam.Xnosel.idx %in% PosX.idx))) /(p-length(PosX.idx))} #spe
    VarselgSAM[rep.number, sam.Xsel.idx]=1
    VarselgSAM[rep.number, sam.Xnosel.idx]=0
    gsam.predA=predict(Model, newdata=X.test)$labels
    rm(Model,sam.coef,sam.Pcoef,sam.Xsel.idx,sam.Xnosel.idx)

    #Qlearning fit, no group info used
    qfit = Qlearning_gglasso(H = as.matrix(X), A = Tr, R = y, pA = rep(0.5, N), group=seq(1: (1+2*length(group) ) ),
                             pentype = "lasso", m = m, loss="ls")
    pred.qlearn=predict.qlearn(qfit, as.matrix(X.test) )
    Q.predA = pred.qlearn$opt_trt
    q.coef = qfit$co
    cont_cols <- which(substr(rownames(q.coef),1,1) == "A")
    q.coef_contra = q.coef[cont_cols][-1]
    Xsel.idx=which(q.coef_contra!=0)
    Xnosel.idx=which(q.coef_contra==0)

    if(length(Xsel.idx)==0){qTP=0} else{qTP=(sum(Xsel.idx %in% PosX.idx)) /length(PosX.idx)} #sen
    if(length(Xnosel.idx)==0) {qTN=0} else{qTN=(sum(!(Xnosel.idx %in% PosX.idx))) /(p-length(PosX.idx))} #spe
    VarselQlearn[rep.number, Xsel.idx]=1
    VarselQlearn[rep.number, Xnosel.idx]=0
    rm(qfit,pred.qlearn,q.coef,q.coef_contra ,cont_cols)
    print("Q Check")

    ##GQfit
    qfit = Qlearning_gglasso(H = as.matrix(X), A = Tr, R = y, pA = rep(0.5, N), group=c(group, max(group)+1, group+1+max(group)),
                             pentype = "lasso", m = m, loss="ls")
    pred.qlearn=predict.qlearn(qfit, as.matrix(X.test) )
    gQ.predA = pred.qlearn$opt_trt
    q.coef = qfit$co
    cont_cols <- which(substr(rownames(q.coef),1,1) == "A")
    q.coef_contra = q.coef[cont_cols][-1]
    Xsel.idx=which(q.coef_contra!=0)
    Xnosel.idx=which(q.coef_contra==0)

    if(length(Xsel.idx)==0){gqTP=0} else{gqTP=(sum(Xsel.idx %in% PosX.idx)) /length(PosX.idx)} #sen
    if(length(Xnosel.idx)==0) {gqTN=0} else{gqTN=(sum(!(Xnosel.idx %in% PosX.idx))) /(p-length(PosX.idx))} #spe
    VarselgQlearn[rep.number, Xsel.idx]=1
    VarselgQlearn[rep.number, Xnosel.idx]=0
    rm(qfit,pred.qlearn,q.coef,q.coef_contra,cont_cols )
    print("GQ Check")


    ###rwl
    rwl.fit <-tryCatch(DTRlearn::Olearning_Single(X, Tr, y, pi=rep(0.5, N), kernel ="rbf",sigma=sgm_list,
                                                  clinear = 2^(-2:2), m=m),
                       error=function(e) {e})
    if ("error" %in% class(rwl.fit)) {
      optTr.rwl  = NA
      print(rwl.fit)
    }else{
      optTr.rwl<-  predict(rwl.fit, X.test)
      optTr.rwl<-factor(optTr.rwl, levels=c(-1,1))
      print("rwl check")
    }
    rm(rwl.fit)

    ###owl
    owl.fit <-tryCatch( DTRlearn2::owl(H=X, AA=Tr, RR=y, n=N, K=1, pi=rep(0.5, N), kernel="rbf",loss='hinge', sigma=sgm_list,
                                       augment=F, m=m,  res.lasso=T),
                        error=function(e) {e})
    if ("error" %in% class(owl.fit)) {
      optTr.rwl  = NA
      print(owl.fit)
    }else{
      optTr.owl <-predict(owl.fit, H=X.test, AA=Tr.test, RR=y.test, K=1, pi=rep(0.5, length(Tr.test)))$treatment[[1]]
      optTr.owl<-factor(optTr.owl, levels=c(-1,1))
      print("owl check")
    }
    rm(owl.fit)

    #output summary
    outsum = lapply(list(gsam.predA, sam.predA, gQ.predA,  Q.predA, optTr.owl, optTr.rwl ), function(estA){
      PCD = mean(estA == optTr.test)
      PotentY=c(y1.test[estA==1], y0.test[estA==-1])
      EY = mean(PotentY)
      MedY= median(PotentY)
      return(c(PCD=PCD,EY=EY,MedY=MedY))
    })

    results <- c(gSamSen=gsam.TP, gSAMSpe=gsam.TN, SamSen=sam.TP, SAMSpe=sam.TN,
                 gQSen=gqTP, gQSpe=gqTN, QSen=qTP, QSpe=qTN,
                 MeanValuegSAM=outsum[[1]][[2]],  MeanValueSAM=outsum[[2]][[2]], MeanValueGQ=outsum[[3]][[2]], MeanValueQ=outsum[[4]][[2]],
                 MeanValuegOWL=outsum[[5]][[2]], MeanValueRWL=outsum[[6]][[2]],

                 AccuracygSAM=outsum[[1]][[1]],  AccuracySAM=outsum[[2]][[1]], AccuracyGQ=outsum[[3]][[1]], AccuracyQ=outsum[[4]][[1]],
                 AccuracyOWL=outsum[[5]][[1]], AccuracygRWL=outsum[[4]][[1]],

                 MedValuegSAM=outsum[[1]][[3]], MedValuegSAM=outsum[[1]][[3]] , MedValueGQ= outsum[[2]][[3]], MedValueQ= outsum[[2]][[3]],
                 MedValueOWL=outsum[[5]][[3]], MedValueRWL=outsum[[6]][[3]] )

    print(results)

    results.aggregated  <- rbind(results.aggregated, results);
    saveRDS(list("results.aggregated"=results.aggregated, "VarselgSAM"=VarselgSAM,  "VarselSAM"=VarselSAM,
                 "VarselgQlearn"=VarselgQlearn, "VarselQlearn"=VarselQlearn ), agg.tmp)
  }

  mean.results <- apply(results.aggregated, 2, mean, na.rm=T);
  sd.results <- apply(results.aggregated, 2, sd, na.rm=T);
  return(list(results.aggregated=results.aggregated,  mean.results=mean.results,sd.results=sd.results,
              "VarselgSAM"=VarselgSAM,  "VarselSAM"=VarselSAM,
              "VarselgQlearn"=VarselgQlearn, "VarselQlearn"=VarselQlearn,
              scenario=s, unlist(test.data[8:19])))
}
sambiostat/WAPL documentation built on May 26, 2020, 12:17 a.m.