R/model.R

Defines functions update.lcModel time.lcModel sigma.lcModel residuals.lcModel predict.lcModel nobs.lcModel nIds .ndobs model.data.lcModel model.data model.frame.lcModel is.lcModel ids getCall.lcModel formula.lcModel fitted.lcModel df.residual.lcModel deviance.lcModel coef.lcModel clusterSizes `clusterNames<-` clusterNames

Documented in clusterNames clusterSizes coef.lcModel deviance.lcModel df.residual.lcModel fitted.lcModel formula.lcModel getCall.lcModel ids is.lcModel model.data model.data.lcModel model.frame.lcModel nIds nobs.lcModel predict.lcModel residuals.lcModel sigma.lcModel time.lcModel update.lcModel

#' @include method.R trajectories.R latrend.R
#' @importFrom stats coef deviance df.residual getCall logLik model.frame model.matrix predict residuals sigma time update

# Model ####
#' @name lcModel
#' @title Longitudinal cluster result (**`lcModel`**)
#' @description A longitudinal cluster model (\code{[lcModel][lcModel-class]}) describes the clustered representation of a certain longitudinal dataset.
#'
#' A `lcModel` is obtained by estimating a specified [longitudinal cluster method][lcMethod] on a [longitudinal dataset][latrend-data].
#' The estimation is done via one of the [latrend estimation functions][latrend-estimation].
#'
#' A longitudinal cluster result represents the dataset in terms of a partitioning of the trajectories into a number of clusters.
#' The [trajectoryAssignments()] function outputs the most likely membership for the respective trajectories.
#' Each cluster has a longitudinal representation, obtained via [clusterTrajectories()], and can be plotted via [plotClusterTrajectories()].
#' @section Functionality:
#' **Clusters and partitioning:**
#' * [nClusters()]: The number of clusters this model represents.
#' * [clusterNames()]: The names of the clusters.
#' * [clusterSizes()]: The respective number of trajectories assigned to each cluster.
#' * [clusterProportions()]: The respective proportional size of each cluster.
#' * [trajectoryAssignments()]: The most likely cluster membership of each trajectory.
#' * [postprob()]: The posterior probability of each trajectory to each cluster.
#'
#' **Longitudinal cluster representation (i.e., trends):**
#' * [clusterTrajectories()]: A `data.frame` containing the longitudinal representation of each cluster.
#' * [plotClusterTrajectories()]: Plots the longitudinal representation of each cluster.
#' * [fittedTrajectories()]: A `data.frame` containing the longitudinal representation of each trajectory. For many methods, this is the cluster center.
#' * [plotFittedTrajectories()]: Plot the trajectory representation.
#'
#' **Training data:**
#' * [nIds()]: The number of trajectories used for estimation.
#' * [ids()]: A vector of identifiers of the trajectories that were used for estimation.
#' * [nobs()]: The number of observations used for estimation, across trajectories.
#' * [time()]: Moments in time on which observations are present.
#' * [trajectories()]: The trajectories that were used for estimation.
#' * [plotTrajectories()]: Plot the trajectories that were used for estimation.
#'
#' **Model evaluation:**
#' * [summary()]: Obtain a summary of the model.
#' * [metric()]: Compute an internal metric.
#' * [externalMetric()]: Compute an external metric in relation to a second `lcModel`.
#' * [converged()]: Whether the estimation procedure converged.
#' * [estimationTime()]: Total time that was needed for the fitting steps.
#' * [sigma()]: Residual error scale.
#' * [qqPlot()]: QQ plot of the model residuals.
#'
#' **Model prediction:**
#' * [predictForCluster()]: Cluster-specific prediction on new data. Not supported for all methods.
#' * [predictPostprob()]: Predict posterior probability for new data. Not supported for all methods.
#' * [predictAssignments()]: Predict cluster membership for new data. Not supported for all methods.
#'
#' **Other functionality:**
#' * [getLcMethod()]: Get the [method specification][lcMethod] by which this model was estimated.
#' * [update()]: Retrain a model with altered method arguments.
#' * [strip()]: Removes non-essential (meta) data and environments from the model to facilitate efficient serialization.
#' @seealso [lcModel-class]
#' @examples
#' data(latrendData)
#' # define the method
#' method <- lcMethodLMKM(Y ~ Time, id = "Id", time = "Time")
#' # estimate the method, giving the model
#' model <- latrend(method, data = latrendData)
#'
#' if (require("ggplot2")) {
#'   plotClusterTrajectories(model)
#' }
NULL


#' @export
#' @name lcModel-class
#' @title `lcModel` class
#' @description Abstract class for defining estimated longitudinal cluster models.
#' @details An extending class must implement the following methods to ensure basic functionality:
#' * `predict.lcModelExt`: Used to obtain the fitted cluster trajectories and trajectories.
#' * `postprob(lcModelExt)`: The posterior probability matrix is used to determine the cluster assignments of the trajectories.
#'
#' For predicting the posterior probability for unseen data, the `predictPostprob()` should be implemented.
#'
#' @param object The `lcModel` object.
#' @param ... Any additional arguments.
#' @slot method The \link{lcMethod-class} object specifying the arguments under which the model was fitted.
#' @slot call The `call` that was used to create this `lcModel` object. Typically, this is the call to `latrend()` or any of the other fitting functions.
#' @slot model An arbitrary underlying model representation.
#' @slot data A `data.frame` object, or an expression to resolves to the `data.frame` object.
#' @slot date The date-time when the model estimation was initiated.
#' @slot id The name of the trajectory identifier column.
#' @slot time The name of the time variable.
#' @slot response The name of the response variable.
#' @slot label The label assigned to this model.
#' @slot ids The trajectory identifier values the model was fitted on.
#' @slot times The exact times on which the model has been trained
#' @slot clusterNames The names of the clusters.
#' @slot estimationTime The time, in seconds, that it took to fit the model.
#' @slot tag An arbitrary user-specified data structure. This slot may be accessed and updated directly.
#' @family lcModel functions
setClass(
  'lcModel',
  representation(
    model = 'ANY',
    method = 'lcMethod',
    call = 'call',
    data = 'ANY',
    id = 'character',
    time = 'character',
    response = 'character',
    label = 'character',
    ids = 'vector',
    times = 'vector',
    clusterNames = 'character',
    date = 'POSIXct',
    estimationTime = 'numeric',
    tag = 'ANY'
  )
)

# . initialize ####
setMethod('initialize', 'lcModel', function(.Object, ...) {
  .Object@date = Sys.time()
  .Object = callNextMethod(.Object, ...)
  method = .Object@method

  assert_that(
    length(.Object@id) > 0 || has_name(method, 'id'),
    msg = '@id not specified, nor defined in lcMethod'
  )
  if (length(.Object@id) == 0) {
    .Object@id = idVariable(method)
  }

  assert_that(
    length(.Object@time) > 0 || has_name(method, 'time'),
    msg = '@time not specified, nor defined in lcMethod'
  )

  if (length(.Object@time) == 0) {
    .Object@time = timeVariable(method)
  }
  if (length(.Object@response) == 0) {
    .Object@response = responseVariable(method)
    assert_that(!is.null(.Object@response))
  }
  .Object
})


setValidity('lcModel', function(object) {
  return(TRUE)

  if (as.character(object@call[[1]]) == "<undef>") {
    # nothing to validate as lcModel is incomplete
    return(TRUE)
  }

  assert_that(
    nchar(object@id) > 0,
    nchar(object@time) > 0,
    nchar(object@response) > 0
  )

  data = model.data(object)
  assert_that(
    !is.null(data),
    msg = 'invalid data object for new lcModel. Either specify the data slot or ensure that the model call contains a data argument which correctly evaluates.'
  )
  assert_that(has_name(data, c(object@id, object@time, object@response)))

  TRUE
})


# . clusterTrajectories ####
#' @export
#' @name clusterTrajectories
#' @rdname clusterTrajectories
#' @aliases clusterTrajectories,lcModel-method
#' @title Extract the cluster trajectories
#' @description Extracts a data frame of all cluster trajectories.
#' @inheritParams predict.lcModel
#' @inheritParams predictForCluster
#' @param at An optional vector of the times at which to compute the cluster trajectory predictions.
#' @return A data.frame of the estimated values at the given times. The first column should be named "Cluster". The second column should be time, with the name matching the `timeVariable(object)`. The third column should be the expected value of the observations, named after the `responseVariable(object)`.
#' @examples
#' method <- lcMethodLMKM(Y ~ Time, id = "Id", time = "Time")
#' model <- latrend(method, latrendData)
#'
#' clusterTrajectories(model)
#'
#' clusterTrajectories(model, at = c(0, .5, 1))
#' @family lcModel functions
setMethod('clusterTrajectories', 'lcModel', function(object, at = time(object), what = 'mu', ...) {
  newdata = data.table(
    Cluster = rep(clusterNames(object, factor = TRUE), each = length(at)),
    Time = at
  ) %>%
    setnames('Time', timeVariable(object))

  assert_that(
    has_name(newdata, timeVariable(object)),
    msg = sprintf('"at" argument of clusterTrajectories() requires time index column of name "%s"', timeVariable(object))
  )

  dfPred = predict(object, newdata = newdata, what = what, ...)
  assert_that(is.data.frame(dfPred), msg = 'invalid output from predict()')
  assert_that(
    nrow(dfPred) == nrow(newdata),
    msg = 'invalid output from predict function of lcModel; expected a prediction per newdata row'
  )
  newdata[, c(responseVariable(object, what = what)) := dfPred$Fit]

  newdata[]
})


#' @export
#' @title Get the cluster names
#' @param object The `lcModel` object.
#' @param factor Whether to return the cluster names as a factor.
#' @return A `character` of the cluster names.
#' @examples
#' data(latrendData)
#' method <- lcMethodLMKM(Y ~ Time, id = "Id", time = "Time")
#' model <- latrend(method, latrendData)
#' clusterNames(model) # A, B
#' @family lcModel functions
clusterNames = function(object, factor = FALSE) {
  assert_that(is.lcModel(object))
  if (isTRUE(factor)) {
    factor(object@clusterNames, levels = object@clusterNames)
  } else {
    object@clusterNames
  }
}

#' @export
#' @title Update the cluster names
#' @param object The `lcModel` object to update.
#' @param value The `character` with the new names.
#' @return The updated `lcModel` object.
#' @examples
#' data(latrendData)
#' method <- lcMethodLMKM(Y ~ Time, id = "Id", time = "Time")
#' model <- latrend(method, latrendData, nClusters = 2)
#' clusterNames(model) <- c("Group 1", "Group 2")
`clusterNames<-` = function(object, value) {
  assert_that(
    is.lcModel(object),
    is.character(value),
    length(value) == nClusters(object)
  )

  object@clusterNames = value

  object
}

#' @export
#' @title Number of trajectories per cluster
#' @description Obtain the size of each cluster, where the size is determined by the number of assigned trajectories to each cluster.
#' @details The cluster sizes are computed from the trajectory cluster membership as decided by the [trajectoryAssignments()] function.
#' @param object The `lcModel` object.
#' @param ... Additional arguments passed to [trajectoryAssignments()].
#' @seealso [clusterProportions] [trajectoryAssignments]
#' @return A named `integer vector` of length `nClusters(object)` with the number of assigned trajectories per cluster.
#' @examples
#' data(latrendData)
#' method <- lcMethodLMKM(Y ~ Time, id = "Id", time = "Time")
#' model <- latrend(method, latrendData, nClusters = 2)
#' clusterSizes(model)
#' @family lcModel functions
clusterSizes = function(object, ...) {
  assert_that(is.lcModel(object))

  trajectoryAssignments(object, ...) %>%
    table() %>%
    as.integer() %>%
    setNames(clusterNames(object))
}

#. clusterProportions ####
#' @export
#' @name clusterProportions
#' @aliases clusterProportions,lcModel-method
#' @title Proportional size of each cluster
#' @description Obtain the proportional size per cluster, with sizes between 0 and 1.
#' By default, the cluster proportions are determined from the cluster-averaged posterior probabilities of the fitted data (as computed by the [postprob()] function).
#' @section Implementation:
#' Classes extending `lcModel` can override this method to return, for example, the exact estimated mixture proportions based on the model coefficients.
#' \preformatted{
#' setMethod("clusterProportions", "lcModelExt", function(object, ...) {
#'   # return cluster proportion vector
#' })
#' }
#' @param object The `lcModel` to obtain the proportions from.
#' @param ... Additional arguments passed to [postprob()].
#' @return A named `numeric vector` of length `nClusters(object)` with the proportional size of each cluster.
#' @seealso [clusterSizes] [postprob]
#' @examples
#' data(latrendData)
#' method <- lcMethodLMKM(Y ~ Time, id = "Id", time = "Time")
#' model <- latrend(method, latrendData, nClusters = 2)
#' clusterProportions(model)
#' @family lcModel functions
setMethod('clusterProportions', 'lcModel', function(object, ...) {
  pp = postprob(object, ...)
  assert_that(
    !is.null(pp),
    msg = 'cannot determine cluster assignments because postprob() returned NULL'
  )
  assert_that(
    nrow(pp) > 0,
    msg = 'cannot determine cluster assignments because postprob() returned a matrix without rows'
  )

  colMeans(pp)
})


#' @export
#' @importFrom stats coef
#' @title Extract lcModel coefficients
#' @description Extract the coefficients of the `lcModel` object, if defined.
#' The returned set of coefficients depends on the underlying type of `lcModel`.
#' The default implementation checks for the existence of a `coef()` function for the internal model as defined in the `@model` slot, returning the output if available.
#' @param object The `lcModel` object.
#' @param ... Additional arguments.
#' @section Implementation:
#' Classes extending `lcModel` can override this method to return model-specific coefficients.
#' \preformatted{
#' coef.lcModelExt <- function(object, ...) {
#'   # return model coefficients
#' }
#' }
#' @return A named `numeric vector` with all coefficients, or a `matrix` with each column containing the cluster-specific coefficients. If `coef()` is not defined for the given model, an empty `numeric vector` is returned.
#' @family lcModel functions
#' @examples
#' data(latrendData)
#' method <- lcMethodLMKM(Y ~ Time, id = "Id", time = "Time")
#' model <- latrend(method, latrendData, nClusters = 2)
#' coef(model)
#' @family lcModel functions
coef.lcModel = function(object, ...) {
  if (is.null(object@model) ||
      is.null(getS3method('coef', class = class(object@model)[1], optional = TRUE))) {
    numeric()
  } else {
    coef(object@model, ...)
  }
}


# . converged ####
#' @export
#' @name converged
#' @aliases converged,lcModel-method
#' @title Check model convergence
#' @description Check convergence of the fitted `lcModel` object.
#' The default implementation returns `NA`.
#' @param object The `lcModel` to check for convergence.
#' @param ... Additional arguments.
#' @return Either `logical` indicating convergence, or a `numeric` status code.
#' @section Implementation:
#' Classes extending `lcModel` can override this method to return a convergence status or code.
#' \preformatted{
#' setMethod("converged", "lcModelExt", function(object, ...) {
#'   # return convergence code
#' })
#' }
#' @examples
#' data(latrendData)
#' method <- lcMethodLMKM(Y ~ Time, id = "Id", time = "Time")
#' model <- latrend(method, latrendData, nClusters = 2)
#' converged(model)
#' @family lcModel functions
setMethod('converged', 'lcModel', function(object, ...) {
  NA
})


#' @export
#' @importFrom stats deviance
#' @title lcModel deviance
#' @description Get the deviance of the fitted `lcModel` object.
#' @details The default implementation checks for the existence of the `deviance()` function for the internal model, and returns the output, if available.
#' @param object The `lcModel` object.
#' @param ... Additional arguments.
#' @return A `numeric` with the deviance value. If unavailable, `NA` is returned.
#' @seealso [stats::deviance] [metric]
#' @family lcModel functions
deviance.lcModel = function(object, ...) {
  if (is.null(object@model) ||
      is.null(getS3method('deviance', class = class(object@model)[1], optional = TRUE))) {
    as.numeric(NA)
  } else {
    # nocov start
    deviance(object@model)
    # nocov end
  }
}


#' @export
#' @importFrom stats df.residual
#' @title Extract the residual degrees of freedom from a lcModel
#' @param object The `lcModel` object.
#' @param ... Additional arguments.
#' @return A `numeric` with the residual degrees of freedom. If unavailable, `NA` is returned.
#' @seealso [stats::df.residual] [nobs] [residuals]
#' @family lcModel functions
df.residual.lcModel = function(object, ...) {
  if (is.null(object@model) ||
      is.null(getS3method('df.residual', class = class(object@model)[1], optional = TRUE))) {
    df = attr(logLik(object), 'df')
    if (!is.null(df) && is.finite(df) && is.numeric(nobs(object))) {
      nobs(object) - df
    } else {
      as.numeric(NA)
    }
  } else {
    # nocov start
    df.residual(object@model)
    # nocov end
  }
}


#. externalMetric ####
#' @export
#' @rdname externalMetric
#' @aliases externalMetric,lcModel,lcModel-method
#' @return For `externalMetric(lcModel, lcModel)`: A `numeric` vector of the computed metrics.
#' @examples
#' data(latrendData)
#' method <- lcMethodLMKM(Y ~ Time, id = "Id", time = "Time")
#' model2 <- latrend(method, latrendData, nClusters = 2)
#' model3 <- latrend(method, latrendData, nClusters = 3)
#'
#' if (require("mclustcomp")) {
#'   externalMetric(model2, model3, "adjustedRand")
#' }
#' @family metric functions
#' @family lcModel functions
setMethod('externalMetric', c('lcModel', 'lcModel'), function(object, object2, name, ...) {
  assert_that(length(name) > 0, msg = 'no external metric names provided')
  assert_that(is.character(name))

  funMask = name %in% getExternalMetricNames()
  if (!all(funMask)) {
    warning(
      'External metric(s) ',
      paste0('"', name[!funMask], '"', collapse = ', '),
      ' are not defined. Returning NA.'
    )
  }
  metricFuns = lapply(name[funMask], getExternalMetricDefinition)
  metricValues = Map(
    function(fun, name) {
      value = fun(object, object2)
      assert_that(
        is.scalar(value) && (is.numeric(value) || is.logical(value)),
        msg = sprintf('invalid output for metric "%s"; expected scalar number or logical value', name)
      )
      return(value)
    },
    metricFuns,
    name[funMask]
  )

  allMetrics = rep(NA * 0, length(name))
  allMetrics[funMask] = unlist(metricValues)
  names(allMetrics) = name

  allMetrics
})


#' @export
#' @importFrom stats fitted
#' @title Extract lcModel fitted values
#' @description Returns the cluster-specific fitted values for the given `lcModel` object.
#' The default implementation calls [predict()] with `newdata = NULL`.
#' @param object The `lcModel` object.
#' @param ... Additional arguments.
#' @param clusters Optional cluster assignments per id. If unspecified, a `matrix` is returned containing the cluster-specific predictions per column.
#' @return A `numeric` vector of the fitted values for the respective class, or a `matrix` of fitted values for each cluster.
#' @section Implementation:
#' Classes extending `lcModel` can override this method to adapt the computation of the predicted values for the training data.
#' Note that the implementation of this function is only needed when [predict()] and [predictForCluster()] are not defined for the `lcModel` subclass.
#' \preformatted{
#' fitted.lcModelExt <- function(object, ..., clusters = trajectoryAssignments(object)) {
#'   pred = predict(object, newdata = NULL)
#'   transformFitted(pred = pred, model = object, clusters = clusters)
#' }
#' }
#' The [transformFitted()] function takes care of transforming the prediction input to the right output format.
#' @examples
#' data(latrendData)
#' method <- lcMethodLMKM(Y ~ Time, id = "Id", time = "Time")
#' model <- latrend(method, latrendData)
#' fitted(model)
#' @seealso [fittedTrajectories] [plotFittedTrajectories] [stats::fitted] [predict.lcModel] [trajectoryAssignments] [transformFitted]
#' @family lcModel functions
fitted.lcModel = function(object, ..., clusters = trajectoryAssignments(object)) {
  if (suppressWarnings(nIds(object)) == 0) {
    warning('No result for fitted() because this model has no associated trajectories that were used for training.')
    transformFitted(pred = NULL, model = object, clusters = clusters)
  }
  else {
    pred = predict(object, newdata = NULL, ...)
    transformFitted(pred = pred, model = object, clusters = clusters)
  }
}


# . fittedTrajectories ####
#' @export
#' @name fittedTrajectories
#' @rdname fittedTrajectories
#' @aliases fittedTrajectories,lcModel-method
#' @title Extract the fitted trajectories for all strata
#' @param object The model.
#' @param at The time points at which to compute the id-specific trajectories.
#' The default implementation merely filters the output of `fitted()`,
#' so fitted values can only be outputted for times at which the model was trained.
#' @param what The distributional parameter to compute the response for.
#' @param clusters The cluster assignments for the strata to base the trajectories on.
#' @param ... Additional arguments.
#' @return A `data.frame` representing the fitted response per trajectory per moment in time for the respective cluster.
#' @details The default implementation uses the output of `fitted()` of the respective model.
#' @examples
#' data(latrendData)
#' # Note: not a great example because the fitted trajectories
#' # are identical to the respective cluster trajectory
#' method <- lcMethodLMKM(Y ~ Time, id = "Id", time = "Time")
#' model <- latrend(method, latrendData)
#' fittedTrajectories(model)
#'
#' fittedTrajectories(model, at = time(model)[c(1, 2)])
#' @family lcModel functions
setMethod('fittedTrajectories', 'lcModel', function(object, at, what, clusters, ...) {
  assert_that(
    is.null(clusters) || length(clusters) == nIds(object),
    is.numeric(at),
    all(at %in% time(object))
  )

  newdata = model.data(object) %>%
    subset(select = c(idVariable(object), timeVariable(object))) %>%
    as.data.table()

  fits = fitted(object, what = what, clusters = clusters, ...)

  if (is.matrix(fits)) {
    # fit per cluster
    assert_that(
      is.numeric(fits),
      nrow(fits) == nobs(object)
    )

    newdata = cbind(newdata, fits) %>%
      melt(
        id.vars = names(newdata)[1:2],
        variable.name = 'Cluster',
        value.name = responseVariable(object, what = what)
      )
  } else {
    # fit for assigned cluster only
    assert_that(
      is.numeric(fits),
      length(fits) == nobs(object)
    )

    newdata[, c(responseVariable(object, what = what)) := fits]
    newdata[, Cluster := trajectoryAssignments(object)[make.idRowIndices(object)]]
  }

  # filter times
  newdata[get(timeVariable(object)) %in% at]
})


#' @export
#' @importFrom stats formula
#' @title Extract the formula of a lcModel
#' @description Get the formula associated with the fitted `lcModel` object.
#' This is determined by the `formula` argument of the `lcMethod` specification that was used to fit the model.
#' @param x The `lcModel` object.
#' @param what The distributional parameter.
#' @param ... Additional arguments.
#' @return Returns the associated `formula`, or `response ~ 0` if not specified.
#' @seealso [stats::formula]
#' @examples
#' data(latrendData)
#' method <- lcMethodLMKM(Y ~ Time, id = "Id", time = "Time")
#' model <- latrend(method, data = latrendData)
#' formula(model) # Y ~ Time
formula.lcModel = function(x, what = 'mu', ...) {
  method = getLcMethod(x)
  if (what == 'mu') {
    if (has_name(method, 'formula')) {
      method$formula
    } else {
      as.formula(paste(x@response, '~ 0'))
    }
  } else {
    formulaName = paste('formula', what, sep = '.')
    if (has_name(method, formulaName)) {
      formula(method, what = what, ...)
    } else {
      ~ 0
    }
  }
}


#' @export
#' @importFrom stats getCall
#' @title Get the model call
#' @description Extract the `call` that was used to fit the given `lcModel` object.
#' @param x The `lcModel` object.
#' @param ... Not used.
#' @return A `call` to [latrend()] with the necessary arguments and data.
#' @keywords internal
#' @examples
#' data(latrendData)
#' method <- lcMethodRandom("Y", id = "Id", time = "Time")
#' model <- latrend(method, latrendData)
#' getCall(model)
#' @seealso [stats::getCall] [getLcMethod]
#' @family lcModel functions
getCall.lcModel = function(x, ...) {
  x@call
}

#. getLabel ####
#' @export
#' @rdname getLabel
#' @aliases getLabel,lcModel-method
setMethod('getLabel', 'lcModel', function(object, ...) {
  lbl = object@label
  if (length(lbl) > 0) {
    lbl
  } else {
    ''
  }
})

# . getLcMethod ####
#' @export
#' @name getLcMethod
#' @rdname getLcMethod
#' @aliases getLcMethod,lcModel-method
#' @title Get the method specification of a lcModel
#' @description Get the `lcMethod` specification object that was used for fitting the given `lcModel` object.
#' @param object The `lcModel` object.
#' @return An `lcMethod` object.
#' @examples
#' method <- lcMethodRandom("Y", id = "Id", time = "Time")
#' model <- latrend(method, latrendData)
#' getLcMethod(model)
#' @seealso [getCall.lcModel]
#' @family lcModel functions
setMethod('getLcMethod', 'lcModel', function(object) object@method)


# . getName ####
#' @export
#' @rdname getName
#' @aliases getName,lcModel-method
setMethod('getName', 'lcModel', function(object) {
  basename = getLcMethod(object) %>% getName()
  lbl = getLabel(object)
  if (length(lbl) > 0 && nchar(lbl) > 0) {
    paste(basename, lbl, sep = '-')
  } else {
    basename
  }
})

# . getShortName ####
#' @export
#' @rdname getName
#' @aliases getShortName,lcModel-method
setMethod('getShortName',  'lcModel',
  function(object) getShortName(getLcMethod(object))
)


# . ids ####
#' @export
#' @title Get the trajectory ids on which the model was fitted
#' @details The order returned by `ids(object)` determines the id order for any output involving id-specific values, such as in [trajectoryAssignments()] or [postprob()].
#' @param object The `lcModel` object.
#' @return A `character vector` or `integer vector` of the identifier for every fitted trajectory.
#' @examples
#' data(latrendData)
#' method <- lcMethodRandom("Y", id = "Id", time = "Time")
#' model <- latrend(method, latrendData)
#' ids(model) # 1, 2, ..., 200
#' @family lcModel functions
ids = function(object) {
  assert_that(is.lcModel(object))
  if (length(object@ids) == 0) {
    idvec = model.data(object)[[idVariable(object)]]
    make.ids(idvec)
  } else {
    object@ids
  }
}


#. idVariable ####
#' @export
#' @name idVariable
#' @rdname idVariable
#' @aliases idVariable,lcModel-method
#' @examples
#' method <- lcMethodRandom("Y", id = "Id", time = "Time")
#' model <- latrend(method, latrendData)
#' idVariable(model) # "Id"
#' @family lcModel variables
setMethod('idVariable', 'lcModel', function(object) object@id)


#' @export
#' @rdname is
is.lcModel = function(x) {
  isS4(x) && is(x, 'lcModel')
}


#. metric ####
#' @export
#' @name metric
#' @rdname metric
#' @aliases metric,lcModel-method
#' @examples
#' data(latrendData)
#' method <- lcMethodLMKM(Y ~ Time, id = "Id", time = "Time")
#' model <- latrend(method, latrendData)
#' metric(model, "WMAE")
#'
#' if (require("clusterCrit")) {
#'   metric(model, c("WMAE", "Dunn"))
#' }
#' @family metric functions
#' @family lcModel functions
setMethod('metric', 'lcModel', function(object, name, ...) {
  assert_that(length(name) > 0, msg = 'no metric names provided')
  assert_that(
    is.lcModel(object),
    is.character(name)
  )

  funMask = name %in% getInternalMetricNames()
  if (!all(funMask)) {
    warning(
      'Internal metric(s) ',
      paste0('"', name[!funMask], '"', collapse = ', '),
      ' are not defined. Returning NA.'
    )
  }
  metricFuns = lapply(name[funMask], getInternalMetricDefinition)
  metricValues = Map(
    function(fun, name) {
      value = fun(object)
      assert_that(
        is.scalar(value) && (is.numeric(value) || is.logical(value)),
        msg = sprintf(
          'invalid output for metric "%s"; expected scalar number or logical value',
          name
        )
      )
      return(value)
    },
    metricFuns,
    name[funMask]
  )

  allMetrics = rep(NA_real_, length(name))
  allMetrics[funMask] = unlist(metricValues)
  names(allMetrics) = name

  allMetrics
})


#' @export
#' @importFrom stats model.frame
#' @title Extract model training data
#' @description See [stats::model.frame()] for more details.
#' @param formula The `lcModel` object.
#' @param ... Additional arguments.
#' @return A `data.frame` containing the variables used by the model.
#' @seealso [stats::model.frame] [model.data.lcModel]
#' @family lcModel functions
#' @examples
#' data(latrendData)
#' method <- lcMethodLMKM(Y ~ Time, id = "Id", time = "Time")
#' model <- latrend(method, data = latrendData)
#' model.frame(model)
model.frame.lcModel = function(formula, ...) {
  if (is.null(formula@model) ||
      is.null(getS3method('model.frame', class = class(formula@model)[1], optional = TRUE))) {
    labs = stats::formula(formula) %>%
      terms() %>%
      labels()

    if (length(labs) > 0) {
      model.data(formula) %>% subset(select = labs)
    } else {
      stop(sprintf('cannot determine model.frame for the given model of class %s', class(formula)))
    }
  } else {
    model.frame(formula@model, ...)
  }
}


#' @export
#' @title Extract the model training data
#' @param object The object.
#' @param ... Additional arguments.
#' @keywords internal
model.data = function(object, ...) {
  UseMethod('model.data')
}

#' @export
#' @title Extract the model data that was used for fitting
#' @param object The `lcModel` object.
#' @param ... Additional arguments.
#' @description Evaluates the data call in the environment that the model was trained in.
#' @return The full `data.frame` that was used for fitting the `lcModel`.
#' @seealso [model.frame.lcModel] [time.lcModel]
#' @examples
#' data(latrendData)
#' method <- lcMethodLMKM(Y ~ Time, id = "Id", time = "Time")
#' model <- latrend(method, latrendData)
#' model.data(model)
model.data.lcModel = function(object, ...) {
  if (!is.null(object@data)) {
    object@data
    assert_that(is.data.frame(object@data), msg = 'expected data reference to be a data.frame')

    object@data
  } else if (has_name(getCall(object), 'data')) {
    data = eval(getCall(object)$data, envir = environment(object))
    assert_that(!is.null(data),
      msg = sprintf('could not find "%s" in the model environment', deparse(data)))
    assert_that(!is.function(data), msg = sprintf('The data object was not found in the model environment. The data object currently evaluates to a function, indicating the original training data is not loaded.'))

    modelData = trajectories(
      data,
      id = idVariable(object),
      time = timeVariable(object),
      response = responseVariable(object),
      envir = environment(object)
    )

    assert_that(is.data.frame(modelData), msg = 'expected data reference to be a data.frame')

    modelData
  } else {
    warning('Cannot determine data used to train this lcModel. Data not part of model call, and not assigned to the @data slot. Returning NULL.')

    NULL
  }
}

# Number of rows used for fitting the model
.ndobs = function(object, ...) {
  data = model.data(object)
  if (is.null(data)) {
    0L
  } else {
    nrow(model.data(object))
  }
}

#' @export
#' @title Number of trajectories
#' @description Get the number of trajectories (strata) that were used for fitting the given `lcModel` object.
#' The number of trajectories is determined from the number of unique identifiers in the training data. In case the trajectory ids were supplied using a `factor` column, the number of trajectories is determined by the number of levels instead.
#' @param object The `lcModel` object.
#' @return An `integer` with the number of trajectories on which the `lcModel` was fitted.
#' @seealso [nobs] [nClusters]
#' @examples
#' data(latrendData)
#' method <- lcMethodRandom("Y", id = "Id", time = "Time")
#' model <- latrend(method, latrendData)
#' nIds(model)
#' @family lcModel functions
nIds = function(object) {
  modelIds = ids(object)
  length(modelIds)
}


# . nClusters ####
#' @export
#' @name nClusters
#' @aliases nClusters,lcModel-method
#' @title Number of clusters
#' @description Get the number of clusters estimated by the given `lcModel` object.
#' @param object The `lcModel` object.
#' @param ... Not used.
#' @return An `integer` with the number of clusters identified by the `lcModel`.
#' @seealso [nIds] [nobs]
#' @examples
#' data(latrendData)
#' method <- lcMethodRandom("Y", id = "Id", time = "Time", nClusters = 3)
#' model <- latrend(method, latrendData)
#' nClusters(model) # 3
#' @family lcModel functions
setMethod('nClusters', 'lcModel', function(object, ...) {
  nClus = length(object@clusterNames)
  assert_that(
    is.count(nClus),
    nClus > 0
  )
  nClus
})


#' @export
#' @importFrom stats nobs
#' @title Number of observations used for the lcModel fit
#' @description Extracts the number of observations that contributed information towards fitting the cluster trajectories of the respective `lcModel` object.
#' Therefore, only non-missing response observations count towards the number of observations.
#' @param object The `lcModel` object.
#' @param ... Additional arguments.
#' @family lcModel functions
#' @seealso [nIds] [nClusters]
#' @examples
#' data(latrendData)
#' method <- lcMethodLMKM(Y ~ Time, id = "Id", time = "Time")
#' model <- latrend(method, latrendData)
#' nobs(model)
nobs.lcModel = function(object, ...) {
  suppressWarnings({
    data = model.data(object)
  })

  if (is.null(data)) {
    0L
  } else {
    x = data[[responseVariable(object)]]
    sum(is.finite(x))
  }
}


#' @export
#' @rdname predict.lcModel
#' @importFrom stats predict
#' @title lcModel predictions
#' @description Predicts the expected trajectory observations at the given time for each cluster.
#' @section Implementation:
#' Note: Subclasses of `lcModel` should preferably implement [predictForCluster()] instead of overriding `predict.lcModel` as that function is designed to be easier to implement because it is single-purpose.
#'
#' The `predict.lcModelExt` function should be able to handle the case where `newdata = NULL` by returning the fitted values.
#' After post-processing the non-NULL newdata input, the observation- and cluster-specific predictions can be computed.
#' Lastly, the output logic is handled by the [transformPredict()] function. It converts the computed predictions (e.g., `matrix` or `data.frame`) to the appropriate output format.
#' \preformatted{
#' predict.lcModelExt <- function(object, newdata = NULL, what = "mu", ...) {
#'   if (is.null(newdata)) {
#'     newdata = model.data(object)
#'     if (hasName(newdata, 'Cluster')) {
#'       # allowing the Cluster column to remain would break the fitted() output.
#'       newdata[['Cluster']] = NULL
#'     }
#'   }
#'
#'   # compute cluster-specific predictions for the given newdata
#'   pred <- NEWDATA_COMPUTATIONS_HERE
#'   transformPredict(pred = pred, model = object, newdata = newdata)
#' })
#' }
#' @param object The `lcModel` object.
#' @param newdata Optional `data.frame` for which to compute the model predictions. If omitted, the model training data is used.
#' Cluster trajectory predictions are made when ids are not specified.
#' @param what The distributional parameter to predict. By default, the mean response 'mu' is predicted. The cluster membership predictions can be obtained by specifying `what = 'mb'`.
#' @return If `newdata` specifies the cluster membership; a `data.frame` of cluster-specific predictions. Otherwise, a `list` of `data.frame` of cluster-specific predictions is returned.
#' @param ... Additional arguments.
#' @param useCluster Whether to use the "Cluster" column in the newdata argument for computing predictions conditional on the respective cluster.
#' For `useCluster = NA` (the default), the feature is enabled if newdata contains the "Cluster" column.
#' @examples
#' data(latrendData)
#' method <- lcMethodLMKM(Y ~ Time, id = "Id", time = "Time")
#' model <- latrend(method, latrendData)
#'
#' predFitted <- predict(model) # same result as fitted(model)
#'
#' # Cluster trajectory of cluster A
#' predCluster <- predict(model, newdata = data.frame(Cluster = "A", Time = time(model)))
#'
#' # Prediction for id S1 given cluster A membership
#' predId <- predict(model, newdata = data.frame(Cluster = "A", Id = "S1", Time = time(model)))
#'
#' # Prediction matrix for id S1 for all clusters
#' predIdAll <- predict(model, newdata = data.frame(Id = "S1", Time = time(model)))
#' @seealso [predictForCluster] [stats::predict] [fitted.lcModel] [clusterTrajectories] [trajectories] [predictPostprob] [predictAssignments]
#' @family lcModel functions
predict.lcModel = function(object, newdata = NULL, what = 'mu', ..., useCluster = NA) {
  assert_that(
    is_newdata(newdata),
    is.string(what),
    nzchar(what),
    is.flag(useCluster)
  )

  if (is.na(useCluster)) {
    useCluster = hasName(newdata, 'Cluster')
  }

  predMethod = selectMethod('predictForCluster', class(object), optional = TRUE)
  if (is.null(predMethod) || predMethod@defined@.Data == 'lcModel') {
    stop(sprintf('Cannot compute predictions for model of class %1$s because neither predict.%1$s nor predictForCluster(%1$s) are implemented for this model', class(object)[1]))
  }

  # special case for when no newdata is provided
  if (is.null(newdata)) {
    newdata = model.data(object)
    if (hasName(newdata, 'Cluster')) {
      newdata[['Cluster']] = NULL # allowing the Cluster column to remain would break the fitted() output.
    }
  }
  else {
    if (nrow(newdata) == 0) {
      warning('called predict() with empty newdata data.frame (nrow = 0)')
    }
  }

  newdata = as.data.table(newdata)

  if (useCluster) {
    assert_that(
      hasName(newdata, 'Cluster'),
      msg = 'newdata must contain a "Cluster" column when useCluster = TRUE'
    )
    # enforce cluster ordering
    newdata[, Cluster := factor(Cluster, levels = clusterNames(object))]

    assert_that(
      noNA(newdata$Cluster),
      all(unique(newdata$Cluster) %in% clusterNames(object)),
      msg = paste0(
        'The provided newdata "Cluster" column must be complete and only contain cluster names associated with the model (',
        paste0(shQuote(clusterNames(object)), collapse = ', '), ').'
      )
    )

    # predictForCluster with newdata subsets
    clusdataList = split(newdata, by = 'Cluster', sorted = TRUE, drop = TRUE) %>%
      lapply(function(cdata) cdata[, Cluster := NULL])
  }
  else {
    # drop Cluster column because it is not being used in order to prevent merge issues in transformPredict()
    if (hasName(newdata, 'Cluster')) {
      newdata[, Cluster := NULL]
    }

    # predictForCluster with newdata for each cluster
    clusdataList = replicate(nClusters(object), newdata, simplify = FALSE)
    names(clusdataList) = clusterNames(object)
  }

  predList = Map(
    function(cname, cdata) {
      predictForCluster(object, cluster = cname, newdata = cdata, what = what, ...)
    },
    names(clusdataList),
    clusdataList
  )

  assert_that(
    length(predList) == length(clusdataList),
    msg = 'unexpected internal state. please report'
  )
  assert_that(
    all(vapply(predList, function(x) is(x, class(predList[[1]])), FUN.VALUE = TRUE)),
    msg = 'output from predictForCluster() must be same class for all clusters. Check the model implementation.'
  )


  if (is.data.frame(predList[[1]])) {
    pred = rbindlist(predList, idcol = 'Cluster')
    pred[, Cluster := factor(Cluster, levels = seq_len(nClusters(object)), labels = clusterNames(object))]
  }
  else if (is.numeric(predList[[1]])) {
    clusDataRows = vapply(clusdataList, nrow, FUN.VALUE = 0)
    clusPredRows = vapply(predList, length, FUN.VALUE = 0)
    assert_that(
      all(clusDataRows == clusPredRows),
      msg = 'Numeric output length from predictForCluster() does not match the number of input newdata rows for one or more clusters'
    )

    pred = data.table(
      Cluster = rep(factor(names(clusDataRows), levels = clusterNames(object)), clusDataRows),
      Fit = do.call(c, predList)
    )
  }
  else {
    stop(
      'unsupported output from predictForCluster(): must be data.frame or numeric. Check the model implementation.'
    )
  }

  preddata = cbind(
    rbindlist(clusdataList),
    pred
  )

  transformPredict(pred = preddata, model = object, newdata = newdata)
}


# . predictForCluster ####
#' @export
#' @name predictForCluster
#' @rdname predictForCluster
#' @aliases predictForCluster,lcModel-method
#' @title lcModel prediction conditional on a cluster
#' @description Predicts the expected trajectory observations at the given time under the assumption that the trajectory belongs to the specified cluster.
#'
#' The same result can be obtained by calling [`predict()`][predict.lcModel()] with the `newdata` `data.frame` having a `"Cluster"` assignment column.
#' The main purpose of this function is to make it easier to implement the prediction computations for custom `lcModel` classes.
#'
#' @details The default `predictForCluster()` method makes use of [predict.lcModel()], and vice versa. For this to work, any extending `lcModel` classes, e.g., `lcModelExample`, should implement either `predictForCluster(lcModelExample)` or `predict.lcModelExample()`. When implementing new models, it is advisable to implement `predictForCluster` as the cluster-specific computation generally results in shorter and simpler code.
#' @inheritParams predict.lcModel
#' @param cluster The cluster name (as `character`) to predict for.
#' @param ... Additional arguments.
#' @return A `vector` with the predictions per `newdata` observation, or a `data.frame` with the predictions and newdata alongside.
#' @section Implementation:
#' Classes extending `lcModel` should override this method, unless [predict.lcModel()] is preferred.
#' \preformatted{
#' setMethod("predictForCluster", "lcModelExt",
#'  function(object, newdata = NULL, cluster, ..., what = "mu") {
#'   # return model predictions for the given data under the
#'   # assumption of the data belonging to the given cluster
#' })
#' }
#' @seealso [predict.lcModel]
#' @family lcModel functions
#' @examples
#' data(latrendData)
#' method <- lcMethodLMKM(Y ~ Time, id = "Id", time = "Time")
#' model <- latrend(method, latrendData)
#'
#' predictForCluster(
#'   model,
#'   newdata = data.frame(Time = c(0, 1)),
#'   cluster = "B"
#' )
#'
#' # all fitted values under cluster B
#' predictForCluster(model, cluster = "B")
setMethod('predictForCluster', 'lcModel',
  function(object, newdata = NULL, cluster, ..., what = 'mu') {
  # check whether predict.lcModelType exists
  cls = class(object)
  classes = extends(class(object)) %>% setdiff('lcModel')
  methodsAvailable = vapply(classes, function(x) !is.null(getS3method('predict', x, optional = TRUE)), FUN.VALUE = FALSE)

  assert_that(
    any(methodsAvailable),
    msg = sprintf('Cannot compute cluster-specific predictions for model of class %1$s because neither predict.%1$s nor predictForCluster(%1$s) are implemented for this model', cls)
  )

  newdata = cbind(newdata, Cluster = cluster)
  pred = predict(object, newdata = newdata, ..., what = what)
  pred$Fit
})



# . predictPostprob ####
#' @export
#' @name predictPostprob
#' @rdname predictPostprob
#' @aliases predictPostprob,lcModel-method
#' @title lcModel posterior probability prediction
#' @description Returns the observation-specific posterior probabilities for the given data.
#' The default implementation returns a uniform probability matrix.
#' @param object The `lcModel` to predict the posterior probabilities with.
#' @param newdata Optional data frame for which to compute the posterior probability. If omitted, the model training data is used.
#' @param ... Additional arguments.
#' @return A N-by-K `matrix` indicating the posterior probability per trajectory per measurement on each row, for each cluster (the columns).
#' Here, `N = nrow(newdata)` and `K = nClusters(object)`.
#' @section Implementation:
#' Classes extending `lcModel` should override this method to enable posterior probability predictions for new data.
#' \preformatted{
#' setMethod("predictPostprob", "lcModelExt", function(object, newdata = NULL, ...) {
#'   # return observation-specific posterior probability matrix
#' })
#' }
#' @family lcModel functions
setMethod('predictPostprob', 'lcModel', function(object, newdata = NULL, ...) {
  if (is.null(newdata)) {
    N = nrow(model.data(object))
    pp = postprob(object, ...)
    rownames(pp) = NULL
    pp[make.idRowIndices(object), ]
  }
  else {
    warning(
      'predictPostprob() not implemented for ',
      class(object)[1],
      '. Returning uniform probability matrix.'
    )

    matrix(1 / nClusters(object), nrow = nrow(newdata), ncol = nClusters(object))
  }
})


#. predictAssignments ####
#' @export
#' @name predictAssignments
#' @rdname predictAssignments
#' @aliases predictAssignments,lcModel-method
#' @title Predict the cluster assignments for new trajectories
#' @description Computes the posterior probability based on the provided (observed) data.
#' @inheritParams predict.lcModel
#' @param strategy A function returning the cluster index based on the given `vector` of membership probabilities.
#' By default (`strategy = which.max`), trajectories are assigned to the most likely cluster.
#' @details The default implementation uses [predictPostprob] to determine the cluster membership.
#' @return A `factor` of length `nrow(newdata)` that indicates the assigned cluster per trajectory per observation.
#' @seealso [predictPostprob] [predict.lcModel]
#' @family lcModel functions
#' @examples
#' \dontrun{
#' data(latrendData)
#' if (require("kml")) {
#'   model <- latrend(method = lcMethodKML("Y", id = "Id", time = "Time"), latrendData)
#'   predictAssignments(model, newdata = data.frame(Id = 999, Y = 0, Time = 0))
#' }
#' }
setMethod('predictAssignments', 'lcModel', function(
  object,
  newdata = NULL,
  strategy = which.max,
  ...
  ) {

  pp = predictPostprob(object, newdata = newdata, ...)

  if (is.null(newdata)) {
    newdata = model.data(object)
  }

  assert_that(
    is_valid_postprob(pp, object),
    nrow(pp) == nrow(newdata)
  )

  apply(pp, 1, strategy, ...) %>%
    factor(levels = 1:nClusters(object), labels = clusterNames(object))
})


#. plot ####
#' @export
#' @name plot-lcModel-method
#' @aliases plot,lcModel,ANY-method plot,lcModel-method
#' @title Plot a lcModel
#' @description Plot a `lcModel` object. By default, this plots the cluster trajectories of the model, along with the training data.
#' @param x The `lcModel` object.
#' @param y Not used.
#' @inheritDotParams plotClusterTrajectories
#' @return A `ggplot` object.
#' @seealso [plotClusterTrajectories] [plotFittedTrajectories] [plotTrajectories] [ggplot2::ggplot]
#' @examples
#' data(latrendData)
#' method <- lcMethodLMKM(Y ~ Time, id = "Id", time = "Time")
#' model <- latrend(method, latrendData, nClusters = 3)
#'
#' if (require("ggplot2")) {
#'   plot(model)
#' }
#' @family lcModel functions
setMethod('plot', c('lcModel', 'ANY'), function(x, y, ...) {
  args = list(...)

  if (!has_name(args, 'trajectories')) {
    args$trajectories = !has_name(args, 'what')
  }

  do.call(plotClusterTrajectories, c(x, args))
})


#. plotFittedTrajectories ####
#' @export
#' @name plotFittedTrajectories
#' @rdname plotFittedTrajectories
#' @aliases plotFittedTrajectories,lcModel-method
#' @title Plot fitted trajectories of a lcModel
#' @param object The `lcModel` object.
#' @param ... Arguments passed to [fittedTrajectories()].
#' @inheritDotParams trajectories
#' @seealso [fittedTrajectories] [plotClusterTrajectories] [plotTrajectories] [plot]
#' @examples
#' data(latrendData)
#' method <- lcMethodLMKM(Y ~ Time, id = "Id", time = "Time")
#' model <- latrend(method, latrendData, nClusters = 3)
#'
#' if (require("ggplot2")) {
#'   plotFittedTrajectories(model)
#' }
#' @family lcModel functions
setMethod('plotFittedTrajectories', 'lcModel', function(object, ...) {
  data = fittedTrajectories(object, ...)

  plotTrajectories(
    data,
    response = responseVariable(object),
    time = timeVariable(object),
    id = idVariable(object),
    cluster = 'Cluster',
    ...
  )
})

#. plotClusterTrajectories ####
# NOTE: Cannot include @inheritDotParams clusterTrajectories-method because the name is not supported by Roxygen
#' @export
#' @name plotClusterTrajectories
#' @rdname plotClusterTrajectories
#' @aliases plotClusterTrajectories,lcModel-method
#' @title Plot the cluster trajectories of a lcModel
#' @inheritParams clusterTrajectories
#' @param clusterLabels Cluster display names. By default it's the cluster name with its proportion enclosed in parentheses.
#' @param trajAssignments The cluster assignments for the fitted trajectories. Only used when `trajectories = TRUE` and `facet = TRUE`. See [trajectoryAssignments].
#' @param ... Arguments passed to [clusterTrajectories()], or [ggplot2::geom_line()] for plotting the cluster trajectory lines.
#' @return A `ggplot` object.
#' @seealso [clusterTrajectories] [plotFittedTrajectories] [plotTrajectories] [plot]
#' @examples
#' data(latrendData)
#' method <- lcMethodLMKM(Y ~ Time, id = "Id", time = "Time")
#' model <- latrend(method, latrendData, nClusters = 3)
#'
#' if (require("ggplot2")) {
#'   plotClusterTrajectories(model)
#'
#'   # show assigned trajectories
#'   plotClusterTrajectories(model, trajectories = TRUE)
#'
#'   # show 95th percentile observation interval
#'   plotClusterTrajectories(model, trajectories = "95pct")
#'
#'   # show observation standard deviation
#'   plotClusterTrajectories(model, trajectories = "sd")
#'
#'   # show observation standard error
#'   plotClusterTrajectories(model, trajectories = "se")
#'
#'   # show observation range
#'   plotClusterTrajectories(model, trajectories = "range")
#' }
#' @family lcModel functions
setMethod('plotClusterTrajectories', 'lcModel',
  function(object,
    what = 'mu',
    at = time(object),
    clusterLabels = NULL,
    trajectories = FALSE,
    facet = !isFALSE(as.logical(trajectories[1])),
    trajAssignments = trajectoryAssignments(object),
    ...
  ) {
  if (is.null(clusterLabels)) {
    clusterLabels = sprintf(
      '%s (%g%%)',
      clusterNames(object),
      round(clusterProportions(object) * 100)
    )
  }
  assert_that(length(clusterLabels) == nClusters(object))

  clusdata = clusterTrajectories(object, at = at, what = what, ...) %>% as.data.table()
  clusdata[, Cluster := factor(Cluster, levels = levels(Cluster), labels = clusterLabels)]

  rawdata = model.data(object) %>% as.data.table()
  if (nrow(rawdata) > 0 && !is.null(trajAssignments)) {
    assert_that(
      length(trajAssignments) == nIds(object),
      all(trajAssignments %in% clusterNames(object))
    )
    trajAssignments = factor(trajAssignments, levels = clusterNames(object), labels = clusterLabels)
    rawdata[, Cluster := trajAssignments[make.idRowIndices(object)]]
  }

  .plotClusterTrajs(
    clusdata,
    response = responseVariable(object, what = what),
    time = timeVariable(object),
    id = idVariable(object),
    trajectories = trajectories[1],
    facet = facet,
    rawdata = rawdata,
    ...
  )
})


#. plotTrajectories ####
#' @export
#' @rdname plotTrajectories
#' @aliases plotTrajectories,lcModel-method
#' @param ... Additional arguments passed to [trajectories()].
#' @examples
#' data(latrendData)
#' method <- lcMethodLMKM(Y ~ Time, id = "Id", time = "Time")
#' model <- latrend(method, latrendData, nClusters = 3)
#'
#' if (require("ggplot2")) {
#'   plotTrajectories(model)
#' }
setMethod('plotTrajectories', 'lcModel', function(object, ...) {
  data = trajectories(object, ...)
  plotTrajectories(
    data,
    id = idVariable(object),
    time = timeVariable(object),
    response = responseVariable(object),
    ...
  )
})


#. postprob ####
#' @export
#' @name postprob
#' @rdname postprob
#' @aliases postprob,lcModel-method
#' @title Posterior probability per fitted trajectory
#' @description Get the posterior probability matrix with element \eqn{(i,j)} indicating the probability of trajectory \eqn{i} belonging to cluster \eqn{j}.
#' @details This method should be extended by `lcModel` implementations. The default implementation returns uniform probabilities for all observations.
#' @param object The `lcModel`.
#' @param ... Additional arguments.
#' @return A I-by-K `matrix` with `I = nIds(object)` and `K = nClusters(object)`.
#' @section Implementation:
#' Classes extending `lcModel` should override this method.
#' \preformatted{
#' setMethod("postprob", "lcModelExt", function(object, ...) {
#'   # return trajectory-specific posterior probability matrix
#' })
#' }
#' @section Troubleshooting:
#' If you are getting errors about undefined model signatures when calling postprob(model),
#' check whether the postprob() function is still the one defined by the latrend package.
#' It may have been overridden when attaching another package (e.g., lcmm). If you need to attach conflicting packages, load them first.
#' @seealso [trajectoryAssignments] [predictPostprob] [predictAssignments]
#' @family lcModel functions
#' @examples
#' data(latrendData)
#' method <- lcMethodLMKM(Y ~ Time, id = "Id", time = "Time")
#' model <- latrend(method, latrendData)
#'
#' postprob(model)
#'
#' if (rlang::is_installed("lcmm")) {
#'   gmmMethod = lcMethodLcmmGMM(
#'     fixed = Y ~ Time,
#'     mixture = ~ Time,
#'     id = "Id",
#'     time = "Time",
#'     idiag = TRUE,
#'     nClusters = 2
#'   )
#'   gmmModel <- latrend(gmmMethod, data = latrendData)
#'   postprob(gmmModel)
#' }
setMethod('postprob', 'lcModel', function(object, ...) {
  if (nIds(object) > 0) {
    warning(
      'postprob() not implemented for ',
      class(object)[1],
      '. Returning uniform posterior probability matrix.'
    )

    matrix(1 / nClusters(object), nrow = nIds(object), ncol = nClusters(object))
  }
  else {
    warning(
      'postprob() not implemented for ',
      class(object)[1],
      ' and no associated trajectories for this model. Returning empty matrix.'
    )

    matrix(1 / nClusters(object), nrow = 0, ncol = nClusters(object))
  }
})


# . QQ plot ####
#' @export
#' @name qqPlot
#' @rdname qqPlot
#' @aliases qqPlot,lcModel-method
#' @title Quantile-quantile plot
#' @description Plot the quantile-quantile (Q-Q) plot for the fitted `lcModel` object. This function is based on the \pkg{qqplotr} package.
#' @param object The `lcModel` object.
#' @param byCluster Whether to plot the Q-Q line per cluster
#' @param ... Additional arguments passed to [qqplotr::geom_qq_band()], [qqplotr::stat_qq_line()], and [qqplotr::stat_qq_point()].
#' @return A `ggplot` object.
#' @seealso [residuals.lcModel] [metric] [plotClusterTrajectories]
#' @examples
#' data(latrendData)
#' method <- lcMethodLMKM(Y ~ Time, id = "Id", time = "Time", nClusters = 3)
#' model <- latrend(method, latrendData)
#'
#' if (require("ggplot2") && require("qqplotr")) {
#'   qqPlot(model)
#' }
#' @family lcModel functions
setMethod('qqPlot', 'lcModel', function(object, byCluster = FALSE, ...) {
  .loadOptionalPackage('ggplot2')
  assert_that(
    is.lcModel(object),
    is.flag(byCluster)
  )

  res = residuals(object, ...)
  mdata = model.data(object)
  assert_that(!is.null(mdata), msg = 'no model data available')

  idIndexColumn = factor(mdata[[idVariable(object)]], levels = ids(object)) %>%
    as.integer()
  rowClusters = trajectoryAssignments(object)[idIndexColumn]

  requireNamespace('qqplotr')
  p = ggplot2::ggplot(
      data = data.frame(Cluster = rowClusters, res = res),
      mapping = ggplot2::aes(sample = res)
    ) +
    qqplotr::geom_qq_band(...) +
    qqplotr::stat_qq_line(...) +
    qqplotr::stat_qq_point(...) +
    ggplot2::labs(
      x = 'Theoretical quantiles',
      y = 'Sample quantiles',
      title = 'Quantile-quantile plot'
    )

  if (isTRUE(byCluster)) {
    p = p + ggplot2::facet_wrap(~ Cluster)
  }

  p
})


#' @export
#' @importFrom stats residuals
#' @title Extract lcModel residuals
#' @description Extract the residuals for a fitted `lcModel` object.
#' By default, residuals are computed under the most likely cluster assignment for each trajectory.
#' @inheritParams fitted.lcModel
#' @return A `numeric vector` of residuals for the cluster assignments specified by clusters.
#' If the `clusters` argument is unspecified, a `matrix` of cluster-specific residuals per observations is returned.
#' @family lcModel functions
#' @seealso [fitted.lcModel] [trajectories]
residuals.lcModel = function(object, ..., clusters = trajectoryAssignments(object)) {
  ypred = fitted(object, clusters = clusters, ...)
  yref = model.data(object)[[responseVariable(object)]]

  if (is.matrix(ypred)) {
    assert_that(length(yref) == nrow(ypred))
    resMat = matrix(yref, nrow = nrow(ypred), ncol = ncol(ypred)) - ypred
    colnames(resMat) = colnames(ypred)

    resMat
  } else if (is.numeric(ypred)) {
    assert_that(length(yref) == length(ypred))

    yref - ypred
  } else {
    NULL
  }
}

#. responseVariable ####
#' @export
#' @name responseVariable
#' @rdname responseVariable
#' @aliases responseVariable,lcModel-method
#' @title Get the response variable
#' @examples
#' data(latrendData)
#' method <- lcMethodRandom("Y", id = "Id", time = "Time")
#' model <- latrend(method, latrendData)
#' responseVariable(model) # "Y"
#' @family lcModel variables
setMethod('responseVariable', 'lcModel', function(object, ...) object@response)

# . estimationTime ####
#' @export
#' @name estimationTime
#' @rdname estimationTime
#' @aliases estimationTime,lcModel-method
#' @title Get the model estimation time
#' @description Get the estimation time of the model, determined by the time taken for the associated [fit()] function to finish.
#' @param object The `lcModel` object.
#' @param unit The time unit in which the estimation time should be outputted.
#' By default, estimation time is in seconds.
#' For accepted units, see [base::difftime].
#' @param ... Additional arguments.
#' @return A `numeric` representing the model estimation time, in the specified unit.
#' @examples
#' data(latrendData)
#' method <- lcMethodLMKM(Y ~ Time, id = "Id", time = "Time")
#' model <- latrend(method, latrendData)
#'
#' estimationTime(model)
#' estimationTime(model, unit = 'mins')
#' estimationTime(model, unit = 'days')
#' @family lcModel functions
setMethod('estimationTime', 'lcModel', function(object, unit, ...) {
  assert_that(is.lcModel(object))
  dtime = as.difftime(object@estimationTime, units = 'secs')
  as.numeric(dtime, units = unit)
})


# . show ####
setMethod('show', 'lcModel', function(object) {
  show(summary(object))
})


#' @export
#' @importFrom stats sigma
#' @title Extract residual standard deviation from a lcModel
#' @description Extracts or estimates the residual standard deviation. If [sigma()] is not defined for a model, it is estimated from the residual error vector.
#' @param object The `lcModel` object.
#' @param ... Additional arguments.
#' @return A `numeric` indicating the residual standard deviation.
#' @seealso [coef.lcModel] [metric]
#' @family lcModel functions
sigma.lcModel = function(object, ...) {
  if (is.null(object@model) ||
      is.null(getS3method('sigma', class = class(object@model)[1], optional = TRUE))) {
    sd(residuals(object, ...))
  } else {
    # nocov start
    sigma(object@model, ...)
    # nocov end
  }
}


#. strip ####
#' @export
#' @name strip
#' @rdname strip
#' @aliases strip,lcModel-method
#' @title Reduce the lcModel memory footprint for serialization
#' @description Strip a lcModel of non-essential variables and environments in order to reduce the model size for serialization.
#' @param object The `lcModel` object.
#' @param classes The object classes for which to remove their assigned environment. By default, only environments from `formula` are removed.
#' @param ... Additional arguments.
#' @return An `lcModel` object of the same type as the `object` argument.
#' @section Implementation:
#' Classes extending `lcModel` can override this method to remove additional non-essentials.
#' \preformatted{
#' setMethod("strip", "lcModelExt", function(object, ..., classes = "formula") {
#'   object <- callNextMethod()
#'   # further process the object
#'   return(object)
#' })
#' }
#' @examples
#' data(latrendData)
#' method <- lcMethodLMKM(Y ~ Time, id = "Id", time = "Time")
#' model <- latrend(method, latrendData)
#' newModel <- strip(model)
#' @family lcModel functions
setMethod('strip', 'lcModel', function(object, ..., classes = 'formula') {
  newObject = object

  environment(newObject) = NULL
  newObject@method = strip(object@method, ..., classes = classes)
  newObject@call = strip(object@call, ..., classes = classes)

  newObject
})


#. timeVariable ####
#' @export
#' @name timeVariable
#' @rdname timeVariable
#' @aliases timeVariable,lcModel-method
#' @examples
#' data(latrendData)
#' method <- lcMethodRandom("Y", id = "Id", time = "Time")
#' model <- latrend(method, latrendData)
#' timeVariable(model) # "Time"
#' @family lcModel variables
setMethod('timeVariable', 'lcModel', function(object) object@time)


#' @export
#' @importFrom stats time
#' @title Sampling times of a lcModel
#' @description Extract the sampling times on which the `lcModel` was fitted.
#' @param x The `lcModel` object.
#' @param ... Not used.
#' @return A `numeric vector` of the unique times at which observations occur, in increasing order.
#' @family lcModel functions
#' @seealso [timeVariable] [model.data]
time.lcModel = function(x, ...) {
  if (length(x@times) == 0) {
    assert_that(
      has_name(model.data(x), timeVariable(x)),
      msg = sprintf(
        'cannot identify model times: model.data() is missing time variable column "%s"',
        timeVariable(x)
    ))
    times = model.data(x)[[timeVariable(x)]] %>% unique() %>% sort()
  } else {
    times = x@times
  }

  assert_that(
    length(times) > 0,
    is.vector(times),
    is.numeric(times)
  )

  times
}


# . trajectories ####
#' @rdname trajectories
#' @aliases trajectories,lcModel-method
setMethod('trajectories', 'lcModel', function(object, ...) {
  data = model.data(object)

  assert_that(!is.null(data), msg = 'no associated training data for this model')

  id = idVariable(object)
  time = timeVariable(object)
  res = responseVariable(object)

  trajdata = subset(data, select = c(id, time, res))

  trajectories(trajdata, id = id, time = time, response = res, ...)
})


#. trajectoryAssignments ####
#' @export
#' @name trajectoryAssignments
#' @aliases trajectoryAssignments,lcModel-method
#' @title Get the cluster membership of each trajectory
#' @description Classify the fitted trajectories based on the posterior probabilities computed by [postprob()], according to a given classification strategy.
#'
#' By default, trajectories are assigned based on the highest posterior probability using [which.max()].
#' In cases where identical probabilities are expected between clusters, it is preferable to use \link[nnet]{which.is.max} instead, as this function breaks ties at random.
#' Another strategy to consider is the function [which.weight()], which enables weighted sampling of cluster assignments based on the trajectory-specific probabilities.
#' @param object The object to obtain the cluster assignments from.
#' @param strategy A function returning the cluster index based on the given vector of membership probabilities. By default, ids are assigned to the cluster with the highest probability.
#' @param ... Any additional arguments passed to the strategy function.
#' @return A `factor` indicating the cluster membership for each trajectory.
#' @seealso [postprob] [clusterSizes] [predictAssignments]
#' @examples
#' data(latrendData)
#' method <- lcMethodLMKM(Y ~ Time, id = "Id", time = "Time")
#' model <- latrend(method, latrendData)
#' trajectoryAssignments(model)
#'
#' # assign trajectories at random using weighted sampling
#' trajectoryAssignments(model, strategy = which.weight)
#' @family lcModel functions
setMethod('trajectoryAssignments', 'lcModel', function(object, strategy = which.max, ...) {
  if (suppressWarnings(nIds(object)) == 0) {
    return(factor(levels = 1:nClusters(object), labels = clusterNames(object)))
  }

  pp = postprob(object, ...)

  result = apply(pp, 1, strategy, ...)

  assert_that(
    is.numeric(result),
    length(result) == nIds(object),
    all(
      vapply(result, is.count, FUN.VALUE = TRUE) |
        vapply(result, is.na, FUN.VALUE = TRUE)
    ),
    min(result, na.rm = TRUE) >= 1,
    max(result, na.rm = TRUE) <= nClusters(object)
  )

  assignments = factor(result, levels = 1:nClusters(object), labels = clusterNames(object))

  make.trajectoryAssignments(object, assignments)
})


#' @export
#' @importFrom stats update
#' @title Update a lcModel
#' @description Fit a new model with modified arguments from the current model.
#' @param object The `lcModel` object.
#' @param ... Any new method arguments to refit the model with.
#' @return The refitted `lcModel` object, of the same type as the `object` argument.
#' @inheritDotParams latrend
#' @seealso [latrend] [getCall]
#' @examples
#' data(latrendData)
#' method <- lcMethodLMKM(Y ~ Time, id = "Id", time = "Time")
#' model2 <- latrend(method, latrendData, nClusters = 2)
#'
#' # fit for a different number of clusters
#' model3 <- update(model2, nClusters = 3)
update.lcModel = function(object, ...) {
  assert_that(is.lcModel(object))
  modelCall = getCall(object)

  assert_that(
    as.character(modelCall[[1]]) != '<undef>',
    msg = 'cannot update lcModel because lcMethod call is undefined'
  )

  updateCall = match.call() %>% tail(-2)
  updateNames = names(updateCall)

  clCall = replace(modelCall, updateNames, updateCall[updateNames])

  eval(clCall, envir = parent.frame())
}

Try the latrend package in your browser

Any scripts or data that you put into this service are public.

latrend documentation built on March 31, 2023, 5:45 p.m.