R/cotOOP.R

Defines functions sum.weightEnv `/.weightEnv` `*.weightEnv` `-.weightEnv` `+.weightEnv` unaryop.weightEnv binaryop.weightEnv `/.OTProblem` `*.OTProblem` `-.OTProblem` `+.OTProblem` binaryop.OTProblem binaryop is_ot_problem OTProblem oop_loss_select Measure

Documented in Measure oop_loss_select OTProblem

# COT general form using object oriented code and R6 classes

#' @name Measure_
#' @title An R6 object for measures
#' @rdname Measure_-class
#' @description Internal R6 class object for Measure objects
#' @keywords internal
Measure_ <- R6::R6Class("Measure", # change name later for Roxygen purposes
 public = {list(
   # data
   
   #' @field balance_functions the functions of the data that 
   #' we want to adjust towards the targets
   balance_functions = "torch_tensor",
   
   #' @field balance_target the values the balance_functions are targeting
   balance_target = "vector",
   
   # values
   #' @field adapt What aspect of the data will be adapted. One of "none","weights", or "x".
   adapt  = "character",
   
   #' @field device the [torch::torch_device()] of the data.
   device = "torch_device",
   
   #' @field dtype the [torch::torch_dtype] of the data.
   dtype  = "torch_dtype",
   
   #' @field n the rows of the covariates, x.
   n      = "integer",
   
   #' @field d the columns of the covariates, x.
   d      = "integer",
   
   #' @field probability_measure is the measure a probability measure?
   probability_measure = "logical",
   
   # functions
   #' @description 
   #' generates a deep clone of the object without gradients.
   detach = function() { #removes gradient calculation in copy
     orig_adapt <- self$adapt
     
     if (orig_adapt != "none") {
       if (orig_adapt == "weights") {
         private$mass_$requires_grad <- FALSE
       } else if (orig_adapt == "x") {
         private$data_$requires_grad <- FALSE
       }
     }
     
     temp_obj <- self$clone(deep = TRUE)
     
     if (orig_adapt != "none") {
       temp_obj$adapt <- "none"
       if (orig_adapt == "weights") {
         private$mass_$requires_grad <- TRUE
       } else if (orig_adapt == "x") {
         private$data_$requires_grad <- TRUE
       }
     }
     return(temp_obj)
   },
   
   #' @description 
   #' Makes a copy of the weights parameters.
   get_weight_parameters = function() {
     private$mass_$clone()
   },
   
   #' @description prints the measure object
   #' @param ... Not used
   print = function(...) {
     cat("Measure: ",  rlang::obj_address(self), "\n", sep = "")
     cat("  x      : a ", self$n, "x", self$d, " matrix \n", sep = "")
     if(self$d > 5) {
     cat("           " , paste(round(as_matrix(private$data_[1,1:5]), digits = 2), collapse = ", "), ", \u2026", "\n", sep = "")
       if(self$n > 1) cat("           " , paste(round(as_matrix(private$data_[2,1:5]), digits = 2), collapse = ", "), ", \u2026", "\n", sep = "")
       if(self$n > 2) cat("           " , paste(round(as_matrix(private$data_[3,1:5]), digits = 2), collapse = ", "), ", \u2026", "\n", sep = "") 
       if(self$n > 3) cat("           " , paste(round(as_matrix(private$data_[4,1:5]), digits = 2), collapse = ", "), ", \u2026", "\n", sep = "")   
       if(self$n > 4) cat("           " , paste(round(as_matrix(private$data_[5,1:5]), digits = 2), collapse = ", "), ", \u2026", "\n", sep = "")   
     } else {
     cat("           " , paste(round(as_matrix(private$data_[1,1:5]), digits = 2), collapse = ", "), "\n", sep = "")
       if(self$n > 1) cat("           " , paste(round(as_matrix(private$data_[2,1:5]), digits = 2), collapse = ", "),  "\n", sep = "")
       if(self$n > 2) cat("           " , paste(round(as_matrix(private$data_[3,1:5]), digits = 2), collapse = ", "),  "\n", sep = "") 
       if(self$n > 3) cat("           " , paste(round(as_matrix(private$data_[4,1:5]), digits = 2), collapse = ", "),  "\n", sep = "")   
       if(self$n > 4) cat("           " , paste(round(as_matrix(private$data_[5,1:5]), digits = 2), collapse = ", "),  "\n", sep = "")  
     }
     if(self$n > 5) {
     cat("           \u22EE \n")
     # cat("           .\n")
     # cat("           .\n")
     }
     if(self$n > 5) { 
     cat("  weights: ", paste(round(t(as_numeric(self$weights[1:5])), digits = 2), collapse = ", "), ", \u2026", "\n", sep = "")
     } else {
     cat("  weights: ", paste(round(t(as_numeric(self$weights[1:5])), digits = 2), collapse = ", "), "\n", sep = "")
     }
     if(all(as_logical(!is.na(self$balance_target)))) {
     cat("  balance: ", "\n", sep = "")
     cat("   funct.: ",  paste(round(as_matrix(self$balance_functions[1,1:5]), digits = 2), collapse = ", "), "\n", sep = "")
     cat("   target: ",  paste(round(t(as_numeric(self$balance_target[1:5])), digits = 2), collapse = ", "), " \u2026", "\n", sep = "")
     }
     cat("  adapt  : ", self$adapt, "\n", sep = "")
     cat("  dtype  : ", capture.output(self$dtype), "\n", sep = "")
     cat("  device : ", capture.output(self$device), "\n", sep = "")
   },
   
   #' @description Constructor function
   #' @param x The data points
   #' @param weights The empirical measure. If NULL, assigns equal weight to each observation
   #' @param probability.measure Is the empirical measure a probability measure? Default is TRUE.
   #' @param adapt Should we try to adapt the data ("x"), the weights ("weights"), or neither ("none"). Default is "none".
   #' @param balance.functions A matrix of functions of the covariates to target for mean balance. If NULL and `target.values` are provided, will use the data in `x`.
   #' @param target.values The targets for the balance functions. Should be the same length as columns in `balance.functions.`
   #' @param dtype The [torch::torch_dtype] or NULL.
   #' @param device The device to have the data on. Should be result of [torch::torch_device()] or NULL.
   initialize = function(x, weights = NULL, 
                         probability.measure = TRUE, 
                         adapt = c("none","weights", "x"), 
                         balance.functions = NA_real_,
                         target.values = NA_real_,
                         dtype = NULL, device = NULL) {
     # browser()
     if(!is.matrix(x)) x <- as.matrix(x)
     self$d <- ncol(x)
     self$n <- nrow(x)
     
     self$adapt <- match.arg(adapt)
     self$probability_measure <- isTRUE(probability.measure)
     
     
     self$device <- cuda_device_check(device)
     self$dtype  <- cuda_dtype_check(dtype, self$device)
     # device
     # if (is.null(device)) {
     #   cuda_opt <- torch::cuda_is_available() && torch::cuda_device_count() >= 1
     #   if (cuda_opt) {
     #     self$device <-  torch::torch_device("cuda")
     #   } else {
     #     self$device <-  torch::torch_device("cpu")
     #   }
     # } else {
     #   stopifnot("device argument must be NULL or an object of class 'torch_device'" = torch::is_torch_device(device))
     #   self$device <- device
     # }
     # 
     # #dtype
     # if ( is.null(dtype) ) {
     #   if (grepl("cuda", capture.output(print(self$device)) ) ) {
     #     dtype <- torch::torch_float()
     #   } else {
     #     dtype <- torch::torch_double()
     #   }
     # }
     # stopifnot("Argument 'dtype' must be of class 'torch_dtype'. Please see '?torch_dtype' for more info." = torch::is_torch_dtype(dtype))
     # 
     # set data
     private$data_ <- torch::torch_tensor(x, dtype = self$dtype, device = self$device)$contiguous()
     
     if (self$adapt == "x") {
       if(!private$data_$requires_grad) private$data_$requires_grad <- TRUE
     }
     
     # browser()
     self$balance_target <- as.numeric(target.values)
     if(self$adapt == "none") self$balance_target <- NA_real_
     if(missing(balance.functions)) balance.functions <- NA_real_
     if(all(is.na(self$balance_target)) || self$adapt == "none") balance.functions <- NA_real_
     if(all(is.na(balance.functions)) && !all(is.na(self$balance_target))) balance.functions <- private$data_
     # browser()
     if (!inherits(balance.functions, "torch_tensor") && !all(is.na(balance.functions)) ) {
       self$balance_functions <- torch::torch_tensor(as.matrix(balance.functions), 
                                                     dtype = self$dtype, device = self$device)$contiguous()
     } else {
       self$balance_functions <- balance.functions
     }
     if (!inherits( self$balance_target, "torch_tensor") && !all(is.na(self$balance_target)) )  {
       self$balance_target <- torch::torch_tensor(self$balance_target, 
                                                     dtype = self$dtype, device = self$device)$contiguous()
     } else if (!all(is.na(self$balance_target))) {
       self$balance_target <- self$balance_target$to(device = self$device, dtype = self$dtype)$contiguous()
     }
     
     if (self$adapt == "x") {
       if (!all(is.na(self$balance_functions)) && !self$balance_functions$requires_grad) self$balance_functions$requires_grad <- TRUE
       if (!all(is.na(self$balance_functions))) {
         if(rlang::obj_address(private$data_) != rlang::obj_address(self$balance_functions)) warning("x and balance.functions may not come from same data. Adapting the two together may not make sense. If you would like x and balance.functions to have the same underlying data, you can either feed the same torch tensor into both or leave balance.functions blank and it will use the values in data by default if target.values are supplied.")
       }
     }
     
     stopifnot("Argument 'target.values' must be NA or the same length as the number of columns in 'balance.functions'." = all(is.na(self$balance_target)) || length(self$balance_target) == ncol(self$balance_functions))
 
     # adjust be Std Dev
     if (inherits( self$balance_target, "torch_tensor") && !as.logical(all(is.na(self$balance_target$to(device = "cpu")))) ) {
       if (!self$adapt == "x") {
         sds <- self$balance_functions$std(1)
         self$balance_functions <- self$balance_functions/sds
         self$balance_target <- self$balance_target /sds
         non_zero_sd <-  as.logical(0 < sds$to(device = "cpu"))
         if (sum(non_zero_sd) == 0) {
           warning("All columns of balance.functions have zero variance. Balance functions not used.")
           self$balance_functions <- NA_real_
           self$balance_target <- NA_real_
         } else {
           sel_nz <- which(non_zero_sd)
           self$balance_functions <- self$balance_functions[,sel_nz, drop = FALSE]
           self$balance_target <- self$balance_target[sel_nz]
         }
         
       }
     }
     
     weights <- check_weights(weights, private$data_, self$device)
     
     if (self$adapt == "weights"){
       # browser()
       private$mass_ <- torch::torch_zeros(length(weights)-1L,
                                            dtype = self$dtype,
                                           device = self$device)
       private$get_mass_ <- function() {
         private$transform_(private$mass_)
       }
       private$assign_mass_ <- function(value) {
         # torch::autograd_set_grad_mode(enabled = FALSE)
         torch::with_no_grad({
         private$mass_$copy_(private$inv_transform_(value)$to(device = self$device))
         if(! torch::is_undefined_tensor(private$mass_$grad) ) private$mass_$grad$copy_(0.0)
         })
         # torch::autograd_set_grad_mode(enabled = TRUE)
       }
       
       private$mass_$requires_grad <- TRUE
     } else {
       private$get_mass_ <- function() {
         return(private$mass_)
       }
       private$assign_mass_ <- function(value) {
         private$mass_ <- value$to(device = self$device)
       }
     }
     
     if (self$probability_measure) {
       private$inv_transform_ <- function(value) {
         min_neg <- round(log(.Machine$double.xmin)) - 50.0
         logs <- torch::torch_log(value/sum(value))
         logs[logs< min_neg] <- min_neg
         logs <- logs[2:length(logs)] - logs[1]
         return(logs)
       }
       
       private$transform_ <- function(value) {
         full_param <- torch::torch_cat(
           list( torch::torch_zeros(1L, device = self$device, dtype = value$dtype),
                value)
         )
         return(full_param$log_softmax(1)$exp())
       }
       
     } else {
       private$inv_transform_ <- torch::torch_log
       private$transform_ <- torch::torch_exp
     }
     private$assign_mass_(weights)
     
     # assign initial values
     private$init_weights_ <- self$weights$detach()
     if (self$adapt == "none" || self$adapt == "weights") {
       private$init_data_ <- private$data_
     } else if (self$adapt == "x") {
       private$init_data_ <- private$data_$detach()$clone()
     } else {
       stop("adapt hasn't been properly set!")
     }
     if (torch::cuda_is_available()) torch::cuda_empty_cache()
     return(invisible(self))
   } 
 )},
 active = {list(
   #' @field grad gets or sets gradient
   grad = function(value) {
     
     # return grad values as appropriate
     if (missing(value)) {
       if(self$adapt == "none") {
         return(NULL)
       } else if (self$adapt == "x") {
         return(private$data_$grad)
       } else if (self$adapt == "weights") {
         return(private$mass_$grad)
       }
     }
     
     
     # save grad values
     torch::with_no_grad({
       if (!is_torch_tensor(value)) {
         value <- torch::torch_tensor(value, dtype = self$dtype,
                                      device = self$device)
       }
       if(self$adapt == "none") {
         stop("No elements of this Measure object have gradients")
       } else if (self$adapt == "x") {
         
         l_v <- length(value)
         dim_v <- dim(value)
         
         if(any(dim_v != c(self$n,self$d)) && l_v != self$d && l_v != self$d * self$n) {
           stop(sprintf("Length of input for the `x` gradients must be of length %s or %s. Alternatively, better to supply a matrix of dimension %s by %s directly.", self$d, self$d*self$n, self$n, self$d))
         }
         
         private$data_$grad <- value$to(device = private$data_$device)
         
       } else if (self$adapt == "weights") {
         
         l_v <- length(value)
         
         if (l_v != (self$n - 1)) {
           stop(sprintf("Input value must be of length %s for the weight gradients. The first value is fixed to make the vector identifiable and thus does not have a gradient.", self$n - 1))
         }
         
         private$mass_$grad <- value$to(device = private$mass_$device)
       }
     })
   },
   
   #' @field init_weights returns the initial value of the weights
   init_weights = function(value) {
     if(!missing(value)) stop("Can't change the initial weights. Try setting the weights by using the `$weights` operator.")
     return(private$init_weights_$clone())
   },
   
   #' @field init_data returns the initial value of the data
   init_data = function(value) {
     if(!missing(value)) stop("Can't change the initial values. Try setting the data by using the `$x` operator.")
     return(private$init_data_$clone())
   },
   
   #' @field requires_grad checks or turns on/off gradient
   requires_grad = function(value) {
     if (missing(value)) {
       rg <- switch(self$adapt,
              "none" = FALSE,
              TRUE)
       return(rg)
     }
     value <- match.arg(value, c("none","weights","x"))
     if (value == "none") {
       if(self$adapt == "weights") {
         private$mass_$requires_grad <- FALSE
       } else if (self$adapt == "x") {
         private$data_$requires_grad <- FALSE
       }
       self$adapt <- "none"
     } else if (value == "x") {
       if (self$adapt == "weights") {
         warning("Turning off gradients for weights and turning on for x values.")
         private$mass_$requires_grad <- FALSE
       } else if (self$adapt == "x") {
         warning("Gradients already on for the x values.")
       }
       private$data_$requires_grad <- TRUE
       self$adapt <- "x"
     } else if (value == "weights") {
       if (self$adapt == "x") {
         warning("Turning off gradients for x values and turning on for weights")
         private$data_$requires_grad <- FALSE
         
       } else if (self$adapt == "weights") {
         warning("Gradients already on for the weights.")
       }
       if(self$adapt != "weights") {
         private$get_mass_ <- function() {
           private$transform_(private$mass_)
         }
         private$assign_mass_ <- function(value) {
           # torch::autograd_set_grad_mode(enabled = FALSE)
           torch::with_no_grad({
             private$mass_$copy_(private$inv_transform_(value)$to(device = self$device))
           })
           # torch::autograd_set_grad_mode(enabled = TRUE)
         }
         
         if (self$probability_measure) {
           private$inv_transform_ <- function(value) {
             min_neg <- round(log(.Machine$double.xmin)) - 50.0
             logs <- torch::torch_log(value/sum(value))
             logs[logs< min_neg] <- min_neg
             logs <- logs[2:length(logs)] - logs[1]
             return(logs)
           }
           
           private$transform_ <- function(value) {
             full_param <- torch::torch_cat(
               list( torch::torch_zeros(1L, dtype = value$dtype),
                     value)
             )
             return(full_param$log_softmax(1)$exp())
           }
           
         } else {
           private$inv_transform_ <- torch::torch_log
           private$transform_ <- torch::torch_exp
         }
         
         private$assign_mass_(self$weights)
       }
       private$mass_$requires_grad <- TRUE
       self$adapt <- "weights"
     }
   },
   
   #' @field weights gets or sets weights
   weights = function(value) {
     if(missing(value)) {
       return(private$get_mass_())
     }
     stopifnot("Input is NA" = !isTRUE(all(is.na(value))))
     stopifnot("Input is NULL" = !isTRUE(is.null(value)))
     stopifnot("Input value is not same length as nrows of data" = (length(value) == self$n) )
     
     # browser()
     if(!inherits(value, "torch_tensor")) {
       value <- torch::torch_tensor(value, dtype = private$mass_$dtype, device = self$device)$contiguous()
     } else {
       stopifnot("Input tensor and original weights have different dtypes! " = isTRUE(value$dtype == private$mass_$dtype))
       if (isFALSE(value$device == private$mass_$device) ) {
         value <- value$to(device = private$mass_$device)
       }
     }
     
     if (self$probability_measure) {
       stopifnot("supplied weights must be >=0" = all(as.logical((value >=0)$to(device = "cpu"))))
       if(as.logical((sum(value) != 1)$to(device = "cpu"))) value <- (value/sum(value))$detach()
     }
     
     private$assign_mass_(value)
     
   },
   
   #' @field x Gets or sets the data.
   x = function(value) {
     
     # return data tensor if no value provided
     if (missing(value)) {
       return(private$data_)
     }
     
     # check input data
     stopifnot("Input is NA" = !all(is.na(value)))
     stopifnot("Input is NULL" = !is.null(value))
     
     # check if input is tensor
     if (!inherits(value, "torch_tensor")) {
       value <- torch::torch_tensor(value, device = self$device, dtype = private$data_$dtype)
     } else {
       stopifnot("Input tensor and original data have different dtypes! " = isTRUE(value$dtype == private$data_$dtype))
     }
     
     # make sure dimensions are correct
     l_value <- length(value)
     
     if(l_value != (self$n * self$d) ) stop(sprintf("Input must either be a matrix of dimension %s by %s or a vector with total length %s", self$n, self$d, self$n * self$d))
     
     dims_v <- dim(value)
     
     if (length(dims_v) != 2) {
       if(length(dims_y) >  2) warning("Tensor being reshaped to a two dimensional tensor")
       value <- value$view(c(self$n, self$d))
     }
     
     # set data
     # torch::autograd_set_grad_mode(enabled = FALSE)
     torch::with_no_grad({
     private$data_$copy_(value$to(device = self$device))
     })
     # torch::autograd_set_grad_mode(enabled = TRUE)
     
     # check if balance.function data is equal to data
     bf_not_equal_data <- isTRUE(rlang::obj_address(self$balance_functions) != rlang::obj_address(private$data_))
     
     if (bf_not_equal_data && !all(is.na(self$balance_functions))) {
       warning("Measure data reset but not the balance_functions. You may need to manually reset this as well.")
     }
     
   }
 )},
 private = {list(
   # values
   data_ = "torch_tensor",
   init_data_ = "torch_tensor",
   init_weights_ = "torch_tensor",
   mass_ = "torch_tensor",
   
   # functions
   assign_mass_ = "function", #transforms if needed
   deep_clone = function(name, value) {
     if (inherits(value, "torch_tensor")) {
       value$clone()
     } else {
       value
     }
   },
   get_mass_ = "function", #transforms if needed
   inv_transform_ = "function", # needs inv log_softmax for prob measure
   transform_ = "function" # needs log_softmax
   
 )}
)

#' @name Measure
#' @title An R6 Class for setting up measures
#'
#' @param x The data points
#' @param weights The empirical measure. If NULL, assigns equal weight to each observation
#' @param probability.measure Is the empirical measure a probability measure? Default is TRUE.
#' @param adapt Should we try to adapt the data ("x"), the weights ("weights"), or neither ("none"). Default is "none".
#' @param balance.functions A matrix of functions of the covariates to target for mean balance. If NULL and `target.values` are provided, will use the data in `x`.
#' @param target.values The targets for the balance functions. Should be the same length as columns in `balance.functions.`
#' @param dtype The torch_tensor dtype or NULL.
#' @param device The device to have the data on. Should be result of [torch::torch_device()] or NULL.
#' @return Returns a Measure object
#' 
#' @details # Public fields
#'   \if{html}{\out{<div class="r6-fields">}}
#'   \describe{
#'     \item{\code{balance_functions}}{the functions of the data that
#'       we want to adjust towards the targets}
#'     \item{\code{balance_target}}{the values the balance_functions are targeting}
#'     \item{\code{adapt}}{What aspect of the data will be adapted. One of "none","weights", or "x".}
#'     \item{\code{device}}{the \code{\link[torch:torch_device]{torch::torch_device}} of the data.}
#'     \item{\code{dtype}}{the \link[torch:torch_dtype]{torch::torch_dtype} of the data.}
#'     \item{\code{n}}{the rows of the covariates, x.}
#'     \item{\code{d}}{the columns of the covariates, x.}
#'     \item{\code{probability_measure}}{is the measure a probability measure?}
#'   }
#'   \if{html}{\out{</div>}}
#' @details # Active bindings
#'   \if{html}{\out{<div class="r6-active-bindings">}}
#'   \describe{
#'     \item{\code{grad}}{gets or sets gradient}
#'     \item{\code{init_weights}}{returns the initial value of the weights}
#'     \item{\code{init_data}}{returns the initial value of the data}
#'     \item{\code{requires_grad}}{checks or turns on/off gradient}
#'     \item{\code{weights}}{gets or sets weights}
#'     \item{\code{x}}{Gets or sets the data}
#'   }
#'   \if{html}{\out{</div>}}
#' @details # Methods
#' \subsection{Public methods}{
#' \itemize{
#' \item \href{#method-Measure-detach}{\code{Measure$detach()}}
#' \item \href{#method-Measure-get_weight_parameters}{\code{Measure$get_weight_parameters()}}
#' \item \href{#method-Measure-clone}{\code{Measure$clone()}}
#' }
#' }
#' \if{html}{\out{<hr>}}
#' \if{html}{\out{<a id="method-Measure-detach"></a>}}
#' \if{latex}{\out{\hypertarget{method-Measure-detach}{}}}
#' \subsection{Method \code{detach()}}{
#' generates a deep clone of the object without gradients.
#' \subsection{Usage}{
#' \if{html}{\out{<div class="r">}}\preformatted{Measure$detach()}\if{html}{\out{</div>}}
#' }
#' }
#' \if{html}{\out{<hr>}}
#' \if{html}{\out{<a id="method-Measure-get_weight_parameters"></a>}}
#' \if{latex}{\out{\hypertarget{method-Measure-get_weight_parameters}{}}}
#' \subsection{Method \code{get_weight_parameters()}}{
#' Makes a copy of the weights parameters.
#' \subsection{Usage}{
#' \if{html}{\out{<div class="r">}}\preformatted{Measure$get_weight_parameters()}\if{html}{\out{</div>}}
#' }
#' }
#' \if{html}{\out{<hr>}}
#' \if{html}{\out{<a id="method-Measure-clone"></a>}}
#' \if{latex}{\out{\hypertarget{method-Measure-clone}{}}}
#' \subsection{Method \code{clone()}}{
#' The objects of this class are cloneable with this method.
#' \subsection{Usage}{
#' \if{html}{\out{<div class="r">}}\preformatted{Measure$clone(deep = FALSE)}\if{html}{\out{</div>}}
#' }
#' \subsection{Arguments}{
#' \if{html}{\out{<div class="arguments">}}
#' \describe{
#'   \item{\code{deep}}{Whether to make a deep clone.}
#' }
#' \if{html}{\out{</div>}}
#' }
#' }
#' @examples 
#' if(torch::torch_is_installed()) {
#' m <- Measure(x = matrix(0, 10, 2), adapt = "none")
#' print(m)
#' m$x
#' m$x <- matrix(1,10,2) # must have same dimensions
#' m$x
#' m$weights
#' m$weights <- 1:10/sum(1:10)
#' m$weights
#' 
#' # with gradients
#' m <- Measure(x = matrix(0, 10, 2), adapt = "weights")
#' m$requires_grad # TRUE
#' m$requires_grad <- "none" # turns off
#' m$requires_grad # FALSE
#' m$requires_grad <- "x"
#' m$requires_grad # TRUE
#' m <- Measure(matrix(0, 10, 2), adapt = "none")
#' m$grad # NULL
#' m <- Measure(matrix(0, 10, 2), adapt = "weights")
#' loss <- sum(m$weights * 1:10)
#' loss$backward()
#' m$grad
#' # note the weights gradient is on the log softmax scale
#' #and the first parameter is fixed for identifiability
#' m$grad <- rep(1,9)  
#' m$grad
#' }
#' @export
Measure <- function(x, 
                    weights = NULL, 
                    probability.measure = TRUE, 
                    adapt = c("none","weights", "x"), 
                    balance.functions = NA_real_,
                    target.values = NA_real_,
                    dtype = NULL, 
                    device = NULL) {
  
  return(Measure_$new(x = x, 
                      weights = weights,
                      probability.measure = probability.measure,
                      adapt = adapt,
                      balance.functions = balance.functions,
                      target.values = target.values,
                      dtype = dtype, 
                      device = device))
  
}

#' Internal function to select appropriate loss function
#'
#' @description Selects sinkhorn or energy distance losses depending on value
#' of penalty parameter
#' 
#' @param ot an OT object
#' @keywords internal
oop_loss_select <- function(ot) {
  lambda <- ot$penalty
  if (is.finite(lambda)) {
    return(sinkhorn_dist(ot))
  } else if ( is.infinite(lambda) ) {
    return(inf_sinkhorn_dist(ot))
  }
}

#' @name OTProblem_-class
#' @title An R6 class to construct OTProblems
#' @rdname OTProblem_-class
#' @description OTProblem R6 class
#' @keywords internal
OTProblem_ <- R6::R6Class("OTProblem",
 public = {list(
   # objects
   
   #' @field device the [torch::torch_device()] of the data.
   device = "torch_device",
   
   #' @field dtype the [torch::torch_dtype] of the data.
   dtype  = "torch_dtype",
   
   #' @field selected_delta the delta value selected after `choose_hyperparameters`
   selected_delta = "numeric", # final delta
   
   #' @field selected_lambda the lambda value selected after `choose_hyperparameters`
   selected_lambda = "numeric", # final lambda

   # functions
   
   #' @param o2 A number or object of class OTProblem
   #' @description adds `o2` to the OTProblem
   add = function(o2) {
     private$unaryop(o2, "+")
   },
   
   #' @param o2 A number or object of class OTProblem
   #' @description subtracts `o2` from OTProblem
   subtract = function(o2) {
     private$unaryop(o2, "-")
   },
   
   #' @param o2 A number or object of class OTProblem
   #' @description multiplies OTProblem by `o2`
   multiply = function(o2) {
     private$unaryop(o2, "*")
   },
   
#' @param o2 A number or object of class OTProblem
#' @description divides OTProblem by `o2`
   divide = function(o2) {
     private$unaryop(o2, "/")
   },

#' @description prints the OT problem object
#' @param ... Not used
print = function(...) {
  obj <- rlang::expr_text(private$objective)
  obj <- gsub("oop_loss_select", "OT", obj)
  obj <- gsub('private$ot_objects[[\"', "", obj, fixed = TRUE)
  obj <- gsub('\"]]', "", obj, fixed = TRUE)
  obj <- gsub("\n    ", "", obj)
  obj <- gsub("+ ", "+\n  ", obj, fixed = TRUE)
  obj <- gsub("- ", "-\n  ", obj, fixed = TRUE)
  cat("OT Problem: \n")
  cat("  ", obj, "\n", sep = "")
},

#' @description Constructor method
#' @param measure_1 An object of class [Measure]
#' @param measure_2 An object of class [Measure]
#' @param ... Not used at this time 
#'
#' @return An R6 object of class "OTProblem"
initialize = function(measure_1, measure_2) {
  # browser()
  add_1 <- rlang::obj_address(measure_1)
  add_2 <- rlang::obj_address(measure_2)
  addresses <- c(add_1, add_2)
  address_names <- paste0(addresses, collapse = ", ")
  
  stopifnot("argument 'measure_1' must be of class 'Measure'" = inherits(measure_1, "Measure"))
  stopifnot("argument 'measure_2' must be of class 'Measure'" = inherits(measure_2, "Measure"))
  
  dtype <- measure_1$dtype
  device <- measure_1$device
  
  if(isFALSE(measure_2$dtype == dtype) ) {
    stop(sprintf("Measures must have same data type! measure_1 is of type %s, while measure_2 is of type %s.", dtype, measure_2$dtype))
  }
  if (!(measure_2$device == device) ) { # can't use != with torch device
    stop(sprintf("Measures should be on same device! measure_1 is is on device %s, while measure_2 is on device %s.", device, measure_2$device) )
  }
  if (measure_1$d != measure_2$d) {
    stop(sprintf("Measures should have the same number of columns! measure_1 has %s columns, while measure_2 has %s columns.", measure_1$d, measure_2$d))
  }
  
  self$dtype <- dtype
  self$device <- device
  
  #environment with names as obj_add, and measures as elements of environment
  private$measures <- rlang::env(!!add_1 := measure_1, !!add_2 := measure_2)
  
  # env with names as obj_add1, obj_add2 (sorted),
  #contains vector with c(obj_add1, obj_add2)
  private$problems <- rlang::env(!!address_names := addresses)
  
  # envionrment with names as obj_add1, obj_add2 (sorted), then a list with f, g duals
  # self$duals <- rlang::env(!!address_names := list(!!addresses[1] := torch::torch_zeros(private$measures[[addresses[1] ]]$n, device = device, dtype = dtype)),
  #                          !!addresses[2] := torch::torch_zeros(private$measures[[addresses[2] ]]$n, device = device, dtype = dtype) )
  
  # ot_objects
  # envionrment with names as obj_add1, obj_add2 (sorted), with OT class objects
  private$ot_objects <- rlang::env()
  
  # target_objects
  # envionrment with names as obj_add1, obj_add2 (sorted), with list of balance.functions, means and delta values
  private$target_objects <- rlang::env()
  
  #penalty list
  private$penalty_list <- list(lambda = NA_real_, delta = NA_real_)
  
  # parameter list initialize
  private$parameters <- list()
  
  
  private$objective <- rlang::expr(
    oop_loss_select(private$ot_objects[[!!address_names]])
  )
  
  private$args_set <- FALSE
  private$opt <- private$sched <- NULL
  
  return(invisible(self))
},

#' @param lambda The penalty parameters to try for the OT problems. If not provided, function will select some
#' @param delta The constraint paramters to try for the balance function problems, if any
#' @param grid.length The number of hyperparameters to try if not provided
#' @param cost.function The cost function for the data. Can be any function that takes arguments `x1`, `x2`, `p`. Defaults to the Euclidean distance
#' @param p The power to raise the cost matrix by. Default is 2
#' @param cost.online Should online costs be used? Default is "auto" but "tensorized" stores the cost matrix in memory while "online" will calculate it on the fly.
#' @param debias Should debiased OT problems be used? Defaults to TRUE
#' @param diameter Diameter of the cost function.
#' @param ot_niter Number of iterations to run the OT problems
#' @param ot_tol The tolerance for convergence of the OT problems
#'
#' @return NULL
setup_arguments = function(lambda, delta, 
                           grid.length = 7L,
                           cost.function = NULL, 
                           p = 2,
                           cost.online = "auto",
                           debias = TRUE,
                           diameter = NULL, ot_niter = 1000L,
                           ot_tol = 1e-3) {
  
  prob_names <- ls(private$problems)
  if(private$args_set) warning("OT problems already set up. This function will erase previous objects")
  problem_1 <- problem_2 <- measure_1 <- measure_2 <- NULL
  device_vector <- NULL
  not_warned <- FALSE
  
  for(v in prob_names) {
    problem_1 <- private$problems[[v]][[1]]
    problem_2 <- private$problems[[v]][[2]]
    measure_1 <- private$measures[[problem_1]]
    measure_2 <- private$measures[[problem_2]]
    private$ot_objects[[v]] <- OT$new(x = measure_1$x$detach(),
                                      y = measure_2$x$detach(),
                                      a = measure_1$weights$detach(),
                                      b = measure_2$weights$detach(),
                                      penalty = 10.0, 
                                      cost_function = cost.function, 
                                      p = p, 
                                      debias = debias, 
                                      tensorized = cost.online,
                                      diameter = diameter,
                                      device = self$device,
                                      dtype = self$dtype)
    if(not_warned && isTRUE(!(device_vector == measure_1$device)) ){
      warning("All measures not on same device. This could slow things down.")
      not_warned <- FALSE
    } else {
      device_vector <- measure_1$device
    }
    
    
  }
  
  
  # ot opt param
  private$ot_niter <- as.integer(ot_niter)
  private$ot_tol <- as.numeric(ot_tol)
  
  # targets
  measure_addresses <- ls(private$measures)
  delta_set <- NA_real_
  not_na_bt <- not_na_bf <- FALSE
  meas <- NULL
  for (v in measure_addresses) {
    meas <- private$measures[[v]]
    not_na_bt <- !all(is.na(meas$balance_target) )
    not_na_bf <- !all(is.na(meas$balance_functions) )
    if (not_na_bt && not_na_bf) {
      if(meas$adapt != "none") {
        if (meas$balance_functions$requires_grad) {
          delta_set <- 0.0
        } else {
          delta_set <- NA_real_
        }
        private$target_objects[[v]] <- list(
          bf = meas$balance_functions,
          bt = meas$balance_target,
          delta = delta_set)
      }
      
    }
  }
  
  # parameters
  measure_addresses <- ls(private$measures)
  adapt <- NULL
  meas <- NULL
  for (v in measure_addresses) {
    meas <- private$measures[[v]]
    adapt <- meas$adapt
    if(adapt != "none") {
      private$parameters[[v]] <-
        switch(adapt,
               "x" = meas$.__enclos_env__$private$data_,
               "weights" = meas$.__enclos_env__$private$mass_)
    }
  }
  
  # make sure ot_args ok
  stopifnot("'ot_niter' must be > 0" = private$ot_niter > 0)
  stopifnot("'ot_tol' must be > 0" = private$ot_tol > 0)
  
  # ot penalty
  
  diameters <- length(private$ot_objects)
  names_ot <- ls(private$ot_objects)
  v <- NULL
  for(i in seq_along(private$ot_objects)) {
    v <- names_ot[i]
    diameters[i] <- private$ot_objects[[v]]$diameter
  }
  max_diameter <- max(diameters)
  stopifnot("The maximum diameter of the OT problem is not finite!" = is.finite(max_diameter))
  l_md <-log(max_diameter)
  stopifnot("The log of the maximum diameter of the OT problem is not finite!" = is.finite(l_md))
  if ( missing(lambda) || is.null(lambda) || all(is.na(lambda)) ) {
    lambda <- c( #log(max_diam * 1e-6) to log(max_diam * 1e4), about
      exp(seq(l_md - 13.82, l_md + 9.21, length.out = grid.length)),
      Inf)
  } 
  private$penalty_list$lambda <- sort(lambda, decreasing = TRUE)
  private$penalty_list$lambda[lambda == 0] <- max_diameter / 1e9
  
  if ( length(private$target_objects) != 0) {
    
    if ( missing(delta) || is.null(delta) || all(is.na(delta)) ) {
      
      diffs    <- numeric(length(private$target_objects))
      names_TO <- ls(private$target_objects)
      measure <- NULL
      for(i in seq_along(private$target_objects)) {
        v <- names_TO[i]
        measure <- private$measures[[v]]
        if(measure$adapt == "weights") {
          diffs[i] <- ((private$target_objects[[v]]$bf * measure$weights$detach()$view(c(measure$n,1)))$sum(1) - private$target_objects[[v]]$bt)$abs()$max()$item()
        }
        
      }
      max_diffs <- max(diffs)
      
      delta <- c(seq(1e-4, max_diffs, length.out = grid.length))
      
    }
    
    
    private$penalty_list$delta <- sort(as.numeric(delta), decreasing = TRUE)
    
    stopifnot("'delta' values must be >=0."=all(private$penalty_list$delta >= 0))
    
  }
  
  # flag to warn about overwrite next time function called
  private$args_set <- TRUE
  
  return(invisible(self))
},

#' @description Solve the OTProblem at each parameter value. Must run setup_arguments first.
#' @param niter The nubmer of iterations to run solver at each combination of hyperparameter values 
#' @param tol The tolerance for convergence
#' @param optimizer The optimizer to use. One of "torch" or "frank-wolfe"
#' @param torch_optim The `torch_optimizer` to use. Default is [torch::optim_lbfgs]
#' @param torch_scheduler The [torch::lr_scheduler] to use. Default is [torch::lr_reduce_on_plateau]
#' @param torch_args Arguments passed to the torch optimizer and scheduler
#' @param osqp_args Arguments passed to [osqp::osqpSettings()] if appropriate
#' @param quick.balance.function Should [osqp::osqp()] be used to select balance function constraints (delta) or not. Default true.
  solve = function(niter = 1000L, tol = 1e-5, optimizer = c("torch", "frank-wolfe"),
                   torch_optim = torch::optim_lbfgs,
                   torch_scheduler = torch::lr_reduce_on_plateau,
                   torch_args = NULL,
                   osqp_args = NULL,
                   quick.balance.function = TRUE) {
    
    # check that everything setup already
    stopifnot("arguments not set! Run '$setup_arguments' first." = private$args_set)
    # check that niter and tol arguments provided
    # stopifnot("`niter` argument must be provided" = !missing(niter))
    stopifnot("Argument `niter` must be > 0" = (niter > 0))
    # stopifnot("Argument `tol` must be provided" = !missing(tol))
    stopifnot("Argument `tol` must be >=0" = (tol >= 0))
    
    # collect osqp args
    private$osqp_args <- osqp_args[names(osqp_args) %in% methods::formalArgs(osqp::osqpSettings)] 
    
    # check feasibility of deltas
    # can also do a quick, approximate selection of deltas
    private$delta_values_setup(run.quick = quick.balance.function, osqp_args = private$osqp_args)
    
    # setup optimizer
    optimizer <- match.arg(optimizer)
    stopifnot("Optimizer must be one of 'torch' or 'frank-wolfe'." = (optimizer %in% c("torch", "frank-wolfe") ))
    opt_call <- opt <- scheduler_call <- opt_sched <- NULL
    
    # assign optimizer to `private$optimization_step` holder
    if (optimizer == "torch") {
      
      private$torch_optim_setup(torch_optim,
                                torch_scheduler,
                                torch_args)
      
    } else if (optimizer == "frank-wolfe") {
      private$frankwolfe_setup()
      private$optimization_step <- private$frankwolfe_step
    } else {
      stop("Optimizer must be one of 'torch' or 'frank-wolfe'.")
    }
    
    # strategy for this function
    # outer loop iterates over lambda
    # inner loop iterates over delta
    # for mirror descent ONLY
    # add BF violations
    # calculate loss
    # torch_optim step
    # for frank-wolfe
    # run ot opt
    # get results from LP
    # step (armijo line search)
    
    # setup holder for weights
    private$weights_list <- vector("list", length(private$penalty_list$lambda))
    names(private$weights_list) <- as.character(private$penalty_list$lambda)
    
    # setup diagnostic collection of variables
    private$iterations_run <- vector("list", length(private$penalty_list$lambda)) |>
      setNames(as.character(private$penalty_list$lambda))
    
    private$final_loss <- vector("list", length(private$penalty_list$lambda)) |>
      setNames(as.character(private$penalty_list$lambda))
    
    
    # optimize over lambda values
    private$iterate_over_lambda(niter, tol)
    torch_cubic_reassign()
    return(invisible(self))
  },

  #' @param n_boot_lambda The number of bootstrap iterations to run when selecting lambda
  #' @param n_boot_delta The number of bootstrap iterations to run when selecting delta
  #' @param lambda_bootstrap The penalty parameter to use when selecting lambda. Higher numbers run faster.
  #'
  #' @description Selects the hyperparameter values through a bootstrap algorithm
  choose_hyperparameters =  function(n_boot_lambda = 100L, n_boot_delta = 1000L, lambda_bootstrap = Inf) {
      
      # check arguments
      stopifnot("n_boot_lambda must be >= 0"= n_boot_lambda>=0)
      stopifnot("n_boot_delta must be >= 0"= n_boot_delta>=0)
      stopifnot("lambda_bootstrap must be > 0"=lambda_bootstrap>0)
      
      # alter lists if needed for inherited classes
      res <- private$setup_choose_hyperparameters()
      
      # pull out current delta and lambda values
      delta_values <- res$delta
      lambda_values <- res$lambda
      
      # vector parameters for temporary holders
      n_delta <- length(delta_values)
      n_lambda<- length(lambda_values)
      
      # current weights list
      weights_list  <- res$weights_list
      
      # setup final metrics list
      private$weights_metrics <- list(delta = vector("list", n_lambda),
                                      lambda= vector("list", n_lambda))
      
      # check if function already ran before
      if(is.numeric(self$selected_lambda)) {
        warning(sprintf("Lambda value of %s previously selected. This function will erase previous bootstrap selection", self$selected_lambda))
      }
      
      if (n_delta > 1) { # begin delta select
        # setup temporary vector to hold delta evaluations
        delta_temp_metric <- vector("numeric", n_delta) 
        class(delta_temp_metric) <- c("weightEnv", class(delta_temp_metric))
        
        lambda_temp_metric <- vector("list", n_lambda)
        
        for (l in seq_along(lambda_values)) {
          lambda_temp_metric[[l]] <- vector("numeric", n_delta)
        }
        
        self$selected_delta <- vector("numeric", n_lambda) 
        
        # boot measure holder
        boot_measure <- NULL
        
        # running delta evaluations
        for ( i in 1:n_boot_delta ) {
          boot_measure <- private$draw_boot_measure()
          for ( l in seq_along(lambda_values) ) {
            delta_temp_metric <- vapply(weights_list[[l]],
                                        FUN = private$eval_delta,
                                        FUN.VALUE = 0.0,
                                        boot = boot_measure
            )
            lambda_temp_metric[[l]] <- lambda_temp_metric[[l]] + delta_temp_metric/n_boot_delta
          }
        }
        
        
        # assign metrics back to final location
        private$weights_metrics$delta <- lambda_temp_metric
        
        # select final deltas
        delta_eval_holder <- vector("numeric", n_delta)
        d_idx <- NULL
        targ_names <- ls(private$target_objects)
        param_names<- ls(private$parameters)
        
        # run through the parameter list and assign all to temp_wt
        # for (address in param_names ) {
        #   for(l in l_names) {
        #     # only look at delta eval if has target
        #     if (address %in% targ_names) {
        #       for (d in d_names) {
        #         delta_eval_holder[[d]] <- private$weights_metrics$delta[[l]][[d]][[address]]
        #       }
        #       d_idx <- which.min(delta_eval_holder)
        #     } else {
        #       d_idx <- 1L
        #     }
        #     
        #     # assign final selected weights
        #     temp_wt[[l]][[t_address]] <- private$weights_list[[l]][[d_idx]][[address]]
        #   }
        # }
        min_d_idx <- NULL
        for(l in seq_along(lambda_values)) {
          min_d_idx <- which.min(lambda_temp_metric[[l]])[1]
          weights_list[[l]] <-  weights_list[[l]][[min_d_idx]]
          self$selected_delta[[l]] <- delta_values[[min_d_idx]]
        }
        
      } else {
        new_weights_list <- vector("list", length(weights_list))
        for (l in seq_along(lambda_values)) {
          new_weights_list[[l]] <- weights_list[[l]][[1L]]
        }
        weights_list <- new_weights_list
        
        if(length(private$target_objects) > 0) {
          self$selected_delta <- list(rep(NA_real_, length(private$target_objects)))
          names_target <- ls(private$target_objects)
          for (v in seq_along(private$target_objects) ) {
            self$selected_delta[[1L]][[v]] <- private$target_objects[[names_target[v] ]]$delta
          }
          names(self$selected_delta[[1L]]) <- names_target
        }
      }# end of delta selection
      
      if (n_lambda > 1) {
        
        # use Energy Dist
        private$set_lambda(lambda_bootstrap)
        
        # setup holder
        lambda_metrics <- rep(0.0, n_lambda)
        boot_measure <- NULL
        
        for (i in 1:n_boot_lambda ) {
          boot_measure <- private$draw_boot_measure()
          lambda_metrics <- vapply(X = weights_list,
                                   FUN = private$eval_lambda,
                                   FUN.VALUE = 0.0,
                                   boot = boot_measure)/n_boot_lambda +
            lambda_metrics
        }
        # choose final lambda value
        idx_lambda <- which.min(lambda_metrics)
        selected_lambda <- as.numeric(lambda_values)[idx_lambda]
        
        # set metrics
        private$weights_metrics$lambda <- lambda_metrics
        
        # pull out final wts
        weights_list <- weights_list[[idx_lambda]]
        
        # save selected lambda
        self$selected_lambda <- selected_lambda
        private$set_lambda(selected_lambda)
        
        if(length(self$selected_delta) > 1) self$selected_delta <- self$selected_delta[[idx_lambda]]
      } else {
        weights_list <- weights_list[[1L]]
        self$selected_lambda <- as.numeric(lambda_values)[1L]
        private$set_lambda(self$selected_lambda)
      }
      
      # set weights back to the measures
      private$parameters_get_set(value = weights_list)
      
      
      # private$ot_update(only_params = FALSE, get_weights = TRUE, use_grad = FALSE)
      
    },

#' @description Provides diagnostics after solve and choose_hyperparameter methods have been run.
#'
#' @return a list with slots
#' \itemize{
#' \item `loss` the final loss values
#' \item `iterations` The number of iterations run for each combination of parameters
#' \item `balance.function.differences` The final differences in the balance functions
#' \item `hyperparam.metrics` A list of the bootstrap evalustion for delta and lambda values}
   info = function(){
     losses <- if (is.list(private$final_loss)) {
       do.call("rbind", private$final_loss)
     } else {
       NULL
     }
     
     metrics <- private$weights_metrics
     if(!is.character(metrics)) {
       delta_df <- do.call("cbind", metrics$delta)
       metrics["delta"] <- list(delta_df)
     } else {
       metrics <- "Hyperparameters not selected yet"
     }
     
     
     bal <- as.list(private$balance_check())
     
     iter <- if(is.list(private$iterations_run)) {
       do.call("rbind", private$iterations_run)
     } else {
       NULL
     }
     
     return(list(loss = losses,
                 iterations = iter,
                 balance.function.differences = bal,
                 hyperparam.metrics = metrics))
   }
   
 )},
 active = {list(
#' @field loss prints the current value of the objective. Only availble after the solve method has been run
   loss = function() {
     private$ot_update()
     return(eval(private$objective)$to(device = self$device))
   },
   
   #' @field penalty Returns a list of the lambda and delta penalities that will be iterated through. To set these values, use the `setup_arguments` function.
   penalty = function() {
     return(private$penalty_list)
   }
 )},
 private = {list(
   # objects
   args_set = "logical",
   final_loss = "list",
   iterations_run = "list",
   lbfgs_reset = 0L,
   lbfgs_count = 0L,
   # @field measures An an environment of measure objects named by address
   measures = "env", # environment of measure objects, named by address
   # @field objective An [rlang::expr()] giving the objective function
   objective = "rlang",
   opt_calls = "list", #saves arguments for the optimizers to reset
   opt = "R6", # store optimizer so easier to reset for LBFGS
   osqp_args = "list",
   ot_niter = "integer",
   ot_objects = "env", # environment with ot R6 classes
   ot_tol = "numeric",
   parameters = "list", # list of the parameters of model
   # @field problems An environment giving the object addresses for each measure in the OT problems
   problems = "env", # character vector of object addresses crossed
   penalty_list = "list", # list of penalties to try
   sched = "R6", # store scheduler
   target_objects = "env", # balance target objects
   weights_metrics = "list", # list of bootstrap metric for each penalty
   weights = "weightEnv", # holder of transformed parameters that are weights
   weights_list = "list", # list of estimated weights for each penalty
   
   # functions
   # adds new element to specified environment
   append = function(o, env_name) {
     listE1 = ls(private[[env_name]])
     listE2 = ls(o$.__enclos_env__$private[[env_name]])
     for(v in listE2) {
       if(v %in% listE1) {
         next
       } else {
         private[[env_name]][[v]] <- o$.__enclos_env__$private[[env_name]][[v]]
       }
     }
   },
   
   # update balance constraint parameters
   bal_param_update = function(osqp_args = NULL, tol = 1e-7) {
     # update balance constraint parameters and then returns
     # the langrangian terms to add to the loss
     
     # save weights incase not already done
     private$weights <- private$get_weights()
     
     l_to <- length(private$target_objects)
     
     loss <- torch::torch_tensor(0.0, dtype = self$dtype, device = self$device)
     # calc_deriv <- FALSE
     
     machine_tol <- .Machine$double.xmin
     coef <- torch::torch_tensor(10000.0, dtype = self$dtype, device = self$device)
     
     if (length(l_to) > 0) {
       
       # variables in for loop
       delta <- bal <- ng_bal <- n <- d <- v <- w <- NULL
       problem <- q <- res <- l <- u <- A <- gamma <- NULL
       abs_bal <- const.viol <- which.max <- which.sign <- NULL
       
       
       # addresses of objects
       target_addresses <- ls(private$target_objects)
      
       # don't print everything by default
       # if(missing(osqp_args) || is.null(osqp_args)) osqp_args <- list(verbose = FALSE)
       
       # run through target means
       for (adds in target_addresses) {
         # get objects
         v   <- private$target_objects[[adds]]
         w   <- private$weights[[adds]]
         
         # setup values
         delta <- v$delta
         # n   <- nrow(v$bf)
         # d   <- ncol(v$bf)
         bal <-  v$bf$transpose(2,1)$matmul(w) - v$bt
         if (v$bf$requires_grad) { # center v$bf
           d   <- ncol(v$bf)
           torch::with_no_grad(
             v$bf$subtract_(bal$view(c(1L,d)))
             )
           next
         }
         
         # #Linear program variables
         # q  <- c(-as.numeric(bal), delta * rep(1, 2 * d)) #linear terms
         # A  <- rbind(cbind(matrix(0,d * 2,d), Matrix::Diagonal(d*2, 1)), cbind(Matrix::Diagonal(d,1),Matrix::Diagonal(d,-1),Matrix::Diagonal(d,1)) ) #constraints
         # l  <- rep(0, 3 * d) #constraint lower bounds
         # u  <- c(rep(Inf, 2 * d), rep(0,d)) #constraint upper bounds
         # 
         # # Linear program
         # problem <- osqp::osqp(q = q, A = A, l = l, u = u, pars = osqp_args)
         # res <- problem$Solve()
         # 
         # # dual variables for targets
         # gamma <- res$x[1:d]
         
         # approx answer if infeasible
         # if (res$info$status_val == -3 || res$info$status_val == -4) {
           ng_bal <- bal$detach()
           abs_bal   <- ng_bal$abs()
           const.viol<- abs_bal > delta
           which.max <- abs_bal$argmax()
           which.sign<- ng_bal[which.max]$sign()
           gamma <- torch::torch_zeros_like(ng_bal)
           #   if ( length(which.max) == 0 ) {
           #   browser()
           #   torch::torch_zeros_like(bal[1L], dtype = bal$dtype)
           # } else {
           #   bal$detach()[which.max]$sign()
           # }
           gamma[which.max] <- which.sign * const.viol[which.max]
         # }
         
         # update loss
           # cur_loss <- bal$dot(gamma) -  sum(abs(gamma)) * delta 
           # loss <- loss + cur_loss * 10000.0
           cur_loss <- bal$dot(gamma) - sum(abs(gamma)) * delta 
           if ( cur_loss$item() / (machine_tol + delta) > tol ) {
             loss <- loss + cur_loss * coef
           }
           
       }

     } 
     return(loss)
     # return(list(loss = loss, calc_grad = calc_deriv))
   },
   
   # check balance constraints
   balance_check = function() {
     l_to <- length(private$target_objects)
     
     ret <- NULL
     
     if (length(l_to) > 0) {
       delta <- bal <- n <- d <- v <- w <- NULL
       
       
       # addresses of objects
       target_addresses <- ls(private$target_objects)
       ret <- rlang::env()
       
       # don't print everything by default
       # if(missing(osqp_args) || is.null(osqp_args)) osqp_args <- list(verbose = FALSE)
       
       # run through target means
       for (adds in target_addresses) {
         # get objects
         v   <- private$target_objects[[adds]]
         w   <- private$measures[[adds]]$weights
         
         # setup values
         delta <- v$delta
         # n   <- nrow(v$bf)
         # d   <- ncol(v$bf)
         bal <-  v$bf$transpose(2,1)$matmul(w) - v$bt
         if(v$bf$requires_grad) bal <- bal/v$bf$std(1)
         
         ret[[adds]] <- list(balance = bal, delta = delta)
         
       }
     }
     return(ret)
   },
   
   # deep clone function
   deep_clone = function(name, value) {
     if (name %in% c("problems", "target_objects", "measures", "ot_objects")){
       list2env(as.list.environment(value, all.names = TRUE),
                parent = emptyenv())
     } else {
       value
     }
   },
   
   delta_values_setup = function(run.quick = TRUE, osqp_args = NULL) {
     
     # check if any target_objects
     n_tf    <- length(private$target_objects)
     
     # check if ndelta>1
     n_delta <- length(private$penalty_list$delta)
     
     # default to not verbose
     if (is.null(osqp_args)) {
       osqp_args <- list(verbose = FALSE)
     }
     
     # run if are target_objects
     if (n_tf > 0 && n_delta > 1) {
       
       w       <- vector("list", n_delta)
       
       # if check all deltas for all lambdas
       if (isFALSE(run.quick)) {
         names_d  <- as.character(private$penalty_list$delta)
         d <- bf <- w <- v <- m <- NULL
         
         target_addresses <- ls(private$target_objects)
         successful.deltas <- NULL
         
         for (addy in  target_addresses) {
           
           v  <- private$target_objects[[addy]]
           m  <- private$measures[[addy]]
           
           if(v$bf$requires_grad) next
           
           bf <- SBW_4_oop$new(source = v$bf,
                               target = v$bt,
                               # a      = m$weights,
                               prob.measure = m$probability_measure,
                               osqp_opts = osqp_args)
           
           for (nd in names_d) {
             d <- eval(rlang::parse_expr(nd))
             check <- bf$solve(d)
             if (!is.null(check)) {
               successful.deltas <- c(successful.deltas, d)
             } else {
               next
             }
           }
           successful.deltas <- sort(unique(successful.deltas), decreasing = TRUE)
           names_d  <- as.character(successful.deltas)
         }
         
         # set feasible deltas
         private$penalty_list$delta <- successful.deltas
         
       } else { # quick run
         
         # selects best delta for each target
         names_d  <- as.character(private$penalty_list$delta)
         d <- bf <- w <- v <- m <- NULL
         means <- NULL
         w_star <- a_star <- means <- NULL
         
         target_addresses <- ls(private$target_objects)
         
         nboot <- 1000L
         
         # check for each target
         for (addy in  target_addresses) {
           w        <- vector("list", n_delta)
           names(w) <- names_d
           
           v  <- private$target_objects[[addy]]
           m  <- private$measures[[addy]]
           
           if(v$bf$requires_grad) next
           
           bf <- SBW_4_oop$new(source = v$bf$to(device = "cpu"),
                               target = v$bt$to(device = "cpu"),
                               # a      = m$weights,
                               prob.measure = m$probability_measure,
                               osqp_opts = osqp_args)
           
           # w  <- lapply(names_d, function(nd) {
           #   d <- eval(rlang::parse_expr(nd))
           #   bf$solve(d)
           #   }) |> 
           #   setNames(names_d)
           for (nd in names_d) {
             d <- eval(rlang::parse_expr(nd))
             w[[nd]] <- bf$solve(d)
           }
           
           means <- sbw_oop_bs_(w, 
                                nboot, 
                                as.matrix(bf$source_scale$to(device = "cpu")), 
                                as.numeric(bf$target_scale$to(device = "cpu")), 
                                as.numeric(m$init_weights$to(device = "cpu")))
           
           # means <- rep(0.0, length(w))
           # for (i in 1L:1000L) {
           #   a_star <- rmultinom(1L, bf$n, as.numeric(m$init_weights))
           #   w_star <- lapply(w, function(ww) ww * a_star)
           #   means  <- means + vapply(w_star, bf$evalBoot, FUN.VALUE = 0.0)/1000.0
           # }
           
           private$target_objects[[addy]]$delta <- as.numeric(names(w)[which.min(means)[1L]])
           
         }
         
         
         # set global delta to NA, which will keep the individualized ones
         private$penalty_list$delta <- NA_real_ 
       }
     }
   },
   
   # make sure no barycenters!!
   frankwolfe_setup = function() {
     
     addresses <- ls(private$parameters)
     meas <- NULL
     
     for (a in addresses) {
       meas <- private$measures[[a]]
       if (meas$adapt == "x") stop("Our Frank-Wolfe algorithm can't handle barycenters. Depending on your problem, you can optimize the barycenter and then optimize the weights with Frank-Wolfe. Alternatively, you can use the torch optimizers directly.")
     }
     
   },
   
   # frankwolfe optimization algorithm
   frankwolfe_step = function(opt, osqp_args, tol) {
     
     # get weights and retain grad
     # can probably be deleted
     weight_setup <- function() {
       private$weights <- private$get_weights()
       pw <- NULL
       for(w in ls(private$weights)) {
         pw <- private$weights[[w]]
         if(pw$requires_grad) pw$retain_grad()
       }
     }
     
     # retain grad
     weight_retain_grad <- function() {
       pw <- NULL
       for(w in ls(private$weights)) {
         pw <- private$weights[[w]]
         if(pw$requires_grad) pw$retain_grad()
       }
     }
     
     # get grad of weights
     weight_grad <- function() {
       w_names <- ls(private$weights)
       out_grad <- rlang::env()
       class(out_grad) <- c("weightEnv", class(out_grad))
       pw <- NULL
       for(w in w_names) {
         pw <- private$weights[[w]]
         if(pw$requires_grad) out_grad[[w]] <- pw$grad
       }
       
       return(out_grad)
     }
     
     # get addresses
     bf_address    <- ls(private$target_objects) # measures with balance functions
     param_address <- ls(private$parameters) # all parameters/ meas with grads
     
     # setup holders
     old_weights <- private$parameters_get_set(clone = TRUE)
     # class(old_weights) <- c("weightEnv", class(old_weights))
     
     new_weights <- rlang::env()
     class(new_weights) <- c("weightEnv", class(new_weights))
     
     ## Run functions ##
     
     # zero out gradients
     private$zero_grad()
     
     # eval loss and get gradients
     # weight_setup()
     # private$ot_update() # now in loss fun
     init_loss          <- self$loss
     weight_retain_grad()
     init_loss$backward()
     
     # setup osqp args
     osqp_arg_call <- rlang::call2(osqp::osqpSettings, !!!osqp_args)
     osqp_args <- eval(osqp_arg_call)
     
     res <- cur_env <- meas <- osqp_opt <- n <- prob.measure <- sums <- sum_bounds <- bt_cpu <- NULL
     
     
     # get solutions most correlated with the gradients
     for (addy in param_address) {
       meas <- private$measures[[addy]]
       n <- meas$n
       prob.measure <- meas$probability_measure
       sums <- switch(1L + prob.measure,
                      Matrix::Matrix(data = 0, nrow = 1, ncol = n),
                      Matrix::Matrix(data = 1, nrow = 1, ncol = n))
       sum_bounds <- switch(1L + prob.measure,
                            c(0, Inf),
                            c(1, 1))
       if (addy %in% bf_address) {
         cur_env <- private$target_objects[[addy ]]
         bt_cpu <- as.numeric(cur_env$bt$to(device = "cpu"))
         osqp_opt <- osqp::osqp(q = as.numeric(private$weights[[addy]]$grad$to(device = "cpu")), 
                                A = rbind(
                                  t(as.matrix(cur_env$bf$to(device = "cpu"))),
                                  Matrix::Diagonal(n, x = 1),
                                  sums
                                ),
                                l = c(-cur_env$delta + bt_cpu, rep(0,n), sum_bounds[1]),
                                u = c(cur_env$delta + bt_cpu, rep(Inf,n), sum_bounds[2]),
                                pars = osqp_args)
       } else {
         osqp_opt <- osqp::osqp(q = as.numeric(private$weights[[addy]]$grad$to(device = "cpu")), 
                                A = rbind(
                                  Matrix::Diagonal(n, x = 1),
                                  sums
                                ),
                                l = c(rep(0,n), sum_bounds[1]),
                                u = c(rep(Inf,n), sum_bounds[2]),
                                pars = osqp_args)
       }
       
       res <- osqp_opt$Solve()
       res$x[res$x < 0] <- 0
       
       #save into env of weights
       new_weights[[addy]] <- switch(1L + prob.measure,
         res$x,
         renormalize(res$x))
     }
     
     deriv <- weight_grad()
     # class(deriv) <- c("weightEnv", class(deriv))
     
     deltaG <- new_weights - old_weights
     derphi0 <- sum(deltaG * deriv)
     
     # need function to update weights and run update on OT
     armijo_loss_fun <- function(x, dx, alpha, ...) {
       
       if(inherits(alpha, "torch_tensor")) {
         alpha <- as.numeric(alpha$item())
       }
       # assign linearly shifted weights
       private$weights <- x + dx * alpha
       
       # update OT problems
       # private$ot_update()
       
       # evaluate loss (which updates ot problems)
       loss <- self$loss$detach()
       
       # return value
       return(loss)
     }
     
     if (-derphi0$item() >  tol) {
       search_res <-  scalar_search_armijo(phi = armijo_loss_fun,
                                           phi0 = init_loss$detach()$item(),
                                           derphi0 = derphi0$item(),
                                           x = old_weights,
                                           dx = deltaG,
                                           c1 = 1e-4, alpha0 = 1.0, 
                                           amin = 0)
       if ( !is.null(search_res$alpha)) {
         if(inherits(search_res$alpha, "torch_tensor")) search_res$alpha <- as.numeric(search_res$alpha$item())
         search_res$alpha = min(max(search_res$alpha,0), 1)
         new_weights <- old_weights +  deltaG * search_res$alpha
         private$parameters_get_set(new_weights)
         loss <- search_res$phi1$detach()
       } else {
         private$parameters_get_set(old_weights)
         loss <- init_loss$detach()
       }
       
     } else {
       private$parameters_get_set(old_weights)
       loss <- init_loss$detach()
     }
     
     return(loss)
   },
   
   get_weights = function(clone = FALSE) {
     
     out <- rlang::env()
     class(out) <- c("weightEnv", class(out))
     
     addresses <- ls(private$measures)
     for (v in addresses) {
       out[[v]] <- 
         private$measures[[v]]$weights
     }
     return(out)
   },
   
   # returns curent gradient value for parameters
   grad = function() {
     param_address <- ls(private$parameters)
     out <- rlang::env()
     for ( p in param_address ) {
       out[[p]] <- private$parameters[[p]]$grad
     }
     return(out)
   },
   
   iterate_over_delta = function(niter, tol) {
     
     # counter for delta
     n_delta <- length(private$penalty_list$delta)
     names_delta <- as.character(private$penalty_list$delta)
     
     # set weights_delta holder
     weights_delta <- vector("list", n_delta)
     names(weights_delta) <- names_delta
     
     iter_delta <- vector("numeric", n_delta) |> setNames(names_delta)
     losses_delta <- vector("numeric", n_delta) |> setNames(names_delta)
     
     # run loop over delta values
     for (d in private$penalty_list$delta) {
       # set bf constraint if needed
       private$set_delta(d)
       
       # run inner loop that actually optimizes
       res  <- private$optimization_loop(niter, tol)
       
       # copy estimated weights to weights_delta holder
       weights_delta[[paste(d)]] <- private$parameters_get_set(clone = TRUE)
       
       # save iterations and final losses
       losses_delta[[paste(d)]] <- res$loss
       iter_delta[[paste(d)]] <- res$iter
       
     }
     
     return(list(loss = losses_delta, 
                 iter = iter_delta,
                 weights = weights_delta))
     
   },
   
   iterate_over_lambda = function(niter, tol) {
     for (l in private$penalty_list$lambda) {  
       # set penalty for OT problems
       private$set_lambda(l)
       
       # initialize ot dual
       private$ot_update(only_params = FALSE, get_weights = TRUE, use_grad = FALSE)
       
       # optimize over delta values
       res <- private$iterate_over_delta(niter, tol) # calls main workhorse functions
       
       # assign weights_delta to weights_list permanent object
       private$weights_list[[as.character(l)]] <- res$weights
       
       # save diagnostic values
         # save iterations run
         private$iterations_run[[as.character(l)]] <- res$iter
         
         # save final loss
         private$final_loss[[as.character(l)]] <- res$loss
         
       
       
     }
   },
   
   lr_reduce = function(loss) { #complicatd lr reduce to fix bugs in torch
     sched <- private$sched
     if (is.null(sched)) return(TRUE)
     
     check <- TRUE
     
     if (inherits(sched, "lr_reduce_on_plateau")) {
       get_lr <- function() {
         tryCatch(
           sched$get_lr(),
           error = function(e) sapply(sched$optimizer$state_dict()$param_groups, function(p) p[["lr"]])
         )
       }
       check <- FALSE
       
       old_mode <- sched$threshold_mode
       if (as.logical(loss == sched$best) ) sched$threshold_mode <- "abs"
       
       init_lr <- sched$optimizer$defaults$lr
       lr <- get_lr()
       min_lr <- sched$min_lrs[[1]]
       
       improved <- as.logical(sched$.is_better(loss, sched$best))
       sched$step(loss)
       
       if (inherits(private$opt, "optim_lbfgs") && isTRUE(private$opt$defaults$line_search_fn == "strong_wolfe") ) {
         if (private$lbfgs_count == sched$patience) {
           if (private$lbfgs_reset > 0) check <- TRUE
           best <- sched$best
           private$torch_optim_reset()
           private$lbfgs_count <- 0L
           private$lbfgs_reset <- private$lbfgs_reset + 1L
           private$sched$best <- best
         } else if (!improved) {
           private$lbfgs_count <- private$lbfgs_count + 1L
         } else if (improved) {
           private$lbfgs_count <- 0L
           private$lbfgs_reset <- 0L
         }
         
       } else {
         # if ( sched$num_bad_epochs == sched$patience && init_lr != lr) {
         if (abs(lr - min_lr)/(lr + .Machine$double.eps) < 1e-3) {
           check <- TRUE 
         } else {
           check <- FALSE
         }
       }
       
       sched$threshold_mode <- old_mode
     } else {
       sched$step()
     }
     
     
     return(check)
   },
   # holder variable for selected optimization method
   optimization_step = "function",
   
   # optimization inner loop
   optimization_loop = function(niter, tol) {
     # set initial loss
     loss_old <- self$loss$detach()$item()
     
     # check convergence variable initialization  
     check <- TRUE
     
     # run optimization steps for given penalties
     for ( i in 1:niter ) {
       # take opt step and return loss
       loss <- private$optimization_step(private$opt, private$osqp_args, tol = tol)$detach()$to(device = "cpu")$item()
       
       # reduce lr
       check <- private$lr_reduce(loss)
       # check <- lr_reduce(opt_sched, loss$detach())
       
       # see if converged
       if ( check && (i >1) && converged(loss, loss_old, tol) ) break
       
       # if not converged, save old loss
       loss_old <- loss
     }
     
     #reset torch optimizer if present
     private$torch_optim_reset()
     private$lbfgs_count <- private$lbfgs_reset <- 0L
     
     return(list(loss = loss,
                 iter = i))
     
   },
   
   #update ot problems
   ot_update = function(only_params = TRUE, get_weights = TRUE, use_grad = TRUE) {
     ot_adds    <- ls(private$ot_objects)
     param_adds <- ls(private$parameters)
     
     # set weights
     if (isTRUE(get_weights)) {
       private$weights <- private$get_weights()
     }
     
     # variables for the loop
     has_grad <- weight_grad <- cost_grad <- FALSE
     problem_addy <- NULL
     measure_1 <- measure_2 <- NULL
     
     cost_forward <- function(add_1, add_2, ot) {
       measure_1 <- private$measures[[add_1]]
       measure_2 <- private$measures[[add_2]]
       a_1 <- measure_1$adapt
       a_2 <- measure_2$adapt
       
       if(a_1 == "x" || a_2 == "x") {
         x <- measure_1$.__enclos_env__$private$data_
         y <- measure_2$.__enclos_env__$private$data_
         
         if(x$requires_grad) {
           update_cost(ot$C_xy, x, y$detach())
           if(ot$debias) update_cost(ot$C_xx, x, x$detach())
         }
         if(y$requires_grad) {
           update_cost(ot$C_yx, y, x$detach())
           if(ot$debias) update_cost(ot$C_yy, y, y$detach())
         }
         has_grad <- TRUE
       } else{
         has_grad <- FALSE
       }
       
       return(has_grad)
     }
     
     weights_forward <- function(add_1, add_2, ot) {
       
       has_grad <- FALSE
       if(private$measures[[add_1]]$adapt == "weights") {
         ot$a <- private$weights[[add_1]]
         has_grad <- TRUE
       }
       if(private$measures[[add_2]]$adapt == "weights") {
         ot$b <-  private$weights[[add_2]]
         has_grad <- TRUE
       }
       return(has_grad)
     }
     
     # loop over OT problems
     ot_prob_loop <- function() {
       cur_ot <- NULL
       for (addy in ot_adds) {
         cur_ot <- private$ot_objects[[addy]]
         
         # need to update weights used
         problem_addy <- private$problems[[addy]]
         
         # update cost if needed
         cost_grad <- cost_forward(problem_addy[[1L]],
                                   problem_addy[[2L]],
                                   cur_ot)
         
         # update weights if needed
         wt_grad <- weights_forward(problem_addy[[1L]],
                                    problem_addy[[2L]],
                                    cur_ot)
         
         has_grad <- (cost_grad || wt_grad)
         
         # run sinkhorn if needed
         if ( is.finite(cur_ot$penalty) && (has_grad || !only_params) ) {
           cur_ot$sinkhorn_opt(niter = private$ot_niter, tol = private$ot_tol)
         }
       }
     }
     # differential run with/without grad
     if (use_grad) {
       ot_prob_loop() # grad, if needed
     } else {
       torch::with_no_grad(ot_prob_loop) # no grad
     }
     
   },
   # return or set parameter weights
   parameters_get_set = function(value, clone = FALSE) {
     
     # return a clone of the weights
     if (missing(value)) {
       out <- rlang::env()
       class(out) <- c("weightEnv", class(out))
       
       param_addresses <- ls(private$parameters)
       for (v in param_addresses) {
         out[[v]] <- if (isTRUE(clone)) {
           if(private$measures[[v]]$adapt == "weights") {
             private$measures[[v]]$weights$detach()$clone()
           } else if (private$measures[[v]]$adapt == "x") {
             private$measures[[v]]$x$detach()$clone()
           }
         } else {
           if(private$measures[[v]]$adapt == "weights") {
             private$measures[[v]]$weights
           } else if (private$measures[[v]]$adapt == "x") {
             private$measures[[v]]$x
           }
         }
       }
       return(out)
     }
     
     # set weights
     param_addresses <- ls(private$parameters)
     
     if(!rlang::is_environment(value)){
       names(value) <- value_addresses <- 1:length(value)
     } else {
       value_addresses <- ls(value)
     }
     
     if(length(value_addresses) == length(param_addresses)) {
       v <- NULL
       u <- NULL
       torch::with_no_grad(
         for (i in seq_along(param_addresses)) {
           v <- param_addresses[[i]]
           u <- value_addresses[[i]]
           if(private$measures[[v]]$adapt == "weights") {
             private$measures[[v]]$weights <- value[[u]]
           } else if (private$measures[[v]]$adapt == "x") {
             private$measures[[v]]$x <- value[[u]]
           } else {
             stop("Error in assignment. Tried to assign to a measure without any gradients.")
           }
         }
       )
     } else {
       stop("Input must have same number of groups as do the parameters.")
     }
   },
   
   torch_optim_step = function(opt, osqp_args = NULL, tol) {
     
     closure <- function() {
       opt$zero_grad()
       # self$forward()
       loss <- private$bal_param_update(osqp_args = osqp_args, tol) +
         self$loss  
       
       # only run ot if no bal constraint violations
       # if (loss$item() == 0) {
         # private$ot_update()
         # loss <- self$loss + loss
       # }
       
       loss$backward()
       return(loss$to(device = "cpu"))
     }
     
     if( inherits(opt, "optim_lbfgs") ) {
       loss <- opt$step(closure)
     } else {
       loss <- closure()
       opt$step()
     }
     
     return(loss)
     
   },
   torch_optim_setup = function(torch_optim, 
                                torch_scheduler,
                                torch_args) {
     
     names_args <- names(torch_args)
     optim_args_names <- names(formals(torch_optim))
     opt_args <- torch_args[match(optim_args_names,
                                  names_args, nomatch = 0L)]
     opt_call <- rlang::call2(torch_optim,
                              params = private$parameters,
                              !!!opt_args)
     private$opt <- eval(opt_call)
     
     torch_lbfgs_check(private$opt)
     
     private$optimization_step <- private$torch_optim_step
     
     if (!is.null(torch_scheduler)) {
       scheduler_args_names <- names(formals(torch_scheduler))
       sched_args <- torch_args[match(scheduler_args_names, 
                                      names_args, 
                                      nomatch = 0L)]
       if(inherits(torch_scheduler, "lr_reduce_on_plateau") &&
          is.null( sched_args$patience ) ) {
         sched_args$patience <- 1L
       }
       if(inherits(torch_scheduler, "lr_reduce_on_plateau") && is.null( sched_args$min_lr ) && inherits(private$opt, "optim_lbfgs") ) {
         sched_args$min_lr <- private$opt$defaults$lr * 1e-3
       }
       scheduler_call <- rlang::call2(torch_scheduler,
                                      optimizer = private$opt,
                                      !!!sched_args)
       opt_sched <- eval(scheduler_call)
     } else {
       opt_sched <- scheduler_call <- NULL
     }
     
     private$opt_calls <- list(opt = opt_call,
                               sched = scheduler_call)
     
     private$sched <- opt_sched
     
     # return(list(opt = opt, opt_call = opt_call,
     #             sched = opt_sched, sched_call = scheduler_call))
     
   },
   torch_optim_reset = function(lr = NULL) {
     if(!is.null(private$opt)) {
       opt_call <- private$opt_calls$opt
       if(is.null(lr)) {
         private$opt <- eval(rlang::call_modify(opt_call, 
                                                params = private$parameters))
       } else {
         # browser()
         # def <- rlang::call_match(opt_call[[1]], opt_call, defaults = TRUE)
         private$opt <- eval(rlang::call_modify(opt_call, 
                                                params = private$parameters,
                                                lr = lr))
       }
       
     }
     
     if(!is.null(private$sched)) {
       private$sched <- eval(rlang::call_modify(
         private$opt_calls$sched, 
                               optimizer = private$opt))
     }
     
   },
   set_lambda = function(lambda) {
     stopifnot("lambda value must be >= 0" = (lambda >= 0))
     ot_prob_names <- ls(private$ot_objects)
     for (v in ot_prob_names) {
       private$ot_objects[[v]]$penalty <- lambda
     }
   },
   set_delta = function(delta) {
     if(length(private$target_objects) > 0 && !is.na(delta)) {
       stopifnot("delta value must be >= 0" = (delta >= 0))
       target_names <- ls(private$target_objects)
       for (v in target_names) {
         if(!private$target_objects[[v]]$bf$requires_grad) private$target_objects[[v]]$delta <- delta
       }
     }
   },
   set_penalties = function(lambda, delta) {
     if(!missing(lambda)) {
       private$set_lambda(lambda)
     }
     
     if(!missing(delta)) {
       private$set_delta(delta)
     }
   },
   setup_choose_hyperparameters = function() {
     return(
       list(delta = private$penalty_list$delta,
            lambda = private$penalty_list$lambda,
            weights_list = private$weights_list)
            )
   },
   
   unaryop = function(o, fun) {
     if(! is_ot_problem(o)) {
       private$objective <- rlang::parse_expr(
         paste(rlang::expr_text(private$objective), fun, o)
         )
     } else {
       stopifnot(self$device == o$device)
       stopifnot(self$dtype == o$dtype)
       private$append(o, "measures")
       private$append(o, "problems")
       private$objective <- rlang::parse_expr(paste(
        rlang::expr_text(private$objective),
        fun,
        rlang::expr_text(o$.__enclos_env__$private$objective)
       ))
     }
   },
  
   zero_grad = function() {
     for ( p in private$parameters ) {
       if(torch::is_undefined_tensor(p$grad)) next
       if(!p$requires_grad) {
         warning("One of the objects in parameters doesn't require a grad. Report this bug!")
         next
       }
       torch::with_no_grad( p$grad$copy_(0.0) )
     }
   }
 )}
)

#' Object Oriented OT Problem
#'
#' @param measure_1 An object of class [Measure]
#' @param measure_2 An object of class [Measure]
#' @param ... Not used at this time 
#'
#' @return An R6 object of class "OTProblem"
#' @details # Public fields
#'   \if{html}{\out{<div class="r6-fields">}}
#'   \describe{
#'     \item{\code{device}}{the \code{\link[torch:torch_device]{torch::torch_device()}} of the data.}
#'     \item{\code{dtype}}{the \link[torch:torch_dtype]{torch::torch_dtype} of the data.}
#'     \item{\code{selected_delta}}{the delta value selected after \code{choose_hyperparameters}}
#'     \item{\code{selected_lambda}}{the lambda value selected after \code{choose_hyperparameters}}
#'   }
#'   \if{html}{\out{</div>}}
#' @details # Active bindings
#'   \if{html}{\out{<div class="r6-active-bindings">}}
#'   \describe{
#'     \item{\code{loss}}{prints the current value of the objective. Only availble after the \href{#method-OTProblem-solve}{\code{OTProblem$solve()}} method has been run}
#'     \item{\code{penalty}}{Returns a list of the lambda and delta penalities that will be iterated through. To set these values, use the \href{#method-OTProblem-setup_arguments}{\code{OTProblem$setup_arguments()}} function.}
#'   }
#'   \if{html}{\out{</div>}}
#' @details # Methods
#'   \subsection{Public methods}{
#'     \itemize{
#'     \item \href{#method-OTProblem-add}{\code{OTProblem$add()}}
#'     \item \href{#method-OTProblem-subtract}{\code{OTProblem$subtract()}}
#'     \item \href{#method-OTProblem-multiply}{\code{OTProblem$multiply()}}
#'     \item \href{#method-OTProblem-divide}{\code{OTProblem$divide()}}
#'     \item \href{#method-OTProblem-setup_arguments}{\code{OTProblem$setup_arguments()}}
#'     \item \href{#method-OTProblem-solve}{\code{OTProblem$solve()}}
#'     \item \href{#method-OTProblem-choose_hyperparameters}{\code{OTProblem$choose_hyperparameters()}}
#'     \item \href{#method-OTProblem-info}{\code{OTProblem$info()}}
#'     \item \href{#method-OTProblem-clone}{\code{OTProblem$clone()}}
#'     }
#'     }
#' \if{html}{\out{<hr>}}
#' \if{html}{\out{<a id="method-OTProblem-add"></a>}}
#' \if{latex}{\out{\hypertarget{method-OTProblem-add}{}}}
#' \subsection{Method \code{add()}}{
#'   adds \code{o2} to the OTProblem
#'   \subsection{Usage}{
#'     \if{html}{\out{<div class="r">}}\preformatted{OTProblem$add(o2)}\if{html}{\out{</div>}}
#'   }
#'   \subsection{Arguments}{
#'     \if{html}{\out{<div class="arguments">}}
#'     \describe{
#'       \item{\code{o2}}{A number or object of class OTProblem}
#'     }
#'     \if{html}{\out{</div>}}
#'   }
#' }
#' \if{html}{\out{<hr>}}
#' \if{html}{\out{<a id="method-OTProblem-subtract"></a>}}
#' \if{latex}{\out{\hypertarget{method-OTProblem-subtract}{}}}
#' \subsection{Method \code{subtract()}}{
#'   subtracts \code{o2} from OTProblem
#'   \subsection{Usage}{
#'     \if{html}{\out{<div class="r">}}\preformatted{OTProblem$subtract(o2)}\if{html}{\out{</div>}}
#'   }
#'   \subsection{Arguments}{
#'     \if{html}{\out{<div class="arguments">}}
#'     \describe{
#'       \item{\code{o2}}{A number or object of class OTProblem}
#'     }
#'     \if{html}{\out{</div>}}
#'   }
#' }
#' \if{html}{\out{<hr>}}
#' \if{html}{\out{<a id="method-OTProblem-multiply"></a>}}
#' \if{latex}{\out{\hypertarget{method-OTProblem-multiply}{}}}
#' \subsection{Method \code{multiply()}}{
#'   multiplies OTProblem by \code{o2}
#'   \subsection{Usage}{
#'     \if{html}{\out{<div class="r">}}\preformatted{OTProblem$multiply(o2)}\if{html}{\out{</div>}}
#'   }
#'   \subsection{Arguments}{
#'     \if{html}{\out{<div class="arguments">}}
#'     \describe{
#'       \item{\code{o2}}{A number or an object of class OTProblem}
#'     }
#'     \if{html}{\out{</div>}}
#'   }
#' }
#' \if{html}{\out{<hr>}}
#' \if{html}{\out{<a id="method-OTProblem-divide"></a>}}
#' \if{latex}{\out{\hypertarget{method-OTProblem-divide}{}}}
#' \subsection{Method \code{divide()}}{
#'   divides OTProblem by \code{o2}
#'   \subsection{Usage}{
#'     \if{html}{\out{<div class="r">}}\preformatted{OTProblem$divide(o2)}\if{html}{\out{</div>}}
#'   }
#'   \subsection{Arguments}{
#'     \if{html}{\out{<div class="arguments">}}
#'     \describe{
#'       \item{\code{o2}}{A number or object of class OTProblem}
#'     }
#'     \if{html}{\out{</div>}}
#'   }
#' }
#' \if{html}{\out{<hr>}}
#' \if{html}{\out{<a id="method-OTProblem-setup_arguments"></a>}}
#' \if{latex}{\out{\hypertarget{method-OTProblem-setup_arguments}{}}}
#' \subsection{Method \code{setup_arguments()}}{
#'   \subsection{Usage}{
#'     \if{html}{\out{<div class="r">}}\preformatted{OTProblem$setup_arguments(
#'       lambda,
#'       delta,
#'       grid.length = 7L,
#'       cost.function = NULL,
#'       p = 2,
#'       cost.online = "auto",
#'       debias = TRUE,
#'       diameter = NULL,
#'       ot_niter = 1000L,
#'       ot_tol = 0.001
#'     )}\if{html}{\out{</div>}}
#'   }
#'   \subsection{Arguments}{
#'     \if{html}{\out{<div class="arguments">}}
#'     \describe{
#'       \item{\code{lambda}}{The penalty parameters to try for the OT problems. If not provided, function will select some}
#'       \item{\code{delta}}{The constraint paramters to try for the balance function problems, if any}
#'       \item{\code{grid.length}}{The number of hyperparameters to try if not provided}
#'       \item{\code{cost.function}}{The cost function for the data. Can be any function that takes arguments \code{x1}, \code{x2}, \code{p}. Defaults to the Euclidean distance}
#'       \item{\code{p}}{The power to raise the cost matrix by. Default is 2}
#'       \item{\code{cost.online}}{Should online costs be used? Default is "auto" but "tensorized" stores the cost matrix in memory while "online" will calculate it on the fly.}
#'       \item{\code{debias}}{Should debiased OT problems be used? Defaults to TRUE}
#'       \item{\code{diameter}}{Diameter of the cost function.}
#'       \item{\code{ot_niter}}{Number of iterations to run the OT problems}
#'       \item{\code{ot_tol}}{The tolerance for convergence of the OT problems}
#'     }
#'     \if{html}{\out{</div>}}
#'   }
#'   \subsection{Returns}{
#'     NULL
#'   }
#'   \subsection{Examples}{
#'     \if{html}{\out{<div class="r example copy">}}
#'     \preformatted{ ot$setup_arguments(lambda = c(1000,10))
#'     }
#'     \if{html}{\out{</div>}}
#'   }
#' }
#' \if{html}{\out{<hr>}}
#' \if{html}{\out{<a id="method-OTProblem-solve"></a>}}
#' \if{latex}{\out{\hypertarget{method-OTProblem-solve}{}}}
#' \subsection{Method \code{solve()}}{
#'   Solve the OTProblem at each parameter value. Must run setup_arguments first.
#'   \subsection{Usage}{
#'     \if{html}{\out{<div class="r">}}\preformatted{OTProblem$solve(
#'       niter = 1000L,
#'       tol = 1e-05,
#'       optimizer = c("torch", "frank-wolfe"),
#'       torch_optim = torch::optim_lbfgs,
#'       torch_scheduler = torch::lr_reduce_on_plateau,
#'       torch_args = NULL,
#'       osqp_args = NULL,
#'       quick.balance.function = TRUE
#'     )}\if{html}{\out{</div>}}
#'   }
#'   \subsection{Arguments}{
#'     \if{html}{\out{<div class="arguments">}}
#'     \describe{
#'       \item{\code{niter}}{The nubmer of iterations to run solver at each combination of hyperparameter values}
#'       \item{\code{tol}}{The tolerance for convergence}
#'       \item{\code{optimizer}}{The optimizer to use. One of "torch" or "frank-wolfe"}
#'       \item{\code{torch_optim}}{The \code{torch_optimizer} to use. Default is \link[torch:optim_lbfgs]{torch::optim_lbfgs}}
#'       \item{\code{torch_scheduler}}{The \link[torch:lr_scheduler]{torch::lr_scheduler} to use. Default is \link[torch:lr_reduce_on_plateau]{torch::lr_reduce_on_plateau}}
#'       \item{\code{torch_args}}{Arguments passed to the torch optimizer and scheduler}
#'       \item{\code{osqp_args}}{Arguments passed to \code{\link[osqp:osqpSettings]{osqp::osqpSettings()}} if appropriate}
#'       \item{\code{quick.balance.function}}{Should \code{\link[osqp:osqp]{osqp::osqp()}} be used to select balance function constraints (delta) or not. Default true.}
#'     }
#'     \if{html}{\out{</div>}}
#'   }
#'   \subsection{Examples}{
#'     \if{html}{\out{<div class="r example copy">}}
#'     \preformatted{ ot$solve(niter = 1, torch_optim = torch::optim_rmsprop)
#'     }
#'     \if{html}{\out{</div>}}
#'   }
#' }
#' \if{html}{\out{<hr>}}
#' \if{html}{\out{<a id="method-OTProblem-choose_hyperparameters"></a>}}
#' \if{latex}{\out{\hypertarget{method-OTProblem-choose_hyperparameters}{}}}
#' \subsection{Method \code{choose_hyperparameters()}}{
#'   Selects the hyperparameter values through a bootstrap algorithm
#'   \subsection{Usage}{
#'     \if{html}{\out{<div class="r">}}\preformatted{OTProblem$choose_hyperparameters(
#'       n_boot_lambda = 100L,
#'       n_boot_delta = 1000L,
#'       lambda_bootstrap = Inf
#'     )}\if{html}{\out{</div>}}
#'   }
#'   \subsection{Arguments}{
#'     \if{html}{\out{<div class="arguments">}}
#'     \describe{
#'       \item{\code{n_boot_lambda}}{The number of bootstrap iterations to run when selecting lambda}
#'       \item{\code{n_boot_delta}}{The number of bootstrap iterations to run when selecting delta}
#'       \item{\code{lambda_bootstrap}}{The penalty parameter to use when selecting lambda. Higher numbers run faster.}
#'     }
#'     \if{html}{\out{</div>}}
#'   }
#'   \subsection{Examples}{
#'     \if{html}{\out{<div class="r example copy">}}
#'     \preformatted{ ot$choose_hyperparameters(n_boot_lambda = 10, 
#'                                              n_boot_delta = 10, 
#'                                              lambda_bootstrap = Inf)
#'     }
#'     \if{html}{\out{</div>}}
#'   }
#' }
#' \if{html}{\out{<hr>}}
#' \if{html}{\out{<a id="method-OTProblem-info"></a>}}
#' \if{latex}{\out{\hypertarget{method-OTProblem-info}{}}}
#' \subsection{Method \code{info()}}{
#'   Provides diagnostics after solve and choose_hyperparameter methods have been run.
#'   \subsection{Usage}{
#'     \if{html}{\out{<div class="r">}}\preformatted{OTProblem$info()}\if{html}{\out{</div>}}
#'   }
#'   \subsection{Returns}{
#'     a list with slots
#'     \itemize{
#'       \item \code{loss} the final loss values
#'       \item \code{iterations} The number of iterations run for each combination of parameters
#'       \item \code{balance.function.differences} The final differences in the balance functions
#'       \item \code{hyperparam.metrics} A list of the bootstrap evalustion for delta and lambda values}
#'   }
#'   \subsection{Examples}{
#'     \if{html}{\out{<div class="r example copy">}}
#'     \preformatted{ ot$info()
#'     }
#'     \if{html}{\out{</div>}}
#'   }
#' }
#' \if{html}{\out{<hr>}}
#' \if{html}{\out{<a id="method-OTProblem-clone"></a>}}
#' \if{latex}{\out{\hypertarget{method-OTProblem-clone}{}}}
#' \subsection{Method \code{clone()}}{
#'   The objects of this class are cloneable with this method.
#'   \subsection{Usage}{
#'     \if{html}{\out{<div class="r">}}\preformatted{OTProblem$clone(deep = FALSE)}\if{html}{\out{</div>}}
#'   }
#'   \subsection{Arguments}{
#'     \if{html}{\out{<div class="arguments">}}
#'     \describe{
#'       \item{\code{deep}}{Whether to make a deep clone.}
#'     }
#'     \if{html}{\out{</div>}}
#'   }
#' }
#' @examples
#' ## ------------------------------------------------
#' ## Method `OTProblem(measure_1, measure_2)`
#' ## ------------------------------------------------
#'
#' if (torch::torch_is_installed()) {
#'   # setup measures
#'   x <- matrix(1, 100, 10)
#'   m1 <- Measure(x = x)
#'   
#'   y <- matrix(2, 100, 10)
#'   m2 <- Measure(x = y, adapt = "weights")
#'   
#'   z <- matrix(3,102, 10)
#'   m3 <- Measure(x = z)
#'   
#'   # setup OT problems
#'   ot1 <- OTProblem(m1, m2)
#'   ot2 <- OTProblem(m3, m2)
#'   ot <- 0.5 * ot1 + 0.5 * ot2
#'   print(ot)
#'
#' ## ------------------------------------------------
#' ## Method `OTProblem$setup_arguments`
#' ## ------------------------------------------------
#'
#'   ot$setup_arguments(lambda = 1000)
#'
#' ## ------------------------------------------------
#' ## Method `OTProblem$solve`
#' ## ------------------------------------------------
#'
#'   ot$solve(niter = 1, torch_optim = torch::optim_rmsprop)
#'
#' ## ------------------------------------------------
#' ## Method `OTProblem$choose_hyperparameters`
#' ## ------------------------------------------------
#'
#'   ot$choose_hyperparameters(n_boot_lambda = 1,
#'                             n_boot_delta = 1, 
#'                             lambda_bootstrap = Inf)
#'
#' ## ------------------------------------------------
#' ## Method `OTProblem$info`
#' ## ------------------------------------------------
#'
#' ot$info()
#' }
#' @export
OTProblem <- function(measure_1, measure_2,...) {
  
  OTProblem_$new(measure_1 = measure_1, 
                  measure_2 = measure_2)
}

is_ot_problem <- function(obj) {
  isTRUE(inherits(obj, "OTProblem"))
}

binaryop <- function(e1, e2, fun) {
  UseMethod("binaryop")
}

#' @export
binaryop.OTProblem <- function(e1, e2, fun) {
  if ( !is_ot_problem(e1) ) {
    stopifnot("LHS must be numeric or OTProblem" = is.numeric(e1))
    o_new <- e2$clone(deep = TRUE)
    o_old <- e1
  } else if ( !is_ot_problem(e2)) {
    stopifnot("RHS must be numeric or OTProblem"=is.numeric(e2))
    o_new <- e1$clone(deep = TRUE)
    o_old <- e2
  } else if(is_ot_problem(e1) && is_ot_problem(e2)) {
    o_new <- e1$clone(deep = TRUE)
    o_old <- e2
  } else {
    stop("error in basic operation on 'OTProblem' objects")
  }
  
  o_new[[fun]](o_old)
  
  return(o_new)
  
}

#' @export
`+.OTProblem` <- function(e1, e2) {
  o_new <- binaryop.OTProblem(e1, e2, "add")
  return(invisible(o_new))
  
}

#' @export
`-.OTProblem` <- function(e1, e2) {
  o_new <- binaryop.OTProblem(e1, e2, "subtract")
  return(invisible(o_new))
  
}

#' @export
`*.OTProblem` <- function(e1, e2) {
  o_new <- binaryop.OTProblem(e1, e2, "multiply")
  return(invisible(o_new))
  
}

#' @export
`/.OTProblem` <- function(e1, e2) {
  o_new <- binaryop.OTProblem(e1, e2, "divide")
  return(invisible(o_new))
  
}


OTProblem_$set("private", 
               "draw_boot_measure",
function() {
  addresses <- ls(private$measures)
  
  out <- rlang::env()
  n <- NULL
  prob <- NULL
  meas <- NULL
  
  for (add in addresses) {
    meas <- private$measures[[add]]
    n <- meas$n
    prob <- meas$init_weights
    out[[add]] <- prob$multinomial(n,replacement = TRUE)$add(1L)$bincount(minlength=n)
  }
  
  return(out)
}               
)

OTProblem_$set("private", 
               "eval_delta",
function (wts, boot) {
  addresses <- ls(private$target_objects)
  means     <- rep(0.0, length(addresses)) |> setNames(addresses)
  ns        <- rep(0.0, length(addresses)) |> setNames(addresses)
  n <- NULL
  w <- NULL
  b <- NULL
  m <- NULL
  s <- NULL
  w_star <- NULL
  target_obj<- NULL
  
  for (add in addresses) {
    if (private$measures[[add]]$adapt == "weights") {
      w <- wts[[add]]
    } else {
      w <- self$measure[[add]]$init_weights 
    }
    b <- boot[[add]]
    
    target_obj <- private$target_objects[[add]]
    n <- nrow(target_obj$bf)
    
    w_star <- w * b
    if(!target_obj$bf$requires_grad) {
      m <- target_obj$bf$transpose(2,1)$matmul( w_star )
      targ <- target_obj$bt
    } else {
      s <- target_obj$bf$detach()$std(1)
      m <- target_obj$bf$detach()$mean(1)/s
      targ <- target_obj$bt/s
    }
    
    means[[add]] <- (m - targ)$abs()$mean()$item()
    ns[[add]] <- n
  }
  
  return(weighted.mean(means, ns))
}               
               
)

OTProblem_$set("private", "eval_lambda",
function (wts, boot) {
  addresses <- ls(private$measures)
  prob_adds <- ls(private$ot_objects)
  wt_adds   <- ls(wts)
  sel_probs <- NULL
  prob_hold <- NULL
  p1_add    <- p2_add <- NULL
  
  n         <- NULL
  w         <- NULL
  b         <- NULL
  m         <- NULL
  w_star    <- NULL
  ot_obj    <- NULL
  meas      <- NULL
  
  for (add in addresses) {
   b <- boot[[add]]
   meas <- private$measures[[add]]
   
   w <- if(add %in% wt_adds) {
     wts[[add]]$detach()
   } else {
     meas$init_weights$detach()
   }
   
   if (meas$adapt == "weights" || meas$adapt == "none") {
     w_star <- w * b
   } else if (meas$adapt == "x") {
     w_star <- meas$init_weights * b
   } else {
     stop("Bug in this if else statement!")
   }
   
   sel_probs <- prob_adds[grep(add, prob_adds)]
   for (i in sel_probs) {
     prob_hold <- private$problems[[i]]
     p1_add <- prob_hold[[1]]
     p2_add <- prob_hold[[2]]
     if (p1_add == add){
       private$ot_objects[[i]]$a <- w_star
       if (meas$adapt == "x") {
         update_cost(private$ot_objects[[i]]$C_xy, 
                     w$detach(), #w should be the data values in this case, not weights
                     private$measures[[p2_add]]$x$detach())
       }
     } else if (p2_add == add) {
       private$ot_objects[[i]]$b <- w_star
       if (meas$adapt == "x") {
         update_cost(private$ot_objects[[i]]$C_yx, 
                     w$detach(), #w should be the data values in this case, not weights
                     private$measures[[p1_add]]$x$detach())
       }
     } else {
       stop("Problem address not found. You found a bug!")
     }
   }
   
  }
  if(is.finite(private$ot_objects[[i]]$penalty ))  {
    for(o in ls(private$ot_objects)) {
      private$ot_objects[[o]]$sinkhorn_opt(20, 1e-3)
    }
  }
  dists <- self$loss$detach()$item()
  return(dists)
}               
               
)


# operators to add weight environments for OTProblems
setOldClass("weightEnv")
binaryop.weightEnv <- function(e1,e2, fun) {
  
  we1 <- inherits(e1, "weightEnv")
  we2 <- inherits(e2, "weightEnv")
  
  if(isFALSE(we1) || isFALSE(we2)) {
    if(we1 && !we2) {
      w1 <- e1
      w2 <- e2
    } else if (!we1 && we2) {
      w1 <- e2
      w2 <- e1
    } else {
      stop("You found a bug!")
    }
    return(unaryop.weightEnv(w1, w2, fun))
  }
  
  listE1 <- ls(e1)
  listE2 <- ls(e2)
  
  stopifnot("weightEnv objects must have the same length." = length(listE1) == length(listE2))
  
  out <- rlang::env()
  class(out) <- c("weightEnv",class(out))
  for(i in seq_along(e1)) {
    v <- listE1[[i]]
    u <- listE2[[i]]
    out[[v]] <- fun(e1[[u]], e2[[v]])
  }
  
  return(out)
}

unaryop.weightEnv <- function(e1,e2, fun) {
  listE1 <- ls(e1)
  
  out <- rlang::env()
  class(out) <- c("weightEnv",class(out))
  for(e in listE1) {
    out[[e]] <- fun(e1[[e]], e2)
  }
  
  return(out)
}

#' @export
`+.weightEnv` <- function(e1, e2) {
  binaryop(e1, e2, `+`)
}

#' @export
`-.weightEnv` <- function(e1, e2) {
  binaryop(e1, e2, `-`)
}

#' @export
`*.weightEnv` <- function(e1, e2) {
  binaryop(e1, e2, `*`)
}

#' @export
`/.weightEnv` <- function(e1, e2) {
  binaryop(e1, e2, `/`)
}


#' @export
sum.weightEnv <- function(..., na.rm = FALSE) {
  out <- 0.0
  l   <- list(...)
  if (length(l) == 1) {
    e1 <- l[[1]]
    listE1 <- ls(e1)
    for (e in listE1) {
      out <- sum(e1[[e]]) + out
    }
  } else {
    for (i in seq_along(l) ) {
      out <- sum(l[[i]]) + out
    }
  }
  
  return(out)
}


#### Optimizers relying on OTProblem class ####
#### NNM class ####
# maybe make a part of OTProblem...unclear how to unite several problems though...
NNM <- R6::R6Class(
  classname = "NNM",
  public = {list(
    setup_arguments = function(lambda = NULL, delta = NULL , 
                               grid.length = 7L,
                               cost.function = NULL, 
                               p = 2,
                               cost.online = "auto",
                               debias = TRUE,
                               diameter = NULL, ot_niter = 1000L,
                               ot_tol = 1e-5) {
      super$setup_arguments(lambda = 0, delta = NULL,
                            grid.length = 1,
                            cost.function = cost.function,
                            p = p,
                            cost.online = cost.online,
                            debias = FALSE,
                            diameter = diameter,
                            ot_niter = ot_niter,
                            ot_tol = ot_tol)
      ot <- private$ot_objects[[ls(private$ot_objects)[[1]] ]]
      # private$device <- ot$device
      private$C_xy <- ot$C_xy
      private$tensorized <- ot$tensorized
      private$b <- ot$b
      private$n <- as.integer(ot$n)
    },
    solve = function(...) {
      C_xy <- private$C_xy
      if (!private$tensorized) { 
        
        x = as_matrix(C_xy$data$x)
        y = as_matrix(C_xy$data$y)
        d = ncol(x)
        dim.red <- if(utils::packageVersion("rkeops") >= pkg_vers_number("2.0")) {
          "0"
        } else {
          "1"
        }
        argmin_op <- rkeops::keops_kernel(
          formula = paste0("ArgMin_Reduction(", C_xy$fun, ",",dim.red,")"),
          args = c(
            paste0("X = Vi(",d,")"),
            paste0("Y = Vj(",d,")"))
        )
        mins = torch::torch_tensor(c(argmin_op(list(x,y))) + 1, 
                                   dtype = torch::torch_int64(),
                                   device = private$b$device)
        
      } else {
        mins = C_xy$data$argmin(1)
      }
      w_nnm = torch::torch_bincount(self = mins, weights = private$b, minlength = private$n)
      
      for (m in ls(private$measures) ) {
        if(private$measures[[m]]$adapt == "weights") private$measures[[m]]$weights <- w_nnm
      } 
      
      # return(w_nnm)
    },
    choose_hyperparameters = function(...) {
      self$selected_lambda <- 0.0
    }
  )},
  private = list(
    b = "tensor",
    C_xy = "cost",
    device = "torch_device",
    n = "integer",
    tensorized = "logical"
  ),
  inherit = OTProblem_
)

#### Primal optimizer ####
# this uses the cotOOP paradigm, OTProblem
# therefore, no individual functions

#### Dual function optimizer ####

# forward functions in torchscript
dual_forward_code_tensorized <- "

 def calc_w1(f: Tensor, C_xy: Tensor, a_log: Tensor, b_log: Tensor, lambda: float, n: int):
   f_minus_C_lambda = f.view([n,1]) + a_log.view([n,1]) - C_xy/lambda #  f/lambda - C/lambda + a_log
   g_lambda = b_log -(f_minus_C_lambda).logsumexp(0) # = g/lambda + b_log
   w1 = (f_minus_C_lambda + g_lambda).logsumexp(1).exp()
   return w1
   
 def calc_w2(f: Tensor, C_xx: Tensor, a_log: Tensor, lambda: float, n: int):
   f_lambda = f + a_log
   w2 = ( f_lambda.view([n,1]) + f_lambda  - C_xx / lambda).logsumexp(1).log_softmax(0).exp()
   return w2
   
 def cot_dual(gamma: Tensor, C_xy: Tensor, C_xx: Tensor, a_log: Tensor, b_log: Tensor, lambda: float, n: int):
    
   f_star = gamma.detach()
   
   w1 = calc_w1(f_star, C_xy, a_log, b_log, lambda, n)
   w2 = calc_w2(f_star, C_xx, a_log,        lambda, n)
   
   measure_diff = (w1-w2).detach()
   loss =  -1.0 * gamma.dot(measure_diff) 
   # mult by -1 because is a maximization and are turning into minimization
   # print(-loss.item())
   return {'loss' : loss, 'avg_diff' : measure_diff.detach().norm(), 'bf_diff' : torch.zeros(1,dtype=loss.dtype)}
   
 def cot_bf_dual(gamma: Tensor, C_xy: Tensor, C_xx: Tensor, a_log: Tensor, b_log: Tensor, lambda: float, n: int, beta: Tensor, bf: Tensor, bt: Tensor, delta: float):
   
   w1 = calc_w1(gamma.detach(), C_xy, a_log, b_log, lambda, n)
   w2 = calc_w2(gamma.detach(), C_xx, a_log,        lambda, n)
   
   measure_diff = (w1-w2).detach()
   loss_gamma = gamma.dot(measure_diff) 
   
   bf_diff = bf.transpose(0,1).matmul(w1) - bt
   
   beta_check = bf_diff * beta.detach() - delta * beta.detach().abs()
   
   loss_beta = bf_diff.dot(beta) - beta.abs().sum() * delta
   
   loss = (loss_gamma + loss_beta) * -1.0 # mult by neg 1 because is a maximization
                                   
   return {'loss' : loss, 'avg_diff' : measure_diff.detach().norm(), 'bf_diff' :  bf_diff.detach().abs().max(), 'beta_check' : beta_check}
   
"

# rkeops forward functions
dual_forwards_keops <- list(
  calc_w1 = function(f, C_xy, a_log, b_log, lambda, n) {
    xmat <- as.matrix(C_xy$data$x$to(device = "cpu"))
    ymat <- as.matrix(C_xy$data$y$to(device = "cpu"))
    f_lambda <- f + a_log
    # f_lambda <- f
    exp_sums_g <- C_xy$reduction( list(ymat, xmat,  
                                       as.numeric(f_lambda$to(device = "cpu")),
                                       1.0 / lambda) )
    
    g_lambda <- if (utils::packageVersion("rkeops") >= pkg_vers_number("2.0")) {
       - c(exp_sums_g) + as_numeric(b_log)
    } else {
       - (log(exp_sums_g[,2]) + exp_sums_g[,1]) + as_numeric(b_log)
    }
    
    exp_sums_a1 <- C_xy$reduction( list(xmat, ymat, 
                                        g_lambda,
                                        1.0 / lambda) )
    a1_log <- if (utils::packageVersion("rkeops") >= pkg_vers_number("2.0")) {
        torch::torch_tensor(c(exp_sums_a1), dtype = f$dtype, device = f$device) + f_lambda
      } else {
        torch::torch_tensor(log(exp_sums_a1[,2]) + exp_sums_a1[,1], dtype = f$dtype, device = f$device) + f_lambda
      }
    return(a1_log$log_softmax(1)$exp()$view(-1))
  },
  calc_w2 = function(f, C_xy, a_log, lambda, n) {
    f_lambda <- f + a_log
    # f_lambda <- f
    xmat <- as.matrix(C_xy$data$x$to(device = "cpu"))
    # ymat <- as.matrix(C_xy$data$x$to(device = "cpu"))
    
    if (utils::packageVersion("rkeops") >= pkg_vers_number("2.0")) {
      log_exp_sums_a2 <- C_xy$reduction( list(xmat, xmat,
                                          as.numeric(f_lambda$to(device = "cpu")),
                                          1.0 / lambda) )
      a2_log <-  torch::torch_tensor(c(log_exp_sums_a2), dtype = f$dtype, device = f$device) + f_lambda
    } else {
        exp_sums_a2 <- C_xy$reduction( list(xmat, xmat,
                                        as.numeric(f_lambda$to(device = "cpu")),
                                        1.0 / lambda) )
        a2_log <-  torch::torch_tensor(log(exp_sums_a2[,2]) + exp_sums_a2[,1], dtype = f$dtype, device = f$device) + f_lambda
      }
    
    return(a2_log$log_softmax(1)$exp()$view(-1))
  },
  cot_dual = function(gamma, C_xy, C_xx, a_log, b_log, lambda, n) {
    f_star = gamma$detach() #+ a_log
    
    w1 = dual_forwards_keops$calc_w1(f_star, C_xy, a_log, b_log, lambda, n)$detach()$to(device = gamma$device)
    w2 = dual_forwards_keops$calc_w2(f_star, C_xx, a_log, lambda, n)$detach()$to(device = gamma$device)
    
    measure_diff = w1-w2
    loss = -1.0 * gamma$dot(measure_diff) # mult by neg 1 because is a maximization
    
    return(list(
      loss = loss,
      avg_diff = measure_diff$norm(),
      bf_diff = torch::torch_zeros(1, dtype = loss$dtype)
    ))
  },
  cot_bf_dual = function(gamma, C_xy, C_xx, a_log, b_log, lambda, n,
                         beta, bf, bt, delta) {
    f_star = gamma$detach() #+ a_log
    beta_d = beta$detach()
    # f_star1 = f_star2 - bf$matmul(beta$detach()) 
    
    w1 = dual_forwards_keops$calc_w1(f_star, C_xy, a_log, b_log, lambda, n)$detach()$to(device = gamma$device)
    w2 = dual_forwards_keops$calc_w2(f_star, C_xx, a_log, lambda, n)$detach()$to(device = gamma$device)
    
    measure_diff = w1-w2
    loss_gamma = gamma$dot(measure_diff) # mult by neg 1 because is a maximization
    
    bf_diff = bf$transpose(1,2)$matmul(w1) - bt
    
    beta_check = bf_diff * beta_d - delta * beta_d$abs()
    
    loss_beta = bf_diff$dot(beta) - delta * beta$abs()$sum()
    
    loss = (loss_gamma + loss_beta) * -1.0
    
    return(list(
      loss = loss,
      avg_diff = measure_diff$abs()$mean(),
      bf_diff = bf_diff$abs()$max(),
      beta_check = beta_check
    ))
  }
  
)



# optimizer without bf
cotDualOpt <- torch::nn_module(
  classname = "cotDualOpt",
  initialize = function(n, d = NULL, device = NULL, dtype = NULL) {
    
    self$device <- cuda_device_check(device)
    self$dtype  <- cuda_dtype_check(dtype, self$device)
    
    self$n <- torch::jit_scalar(as.integer(n))
    self$gamma <- torch::nn_parameter(
      torch::torch_zeros(self$n,
                         dtype = self$dtype, 
                         device = self$device),
      requires_grad = TRUE)
    private$set_forward(bf = FALSE)
  },
  forward = function(C_xy, C_xx, a_log, b_log, lambda, bf=NULL, bt=NULL, delta = NULL) {
    private$ts_forward(self$gamma, C_xy$data, C_xx$data, a_log, b_log, torch::jit_scalar(lambda), self$n)
  },
  backward = function(res) {
    res$loss$backward()
  },
  clone_param = function(requires_grad = FALSE) {
    if(isTRUE(requires_grad) ) {
      return(self$parameters)
    } else {
      param <- self$parameters
      out_param <- vector("list", length(param))
      names(out_param) <- names(param)
      
      for(i in seq_along(param) ) {
        out_param[[i]] <- param[[i]]$detach()$clone()
      }
      return(out_param)
    }
  },
  converged = function(res, avg_diff_old, loss_old, old_param,
                        tol=1e-5, lambda, delta) {
    machine_tol <- .Machine$double.eps
    if(is.na(delta) || is.null(delta)) delta <- 0
    
    avg_diff <- res$avg_diff$item()
    loss <- res$loss$detach()$item()
    # param <- self$parameters
    bf_diff <- res$bf_diff$item()
    
    # rel_param_sq_sum <- torch::torch_zeros(1, dtype = torch::torch_double())
    # for (i in seq_along(param)) {
    #   rel_param_sq_sum$add_( ((param[[i]]$detach() - old_param[[i]])/(old_param[[i]] + machine_tol))$norm()$square() )
    # }
    
    sq_tol <- tol * tol
    
    abs_avg_diff  <- abs(avg_diff - avg_diff_old)
    abs_loss      <- abs(loss - loss_old)
    
    abs_loss_check <- abs_loss < sq_tol
    abs_avg_diff_check <- abs_avg_diff < sq_tol
    
    abs_bf_diff <- (bf_diff - delta)
    rel_bf_diff <- abs_bf_diff/(delta + machine_tol)
    
    rel_avg_diff_check <- abs_avg_diff/(avg_diff + machine_tol) < tol
    rel_loss_check <- abs_loss/(abs(loss) + machine_tol) < tol
    # rel_param_sq_check <- rel_param_sq_sum$item() < sq_tol
    
    
    must_pass <- avg_diff < 1/lambda &&  (rel_bf_diff <= tol || abs_bf_diff <= min(sq_tol, delta/1e4) )
    
    rel_checks <- rel_loss_check || rel_avg_diff_check #|| rel_param_sq_check
    abs_checks <- abs_loss_check || abs_avg_diff_check
    
    return ( must_pass && (rel_checks || abs_checks ) )
  },
  calc_w1 =  "torch_jit",
  calc_w2 =  "torch_jit",
  dtype = "torch_dtype",
  device = "torch_dtype",
  private = list(
    ts_forward = "function",
    set_forward = function(bf = FALSE) {
      dual_forwards <- private$dual_forwards()
      private$ts_forward = switch(bf +1L,
                                  dual_forwards$cot_dual,
                                  dual_forwards$cot_bf_dual)
      self$calc_w1 =  dual_forwards$calc_w1
      self$calc_w2 =  dual_forwards$calc_w2
    },
    dual_forwards = function() {
      torch::jit_compile(dual_forward_code_tensorized)
    }
  )
)

cotDualOpt_keops <- torch::nn_module(
  classname = "cotDualOpt_keops",
  inherit = cotDualOpt,
  forward = function(C_xy, C_xx, a_log, b_log, lambda, bf, bt, delta) {
    private$ts_forward(self$gamma, C_xy, C_xx, a_log, b_log, torch::jit_scalar(lambda), self$n)
  },
  private = list(
    set_forward = function(...) {
      private$ts_forward = dual_forwards_keops$cot_dual
      self$calc_w1 = dual_forwards_keops$calc_w1
      self$calc_w2 = dual_forwards_keops$calc_w2
    }
  )
)

# optimizer with bf
cotDualBfOpt <- torch::nn_module(
  classname="cotDualBfOpt",
  inherit = cotDualOpt,
  initialize = function(n, d, device = NULL, dtype = NULL) {
    self$device <- cuda_device_check(device)
    self$dtype  <- cuda_dtype_check(dtype, self$device)
    
    self$n <- torch::jit_scalar(as.integer(n))
    self$d <- torch::jit_scalar(as.integer(d))
    self$gamma <- torch::nn_parameter(
      torch::torch_zeros(self$n, 
                         dtype = self$dtype,
                         device = self$device), requires_grad = TRUE)
    self$beta <- torch::nn_parameter(
      torch::torch_zeros(self$d, 
                         dtype = self$dtype,
                         device = self$device), requires_grad = TRUE)
    private$set_forward(bf = TRUE)
  },
  forward = function(C_xy, C_xx, a_log, b_log, lambda, bf, bt, delta) {
    torch::with_no_grad(self$gamma$sub_(bf$matmul(self$beta)))
    res <- private$ts_forward(self$gamma, C_xy$data, C_xx$data, a_log, b_log, torch::jit_scalar(lambda), self$n,
                       self$beta, bf, bt, torch::jit_scalar(delta))
    return(res)
  },
  backward = function(res) {
    res$loss$backward()
    torch::with_no_grad(self$beta$mul_(res$beta_check > 0))
  }
)

cotDualBfOpt_keops <- torch::nn_module(
  classname = "cotDualBfOpt_keops",
  inherit = cotDualBfOpt,
  forward = function(C_xy, C_xx, a_log, b_log, lambda, bf, bt, delta) {
    torch::with_no_grad(self$gamma$sub_(bf$matmul(self$beta)))
    res <- private$ts_forward(self$gamma, C_xy, C_xx, a_log, b_log, torch::jit_scalar(lambda), self$n,
                       self$beta, bf, bt, torch::jit_scalar(delta))
    return(res)
  },
  private = list(
    set_forward = function(...) {
      private$ts_forward = dual_forwards_keops$cot_bf_dual
      self$calc_w1 = dual_forwards_keops$calc_w1
      self$calc_w2 = dual_forwards_keops$calc_w2
    }
  )
)

# main object to do training, inherit from OTproblem and makes changes where needed
cotDualTrain <- R6::R6Class(
  classname = "cotDualTrain",
  public = {list(
    ot = "OT",
    setup_arguments = function(lambda = NULL, delta = NULL , 
                               grid.length = 7L,
                               cost.function = NULL, 
                               p = 2,
                               cost.online = "auto",
                               debias = TRUE,
                               diameter = NULL, ot_niter = 1000L,
                               ot_tol = 1e-5) {
      super$setup_arguments(lambda, delta, 
                            grid.length,
                            cost.function, 
                            p,
                            cost.online ,
                            debias,
                            diameter , ot_niter,
                            ot_tol)
      
      m_add   <- ls(private$measures)
      p_add   <- ls(private$problems)
      adapt_m <- private$problems[[ p_add[[1]] ]][1]
      targ_m  <- private$problems[[ p_add[[1]] ]][2]
      
      self$ot <- private$ot_objects[[ p_add[1] ]]
      
      private$C_xy <- self$ot$C_xy
      private$C_xx <- self$ot$C_xx
      # private$a_log <- log_weights(self$ot$a$detach())
      # private$b_log <- log_weights(self$ot$b$detach())
      # makes sure the non- changing weights are used...
      stopifnot("Wrong measure is being fed for adapatation. Report this bug please!" = private$measures[[ adapt_m ]]$requires_grad)
      private$a_log <- log_weights(private$measures[[ adapt_m ]]$init_weights$detach())
      private$b_log <- log_weights(private$measures[[ targ_m  ]]$init_weights$detach())
      
      runbf <-  length(private$target_objects) >= 1
      if (runbf) {
        t_add <- ls(private$target_objects)
        targ  <- private$target_objects[[ t_add[1] ]]
        
        private$bf <- targ$bf
        private$bt <- targ$bt
        
      } else {
        private$bf <- private$bt <- NULL
      }
      
      tensorized <- self$ot$tensorized
      nn_fun <- switch(tensorized * 2 + runbf + 1L ,
                       cotDualOpt_keops,
                       cotDualBfOpt_keops,
                       cotDualOpt,
                       cotDualBfOpt
      )
      private$nn_holder  <- nn_fun$new(n = self$ot$n, 
                                       d = length(private$bt),
                                       device = private$measures[[m_add[1L]]]$device,
                                       dtype = private$measures[[m_add[1L]]]$dtype)
      private$parameters <- private$nn_holder$parameters
      
      private$penalty_list$lambda[ is.infinite(private$penalty_list$lambda) ] <- self$ot$diameter * 1e5
      private$penalty_list$lambda[private$penalty_list$lambda == 0] <- self$ot$diameter / 1e9
      # runs faster and more accurately when reversed (small -> large)
      private$penalty_list$lambda <- sort(private$penalty_list$lambda, decreasing = FALSE)
      
      private$lambda <- private$penalty_list$lambda[1L]
      
      # private$niter <- private$ot_niter
      # private$tol <- private$ot_tol
      private$prev_lambda <- private$lambda
      return(invisible(self))
      
    }
  )},
  active = {list(
    weights = function(value) {
      f1 <- private$nn_holder$gamma$detach()$clone() #+ private$a_log
      # f2 <- private$nn_holder$gamma$detach()
      if (!is.null(private$bf)) {
        f1$sub_(private$bf$matmul(private$nn_holder$beta$detach()))
      } 
      if (inherits(private$C_xy, "costTensor") ){
        C_xy <- private$C_xy$data
        C_xx <- private$C_xx$data
      } else {
        C_xy <- private$C_xy
        C_xx <- private$C_xx
      }
      w1 <- private$nn_holder$calc_w1(f1, C_xy, private$a_log, private$b_log, torch::jit_scalar(private$lambda), torch::jit_scalar(private$nn_holder$n))
      w2 <- private$nn_holder$calc_w2(f1, C_xx, private$a_log, torch::jit_scalar(private$lambda), torch::jit_scalar(private$nn_holder$n))
      w  <- (w1 + w2) * 0.5
      # return(list(a = a, a1 = a1, a2 = a2))
      return(w)
    }
  )},
  private = {list(
    eval_lambda = function (wts, boot) {
      addresses <- ls(private$measures)
      prob_adds <- ls(private$ot_objects)
      wt_adds   <- ls(wts)
      sel_probs <- NULL
      prob_hold <- NULL
      p1_add    <- p2_add <- NULL
      
      n         <- NULL
      w         <- NULL
      b         <- NULL
      m         <- NULL
      w_star    <- NULL
      ot_obj    <- NULL
      meas      <- NULL
      
      for (add in addresses) {
        b    <- boot[[add]]
        meas <- private$measures[[add]]
        
        w <- if(add %in% wt_adds) {
          wts[[add]]$detach()
        } else {
          meas$init_weights$detach()
        }
        
        if (meas$adapt == "weights" || meas$adapt == "none") {
          w_star <- w * b
        } else if (meas$adapt == "x") {
          w_star <- meas$init_weights * b
        } else {
          stop("Bug in this if else statement!")
        }
        
        sel_probs <- prob_adds[grep(add, prob_adds)]
        for (i in sel_probs) {
          prob_hold <- private$problems[[i]]
          p1_add <- prob_hold[[1]]
          p2_add <- prob_hold[[2]]
          if (p1_add == add){
            private$ot_objects[[i]]$a <- w_star
            if (meas$adapt == "x") {
              update_cost(private$ot_objects[[i]]$C_xy, w, private$measures[[p2_add]]$x$detach())
            }
          } else if (p2_add == add) {
            private$ot_objects[[i]]$b <- w_star
            if (meas$adapt == "x") {
              update_cost(private$ot_objects[[i]]$C_yx, w, private$measures[[p1_add]]$x$detach())
            }
          } else {
            stop("OT Problem address not found. You found a bug!")
          }
        }
        
      }
      # browser()
      # private$ot_objects[[i]]$penalty <- private$boot_lambda #0.05 #private$ot_objects[[i]]$diameter
      if(is.finite(private$ot_objects[[i]]$penalty ))  private$ot_objects[[i]]$sinkhorn_opt(20, 1e-3)
      dists <- self$loss$detach()$item()
      return(dists)
    },
    set_lambda = function (l) {
      stopifnot(l >= 0.0)
      stopifnot(l <= Inf)
      if(l == 0.0) l <- self$ot$diameter / 1e9
      super$set_lambda(l)
      if(is.infinite(l)) l <- self$ot$diameter * 1e4
     
      private$lambda <- torch::jit_scalar(l)
    },
    set_delta = function (d) {
      stopifnot(d >= 0.0 || is.na(d))
      private$delta <- if(!is.na(d)) {
        torch::jit_scalar(d)
      } else if(length(private$target_objects) >= 1L) {
         private$target_objects[[ls(private$target_objects)]]$delta
      } else {
        NA_real_
      }
      
    },
    set_penalties = function (value) {
      if (is.list(value)) {
        if (!all(names(value) %in% c("lambda", "delta")) ) {
          stop("If penalties are provided as a list, must be with names in c('lamba', 'delta')")
        }
        lambda <- value$lambda
        delta <- value$delta
      } else {
        if (any(!is.null(names(value))) ) {
          if (!all(names(value) %in% c("lambda", "delta")) ) {
            stop("If penalties are provided as a named vector, must be with names in c('lamba', 'delta')")
          }
          lambda <- value["lambda"]
          delta <- value["delta"]
          names(lambda) <- NULL
          names(delta) <- NULL
        } else {
          lambda <- value[1]
          delta <- NA_real_
          if(length(value) > 1) warning("For unnamed vectors, only the first number is used to set the OT penalty, lambda. Other values are ignored.")
        }
      }
      private$set_lambda(lambda)
      private$set_delta(delta)
    },
    optimization_loop = function (niter, tol) {
      
      #reset torch optimizer if present
      private$torch_optim_reset()
      
      # reset the parameters to 0, speeds up estimation for lower lambdas
      # torch::with_no_grad(private$nn_holder$gamma$mul_(private$lambda/private$prev_lambda))
      # torch::with_no_grad(private$nn_holder$gamma$copy_(0.0))
      if(!is.null( private$nn_holder$beta) ) torch::with_no_grad(private$nn_holder$beta$copy_(0.0))
      
      avg_diff_old <- loss_old <- 10.0
      old_param <- NULL #private$nn_holder$clone_param()
      
      for (i in 1:niter) {
        private$opt$zero_grad()
        res <- private$nn_holder$forward(private$C_xy, private$C_xx, 
                                         private$a_log, 
                                         private$b_log, 
                                         private$lambda, private$bf, private$bt, private$delta)
        # print(res$loss$detach()$item())
        if ((i > 2L) && private$nn_holder$converged(res, avg_diff_old, 
                                                    loss_old, old_param,
                                        tol,
                                        private$lambda, private$delta)) {
          break
        } else {
          avg_diff_old <- res$avg_diff$item()
          loss_old <- res$loss$detach()$item()
          # old_param <- private$nn_holder$clone_param()
        }
        
        private$nn_holder$backward(res)
        private$opt$step()
        check <- private$lr_reduce(avg_diff_old)
        # private$sched$step()
        
      }
      private$prev_lambda <- private$lambda
      return(list(loss = res$loss$detach()$item(),
                  iter = i))
      
    }, # overwrite super function
    parameters_get_set = function (value, clone = FALSE) {
      
      ifthenfun <- function(v) {
        if (is.list(private$parameters[[v]]) ) {
          private$parameters[[v]]$params
        } else {
          private$parameters[[v]]
        }
      }
      
      # return a clone of the weights
      if ( missing(value) ) {
        out <- rlang::env()
        class(out) <- c("weightEnv", class(out))
        
        param_addresses <- ls(private$measures)
        for (v in param_addresses) {
          # out[[v]] <- if (isTRUE(clone)) {
          #   ifthenfun(v)$detach()$clone()
          # } else {
          #   ifthenfun(v)
          # }
          if(private$measures[[v]]$adapt == "weights") {
            out[[v]] <- self$weights
          }
         
        }
        return(out)
      }
      
      # set weights
      param_addresses <- ls(private$measures)
      
      if(!rlang::is_environment(value)){
        stopifnot("value must be a list or environment" = is.list(value))
        names(value) <- value_addresses <- 1:length(value)
      } else {
        value_addresses <- ls(value)
      }
      
      if(length(value_addresses) == 1) {
        v <- NULL
        u <- NULL
        torch::with_no_grad(
          for ( i in seq_along(param_addresses) ) {
            v <- param_addresses[[i]]
            u <- value_addresses[[1L]]
            if (private$measures[[v]]$adapt == "weights")  {
              private$measures[[v]]$weights <- value[[u]]
            }
          }
        )
      } else {
        stop("Input must have same number of groups as do the parameters.")
      }
    }, #overwrite super
    torch_optim_setup = function (torch_optim, 
                                 torch_scheduler,
                                 torch_args) {
      
      names_args <- names(torch_args)
      optim_args_names <- names(formals(torch_optim))
      opt_args <- torch_args[match(optim_args_names,
                                   names_args, nomatch = 0L)]
      param <- private$parameters
      gamma_lr <- if(!is.null(opt_args$lr)) {
        opt_args$lr
        } else {
          1e-2 #private$lambda/100
        }
      param$gamma <- if(!is.list(param$gamma) ) {
        list(params = param$gamma, lr = gamma_lr)
      } else {
        list(params = param$gamma$params, lr = gamma_lr)
      }
      if(!is.null(param$beta)) {
        param$beta <- if(!is.list(param$beta)) {
          list(params = param$beta,
                           lr = min(private$delta, min(opt_args$lr, 1e-2)))
        } else {
          list(params = param$beta$params,
               lr = min(private$delta, min(opt_args$lr, 1e-2)))
        }
      }
      private$parameters <- param
      opt_call <- rlang::call2(torch_optim,
                               params = private$parameters,
                               !!!opt_args)
      private$opt <- eval(opt_call)
      
      torch_lbfgs_check(private$opt)
      if (inherits(private$opt, "optim_lbfgs")) {
        warning("Torch's LBFGS optimizer does not work well on the dual problem. Please use another optimizer.")
      }
      
      if (!is.null(torch_scheduler)) {
        scheduler_args_names <- names(formals(torch_scheduler))
        sched_args <- torch_args[match(scheduler_args_names, 
                                       names_args, 
                                       nomatch = 0L)]
        if(inherits(torch_scheduler, "lr_reduce_on_plateau") &&
           is.null( sched_args$patience )  && inherits(private$opt, "optim_lbfgs")) {
          sched_args$patience <- 1L
        }
        if(inherits(torch_scheduler, "lr_reduce_on_plateau") && is.null( sched_args$min_lr ) && inherits(private$opt, "optim_lbfgs") ) {
          sched_args$min_lr <- private$opt$defaults$lr * 1e-3
        }
        if(inherits(torch_scheduler, "lr_multiplicative") &&
           is.null( sched_args$lr_lambda ) ) {
          sched_args$lr_lambda <- function(epoch) {0.99}
        }
        scheduler_call <- rlang::call2(torch_scheduler,
                                       optimizer = private$opt,
                                       !!!sched_args)
        opt_sched <- eval(scheduler_call)
      } else {
        opt_sched <- scheduler_call <- NULL
      }
      
      private$opt_calls <- list(opt = opt_call,
                                sched = scheduler_call)
      
      private$sched <- opt_sched
      
      # return(list(opt = opt, opt_call = opt_call,
      #             sched = opt_sched, sched_call = scheduler_call))
      
    },
    torch_optim_reset = function (lr = NULL) {
      # browser()
      if(!is.null(private$opt)) {
        default_lr <- private$opt$defaults$lr
        opt_call <- private$opt_calls$opt
        
        if (!is.null(lr)) {
          private$parameters$gamma$lr <- lr
        # } else if(is.null(opt_call$lr)) {
        #   private$parameters$gamma$lr <- private$lambda/100
        } else if (!is.null(opt_call$lr)) {
          private$parameters$gamma$lr <- opt_call$lr
        } else {
          private$parameters$gamma$lr <- default_lr
        }
          
        if( is.null(lr)) lr <- default_lr
        
        if(!is.null(private$parameters$beta)) {
          if(!is.na(private$delta)) {
            lr_new <- min(private$delta, lr)
          } else {
            lr_new <- lr
          }
          private$parameters$beta$lr <- lr_new
        }
        
        if(is.null(lr)) {
          private$opt <- eval(rlang::call_modify(opt_call, params = private$parameters))
        } else {
          # browser()
          # def <- rlang::call_match(opt_call[[1]], opt_call, defaults = TRUE)
          private$opt <- eval(rlang::call_modify(opt_call, params = private$parameters,
                                                 lr = lr))
        }
        
      }
      
      if(!is.null(private$sched)) {
        private$sched <- eval(rlang::call_modify(private$opt_calls$sched, 
                                                 optimizer = private$opt))
      }
      
    },
    ot_update = function (...) {NULL},
    lambda = "numeric",
    delta = "numeric",
    # tol = "numeric",
    nn_holder = "dualCotOpt",
    # niter = "integer",
    optim = "optim",
    prev_lambda = "numeric",
    sched = "scheduler",
    C_xy = "torch_tensor",
    C_xx = "torch_tensor",
    a_log = "torch_tensor",
    b_log = "torch_tensor",
    bf = "torch_tensor",
    bt = "torch_tensor"
  )},
  inherit = OTProblem_
)

Try the causalOT package in your browser

Any scripts or data that you put into this service are public.

causalOT documentation built on May 29, 2024, 6:16 a.m.