#' WAPL with CV
#' This is based on WAPL, and uses cross validation to tune the hyper-parameter lambda and number of basis
#' @param H, n by p Covariate matrix
#' @param A, Treatment assigned, length n vector
#' @param R2, Residual or the original outcome
#' @param prop: Pr(A|H) the propensity score, assumed to be known.
#' @param pentype: Penalty type used for residual calculation, default is lasso, otherwise LSE is used
#' @param m: m-fold cross validation for choosing the tunning parameter lambda
#' @param group: Group information, should be consective, default is NULL i.e. no group information presents
#' @param plist: Candidate value of number of basis, default is c(3:10)
#' @param lambda: A user supplied lambda sequence(ordered in decreasing value). Normally set it to null so the computing algorithm calculate it automatically.
#' @param nlambda: number of lambda, default is 50
#' @param lambda.min.ratio: the ration between max lambda and minimal lambda
#' @param thol: Stopping precision. The default value is 1e-5.
#' @param mu: Smoothing parameter used in approximate the Hinge Loss. The default value is 0.05.
#' @param max.ite: The number of maximum iterations. The default value is 1e5
#' @import gglasso
#' @import glmnet
#' @export
#' @examples
#' train.data <- gSim(N=200, sigma=0, scenario=1)
#' H <- train.data[[1]]
#' A <- train.data[[2]]
#' R2 <- train.data[[3]]
#' group=rep(1:20, each=3)
#' tst = cv_WAPL(H, A, R2 , prop=rep(1,200), pentype = "lasso",lambda.min.ratio=0.2, m=7, group= group, plist=c(3:5))
cv_WAPL <- function(H, A, R2, prop=rep(1,n), pentype = "lasso", m=10, group= NULL, plist=c(3:10),
lambda=NULL, nlambda=50, lambda.min.ratio=0.2, thol=1e-5, mu=0.05, max.ite=1e5){
if (is.null(lambda)) {
npar=100
nlambda=100
}else{ npar= length(lambda)}
n=length(A)
#calculate the residual
if (max(R2) != min(R2)) {
if (pentype == "lasso") {
cvfit = glmnet::cv.glmnet(as.matrix(H), R2, nfolds = 10)
co = as.matrix(predict(cvfit, s = "lambda.min", type = "coeff"))
}else if(pentype=="gglasso"){
cvfit = gglasso::cv.gglasso(x=as.matrix(H), y=R2, group=group, pred.loss="L1", nfolds = 10)
co = coef(cvfit, s = "lambda.min")
}
else if (pentype == "LSE") {
co = coef(lm(R2 ~ H))
}
else stop(gettextf("'pentype' is the penalization type for the regression step of Olearning, the default is 'lasso',\nit can also be 'LSE' without penalization"))
r = R2 - cbind(rep(1, n), as.matrix(H)) %*% co
} else r = R2
#shift the residual to make it positive
r=(r-min(r))/(max(r)-min(r))
w=r/prop
#cross-validation to tune the parameter
rand = sample(m, n, replace = TRUE)
V = matrix(NA, m, npar)
CVres = lapply(plist, function(pp){
print(pp)
fit = gSAM(X=H, y=A, w=w, p=pp, group=group,lambda=lambda, nlambda=nlambda,
lambda.min.ratio=lambda.min.ratio, thol = thol, mu=mu, max.ite=max.ite)
lambda0=fit$lambda
V = matrix(NA, nrow=m, ncol=npar)
for (i in 1:m) {
gc()
this = (rand != i)
X = H[this, ]
Y = A[this]
W = w[this]
Xt = H[!this, ]
Yt = A[!this]
Wt = w[!this]
model= gSAM(X=X, y=Y, w=W, p=pp, group=group, lambda=lambda0, nlambda=nlambda,
lambda.min.ratio=lambda.min.ratio, thol = thol, mu=mu, max.ite=max.ite)
print(dim(model$w))
pred.tst <-tryCatch( predict(model, Xt), error=function(e) {e})
if ("error" %in% class(pred.tst)){
next
print(pred.tst)
} else {
YP = pred.tst$labels ###return a n by #lambda matrix
YP.idx=(YP == Yt)
V[i,] = apply(YP.idx * Wt, 2, sum, na.rm=T)/apply(YP.idx, 2, sum, na.rm=T)
#print(V[i,])
}
}
mat = colMeans(V, na.rm=T)
pot_idx = (mat == max(mat,na.rm=T ))
bestL = lambda0[pot_idx]
bestL = bestL[!is.na(bestL)]
if (length(bestL)==0){
return(c(NA, NA))
} else if (length(bestL)>1) {
YP.idx = (predict(fit, H)$labels == A)[, pot_idx]
bestL = lambda0[which.max(apply(YP.idx*drop(w), 2, sum)/apply(YP.idx, 2, sum))]
}
return(c(max(mat, na.rm=T), bestL))
})
CVres=Reduce(rbind, CVres)
#print(CVres)
bestP=plist[which.max(CVres[,1])]
bestL=CVres[which.max(CVres[,1]), 2]
print(c(bestP, bestL))
fit = gSAM(X=H, y=A, w=w, p=bestP, group=group, lambda=bestL, thol = thol, mu=mu, max.ite=max.ite)
rm(CVres, r)
return(fit)
}
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.