R/methods_flexsurv.R

Defines functions get_predict.flexsurvreg set_coef.flexsurvreg

Documented in get_predict.flexsurvreg set_coef.flexsurvreg

#' @rdname set_coef
#' @export
set_coef.flexsurvreg <- function(model, coefs, ...) {
  out <- model
  out$res[, 1] <- coefs
  out$coefficients <- coefs
  return(out)
}


#' @rdname get_predict
#' @export
get_predict.flexsurvreg <- function(model, newdata, type, ...) {
  preds <- stats::predict(
    object = model,
    newdata = newdata,
    type = type,
    ...
  )

  if (ncol(preds) == 1L) {
    if (names(preds) == '.pred') {
      gp <- unlist(lapply(preds$.pred, function(x) {x[, 1, drop = TRUE]}))
      val <- unlist(lapply(preds$.pred, function(x) {x[, 2, drop = TRUE]}))

      out <- data.frame(
        rowid = seq_len(nrow(preds)),
        group = as.vector(gp),
        estimate = as.vector(val)
      )
      out$group <- group_to_factor(out$group, model)
      return(out)
    }
    out <- data.frame(
      rowid = seq_len(nrow(preds)),
      estimate = as.vector(preds[, 1, drop = TRUE])
    )
    return(out)
  }

  out <- data.frame(
    rowid = seq_len(nrow(preds)),
    group = as.vector(preds[, 1, drop = TRUE]),
    estimate = as.vector(preds[, 2, drop = TRUE])
  )
  out$group <- group_to_factor(out$group, model)
  out
}

Try the marginaleffects package in your browser

Any scripts or data that you put into this service are public.

marginaleffects documentation built on May 29, 2024, 4:03 a.m.