R/basis.R

Defines functions fm_block_prep fm_block_log_shift fm_block_log_weights fm_block_weights fm_block_logsumexp_eval fm_block_eval fm_block internal_bspline2 internal_bspline fm_basis_mesh_1d fm_basis_mesh_2d fm_raw_basis internal_spline_mesh_1d fm_basis.fm_evaluator fm_basis.fm_basis fm_basis.list fm_basis.Matrix fm_basis.matrix fm_basis.fm_tensor fm_basis.fm_lattice_Nd fm_basis.fm_lattice_2d fm_basis.fm_mesh_3d fm_basis.fm_mesh_2d fm_basis.fm_mesh_1d fm_basis.default fm_basis fm_is_within

Documented in fm_basis fm_basis.default fm_basis.fm_basis fm_basis.fm_evaluator fm_basis.fm_lattice_2d fm_basis.fm_lattice_Nd fm_basis.fm_mesh_1d fm_basis.fm_mesh_2d fm_basis.fm_mesh_3d fm_basis.fm_tensor fm_basis.list fm_basis.matrix fm_basis.Matrix fm_basis_mesh_1d fm_basis_mesh_2d fm_block fm_block_eval fm_block_log_shift fm_block_logsumexp_eval fm_block_log_weights fm_block_prep fm_block_weights fm_is_within fm_raw_basis

# fm_is_within ####

#' @title Query if points are inside a mesh
#'
#' @description
#'  Queries whether each input point is within a mesh or not.
#'
#' @param x A set of points/locations of a class supported by `fm_basis(y, loc =
#'   x, ..., full = TRUE)`
#' @param y An [fm_mesh_2d] or other class supported by
#' `fm_basis(y, loc = x, ..., full = TRUE)`
#' @param \dots Passed on to [fm_basis()]
#' @returns A logical vector
#' @examples
#' all(fm_is_within(fmexample$loc, fmexample$mesh))
#' @export
fm_is_within <- function(x, y, ...) {
  fm_basis(y, loc = x, ..., full = TRUE)$ok
}



# fm_basis ####

#' @title Compute mapping matrix between mesh function space and points
#'
#' @description Computes the basis mapping matrix between a function space on a
#' mesh, and locations.
#'
#' @param x An function space object, or other supported object
#'   (`matrix`, `Matrix`, `list`)
#' @param loc A location/value information object (`numeric`, `matrix`, `sf`,
#' `fm_bary`, etc, depending on the class of `x`)
#' @param full logical; if `TRUE`, return a `fm_basis` object, containing at
#'   least a projection matrix `A` and logical vector `ok` indicating which
#'   evaluations are valid. If `FALSE`, return only the projection matrix `A`.
#'   Default is `FALSE`.
#' @param \dots Passed on to submethods
#' @returns A `sparseMatrix` object (if `full = FALSE`), or a `fm_basis` object
#'   (if `full = TRUE` or `isTRUE(derivatives)`). The `fm_basis` object contains
#'   at least the projection matrix `A` and logical vector `ok`; If `x_j`
#'   denotes the latent basis coefficient for basis function `j`, the field is
#'   defined as `u(loc_i)=sum_j A_ij x_j` for all `i` where `ok[i]` is `TRUE`,
#'   and `u(loc_i)=0.0` where `ok[i]` is `FALSE`.
#' @seealso [fm_raw_basis()]
#' @examples
#' # Compute basis mapping matrix
#' dim(fm_basis(fmexample$mesh, fmexample$loc))
#' print(fm_basis(fmexample$mesh, fmexample$loc, full = TRUE))
#'
#' # From precomputed `fm_bary` information:
#' bary <- fm_bary(fmexample$mesh, fmexample$loc)
#' print(fm_basis(fmexample$mesh, bary, full = TRUE))
#' @export
fm_basis <- function(x, ..., full = FALSE) {
  UseMethod("fm_basis")
}

#' @rdname fm_basis
#' @export
fm_basis.default <- function(x, ..., full = FALSE) {
  lifecycle::deprecate_stop(
    "0.1.7.9002",
    "fm_basis.default()",
    details = "Each mesh class needs its own `fm_basis()` method."
  )
}

#' @param derivatives If non-NULL and logical, include derivative matrices
#' in the output. Forces `full = TRUE`.
#' @describeIn fm_basis If `derivatives=TRUE`, the `fm_basis` object contains
#'   additional derivative weight matrices, `d1A` and `d2A`, `du/dx(loc_i)=sum_j
#'   dx_ij w_i`.
#' @export
fm_basis.fm_mesh_1d <- function(x,
                                loc,
                                weights = NULL,
                                derivatives = NULL,
                                ...,
                                full = FALSE) {
  result <- fm_basis_mesh_1d(
    x,
    loc = loc,
    weights = weights,
    derivatives = derivatives,
    ...
  )
  if (isTRUE(derivatives) && !full) {
    full <- TRUE
  }
  fm_basis(result, full = full)
}

#' @describeIn fm_basis If `derivatives=TRUE`, additional derivative weight
#'   matrices are included in the `full=TRUE` output: Derivative weight matrices
#' `dx`, `dy`, `dz`; `du/dx(loc_i)=sum_j dx_ij w_i`, etc.
#' @export
fm_basis.fm_mesh_2d <- function(x, loc, weights = NULL, derivatives = NULL, ...,
                                full = FALSE) {
  result <- fm_basis_mesh_2d(
    x,
    loc = loc,
    weights = weights,
    derivatives = derivatives,
    ...
  )
  if (isTRUE(derivatives) && !full) {
    full <- TRUE
  }
  fm_basis(result, full = full)
}

#' @describeIn fm_basis `fm_mesh_3d` basis functions.
#' @export
fm_basis.fm_mesh_3d <- function(x, loc, weights = NULL, ...,
                                full = FALSE) {
  bary <- fm_bary(x, loc, ...)
  n_loc <- NROW(bary)
  ok <- !is.na(bary$index)
  simplex <- fm_bary_simplex(x, bary[ok, , drop = FALSE])
  if (is.null(weights)) {
    weights <- rep(1.0, n_loc)
  } else if (length(weights) == 1) {
    weights <- rep(weights, n_loc)
  }
  A <- Matrix::sparseMatrix(
    i = rep(which(ok), 4),
    j = as.vector(simplex),
    x = as.numeric(as.vector(bary$where[ok, ]) * weights[rep(which(ok), 4)]),
    dims = c(n_loc, fm_dof(x))
  )

  fm_basis(list(A = A, ok = ok, bary = bary), full = full)
}

#' @describeIn fm_basis `fm_lattice_2d` bilinear basis functions.
#' @export
fm_basis.fm_lattice_2d <- function(x, loc, weights = NULL, ...,
                                   full = FALSE) {
  bary <- fm_bary(x, loc, ...)
  n_loc <- NROW(bary)
  ok <- !is.na(bary$index)
  simplex <- fm_bary_simplex(x, bary[ok, , drop = FALSE])
  if (is.null(weights)) {
    weights <- rep(1.0, n_loc)
  } else if (length(weights) == 1) {
    weights <- rep(weights, n_loc)
  }
  A <- Matrix::sparseMatrix(
    i = rep(which(ok), 4),
    j = as.vector(simplex),
    x = as.numeric(as.vector(bary$where[ok, ]) * weights[rep(which(ok), 4)]),
    dims = c(n_loc, fm_dof(x))
  )

  fm_basis(list(A = A, ok = ok), full = full)
}

#' @describeIn fm_basis `fm_lattice_Nd` multilinear basis functions.
#' @export
fm_basis.fm_lattice_Nd <- function(x, loc, weights = NULL, ...,
                                   full = FALSE) {
  bary <- fm_bary(x, loc, ...)
  n_loc <- NROW(bary)
  ok <- !is.na(bary$index)
  simplex <- fm_bary_simplex(x, bary[ok, , drop = FALSE])
  if (is.null(weights)) {
    weights <- rep(1.0, n_loc)
  } else if (length(weights) == 1) {
    weights <- rep(weights, n_loc)
  }
  A <- Matrix::sparseMatrix(
    i = rep(which(ok), ncol(bary$where)),
    j = as.vector(simplex),
    x = as.numeric(as.vector(bary$where[ok, ]) *
      weights[rep(which(ok), ncol(bary$where))]),
    dims = c(n_loc, fm_dof(x))
  )

  fm_basis(list(A = A, ok = ok), full = full)
}

#' @export
#' @describeIn fm_basis Evaluates a basis matrix for a `fm_tensor` function
#'   space.
fm_basis.fm_tensor <- function(x,
                               loc,
                               weights = NULL,
                               ...,
                               full = FALSE) {
  if (length(loc) != length(x[["fun_spaces"]])) {
    stop(
      paste0(
        "Length of location list (",
        length(loc), ") doesn't match the number of function spaces (",
        length(x[["fun_spaces"]]),
        ")"
      )
    )
  }
  if (is.null(names(loc))) {
    names(loc) <- names(x[["fun_spaces"]])
  } else if (!setequal(names(x[["fun_spaces"]]), names(loc))) {
    stop("Name mismatch between location list names and function space names.")
  }
  idx <- names(x[["fun_spaces"]])
  if (is.null(idx)) {
    idx <- seq_along(x[["fun_spaces"]])
  }
  proj <- lapply(
    idx,
    function(k) {
      fm_basis(x[["fun_spaces"]][[k]], loc = loc[[k]], full = TRUE)
    }
  )
  names(proj) <- names(x[["fun_spaces"]])

  # Combine the matrices
  # (A1, A2, A3) -> rowkron(A3, rowkron(A2, A1))
  A <- proj[[1]][["A"]]
  if (!is.null(weights)) {
    A <- Matrix::Diagonal(nrow(A), x = weights) %*% A
  }
  ok <- proj[[1]][["ok"]]
  for (k in seq_len(length(x[["fun_spaces"]]) - 1)) {
    A <- fm_row_kron(proj[[k + 1]][["A"]], A)
    ok <- proj[[k + 1]][["ok"]] & ok
  }

  fm_basis(
    list(A = A, ok = ok),
    full = full
  )
}


#' @describeIn fm_basis Creates a new `fm_basis` object with elements `A` and
#'   `ok`, from a pre-evaluated basis matrix, including optional additional
#'   elements in the `...` arguments. If a `ok` is `NULL`, it is inferred as
#'   `rep(TRUE, NROW(x))`, indicating that all rows correspond to successful
#'   basis evaluations. If `full = FALSE`,
#'   returns the matrix unchanged.
#' @param ok numerical of length `NROW(x)`, indicating which rows of `x` are
#'   valid/successful basis evaluations. If `NULL`, inferred as
#'   `rep(TRUE, NROW(x))`.
#' @param weights Optional weight vector to apply (from the left, one
#' weight for each row of the basis matrix)
#' @export
fm_basis.matrix <- function(x, ok = NULL, weights = NULL, ..., full = FALSE) {
  if (!full) {
    return(x)
  }
  fm_basis(list(A = x, ok = ok, ...), weights = weights, full = TRUE)
}

#' @describeIn fm_basis Creates a new `fm_basis` object with elements `A` and
#'   `ok`, from a pre-evaluated basis matrix, including optional additional
#'   elements in the `...` arguments. If a `ok` is `NULL`, it is inferred as
#'   `rep(TRUE, NROW(x))`, indicating that all rows correspond to successful
#'   basis evaluations. If `full = FALSE`,
#'   returns the matrix unchanged.
#' @export
fm_basis.Matrix <- function(x, ok = NULL, weights = NULL, ..., full = FALSE) {
  if (!full) {
    return(x)
  }
  fm_basis(list(A = x, ok = ok, ...), weights = weights, full = TRUE)
}

#' @describeIn fm_basis Creates a new `fm_basis` object from a plain list
#'   containing at least an element `A`. If an `ok` element is missing,
#'   it is inferred as `rep(TRUE, NROW(x$A))`. If `full = FALSE`,
#'   extracts the `A` matrix.
#' @export
fm_basis.list <- function(x, weights = NULL, ..., full = FALSE) {
  stopifnot("A" %in% names(x))
  if (!is.null(weights)) {
    x[["A"]] <- Matrix::Diagonal(nrow(x[["A"]]), x = weights) %*% x[["A"]]
  }
  if (!full) {
    return(x[["A"]])
  }
  if (is.null(x[["ok"]])) {
    x[["ok"]] <- rep(TRUE, NROW(x[["A"]]))
  } else if (!is.logical(x[["ok"]]) ||
    (length(x[["ok"]]) != NROW(x[["A"]]))) {
    stop(
      "Invalid 'ok' element in 'x'; should be a logical vector of length ",
      NROW(x[["A"]])
    )
  }
  structure(x, class = "fm_basis")
}

#' @describeIn fm_basis If `full` is `TRUE`, returns `x` unchanged, otherwise
#'   returns the `A` matrix contained in `x`.
#' @export
fm_basis.fm_basis <- function(x, ..., full = FALSE) {
  if (full) {
    x
  } else {
    x[["A"]]
  }
}

#' @describeIn fm_basis Extract `fm_basis` information from an `fm_evaluator`
#'   object. If `full = FALSE`, returns the `A` matrix contained in the
#'   `fm_basis` object.
#' @export
fm_basis.fm_evaluator <- function(x, ..., full = FALSE) {
  fm_basis(x$proj, full = full)
}




internal_spline_mesh_1d <- function(interval,
                                    m,
                                    degree,
                                    boundary,
                                    free.clamped) {
  boundary <-
    match.arg(
      boundary,
      c("neumann", "dirichlet", "free", "cyclic")
    )
  if (degree <= 1) {
    n <- (switch(boundary,
      neumann = m,
      dirichlet = m + 2,
      free = m,
      cyclic = m + 1
    ))
    if (n < 2) {
      n <- 2
      degree <- 0
      boundary <- "c"
    }
  } else {
    stopifnot(degree == 2)
    n <- (switch(boundary,
      neumann = m + 1,
      dirichlet = m + 1,
      free = m - 1,
      cyclic = m
    ))
    if (boundary == "free") {
      if (m <= 1) {
        n <- 2
        degree <- 0
        boundary <- "c"
      } else if (m == 2) {
        n <- 2
        degree <- 1
      }
    } else if (boundary == "cyclic") {
      if (m <= 1) {
        n <- 2
        degree <- 0
      }
    }
  }
  fm_mesh_1d(seq(interval[1], interval[2], length.out = n),
    degree = degree,
    boundary = boundary,
    free.clamped = free.clamped
  )
}


#' Basis functions for mesh manifolds
#'
#' Calculate basis functions on [fm_mesh_1d()] or [fm_mesh_2d()],
#' without necessarily matching the default function space of the given mesh
#' object.
#'
#' @param mesh An [fm_mesh_1d()] or [fm_mesh_2d()] object.
#' @param type `b.spline` (default) for B-spline basis functions,
#' `sph.harm` for spherical harmonics (available only for meshes on the
#' sphere)
#' @param n For B-splines, the number of basis functions in each direction (for
#' 1d meshes `n` must be a scalar, and for planar 2d meshes a 2-vector).
#' For spherical harmonics, `n` is the maximal harmonic order.
#' @param degree Degree of B-spline polynomials.  See
#' [fm_mesh_1d()].
#' @param knot.placement For B-splines on the sphere, controls the latitudinal
#' placements of knots. `"uniform.area"` (default) gives uniform spacing
#' in `sin(latitude)`, `"uniform.latitude"` gives uniform spacing in
#' latitudes.
#' @param rot.inv For spherical harmonics on a sphere, `rot.inv=TRUE`
#' gives the rotationally invariant subset of basis functions.
#' @param boundary Boundary specification, default is free boundaries.  See
#' [fm_mesh_1d()] for more information.
#' @param free.clamped If `TRUE` and `boundary` is `"free"`, the
#' boundary basis functions are clamped to 0/1 at the interval boundary by
#' repeating the boundary knots. See
#' [fm_mesh_1d()] for more information.
#' @param ... Unused
#' @returns A matrix with evaluated basis function
#' @author Finn Lindgren \email{finn.lindgren@@gmail.com}
#' @seealso [fm_mesh_1d()], [fm_mesh_2d()], [fm_basis()]
#' @examples
#'
#' loc <- rbind(c(0, 0), c(1, 0), c(1, 1), c(0, 1))
#' mesh <- fm_mesh_2d(loc, max.edge = 0.15)
#' basis <- fm_raw_basis(mesh, n = c(4, 5))
#'
#' proj <- fm_evaluator(mesh, dims = c(10, 10))
#' image(proj$x, proj$y, fm_evaluate(proj, basis[, 7]), asp = 1)
#' \donttest{
#' if (interactive() && require("rgl")) {
#'   plot_rgl(mesh, col = basis[, 7], draw.edges = FALSE, draw.vertices = FALSE)
#' }
#' }
#'
#' @export
fm_raw_basis <- function(mesh,
                         type = "b.spline",
                         n = 3,
                         degree = 2,
                         knot.placement = "uniform.area",
                         rot.inv = TRUE,
                         boundary = "free",
                         free.clamped = TRUE,
                         ...) {
  type <- match.arg(type, c("b.spline", "sph.harm"))
  knot.placement <- (match.arg(
    knot.placement,
    c(
      "uniform.area",
      "uniform.latitude"
    )
  ))

  if (identical(type, "b.spline")) {
    if (fm_manifold(mesh, c("R1", "S1"))) {
      mesh1 <-
        internal_spline_mesh_1d(
          mesh$interval, n, degree,
          boundary, free.clamped
        )
      basis <- fm_basis(mesh1, mesh$loc)
    } else if (identical(mesh$manifold, "R2")) {
      if (length(n) == 1) {
        n <- rep(n, 2)
      }
      if (length(degree) == 1) {
        degree <- rep(degree, 2)
      }
      if (length(boundary) == 1) {
        boundary <- rep(boundary, 2)
      }
      if (length(free.clamped) == 1) {
        free.clamped <- rep(free.clamped, 2)
      }
      mesh1x <-
        internal_spline_mesh_1d(
          range(mesh$loc[, 1]),
          n[1], degree[1],
          boundary[1], free.clamped[1]
        )
      mesh1y <-
        internal_spline_mesh_1d(
          range(mesh$loc[, 2]),
          n[2], degree[2],
          boundary[2], free.clamped[2]
        )
      basis <-
        fm_row_kron(
          fm_basis(mesh1y, mesh$loc[, 2]),
          fm_basis(mesh1x, mesh$loc[, 1])
        )
    } else if (identical(mesh$manifold, "S2")) {
      loc <- mesh$loc
      uniform.lat <- identical(knot.placement, "uniform.latitude")
      degree <- max(0L, min(n - 1L, degree))
      basis <- fmesher_spherical_bsplines1(
        loc[, 3],
        n = n,
        degree = degree,
        uniform = uniform.lat
      )
      if (!rot.inv) {
        warning("Currently only 'rot.inv=TRUE' is supported for B-splines.")
      }
    } else {
      stop("Only know how to make B-splines on R2 and S2.")
    }
  } else if (identical(type, "sph.harm")) {
    if (!identical(mesh$manifold, "S2")) {
      stop("Only know how to make spherical harmonics on S2.")
    }
    # With GSL activated:
    #        if (rot.inv) {
    #            basis <- (inla.fmesher.smorg(
    #                mesh$loc,
    #                mesh$graph$tv,
    #                sph0 = n
    #            )$sph0)
    #        } else {
    #            basis <- (inla.fmesher.smorg(
    #                mesh$loc,
    #                mesh$graph$tv,
    #                sph = n
    #            )$sph)
    #        }

    fm_require_stop(
      "gsl",
      "The 'gsl' R package is needed for spherical harmonics."
    )

    # Make sure we have radius-1 coordinates
    loc <- mesh$loc / rowSums(mesh$loc^2)^0.5
    if (rot.inv) {
      basis <- matrix(0, nrow(loc), n + 1)
      for (l in seq(0, n)) {
        basis[, l + 1] <- sqrt(2 * l + 1) *
          gsl::legendre_Pl(l = l, x = loc[, 3])
      }
    } else {
      angle <- atan2(loc[, 2], loc[, 1])
      basis <- matrix(0, nrow(loc), (n + 1)^2)
      for (l in seq(0, n)) {
        basis[, 1 + l * (l + 1)] <-
          sqrt(2 * l + 1) *
            gsl::legendre_Pl(l = l, x = loc[, 3])
        for (m in seq_len(l)) {
          scaling <- sqrt(2 * (2 * l + 1) * exp(lgamma(l - m + 1) -
            lgamma(l + m + 1)))
          poly <- gsl::legendre_Plm(l = l, m = m, x = loc[, 3])
          basis[, 1 + l * (l + 1) - m] <-
            scaling * sin(-m * angle) * poly
          basis[, 1 + l * (l + 1) + m] <-
            scaling * cos(m * angle) * poly
        }
      }
    }
  }

  return(basis)
}




#' @title Internal helper functions for mesh field evaluation
#'
#' @description Methods called internally by [fm_basis()] methods.
#' @param weights Optional weight vector, one weight for each location
#' @param derivatives logical; If true, also return matrices `dA` and `d2A`
#' for `fm_mesh_1d` objects, and `dx`, `dy`, `dz` for `fm_mesh_2d`.
#' @inheritParams fm_basis
#' @export
#' @keywords internal
#' @returns A `fm_basis` object; a list of evaluator information objects,
#' at least a matrix `A` and logical vector `ok`.
#' @name fm_basis_helpers
#' @examples
#' str(fm_basis_mesh_2d(fmexample$mesh, loc = fmexample$loc))
#'
fm_basis_mesh_2d <- function(mesh,
                             loc = NULL,
                             weights = NULL,
                             derivatives = NULL,
                             crs = NULL,
                             ...) {
  if (!inherits(loc, "fm_bary")) {
    loc <- fm_bary(mesh, loc = loc, crs = crs, ...)
  }
  n_loc <- NROW(loc)

  ok <- !is.na(loc$index)

  if (is.null(weights)) {
    weights <- rep(1.0, n_loc)
  } else if (length(weights) == 1) {
    weights <- rep(weights, n_loc)
  }

  ii <- which(ok)
  A <- (Matrix::sparseMatrix(
    dims = c(n_loc, mesh$n),
    i = rep(ii, 3),
    j = as.vector(mesh$graph$tv[loc$index[ii], ]),
    x = as.numeric(as.vector(loc$where[ii, ]) * weights[rep(ii, 3)])
  ))

  mesh_deriv <- function(mesh, bary, ok, weights) {
    n.mesh <- mesh$n

    ii <- which(ok)
    n.ok <- sum(ok)
    tv <- mesh$graph$tv[bary$index[ii], , drop = FALSE]
    e1 <- mesh$loc[tv[, 3], , drop = FALSE] - mesh$loc[tv[, 2], , drop = FALSE]
    e2 <- mesh$loc[tv[, 1], , drop = FALSE] - mesh$loc[tv[, 3], , drop = FALSE]
    e3 <- mesh$loc[tv[, 2], , drop = FALSE] - mesh$loc[tv[, 1], , drop = FALSE]
    n1 <- e2 - e1 * matrix(rowSums(e1 * e2) / rowSums(e1 * e1), n.ok, 3)
    n2 <- e3 - e2 * matrix(rowSums(e2 * e3) / rowSums(e2 * e2), n.ok, 3)
    n3 <- e1 - e3 * matrix(rowSums(e3 * e1) / rowSums(e3 * e3), n.ok, 3)
    g1 <- n1 / matrix(rowSums(n1 * n1), n.ok, 3)
    g2 <- n2 / matrix(rowSums(n2 * n2), n.ok, 3)
    g3 <- n3 / matrix(rowSums(n3 * n3), n.ok, 3)
    x <- cbind(g1[, 1], g2[, 1], g3[, 1])
    y <- cbind(g1[, 2], g2[, 2], g3[, 2])
    z <- cbind(g1[, 3], g2[, 3], g3[, 3])
    dx <- (Matrix::sparseMatrix(
      dims = c(n_loc, n.mesh),
      i = rep(ii, 3),
      j = as.vector(tv),
      x = as.vector(x) * weights[rep(ii, 3)]
    ))
    dy <- (Matrix::sparseMatrix(
      dims = c(n_loc, n.mesh),
      i = rep(ii, 3),
      j = as.vector(tv),
      x = as.vector(y) * weights[rep(ii, 3)]
    ))
    dz <- (Matrix::sparseMatrix(
      dims = c(n_loc, n.mesh),
      i = rep(ii, 3),
      j = as.vector(tv),
      x = as.vector(z) * weights[rep(ii, 3)]
    ))

    list(dx = dx, dy = dy, dz = dz)
  }

  info <- list(bary = loc, A = A, ok = ok)

  if (!is.null(derivatives) && derivatives) {
    info <-
      c(
        info,
        mesh_deriv(
          mesh = mesh,
          bary = info$bary,
          ok = info$ok,
          weights = weights
        )
      )
  }

  fm_basis(info, full = TRUE)
}


#' @param method character; either "default", "nearest", "linear", or
#' "quadratic". With `NULL` or "default", uses the object definition of the
#' function space. Otherwise overrides the object definition.
#' @export
#' @rdname fm_basis_helpers
fm_basis_mesh_1d <- function(mesh,
                             loc,
                             weights = NULL,
                             derivatives = NULL,
                             method = deprecated(),
                             ...) {
  if (lifecycle::is_present(method)) {
    lifecycle::deprecate_stop(
      "0.0.9.9020",
      "fm_evaluator_mesh_1d(method)",
      details = c("Create a separate fm_mesh_1d() object instead.")
    )
    method <- match.arg(method, c(
      "default",
      "nearest",
      "linear",
      "quadratic"
    ))

    if (!(method %in% "default") &&
      (mesh$degree != c(nearest = 0, linear = 1, quadratic = 2)[method])) {
      deg <- c(nearest = 0, linear = 1, quadratic = 2)[method]
      info <- fm_basis_mesh_1d(
        fm_mesh_1d(mesh$loc,
          interval = mesh$interval,
          boundary = mesh$boundary,
          free.clamped = mesh$free.clamped,
          degree = deg
        ),
        loc = loc,
        weights = weights,
        derivatives = derivatives
      )
      return(info)
    }
  }

  if (is.null(weights)) {
    weights <- rep(1.0, NROW(loc))
  } else if (length(weights) == 1L) {
    weights <- rep(weights, NROW(loc))
  }

  derivatives <- !is.null(derivatives) && derivatives
  info_ <- list()

  ## Compute basis based on mesh$degree and mesh$boundary
  if (mesh$degree == 0) {
    info <- fm_bary(mesh, loc = loc, method = "nearest")
    info_ <- list(bary = info)
    bary_ok <- !is.na(info$index)
    i_ <- seq_along(loc)[bary_ok]
    j_ <- info$index[bary_ok]
    x_ <- info$where[bary_ok, 1]
    if (derivatives) {
      if (mesh$cyclic) {
        j_prev <- (j_ - 2L) %% mesh$n + 1L
        j_next <- j_ %% mesh$n + 1L
        ok <- rep(TRUE, length(j_))
        dist <- (mesh$loc[j_next] - mesh$loc[j_prev]) %% diff(mesh$interval)
      } else {
        j_prev <- j_ - 1L
        j_next <- j_ + 1L
        ok <- (j_prev >= 1L) & (j_next <= mesh$n)
        dist <- mesh$loc[j_next] - mesh$loc[j_prev]
      }
      i_d <- c(i_[ok], i_[ok])
      j_d <- c(j_prev[ok], j_next[ok])
      x_d <- c(-x_[ok], x_[ok]) / dist
    }

    if (mesh$boundary[1] == "dirichlet") {
      ok <- j_ > 1L
      i_ <- i_[ok]
      j_ <- j_[ok] - 1L
      x_ <- x_[ok]
      if (derivatives) {
        ok <- j_d > 1L
        i_d <- i_d[ok]
        j_d <- j_d[ok] - 1L
        x_d <- x_d[ok]
      }
    }
    if (mesh$boundary[2] == "dirichlet") {
      ok <- j_ <= mesh$m
      i_ <- i_[ok]
      j_ <- j_[ok]
      x_ <- x_[ok]
      if (derivatives) {
        ok <- j_d <= mesh$m
        i_d <- i_d[ok]
        j_d <- j_d[ok]
        x_d <- x_d[ok]
      }
    }
  } else if (mesh$degree == 1) {
    info <- fm_bary(mesh, loc = loc, method = "linear")
    info_ <- list(bary = info)
    bary_ok <- !is.na(info$index)
    info <- info[bary_ok, ]
    i_ <- c(which(bary_ok), which(bary_ok))
    simplex <- fm_bary_simplex(mesh, info)
    j_ <- as.vector(simplex)
    x_ <- as.vector(info$where)
    if (derivatives) {
      j_curr <- simplex[, 1]
      j_next <- simplex[, 2]
      if (mesh$cyclic) {
        if (mesh$n > 1) {
          dist <- (mesh$loc[j_next] - mesh$loc[j_curr]) %% diff(mesh$interval)
        } else {
          dist <- rep(diff(mesh$interval), length(j_curr))
        }
      } else {
        dist <- mesh$loc[j_next] - mesh$loc[j_curr]
      }
      i_d <- i_
      j_d <- c(j_curr, j_next)
      x_d <- rep(c(-1, 1), each = sum(bary_ok)) / rep(dist, times = 2)
    }

    if (mesh$boundary[1] == "dirichlet") {
      ok <- j_ > 1L
      i_ <- i_[ok]
      j_ <- j_[ok] - 1L
      x_ <- x_[ok]
      if (derivatives) {
        ok <- j_d > 1L
        i_d <- i_d[ok]
        j_d <- j_d[ok] - 1L
        x_d <- x_d[ok]
      }
    } else if (mesh$boundary[1] == "neumann") {
      if (derivatives) {
        x_d[(j_ == 1) & (x_ > 1)] <- 0.0
        x_d[(j_ == 2) & (x_ < 0)] <- 0.0
      }
      # Set Anew[, 1] = 1 on the left
      # Set Anew[, 2] = 0 on the left
      x_[(j_ == 1) & (x_ > 1)] <- 1.0
      x_[(j_ == 2) & (x_ < 0)] <- 0.0
    }
    if (mesh$boundary[2] == "dirichlet") {
      ok <- j_ <= mesh$m
      i_ <- i_[ok]
      j_ <- j_[ok]
      x_ <- x_[ok]
      if (derivatives) {
        ok <- j_d <= mesh$m
        i_d <- i_d[ok]
        j_d <- j_d[ok]
        x_d <- x_d[ok]
      }
    } else if (mesh$boundary[2] == "neumann") {
      if (derivatives) {
        x_d[(j_ == mesh$m) & (x_ > 1)] <- 0.0
        x_d[(j_ == mesh$m - 1L) & (x_ < 0)] <- 0.0
      }
      # Set Anew[, m] = 1 on the right
      # Set Anew[, m-1] = 0 on the right
      x_[(j_ == mesh$m) & (x_ > 1)] <- 1.0
      x_[(j_ == mesh$m - 1L) & (x_ < 0)] <- 0.0
    }
  } else if (mesh$degree == 2) {
    if (mesh$cyclic) {
      knots <- mesh$loc - mesh$loc[1]
      loc <- loc - mesh$loc[1]
      inter <- c(0, diff(mesh$interval))
    } else {
      knots <- mesh$loc - mesh$loc[1]
      loc <- loc - mesh$loc[1]
      inter <- range(knots)
    }

    # Note: If loc is `fm_bary`, it's still valid for this local fm_mesh_1d.
    info <-
      fm_bary(
        fm_mesh_1d(
          knots,
          interval = inter,
          boundary = if (isTRUE(mesh$cyclic)) "cyclic" else "free",
          degree = 1
        ),
        loc = loc,
        method = "linear"
      )
    info_ <- list(bary = info)
    bary_ok <- !is.na(info$index)
    info <- info[bary_ok, ]

    if (mesh$cyclic) {
      d <-
        (knots[c(seq_len(length(knots) - 1L) + 1L, 1)] - knots) %%
        diff(mesh$interval)
      d2 <- (knots[c(seq_len(length(knots) - 2L) + 2L, seq_len(2))] -
        knots) %% diff(mesh$interval)
      d2[d2 == 0] <- diff(mesh$interval)
      d <- d[c(length(d), seq_len(length(d) - 1L))]
      d2 <- d2[c(length(d2), seq_len(length(d2) - 1L))]
    } else {
      d <- knots[-1] - knots[-length(knots)]
      d2 <- knots[-seq_len(2)] - knots[seq_len(length(knots) - 2L)]
    }

    if (mesh$cyclic) {
      ## Left intervals for each basis function:
      simplex <- fm_bary_simplex(mesh, info)
      i.l <- which(bary_ok)
      j.l <- simplex[, 1] + 2L
      x.l <- (info$where[, 2] * d[simplex[, 2]] / d2[simplex[, 2]] *
        info$where[, 2])
      if (derivatives) {
        x.d1.l <- (2 / d2[simplex[, 2]] * info$where[, 2])
        x.d2.l <- (2 / d2[simplex[, 2]] / d[simplex[, 2]])
      }
      ## Right intervals for each basis function:
      i.r <- seq_along(simplex[, 1])
      j.r <- simplex[, 1]
      x.r <- (info$where[, 1] * d[simplex[, 2]] / d2[simplex[, 1]] *
        info$where[, 1])
      if (derivatives) {
        x.d1.r <- -(2 / d2[simplex[, 2]] * info$where[, 1])
        x.d2.r <- (2 / d2[simplex[, 1]] / d[simplex[, 2]])
      }
      ## Middle intervals for each basis function:
      i.m <- seq_along(simplex[, 1])
      j.m <- simplex[, 1] + 1L
      x.m <- (1 - (info$where[, 1] * d[simplex[, 2]] / d2[simplex[, 1]] *
        info$where[, 1] +
        info$where[, 2] * d[simplex[, 2]] / d2[simplex[, 2]] *
          info$where[, 2]))
      if (derivatives) {
        x.d1.m <- (2 / d2[simplex[, 1]] * info$where[, 1]) -
          (2 / d2[simplex[, 2]] * info$where[, 2])
        x.d2.m <- -(2 / d2[simplex[, 1]] / d[simplex[, 2]]) -
          (2 / d2[simplex[, 2]] / d[simplex[, 2]])
      }
    } else {
      d2 <- c(2 * d[1], 2 * d[1], d2, 2 * d[length(d)], 2 * d[length(d)])
      d <- c(d[1], d[2], d, d[length(d)], d[length(d)])
      simplex <- fm_bary_simplex(mesh, info)
      ok <- (simplex[, 1] >= 1L) & (simplex[, 2] <= length(knots))
      index <- simplex[ok, , drop = FALSE] + 1L
      bary <- info$where[ok, , drop = FALSE]
      ## Left intervals for each basis function:
      i.l <- which(bary_ok)[ok]
      j.l <- index[, 2]
      x.l <- (bary[, 2] * d[index[, 2]] / d2[index[, 2]] * bary[, 2])
      if (derivatives) {
        x.d1.l <- (2 / d2[index[, 2]] * bary[, 2])
        x.d2.l <- (2 / d2[index[, 2]] / d[index[, 2]])
      }
      ## Right intervals for each basis function:
      i.r <- which(bary_ok)[ok]
      j.r <- index[, 1] - 1L
      x.r <- (bary[, 1] * d[index[, 2]] / d2[index[, 1]] * bary[, 1])
      if (derivatives) {
        x.d1.r <- -(2 / d2[index[, 1]] * bary[, 1])
        x.d2.r <- (2 / d2[index[, 1]] / d[index[, 2]])
      }
      ## Middle intervals for each basis function:
      i.m <- which(bary_ok)[ok]
      j.m <- index[, 1]
      x.m <- (1 - (bary[, 1] * d[index[, 2]] / d2[index[, 1]] * bary[, 1] +
        bary[, 2] * d[index[, 2]] / d2[index[, 2]] * bary[, 2]
      ))
      if (derivatives) {
        x.d1.m <- (2 / d2[index[, 1]] * bary[, 1]) -
          (2 / d2[index[, 2]] * bary[, 2])
        x.d2.m <- -(2 / d2[index[, 1]] / d[index[, 2]]) -
          (2 / d2[index[, 2]] / d[index[, 2]])
      }
    }

    i_ <- c(i.l, i.r, i.m)
    j_ <- c(j.l, j.r, j.m)
    x_ <- c(x.l, x.r, x.m)
    if (derivatives) {
      x_d1 <- c(x.d1.l, x.d1.r, x.d1.m)
      x_d2 <- c(x.d2.l, x.d2.r, x.d2.m)
    }

    if (!mesh$cyclic) {
      simplex <- fm_bary_simplex(mesh, info)

      # Convert boundary basis functions to linear
      # First remove anything from above outside the interval, then add back in
      # the appropriate values
      ok <- (loc[bary_ok] >= inter[1]) & (loc[bary_ok] <= inter[2])
      i_ <- i_[ok]
      j_ <- j_[ok]
      x_ <- x_[ok]
      if (derivatives) {
        x_d1 <- x_d1[ok]
        x_d2 <- x_d2[ok]
      }

      # left
      ok <- (loc < 0) & (simplex[, 1] == 1L)
      i_l <- c(which(bary_ok)[ok], which(bary_ok)[ok])
      j_l <- c(simplex[ok, 1], simplex[ok, 2])
      x_l <- c(
        0.5 + (info$where[ok, 1] - 1),
        0.5 - (info$where[ok, 1] - 1)
      )
      if (derivatives) {
        x_d1_l <- rep(c(-1, 1) / d[1], each = sum(ok))
        x_d2_l <- rep(c(0, 0), each = sum(ok))
      }

      # right
      ok <- (loc > inter[2]) & (simplex[, 2] == length(knots))
      i_r <- c(which(bary_ok)[ok], which(bary_ok)[ok])
      j_r <- c(simplex[ok, 2], simplex[ok, 1]) + 1L
      x_r <- c(
        0.5 + (info$where[ok, 2] - 1),
        0.5 - (info$where[ok, 2] - 1)
      )
      if (derivatives) {
        x_d1_r <- rep(c(1, -1) / d[length(d)], each = sum(ok))
        x_d2_r <- rep(c(0, 0), each = sum(ok))
      }

      i_ <- c(i_, i_l, i_r)
      j_ <- c(j_, j_l, j_r)
      x_ <- c(x_, x_l, x_r)
      if (derivatives) {
        x_d1 <- c(x_d1, x_d1_l, x_d1_r)
        x_d2 <- c(x_d2, x_d2_l, x_d2_r)
      }

      if (mesh$boundary[1] == "dirichlet") {
        ok <- j_ > 1L
        j_[ok] <- j_[ok] - 1L
        x_[!ok] <- -x_[!ok]
        if (derivatives) {
          x_d1[!ok] <- -x_d1[!ok]
          x_d2[!ok] <- -x_d2[!ok]
        }
      } else if (mesh$boundary[1] == "neumann") {
        ok <- j_ > 1L
        j_[ok] <- j_[ok] - 1L
      } else if ((mesh$boundary[1] == "free") &&
        (mesh$free.clamped[1])) {
        # new1 <- 2 * basis1
        # new2 <- basis2 - basis1
        ok1 <- j_ == 1L
        i2 <- i_[ok1]
        j2 <- j_[ok1] + 1L
        x2 <- -x_[ok1]
        x_[ok1] <- 2 * x_[ok1]
        if (derivatives) {
          x2_d1 <- -x_d1[ok1]
          x_d1[ok1] <- 2 * x_d1[ok1]
          x2_d2 <- -x_d2[ok1]
          x_d2[ok1] <- 2 * x_d2[ok1]
        }
        i_ <- c(i_, i2)
        j_ <- c(j_, j2)
        x_ <- c(x_, x2)
        if (derivatives) {
          x_d1 <- c(x_d1, x2_d1)
          x_d2 <- c(x_d2, x2_d1)
        }
      }
      if (mesh$boundary[2] == "dirichlet") {
        ok <- j_ > mesh$m
        j_[ok] <- j_[ok] - 1L
        x_[ok] <- -x_[ok]
        if (derivatives) {
          x_d1[ok] <- -x_d1[ok]
          x_d2[ok] <- -x_d2[ok]
        }
      } else if (mesh$boundary[2] == "neumann") {
        ok <- j_ > mesh$m
        j_[ok] <- mesh$m
      } else if ((mesh$boundary[2] == "free") &&
        (mesh$free.clamped[2])) {
        # new_m <- m + {m-1};     m = 1, m - 1 = 2
        # new_{m-1} <- {m-1} - m; m = 1, m - 1 = 2
        # new1 <- 2 * basis1
        # new2 <- basis2 - basis1
        ok1 <- j_ == mesh$m
        i2 <- i_[ok1]
        j2 <- j_[ok1] - 1L
        x2 <- -x_[ok1]
        x_[ok1] <- 2 * x_[ok1]
        if (derivatives) {
          x2_d1 <- -x_d1[ok1]
          x_d1[ok1] <- 2 * x_d1[ok1]
          x2_d2 <- -x_d2[ok1]
          x_d2[ok1] <- 2 * x_d2[ok1]
        }
        i_ <- c(i_, i2)
        j_ <- c(j_, j2)
        x_ <- c(x_, x2)
        if (derivatives) {
          x_d1 <- c(x_d1, x2_d1)
          x_d2 <- c(x_d2, x2_d1)
        }
      }
    }

    if (mesh$cyclic) {
      j_ <- (j_ - 1L - 1L) %% mesh$m + 1L
    }
  } else {
    stop("Unsupported B-spline degree = ", mesh$degree)
  }

  info_$A <- Matrix::sparseMatrix(
    i = i_,
    j = j_,
    x = (weights[i_] * x_),
    dims = c(NROW(loc), mesh$m)
  )
  if (derivatives) {
    if (mesh$degree <= 1) {
      info_$dA <- Matrix::sparseMatrix(
        i = i_d,
        j = j_d,
        x = weights[i_d] * x_d,
        dims = c(NROW(loc), mesh$m)
      )
    } else {
      # degree is 2
      info_$dA <- Matrix::sparseMatrix(
        i = i_,
        j = j_,
        x = weights[i_] * x_d1,
        dims = c(length(loc), mesh$m)
      )
      info_$d2A <- Matrix::sparseMatrix(
        i = i_,
        j = j_,
        x = weights[i_] * x_d2,
        dims = c(NROW(loc), mesh$m)
      )
    }
  }

  info_[["ok"]] <- bary_ok

  fm_basis(info_, full = TRUE)
}



# Plain B-spline basis evaluation by Farin eq 10.13-10.14,
# building the basis function matrices recursively via index vectors
internal_bspline <- function(x, knots, degree = 1, deriv = 0) {
  if (min(x) < min(knots)) {
    stop("Some x out of range (too small)")
  }
  if (max(x) > max(knots)) {
    stop("Some x out of range (too large)")
  }
  k <- findInterval(x, knots, all.inside = TRUE)
  basis <- list(
    i = seq_along(x),
    j = k,
    values = numeric(length(x)) + 1.0
  )
  #  message("knots: ", knots)
  #  message("unique j: ", unique(basis$j))
  if (degree == 0) {
    return(basis)
  }
  knots <- c(rep(min(knots), degree - 1), knots, rep(max(knots), degree - 1))
  basis$j <- basis$j + degree
  #  message("knots: ", knots)
  #  message("unique j: ", unique(basis$j))
  l_range <- unique(sort(basis$j)) - 1L
  for (deg in seq_len(degree)) {
    basis_prev <- basis
    basis <- list(
      i = integer(0),
      j = integer(0),
      x = numeric(0),
      values = numeric(0)
    )
    l_range <- unique(sort(c(l_range - 1L, l_range)))
    for (l in l_range) {
      # +1L since j is base-1 but l is base-0
      left <- basis_prev$j == l + 1L
      right <- basis_prev$j == l + 2L
      if (any(left)) {
        basis$i <- c(basis$i, basis_prev$i[left])
        basis$j <- c(basis$j, basis_prev$j[left])
        basis$values <- c(
          basis$values,
          basis_prev$values[left] *
            (x[basis_prev$i[left]] - knots[l]) /
            (knots[l + deg] - knots[l])
        )
      }
      if (any(right)) {
        basis$i <- c(basis$i, basis_prev$i[right])
        basis$j <- c(basis$j, basis_prev$j[right] - 1L)
        basis$values <- c(
          basis$values,
          basis_prev$values[right] *
            (knots[l + deg + 1L] - x[basis_prev$i[right]]) /
            (knots[l + deg + 1L] - knots[l + 1L])
        )
      }
    }
    #    message("knots: ", knots)
    #    message("unique j: ", unique(basis$j))
  }
  return(basis)
}


# Plain B-spline basis evaluation by Farin eq 10.13-10.14,
# building the basis function matrices recursively via index vectors
internal_bspline2 <- function(x, knots, degree = 1, deriv = 0) {
  if (min(x) < min(knots)) {
    stop("Some x out of range (too small)")
  }
  if (max(x) > max(knots)) {
    stop("Some x out of range (too large)")
  }

  if (deriv > 0) {
    basis_lower <-
      internal_bspline2(x, knots, degree = degree - 1, deriv = deriv - 1)
    m <- length(knots) + degree - 1L
    m_lower <- m - 1L
    A <- Matrix::sparseMatrix(
      i = basis_lower$i,
      j = basis_lower$j,
      x = basis_lower$values,
      dims = c(length(x), m_lower)
    )
    basis_diff <- Matrix::sparseMatrix(
      i = c(seq_len(m_lower), seq_len(m_lower)),
      j = c(seq_len(m_lower), seq_len(m_lower) - 1L),
      x = rep(c(1, -1), c(length(knots) - 1, length(knots) - 1)),
      dims = c(m_lower, m)
    )
    A <- A %*% basis_diff
    return(A)
  }

  k <- findInterval(x, knots, all.inside = TRUE)
  basis <- list(
    i = seq_along(x),
    j = k,
    values = numeric(length(x)) + 1.0
  )
  if (degree == 0) {
    return(basis)
  }
  knots <- c(rep(min(knots), degree - 1), knots, rep(max(knots), degree - 1))
  basis$j <- basis$j + degree
  l_range <- unique(sort(basis$j)) - 1L
  for (deg in seq_len(degree)) {
    basis_prev <- basis
    l_range <- unique(sort(c(l_range - 1L, l_range)))
    # +1L since j is base-1 but l is base-0
    left <- basis_prev$j %in% (l_range + 1L)
    right <- basis_prev$j %in% (l_range + 2L)
    sz <- c(sum(left), sum(right))
    basis <- list(
      i = integer(sum(sz)),
      j = integer(sum(sz)),
      values = numeric(sum(sz))
    )
    if (sz[1] > 0L) {
      idx_left <- seq_len(sz[1])
      i <- basis_prev$i[left]
      j <- basis_prev$j[left]
      l <- j - 1L
      basis$i[idx_left] <- i
      basis$j[idx_left] <- j
      basis$values[idx_left] <- basis_prev$values[left] *
        (x[i] - knots[l]) /
        (knots[l + deg] - knots[l])
    }
    if (sz[2] > 0L) {
      idx_right <- seq_len(sz[2]) + sz[1]
      i <- basis_prev$i[right]
      j <- basis_prev$j[right]
      l <- j - 2L
      basis$i[idx_right] <- i
      basis$j[idx_right] <- j - 1L
      basis$values[idx_right] <- basis_prev$values[right] *
        (knots[l + deg + 1L] - x[i]) /
        (knots[l + deg + 1L] - knots[l + 1L])
    }
  }
  return(basis)
}




# Block methods ####

#' Blockwise aggregation matrices
#'
#' Creates an aggregation matrix for blockwise aggregation, with optional
#' weighting.
#'
#' @param block integer vector; block information. If `NULL`,
#'   `rep(1L, block_len)` is used, where `block_len` is determined by
#'   `length(log_weights)))` or `length(weights)))`. A single scalar is also
#'   repeated to a vector of corresponding length to the weights. 'character'
#'   input is converted to integer with `as.integer(factor(block))` (from
#'   `0.2.0.9017`).
#' @param weights Optional weight vector
#' @param log_weights Optional `log(weights)` vector. Overrides `weights` when
#' non-NULL.
#' @param rescale logical; If `TRUE`, normalise the weights by `sum(weights)`
#' or `sum(exp(log_weights))` within each block.
#' Default: `FALSE`
#' @param n_block integer; The number of conceptual blocks. Only needs to be
#' specified if it's larger than `max(block)`, or to keep the output of
#' consistent size for different inputs.
#'
#' @returns A (sparse) matrix
#' @export
#' @describeIn fm_block A (sparse) matrix of size `n_block` times
#'   `length(block)`.
#' @examples
#' block <- rep(1:2, 3:2)
#' fm_block(block)
#' fm_block(block, rescale = TRUE)
#' fm_block(block, log_weights = -2:2, rescale = TRUE)
#' fm_block_eval(
#'   block,
#'   weights = 1:5,
#'   rescale = TRUE,
#'   values = 11:15
#' )
#' fm_block_logsumexp_eval(
#'   block,
#'   weights = 1:5,
#'   rescale = TRUE,
#'   values = log(11:15),
#'   log = FALSE
#' )
fm_block <- function(block = NULL,
                     weights = NULL,
                     log_weights = NULL,
                     rescale = FALSE,
                     n_block = NULL) {
  info <-
    fm_block_prep(
      block = block,
      n_block = n_block,
      log_weights = log_weights,
      weights = weights
    )
  if (length(info$block) == 0) {
    return(
      Matrix::sparseMatrix(
        i = integer(0),
        j = integer(0),
        x = numeric(0),
        dims = c(0L, 0L)
      )
    )
  }
  weights <-
    fm_block_weights(
      block = info$block,
      weights = info$weights,
      log_weights = info$log_weights,
      n_block = info$n_block,
      rescale = rescale
    )

  Matrix::sparseMatrix(
    i = info$block,
    j = seq_len(length(info$block)),
    x = as.numeric(weights),
    dims = c(info$n_block, length(info$block))
  )
}

#' @param values Vector to be blockwise aggregated
#' @describeIn fm_block Evaluate aggregation. More efficient alternative to to
#' `as.vector(fm_block(...) %*% values)`.
#' @export
fm_block_eval <- function(block = NULL,
                          weights = NULL,
                          log_weights = NULL,
                          rescale = FALSE,
                          n_block = NULL,
                          values = NULL) {
  info <-
    fm_block_prep(
      block = block,
      n_block = n_block,
      log_weights = log_weights,
      weights = weights,
      values = values
    )
  if (length(info$block) == 0) {
    return(numeric(0L))
  }
  weights <-
    fm_block_weights(
      block = info$block,
      weights = info$weights,
      log_weights = info$log_weights,
      n_block = info$n_block,
      rescale = rescale
    )

  if (FALSE) {
    val <-
      Matrix::sparseMatrix(
        i = info$block,
        j = rep(1L, length(info$block)),
        x = as.numeric(values * weights),
        dims = c(info$n_block, 1)
      )
    as.vector(val)
  } else {
    agg <- stats::aggregate(
      data.frame(x = values * weights),
      by = list(block = info[["block"]]),
      FUN = sum,
      simplify = TRUE,
      drop = TRUE
    )
    val <- numeric(info$n_block)
    val[agg[["block"]]] <- agg[["x"]]
    val
  }
}


#' @param log If `TRUE` (default), return log-sum-exp. If `FALSE`,
#' return sum-exp.
#' @describeIn fm_block Evaluate log-sum-exp aggregation.
#' More efficient and numerically stable alternative to to
#' `log(as.vector(fm_block(...) %*% exp(values)))`.
#' @export
fm_block_logsumexp_eval <- function(block = NULL,
                                    weights = NULL,
                                    log_weights = NULL,
                                    rescale = FALSE,
                                    n_block = NULL,
                                    values = NULL,
                                    log = TRUE) {
  info <-
    fm_block_prep(
      block = block,
      n_block = n_block,
      log_weights = log_weights,
      weights = weights,
      values = values
    )
  if (length(info$block) == 0) {
    return(numeric(0L))
  }
  log_weights <-
    fm_block_log_weights(
      block = info$block,
      weights = info$weights,
      log_weights = info$log_weights,
      n_block = info$n_block,
      rescale = rescale
    )

  # Compute shift for stable log-sum-exp
  w_values <- values + log_weights
  shift <- fm_block_log_shift(
    log_weights = w_values,
    block = info$block,
    n_block = info$n_block
  )

  val <-
    Matrix::sparseMatrix(
      i = info$block,
      j = rep(1L, length(info$block)),
      x = as.numeric(exp(w_values - shift[info$block])),
      dims = c(info$n_block, 1)
    )

  if (log) {
    val <- log(as.vector(val)) + shift
  } else {
    val <- as.vector(val) * exp(shift)
  }

  val
}


#' @describeIn fm_block Computes (optionally) blockwise renormalised weights
#' @export
fm_block_weights <-
  function(block = NULL,
           weights = NULL,
           log_weights = NULL,
           rescale = FALSE,
           n_block = NULL) {
    info <-
      fm_block_prep(
        block = block,
        n_block = n_block,
        log_weights = log_weights,
        weights = weights
      )
    if (length(info$block) == 0) {
      return(numeric(0))
    }
    if (rescale) {
      # Compute blockwise normalised weights
      if (!is.null(info$log_weights)) {
        info$log_weights <- fm_block_log_weights(
          log_weights = info$log_weights,
          block = info$block,
          n_block = info$n_block,
          rescale = rescale
        )
        info$weights <- exp(info$log_weights)
      } else {
        if (is.null(info$weights)) {
          info$weights <- rep(1.0, length(info$block))
        }
        scale <- as.vector(Matrix::sparseMatrix(
          i = info$block,
          j = rep(1L, length(info$block)),
          x = as.numeric(info$weights),
          dims = c(info$n_block, 1L)
        ))
        info$weights <- info$weights / scale[block]
      }
    } else if (!is.null(info$log_weights)) {
      info$weights <- exp(info$log_weights)
    } else if (is.null(info$weights)) {
      info$weights <- rep(1.0, length(info$block))
    }
    info$weights
  }

#' @describeIn fm_block Computes (optionally) blockwise renormalised log-weights
#' @export
fm_block_log_weights <- function(block = NULL,
                                 weights = NULL,
                                 log_weights = NULL,
                                 rescale = FALSE,
                                 n_block = NULL) {
  info <-
    fm_block_prep(
      block = block,
      n_block = n_block,
      log_weights = log_weights,
      weights = weights,
      force_log = TRUE
    )
  if (length(info$block) == 0) {
    return(numeric(0))
  }
  if (is.null(info$log_weights)) {
    if (is.null(info$weights)) {
      info$log_weights <- rep(0.0, length(info$block))
    } else {
      info$log_weights <- log(info$weights)
    }
  }
  if (rescale) {
    shift <- fm_block_log_shift(
      block = info$block, log_weights = info$log_weights,
      n_block = info$n_block
    )
    log_rescale <- as.vector(
      Matrix::sparseMatrix(
        i = info$block,
        j = rep(1L, length(info$block)),
        x = as.numeric(exp(info$log_weights - shift[info$block])),
        dims = c(info$n_block, 1L)
      )
    )
    log_rescale <- (log(log_rescale) + shift)[block]
    info$log_weights <- info$log_weights - log_rescale
  }

  info$log_weights
}


#' @describeIn fm_block Computes shifts for stable blocked log-sum-exp.
#' To compute \eqn{\log(\sum_{i; \textrm{block}_i=k} \exp(v_i) w_i)}{
#' log(sum_(i;block_i=k) exp(v_i) w_i)
#' } for
#' each block `k`, first compute combined values and weights, and a shift:
#' ```
#' w_values <- values + fm_block_log_weights(block, log_weights = log_weights)
#' shift <- fm_block_log_shift(block, log_weights = w_values)
#' ```
#' Then aggregate the values within each block:
#' ```
#' agg <- aggregate(exp(w_values - shift[block]),
#'                  by = list(block = block),
#'                  \(x) log(sum(x)))
#' agg$x <- agg$x + shift[agg$block]
#' ```
#' The implementation uses a faster method:
#' ```
#' as.vector(
#'   Matrix::sparseMatrix(
#'     i = block,
#'     j = rep(1L, length(block)),
#'     x = exp(w_values - shift[block]),
#'     dims = c(n_block, 1))
#' ) + shift
#' ```
#' @export
fm_block_log_shift <- function(block = NULL,
                               log_weights = NULL,
                               n_block = NULL) {
  info <-
    fm_block_prep(
      block = block,
      n_block = n_block,
      log_weights = log_weights,
      force_log = TRUE
    )
  if (length(info$block) == 0) {
    return(0.0)
  }
  block_k <- sort(unique(info$block))
  shift <- numeric(info$n_block)
  if (!is.null(info$log_weights)) {
    if (info$n_block <= 200) {
      # Fast for small n_block
      shift[block_k] <-
        vapply(
          block_k,
          function(k) {
            max(info$log_weights[info$block == k])
          },
          0.0
        )
    } else {
      # Fast for large n_block
      shift[block_k] <-
        stats::aggregate(
          info[["log_weights"]],
          by = list(block = info[["block"]]),
          FUN = max,
          simplify = TRUE
        )$x
    }
  }

  shift
}


#' @describeIn fm_block Helper function for preparing `block`, `weights`, and
#' `log_weights`, `n_block` inputs.
#' @export
#' @param n_values When supplied, used instead of `length(values)` to determine
#' the value vector input length.
#' @param force_log When `FALSE` (default),
#' passes either `weights` and `log_weights` on, if provided, with `log_weights`
#' taking precedence. If `TRUE`, forces the computation of `log_weights`,
#' whether given in the input or not.
fm_block_prep <- function(block = NULL,
                          log_weights = NULL,
                          weights = NULL,
                          n_block = NULL,
                          values = NULL,
                          n_values = NULL,
                          force_log = FALSE) {
  if (is.null(n_values)) {
    if (is.null(values)) {
      n_values <- max(length(block), length(weights), length(log_weights))
    } else {
      n_values <- length(values)
    }
  }
  if (is.null(block) || (length(block) == 0)) {
    block <- rep(1L, n_values)
  } else if (length(block) == 1L) {
    block <- rep(block, n_values)
  }
  if (is.character(block)) {
    block <- as.integer(factor(block))
  }
  if (min(block) < 1L) {
    warning(paste0(
      "min(block) = ", min(block),
      " < 1L. Setting too small values to 1L."
    ))
    block <- pmax(1L, block)
  }
  if (is.null(n_block)) {
    n_block <- max(block)
  } else if (max(block) > n_block) {
    warning(paste0(
      max(block), " = max(block) > n_block = ",
      n_block,
      ". Setting too large values to n_block."
    ))
    block <- pmin(n_block, block)
  }

  if (force_log) {
    if (is.null(log_weights)) {
      if (is.null(weights)) {
        log_weights <- rep(0.0, n_values)
      } else {
        log_weights <- log(weights)
        weights <- NULL
      }
    } else if (!is.null(weights)) {
      warning("Both weights and log_weights supplied. Using log_weights.",
        immediate. = TRUE
      )
      weights <- NULL
    }
  } else {
    if (is.null(log_weights) && is.null(weights)) {
      weights <- rep(1.0, n_values)
    } else if (!is.null(weights)) {
      # log_weights is non-NULL
      if (!is.null(log_weights)) {
        warning("Both weights and log_weights supplied. Using log_weights.",
          immediate. = TRUE
        )
        weights <- NULL
      }
    }
  }
  if (length(weights) == 1L) {
    weights <- rep(weights, n_values)
  }
  if (length(log_weights) == 1L) {
    log_weights <- rep(log_weights, n_values)
  }

  list(
    block = block, weights = weights, log_weights = log_weights,
    n_block = n_block
  )
}

Try the fmesher package in your browser

Any scripts or data that you put into this service are public.

fmesher documentation built on April 3, 2025, 7:45 p.m.