R/svdstuff.R

#' R6 class for a SVD worker object to use with master objects generated by [SVDMaster()]
#' @description `SVDWorker` objects are worker objects at each site of a distributed SVD model computation
#'
#' @seealso [SVDMaster()] which goes hand-in-hand with this object
#' @importFrom R6 R6Class
#' @export
SVDWorker <- R6Class("SVDWorker",
                     private = list(
                         ## the computation definition
                         defn = NA,
                         ## the stateful flag
                         stateful = TRUE,
                         ## the current x value
                         x = NA,
                         ## the x matrix value
                         u = NA,
                         ## the current p value
                         p = NA,
                         ## the working copy of x
                         workX = NA
                     ),
                     public = list(

                         #' @description
                         #' Create a new `SVDWorker` object.
                         #' @param defn the computation definition
                         #' @param data the local `x` matrix
                         #' @param stateful a boolean flag indicating if state needs to be preserved between REST calls, `TRUE` by default
                         #' @return a new `SVDWorker` object
                         initialize = function(defn, data, stateful = TRUE) {
                             private$x <- private$workX <- data
                             private$u <- rep(1, nrow(data))
                             private$stateful <- stateful
                             stopifnot(self$kosher())
                         },

                         #' @description
                         #' Reset the computation state by initializing work matrix and set up starting values for iterating
                         reset = function() {
                             private$workX <- private$x
                             private$u <- rep(1, nrow(private$x))
                         },

                         #' @description
                         #' Return the dimensions of the matrix
                         #' @param ... other args ignored
                         #' @return the dimension of the matrix
                         dimX = function(...) dim(private$x),

                         #' @description
                         #' Return an updated value for the `V` vector, normalized by `arg`
                         #' @param arg the normalizing value
                         #' @param ... other args ignored
                         #' @return updated `V`
                         updateV = function(arg, ...) {
                             t(private$workX) %*% private$u / arg
                         },

                         #' @description
                         #' Update `U` and return the updated norm of `U`
                         #' @param arg the initial value
                         #' @param ... other args ignored
                         #' @return updated norm of `U`
                         updateU = function(arg, ...) {
                             u <- private$u <- as.numeric(private$workX %*% arg)
                             sum(u^2)
                         },

                         #' @description
                         #' Normalize `U` vector
                         #' @param arg the normalizing value
                         #' @param ... other args ignored
                         #' @return `TRUE` invisibly
                         normU = function(arg, ...) {
                             private$u <- private$u / arg
                             invisible(TRUE)
                         },

                         #' @description
                         #' Construct residual matrix using `arg`
                         #' @param arg the value to use for residualizing
                         #' @param ... other args ignored
                         fixU = function(arg, ...) {
                             private$workX <- private$workX - private$u %*% t(arg)
                         },

                         #' @description
                         #' Getthe number of rows of `x` matrix
                         #' @return the number of rows of `x` matrix
                         getN = function() {
                             nrow(private$x)
                         },

                         #' @description
                         #' Getthe number of columnsof `x` matrix
                         #' @return the number of columns of `x` matrix
                         getP = function() {
                             ncol(private$x)
                         },

                         #' @description
                         #' Return the stateful status of the object.
                         #' @return the stateful flag, `TRUE` or `FALSE`
                         getStateful = function() private$stateful,

                         #' @description
                         #' Check if inputs and state of object are sane. For future use
                         #' @return `TRUE` or `FALSE`
                         kosher = function() {
                             TRUE
                         })
                     )

#' R6 class for SVD master object to control worker objects generated by [SVDWorker()]
#' @description `SVDMaster` objects instantiate and run a distributed SVD computation
#'
#' @seealso [SVDWorker()] which goes hand-in-hand with this object
#' @importFrom R6 R6Class
#' @importFrom httr POST add_headers
#' @importFrom jsonlite toJSON
#' @export
SVDMaster <- R6Class("SVDMaster",
                     private = list(
                         ## computation definition
                         defn = NA,
                         ## dry_run flag for prototyping
                         dry_run = FALSE,
                         ## list of sites
                         sites = list(),
                         ## dimension of x
                         dimX = NA,
                         ## mapping function for sites
                         mapFn = function(site, arg, method) {
                             payload <- list(objectId = site$instanceId,
                                             method = method,
                                             arg = arg)
                             q <- httr::POST(.makeOpencpuURL(urlPrefix=site$url, fn="executeMethod"),
                                             body = jsonlite::toJSON(payload),
                                             httr::add_headers("Content-Type" = "application/json"),
                                             config = getConfig()$sslConfig
                                             )
                             ## Should really examine result here.
                             .deSerialize(q)
                         },
                         ## current result
                         result = list(),
                         ## debug flag
                         debug = FALSE
                     ),
                     public = list(

                         #' @description
                         #' `SVDMaster` objects instantiate and run a distributed SVD computation
                         #' @param defn a computation definition
                         #' @param debug a flag for debugging, default `FALSE`
                         #' @return R6 `SVDMaster` object
                         initialize = function(defn, debug = FALSE) {
                             'Initialize the object with a dataset'
                             private$defn <- defn
                             private$debug <- debug
                             stopifnot(self$kosher())
                         },

                         #' @description
                         #' Check if inputs and state of object are sane. For future use
                         #' @return `TRUE` or `FALSE`
                         kosher = function() {
                             ' Check for appropriateness'
                             TRUE
                         },

                         #' @description
                         #' Return an updated value for the `V` vector, normalized by `arg`
                         #' @param arg the normalizing value
                         #' @param ... other args ignored
                         #' @return updated `V`
                         updateV = function(arg) { ## Here arg is a list of right size already
                             'Compute or Update VD'
                             sites <- private$sites
                             n <- length(sites)
                             if (private$dry_run) {
                                 mapFn <- function(x, arg) x$worker$updateV(arg)
                                 results <- Map(mapFn, sites, rep(list(arg), n))
                             } else {
                                 results <- Map(private$mapFn, sites, rep(arg, n), rep(list("updateV"), n))
                             }
                             vd <- Reduce(f = '+', results)
                             vd / sqrt(sum(vd^2))
                         },

                         #' @description
                         #' Update `U` and return the updated norm of `U`
                         #' @param arg the normalizing value
                         #' @param ... other args ignored
                         #' @return updated norm of `U`
                         updateU = function(arg) { ## arg is a single vector
                             'Compute/Update U'
                             sites <- private$sites
                             n <- length(sites)
                             if (private$dry_run) {
                                 mapFn <- function(x, arg) x$worker$updateU(arg)
                                 results <- Map(mapFn, sites, rep(list(arg), n))
                             } else {
                                 results <- Map(private$mapFn, sites, rep(list(arg), n), rep(list("updateU"), n))
                             }
                             sqrt(Reduce(f = '+', results))
                         },

                         #' @description
                         #' Construct the residual matrix using given the `V` vector and `d` so far
                         #' @param v the value for `v`
                         #' @param d the value for `d`
                         #' @return result
                         fixFit = function(v, d) { ## arg is a single vector
                             'Compute/Update U'
                             result <- private$result
                             result$v <- cbind(result$v, v)
                             result$d <- c(result$d, d)

                             sites <- private$sites
                             n <- length(sites)
                             if (private$dry_run) {
                                 mapFn <- function(x, v) x$worker$fixU(v)
                                 Map(mapFn, sites, rep(list(v), n))
                             } else {
                                 Map(private$mapFn, sites, rep(list(v), n), rep(list("fixU"), n))
                             }
                             private$result <- result
                         },

                         #' @description
                         #' Reset the computation state by initializing work matrix and set up starting values for iterating
                         reset = function() {
                             private$result <- list()
                             sites <- private$sites
                             n <- length(sites)
                             if (private$dry_run) {
                                 lapply(sites, function(x) x$worker$reset())
                             } else {
                                 mapFn = function(site, method) {
                                     payload <- list(objectId = site$instanceId,
                                                     method = method)
                                     q <- httr::POST(.makeOpencpuURL(urlPrefix=site$url, fn="executeMethod"),
                                                     body = jsonlite::toJSON(payload),
                                                     httr::add_headers("Content-Type" = "application/json"),
                                                     config = getConfig()$sslConfig
                                                     )
                                 }
                                 Map(mapFn, sites, rep(list("reset"), n))
                             }


                         },

                         #' @description
                         #' Add a url or worker object for a site for participating in the distributed computation. The worker object can be used to avoid complications in debugging remote calls during prototyping.
                         #' @param name of the site
                         #' @param url web url of the site; exactly one of `url` or `worker` should be specified
                         #' @param worker worker object for the site; exactly one of `url` or `worker` should be specified
                         addSite = function(name, url = NULL, worker = NULL) {
                             ## critical section start
                             ## This is the time to cache "p" and check it
                             ## against all added sites
                             ## Only one of url/worker should be non-null
                             stopifnot(is.null(url) || is.null(worker))
                             n <- length(private$sites)
                             if (is.null(url)) {
                                 private$dry_run <- private$dry_run || TRUE
                                 private$sites[[n+1]] <- list(name = name, worker = worker)
                             } else {
                                 localhost <- (grepl("^http://localhost", url) ||
                                               grepl("^http://127.0.0.1", url))
                                 private$sites[[n+1]] <- list(name = name, url = url,
                                                              localhost = localhost,
                                                              dataFileName = if (localhost) paste0(name, ".rds") else NULL)
                             }
                             ## critical section end
                         },

                         #' @description
                         #' Run the distributed Cox model fit and return the estimates
                         #' @param thr the threshold for convergence, default 1e-8
                         #' @param max.iter the maximum number of iterations, default 100
                         #' @return a named list of `V`, `d`
                         run = function(thr = 1e-8, max.iter = 100) {
                             'Run Calculation'
                             dry_run <- private$dry_run
                             defn <- private$defn
                             debug <- private$debug
                             n <- length(sites)
                             stopifnot(n > 1)
                             if (debug) {
                                 print("run(): checking worker object creation")
                             }

                             if (dry_run) {
                                 ## Workers have already been created and passed
                                 sites <- private$sites
                                 pVals <- sapply(sites, function(x) x$worker$getP())
                             } else {
                                 ## Make remote call to instantiate workers
                                 instanceId <- generateId(object=list(Sys.time(), self))
                                 ## Augment each site with object instance ids
                                 private$sites <- sites <- lapply(private$sites,
                                                                  function(x) list(name = x$name,
                                                                                   url = x$url,
                                                                                   localhost = x$localhost,
                                                                                   dataFileName = x$dataFileName,
                                                                                   instanceId = if (x$localhost) x$name else instanceId))
                                 sitesOK <- sapply(sites,
                                                   function(x) {
                                                       payload <- if (is.null(x$dataFileName)) {
                                                                      list(defnId = defn$id, instanceId = x$instanceId)
                                                                  } else {
                                                                      list(defnId = defn$id, instanceId = x$instanceId,
                                                                           dataFileName = x$dataFileName)
                                                                  }
                                                       q <- httr::POST(url = .makeOpencpuURL(urlPrefix=x$url, fn="createWorkerInstance"),
                                                                       body = jsonlite::toJSON(payload),
                                                                       httr::add_headers("Content-Type" = "application/json"),
                                                                       config = getConfig()$sslConfig
                                                                       )
                                                       .deSerialize(q)
                                                   })

                                 ## I am not checking the value of p here; I do it later below
                                 if (!all(sitesOK)) {
                                     warning("run():  Some sites did not respond successfully!")
                                     sites <- sites[which(sitesOK)]  ## Only use sites that created objects successfully.
                                 }
                                 ## stop if no sites
                                 if (debug) {
                                     print("run(): checking p")
                                 }

                                 pVals <- sapply(sites,
                                                 function(x) {
                                                     payload <- list(objectId = x$instanceId, method = "getP")
                                                     q <- httr::POST(.makeOpencpuURL(urlPrefix=x$url, fn="executeMethod"),
                                                                     body = jsonlite::toJSON(payload),
                                                                     httr::add_headers("Content-Type" = "application/json"),
                                                                     config=getConfig()$sslConfig
                                                                     )
                                                     .deSerialize(q)
                                                 })
                             }
                             if (debug) {
                                 print(pVals)
                             }
                             p <- pVals[1]
                             if (any(pVals != p)) {
                                 stop("run(): Heterogeneous sites! Stopping!")
                             }

                             returnCode <- 0
                             self$reset()
                             k <- private$defn$rank
                             for (j in seq_len(k)) {
                                 v <-  rep(1.0, p)
                                 vold <- rep(0.0, p)
                                 for (i in seq.int(max.iter)) {
                                     unorm <- self$updateU(v)
                                     v <- self$updateV(unorm) # computes vd
                                     discrepancy <- max(abs(v - vold))
                                     if (debug) {
                                         print(paste("Iteration:", i, "; Discrepancy: ", discrepancy))
                                     }
                                     if (discrepancy < thr) break
                                     vold <- v
                                 }
                                 self$fixFit(v, unorm)
                             }

                             if (!dry_run) {
                                 if (debug) {
                                     print("run(): checking worker object cleanup")
                                 }
                                 sitesOK <- sapply(sites,
                                                   function(x) {
                                                       payload <- list(instanceId = x$instanceId)
                                                       q <- httr::POST(url = .makeOpencpuURL(urlPrefix=x$url, fn="destroyInstanceObject"),
                                                                       body = jsonlite::toJSON(payload),
                                                                       httr::add_headers("Content-Type" = "application/json"),
                                                                       config=getConfig()$sslConfig
                                                                       )
                                                       .deSerialize(q)
                                                   })
                                 if (!all(sitesOK)) {
                                     warning("run():  Some sites did not clean up successfully!")
                                 }
                             }
                             return(private$result)
                         },
                         #' @description
                         #' Return the summary result
                         #' @return a named list of `V`, `d`
                         summary = function() {
                             'Return the summary'
                             result <- private$result
                             if (length(result) == 0) {
                                 stop ("Run the computation first using run()")
                                 result <- private$result
                             }
                             result
                         }
                     )
                     )
hrpcisd/distcomp documentation built on Feb. 14, 2023, 4:56 p.m.