R/modelApprox.R

Defines functions fitted.lcApproxModel

Documented in fitted.lcApproxModel

#' @include model.R
#' @export
#' @name lcApproxModel-class
#' @rdname lcApproxModel-class
#' @aliases lcApproxModel
#' @title lcApproxModel class
#' @description approx models have defined cluster trajectories at fixed moments in time, which should be interpolated
#' For a correct implementation, `lcApproxModel` requires the extending class to implement `clusterTrajectories(at=NULL)`
#' to return the fixed cluster trajectories
setClass('lcApproxModel', contains = 'lcModel')

#' @export
#' @rdname lcApproxModel-class
#' @inheritParams fitted.lcModel
fitted.lcApproxModel = function(object, ..., clusters = trajectoryAssignments(object)) {
  columns = c(idVariable(object), timeVariable(object), responseVariable(object))
  newdata = subset(model.data(object), select = columns)
  pred = predict(object, newdata = newdata, useCluster = FALSE)
  transformFitted(pred, model = object, clusters = clusters)
}

#. predictForCluster ####
#' @rdname lcApproxModel-class
#' @inheritParams predictForCluster
#' @param approxFun Function to interpolate between measurement moments, \link[stats]{approx}() by default.
setMethod('predictForCluster', 'lcApproxModel',
  function(object, newdata, cluster, what = 'mu', approxFun = approx, ...) {
  assert_that(is.function(approxFun))

  clusTrajs = clusterTrajectories(object, at = numeric(), what = what, approxFun = approxFun, ...) %>%
    as.data.table() %>%
    .[Cluster == cluster]

  time = timeVariable(object)
  resp = responseVariable(object)
  clusTimes = clusTrajs[[time]]
  newtimes = newdata[[time]]

  # check if we need to do any interpolation
  if (all(newtimes %in% clusTimes)) {
    pred = clusTrajs[match(newtimes, .time), .resp, env = list(.time = time, .resp = resp)]
    return(pred)
  }

  if (sum(is.finite(clusTrajs[[resp]])) < 2) {
    warning(
      sprintf(
        'Cannot interpolate cluster trajectory of cluster %s: need at least two non-NA cluster trajectory values',
        cluster
      )
    )
    return(rep(NaN, length(newtimes)))
  }

  dtpred = clusTrajs[,
    lapply(.SD, function(y) approxFun(x = .time, y = y, xout = newtimes)$y),
    keyby = Cluster, .SDcols = -c(time),
    env = list(.time = time)
  ]

  dtpred[[resp]]
})
philips-software/latrend documentation built on Jan. 9, 2025, 4:33 p.m.