R/gnlmm2.R

Defines functions mat.indices gnlmm2 getModelVars llik_binomial

Documented in gnlmm2

## gnlmm.R: population PK/PD modeling library
##
## Copyright (C) 2014 - 2016  Wenping Wang
##
## This file is part of nlmixr.
##
## nlmixr is free software: you can redistribute it and/or modify it
## under the terms of the GNU General Public License as published by
## the Free Software Foundation, either version 2 of the License, or
## (at your option) any later version.
##
## nlmixr is distributed in the hope that it will be useful, but
## WITHOUT ANY WARRANTY; without even the implied warranty of
## MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
## GNU General Public License for more details.
##
## You should have received a copy of the GNU General Public License
## along with nlmixr.  If not, see <http:##www.gnu.org/licenses/>.

## require(lbfgs)
## require(lbfgsb3)
## require(madness)
## require(Rcpp)
# require(nlmixr)
## require(parallel)
## require(minqa)
## require(Deriv)

llik_binomial <- function(y, n, params) {
  r <- llik_binomial_c(y, n, params)
  r$J <- diag(r$J)
  return(r)
}


#-- for new gnlmm
getModelVars <- function(blik, bpar, m1) {
  argsList <- list(
    dpois = list(ix = 2, dvdx = ".arg1"),
    dbinom = list(ix = 2:3, dvdx = ".arg2"),
    dnorm = list(ix = 2:3, dvdx = c(".arg1", ".arg2")),
    dbeta = list(ix = 2:3, dvdx = c(".arg1", ".arg2")),
    dneg_binomial = list(ix = 2:3, dvdx = c(".arg1", ".arg2")),
    dbetabinomial = list(ix = 2:4, dvdx = c(".arg2", ".arg3")),
    dt = list(ix = 2:4, dvdx = c(".arg1", ".arg2", ".arg3"))
  )

  blik.txt <- deparse(blik)
  len <- length(blik.txt)

  s <- gsub("\\s+", "", blik.txt[len - 1], perl = TRUE) # FIXME
  lp <- regexpr("\\(", s)
  dist <- substr(s, 1, lp - 1)
  s <- strsplit(substr(s, lp + 1, 200), ",")[[1]]

  args <- argsList[[dist]]
  args.ix <- args$ix # position of dist pars
  args.dvdx <- args$dvdx # args need dvdx
  narg <- length(args.ix)
  blik.new.text <- paste(c(
    blik.txt[2:(len - 2)],
    paste0(".arg", 1:narg, "=", s[args.ix])
  ), collapse = "\n")
  blik.new <- parse(text = blik.new.text)

  dist.df <- NULL
  if (dist == "dbinom") dist.df <- s[2] # binomial size
  if (dist == "dt") dist.df <- s[2] # t df

  #----------------------------
  f <- deparse(blik)
  .lhsrhs <- nlmixrfindRhsLhs(blik)

  states <- m1$get.modelVars()$state

  state.llik <- intersect(states, .lhsrhs$rhs) # state used in llik

  list(
    state.llik = state.llik,
    pars.llik = .lhsrhs$rhs,
    vars.par = .lhsrhs$lhs, # vars defined in pars
    dist = dist, dist.df = dist.df,
    blik.new = blik.new, blik.new.text = blik.new.text,
    args.dvdx = args.dvdx
  )
}

##' @rdname gnlmm
##' @export
gnlmm2 <- function(llik, data, inits, syspar = NULL,
                   system = NULL, diag.xform = c("sqrt", "log", "identity"),
                   ..., control = list()) {
  RxODE::rxReq("Deriv")
  ## data
  if (is.null(data$ID)) stop('"ID" not found in data')
  if (is.null(data$EVID)) data$EVID <- 0
  data.obs <- data[data$EVID == 0, ]
  data.sav <- data
  names(data) <- tolower(names(data)) # needed in ev

  # model
  if (is.null(system)) {
    stop("system must be an RxODE model or string")
  }
  else if (inherits(system, "RxODE")) {
    system <- RxODE(system, calcSens = TRUE)
  }
  else if (inherits(system, "character")) {
    system <- RxODE(model = system, calcSens = TRUE)
  }
  else {
    stop("invalid system input")
  }

  # options
  con <- list(
    trace = 0,
    maxit = 100L,
    atol.ode = 1e-08,
    rtol.ode = 1e-08,
    reltol.inner = 1.0e-4,
    reltol.outer = 1.0e-3,
    optim.inner = "lbfgs",
    optim.outer = "newuoa",
    start.zero.inner = FALSE,
    mc.cores = 1,
    nAQD = 1,
    transit_abs = FALSE,
    cov = FALSE,
    eps = c(1e-8, 1e-3), # finite difference step
    NOTRUN = FALSE,
    DEBUG.INNER = FALSE,
    rhobeg = .2,
    rhoend = 1e-3,
    iprint = 2,
    npt = NULL,
    do.optimHess = TRUE
  )
  nmsC <- names(con)
  con[(namc <- names(control))] <- control
  if (length(noNms <- namc[!namc %in% nmsC])) {
    warning("unknown names in control: ", paste(noNms, collapse = ", "))
  }

  square <- function(x) x * x
  diag.xform <- match.arg(diag.xform)
  diag.xform.inv <- c("sqrt" = "square", "log" = "exp", "identity" = "identity")[diag.xform]

  # process inits
  lh <- parseOM(inits$OMGA)
  nlh <- sapply(lh, length)
  osplt <- rep(1:length(lh), nlh)

  lini <- list(inits$THTA, unlist(lh))
  nlini <- sapply(lini, length)
  nsplt <- rep(1:length(lini), nlini)

  om0 <- genOM(lh)
  th0.om <- lapply(1:length(lh), function(k) {
    m <- genOM(lh[k])
    nr <- nrow(m)
    mi <- tryCatch(
      backsolve(chol(m), diag(nr)),
      error = function(e) {
        stop("OMEGA not positive-definite")
      }
    )
    diag(mi) <- eval(call(diag.xform, diag(mi)))
    mi[col(mi) >= row(mi)]
  })
  inits.vec <- c(inits$THTA, unlist(th0.om), inits$SGMA)
  names(inits.vec) <- NULL

  nTHTA <- nlini[1]
  nETA <- nrow(om0)
  ID.all <- unique(data[, "id"])
  ID.ord <- order(ID.all)
  names(ID.ord) <- ID.all
  nSUB <- length(ID.all)

  # gaussian quadrature nodes & wts
  nAQD <- con$nAQD
  nw <- gauss.quad(nAQD)
  mij <- as.matrix(
    do.call("expand.grid", lapply(1:nETA, function(x) 1:nAQD))
  )
  nij <- nrow(mij)


  # obj fn by AQD
  if (!is.null(syspar)) {
    bpar <- body(syspar)
  }
  blik <- body(llik)
  modVars <- getModelVars(blik, bpar, system)


  # === start of dvdx code
  proc.deriv <- function() {
    getDeriv <- function(pars) {
      npar <- length(pars)
      s <- deparse(bpar)
      for (i in 1:nETA) {
        s <- gsub(sprintf("\\bETA\\[%d\\]", i), sprintf("ETA%d", i), s, perl = TRUE)
      }
      s <- gsub("initCondition", "#initCondition", s)
      len <- length(s)
      s <- s[-len]

      ix <- matrix(unlist(expand.grid(1:npar, 1:nETA)), ncol = 2)
      m <- sapply(1:nrow(ix), function(k) {
        i <- ix[k, 1]
        j <- ix[k, 2]
        s1 <- c("Deriv::Deriv(~", s, pars[i], sprintf("}, \"ETA%d\")", j))
        a <- paste(s1, collapse = "\n")
        e <- eval(parse(text = a))
        s <- if (inherits(e, "call")) {
          s <- deparse(e)
          gsub(sprintf("\\bETA%d\\b", j), sprintf("ETA[%d]", j), s, perl = TRUE)
        } else {
          "0"
        }
        # parse(text=s)
        s
      })
      m
    }

    env <- environment()
    dati <- data.sav[data.sav$ID == ID.all[1], ]
    list2env(dati, env)
    THETA <- tapply(inits.vec, nsplt, identity)[[1]]
    ETA <- madness::madness(array(0, c(nETA, 1))) ## Is madness still needed...?
    eval(bpar)

    px <- as.list(env)
    madIx <- sapply(px, function(x) {
      if (inherits(x, "madness")) TRUE else FALSE
    })
    madVars <- names(px)[madIx]

    pars <- system$get.modelVars()$params
    px <- as.list(env)[pars]
    madIx <- sapply(px, function(x) {
      if (inherits(x, "madness")) TRUE else FALSE
    })

    pars <- pars[madIx]
    matode <- getDeriv(pars)
    pars <- setdiff(madVars, c(pars, "ETA"))
    matllk <- getDeriv(pars)

    list(matode = matode, madIx = madIx, matllk = matllk, madllk = pars) # deriv expr for ode pars
    # idx of ode pars that need deriv
    # llik pars that need deriv
    # deriv expr llik pars
  }
  s <- proc.deriv()

  # FIXME: chk par order of below 2 para against those in d(llik)/d(ETA) by chain rule
  # d(pars)/d(ETA) for ode
  madIx <- s$madIx
  npar <- sum(madIx)
  m <- s$matode
  idx.dpde.ode <- m != "0"
  expr.dpde.ode <- lapply(m[idx.dpde.ode], function(k) parse(text = k))
  dpde.ode <- matrix(0, npar, nETA)

  # d(pars)/d(ETA) for llik
  madVars.llk <- s$madllk
  npar <- length(madVars.llk)
  m <- s$matllk
  idx.dpde.llk <- m != "0"
  expr.dpde.llk <- lapply(m[idx.dpde.llk], function(k) parse(text = k))
  dpde.llk <- matrix(0, npar, nETA)
  dimnames(dpde.llk) <- list(madVars.llk, NULL)

  # d(args)/d(pars) for llik
  pars <- c(modVars$state.llik, madVars.llk)
  npar <- length(pars)
  lexpr <- unlist(lapply(modVars$args.dvdx, function(arg) {
    s <- sprintf("{\n%s\n%s", modVars$blik.new.text, arg) # FIXME

    expr.dldp <- lapply(1:npar, function(k) {
      s1 <- c("Deriv::Deriv(~", s, sprintf("}, \"%s\")", pars[k]))
      a <- paste(s1, collapse = "\n")
      e <- eval(parse(text = a))
      e
    })
    names(expr.dldp) <- pars
    expr.dldp
  }))

  llik.narg <- length(modVars$args.dvdx) # llik.narg = # of args in density that need dvdx
  # may have efficiency gain to rm args that do not need dvdx
  llik.npar <- npar
  m <- matrix(lexpr, llik.npar, llik.narg)
  dadp.expr <- c(t(m)) # chg the order of args & pars in llik; see dldp in ..fg()
  # === ends of dvdx code


  starts <- matrix(0., nSUB, nETA)
  .startsEnv <- environment()
  omga_save <- NULL
  update_starts <- TRUE


  # algo starts
  obj.vec <- function(th, noomga = FALSE) {
    th <- th * inits.vec
    # if (con$DEBUG.INNER) print(th)

    if (noomga) th <- c(th, omga_save)

    lth <- tapply(th, nsplt, identity)
    THETA <- lth[[1]]
    lD <- tapply(lth[[2]], osplt, identity)
    Dinv.5 <- genOMinv.5(lD)
    diag(Dinv.5) <- eval(call(diag.xform.inv, diag(Dinv.5)))
    detDinv.5 <- prod(diag(Dinv.5))

    ep <- environment()

    llik2.subj <- function(ix) {
      # ix=4; ETA=rep(0, nETA)
      dati <- data.sav[data.sav$ID == ix, ]
      evi <- dati[, c("TIME", "EVID", "AMT")]
      names(evi) <- tolower(names(evi))
      ev <- RxODE::eventTable()
      ev$import.EventTable(evi)
      dati <- dati[dati$EVID == 0, ]


      # ETA => pars & stateVar => .args => llik
      # d(llik)/d(ETA) = d(llik)/d(args) * d(args)/d(ETA)
      # d(args)/d(ETA) = d(args)/d(pars) * d(pars)/d(ETA)
      # stateVar is parallel to pars when forming args, however, stateVar = f(pars(ETA))
      # hence, we need d(State)/d(ETA). d(State)/d(ETA) = d(State)/d(pars) * d(pars)/d(ETA)
      ..g.fn <- function(ETA) {
        env <- environment()
        list2env(dati, env)
        eval(bpar)

        pars <- system$get.modelVars()$params
        po <- unlist(as.list(env)[pars])
        x <- rxSolve(system, po, ev, returnType = "matrix")
        if (any(is.na(x))) {
          print(ID[1])
          print(po)
          print(head(x, 10))
          stop("NA in ODE solution")
        }

        # d(State)/d(ETA)
        whState <- modVars$state.llik
        senState <- paste0("rx__sens_", whState, "_BY_", pars[madIx], "__")
        fxJ <- list(fx = x[, whState], J = x[, senState]) # FIXME, t()

        dvdx <- sapply(expr.dpde.ode, eval, envir = env) # FIXME
        dpde.ode[idx.dpde.ode] <- dvdx
        # d(State)/d(ETA) = d(State)/d(pars) * d(pars)/d(ETA)
        dvdxState <- fxJ$J %*% dpde.ode # FIXME

        # make state var & other symbols available for calc
        valState <- fxJ$fx
        assign(whState, valState, envir = env)
        eval(modVars$blik.new) # FIXME: here or at llik_binomial?

        # d(pars)/d(ETA)
        if (length(expr.dpde.llk) > 0) {
          dvdx <- sapply(expr.dpde.llk, eval, envir = env) # FIXME
          dpde.llk[idx.dpde.llk] <- dvdx
        } else {
          dpde.llk <- NULL
        }

        # d(args)/d(pars)
        ni <- dim(x)[1]
        dadp <- sapply(1:length(dadp.expr), function(k) { # why lapply(m, eval) doesn't work?
          s <- eval(dadp.expr[[k]])
          if (length(s) == 1) rep(s, ni) else s
        })
        dim(dadp) <- c(ni, llik.narg, llik.npar) # FIXME

        # d(args)/d(ETA) = d(args)/d(pars) * d(pars)/d(ETA)
        dade <- sapply(1:ni, function(k) { # FIXME: vectorize?
          dpde <- rbind(dvdxState[k, ], dpde.llk)
          dadp[k, , ] %*% dpde # FIXME: need t()?
        })
        # dade = t(s)										#FIXME: t() can be rm'ed?


        dim(dade) <- c(llik.narg, nETA, ni)


        # d(llik)/d(ETA) = d(llik)/d(args) * d(args)/d(ETA)
        if (modVars$dist == "dt") {
          if (length(.arg1) == 1 && ni > 1) .arg1 <- rep(.arg1, ni)
          if (length(.arg2) == 1 && ni > 1) .arg2 <- rep(.arg2, ni)
          if (length(.arg3) == 1 && ni > 1) .arg3 <- rep(.arg3, ni)
          s <- sapply(1:ni, function(k) { # FIXME: vectorize?
            unlist(llik_student_t(DV[k], c(.arg1[k], .arg2[k], .arg3[k])))
          })
          s <- list(fx = s[1, ], J = t(s[-1, ]))
        } else if (modVars$dist == "dbetabinomial") {
          if (length(.arg1) == 1 && ni > 1) .arg1 <- rep(.arg1, ni)
          if (length(.arg2) == 1 && ni > 1) .arg2 <- rep(.arg2, ni)
          if (length(.arg3) == 1 && ni > 1) .arg3 <- rep(.arg3, ni)
          s <- sapply(1:ni, function(k) { # FIXME: vectorize?
            unlist(llik_betabinomial(DV[k], .arg1[k], c(.arg2[k], .arg3[k])))
          })
          s <- list(fx = s[1, ], J = t(s[-1, ]))
        } else if (modVars$dist == "dneg_binomial") {
          if (length(.arg1) == 1 && ni > 1) .arg1 <- rep(.arg1, ni)
          if (length(.arg2) == 1 && ni > 1) .arg2 <- rep(.arg2, ni)
          s <- sapply(1:ni, function(k) { # FIXME: vectorize?
            unlist(llik_neg_binomial(DV[k], c(.arg1[k], .arg2[k])))
          })
          s <- list(fx = s[1, ], J = t(s[-1, ]))
        } else if (modVars$dist == "dbeta") {
          if (length(.arg1) == 1 && ni > 1) .arg1 <- rep(.arg1, ni)
          if (length(.arg2) == 1 && ni > 1) .arg2 <- rep(.arg2, ni)
          s <- sapply(1:ni, function(k) { # FIXME: vectorize?
            unlist(llik_beta(DV[k], c(.arg1[k], .arg2[k])))
          })
          s <- list(fx = s[1, ], J = t(s[-1, ]))
        } else if (modVars$dist == "dnorm") {
          if (length(.arg1) == 1 && ni > 1) .arg1 <- rep(.arg1, ni)
          if (length(.arg2) == 1 && ni > 1) .arg2 <- rep(.arg2, ni)
          s <- sapply(1:ni, function(k) { # FIXME: vectorize?
            if (.arg2[k] < 0.0001) {
              # print("HAY!");print(.arg1[k]);print(.arg2[k])
            }
            unlist(llik_normal(DV[k], c(.arg1[k], .arg2[k])))
          })
          s <- list(fx = s[1, ], J = t(s[-1, ]))
        } else if (modVars$dist == "dbinom") {
          if (length(.arg1) == 1 && ni > 1) .arg1 <- rep(.arg1, ni)
          s <- llik_binomial(DV, .arg1, c(.arg2))
        } else if (modVars$dist == "dpois") {
          s <- llik_poisson(DV, c(.arg1))
        } else {
          stop("dist not supported")
        }

        dim(s$J) <- c(ni, llik.narg)
        dlde <- sapply(1:ni, function(k) { # FIXME: vectorize?
          s$J[k, ] %*% dade[, , k] # FIXME: need t()?
        })
        s <- rbind(s$fx, dlde) # FIXME t()?
        s <- apply(s, 1, sum) # FIXME sum index; chg'ed w/ vec stan call

        # llik.dat = madness::madness(val=matrix(s[1], 1, 1), dvdx=matrix(s[-1], 1, nETA))
        # llik.eta = -crossprod(Dinv.5 %*% ETA)/2 -nETA/2*log(2*pi)+log(detDinv.5)
        # llik.eta = madness::madness(val=val(llik.eta), dvdx=dvdx(llik.eta))
        llik.eta.val <- -crossprod(Dinv.5 %*% ETA) / 2 - nETA / 2 * log(2 * pi) + log(detDinv.5)
        llik.eta.dvd <- -t(ETA) %*% crossprod(Dinv.5)

        r <- s[1] + c(llik.eta.val)
        attr(r, "dvdx") <- s[-1] + c(llik.eta.dvd)
        r
      }
      pvd <- NULL
      .pvdEnv <- environment()
      fg <- function(par) {
        if (identical(par, pvd[[1]])) {
          return(pvd)
        }
        ym <- ..g.fn(par)
        assign("pvd", list(par, ym), .pvdEnv)
      }
      f <- function(ETA) -as.vector(fg(ETA)[[2]])
      g <- function(ETA) -as.vector(attr(fg(ETA)[[2]], "dvdx"))


      .wh <- ID.ord[as.character(ix)]
      ETA.val <- starts[.wh, ]
      ..fit.inner <- nlminb(ETA.val, f, g, control = list(trace = FALSE, rel.tol = 1e-4))
      # ..fit.inner = lbfgs(f, g, ETA.val, invisible=TRUE, ftol=1e-4) # epsilon=1e-3)
      if (con$do.optimHess) {
        ..fit.inner$hessian <- optimHess(..fit.inner$par, f, g)
      }
      if (con$DEBUG.INNER) {
        # print(..fit.inner$message)
      }

      # =========================================================
      if (con$do.optimHess) {
        Ginv.5 <- tryCatch(
          {
            .m <- chol(..fit.inner$hessian)
            backsolve(.m, diag(nETA))
          },
          error = function(e) {
            cat("Warning: Hessian not positive definite\n")
            print(..fit.inner$hessian)
            .m <- ..fit.inner$hessian
            # .m <- chol(.m+diag(nETA)*100)
            # .m[col(.m)!=row(.m)] = .001*.m[col(.m)!=row(.m)]
            .md <- matrix(0, nETA, nETA)
            diag(.md) <- abs(diag(.m)) * 1.1
            .m <- chol(.md)
            backsolve(.m, diag(nETA))
          }
        )
      } else {
        Ginv.5 <- chol(..fit.inner$Hessian.inv)
      }
      det.Ginv.5 <- prod(diag(Ginv.5))

      ## -- AQD
      ..lik.ij <- lapply(1:nij, function(ix) {
        ij <- mij[ix, ]
        w <- nw$weights[ij]
        z <- nw$nodes[ij]
        a <- ..fit.inner$par + sqrt(2) * Ginv.5 %*% z
        f1 <- exp(as.vector(..g.fn(a))) # FIXME
        f2 <- prod(w * exp(z^2))
        f1 * f2
      })

      ..lik <- 2^(nETA / 2) * det.Ginv.5 * do.call("sum", ..lik.ij)
      c(-2 * log(..lik), .wh, ..fit.inner$par)
    }
    s <- mclapply(ID.all, llik2.subj, mc.cores = con$mc.cores) # FIXME
    m <- matrix(unlist(s), ncol = 2 + nETA, byrow = TRUE)

    if (update_starts) {
      .starts <- starts
      .starts[m[, 2], ] <- m[, 3:(2 + nETA)]
      assign("starts", .starts, .startsEnv)
    }
    m[, 1]
  }

  nobjcall <- 0
  .nobjEnv <- environment()
  obj <- function(th, noomga = FALSE) {
    assign("nobjcall", nobjcall + 1, .nobjEnv)
    s <- obj.vec(th, noomga)
    r <- sum(s)
    if (con$DEBUG.INNER) {
      print(rbind(
        c(nobjcall, r, th),
        c(nobjcall, r, th * inits.vec)
      ))
    }
    attr(r, "subj") <- s
    r
  }

  np <- length(inits.vec)
  start <- rep(1, np)
  args <- list(start, obj, control = list(trace = con$trace, reltol = con$reltol.outer))

  if (!con$NOTRUN) {
    fit <- if (con$optim.outer == "nmsimplex") {
      do.call("nmsimplex", args)
    } else if (con$optim.outer == "Nelder-Mead") {
      args <- c(args, method = con$optim.outer)
      do.call("optim", args)
    }
    else if (con$optim.outer == "nlminb") {
      args <- list(start, obj, control = list(trace = con$trace, rel.tol = con$reltol.outer))
      do.call("nlminb", args)
    }
    else {
      if (!is.null(con$npt)) {
        npt <- con$npt
      } else {
        npt <- 2 * np + 1
      }
      minqa::newuoa(start, obj, control = list(rhobeg = con$rhobeg, rhoend = con$rhoend, npt = npt, iprint = con$iprint))
    }
  } else {
    fit <- NULL
  }

  fit <- c(fit, obj = obj, list(ETA = starts, con = con, diag.xform = diag.xform, nsplt = nsplt, osplt = osplt, calls = list(data = data.sav, system = system, syspar = syspar)))
  fit$par.unscaled <- fit$par * inits.vec
  attr(fit, "class") <- "gnlmm.fit"
  fit
}


#------------------
mat.indices <- function(nETA) {
  idx <- do.call(
    "rbind",
    lapply(1:nETA, function(k) cbind(k:nETA, k))
  )
  H <- matrix(1:(nETA^2), nETA, nETA)
  Hlo.idx <- row(H) >= col(H)
  lo.idx <- H[row(H) > col(H)]
  hi.idx <- t(H)[row(H) > col(H)]

  list(
    idx = idx, # (r, c) of lo-half
    Hlo.idx = Hlo.idx, # index of lo-half
    lo.idx = lo.idx, # index of strict lo-half
    hi.idx = hi.idx
  ) # index of strict hi-half
}

Try the nlmixr package in your browser

Any scripts or data that you put into this service are public.

nlmixr documentation built on March 27, 2022, 5:05 p.m.