R/torch_diag_batch.R

Defines functions torch_diag_batch

#' @export
torch_diag_batch<-function(x)
{
  is_bm = is_batch_mat(x)
  # --- this is probably unnecessarily slow
  if(is_bm){
    dim_x = dim(x)
    Id = torch_eye_embed(dim_x[1],dim_x[2])
    d = x*Id
  }
  else{
    d = torch_diag(x)
    d = torch_diag(x)
  }
  return(d)
}
adsb85/lqp documentation built on April 9, 2022, 12:35 a.m.