R/methods_mlogit.R

Defines functions sanitize_model_specific.mlogit get_predict.mlogit

Documented in get_predict.mlogit sanitize_model_specific.mlogit

#' @rdname get_predict
#' @export
get_predict.mlogit <- function(model,
                               newdata,
                               ...) {

    mat <- stats::predict(model, newdata = newdata)
    if (isTRUE(checkmate::check_atomic_vector(mat))) {
        out <- data.table(rowid = seq_along(mat),
                          group = names(mat),
                          estimate = mat)
    } else {
    out <- data.table(rowid = rep(seq_len(nrow(mat)), rep = ncol(mat)),
                      group = rep(colnames(mat), each = nrow(mat)),
                      estimate = as.vector(mat))
    }
    setkey(out, rowid, group)
    if ("term" %in% colnames(newdata)) {
        out[, "term" := newdata[["term"]]]
    }
    return(out)
}


#' @include sanity_model.R
#' @rdname sanitize_model_specific
#' @keywords internal
sanitize_model_specific.mlogit <- function(model, newdata, ...) {
    if (!is.null(newdata)) {
        nchoices <- length(unique(model$model$idx[, 2]))
        if (!isTRUE(nrow(newdata) %% nchoices == 0)) {
            msg <- sprintf("The `newdata` argument for `mlogit` models must be a data frame with a number of rows equal to a multiple of the number of choices: %s.", nchoices)
            stop(msg, call. = FALSE)
        }
    }
    return(model)
}

Try the marginaleffects package in your browser

Any scripts or data that you put into this service are public.

marginaleffects documentation built on Oct. 20, 2023, 1:07 a.m.