R/cv_WAPL.R

Defines functions cv_WAPL

Documented in cv_WAPL

#' WAPL with CV
#' This is based on WAPL, and uses cross validation to tune the hyper-parameter lambda and number of basis
#'  @param  H, n by p Covariate matrix
#'  @param  A, Treatment assigned, length n vector
#'  @param R2, Residual or the original outcome
#'  @param prop: Pr(A|H) the propensity score, assumed to be known.
#'  @param pentype: Penalty type used for residual calculation, default is lasso, otherwise LSE is used
#'  @param m: m-fold cross validation for choosing the tunning parameter lambda
#'  @param group: Group information, should be consective, default is NULL i.e. no group information presents
#'  @param plist: Candidate value of number of basis, default is c(3:10)
#'  @param lambda: A user supplied lambda sequence(ordered in decreasing value). Normally set it to null so the computing algorithm calculate it automatically.
#'  @param nlambda: number of lambda, default is 50
#'  @param lambda.min.ratio: the ration between max lambda and minimal lambda
#'  @param thol: Stopping precision. The default value is 1e-5.
#'  @param mu: Smoothing parameter used in approximate the Hinge Loss. The default value is 0.05.
#'  @param max.ite: The number of maximum iterations. The default value is 1e5
#'  @import gglasso
#'  @import glmnet
#'  @export
#'  @examples
#'  train.data <- gSim(N=200, sigma=0, scenario=1)
#'  H <- train.data[[1]]
#'  A <- train.data[[2]]
#'  R2 <- train.data[[3]]
#'  group=rep(1:20, each=3)
#'  tst = cv_WAPL(H, A, R2 , prop=rep(1,200), pentype = "lasso",lambda.min.ratio=0.2, m=7, group= group, plist=c(3:5))

cv_WAPL <- function(H, A, R2, prop=rep(1,n), pentype = "lasso", m=10, group= NULL, plist=c(3:10),
                    lambda=NULL, nlambda=50, lambda.min.ratio=0.2, thol=1e-5, mu=0.05, max.ite=1e5){

  if (is.null(lambda)) {
    npar=100
    nlambda=100
  }else{ npar= length(lambda)}
  n=length(A)

  #calculate the residual
  if (max(R2) != min(R2)) {
    if (pentype == "lasso") {
      cvfit = glmnet::cv.glmnet(as.matrix(H), R2, nfolds = 10)
      co = as.matrix(predict(cvfit, s = "lambda.min", type = "coeff"))
    }else if(pentype=="gglasso"){
      cvfit = gglasso::cv.gglasso(x=as.matrix(H), y=R2, group=group, pred.loss="L1", nfolds = 10)
      co = coef(cvfit, s = "lambda.min")
    }
    else if (pentype == "LSE") {
      co = coef(lm(R2 ~ H))
    }
    else stop(gettextf("'pentype' is the penalization type for the regression step of Olearning, the default is 'lasso',\nit can also be 'LSE' without penalization"))
    r = R2 - cbind(rep(1, n), as.matrix(H)) %*% co
  } else r = R2

  #shift the residual to make it positive
  r=(r-min(r))/(max(r)-min(r))
  w=r/prop
  #cross-validation to tune the parameter
  rand = sample(m, n, replace = TRUE)
  V = matrix(NA, m, npar)

  CVres = lapply(plist, function(pp){
    print(pp)
    fit = gSAM(X=H, y=A, w=w, p=pp, group=group,lambda=lambda, nlambda=nlambda,
               lambda.min.ratio=lambda.min.ratio, thol = thol, mu=mu, max.ite=max.ite)
    lambda0=fit$lambda
    V = matrix(NA, nrow=m, ncol=npar)
    for (i in 1:m) {
      gc()
      this = (rand != i)
      X = H[this, ]
      Y = A[this]
      W = w[this]
      Xt = H[!this, ]
      Yt = A[!this]
      Wt = w[!this]
      model= gSAM(X=X, y=Y, w=W, p=pp, group=group, lambda=lambda0, nlambda=nlambda,
                  lambda.min.ratio=lambda.min.ratio, thol = thol, mu=mu, max.ite=max.ite)
      print(dim(model$w))
      pred.tst <-tryCatch( predict(model, Xt), error=function(e) {e})
      if ("error" %in% class(pred.tst)){
        next
        print(pred.tst)
      } else {
        YP = pred.tst$labels ###return a n by #lambda matrix
        YP.idx=(YP == Yt)
        V[i,] = apply(YP.idx * Wt, 2, sum, na.rm=T)/apply(YP.idx, 2, sum, na.rm=T)
        #print(V[i,])
      }
    }
    mat =  colMeans(V, na.rm=T)
    pot_idx = (mat == max(mat,na.rm=T ))
    bestL = lambda0[pot_idx]
    bestL = bestL[!is.na(bestL)]
    if (length(bestL)==0){
      return(c(NA, NA))
    } else if (length(bestL)>1) {
      YP.idx = (predict(fit, H)$labels == A)[, pot_idx]
      bestL = lambda0[which.max(apply(YP.idx*drop(w), 2, sum)/apply(YP.idx, 2, sum))]
    }
    return(c(max(mat, na.rm=T), bestL))
  })

  CVres=Reduce(rbind, CVres)
  #print(CVres)
  bestP=plist[which.max(CVres[,1])]
  bestL=CVres[which.max(CVres[,1]), 2]
  print(c(bestP, bestL))
  fit = gSAM(X=H, y=A, w=w, p=bestP, group=group, lambda=bestL, thol = thol, mu=mu, max.ite=max.ite)
  rm(CVres, r)
  return(fit)
}
sambiostat/WAPL documentation built on May 26, 2020, 12:17 a.m.