R/042_atoms_affine_sum.R

Defines functions sum_entries

Documented in sum_entries

#####
## DO NOT EDIT THIS FILE!! EDIT THE SOURCE INSTEAD: rsrc_tree/atoms/affine/sum.R
#####

## CVXPY SOURCE: atoms/affine/sum.py
## SumEntries -- sum the entries of an expression over a given axis
##
## CVXPY: Sum(AxisAtom, AffAtom) -- multiple inheritance resolved via AxisAffAtom
## axis=NULL: sum all entries -> scalar (1,1)
## axis=1: sum over columns (row-wise) -> column vector (nrow, 1)
## axis=2: sum over rows (column-wise) -> row vector (1, ncol)
## keepdims=TRUE: reduced dimension kept as size 1


SumEntries <- new_class("SumEntries", parent = AxisAffAtom, package = "CVXR",
  constructor = function(expr, axis = NULL, keepdims = FALSE) {
    expr <- as_expr(expr)
    if (!is.null(axis)) axis <- as.integer(axis)
    keepdims <- as.logical(keepdims)

    ## Compute shape using axis-aware reduction
    shape <- .axis_shape(expr@shape, axis, keepdims)

    obj <- new_object(S7_object(),
      id       = next_expr_id(),
      .cache   = new.env(parent = emptyenv()),
      args     = list(expr),
      shape    = shape,
      axis     = axis,
      keepdims = keepdims
    )
    validate_arguments(obj)
    obj
  }
)

# -- sign_from_args --------------------------------------------------
## Inherits from AffAtom: sum_signs(args)

# -- is_atom_log_log_convex / concave -------------------------------
## CVXPY SOURCE: sum.py lines 69-75

method(is_atom_log_log_convex, SumEntries) <- function(x) TRUE
method(is_atom_log_log_concave, SumEntries) <- function(x) FALSE

# -- numeric_value ---------------------------------------------------
## CVXPY SOURCE: sum.py lines 93-101

method(numeric_value, SumEntries) <- function(x, values, ...) {
  val <- values[[1L]]
  if (is.null(x@axis)) {
    ## Sum all entries -> scalar
    return(matrix(sum(val), 1L, 1L))
  }
  ## Axis-aware sum
  if (inherits(val, "sparseMatrix")) {
    val <- as.matrix(val)
  }
  if (!is.matrix(val)) val <- matrix(val, ncol = 1L)

  if (x@axis == 2L) {
    ## axis=2: Sum over rows (column-wise) -> colSums -> row vector (1, ncol)
    result <- colSums(val)
    if (x@keepdims) {
      return(matrix(result, nrow = 1L))
    } else {
      return(matrix(result, nrow = 1L))
    }
  } else {
    ## axis=1: sum over columns (row-wise) -> rowSums -> column vector (nrow, 1)
    result <- rowSums(val)
    if (x@keepdims) {
      return(matrix(result, ncol = 1L))
    } else {
      return(matrix(result, ncol = 1L))
    }
  }
}

# -- graph_implementation --------------------------------------------
## CVXPY SOURCE: sum.py lines 103-141
## For 2D expressions:
##   axis=NULL: sum_entries_linop
##   axis=1: right-multiply by ones vector (rmul_expr) -> column vector (nrow, 1)
##   axis=2: left-multiply by ones vector (mul_expr) -> row vector (1, ncol)

method(graph_implementation, SumEntries) <- function(x, arg_objs, shape, data = NULL, ...) {
  axis <- data[[1L]]
  keepdims <- data[[2L]]

  if (is.null(axis)) {
    list(sum_entries_linop(arg_objs[[1L]], shape), list())
  } else if (axis == 1L) {
    ## axis=1: Sum over columns (row-wise): right-multiply by ones vector
    const_shape <- c(arg_objs[[1L]]$shape[2L], 1L)
    ones <- create_const(rep(1.0, const_shape[1L]), const_shape)
    list(rmul_expr_linop(arg_objs[[1L]], ones, shape), list())
  } else {
    ## axis=2: Sum over rows (column-wise): left-multiply by ones vector
    ## ones(1, nrow) * x(nrow, ncol) -> (1, ncol)
    const_shape <- c(1L, arg_objs[[1L]]$shape[1L])
    ## CRITICAL: Store as matrix(1, 1, nrow), NOT rep(1, nrow).
    ## format_matrix(rep(1,n)) -> as.matrix -> column vector (n,1),
    ## but we need row vector (1,n) for the C++ MUL_EXPR block diagonal.
    ones <- create_const(matrix(1.0, nrow = 1L, ncol = const_shape[2L]), const_shape)
    ## mul_expr gives (1, ncol) -- matches axis=2 output shape directly
    list(mul_expr_linop(ones, arg_objs[[1L]], shape), list())
  }
}

# -- expr_name -----------------------------------------------------

method(expr_name, SumEntries) <- function(x) {
  data <- get_data(x)
  data_str <- vapply(data, function(d) {
    if (is.null(d)) "NULL" else as.character(d)
  }, character(1))
  arg_strs <- vapply(x@args, expr_name, character(1))
  sprintf("SumEntries(%s)", paste(c(arg_strs, data_str), collapse = ", "))
}

# -- Convenience function ------------------------------------------

#' Sum the entries of an expression
#'
#' @param x An Expression or numeric value.
#' @param axis NULL (sum all), 1 (row-wise, like apply(X,1,sum)), or 2 (column-wise, like apply(X,2,sum)).
#' @param keepdims Logical: if TRUE, keep the reduced dimension as size 1.
#' @returns A SumEntries expression.
#' @export
sum_entries <- function(x, axis = NULL, keepdims = FALSE) {
  SumEntries(x, axis = axis, keepdims = keepdims)
}

Try the CVXR package in your browser

Any scripts or data that you put into this service are public.

CVXR documentation built on March 6, 2026, 9:10 a.m.