R/querycount.R

#' R6 worker object for use as a worker with master objects generated by [QueryCountMaster()]
#'
#' @description `QueryCountWorker` objects are worker objects at each site of
#' a distributed QueryCount model computation
#'
#' @seealso [QueryCountMaster()] which goes hand-in-hand with this object
#'
#' @importFrom R6 R6Class
#' @export
QueryCountWorker <- R6::R6Class(
    "QueryCountWorker",

    private = list(
        ## the computation definition
        defn = NA,

        ## the local data
        data = NA,

        ## the stateful flag
        stateful = FALSE
    )
  , public = list(

        #' @description
        #' Create a new `QueryCountWorker` object.
        #' @param defn the computation definition
        #' @param data the local data
        #' @param stateful the statefulness flag, default `FALSE`
        #' @return a new `QueryCountWorker` object
        initialize = function(defn, data, stateful = FALSE) {
            private$defn  <- defn
            private$data  <- data
            stopifnot(self$kosher())
        },

        #' @description
        #' Retrieve the value of the `stateful` field
        getStateful = function() {
            private$stateful
        },

        #' @description
        #' Check if inputs and state of object are sane. For future use
        #' @return `TRUE` or `FALSE`
        kosher = function() {
            TRUE
        },

        #' @description
        #' Return the query count on the local data
        #' @import rlang
        #' @importFrom dplyr filter
        #' @importFrom magrittr %>%
        queryCount = function() {
            data  <- private$data
            filter_expr  <- eval(parse(text = paste("rlang::expr(", private$defn$filterCondition, ")")))
            data %>%
                dplyr::filter(!! filter_expr) %>%
                nrow()
        }
    )
)

#' Create a master object to control worker objects generated by [QueryCountWorker()]
#'
#' @description `QueryCountMaster` objects instantiate and run a distributed query count computation
#'
#' @docType class
#' @seealso [QueryCountWorker()] which goes hand-in-hand with this object
#' @importFrom R6 R6Class
#' @importFrom httr POST add_headers
#' @importFrom jsonlite toJSON
#' @export
QueryCountMaster <- R6::R6Class(
    "QueryCountMaster",
    private = list(
        defn = NA,
        dry_run = FALSE,
        sites = list(),
        mapFn = function(site) {
            payload <- list(objectId = site$instanceId,
                            method = "queryCount")
            q <- httr::POST(.makeOpencpuURL(urlPrefix=site$url, fn="executeMethod"),
                            body = jsonlite::toJSON(payload),
                            httr::add_headers("Content-Type" = "application/json"),
                            config = getConfig()$sslConfig
                            )
            ## Should really examine result here.
            .deSerialize(q)
        },
        result_cache = list(),
        debug = FALSE
    ),

    public = list(

        #' @description
        #' Create a new `QueryCountMaster` object.
        #' @param defn the computation definition
        #' @param debug a flag for debugging, default `FALSE`
        #' @return a new `QueryCountMaster` object
        initialize = function(defn, debug = FALSE) {
            'Initialize the object with a dataset'
            private$defn <- defn
            private$debug <- debug
            stopifnot(self$kosher())
        },

        #' @description
        #' Check if inputs and state of object are sane. For future use
        #' @return `TRUE` or `FALSE`
        kosher = function() {
            ' Check for appropriateness'
            TRUE
        },

        #' @description
        #' Run the distributed query count and return the result
        #' @return the count
        queryCount = function() {
            'Compute the count'
            sites <- private$sites
            if (private$dry_run) {
                mapFn <- function(x) x$worker$queryCount()
            } else {
                mapFn <- private$mapFn
            }
            results <- Map(mapFn, sites)
            value <- Reduce(f = sum, results)
            if (private$debug) {
                print("value")
                print(value)
            }
            value
        },

        #' @description
        #' Retrieve the value of the private `sites` field
        getSites = function() {
            private$sites
        },

        #' @description
        #' Add a url or worker object for a site for participating in the distributed computation. The worker object can be used to avoid complications in debugging remote calls during prototyping.
        #' @param name of the site
        #' @param url web url of the site; exactly one of `url` or `worker` should be specified
        #' @param worker worker object for the site; exactly one of `url` or `worker` should be specified
        addSite = function(name, url = NULL, worker = NULL) {
            ## critical section start
            ## This is the time to cache "p" and check it
            ## against all added sites
            ## Only one of url/worker should be non-null
            stopifnot(is.null(url) || is.null(worker))
            n <- length(private$sites)
            if (is.null(url)) {
                private$dry_run <- private$dry_run || TRUE
                private$sites[[n+1]] <- list(name = name, worker = worker)
            } else {
                localhost <- (grepl("^http://localhost", url) ||
                              grepl("^http://127.0.0.1", url))
                private$sites[[n+1]] <- list(name = name, url = url,
                                             localhost = localhost,
                                             dataFileName = if (localhost) paste0(name, ".rds") else NULL)
            }
            ## critical section end
        },

        #' @description
        #' Run the distributed query count
        #' @return the count


        run = function() {
            'Run Computation'
            dry_run <- private$dry_run
            defn <- private$defn
            debug <- private$debug
            n <- length(sites)
            if (debug) {
                print("run(): checking worker object creation")
            }
            if (dry_run) {
                ## Workers have already been created and passed
                sites <- private$sites
            } else {
                ## Create an instance Id
                ## Make remote call to instantiate workers
                instanceId <- generateId(object=list(Sys.time(), self))
                ## Augment each site with object instance ids
                private$sites <- sites <- lapply(private$sites,
                                                 function(x) list(name = x$name,
                                                                  url = x$url,
                                                                  localhost = x$localhost,
                                                                  dataFileName = x$dataFileName,
                                                                  instanceId = if (x$localhost) x$name else instanceId))
                sitesOK <- sapply(sites,
                                  function(x) {
                                      payload <- if (is.null(x$dataFileName)) {
                                                     list(defnId = defn$id, instanceId = x$instanceId)
                                                 } else {
                                                     list(defnId = defn$id, instanceId = x$instanceId,
                                                          dataFileName = x$dataFileName)
                                                 }
                                      q <- httr::POST(url = .makeOpencpuURL(urlPrefix=x$url, fn="createWorkerInstance"),
                                                      body = jsonlite::toJSON(payload),
                                                      httr::add_headers("Content-Type" = "application/json"),
                                                      config = getConfig()$sslConfig
                                                      )
                                      .deSerialize(q)
                                  })

                ## Stop on error
                if (!all(sitesOK)) {
                    stop("run():  Some sites did not respond successfully!")
                    sites <- sites[which(sitesOK)]  ## Only use sites that created objects successfully.
                }
            }
            result  <- self$queryCount()

            if (!dry_run) {
                if (debug) {
                    print("run(): checking worker object cleanup")
                }
                sitesOK <- sapply(sites,
                                  function(x) {
                                      payload <- list(instanceId = x$instanceId)
                                      q <- POST(url = .makeOpencpuURL(urlPrefix=x$url, fn="destroyInstanceObject"),
                                                body = jsonlite::toJSON(payload),
                                                add_headers("Content-Type" = "application/json"),
                                                config=getConfig()$sslConfig
                                                )
                                      .deSerialize(q)
                                  })
                if (!all(sitesOK)) {
                    warning("run():  Some sites did not clean up successfully!")
                }
            }
            result
        }
    )
)

#' Create a homomorphic computation query count worker object for use with master objects generated by [HEQueryCountMaster()]
#' @description `HEQueryCountWorker` objects are worker objects at each site of
#' a distributed query count model computation using homomorphic encryption
#'
#' @seealso [HEQueryCountMaster()] which goes hand-in-hand with this object
#' @importFrom R6 R6Class
#' @export
HEQueryCountWorker <- R6Class(
    "HEQueryCountWorker",
    inherit = QueryCountWorker,

    private = list(
        ## the result cache for saving parts of answers to NCP1 and NCP2
        result_cache = list()
    ),
    public = list(

        #' @field pubkey the master's public key visible to everyone
        pubkey = NA,

        #' @field den the denominator for rational arithmetic
        den = NA,

        #' @description
        #' Create a new `HEQueryMaster` object.
        #' @param defn the computation definition
        #' @param data the data which is usually the list of sites
        #' @param pubkey_bits the number of bits in public key
        #' @param pubkey_n the `n` for the public key
        #' @param den_bits the number of bits in the denominator (power of 2) used in rational approximations
        #' @importFrom homomorpheR PaillierKeyPair
        #' @return a new `HEQueryMaster` object
        initialize = function(defn, data, pubkey_bits = NULL, pubkey_n = NULL, den_bits = NULL) {
            private$defn  <- defn
            private$data  <- data
            private$stateful <- TRUE
            if (!is.null(pubkey_bits)) {
                self$pubkey  <- homomorpheR::PaillierPublicKey$new(pubkey_bits, pubkey_n)
                self$den <- gmp::as.bigq(2)^(den_bits)  #Our denominator for rational approximations
            }
            stopifnot(self$kosher())
        },

        ## For debugging only
        ## setResultsCache = function(value) {
        ##     private$results_cache  <- value
        ## },
        ## getResultsCache = function() {
        ##     private$result_cache
        ## },

        #' @description
        #' Set some parameters for homomorphic computations
        #' @param pubkey_bits the number of bits in public key
        #' @param pubkey_n the `n` for the public key
        #' @param den_bits the number of bits in the denominator (power of 2) used in rational approximations
        setParams = function(pubkey_bits, pubkey_n, den_bits) {
            if(is.character(pubkey_n)) { ## we serialize stuff to character over the wire...
                pubkey_n  <- gmp::as.bigz(pubkey_n)
            }
            if (!is.null(pubkey_bits)) {
                self$pubkey  <- homomorpheR::PaillierPublicKey$new(pubkey_bits, pubkey_n)
                self$den <- gmp::as.bigq(2)^(den_bits)  #Our denominator for rational approximations
            }
            TRUE
        },

        #' @description
        #' Run the query count on local data and return the appropriate encrypted result to the party
        #' @param partyNumber the NCP party number (1 or 2)
        #' @param token a token to use for identifying parts of the same computation for NCP1 and NCP2
        #' @return the count as a list of encrypted items with components `int` and `frac`
        #' @importFrom homomorpheR random.bigz
        queryCount = function(partyNumber, token) {
            result  <- private$result_cache[[token]]
            if (is.null(result)) {
                ## we have to compute result for token
                result.int  <- super$queryCount()
                ## Add random quantity and encrypt
                pubkey <- self$pubkey
                ## Generate random offset for int and frac parts
                ## making nBits 256 ensures that x-r and x+r are both positive if x is 1024 bits
                offset <- homomorpheR::random.bigz(nBits = 256)
                ## For debugging set offset to zero
                ##offset  <- 0
                zero  <- pubkey$encrypt(0)
                ncp1Result  <- list(int = pubkey$encrypt(result.int - offset), frac = zero)
                ncp2Result  <- list(int = pubkey$encrypt(result.int + offset), frac = zero)
                result  <- list(ncp1 = ncp1Result, ncp2 = ncp2Result)
                private$result_cache[[token]] <- result
            }
            if (partyNumber == 1) result$ncp1 else result$ncp2
        }
    )
)

#' Create a homomorphic computation query count master object to employ worker objects generated by [HEQueryCountWorker()]
#' @description `HEQueryCountMaster` objects instantiate and run a distributed homomorphic query count computation; they're instantiated by non-cooperating parties (NCPs)
#'
#' @seealso [HEQueryCountWorker()] which goes hand-in-hand with this object
#' @importFrom R6 R6Class
#' @importFrom httr POST add_headers
#' @importFrom jsonlite toJSON
#' @export
HEQueryCountMaster <- R6Class(
    "HEQueryCountMaster",
    inherit = QueryCountMaster,
    private = list(
        ## the party number of the NCP this object belongs to
        partyNumber = NA,
        ## the computation definition
        defn = NA,
        ## the mapping function that runs the HE query count method on each site
        mapFn = function(site, token) {
            payload <- list(objectId = site$instanceId,
                            method = "queryCount", partyNumber = private$partyNumber,
                            token = token)
            q <- httr::POST(.makeOpencpuURL(urlPrefix=site$url, fn="executeHEMethod"),
                            body = jsonlite::toJSON(payload),
                            httr::add_headers("Content-Type" = "application/json"),
                            config = getConfig()$sslConfig
                            )
            ## Should really examine result here.
            .deSerialize(q)
        },
        ## A flag for debugging
        debug = FALSE
    ),
    public = list(

        #' @field pubkey the master's public key visible to everyone
        pubkey = NA,

        #' @field pubkey_bits the number of bits in the public key (used for reconstructing public key remotely by serializing to character)
        pubkey_bits = NA,

        #' @field pubkey_n the `n` for the public key used for reconstructing public key remotely
        pubkey_n = NA,

        #' @field den the denominator for rational arithmetic
        den = NA,

        #' @field den_bits the number of bits in the denominator used for reconstructing denominator remotely
        den_bits = NA,

        #' @description
        #' Create a new `HEQueryCountMaster` object.
        #' @param defn the computation definition
        #' @param partyNumber the party number of the NCP that this object belongs to (1 or 2)
        #' @param debug a flag for debugging, default `FALSE`
        #' @return a new `HEQueryCountMaster` object
        initialize = function(defn, partyNumber, debug = FALSE) {
            'Initialize the object with a dataset'
            private$defn <- defn
            private$partyNumber <- partyNumber
            private$debug <- debug
            stopifnot(self$kosher())
        },

        #' @description
        #' Set some parameters of the `HEQueryCountMaster` object for homomorphic computations
        #' @param pubkey_bits the number of bits in public key
        #' @param pubkey_n the `n` for the public key
        #' @param den_bits the number of bits in the denominator (power of 2) used in rational approximations
        setParams = function(pubkey_bits, pubkey_n, den_bits) {
            self$pubkey_bits  <- pubkey_bits
            if(is.character(pubkey_n)) { ## we serialize stuff to character over the wire...
                pubkey_n  <- gmp::as.bigz(pubkey_n)
            }
            self$pubkey_n  <- pubkey_n
            self$den_bits  <- den_bits
            self$pubkey  <- homomorpheR::PaillierPublicKey$new(pubkey_bits, pubkey_n)
            ##browser()
            self$den <- gmp::as.bigq(2)^(den_bits)  #Our denominator for rational approximations
            TRUE
        },

        #' @description
        #' Check if inputs and state of object are sane. For future use
        #' @return `TRUE` or `FALSE`
        kosher = function() {
            ' Check for appropriateness'
            TRUE
        },

        #' @description
        #' Run the distributed query count, associate it with a token, and return the result
        #' @param token a token to use as key
        #' @return the partial result as a list of encrypted items with components `int` and `frac`
        #' @importFrom gmp as.bigz
        queryCount = function(token) {
            'Compute the query count from all sites'
            sites <- private$sites
            if (private$dry_run) {
                mapFn <- function(site, token) site$worker$queryCount(private$partyNumber, token)
            } else {
                mapFn <- private$mapFn
            }
            ##browser()
            results <- Map(mapFn, sites, rep(token, length(sites)))
            pubkey  <- self$pubkey
            zero  <- pubkey$encrypt(0)
            ## The results could arrive as strings over the wire, so convert
            intResults  <- lapply(results, function(x) gmp::as.bigz(x$int))
            fracResults  <- lapply(results, function(x) gmp::as.bigz(x$frac))
            ## ADD manually for testing
            ## intSum  <- pubkey$add(intResults[[1L]], intResults[[2L]])
            ## intSum  <- pubkey$add(intSum, intResults[[3L]])
            ## fracSum  <- pubkey$add(intResults[[1L]], intResults[[2L]])
            ## fracSum  <- pubkey$add(fracSum, intResults[[3L]])
            intSum  <- Reduce(f = pubkey$add, x = intResults, init = zero)
            fracSum  <- Reduce(f = pubkey$add, x = fracResults, init = zero)
            list(int = intSum, frac = fracSum)
        },

        #' @description
        #' Cleanup the instance objects
        #' @importFrom httr POST add_headers
        #' @importFrom jsonlite toJSON
        cleanup = function() {
            'Send cleanup message to sites'
            if (!private$dry_run) {
                sites  <- private$sites
                ## Sites have already been augmented with instanceId in run method, after which
                ## this should be called
                sitesOK <- sapply(sites,
                                  function(x) {
                                      payload <- list(instanceId = x$instanceId)
                                      q <- httr::POST(url = .makeOpencpuURL(urlPrefix=x$url, fn="destroyInstanceObject"),
                                                      body = jsonlite::toJSON(payload),
                                                      httr::add_headers("Content-Type" = "application/json"),
                                                      config=getConfig()$sslConfig
                                                )
                                      .deSerialize(q)
                                  })
                if (!all(sitesOK)) {
                    warning("run():  Some sites did not clean up successfully!")
                }
            }
            invisible(TRUE)
        },

        #' @description
        #' Run the homomorphic encrypted distributed query count computation
        #' @param token a token to use as key
        #' @return the partial result as a list of encrypted items with components `int` and `frac`
        #' @importFrom httr POST add_headers
        #' @importFrom jsonlite toJSON
        run = function(token) {
            'Run Computation'
            dry_run <- private$dry_run
            debug  <- private$debug
            defn <- private$defn
            sites  <- private$sites
            if (dry_run) {
                ## Workers have already been created and passed
                workers  <- lapply(sites, function(x) x$worker)
                ## We just need to set params on the sites
                for (worker in workers) {
                    worker$setParams(pubkey_bits = self$pubkey_bits, pubkey_n = self$pubkey_n, den_bits = self$den_bits)
                }
            } else {
                ## Create an instance Id
                ## Make remote call to instantiate workers
                ## instanceId <- generateId(object=list(Sys.time(), self))
                ## Instancd ID for HE method is just the token!
                instanceId  <- token
                ## Augment each site with object instance ids
                private$sites <- sites <- lapply(private$sites,
                                                 function(x) list(name = x$name,
                                                                  url = x$url,
                                                                  localhost = x$localhost,
                                                                  dataFileName = x$dataFileName,
                                                                  instanceId = if (x$localhost) x$name else instanceId))
                sitesOK <- sapply(sites,
                                  function(x) {
                                      ### FIXED NOW. BUG HERE. This payload of pubkey stuff needs to made character!!
                                      payload <- if (is.null(x$dataFileName)) {
                                                     list(defnId = defn$id, instanceId = x$instanceId,
                                                          pubkey_bits = self$pubkey_bits, pubkey_n = as.character(self$pubkey_n), den_bits = self$den_bits)
                                                 } else {
                                                     list(defnId = defn$id, instanceId = x$instanceId,
                                                          dataFileName = x$dataFileName,
                                                          pubkey_bits = self$pubkey_bits, pubkey_n = as.character(self$pubkey_n), den_bits = self$den_bits)
                                                 }
                                      q <- httr::POST(url = .makeOpencpuURL(urlPrefix=x$url, fn="createHEWorkerInstance"),
                                                      body = jsonlite::toJSON(payload),
                                                      httr::add_headers("Content-Type" = "application/json"),
                                                      config = getConfig()$sslConfig
                                                      )
                                      .deSerialize(q)
                                  })

                ## Stop on error
                if (!all(sitesOK)) {
                    stop("run():  Some sites did not respond successfully!")
                    sites <- sites[which(sitesOK)]  ## Only use sites that created objects successfully.
                }
            }
            self$queryCount(token)
        }
    )
)

Try the distcomp package in your browser

Any scripts or data that you put into this service are public.

distcomp documentation built on Sept. 2, 2022, 1:07 a.m.