R/dnm_nhmm.R

Defines functions dnm_nhmm

dnm_nhmm <- function(model, inits, init_sd, restarts, lambda, bound, control, 
                     control_restart, save_all_solutions) {
 
  all_solutions <- NULL
  need_grad <- grepl("NLOPT_LD_", control$algorithm)
  objectivef <- make_objective_nhmm(
    model, lambda, need_grad
  )
  if (restarts > 0L) {
    .fun <- function(base_init, u) {
      init <- base_init + u
      fit <- nloptr(
        x0 = init, eval_f = objectivef, lb = -rep(bound, length(init)), 
        ub = rep(bound, length(init)), opts = control_restart
      )
      p()
      fit
    }
    p <- progressr::progressor(along = seq_len(restarts))
    original_options <- options(future.globals.maxSize = Inf)
    on.exit(options(original_options))
    base_init <- unlist(create_initial_values(inits, model, init_sd = 0))
    u <- t(
      stats::qnorm(
        lhs::maximinLHS(restarts, length(unlist(base_init))), sd = init_sd
      )
    )
    fit <- future.apply::future_lapply(
      seq_len(restarts), \(i) .fun(base_init, u[, i]), future.seed = TRUE
    )
    logliks <- -unlist(lapply(fit, "[[", "objective")) * nobs(model)
    return_codes <- unlist(lapply(fit, "[[", "status"))
    successful <- which(return_codes > 0)
    if (length(successful) == 0) {
      warning_(
        c("All restarts terminated due to error.",
          "Error of first restart: ", return_msg(return_codes[1]),
          "Trying once more. ")
      )
      init <- unlist(create_initial_values(inits, model, init_sd))
    } else {
      optimum <- successful[which.max(logliks[successful])]
      init <- fit[[optimum]]$solution
    }
    if (save_all_solutions) {
      all_solutions <- fit
    }
  } else {
    init <- unlist(create_initial_values(inits, model, init_sd))
  }
  
  fit <- nloptr(
    x0 = init, eval_f = objectivef, lb = -rep(bound, length(init)), 
    ub = rep(bound, length(init)), opts = control
  )
  if (fit$status < 0) {
    warning_(
      paste("Optimization terminated due to error:", return_msg(fit$status))
    )
    loglik <- NaN
  } else {
    loglik <- -fit$objective * nobs(model)
  }
  
  pars <- fit$solution
  
  S <- model$n_states
  M <- model$n_symbols
  np_pi <- attr(model, "np_pi")
  np_A <- attr(model, "np_A")
  np_B <- attr(model, "np_B")
  K_pi <- K(model$X_pi)
  K_A <- K(model$X_A)
  K_B <- K(model$X_B)
  
  model$etas$eta_pi <- create_eta_pi_nhmm(pars[seq_len(np_pi)], S, K_pi)
  model$gammas$gamma_pi <- eta_to_gamma_mat(model$etas$eta_pi)
  model$etas$eta_A <- create_eta_A_nhmm(pars[np_pi + seq_len(np_A)], S, K_A)
  model$gammas$gamma_A <- eta_to_gamma_cube(model$etas$eta_A)
  model$etas$eta_B <- create_eta_B_nhmm(
    pars[np_pi + np_A + seq_len(np_B)], S, M, K_B
  )
  model$gammas$gamma_B <- drop(eta_to_gamma_cube_field(model$etas$eta_B))
  
  model$estimation_results <- list(
    loglik = loglik, 
    return_code = fit$status,
    message = fit$message,
    iterations = fit$iterations,
    logliks_of_restarts = if(restarts > 0L) logliks else NULL, 
    return_codes_of_restarts = if(restarts > 0L) return_codes else NULL,
    all_solutions = all_solutions,
    lambda = lambda,
    bound = bound,
    method = "DNM", 
    algorithm = control$algorithm
  )
  model
}

Try the seqHMM package in your browser

Any scripts or data that you put into this service are public.

seqHMM documentation built on June 8, 2025, 10:16 a.m.