R/models_utils.R

Defines functions desingularize recursive_length gipsDA_from_json gipsDA_to_json deserialize_from_json serialize_for_json project_matrix_multiperm project_covs

#' Project covs
#'
#' Projects a list of square matrices onto cone of those invariant under the action of the cyclic subgroup generated by the estimated perm
#'
#' @noRd
project_covs <- function(emp_covs, ns_obs, MAP = TRUE, optimizer, max_iter, tol = 1e-3) {
  gg <- gipsmult(emp_covs, ns_obs, was_mean_estimated = TRUE)
  if (MAP) {
    gg <- find_MAP(gg, optimizer = optimizer, max_iter = max_iter, show_progress_bar = FALSE)
    perm <- gg[[1]]
    return(list(covs = lapply(emp_covs, function(x) gips::project_matrix(x, perm)), opt_info = perm))
  }
  gg <- find_MAP(gg, optimizer = optimizer, max_iter = max_iter, return_probabilities = TRUE, save_all_perms = TRUE, show_progress_bar = FALSE)
  probs <- get_probabilities_from_gipsmult(gg)
  if (all(probs <= tol)) {
    warning("There are no perms with estimated probability above threshold, projecting onto MAP")
    probs <- probs[1]
  } else {
    probs <- probs[probs > tol]
  }


  return(list(covs = lapply(emp_covs, function(x) project_matrix_multiperm(x, probs)), opt_info = probs))
}


#' Project matrix multiperm
#'
#' Returns an average of projections weighted by a posteriori probability
#'
#' @noRd
project_matrix_multiperm <- function(emp_cov, probs) {
  perms <- names(probs)
  projected_matrix <- matrix(0, nrow = dim(emp_cov), dim(emp_cov))
  for (i in 1:length(probs)) {
    projected_matrix <- projected_matrix + probs[[i]] * gips::project_matrix(emp_cov, perms[i])
  }
  return(projected_matrix / sum(probs))
}


#' Serialize for json
#'
#' Serializes gipsDA model objects to be saved in json format
#'
#' @noRd
serialize_for_json <- function(x) {
  if (inherits(x, "gips_perm")) {
    return(list(
      `__type` = "gips_perm",
      value = as.character(x),
      size = recursive_length(x)
    ))
  }

  if (typeof(x) == "language") {
    return(list(
      `__type` = "call",
      value = paste(deparse(x), collapse = " ")
    ))
  }

  if (inherits(x, "formula") || inherits(x, "terms")) {
    return(list(
      `__type` = "formula",
      value = paste(deparse(x), collapse = " ")
    ))
  }

  if (is.matrix(x)) {
    return(list(
      `__type` = "matrix",
      data = as.vector(x),
      nrow = nrow(x),
      ncol = ncol(x)
    ))
  }

  if (is.list(x)) {
    return(lapply(x, serialize_for_json))
  }

  x
}

#' Deserialize from json
#'
#' Desrializes data loaded from json to create a gipsDA model object
#'
#' @noRd
deserialize_from_json <- function(x) {
  if (is.list(x) && !is.null(x$`__type`)) {
    type <- x$`__type`

    if (type == "gips_perm") {
      return(gips::gips_perm(x$value, x$size))
    }

    if (type == "call") {
      return(str2lang(x$value))
    }

    if (type == "formula") {
      return(as.formula(x$value))
    }

    if (type == "matrix") {
      return(matrix(x$data, nrow = x$nrow, ncol = x$ncol))
    }
  }

  if (is.list(x)) {
    return(lapply(x, deserialize_from_json))
  }

  x
}

#' gipsDA to json
#'
#' Saves a gipsDA model object to json format
#'
#' @noRd
gipsDA_to_json <- function(obj, file) {
  jsonlite::write_json(
    serialize_for_json(obj),
    file,
    pretty = TRUE,
    auto_unbox = TRUE
  )
}


#' gipsDA from json
#'
#' Loads a gipsDA model object from a json file
#'
#' @noRd
gipsDA_from_json <- function(file, classname) {
  raw <- jsonlite::read_json(file)
  obj <- deserialize_from_json(raw)
  class(obj) <- classname
  obj
}

#' Recursive length
#'
#' Returns the number of atomic elements in an object
#'
#' @noRd
recursive_length <- function(x) {
  if (is.atomic(x)) {
    return(length(x))
  }
  if (is.list(x)) {
    return(sum(vapply(x, recursive_length, integer(1))))
  }
  return(0)
}

#' Desigularize
#'
#' Rergularizes a square matrix so that the module of it's smallest eigenvalue is the target
#'
#' @noRd
desingularize <- function(A, target = 0.05) {
  symmetric <- all.equal(A, t(A))
  eigvals <- eigen(A, symmetric = symmetric, only.values = TRUE)$values
  idx <- which.min(abs(eigvals))
  lambda <- eigvals[idx]

  if (abs(lambda) >= target) {
    return(A)
  }

  s <- (target - lambda) / (1 - target)

  if (1 + s <= 0) {
    stop("Invalid scaling: 1 + s <= 0")
  }

  return((A + diag(s, nrow(A))) / (1 + s))
}

Try the gipsDA package in your browser

Any scripts or data that you put into this service are public.

gipsDA documentation built on Feb. 3, 2026, 5:07 p.m.