R/007_utilities_shape.R

Defines functions .broadcast_2d mul_shapes mul_shapes_promote sum_shapes

#####
## DO NOT EDIT THIS FILE!! EDIT THE SOURCE INSTEAD: rsrc_tree/utilities/shape.R
#####

## CVXPY SOURCE: utilities/shape.py
## Shape inference rules for expression arithmetic
##
## All shapes in CVXR are integer(2) vectors (2D only, per ALLOW_ND_EXPR = FALSE).

#' Shape of a sum of expressions (with broadcasting)
#'
#' Computes the result shape from summing multiple shapes using
#' NumPy-style broadcasting rules (restricted to 2D).
#'
#' @param shapes List of integer(2) shape vectors
#' @returns Integer(2) shape vector (the broadcast result)
#' @noRd
sum_shapes <- function(shapes) {
  ## CVXPY: sum_shapes(shapes) in utilities/shape.py lines 27-49
  ## CVXPY delegates to np.broadcast_shapes(*shapes)
  ## We implement 2D broadcasting directly.
  if (length(shapes) == 0L) {
    cli_abort("sum_shapes requires at least one shape.")
  }
  result <- shapes[[1L]]
  for (i in seq_along(shapes)[-1L]) {
    result <- .broadcast_2d(result, shapes[[i]])
  }
  result
}

#' Promote shapes for matrix multiplication and return result shape
#'
#' Promotes 1D shapes to 2D as needed, checks inner dimension compatibility,
#' and returns all three shapes: promoted left, promoted right, and result.
#'
#' @param lh_shape Integer(2) shape of left operand
#' @param rh_shape Integer(2) shape of right operand
#' @returns List with components: lh_shape, rh_shape, shape (all integer(2))
#' @noRd
mul_shapes_promote <- function(lh_shape, rh_shape) {
  ## CVXPY: mul_shapes_promote(lh_shape, rh_shape) in utilities/shape.py lines 52-102
  lh_shape <- as.integer(lh_shape)
  rh_shape <- as.integer(rh_shape)

  ## Handle scalar cases: (1,1) is R equivalent of CVXPY's () shape
  ## Scalar * anything = anything (scalar multiplication)
  if (is_scalar_shape(lh_shape)) {
    return(list(lh_shape = lh_shape, rh_shape = rh_shape,
                shape = rh_shape))
  }
  if (is_scalar_shape(rh_shape)) {
    return(list(lh_shape = lh_shape, rh_shape = rh_shape,
                shape = lh_shape))
  }

  ## Inner dimension check
  if (lh_shape[2L] != rh_shape[1L]) {
    cli_abort("Incompatible dimensions ({lh_shape[1L]}, {lh_shape[2L]}) and ({rh_shape[1L]}, {rh_shape[2L]}).")
  }

  shape <- c(lh_shape[1L], rh_shape[2L])
  list(lh_shape = lh_shape,
       rh_shape = rh_shape,
       shape = as.integer(shape))
}

#' Shape of a matrix multiplication
#'
#' Computes the result shape of matmul(lh, rh) following np.matmul semantics.
#' Also supports scalar multiplication as a special case.
#'
#' @param lh_shape Integer(2) shape of left operand
#' @param rh_shape Integer(2) shape of right operand
#' @returns Integer(2) shape vector of the product
#' @noRd
mul_shapes <- function(lh_shape, rh_shape) {
  ## CVXPY: mul_shapes(lh_shape, rh_shape) in utilities/shape.py lines 105-135
  promoted <- mul_shapes_promote(lh_shape, rh_shape)
  as.integer(promoted$shape)
}

# -- Internal helper ---------------------------------------------------

#' 2D broadcasting (internal)
#'
#' Broadcasts two 2D shapes following NumPy rules:
#' - Dimensions must either be equal or one of them must be 1
#' - Result takes the maximum of each dimension
#'
#' @param s1 Integer(2) first shape
#' @param s2 Integer(2) second shape
#' @returns Integer(2) broadcast shape
#' @noRd
.broadcast_2d <- function(s1, s2) {
  result <- integer(2L)
  for (d in 1:2) {
    if (s1[d] == s2[d]) {
      result[d] <- s1[d]
    } else if (s1[d] == 1L) {
      result[d] <- s2[d]
    } else if (s2[d] == 1L) {
      result[d] <- s1[d]
    } else {
      cli_abort("Cannot broadcast shapes ({s1[1L]}, {s1[2L]}) and ({s2[1L]}, {s2[2L]}).")
    }
  }
  result
}

Try the CVXR package in your browser

Any scripts or data that you put into this service are public.

CVXR documentation built on March 6, 2026, 9:10 a.m.