R/Wrapper_VectorDistribution.R

Defines functions length.VectorDistribution as.VectorDistribution

Documented in as.VectorDistribution length.VectorDistribution

#' @name VectorDistribution
#' @title Vectorise Distributions
#' @description A wrapper for creating a vector of distributions.
#' @template class_vecdist
#' @template method_wrappedModels
#' @template method_mode
#' @template method_kurtosis
#' @template method_entropy
#' @template method_pgf
#' @template method_mgfcf
#' @template param_paramid
#' @template param_log
#' @template param_logp
#' @template param_simplify
#' @template param_data
#' @template param_lowertail
#' @template param_n
#' @template param_decorators
#' @template param_ids
#'
#' @details A vector distribution is intented to vectorize distributions more efficiently than
#' storing a list of distributions. To improve speed and reduce memory usage, distributions are
#' only constructed when methods (e.g. d/p/q/r) are called.
#'
#' @export
VectorDistribution <- R6Class("VectorDistribution",
  inherit = DistributionWrapper,
  lock_objects = FALSE,
  lock_class = FALSE,
  public = list(
    #' @description
    #' Creates a new instance of this [R6][R6::R6Class] class.
    #' @param ... Unused
    #'
    #' @examples
    #' \dontrun{
    #' VectorDistribution$new(
    #'   distribution = "Binomial",
    #'   params = list(
    #'     list(prob = 0.1, size = 2),
    #'     list(prob = 0.6, size = 4),
    #'     list(prob = 0.2, size = 6)
    #'   )
    #' )
    #'
    #' VectorDistribution$new(
    #'   distribution = "Binomial",
    #'   params = data.table::data.table(prob = c(0.1, 0.6, 0.2), size = c(2, 4, 6))
    #' )
    #'
    #' # Alternatively
    #' VectorDistribution$new(
    #'   list(
    #'   Binomial$new(prob = 0.1, size = 2),
    #'   Binomial$new(prob = 0.6, size = 4),
    #'   Binomial$new(prob = 0.2, size = 6)
    #'   )
    #' )
    #' }
    initialize = function(distlist = NULL, distribution = NULL, params = NULL,
                          shared_params = NULL, name = NULL, short_name = NULL,
                          decorators = NULL, vecdist = NULL, ids = NULL, ...) {

      if (!is.null(ids)) {
        # 3.6 fix
        ids <- as.character(assert_alphanum(ids))
      }

      #-----------------
      # Decorate wrapper
      #-----------------
      if (!is.null(decorators)) {
        suppressMessages(decorate(self, decorators))
      }

      if (!is.null(vecdist)) {

        if (checkmate::testList(vecdist)) {

          dist <- as.character(unlist(vecdist[[1]]$modelTable$Distribution[[1]]))
          if (is.null(ids)) {
            ids <- paste0(
              get(dist)$public_fields$short_name,
              seq.int(sum(sapply(vecdist, function(.x) nrow(.x$modelTable)))))
          } else {
            checkmate::assertCharacter(ids, unique = TRUE)
          }

          private$.modelTable <- as.data.table(data.frame(Distribution = dist,
                                                          shortname = ids))
          private$.distlist <- FALSE
          private$.univariate <- vecdist[[1]]$.__enclos_env__$private$.univariate
          # need to recopy function to prevent referencing error
          for (which in c(".pdf", ".cdf", ".quantile", ".rand")) {
            private[[which]] <- function() {}
            formals(private[[which]]) <-
              formals(vecdist[[1]]$.__enclos_env__$private[[which]])
            body(private[[which]]) <-
              body(vecdist[[1]]$.__enclos_env__$private[[which]])
          }

          ## TODO: This is very messy, not too slow but probably inefficient
          params <- unlist(lapply(vecdist, function(.x) {
            vals <- .x$parameters()$values
            pf <- unique(get_prefix(names(vals)))
            lapply(pf, function(i) {
              v <- vals[grepl(i, names(vals))]
              names(v) <- unprefix(names(v))
              v
            })
          }), FALSE)
          parameters <- do.call(get(paste0("getParameterSet.", dist)),
                                c(params[[1]], shared_params))
          parameters$rep(length(params), prefix = ids)
          names(params) <- ids
          params <- unlist(params, recursive = FALSE)
          names(params) <- gsub(".", "__", names(params), fixed = TRUE)
          parameters$values <- params

          super$initialize(
            distlist = if (vecdist[[1]]$distlist)
              unlist(lapply(vecdist, function(.x) .x$wrappedModels()), recursive = FALSE) else NULL,
            name = paste0("Vector: ", length(ids), " ", dist, "s"),
            short_name = paste0("Vec", length(ids), get(dist)$public_fields$short_name),
            description = paste0("Vector of ", length(ids), " ", dist, "s"),
            support = do.call(setproduct, lapply(vecdist, function(.x) .x$properties$support)),
            type = do.call(setproduct, lapply(vecdist, function(.x) .x$traits$type)),
            valueSupport = vecdist[[1]]$traits$valueSupport,
            variateForm = "multivariate",
            parameters = parameters
          )

          invisible(self)

        } else {

          private$.modelTable <- vecdist$modelTable
          private$.distlist <- vecdist$distlist
          private$.univariate <- vecdist$.__enclos_env__$private$.univariate
          private$.pdf <- vecdist$.__enclos_env__$private$.pdf
          private$.cdf <- vecdist$.__enclos_env__$private$.cdf
          private$.quantile <- vecdist$.__enclos_env__$private$.quantile
          private$.rand <- vecdist$.__enclos_env__$private$.rand

          parameters  <- vecdist$parameters()

          if (checkmate::testClass(vecdist, "MixtureDistribution")) {
            parameters$remove(prefix = "mix")
          }

          super$initialize(
            distlist = distlist,
            name = vecdist$name,
            short_name = vecdist$short_name,
            description = vecdist$description,
            support = vecdist$properties$support,
            type = vecdist$traits$type,
            valueSupport = vecdist$traits$valueSupport,
            variateForm = "multivariate",
            parameters = parameters
          )

          if (is.null(name)) self$name <- gsub("Product|Mixture", "Vector", self$name)
          if (is.null(short_name)) self$short_name <- gsub("Prod|Mix", "Vec", self$short_name)
          self$description <- gsub("Product|Mixture", "Vector", self$description)

          invisible(self)
        }

      } else {

        #----------------------------------
        # distribution + params constructor
        #----------------------------------
        if (is.null(distlist)) {
          if (is.null(distribution) | (is.null(params))) {
            stop("Either distlist or distribution and params must be provided.")
          }

          distribution <- match.arg(distribution, c(listDistributions(simplify = TRUE),
                                                    listKernels(simplify = TRUE)))

          if (grepl("Empirical|Matdist", distribution)) {
            stop("Matdist, Empirical and EmpiricalMV not currently available for `distribution/params`
constructor, use `distlist` instead.")
          }

          # convert params to list
          if (!checkmate::testList(params)) {
            params <- apply(params, 1, as.list)
          }

          # catch for Geometric and NegativeBinomial
          if (distribution == "Geometric" & "trials" %in% names(unlist(params))) {
            stop("For Geometric distributions either `trials` must be passed to `shared_params`
or `distlist` should be used.")
          }

          if (distribution == "NegativeBinomial" & "form" %in% names(unlist(params))) {
            stop("For NegativeBinomial distributions either `form` must be passed to
`shared_params` or `distlist` should be used.")
          }

          # convert shared_params to list
          if (is.null(shared_params)) {
            shared_params <- list()
          } else {
            if (!checkmate::testList(shared_params)) {
              shared_params <- as.list(shared_params)
            }
          }
          private$.sharedparams <- shared_params

          # create wrapper parameters by cloning distribution parameters and setting by given params
          # skip if no parameters
          pdist <- get(distribution)
          if (is.null(ids)) {
            shortname <- pdist$public_fields$short_name
            shortnames <- NULL
          } else {
            shortname <- shortnames <- ids
          }

          if (!is.null(names(params)) && all(grepl("__", names(params)))) {
            pf <- unique(get_prefix(names(params)))
            shortnames <- pf
          } else {
            if (is.null(shortnames)) {
              shortnames <- sprintf("%s%d", shortname, seq(length(params)))
            }
            if (length(drop_null(params))) {
              names(params) <- shortnames
              params <- unlist(params, recursive = FALSE)
              names(params) <- gsub(".", "__", names(params), fixed = TRUE)
            }
          }
          lng <- length(shortnames)

          parameters <- tryCatch(get(paste0("getParameterSet.", distribution)),
                                error = function(e) NULL)
          if (!is.null(parameters)) {
            parameters <- do.call(parameters, c(params[[1]], shared_params))
            parameters$rep(lng, prefix = shortname)
            parameters$values <- params
          }

          # modelTable is for reference and later
          # construction; coercion to table from frame due to recycling
          private$.modelTable <- as.data.table(
            data.frame(
              Distribution = distribution,
              shortname = shortnames
            )
          )

          # set univariate flag for calling d/p/q/r
          private$.univariate <- pdist$private_fields$.traits$variateForm == "univariate"
          # inheritance catch
          if (!length(private$.univariate)) {
            private$.univariate <-
              pdist$get_inherit()$private_fields$.trait$variateForm == "univariate"
          }
          # set valueSupport
          valueSupport <- pdist$private_fields$.traits$valueSupport
          # inheritance catch
          if (!length(valueSupport)) {
            valueSupport <- pdist$get_inherit()$private_fields$.trait$valueSupport
          }

          # set d/p/q/r if non-NULL
          pdist_pri <- pdist[["private_methods"]]
          if (!is.null(pdist_pri[[".pdf"]])) {
            private$.pdf <- function(x1, log) {}
            body(private$.pdf) <- substitute(
              {
                fun <- function(x, log) {}
                body(fun) <- substitute(FUN)

                dpqr <- data.table()
                if (private$.univariate) {
                  if (ncol(x1) == 1) {
                    dpqr <- fun(unlist(x1), log = log)
                  } else if (nrow(x1) == 1) {
                    dpqr <- fun(x1, log = log)
                    if (nrow(dpqr) > 1) {
                      dpqr <- diag(as.matrix(dpqr))
                    }
                  } else {
                    for (i in seq_len(ncol(x1))) {
                      a_dpqr <- fun(unlist(x1[, i]), log = log)
                      a_dpqr <- if (inherits(a_dpqr, "numeric")) a_dpqr[i] else a_dpqr[, i]
                      dpqr <- cbind(dpqr, a_dpqr)
                    }
                  }
                } else {
                  if (length(dim(x1)) == 2) {
                    dpqr <- data.table(matrix(fun(x1, log = log), nrow = nrow(x1)))
                  } else {
                    for (i in seq_len(dim(x1)[3])) {
                      mx <- x1[, , i]
                      if (inherits(mx, "numeric")) {
                        mx <- matrix(mx, nrow = 1)
                      }
                      a_dpqr <- fun(mx, log = log)
                      a_dpqr <- if (inherits(a_dpqr, "numeric")) a_dpqr[i] else a_dpqr[, i]
                      dpqr <- cbind(dpqr, a_dpqr)
                    }
                  }
                }

                return(dpqr)
              },
              list(FUN = body(pdist_pri[[".pdf"]]))
            )
          }
          if (!is.null(pdist_pri[[".cdf"]])) {
            private$.cdf <- function(x1, lower.tail, log.p) {}
            body(private$.cdf) <- substitute(
              {
                fun <- function(x, lower.tail, log.p) {}
                body(fun) <- substitute(FUN)

                dpqr <- data.table()
                if (private$.univariate) {
                  if (ncol(x1) == 1) {
                    dpqr <- fun(unlist(x1), lower.tail = lower.tail, log.p = log.p)
                  } else if (nrow(x1) == 1) {
                    dpqr <- fun(x1, lower.tail = lower.tail, log.p = log.p)
                    if (nrow(dpqr) > 1) {
                      dpqr <- diag(as.matrix(dpqr))
                    }
                  } else {
                    for (i in seq(ncol(x1))) {
                      a_dpqr <- fun(unlist(x1[, i]), lower.tail = lower.tail, log.p = log.p)
                      a_dpqr <- if (inherits(a_dpqr, "numeric")) a_dpqr[i] else a_dpqr[, i]
                      dpqr <- cbind(dpqr, a_dpqr)
                    }
                  }
                }
                # TODO - This will be uncommented once EmpiricalMV can be used here
                # else {
                # for (i in seq(dim(x1)[3])) {
                #   a_dpqr <- fun(unlist(x1[, , i]), lower.tail = lower.tail, log.p = log.p)
                #   a_dpqr <- if (class(a_dpqr)[1] == "numeric") a_dpqr[i] else a_dpqr[, i]
                #   dpqr <- cbind(dpqr, a_dpqr)
                # }
                #}

                return(dpqr)
              },
              list(FUN = body(pdist_pri[[".cdf"]]))
            )
          }
          if (!is.null(pdist_pri[[".quantile"]])) {
            private$.quantile <- function(x1, lower.tail, log.p) {}
            body(private$.quantile) <- substitute(
              {
                fun <- function(p, lower.tail, log.p) {}
                body(fun) <- substitute(FUN)

                dpqr <- data.table()
                if (ncol(x1) == 1) {
                  dpqr <- fun(unlist(x1), lower.tail = lower.tail, log.p = log.p)
                } else if (nrow(x1) == 1) {
                  dpqr <- fun(x1, lower.tail = lower.tail, log.p = log.p)
                  if (nrow(dpqr) > 1) {
                    dpqr <- diag(as.matrix(dpqr))
                  }
                } else {
                  for (i in seq_len(ncol(x1))) {
                    a_dpqr <- fun(unlist(x1[, i]), lower.tail = lower.tail, log.p = log.p)
                    a_dpqr <- if (inherits(a_dpqr, "numeric")) a_dpqr[i] else a_dpqr[, i]
                    dpqr <- cbind(dpqr, a_dpqr)
                  }
                }

                return(dpqr)
              },
              list(FUN = body(pdist_pri[[".quantile"]]))
            )
          }
          if (!is.null(pdist_pri[[".rand"]])) {
            private$.rand <- function(x1) {}
            body(private$.rand) <- substitute(
              {
                fun <- function(n) {}
                body(fun) <- substitute(FUN)

                return(fun(x1))
              },
              list(FUN = body(pdist_pri[[".rand"]]))
            )
          }

          #----------------------------------
          # ditlist constructor
          #----------------------------------
        } else {

          # set flag to TRUE
          private$.distlist <- TRUE
          distribution <- c()

          if (is.null(ids)) {
            shortname <- character(0)
          } else {
            shortname <- checkmate::assertCharacter(ids, unique = TRUE)
          }

          # get all parameters in a list
          # assert all variateForm the same
          # collect valueSupport, short_name, name
          vf <- distlist[[1]]$traits$variateForm
          paramlst <- vector("list", length(distlist))
          vs <- distlist[[1]]$traits$valueSupport
          for (i in seq_along(distlist)) {
            stopifnot(distlist[[i]]$traits$variateForm == vf)
            if (is.null(ids)) {
              shortname <- c(shortname, distlist[[i]]$short_name)
            }
            distribution <- c(distribution, distlist[[i]]$name)
            if (!is.null(distlist[[i]]$parameters())) {
              paramlst[[i]] <- distlist[[i]]$parameters()
            }
            vs <- c(vs, distlist[[i]]$traits$valueSupport)
          }

          valueSupport <- if (length(unique(vs)) == 1) vs[[1]] else "mixture"
          if (is.null(ids)) {
            shortname <- makeUniqueNames(shortname)
          }
          parameters <- NULL

          names(distlist) <- shortname

          private$.univariate <- vf == "univariate"

          # modelTable is for reference and later
          # construction; coercion to table from frame due to recycling
          private$.modelTable <- as.data.table(
            data.frame(Distribution = distribution, shortname = shortname)
          )

          # set dpqr
          private$.pdf <- function(x, log = FALSE) {
            dpqr <- data.table()
            if (private$.univariate) {
              for (i in seq_len(ncol(x))) {
                a_dpqr <- self[i]$pdf(x[, i], log = log)
                dpqr <- cbind(dpqr, a_dpqr)
              }
            } else {
              for (i in seq_len(dim(x)[[3]])) {
                a_dpqr <- self[i]$pdf(data = matrix(x[, , i], nrow = nrow(x), ncol = ncol(x)),
                                      log = log)
                dpqr <- cbind(dpqr, a_dpqr)
              }
            }

            return(dpqr)
          }
          private$.cdf <- function(x, lower.tail = TRUE, log.p = FALSE) {
            dpqr <- data.table()
            if (private$.univariate) {
              for (i in seq(ncol(x))) {
                a_dpqr <- self[i]$cdf(x[, i], lower.tail = lower.tail, log.p = log.p)
                dpqr <- cbind(dpqr, a_dpqr)
              }
            } else {
              for (i in seq(dim(x)[3])) {
                a_dpqr <- self[i]$cdf(data = matrix(x[, , i], nrow = nrow(x), ncol = ncol(x)),
                                      lower.tail = lower.tail, log.p = log.p)
                dpqr <- cbind(dpqr, a_dpqr)
              }
            }

            return(dpqr)
          }
          private$.quantile <- function(x, lower.tail = TRUE, log.p = FALSE) {
            dpqr <- data.table()
            for (i in seq(ncol(x))) {
              a_dpqr <- self[i]$quantile(x[, i], lower.tail = lower.tail, log.p = log.p)
              dpqr <- cbind(dpqr, a_dpqr)
            }

            return(dpqr)
          }
          private$.rand <- function(n) {
            if (private$.univariate) {
              if (n == 1) {
                return(matrix(sapply(self$wrappedModels(), function(x) x$rand(n)), nrow = 1))
              } else {
                return(sapply(self$wrappedModels(), function(x) x$rand(n)))
              }
            } else {
              return(lapply(self$wrappedModels(), function(x) x$rand(n)))
            }

          }
        }

        # define number of distributions from modelTable
        ndist <- nrow(private$.modelTable)

        # create name, short_name, description, type, support
        dst <- unique(self$modelTable$Distribution)
        if (length(dst) == 1 & dst[[1]] %in% c(listDistributions(simplify = TRUE),
                                               listKernels(simplify = TRUE))) {
          distribution <- get(as.character(unlist(self$modelTable[1, 1])))
          if (is.null(name)) {
            name <- paste0(
              "Vector: ", ndist, " ",
              distribution$public_fields$name, "s"
            )
          }
          if (is.null(short_name)) {
            short_name <- paste0(
              "Vec", ndist,
              distribution$public_fields$short_name
            )
          }
          description <- paste0("Vector of ", ndist, " ", distribution$public_fields$name, "s")
          type <- distribution$new()$traits$type^ndist
          # FIXME - support defined as same as type
          support <- type
        } else {
          type <- do.call(setproduct, lapply(distlist, function(x) x$traits$type))
          # FIXME - support defined as same as type
          support <- type

          # name depends on length of distributions, anything over 3 shortened
          if (ndist > 3) {
            if (is.null(name)) name <- paste("Vector:", ndist, "Distributions")
            if (is.null(short_name)) short_name <- paste0("Vec", ndist, "Dists")
            description <- paste0("Vector of ", ndist, " distributions.")
          } else {
            if (is.null(name)) name <- paste("Vector:", paste0(distribution, collapse = ", "))
            if (is.null(short_name)) short_name <- paste0(shortname, collapse = "Vec")
            description <- paste0("Vector of: ", paste0(shortname, collapse = ", "))
          }
        }

        super$initialize(
          distlist = distlist,
          name = name,
          short_name = short_name,
          description = description,
          support = support,
          type = type,
          valueSupport = valueSupport,
          variateForm = "multivariate",
          parameters = parameters,
          ...
        )

      }
    },

    #' @description
    #' Returns the value of the supplied parameter.
    #' @param ... Unused
    getParameterValue = function(id, ...) {
      vals <- private$.parameters$get_values(id)
      if (!is.null(names(vals))) {
        names(vals) <- get_n_prefix(names(vals))
      }
      vals
    },

    #' @description
    #' Returns model(s) wrapped by this wrapper.
    wrappedModels = function(model = NULL) {

      if (is.null(model)) {
        if (private$.distlist) {
          distlist <- private$.wrappedModels
        } else {
          distlist <- lapply(private$.modelTable$shortname, function(x) {
            pars <- self$parameters()[prefix = x]$values
            dist <- do.call(get(as.character(unlist(private$.modelTable$Distribution[[1]])))$new,
                            c(pars, list(decorators = self$decorators)))
            return(dist)
          })
        }
      } else {
        models <- subset(private$.modelTable, shortname == model)$shortname

        if (length(models) == 0) {
          stop(sprintf("No distribution called %s.", model))
        }

        if (private$.distlist) {
          distlist <- private$.wrappedModels[models]
        } else {
          distlist <- lapply(models, function(x) {
            dist <- do.call(get(as.character(unlist(private$.modelTable$Distribution[[1]])))$new,
                            list(decorators = self$decorators))
            dist$setParameterValue(lst = self$parameters()[prefix = x]$values)
            return(dist)
          })
        }
      }

      if (length(distlist) == 1) {
        return(distlist[[1]])
      } else {
        names(distlist) <- as.character(unlist(private$.modelTable$shortname))
        return(distlist)
      }
    },

    #' @description
    #' Printable string representation of the `VectorDistribution`. Primarily used internally.
    #' @param n `(integer(1))`\cr
    #' Number of distributions to include when printing.
    strprint = function(n = 10) {
      names <- as.character(unlist(self$modelTable$shortname))
      lng <- length(names)
      if (lng > (2 * n)) {
        names <- c(names[1:n], "...", names[(lng - n + 1):lng])
      }

      return(names)
    },

    #' @description
    #' Returns named vector of means from each wrapped [Distribution].
    #' @param ... Passed to [CoreStatistics]`$genExp` if numeric.
    mean = function(...) {
      if (self$distlist) {
        ret <- sapply(seq(nrow(private$.modelTable)), function(i) {
          ifnerror(self[i]$mean(...), error = NaN)
        })
      } else {
        f <- get(as.character(unlist(self$modelTable$Distribution[[1]])))$public_methods$mean
        if (is.null(f)) {
          f <- get(as.character(unlist(self$modelTable$Distribution[[1]])))$get_inherit()$
            public_methods$mean
        }
        if (is.null(f)) {
          stop("Not implemented for this distribution.")
        }
        formals(f) <- c(list(self = self), alist(... = )) # nolint
        ret <- f()
        if (length(ret) == 1) {
          ret <- rep(ret, nrow(self$modelTable))
        }
      }

      if (is.null(dim(ret))) {
        names(ret) <- as.character(unlist(private$.modelTable[, "shortname"]))
      } else {
        ret <- data.table(t(ret))
        colnames(ret) <- as.character(unlist(private$.modelTable[, "shortname"]))
      }

      return(ret)
    },

    #' @description
    #' Returns named vector of modes from each wrapped [Distribution].
    mode = function(which = "all") {
      if (self$distlist) {
        ret <- sapply(seq(nrow(private$.modelTable)), function(i) {
          ifnerror(self[i]$mode(which), error = NaN)
        })
      } else {
        f <- get(as.character(unlist(self$modelTable$Distribution[[1]])))$public_methods$mode
        if (is.null(f)) {
          f <- get(as.character(unlist(self$modelTable$Distribution[[1]])))$get_inherit()$
            public_methods$mode
        }
        if (is.null(f)) {
          stop("Not implemented for this distribution.")
        }
        formals(f) <- list(self = self, which = which)
        ret <- f()
        if (length(ret) == 1) {
          ret <- rep(ret, nrow(self$modelTable))
        }
      }

      if (is.null(dim(ret))) {
        names(ret) <- as.character(unlist(private$.modelTable[, "shortname"]))
      } else {
        # hacky catch for MVN
        if (as.character(unlist(self$modelTable$Distribution[[1]])) != "MultivariateNormal") {
          ret <- t(ret)
        }

        ret <- data.table(ret)
        colnames(ret) <- as.character(unlist(private$.modelTable[, "shortname"]))
      }

      return(ret)
    },

    #' @description
    #' Returns named vector of medians from each wrapped [Distribution].
    median = function() {
      self$quantile(0.5)
    },

    #' @description
    #' Returns named vector of variances from each wrapped [Distribution].
    #' @param ... Passed to [CoreStatistics]`$genExp` if numeric.
    variance = function(...) {
      if (self$distlist) {
        ret <- sapply(seq(nrow(private$.modelTable)), function(i) {
          ifnerror(self[i]$variance(...), error = NaN)
        })
      } else {
        f <- get(as.character(unlist(self$modelTable$Distribution[[1]])))$public_methods$variance
        if (is.null(f)) {
          f <- get(as.character(unlist(self$modelTable$Distribution[[1]])))$get_inherit()$
            public_methods$variance
        }
        if (is.null(f)) {
          stop("Not implemented for this distribution.")
        }
        formals(f) <- c(list(self = self), alist(... = )) # nolint
        ret <- f()
        if (length(ret) == 1) {
          ret <- rep(ret, nrow(self$modelTable))
        }
      }

      if (is.null(dim(ret))) {
        names(ret) <- as.character(unlist(private$.modelTable[, "shortname"]))
      } else {
        # catch for covariance matrices
        dimnames(ret)[3] <- as.list(private$.modelTable[, "shortname"])
      }

      return(ret)
    },

    #' @description
    #' Returns named vector of skewness from each wrapped [Distribution].
    #' @param ... Passed to [CoreStatistics]`$genExp` if numeric.
    skewness = function(...) {
      if (self$distlist) {
        ret <- sapply(seq(nrow(private$.modelTable)), function(i) {
          ifnerror(self[i]$skewness(...), error = NaN)
        })
      } else {
        f <- get(as.character(unlist(self$modelTable$Distribution[[1]])))$public_methods$skewness
        if (is.null(f)) {
          f <- get(as.character(unlist(self$modelTable$Distribution[[1]])))$get_inherit()$
            public_methods$skewness
        }
        if (is.null(f)) {
          stop("Not implemented for this distribution.")
        }
        formals(f) <- c(list(self = self), alist(... = )) # nolint
        ret <- f()
        if (length(ret) == 1) {
          ret <- rep(ret, nrow(self$modelTable))
        }
      }

      names(ret) <- as.character(unlist(private$.modelTable[, "shortname"]))

      return(ret)
    },

    #' @description
    #' Returns named vector of kurtosis from each wrapped [Distribution].
    #' @param ... Passed to [CoreStatistics]`$genExp` if numeric.
    kurtosis = function(excess = TRUE, ...) {

      if (self$distlist) {
        ret <- sapply(seq(nrow(private$.modelTable)), function(i) {
          ifnerror(self[i]$kurtosis(excess, ...), error = NaN)
        })
      } else {
        f <- get(as.character(unlist(self$modelTable$Distribution[[1]])))$public_methods$kurtosis
        if (is.null(f)) {
          f <- get(as.character(unlist(self$modelTable$Distribution[[1]])))$get_inherit()$
            public_methods$kurtosis
        }
        if (is.null(f)) {
          stop("Not implemented for this distribution.")
        }
        formals(f) <- c(list(self = self, excess = excess), alist(... = )) # nolint
        ret <- f()
        if (length(ret) == 1) {
          ret <- rep(ret, nrow(self$modelTable))
        }
      }

      names(ret) <- as.character(unlist(private$.modelTable[, "shortname"]))

      return(ret)
    },

    #' @description
    #' Returns named vector of entropy from each wrapped [Distribution].
    #' @param ... Passed to [CoreStatistics]`$genExp` if numeric.
    entropy = function(base = 2, ...) {
      if (self$distlist) {
        ret <- sapply(seq(nrow(private$.modelTable)), function(i) {
          ifnerror(self[i]$entropy(base, ...), error = NaN)
        })
      } else {
        f <- get(as.character(unlist(self$modelTable$Distribution[[1]])))$public_methods$entropy
        if (is.null(f)) {
          f <- get(as.character(unlist(self$modelTable$Distribution[[1]])))$get_inherit()$
            public_methods$entropy
        }
        formals(f) <- c(list(self = self, base = base), alist(... = )) # nolint
        ret <- f()
        if (length(ret) == 1) {
          ret <- rep(ret, nrow(self$modelTable))
        }
      }

      names(ret) <- as.character(unlist(private$.modelTable[, "shortname"]))

      return(ret)
    },

    #' @description
    #' Returns named vector of mgf from each wrapped [Distribution].
    #' @param ... Passed to [CoreStatistics]`$genExp` if numeric.
    mgf = function(t, ...) {
      if (!self$distlist) {
        warning("mgf not currently efficiently vectorised, may be slow.")
      }

      ret <- sapply(seq(nrow(private$.modelTable)), function(i) {
        ifnerror(self[i]$mgf(t, ...), error = NaN)
      })

      # FIXME - VECTORISE PROPERLY
      # } else {
      #   f <- get(as.character(unlist(self$modelTable$Distribution[[1]])))$public_methods$mgf
      #   formals(f) = list(self = self, t = t)
      #   ret <- f()
      # }

      if (is.null(dim(ret))) {
        names(ret) <- as.character(unlist(private$.modelTable[, "shortname"]))
      } else {
        ret <- data.table(ret)
        colnames(ret) <- as.character(unlist(private$.modelTable[, "shortname"]))
      }

      return(ret)
    },

    #' @description
    #' Returns named vector of cf from each wrapped [Distribution].
    #' @param ... Passed to [CoreStatistics]`$genExp` if numeric.
    cf = function(t, ...) {
      if (!self$distlist) {
        warning("cf not currently efficiently vectorised, may be slow.")
      }

      ret <- sapply(seq(nrow(private$.modelTable)), function(i) {
        ifnerror(self[i]$cf(t, ...), error = NaN)
      })

      # FIXME - VECTORISE PROPERLY
      # } else {
      #   f <- get(as.character(unlist(self$modelTable$Distribution[[1]])))$public_methods$cf
      #   formals(f) = list(self = self, t = t)
      #   ret <- f()
      # }

      if (is.null(dim(ret))) {
        names(ret) <- as.character(unlist(private$.modelTable[, "shortname"]))
      } else {
        ret <- data.table(ret)
        colnames(ret) <- as.character(unlist(private$.modelTable[, "shortname"]))
      }

      return(ret)
    },

    #' @description
    #' Returns named vector of pgf from each wrapped [Distribution].
    #' @param ... Passed to [CoreStatistics]`$genExp` if numeric.
    pgf = function(z, ...) {
      if (!self$distlist) {
        warning("pgf not currently efficiently vectorised, may be slow.")
      }

      ret <- sapply(seq(nrow(private$.modelTable)), function(i) {
        ifnerror(self[i]$pgf(z, ...), error = NaN)
      })

      # FIXME - VECTORISE PROPERLY
      # } else {
      #   f <- get(as.character(unlist(self$modelTable$Distribution[[1]])))$public_methods$pgf
      #   formals(f) = list(self = self, z = z)
      #   ret <- f()
      # }

      if (is.null(dim(ret))) {
        names(ret) <- as.character(unlist(private$.modelTable[, "shortname"]))
      } else {
        ret <- data.table(ret)
        colnames(ret) <- as.character(unlist(private$.modelTable[, "shortname"]))
      }

      return(ret)
    },

    #' @description
    #' Returns named vector of pdfs from each wrapped [Distribution].
    #' @param ... `(numeric())` \cr
    #' Points to evaluate the function at Arguments do not need
    #' to be named. The length of each argument corresponds to the number of points to evaluate,
    #' the number of arguments corresponds to the number of variables in the distribution.
    #' See examples.
    #' @examples
    #' vd <- VectorDistribution$new(
    #'  distribution = "Binomial",
    #'  params = data.frame(size = 9:10, prob = c(0.5,0.6)))
    #'
    #' vd$pdf(2)
    #' # Equivalently
    #' vd$pdf(2, 2)
    #'
    #' vd$pdf(1:2, 3:4)
    #' # or as a matrix
    #' vd$pdf(data = matrix(1:4, nrow = 2))
    #'
    #' # when wrapping multivariate distributions, arrays are required
    #' vd <- VectorDistribution$new(
    #'  distribution = "Multinomial",
    #'  params = list(
    #'  list(size = 5, probs = c(0.1, 0.9)),
    #'  list(size = 8, probs = c(0.3, 0.7))
    #'  )
    #'  )
    #'
    #' # evaluates Multinom1 and Multinom2 at (1, 4)
    #' vd$pdf(1, 4)
    #'
    #' # evaluates Multinom1 at (1, 4) and Multinom2 at (5, 3)
    #' vd$pdf(data = array(c(1,4,5,3), dim = c(1,2,2)))
    #'
    #' # and the same across many samples
    #' vd$pdf(data = array(c(1,2,4,3,5,1,3,7), dim = c(2,2,2)))
    pdf = function(..., log = FALSE, simplify = TRUE, data = NULL) {
      if (is.null(data)) {
        data <- as.matrix(data.table(...))
      } else if (length(dim(data)) == 2) {
        data <- as.matrix(data)
      }

      if (ncol(data) != nrow(self$modelTable) & ncol(data) > 1 & private$.univariate) {
        stopf("Expected data with %s or 1 columns, received %s.", nrow(self$modelTable), ncol(data))
      }

      if (private$.univariate) {
        if (private$.distlist & ncol(data) == 1) {
          data <- matrix(rep(data, nrow(private$.modelTable)), nrow = nrow(data),
                         ncol = nrow(private$.modelTable))
        }
        dpqr <- private$.pdf(data, log = log)
        if (inherits(dpqr, "numeric")) {
          dpqr <- matrix(dpqr, ncol = nrow(private$.modelTable))
        }
        dpqr <- as.data.table(dpqr)
        colnames(dpqr) <- as.character(unlist(private$.modelTable[, 2]))
        return(dpqr)
      } else {
        if (ncol(data) == 1) {
          stop("Distribution is multivariate but values have only been passed to one argument.")
        }
        if ((inherits(data, "array") | inherits(data, "matrix")) & private$.distlist) {
          if (is.na(dim(data)[3])) {
            data <- array(rep(data, nrow(private$.modelTable)),
              dim = c(nrow(data), ncol(data), nrow(private$.modelTable))
            )
          }
        }
        dpqr <- private$.pdf(data, log = log)
        colnames(dpqr) <- as.character(unlist(private$.modelTable[, 2]))
        return(dpqr)
      }
    },

    #' @description
    #' Returns named vector of cdfs from each wrapped [Distribution].
    #' Same usage as `$pdf.`
    #' @param ... `(numeric())` \cr
    #' Points to evaluate the function at Arguments do not need
    #' to be named. The length of each argument corresponds to the number of points to evaluate,
    #' the number of arguments corresponds to the number of variables in the distribution.
    #' See examples.
    cdf = function(..., lower.tail = TRUE, log.p = FALSE, simplify = TRUE, data = NULL) {

      if (is.null(data)) {
        data <- as.matrix(data.table(...))
      }

      if (ncol(data) != nrow(self$modelTable) & ncol(data) > 1 & private$.univariate) {
        stopf("Expected data with %s or 1 columns, received %s.", nrow(self$modelTable), ncol(data))
      }

      if (private$.univariate) {
        if (ncol(data) == 1 & private$.distlist) {
          data <- matrix(rep(data, nrow(private$.modelTable)), nrow = nrow(data))
        }
      } else {
        if (ncol(data) == 1) {
          stop("Distribution is multivariate but values have only been passed to one argument.")
        } else if (inherits(data, "array") | inherits(data, "matrix")) {
          if (is.na(dim(data)[3]) & private$.distlist) {
            data <- array(rep(data, nrow(private$.modelTable)),
                          dim = c(nrow(data), ncol(data), nrow(private$.modelTable))
            )
          }
        }
      }

      dpqr <- private$.cdf(data, lower.tail = lower.tail, log.p = log.p)
      if (inherits(dpqr, "numeric")) {
        dpqr <- matrix(dpqr, ncol = nrow(private$.modelTable))
      }
      dpqr <- as.data.table(dpqr)
      colnames(dpqr) <- as.character(unlist(private$.modelTable[, 2]))
      return(dpqr)
    },

    #' @description
    #' Returns named vector of quantiles from each wrapped [Distribution].
    #' Same usage as `$cdf.`
    #' @param ... `(numeric())` \cr
    #' Points to evaluate the function at Arguments do not need
    #' to be named. The length of each argument corresponds to the number of points to evaluate,
    #' the number of arguments corresponds to the number of variables in the distribution.
    #' See examples.
    quantile = function(..., lower.tail = TRUE, log.p = FALSE, simplify = TRUE, data = NULL) {
      if (is.null(data)) {
        data <- as.matrix(data.table(...))
      }

      if (ncol(data) != nrow(self$modelTable) & ncol(data) > 1 & private$.univariate) {
        stopf("Expected data with %s or 1 columns, received %s.", nrow(self$modelTable), ncol(data))
      }

      if (private$.univariate) {
        if (ncol(data) == 1 & private$.distlist) {
          data <- matrix(rep(data, nrow(private$.modelTable)), nrow = nrow(data))
        }
      } else {
        stop("Quantile not possible for non-univariate distributions.")
      }

      dpqr <- private$.quantile(data, lower.tail = lower.tail, log.p = log.p)
      if (inherits(dpqr, "numeric")) {
        dpqr <- matrix(dpqr, ncol = nrow(private$.modelTable))
      }
      dpqr <- as.data.table(dpqr)
      colnames(dpqr) <- as.character(unlist(private$.modelTable[, 2]))
      return(dpqr)
    },

    #' @description
    #' Returns [data.table::data.table] of draws from each wrapped [Distribution].
    rand = function(n, simplify = TRUE) {
      if (length(n) > 1) {
        n <- length(n)
      }

      data <- n

      if (private$.univariate) {
        dpqr <- as.data.table(private$.rand(data))
        colnames(dpqr) <- as.character(unlist(private$.modelTable[, 2]))
        return(dpqr)
      } else {
        dpqr <- private$.rand(data)
        dpqr <- array(unlist(dpqr), c(nrow(dpqr[[1]]), ncol(dpqr[[1]]), length(dpqr)))
        dimnames(dpqr) <- list(NULL, paste0("V", seq(ncol(dpqr))),
                               as.character(unlist(private$.modelTable$shortname)))
        return(dpqr)
      }
    }
  ),

  active = list(
    #' @field modelTable
    #' Returns reference table of wrapped [Distribution]s.
    modelTable = function() {
      private$.modelTable
    },
    #' @field distlist
    #' Returns list of constructed wrapped [Distribution]s.
    distlist = function() {
      return(private$.distlist)
    },
    #' @field ids
    #' Returns ids of constructed wrapped [Distribution]s.
    ids = function() {
      as.character(unlist(private$.modelTable$shortname))
    }
  ),

  private = list(
    .univariate = logical(0),
    .distlist = FALSE,
    .sharedparams = list(),
    .properties = list(),
    .traits = list(type = NA, valueSupport = "mixture", variateForm = "multivariate"),
    .trials = logical(0)
  )
)

.distr6$wrappers <- append(.distr6$wrappers, list(VectorDistribution = VectorDistribution))


#' @title Extract one or more Distributions from a VectorDistribution
#' @description Once a \code{VectorDistribution} has been constructed, use \code{[}
#' to extract one or more \code{Distribution}s from inside it.
#' @param vecdist VectorDistribution from which to extract Distributions.
#' @param i indices specifying distributions to extract or ids of wrapped distributions.
#' @usage \method{[}{VectorDistribution}(vecdist, i)
#' @examples
#' v <- VectorDistribution$new(distribution = "Binom", params = data.frame(size = 1:2, prob = 0.5))
#' v[1]
#' v["Binom1"]
#'
#' @export
"[.VectorDistribution" <- function(vecdist, i) {

  if (checkmate::testCharacter(i)) {
    checkmate::assertSubset(i, as.character(unlist(vecdist$modelTable$shortname)))
    i <- match(i, as.character(unlist(vecdist$modelTable$shortname)), 0)
  }
  if (is.logical(i)) {
    i <- which(i)
  }
  i <- i[i %in% seq_len(nrow(vecdist$modelTable))]
  if (length(i) == 0) {
    stop("Index i too large, should be less than or equal to ", nrow(vecdist$modelTable))
  }

  decorators <- vecdist$decorators

  if (!vecdist$distlist) {
    distribution <- as.character(unlist(vecdist$modelTable[1, 1]))
    if (length(i) == 1) {
      id <- as.character(unlist(vecdist$modelTable[i, 2]))
      dist <- get(distribution)$new(decorators = decorators)
      pri <- get_private(dist)
      pri$.parameters <- vecdist$parameters()[prefix = id]
      return(dist)
    } else {
      id <- as.character(unlist(vecdist$modelTable[i, 2]))
      pars <- private(vecdist$parameters())$.value
      pars <- pars[grepl(paste0("^", id, "__", collapse = "|"), names(pars))]

      return(VectorDistribution$new(
        distribution = distribution, params = pars,
        decorators = decorators,
        shared_params = vecdist$.__enclos_env__$private$.sharedparams,
        ids = vecdist$modelTable$shortname[i]
      ))
    }
  } else {
    if (length(i) == 1) {
      dist <- vecdist$wrappedModels()[[i]]
      if (!is.null(decorators)) {
        suppressMessages(decorate(dist, decorators))
      }
      return(dist)
    } else {
      return(VectorDistribution$new(
        distlist = vecdist$wrappedModels()[i],
        decorators = decorators,
        ids = vecdist$modelTable$shortname[i]
      ))
    }
  }
}

#' @title Coercion to Vector Distribution
#' @description Helper functions to quickly convert compatible objects to
#' a [VectorDistribution].
#' @param object [MixtureDistribution] or [ProductDistribution]
#' @export
as.VectorDistribution <- function(object) {
  if (checkmate::testClass(object, "VectorDistribution")) {
    VectorDistribution$new(vecdist = object)
  } else if (checkmate::testClass(object, "Matdist")) {
    as.Distribution(gprm(object, "pdf"), "pdf", object$decorators, TRUE)
  } else {
    stop("Object must inherit from VectorDistribution or Matdist.")
  }
}

#' @title Get Number of Distributions in Vector Distribution
#' @description Gets the number of distributions in an object inheriting from
#' [VectorDistribution].
#' @param x [VectorDistribution]
#' @export
length.VectorDistribution <- function(x) {
  nrow(x$modelTable)
}
RaphaelS1/distr6 documentation built on Feb. 24, 2024, 9:14 p.m.