#### SCM ####
SCM <- R6::R6Class("SCM",
public = list(
source = "matrix",
target = "matrix",
a = "vector",
b = "vector",
A = "matrix",
bf = "R6",
gridInit = function(...) {
return(c(NA_real_))
},
solve = function(penalty = NULL, w = NULL) {
n_t <- nrow(self$target)
n_s <- nrow(self$source)
d <- ncol(self$source)
w_list <- vector("list", n_t)
res <- vector("list", 1)
addl_0 <- rep(0,d)
w_init <- c(rep(1/n_s,n_s), addl_0)
# update BF penalty function
if(inherits(self$bf, "balanceFunction") && !is.null(penalty)) {
if(is.list(penalty)) penalty <- penalty[[1]]
if(is.numeric(penalty) && penalty >= 0) {
self$bf$delta <- as.numeric(penalty)
} else {
warning("provided penalty for balanceFunction method is not a number > 0. Not used!")
}
}
#update target
t_idx <- private$target_idx
l <- private$solver$GetData(element = "l")
u <- private$solver$GetData(element = "u")
for (j in 1:n_t) {
l[t_idx] <- u[t_idx] <- c(self$target[j,])
private$solver$Update(l = l, u = u)
res[[1]] <- osqp_R6_solve(private$solver, self$bf$delta,
self$bf$delta_idx, w_init, normalize = FALSE)
res[[1]][res[[1]] < 0] <- 0
w_list[[j]] <- renormalize(res[[1]][1:n_s]) * self$b[j]
}
return(Reduce("+", w_list))
},
initialize = function(source, target,
a = NULL, b = NULL,
options = list()) {
# browser()
if(!inherits(options, "scmOptions")) {
if(!is.list(options)) stop("options must be a list or output of scmOptions function")
options <- do.call(scmOptions, options)
}
self$source <- as.matrix(source)
self$target <- as.matrix(target)
self$a <- check_weights(a, self$source)
self$b <- check_weights(b, self$target)
n <- nrow(self$source)
d <- ncol(self$source)
if (!is.null(options$balance.formula)) {
if (is.null(colnames(source))) colnames(source) <- paste0("X", 1:ncol(source))
if (is.null(colnames(target))) colnames(target) <- colnames(source)
tbf <- terms(formula(options$balance.formula))
attr(tbf, "intercept") <- 0
source.bf <- model.matrix(tbf, data.frame(source))
target.bf <- model.matrix(tbf, data.frame(target))
self$bf <- balanceFunction$new(source = source.bf,
target = target.bf,
a = as.numeric(self$ot$a),
b = as.numeric(self$ot$b),
delta = options$delta)
# self$runbf <- TRUE
k <- ncol(target.bf)
} else {
# self$runbf <- FALSE
self$bf <- list(NULL)
k <- 0
}
nvars <- n + d
# q <- c(self$source %*% c(self$target[1,]), rep(0, d))
q <- NULL
P <- Matrix::sparseMatrix(i = n + 1:d, j = n + 1:d, x = 1)
A_quad <- cbind(Matrix::Matrix(t(self$source)),
Matrix::Diagonal(d, x = 1)
)
u_quad <- l_quad <- c(self$target[1,])
A_bounds <- rbind(cbind(Matrix::Diagonal(n, x = 1),
Matrix::Matrix(matrix(0, n, d))
),
Matrix::sparseMatrix(i = rep(1,n), j = 1:n, x = 1, dims = c(1, nvars)))
l_bounds <- c(rep(0, n), 1)
u_bounds <- c(rep(Inf, n), 1)
if(k > 0) {
A_bf <- cbind(
Matrix::Matrix(
t(self$bf$A)
),
Matrix::Matrix(matrix(0, k, d)))
l_bf <- rep(-self$bf$delta, k)
u_bf <- rep(self$bf$delta, k)
self$bf$delta_idx <- 1:length(l_bf)
} else {
A_bf <- l_bf <- u_bf <- NULL
}
l <- c(l_bf, l_bounds, l_quad)
u <- c(u_bf, u_bounds, u_quad)
A <- rbind(A_bf, A_bounds, A_quad)
private$target_idx <- (k + n + 2):length(l)
private$solver <- osqp::osqp(P = P, q = q,
A = A, l = l, u = u,
pars = options$solver.options)
}
),
private = list(
target_idx = "numeric",
solver = "R6"
)
)
# @param lambda Penalty parameter on the weights. Not currently used but here because a plan is to add it.
# @param delta The constraint parameter for the balancing functions. Not currently used.
# @param grid.length The number of penalty parameters to try. Not currently used.
# @param nboot The number of bootstrap samples. Not currently used.
# @param balance.formula The formula that denotes the covariate functions to balance. Not currently used.
#' Options for the SCM Method
#'
#' @param ... Arguments passed to the [osqpSettings()][osqp::osqpSettings()] function which solves the problem.
#'
#' @return A list with arguments to pass to [osqpSettings()][osqp::osqpSettings()]
#' @export
#'
#' @details Options for the solver used in the optimization of the Synthetic Control Method of Abadie and Gardeazabal (2003).
#'
#' @examples
#' opts <- scmOptions()
scmOptions <- function(
# lambda = NULL,
# delta = NULL,
# grid.length = 7L,
# nboot = 1000L,
# balance.formula = NULL,
...) { # dots are the osqp args
mc <- match.call()
used.args <- as.list(mc)[-1]
# browser()
grid.length <- 7L
nboot <- 1000L
delta <- NULL
balance.formula <- NULL
gsOpts <- gridSearchOptions(nboot = nboot, grid.length = grid.length)
nboot <- gsOpts$nboot
grid.length <- gsOpts$grid.length
output <- list()
# lambda currently not used but may consider in future with following code
# if(arg_not_used(lambda)) {
# output["lambda"] <- list(NULL)
# } else {
# if(any(lambda < 0)) stop("lambda must be >= 0")
# output$lambda <- sort(lambda, decreasing = TRUE)
# }
output$lambda <- NULL
if(arg_not_used(delta)) {
output["delta"]<- list(NULL)
} else {
if(any(delta < 0)) stop("delta must be >= 0")
output$delta <- sort(delta, decreasing = TRUE)
}
# also only L2 penalty at this time but may consider in future
# if ( arg_not_used(penalty) ) {
# output$penalty <- "entropy"
# } else {
# output$penalty <- match.arg( penalty, c("entropy", "L2") )
# }
if( arg_not_used(balance.formula) ) {
output["balance.formula"] <- list(NULL)
} else {
balance.formula <- as.character(balance.formula)
bf_split <- strsplit(balance.formula, "~")
output$balance.formula <- paste0("~ 0 +", bf_split[[1]][2])
}
# only for delta at this time
if ( arg_not_used(grid.length) ) {
output$grid.length <- 7L
} else {
output$grid.length <- as.integer(grid.length)
if(grid.length <= 0) stop("grid.length must be greater than 0")
}
if (!is.null(nboot)) {
output$nboot <- as.integer(nboot)
} else {
output$nboot <- 1000L
}
# if (!is.null(output$lambda) && !is.null(output$delta)) {
# output["grid.length"] <- list(NULL)
# }
if ( !is.null(output$delta)) {
output["grid.length"] <- list(NULL)
}
# if (output$penalty == "L2") {
output$solver.options <- list(...)[...names() %in% methods::formalArgs(osqp::osqpSettings)]
# }
class(output) <- "scmOptions"
return(output)
}
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.