R/034_atoms_axis_atom.R

Defines functions .axis_out_of_bounds_error .validate_axis .axis_shape

#####
## DO NOT EDIT THIS FILE!! EDIT THE SOURCE INSTEAD: rsrc_tree/atoms/axis_atom.R
#####

## CVXPY SOURCE: atoms/axis_atom.py
## AxisAtom -- abstract base class for atoms applied along an axis
##
## Atoms that can reduce over an axis (e.g., sum, max, min, norm).
## Stores axis and keepdims properties. Shape is reduced along the given axis.


AxisAtom <- new_class("AxisAtom", parent = Atom, package = "CVXR",
  properties = list(
    axis     = class_any,     # NULL (reduce all) or integer
    keepdims = class_logical  # whether to keep reduced dimensions
  ),
  constructor = function(expr, axis = NULL, keepdims = FALSE) {
    expr <- as_expr(expr)
    if (!is.null(axis)) axis <- as.integer(axis)
    keepdims <- as.logical(keepdims)

    ## Compute shape from args using axis-aware reduction
    ## CVXPY: AxisAtom.__init__ -> super().__init__(expr) -> Atom.__init__ -> self.shape_from_args()
    shape <- .axis_shape(expr@shape, axis, keepdims)

    obj <- new_object(S7_object(),
      id       = next_expr_id(),
      .cache   = new.env(parent = emptyenv()),
      args     = list(expr),
      shape    = shape,
      axis     = axis,
      keepdims = keepdims
    )
    validate_arguments(obj)
    obj
  }
)

# -- shape_from_args --------------------------------------------------
## CVXPY SOURCE: axis_atom.py lines 36-60
## Returns the shape after reducing along the given axis.
## In R, we always maintain 2D shapes: c(nrow, ncol).

method(shape_from_args, AxisAtom) <- function(x) {
  .axis_shape(x@args[[1L]]@shape, x@axis, x@keepdims)
}

## Internal: compute axis-reduced shape
## CVXPY uses arbitrary ndim; R is always 2D c(nrow, ncol).
## R convention (1-based axis):
##   axis=1 -> reduce cols (row-wise) -> like apply(X, 1, FUN)
##   axis=2 -> reduce rows (column-wise) -> like apply(X, 2, FUN)
##   axis=NULL -> reduce all
##
## Shape results for (m, n) input:
##   axis=1 -> c(m, 1)  (column vector of row results)
##   axis=2 -> c(1, n)  (row vector of column results)
##   axis=NULL -> c(1, 1)  (scalar)
##
## With keepdims:
##   axis=1, keepdims -> c(m, 1)
##   axis=2, keepdims -> c(1, n)
.axis_shape <- function(arg_shape, axis, keepdims) {
  if (is.null(axis)) {
    ## Reduce all -> scalar (keepdims: all dims become 1)
    return(c(1L, 1L))
  }
  ## Normalize negative axis (R 2D: ndim=2, axes are 1 and 2)
  ndim <- 2L
  if (axis < 0L) axis <- axis + ndim + 1L
  if (axis < 1L || axis > ndim) {
    .axis_out_of_bounds_error(axis, ndim)
  }
  shape <- arg_shape
  if (keepdims) {
    shape[3L - axis] <- 1L
  } else {
    ## Remove axis -> result dimension
    ## axis=1: reduce cols -> c(nrow, 1)  (column vector)
    ## axis=2: reduce rows -> c(1, ncol)  (row vector)
    if (axis == 1L) {
      shape <- c(shape[1L], 1L)
    } else {
      shape <- c(1L, shape[2L])
    }
  }
  as.integer(shape)
}

# -- validate axis helper --------------------------------------------
## Used by AxisAtom subclasses to validate axis in constructors
.validate_axis <- function(axis, ndim = 2L) {
  if (!is.null(axis)) {
    axis <- as.integer(axis)
    if (axis < 0L) axis <- axis + ndim + 1L
    if (axis < 1L || axis > ndim) {
      .axis_out_of_bounds_error(axis, ndim)
    }
  }
  invisible(NULL)
}

## Informative error for axis out of bounds -- helps users migrate
.axis_out_of_bounds_error <- function(axis, ndim) {
  msg <- "axis {axis} is out of bounds for expression with {ndim} dimensions."
  hint <- NULL
  if (axis == 0L && ndim == 2L) {
    hint <- c(
      "i" = "CVXR uses 1-based axis indexing (R convention).",
      "i" = "Use {.code axis = 1L} for row-wise reduction (like {.fn apply} with MARGIN=1).",
      "i" = "Use {.code axis = 2L} for column-wise reduction (like {.fn apply} with MARGIN=2)."
    )
  }
  cli_abort(c(msg, hint))
}

# -- get_data --------------------------------------------------------
## CVXPY SOURCE: axis_atom.py lines 62-66

method(get_data, AxisAtom) <- function(x) {
  list(x@axis, x@keepdims)
}

# -- validate_arguments ----------------------------------------------
## CVXPY SOURCE: axis_atom.py lines 68-76

method(validate_arguments, AxisAtom) <- function(x) {
  if (!is.null(x@axis)) {
    ndim <- 2L
    axis <- x@axis
    if (axis < 0L) axis <- axis + ndim + 1L
    if (axis < 1L || axis > ndim) {
      .axis_out_of_bounds_error(x@axis, ndim)
    }
  }
  ## Call parent (Atom) validation -- rejects complex unless overridden
  ## We need to manually call Atom's validate_arguments since NextMethod
  ## dispatches to Atom's method
  if (.any_args(x, is_complex)) {
    cli_abort("Arguments to {.cls {class(x)[[1L]]}} cannot be complex.")
  }
  invisible(NULL)
}

# -- expr_name: include axis/keepdims data ---------------------------

method(expr_name, AxisAtom) <- function(x) {
  data <- get_data(x)
  data_str <- vapply(data, function(d) {
    if (is.null(d)) "NULL" else as.character(d)
  }, character(1))
  arg_strs <- vapply(x@args, expr_name, character(1))
  sprintf("%s(%s)", class(x)[[1L]], paste(c(arg_strs, data_str), collapse = ", "))
}

Try the CVXR package in your browser

Any scripts or data that you put into this service are public.

CVXR documentation built on March 6, 2026, 9:10 a.m.