R/plot_momentuHMM.R

Defines functions getCovs plotTPM plotHistMVN plot.momentuHMM

Documented in plot.momentuHMM

#' Plot \code{momentuHMM}
#'
#' Plot the fitted step and angle densities over histograms of the data, transition probabilities
#' as functions of the covariates, and maps of the animals' tracks colored by the decoded states.
#'
#' @method plot momentuHMM
#'
#' @param x Object \code{momentuHMM}
#' @param animals Vector of indices or IDs of animals for which information will be plotted.
#' Default: \code{NULL} ; all animals are plotted.
#' @param covs Data frame consisting of a single row indicating the covariate values to be used in plots. 
#' If none are specified, the means of any covariates appearing in the model are used (unless covariate is a factor, in which case the first factor in the data is used).
#' @param ask If \code{TRUE}, the execution pauses between each plot.
#' @param breaks Histogram parameter. See \code{hist} documentation.
#' @param hist.ylim An optional named list of vectors specifying \code{ylim=c(ymin,ymax)} for the data stream histograms.
#' See \code{hist} documentation. Default: \code{NULL} ; the function sets default values for all data streams.
#' @param sepAnimals If \code{TRUE}, the data is split by individuals in the histograms.
#' Default: \code{FALSE}.
#' @param sepStates If \code{TRUE}, the data is split by states in the histograms.
#' Default: \code{FALSE}.
#' @param col Vector or colors for the states (one color per state).
#' @param cumul	If TRUE, the sum of weighted densities is plotted (default).
#' @param plotTracks If TRUE, the Viterbi-decoded tracks are plotted (default).
#' @param plotCI Logical indicating whether to include confidence intervals in natural parameter plots (default: FALSE)
#' @param alpha Significance level of the confidence intervals (if \code{plotCI=TRUE}). Default: 0.95 (i.e. 95\% CIs).
#' @param plotStationary Logical indicating whether to plot the stationary state probabilities as a function of any covariates (default: FALSE). Ignored unless covariate are included in \code{formula}.
#' @param ... Additional arguments passed to \code{graphics::plot} and \code{graphics::hist} functions. These can currently include \code{asp}, \code{cex}, \code{cex.axis}, \code{cex.lab}, \code{cex.legend}, \code{cex.main}, \code{legend.pos}, and \code{lwd}. See \code{\link[graphics]{par}}. \code{legend.pos} can be a single keyword from the list ``bottomright'', ``bottom'', ``bottomleft'', ``left'', ``topleft'', ``top'', ``topright'', ``right'', and ``center''. Note that \code{asp} and \code{cex} only apply to plots of animal tracks. 
#'
#' @details The state-dependent densities are weighted by the frequency of each state in the most
#' probable state sequence (decoded with the function \code{\link{viterbi}}). For example, if the
#' most probable state sequence indicates that one third of observations correspond to the first
#' state, and two thirds to the second state, the plots of the densities in the first state are
#' weighted by a factor 1/3, and in the second state by a factor 2/3.
#' 
#' Confidence intervals for natural parameters are calculated from the working parameter point and covariance estimates
#' using finite-difference approximations of the first derivative for the transformation (see \code{\link[numDeriv]{grad}}).
#' For example, if \code{dN} is the numerical approximation of the first derivative of the transformation \code{N = exp(x_1 * B_1 + x_2 * B_2)}
#' for covariates (x_1, x_2) and working parameters (B_1, B_2), then 
#' \code{var(N)=dN \%*\% Sigma \%*\% dN}, where \code{Sigma=cov(B_1,B_2)}, and normal confidence intervals can be 
#' constructed as \code{N +/- qnorm(1-(1-alpha)/2) * se(N)}.
#'
#' @examples
#' # m is a momentuHMM object (as returned by fitHMM), automatically loaded with the package
#' m <- example$m
#'
#' plot(m,ask=TRUE,animals=1,breaks=20,plotCI=TRUE)
#'
#' @export
#' @importFrom graphics legend lines segments arrows layout image contour barplot plot.new
#' @importFrom grDevices adjustcolor gray hcl colorRampPalette
#' @importFrom stats as.formula
#' @importFrom CircStats circ.mean
# @importFrom scatterplot3d scatterplot3d
#' @importFrom MASS kde2d

plot.momentuHMM <- function(x,animals=NULL,covs=NULL,ask=TRUE,breaks="Sturges",hist.ylim=NULL,sepAnimals=FALSE,
                            sepStates=FALSE,col=NULL,cumul=TRUE,plotTracks=TRUE,plotCI=FALSE,alpha=0.95,plotStationary=FALSE,...)
{
  m <- x # the name "x" is for compatibility with the generic method
  m <- delta_bc(m)
  
  nbAnimals <- length(unique(m$data$ID))
  stateNames <- m$stateNames
  nbStates <- length(stateNames)
  
  distnames <- names(m$conditions$dist)
  
  if(is.null(hist.ylim)){
    hist.ylim<-vector('list',length(distnames))
    names(hist.ylim)<-distnames
  }
  for(i in distnames){
    if(!is.null(hist.ylim[[i]]) & length(hist.ylim[[i]])!=2)
      stop("hist.ylim$",i," needs to be a vector of two values (ymin,ymax)")
  }
  
  # prepare colors for the states (used in the maps and for the densities)
  if (!is.null(col) & length(col) >= nbStates)
    col <- col[1:nbStates]
  if(!is.null(col) & length(col)<nbStates) {
    warning("Length of 'col' should be at least number of states - argument ignored")
    if(nbStates<8) {
      pal <- c("#E69F00", "#56B4E9", "#009E73", "#F0E442", 
               "#0072B2", "#D55E00", "#CC79A7")
      col <- pal[1:nbStates]
    } else {
      # to make sure that all colours are distinct (emulate ggplot default palette)
      hues <- seq(15, 375, length = nbStates + 1)
      col <- hcl(h = hues, l = 65, c = 100)[1:nbStates]
    }
  }
  if (is.null(col) & nbStates < 8) {
    pal <- c("#E69F00", "#56B4E9", "#009E73", "#F0E442", 
             "#0072B2", "#D55E00", "#CC79A7")
    col <- pal[1:nbStates]
  }
  if (is.null(col) & nbStates >= 8) {
    # to make sure that all colours are distinct (emulate ggplot default palette)
    hues <- seq(15, 375, length = nbStates + 1)
    col <- hcl(h = hues, l = 65, c = 100)[1:nbStates]
  }
  
  if (sepStates | nbStates < 2) 
    cumul <- FALSE
  
  if(inherits(x,"miSum")) plotEllipse <- m$plotEllipse
  else plotEllipse <- FALSE
  
  # this function is used to muffle the warning "zero-length arrow is of indeterminate angle and so skipped" when plotCI=TRUE
  muffWarn <- function(w) {
    if(any(grepl("zero-length arrow is of indeterminate angle and so skipped",w)))
      invokeRestart("muffleWarning")
  }
  
  coordNames <- attr(m$data,"coords")
  
  if(!is.null(m$conditions$mvnCoords)){
    coordNames <- c("x","y")
    if(m$conditions$dist[[m$conditions$mvnCoords]] %in% c("mvnorm3","rw_mvnorm3")) coordNames <- c("x","y","z")
    coordNames <- paste0(m$conditions$mvnCoords,".",coordNames)
  } else if(is.null(coordNames)) coordNames <- c("x","y")
  
  ######################
  ## Prepare the data ##
  ######################
  # determine indices of animals to be plotted
  if(is.null(animals)) # all animals are plotted
    animalsInd <- 1:nbAnimals
  else {
    if(is.character(animals)) { # animals' IDs provided
      animalsInd <- NULL
      for(zoo in 1:length(animals)) {
        if(length(which(unique(m$data$ID)==animals[zoo]))==0) # ID not found
          stop("Check 'animals' argument, ID not found")
        
        animalsInd <- c(animalsInd,which(unique(m$data$ID)==animals[zoo]))
      }
    }
    
    if(is.numeric(animals)) { # animals' indices provided
      if(length(which(animals<1))>0 | length(which(animals>nbAnimals))>0) # index out of bounds
        stop("Check 'animals' argument, index out of bounds")
      
      animalsInd <- animals
    }
  }
  
  nbAnimals <- length(animalsInd)
  ID <- unique(m$data$ID)[animalsInd]
  
  ##################################
  ## States decoding with Viterbi ##
  ##################################
  if(nbStates>1) {
    if(inherits(x,"miSum")) states <- m$Par$states
    else {
      cat("Decoding state sequence... ")
      states <- viterbi(m)
      cat("DONE\n")
      if(inherits(m,"hierarchical")){
        cat("Decoding hierarchical state sequence... ")
        hStates <- viterbi(m,hierarchical = TRUE)
        cat("DONE\n")
      }
    }
  } else
    states <- rep(1,nrow(m$data))
  
  # proportion of each state in the states sequence returned by the Viterbi algorithm
  w <- iStates <- list()
  for(i in distnames){
    if(!inherits(x,"hierarchical")){
      if(sepStates | nbStates==1)
        w[[i]] <- rep(1,nbStates)
      else {
        w[[i]] <- rep(NA,nbStates)
        for(state in 1:nbStates)
          w[[i]][state] <- length(which(states==state))/length(states)
      }
      iStates[[i]] <- 1:nbStates
      names(iStates[[i]]) <- stateNames
    } else {
      installDataTree()
      w[[i]] <- rep(0,nbStates)
      iLev <- gsub(paste0(".",i),"",names(m$conditions$hierDist$leaves)[grepl(i,names(m$conditions$hierDist$leaves))])
      iStates[[i]] <- m$conditions$hierStates$Get(function(x) data.tree::Aggregate(x,"state",min),filterFun=function(x) x$level==(as.numeric(gsub("level","",iLev))+1))
      if(sepStates) {
        w[[i]][iStates[[i]]] <- 1
      } else {
        denom <- length(hStates[[iLev]])
        for(state in 1:length(iStates[[i]])){
          w[[i]][iStates[[i]][state]] <- length(which(hStates[[iLev]]==names(iStates[[i]][state])))/denom
        }
      }
    }
  }
  
  if(all(coordNames %in% names(m$data))){
    x <- list()
    y <- list()
    z <- list()
    if(plotEllipse)  errorEllipse <- list()
    for(zoo in 1:nbAnimals) {
      ind <- which(m$data$ID==ID[zoo])
      x[[zoo]] <- m$data[[coordNames[1]]][ind]
      y[[zoo]] <- m$data[[coordNames[2]]][ind]
      if(!is.null(m$conditions$mvnCoords)){
        if(m$conditions$dist[[m$conditions$mvnCoords]] %in% c("mvnorm3","rw_mvnorm3")) {
          z[[zoo]] <- m$data[[coordNames[3]]][ind]
          plotEllipse <- FALSE
          errorEllipse <- NULL
        } 
      }
      if(plotEllipse) errorEllipse[[zoo]] <- m$errorEllipse[ind]
    }
  }
  
  covs <- getCovs(m,covs,ID)
  
  # identify covariates
  reForm <- formatRecharge(nbStates,m$conditions$formula,m$conditions$betaRef,m$data,par=m$mle)
  recharge <- reForm$recharge
  hierRecharge <- reForm$hierRecharge
  newformula <- reForm$newformula
  nbCovs <- reForm$nbCovs
  aInd <- reForm$aInd
  nbG0covs <- reForm$nbG0covs
  nbRecovs <- reForm$nbRecovs
  g0covs <- reForm$g0covs
  recovs <- reForm$recovs
  
  if(!is.null(recharge)){
    rechargeNames <- colnames(reForm$newdata)
    m$data[rechargeNames] <- reForm$newdata
    g0covs <- stats::model.matrix(recharge$g0,covs)
    g0 <- m$mle$g0 %*% t(g0covs)
    recovs <- stats::model.matrix(recharge$theta,covs)
    for(j in rechargeNames){
      if(is.null(covs[[j]])) covs[[j]] <- mean(m$data[[j]])
    }
    #if(is.null(covs$recharge)) covs$recharge <- mean(m$data$recharge) #g0 + theta%*%t(recovs)
    covsCol <- cbind(get_all_vars(newformula,m$data),get_all_vars(recharge$theta,m$data))#rownames(attr(stats::terms(formula),"factors"))#attr(stats::terms(formula),"term.labels")#seq(1,ncol(data))[-match(c("ID","x","y",distnames),names(data),nomatch=0)]
    if(!all(names(covsCol) %in% names(m$data))){
      covsCol <- covsCol[,names(covsCol) %in% names(m$data),drop=FALSE]
    }
    rawCovs <- covsCol[which(m$data$ID %in% ID),c(unique(colnames(covsCol))),drop=FALSE]
  } else {
    rawCovs <- m$rawCovs[which(m$data$ID %in% ID),,drop=FALSE]
  }

  Par <- m$mle[distnames]
  
  ncmean <- get_ncmean(distnames,m$conditions$fullDM,m$conditions$circularAngleMean,nbStates)
  nc <- ncmean$nc
  meanind <- ncmean$meanind
  
  tmPar <- lapply(Par,function(x) c(t(x)))
  parCount<- lapply(m$conditions$fullDM,ncol)
  for(i in distnames[!unlist(lapply(m$conditions$circularAngleMean,isFALSE))]){
    parCount[[i]] <- length(unique(gsub("cos","",gsub("sin","",colnames(m$conditions$fullDM[[i]])))))
  }
  parindex <- c(0,cumsum(unlist(parCount))[-length(m$conditions$fullDM)])
  names(parindex) <- distnames
  
  for(i in distnames){
    if(!is.null(m$conditions$DM[[i]])){# & m$conditions$DMind[[i]]){
      Par[[i]] <- m$mod$estimate[parindex[[i]]+1:parCount[[i]]]
      if(!isFALSE(m$conditions$circularAngleMean[[i]])){
        names(Par[[i]]) <- unique(gsub("cos","",gsub("sin","",colnames(m$conditions$fullDM[[i]]))))
      } else names(Par[[i]])<-colnames(m$conditions$fullDM[[i]])
    }
  }
  Par<-lapply(Par,function(x) c(t(x)))
  beta <- list(beta=m$mle$beta)
  if(!is.null(m$conditions$recharge)){
    beta$g0 <- m$mle$g0
    beta$theta <- m$mle$theta
  }
  
  mixtures <- m$conditions$mixtures
  
  if(!m$conditions$stationary){
    nbCovsDelta <- ncol(m$covsDelta)-1
    foo <- length(m$mod$estimate)-ifelse(nbRecovs,(nbRecovs+1)+(nbG0covs+1),0)-(nbCovsDelta+1)*(nbStates-1)*mixtures+1
    delta <- m$mod$estimate[foo:(length(m$mod$estimate)-ifelse(nbRecovs,(nbRecovs+1)+(nbG0covs+1),0))]
  } else {
    delta <- NULL
  }
  
  if(mixtures>1){
    if(!m$conditions$stationary) beta[["pi"]] <- m$mod$estimate[length(m$mod$estimate)-ncol(m$covsPi)*(mixtures-1)-ifelse(nbRecovs,(nbRecovs+1)+(nbG0covs+1),0)-(nbCovsDelta+1)*(nbStates-1)*mixtures+1:(ncol(m$covsPi)*(mixtures-1))]
    else beta[["pi"]] <- m$mod$estimate[length(m$mod$estimate)-ncol(m$covsPi)*(mixtures-1)-ifelse(nbRecovs,(nbRecovs+1)+(nbG0covs+1),0)+1:(ncol(m$covsPi)*(mixtures-1))]
  } else beta[["pi"]] <- NULL
  
  tmpPar <- Par
  tmpConditions <- m$conditions
  
  for(i in distnames[which(m$conditions$dist %in% angledists)]){
    if(!m$conditions$estAngleMean[[i]]){
      tmpConditions$estAngleMean[[i]]<-TRUE
      tmpConditions$userBounds[[i]]<-rbind(matrix(rep(c(-pi,pi),nbStates),nbStates,2,byrow=TRUE),m$conditions$bounds[[i]])
      tmpConditions$workBounds[[i]]<-rbind(matrix(rep(c(-Inf,Inf),nbStates),nbStates,2,byrow=TRUE),m$conditions$workBounds[[i]])
      if(!is.null(m$conditions$DM[[i]])){
        tmpPar[[i]] <- c(rep(0,nbStates),Par[[i]])
        if(is.list(m$conditions$DM[[i]])){
          tmpConditions$DM[[i]]$mean<- ~1
        } else {
          tmpDM <- matrix(0,nrow(tmpConditions$DM[[i]])+nbStates,ncol(tmpConditions$DM[[i]])+nbStates)
          tmpDM[nbStates+1:nrow(tmpConditions$DM[[i]]),nbStates+1:ncol(tmpConditions$DM[[i]])] <- tmpConditions$DM[[i]]
          diag(tmpDM)[1:nbStates] <- 1
          tmpConditions$DM[[i]] <- tmpDM
        }
      } else {
        Par[[i]] <- Par[[i]][-(1:nbStates)]
      }
    }
  }
  
  # get pars for probability density plots
  tmpInputs <- checkInputs(nbStates,tmpConditions$dist,tmpPar,tmpConditions$estAngleMean,tmpConditions$circularAngleMean,tmpConditions$zeroInflation,tmpConditions$oneInflation,tmpConditions$DM,tmpConditions$userBounds,stateNames)
  tmpp <- tmpInputs$p
  
  splineInputs<-getSplineDM(distnames,tmpInputs$DM,m,covs)
  covs<-splineInputs$covs
  DMinputs<-getDM(covs,splineInputs$DM,tmpInputs$dist,nbStates,tmpp$parNames,tmpp$bounds,tmpPar,tmpConditions$zeroInflation,tmpConditions$oneInflation,tmpConditions$circularAngleMean)
  fullDM <- DMinputs$fullDM
  DMind <- DMinputs$DMind
  wpar <- n2w(tmpPar,tmpp$bounds,beta,delta,nbStates,tmpInputs$estAngleMean,tmpInputs$DM,tmpp$Bndind,tmpInputs$dist)
  #if(!m$conditions$stationary & nbStates>1) {
  #  wpar[(length(wpar)-nbStates+2):length(wpar)] <- m$mod$estimate[(length(m$mod$estimate)-nbStates+2):length(m$mod$estimate)] #this is done to deal with any delta=0 in n2w
  #}
  
  ncmean <- get_ncmean(distnames,fullDM,tmpInputs$circularAngleMean,nbStates)
  nc <- ncmean$nc
  meanind <- ncmean$meanind
  
  par <- w2n(wpar,tmpp$bounds,tmpp$parSize,nbStates,nbCovs,tmpInputs$estAngleMean,tmpInputs$circularAngleMean,tmpInputs$consensus,stationary=m$conditions$stationary,fullDM,DMind,1,tmpInputs$dist,tmpp$Bndind,nc,meanind,m$covsDelta,tmpConditions$workBounds,m$covsPi)
  
  inputs <- checkInputs(nbStates,m$conditions$dist,Par,m$conditions$estAngleMean,m$conditions$circularAngleMean,m$conditions$zeroInflation,m$conditions$oneInflation,m$conditions$DM,m$conditions$userBounds,stateNames)
  p <- inputs$p
  
  Fun <- lapply(inputs$dist,function(x) paste("d",x,sep=""))
  for(i in names(Fun)){
    if(Fun[[i]]=="dcat"){
      if (!requireNamespace("extraDistr", quietly = TRUE))
        stop("Package \"extraDistr\" needed for categorical distribution. Please install it.",
             call. = FALSE) 
      dcat <- extraDistr::dcat
    }
  }
  
  zeroMass<-oneMass<-vector('list',length(inputs$dist))
  names(zeroMass)<-names(oneMass)<-distnames
  
  # text for legends
  legText <- stateNames
  
  tmpcovs<-covs
  for(i in which(mapply(is.numeric,covs))){
    tmpcovs[i]<-round(covs[i],2)
  }
  for(i in which(mapply(is.factor,covs))){
    tmpcovs[i] <- as.character(covs[[i]])
  }
  
  if(inherits(m,"miSum")){
    if(length(m$conditions$optInd)){
      Sigma <- matrix(0,length(m$mod$estimate),length(m$mod$estimate))
      Sigma[(1:length(m$mod$estimate))[-m$conditions$optInd],(1:length(m$mod$estimate))[-m$conditions$optInd]] <- m$MIcombine$variance
    } else {
      Sigma <- m$MIcombine$variance
    }
  } else if(!is.null(m$mod$hessian) && !inherits(m$mod$Sigma,"error")){
    Sigma <- m$mod$Sigma
  } else {
    Sigma <- NULL
    plotCI <- FALSE
  }
  
  # set graphical parameters
  par(mfrow=c(1,1))
  par(mar=c(5,4,4,2)-c(0,0,2,1)) # bottom, left, top, right
  par(ask=ask)
  
  if(!missing(...)){
    arg <- list(...)
    if(any(!(names(arg) %in% plotArgs))) stop("additional graphical parameters are currently limited to: ",paste0(plotArgs,collapse=", "))
    if(!is.null(arg$cex.main)) cex.main <- arg$cex.main
    else cex.main <- 1
    arg$cex.main <- NULL
    if(!is.null(arg$cex.legend)) cex.legend <- arg$cex.legend
    else cex.legend <- 1
    arg$cex.legend <- NULL 
    if(!is.null(arg[["cex"]])) cex <- arg[["cex"]]
    else cex <- 0.6
    arg$cex <- NULL
    if(!is.null(arg$asp)) asp <- arg$asp
    else asp <- 1
    arg$asp <- NULL
    if(!is.null(arg$lwd)) lwd <- arg$lwd
    else lwd <- 1.3
    arg$lwd <- NULL
    if(!is.null(arg$legend.pos)) {
      if(!(arg$legend.pos %in% c("bottomright", "bottom", "bottomleft", "left", "topleft", "top", "topright", "right", "center"))) 
        stop('legend.pos must be a single keyword from the list "bottomright", "bottom", "bottomleft", "left", "topleft", "top", "topright", "right" and "center"')
      legend.pos <- arg$legend.pos
    }
    else legend.pos <- NULL
    arg$legend.pos <- NULL
  } else {
    cex <- 0.6
    asp <- 1
    lwd <- 1.3
    cex.main <- 1
    cex.legend <- 1
    legend.pos <- NULL
    arg <- NULL
  }
  marg <- arg
  marg$cex <- NULL
  
  for(i in distnames){
    
    if(m$conditions$dist[[i]] %in% mvndists){
      if(m$conditions$dist[[i]]=="mvnorm2" || m$conditions$dist[[i]]=="rw_mvnorm2"){
        tmpData <- c(m$data[[paste0(i,".x")]],m$data[[paste0(i,".y")]])
        if(m$conditions$dist[[i]]=="mvnorm2") ndim <- as.numeric(gsub("mvnorm","",m$conditions$dist[[i]]))
        else ndim <- as.numeric(gsub("rw_mvnorm","",m$conditions$dist[[i]]))
      } else if(m$conditions$dist[[i]]=="mvnorm3" || m$conditions$dist[[i]]=="rw_mvnorm3"){
        tmpData <- c(m$data[[paste0(i,".x")]],m$data[[paste0(i,".y")]],m$data[[paste0(i,".z")]])
        if(m$conditions$dist[[i]]=="mvnorm3") ndim <- as.numeric(gsub("mvnorm","",m$conditions$dist[[i]]))
        else ndim <- as.numeric(gsub("rw_mvnorm","",m$conditions$dist[[i]]))
      }
    } else {
      tmpData <- m$data[[i]]
    }
    
    # split data by animals if necessary
    if(sepAnimals) {
      genData <- list()
      for(zoo in 1:nbAnimals) {
        ind <- which(m$data$ID==ID[zoo])
        if(m$conditions$dist[[i]] %in% mvndists){
          if(m$conditions$dist[[i]] %in% c("mvnorm2","rw_mvnorm2"))
            genData[[zoo]] <- c(tmpData[ind],tmpData[(-(1:nrow(m$data)))][ind])
          else if(m$conditions$dist[[i]] %in% c("mvnorm3","rw_mvnorm3"))
            genData[[zoo]] <- c(tmpData[ind],tmpData[-(1:nrow(m$data))][ind],tmpData[-(1:(2*nrow(m$data)))][ind])
        } else genData[[zoo]] <- tmpData[ind]
      }
    } else {
      ind <- which(m$data$ID %in% ID)
      if(m$conditions$dist[[i]] %in% mvndists){
        if(m$conditions$dist[[i]] %in% c("mvnorm2","rw_mvnorm2"))
          genData <- tmpData[c(ind,ind+nrow(m$data))]
        else if(m$conditions$dist[[i]] %in% c("mvnorm3","rw_mvnorm3"))
          genData <- tmpData[c(ind,ind+nrow(m$data),ind+2*nrow(m$data))]
      } else genData <- tmpData[ind]
    }
    
    zeroMass[[i]] <- rep(0,nbStates)
    oneMass[[i]] <- rep(0,nbStates)
    if(m$conditions$zeroInflation[[i]] | m$conditions$oneInflation[[i]]) {
      if(m$conditions$zeroInflation[[i]]) zeroMass[[i]] <- par[[i]][nrow(par[[i]])-nbStates*m$conditions$oneInflation[[i]]-(nbStates-1):0,]
      if(m$conditions$oneInflation[[i]]) oneMass[[i]] <- par[[i]][nrow(par[[i]])-(nbStates-1):0,]
      par[[i]] <- par[[i]][-(nrow(par[[i]])-(nbStates*m$conditions$oneInflation[[i]]-nbStates*m$conditions$zeroInflation[[i]]-1):0),,drop=FALSE]
    }
    
    infInd <- FALSE
    if(inputs$dist[[i]] %in% angledists)
      if(i=="angle" & ("step" %in% distnames))
        if(inputs$dist$step %in% stepdists & m$conditions$zeroInflation$step)
          if(all(coordNames %in% names(m$data)))
            infInd <- TRUE
    
    #get covariate names
    covNames <- getCovNames(m,p,i)
    DMterms<-covNames$DMterms
    DMparterms<-covNames$DMparterms
    
    if(inputs$consensus[[i]]){
      for(jj in 1:nbStates){
        if(!is.null(DMparterms$mean[[jj]])) DMparterms$kappa[[jj]] <- c(DMparterms$mean[[jj]],DMparterms$kappa[[jj]])
      }
    }
    
    factorterms<-names(m$data)[unlist(lapply(m$data,is.factor))]
    factorcovs<-paste0(rep(factorterms,times=unlist(lapply(m$data[factorterms],nlevels))),unlist(lapply(m$data[factorterms],levels)))
    
    if(length(DMterms)){
      for(jj in 1:length(DMterms)){
        cov<-DMterms[jj]
        form<-stats::formula(paste("~",cov))
        varform<-all.vars(form)
        if(any(varform %in% factorcovs) && !all(varform %in% factorterms)){
          factorvar<-factorcovs %in% (varform[!(varform %in% factorterms)])
          DMterms[jj]<-rep(factorterms,times=unlist(lapply(m$data[factorterms],nlevels)))[which(factorvar)]
        } 
      }
    }
    DMterms<-unique(DMterms)
    
    if(length(DMparterms)){
      for(ii in 1:length(DMparterms)){
        for(state in 1:nbStates){
          if(length(DMparterms[[ii]][[state]])){
            for(jj in 1:length(DMparterms[[ii]][[state]])){
              cov<-DMparterms[[ii]][[state]][jj]
              form<-stats::formula(paste("~",cov))
              varform<-all.vars(form)
              if(any(varform %in% factorcovs) && !all(varform %in% factorterms)){
                factorvar<-factorcovs %in% (varform[!(varform %in% factorterms)])
                DMparterms[[ii]][[state]][jj]<-rep(factorterms,times=unlist(lapply(m$data[factorterms],nlevels)))[which(factorvar)]
              }
            }
            DMparterms[[ii]][[state]]<-unique(DMparterms[[ii]][[state]])
          }
        }
      }
    }
    covmess <- ifelse(!m$conditions$DMind[[i]],paste0(": ",paste0(DMterms," = ",tmpcovs[DMterms],collapse=", ")),"")
    
    ###########################################
    ## Compute estimated densities on a grid ##
    ###########################################
    genDensities <- list()
    genFun <- Fun[[i]]
    if(inputs$dist[[i]] %in% angledists) {
      grid <- seq(-pi,pi,length=1000)
    } else {
      
      if(inputs$dist[[i]] %in% integerdists){
        if(all(is.na(m$data[[i]])) || !is.finite(max(m$data[[i]],na.rm=TRUE))) next;
        if(inputs$dist[[i]]=="cat"){
          dimCat <- as.numeric(gsub("cat","",m$conditions$dist[[i]]))
          grid <- seq(1,dimCat)
        }
        else grid <- seq(0,max(m$data[[i]],na.rm=TRUE))
      } else if(inputs$dist[[i]] %in% stepdists){
        if(all(is.na(m$data[[i]])) || !is.finite(max(m$data[[i]],na.rm=TRUE))) next;
        grid <- seq(0,max(m$data[[i]],na.rm=TRUE),length=10000)
      } else if(inputs$dist[[i]] %in% mvndists){
        if(inputs$dist[[i]]=="mvnorm2" || inputs$dist[[i]]=="rw_mvnorm2"){
          if(all(is.na(m$data[paste0(i,c(".x",".y"))])) || !is.finite(max(m$data[paste0(i,c(".x",".y"))],na.rm=TRUE))) next;
          grid <- c(seq(min(m$data[[paste0(i,".x")]],na.rm=TRUE), max(m$data[[paste0(i,".x")]],na.rm=TRUE), length=100),
                    seq(min(m$data[[paste0(i,".y")]],na.rm=TRUE), max(m$data[[paste0(i,".y")]],na.rm=TRUE), length=100))
        } else if(all(is.na(m$data[paste0(i,c(".x",".y",".z"))])) || !is.finite(max(m$data[paste0(i,c(".x",".y",".z"))],na.rm=TRUE))) next;
      } else {
        if(all(is.na(m$data[[i]])) || !is.finite(max(m$data[[i]],na.rm=TRUE))) next;
        grid <- seq(min(m$data[[i]],na.rm=TRUE),max(m$data[[i]],na.rm=TRUE),length=10000)
      }
    }
    for(state in iStates[[i]]) {
      genArgs <- list(grid)
      
      if(m$conditions$dist[[i]] %in% mvndists){
        genArgs[[2]] <- matrix(par[[i]][seq(state,nbStates*ndim,nbStates)],ndim,1)
        sig <- matrix(0,ndim,ndim)
        lowertri <- par[[i]][nbStates*ndim+seq(state,sum(lower.tri(matrix(0,ndim,ndim),diag=TRUE))*nbStates,nbStates)]
        sig[lower.tri(sig, diag=TRUE)] <- lowertri
        sig <- t(sig)
        sig[lower.tri(sig, diag=TRUE)] <- lowertri
        genArgs[[3]] <- sig
      } else if(grepl("cat",m$conditions$dist[[i]])){
        genArgs[[2]] <- t(par[[i]][seq(state,dimCat*nbStates,nbStates),])
      } else {
        for(j in 1:(nrow(par[[i]])/nbStates))
          genArgs[[j+1]] <- par[[i]][(j-1)*nbStates+state,]
      }
      
      # conversion between mean/sd and shape/scale if necessary
      if(inputs$dist[[i]]=="gamma") {
        shape <- genArgs[[2]]^2/genArgs[[3]]^2
        scale <- genArgs[[3]]^2/genArgs[[2]]
        genArgs[[2]] <- shape
        genArgs[[3]] <- 1/scale # dgamma expects rate=1/scale
      }
      # (weighted by the proportion of each state in the Viterbi states sequence)
      if(m$conditions$zeroInflation[[i]] | m$conditions$oneInflation[[i]]){
        genDensities[[state]] <- cbind(grid,(1-zeroMass[[i]][state]-oneMass[[i]][state])*w[[i]][state]*do.call(genFun,genArgs))
      } else if(infInd) {
        genDensities[[state]] <- cbind(grid,(1-zeroMass$step[state])*w[[i]][state]*do.call(genFun,genArgs))
      } else if(inputs$dist[[i]] %in% mvndists){
        if(inputs$dist[[i]]=="mvnorm2" || inputs$dist[[i]]=="rw_mvnorm2"){
          dens <- outer(genArgs[[1]][1:100],genArgs[[1]][101:200], function(x,y) dmvnorm2(c(x,y),matrix(rep(genArgs[[2]],10000),2),matrix(rep(genArgs[[3]],10000),2*2)))
          genDensities[[state]] <- list(x=genArgs[[1]][1:100], y=genArgs[[1]][101:200], z=w[[i]][state]*dens)
        }
      } else {
        genDensities[[state]] <- cbind(grid,w[[i]][state]*do.call(genFun,genArgs))
      }
      
      for(j in p$parNames[[i]]){
        
        for(jj in DMparterms[[j]][[state]]){
          
          if(!is.factor(m$data[,jj])){
            
            gridLength <- 101
            
            inf <- min(m$data[,jj],na.rm=T)
            sup <- max(m$data[,jj],na.rm=T)
            
            # set all covariates to their mean, except for "cov"
            # (which takes a grid of values from inf to sup)
            tempCovs <- data.frame(matrix(covs[jj][[1]],nrow=gridLength,ncol=1))
            if(length(DMterms)>1)
              for(ii in DMterms[which(!(DMterms %in% jj))])
                tempCovs <- cbind(tempCovs,rep(covs[[ii]],gridLength))
            names(tempCovs) <- c(jj,DMterms[which(!(DMterms %in% jj))])
            tempCovs[,jj] <- seq(inf,sup,length=gridLength)
          } else {
            gridLength<- nlevels(m$data[,jj])
            # set all covariates to their mean, except for "cov"
            tempCovs <- data.frame(matrix(covs[jj][[1]],nrow=gridLength,ncol=1))
            if(length(DMterms)>1)
              for(ii in DMterms[which(!(DMterms %in% jj))])
                tempCovs <- cbind(tempCovs,rep(covs[[ii]],gridLength))
            names(tempCovs) <- c(jj,DMterms[which(!(DMterms %in% jj))])
            tempCovs[,jj] <- as.factor(levels(m$data[,jj]))
          }
          
          for(ii in DMterms[which(unlist(lapply(m$data[DMterms],is.factor)))])
            tempCovs[[ii]] <- factor(tempCovs[[ii]],levels=levels(m$data[[ii]]))
          
          tmpSplineInputs<-getSplineDM(i,inputs$DM,m,tempCovs)
          tempCovs<-tmpSplineInputs$covs
          DMinputs<-getDM(tempCovs,tmpSplineInputs$DM[i],inputs$dist[i],nbStates,p$parNames[i],p$bounds[i],Par[i],m$conditions$zeroInflation[i],m$conditions$oneInflation[i],m$conditions$circularAngleMean[i])
          
          fullDM <- DMinputs$fullDM
          DMind <- DMinputs$DMind
          
          nc[[i]] <- apply(fullDM[[i]],1:2,function(x) !all(unlist(x)==0))
          if(!isFALSE(inputs$circularAngleMean[[i]])) {
            meanind[[i]] <- which((apply(fullDM[[i]][1:nbStates,,drop=FALSE],1,function(x) !all(unlist(x)==0))))
            # deal with angular covariates that are exactly zero
            if(length(meanind[[i]])){
              angInd <- which(is.na(match(gsub("cos","",gsub("sin","",colnames(nc[[i]]))),colnames(nc[[i]]),nomatch=NA)))
              sinInd <- colnames(nc[[i]])[which(grepl("sin",colnames(nc[[i]])[angInd]))]
              nc[[i]][meanind[[i]],sinInd]<-ifelse(nc[[i]][meanind[[i]],sinInd],nc[[i]][meanind[[i]],sinInd],nc[[i]][meanind[[i]],gsub("sin","cos",sinInd)])
              nc[[i]][meanind[[i]],gsub("sin","cos",sinInd)]<-ifelse(nc[[i]][meanind[[i]],gsub("sin","cos",sinInd)],nc[[i]][meanind[[i]],gsub("sin","cos",sinInd)],nc[[i]][meanind[[i]],sinInd])
            }
          }
          gradfun<-function(wpar,k) {
            w2n(wpar,p$bounds[i],p$parSize[i],nbStates,nbCovs,inputs$estAngleMean[i],inputs$circularAngleMean[i],inputs$consensus[i],stationary=TRUE,fullDM,DMind,gridLength,inputs$dist[i],p$Bndind[i],nc[i],meanind[i],m$covsDelta,m$conditions$workBounds[c(i,"beta")],m$covsPi)[[i]][(which(tmpp$parNames[[i]]==j)-1)*nbStates+state,k]
          }
          est<-w2n(c(m$mod$estimate[parindex[[i]]+1:parCount[[i]]],beta$beta,beta[["pi"]]),p$bounds[i],p$parSize[i],nbStates,nbCovs,inputs$estAngleMean[i],inputs$circularAngleMean[i],inputs$consensus[i],stationary=TRUE,fullDM,DMind,gridLength,inputs$dist[i],p$Bndind[i],nc[i],meanind[i],m$covsDelta,m$conditions$workBounds[c(i,"beta")],m$covsPi)[[i]][(which(tmpp$parNames[[i]]==j)-1)*nbStates+state,]
          if(plotCI){
            dN<-t(mapply(function(x) tryCatch(numDeriv::grad(gradfun,c(m$mod$estimate[parindex[[i]]+1:parCount[[i]]],beta$beta,beta[["pi"]]),k=x),error=function(e) NA),1:gridLength))
            se<-t(apply(dN[,1:parCount[[i]]],1,function(x) tryCatch(suppressWarnings(sqrt(x%*%Sigma[parindex[[i]]+1:parCount[[i]],parindex[[i]]+1:parCount[[i]]]%*%x)),error=function(e) NA)))
            uci<-est+qnorm(1-(1-alpha)/2)*se
            lci<-est-qnorm(1-(1-alpha)/2)*se
            do.call(plot,c(list(tempCovs[,jj],est,ylim=range(c(lci,est,uci),na.rm=TRUE),xaxt="n",xlab=jj,ylab=paste(i,ifelse(j=="kappa","concentration",j),'parameter'),main=paste0(names(iStates[[i]])[match(state,iStates[[i]])],ifelse(length(tempCovs[,DMparterms[[j]][[state]][-which(DMparterms[[j]][[state]]==jj)]]),paste0(": ",paste(DMparterms[[j]][[state]][-which(DMparterms[[j]][[state]]==jj)],"=",tmpcovs[,DMparterms[[j]][[state]][-which(DMparterms[[j]][[state]]==jj)]],collapse=", ")),"")),type="l",lwd=lwd,cex.main=cex.main),arg))            
            if(!all(is.na(se))){
              ciInd <- which(!is.na(se))
              
              withCallingHandlers(do.call(arrows,c(list(as.numeric(tempCovs[ciInd,jj]), lci[ciInd], as.numeric(tempCovs[ciInd,jj]),
                                         uci[ciInd], length=0.025, angle=90, code=3, col=gray(.5), lwd=lwd),arg)),warning=muffWarn)
              
            }
          } else do.call(plot,c(list(tempCovs[,jj],est,xaxt="n",xlab=jj,ylab=paste(i,ifelse(j=="kappa","concentration",j),'parameter'),main=paste0(names(iStates[[i]])[match(state,iStates[[i]])],ifelse(length(tempCovs[,DMparterms[[j]][[state]][-which(DMparterms[[j]][[state]]==jj)]]),paste0(": ",paste(DMparterms[[j]][[state]][-which(DMparterms[[j]][[state]]==jj)],"=",tmpcovs[,DMparterms[[j]][[state]][-which(DMparterms[[j]][[state]]==jj)]],collapse=", ")),"")),type="l",lwd=lwd,cex.main=cex.main),arg)) 
          if(is.factor(tempCovs[,jj])) do.call(axis,c(list(1,at=tempCovs[,jj],labels=tempCovs[,jj]),arg))
          else do.call(axis,c(list(1),arg))
          
        }
      }
    }
    
    #########################
    ## Plot the histograms ##
    #########################
    if(!(inputs$dist[[i]] %in% mvndists)){
      if(sepAnimals) {
        
        # loop over the animals
        for(zoo in 1:nbAnimals) {
          if(sepStates) {
            
            # loop over the states
            for(state in iStates[[i]]) {
              gen <- genData[[zoo]][which(states[which(m$data$ID==ID[zoo])]==state)]
              message <- paste0("ID ",ID[zoo]," - ",names(iStates[[i]])[match(state,iStates[[i]])],covmess)
              
              # the function plotHist is defined below
              plotHist(gen,genDensities,inputs$dist[i],message,sepStates,breaks,state,hist.ylim[[i]],col,names(iStates[[i]]), cumul = cumul, iStates[[i]], ...)
            }
            
          } else { # if !sepStates
            gen <- genData[[zoo]]
            message <- paste0("ID ",ID[zoo],covmess)
            
            plotHist(gen,genDensities,inputs$dist[i],message,sepStates,breaks,NULL,hist.ylim[[i]],col,names(iStates[[i]]), cumul = cumul, iStates[[i]], ...)
          }
        }
      } else { # if !sepAnimals
        if(sepStates) {
          
          # loop over the states
          for(state in iStates[[i]]) {
            gen <- genData[which(states==state)]
            if(nbAnimals>1) message <- paste0("All animals - ",names(iStates[[i]])[match(state,iStates[[i]])],covmess)
            else message <- paste0("ID ",ID," - ",names(iStates[[i]])[match(state,iStates[[i]])],covmess)
            
            plotHist(gen,genDensities,inputs$dist[i],message,sepStates,breaks,state,hist.ylim[[i]],col,names(iStates[[i]]), cumul = cumul, iStates[[i]], ...)
          }
          
        } else { # if !sepStates
          gen <- genData
          if(nbAnimals>1) message <- paste0("All animals",covmess)
          else message <- paste0("ID ",ID,covmess)
          
          plotHist(gen,genDensities,inputs$dist[i],message,sepStates,breaks,NULL,hist.ylim[[i]],col,names(iStates[[i]]), cumul = cumul, iStates[[i]], ...)
        }
      }
    } else if(inputs$dist[[i]]=="mvnorm2" || inputs$dist[[i]]=="rw_mvnorm2"){
      
      datNames <- paste0(i,c(".x",".y"))
      
      if(sepAnimals) {
        
        # loop over the animals
        for(zoo in 1:nbAnimals) {
          if(sepStates) {
            
            # loop over the states
            for(state in iStates[[i]]) {
              gen <- m$data[which(states==state & m$data$ID==ID[zoo]),datNames]
              message <- paste0("ID ",ID[zoo]," - ",names(iStates[[i]])[match(state,iStates[[i]])],covmess)
              
              # the function plotHistMVN is defined below
              plotHistMVN(gen,genDensities,inputs$dist[i],message,sepStates,breaks,state,col,names(iStates[[i]]), cumul=cumul, iStates[[i]], ...)
            }
            
          } else { # if !sepStates
            gen <- m$data[which(m$data$ID==ID[zoo]),datNames]
            message <- paste0("ID ",ID[zoo],covmess)
            
            plotHistMVN(gen,genDensities,inputs$dist[i],message,sepStates,breaks,NULL,col,names(iStates[[i]]), cumul=cumul, iStates[[i]], ...)
          }
        }
      } else { # if !sepAnimals
        if(sepStates) {
          
          # loop over the states
          for(state in iStates[[i]]) {
            gen <- m$data[which(states==state),datNames]
            if(nbAnimals>1) message <- paste0("All animals - ",names(iStates[[i]])[match(state,iStates[[i]])],covmess)
            else message <- paste0("ID ",ID," - ",names(iStates[[i]])[match(state,iStates[[i]])],covmess)

            plotHistMVN(gen,genDensities,inputs$dist[i],message,sepStates,breaks,state,col,names(iStates[[i]]), cumul=cumul, iStates[[i]], ...)            
          }
          
        } else { # if !sepStates
          gen <- m$data[,datNames]
          if(nbAnimals>1) message <- paste0("All animals",covmess)
          else message <- paste0("ID ",ID,covmess)
          
          plotHistMVN(gen,genDensities,inputs$dist[i],message,sepStates,breaks,NULL,col,names(iStates[[i]]), cumul=cumul, iStates[[i]], ...)
        }
      }     
      # reset graphical parameters
      par(mfrow=c(1,1))
      par(mar=c(5,4,4,2)-c(0,0,2,1)) # bottom, left, top, right
      par(ask=ask)
    } else if(inputs$dist[[i]]=="mvnorm3" || inputs$dist[[i]]=="rw_mvnorm3"){
      
      datNames <- paste0(i,c(".x",".y",".z"))
      
      tmpdist <- list()
      for(j in 1:length(datNames)){
        
        tmpdist[[datNames[j]]] <- "norm"
        
        genArgs[[1]] <- seq(min(m$data[[datNames[j]]],na.rm=TRUE),max(m$data[[datNames[j]]],na.rm=TRUE),length=10000)
        genArgs[[2]] <- j
        
        for(state in iStates[[i]]){
          
          genArgs[[3]] <- par[[i]][seq(state,nbStates*ndim,nbStates)]
          sig <- matrix(0,ndim,ndim)
          lowertri <- par[[i]][nbStates*ndim+seq(state,sum(lower.tri(matrix(0,ndim,ndim),diag=TRUE))*nbStates,nbStates)]
          sig[lower.tri(sig, diag=TRUE)] <- lowertri
          sig <- t(sig)
          sig[lower.tri(sig, diag=TRUE)] <- lowertri
          genArgs[[4]] <- sig

          genDensities[[state]] <- cbind(genArgs[[1]],w[[i]][state]*do.call("dtmvnorm.marginal",genArgs))
        }

        if(sepAnimals) {
          
          # loop over the animals
          for(zoo in 1:nbAnimals) {
            if(sepStates) {
              
              # loop over the states
              for(state in iStates[[i]]) {
                gen <- m$data[which(states==state & m$data$ID==ID[zoo]),datNames[j]]
                message <- paste0("ID ",ID[zoo]," - ",names(iStates[[i]])[match(state,iStates[[i]])],covmess)

                # the function plotHist is defined below
                plotHist(gen,genDensities,tmpdist[datNames[j]],message,sepStates,breaks,state,hist.ylim[[i]],col,names(iStates[[i]]), cumul = cumul, iStates[[i]], ...)
              }
              
            } else { # if !sepStates
              gen <- m$data[which(m$data$ID==ID[zoo]),datNames[j]]
              message <- paste0("ID ",ID[zoo],covmess)
              
              plotHist(gen,genDensities,tmpdist[datNames[j]],message,sepStates,breaks,NULL,hist.ylim[[i]],col,names(iStates[[i]]), cumul = cumul, iStates[[i]], ...)
            }
          }
        } else { # if !sepAnimals
          if(sepStates) {
            
            # loop over the states
            for(state in iStates[[i]]) {
              gen <- m$data[which(states==state),datNames[j]]
              if(nbAnimals>1) message <- paste0("All animals - ",names(iStates[[i]])[match(state,iStates[[i]])],covmess)
              else message <- paste0("ID ",ID," - ",names(iStates[[i]])[match(state,iStates[[i]])],covmess)
              
              plotHist(gen,genDensities,tmpdist[datNames[j]],message,sepStates,breaks,state,hist.ylim[[i]],col,names(iStates[[i]]), cumul = cumul, iStates[[i]], ...)
            }
            
          } else { # if !sepStates
            gen <- m$data[,datNames[j]]
            if(nbAnimals>1) message <- paste0("All animals",covmess)
            else message <- paste0("ID ",ID,covmess)
            
            plotHist(gen,genDensities,tmpdist[datNames[j]],message,sepStates,breaks,NULL,hist.ylim[[i]],col,names(iStates[[i]]), cumul = cumul, iStates[[i]], ...)
          }
        }
      }
      #message("Plotting multivariate normal histograms and densities in 3D is hard -- you're on your own!")
    }
  }
  
  ##################################################
  ## Plot the t.p. as functions of the covariates ##
  ##################################################
  if(nbStates>1) {
    par(mar=c(5,4,4,2)-c(0,0,1.5,1)) # bottom, left, top, right
    
    gamInd<-(length(m$mod$estimate)-(nbCovs+1)*nbStates*(nbStates-1)*mixtures+1):(length(m$mod$estimate))-(ncol(m$covsPi)*(mixtures-1))-ifelse(nbRecovs,nbRecovs+1+nbG0covs+1,0)-ncol(m$covsDelta)*(nbStates-1)*(!m$conditions$stationary)*mixtures
    quantSup<-qnorm(1-(1-alpha)/2)
    
    if(nbCovs>0) {
      
      # values of each covariate
      #rawCovs <- m$rawCovs[which(m$data$ID %in% ID),,drop=FALSE]
      #if(is.null(covs)) {
      #  rawCovs <- m$rawCovs
      #  meanCovs <- colSums(rawCovs)/nrow(rawCovs)
      #} else {
      #  rawCovs <- m$data[,names(covs),drop=FALSE]
      #  meanCovs <- as.numeric(covs)
      #}
      
      if(inherits(m,"hierarchical")) {
        covIndex <- which(!(names(rawCovs)=="level"))
        covs$level <- NULL
        covs <- data.frame(covs[rep(1:nrow(covs),nlevels(m$data$level)),,drop=FALSE],level=rep(levels(m$data$level),each=nrow(covs)))
      } else covIndex <- 1:ncol(rawCovs)
      
      for(cov in covIndex) {
        
        if(!is.factor(rawCovs[,cov])){
          
          gridLength <- 101
          hGridLength <- gridLength*ifelse(inherits(m,"hierarchical"),nlevels(m$data$level),1)
          
          inf <- min(rawCovs[,cov],na.rm=TRUE)
          sup <- max(rawCovs[,cov],na.rm=TRUE)
          
          # set all covariates to their mean, except for "cov"
          # (which takes a grid of values from inf to sup)
          tempCovs <- data.frame(matrix(covs[names(rawCovs)][[1]],nrow=hGridLength,ncol=1))
          if(ncol(rawCovs)>1)
            for(i in 2:ncol(rawCovs))
              tempCovs <- cbind(tempCovs,rep(covs[names(rawCovs)][[i]],gridLength))
          
          tempCovs[,cov] <- rep(seq(inf,sup,length=gridLength),each=hGridLength/gridLength)
        } else {
          gridLength<- nlevels(rawCovs[,cov])
          hGridLength <- gridLength*ifelse(inherits(m,"hierarchical"),nlevels(m$data$level),1)
          # set all covariates to their mean, except for "cov"
          tempCovs <- data.frame(matrix(covs[names(rawCovs)][[1]],nrow=hGridLength,ncol=1))
          if(ncol(rawCovs)>1)
            for(i in 2:ncol(rawCovs))
              tempCovs <- cbind(tempCovs,rep(covs[names(rawCovs)][[i]],gridLength))
          
          tempCovs[,cov] <- as.factor(rep(levels(rawCovs[,cov]),each=hGridLength/gridLength))
        }
        
        names(tempCovs) <- names(rawCovs)
        tmpcovs<-covs[names(rawCovs)]
        for(i in which(unlist(lapply(rawCovs,is.factor)))){
          tempCovs[[i]] <- factor(tempCovs[[i]],levels=levels(rawCovs[,i]))
          tmpcovs[i] <- as.character(tmpcovs[[i]])
        }
        for(i in which(!unlist(lapply(rawCovs,is.factor)))){
          tmpcovs[i]<-round(covs[names(rawCovs)][i],2)
        }
        if(!is.null(recharge)){
          tmprecovs<-covs[names(m$reCovs)]
          for(i in which(unlist(lapply(m$reCovs,is.factor)))){
            tmprecovs[i] <- as.character(tmprecovs[[i]])
          }
          for(i in which(!unlist(lapply(m$reCovs,is.factor)))){
            tmprecovs[i]<-round(recovs[names(m$reCovs)][i],2)
          }
        }
        
        if(inherits(m$data,"hierarchical")) class(tempCovs) <- append("hierarchical",class(tempCovs))
        
        tmpSplineInputs<-getSplineFormula(newformula,m$data,tempCovs)
        desMat <- stats::model.matrix(tmpSplineInputs$formula,data=tmpSplineInputs$covs)
        
        for(mix in 1:mixtures){
          
          if(is.null(recharge)){
            trMat <- trMatrix_rcpp(nbStates,beta$beta[(mix-1)*(nbCovs+1)+1:(nbCovs+1),,drop=FALSE],desMat,m$conditions$betaRef)
          } else {
            trMat <- array(unlist(lapply(split(tmpSplineInputs$covs,1:nrow(desMat)),function(x) tryCatch(get_gamma_recharge(m$mod$estimate[c(gamInd[unique(c(m$conditions$betaCons))],length(m$mod$estimate)-nbRecovs:0)],covs=x,formula=tmpSplineInputs$formula,hierRecharge=hierRecharge,nbStates=nbStates,betaRef=m$conditions$betaRef,betaCons=m$conditions$betaCons,workBounds=rbind(m$conditions$workBounds$beta,m$conditions$workBounds$theta),mixture = mix),error=function(e) NA))),dim=c(nbStates,nbStates,nrow(desMat)))
          }
          
          if(!inherits(m,"hierarchical")){
            
            plotTPM(nbStates,cov,ref=1:nbStates,tempCovs,trMat,rawCovs,lwd,arg,plotCI,Sigma,gamInd,m,desMat,nbRecovs,tmpSplineInputs$formula,hierRecharge,mix,muffWarn,quantSup,tmpSplineInputs$covs,stateNames=1:nbStates)
            
            txt <- paste(names(rawCovs)[-cov],"=",tmpcovs[-cov],collapse=", ")
            if(nbRecovs & names(rawCovs)[cov]=="recharge"){
              tmpNames <- c(names(rawCovs)[-cov],colnames(m$reCovs))
              txt <- paste(tmpNames[!duplicated(tmpNames)],"=",c(tmpcovs[-cov],tmprecovs)[!duplicated(tmpNames)],collapse=", ")
            }
            if(ncol(rawCovs)>1 | nbRecovs) do.call(mtext,c(list(paste0(ifelse(mixtures>1,paste0("Mixture ",mix," t"),"T"),"ransition probabilities",ifelse(nbRecovs," at next time step: ",": "),txt),side=3,outer=TRUE,padj=2,cex=cex.main),marg))
            else do.call(mtext,c(list(paste0(ifelse(mixtures>1,paste0("Mixture ",mix," t"),"T"),"ransition probabilities"),side=3,outer=TRUE,padj=2,cex=cex.main),marg))
            
          } else {
            for(j in 1:(m$conditions$hierStates$height-1)){
              
              txt <- paste(names(rawCovs)[-cov][which(names(rawCovs)[-cov]!="level")],"=",tmpcovs[which(tmpcovs$level==j),-cov][which(names(rawCovs)[-cov]!="level")],collapse=", ")
              if(nbRecovs & grepl("recharge",names(rawCovs)[cov])){
                tmpNames <- c(names(rawCovs)[-cov][which(names(rawCovs)[-cov]!="level" & !grepl("recharge",names(rawCovs)[-cov]))],colnames(m$reCovs)[which(colnames(m$reCovs)!="level" & !grepl("recharge",colnames(m$reCovs)))])
                tmprecovs <- tmprecovs[,which(colnames(tmprecovs)!="level" & !grepl("recharge",colnames(tmprecovs))),drop=FALSE]
                txt <- paste(tmpNames[!duplicated(tmpNames)],"=",c(tmpcovs[which(tmpcovs$level==j),-cov][which(names(rawCovs)[-cov]!="level" & !grepl("recharge",names(rawCovs)[-cov]))],tmprecovs)[!duplicated(tmpNames)],collapse=", ")
              }
              
              if(j==1) {

                ref <- m$conditions$hierStates$Get(function(x) data.tree::Aggregate(x,"state",min),filterFun=function(x) x$level==j+1)
                
                # only plot if there is variation in stationary state proabilities
                if(!all(apply(trMat[ref,ref,which(tempCovs$level==j)],1:2,function(x) all( abs(x - mean(x)) < 1.e-6 )))){
                  
                  plotTPM(nbStates,cov,ref,tempCovs[which(tempCovs$level==j),],trMat[,,which(tempCovs$level==j)],rawCovs,lwd,arg,plotCI,Sigma,gamInd,m,desMat[which(tempCovs$level==j),],nbRecovs,tmpSplineInputs$formula,hierRecharge,mix,muffWarn,quantSup,tmpSplineInputs$covs[which(tempCovs$level==j),],stateNames=names(ref))
     
                  if(ncol(rawCovs[-cov])>1 | nbRecovs) do.call(mtext,c(list(paste0(ifelse(mixtures>1,paste0("Mixture ",mix," t"),"T"),"ransition probabilities for level",j,ifelse(nbRecovs," at next time step: ",": "),txt),side=3,outer=TRUE,padj=2,cex=cex.main),marg))
                  else do.call(mtext,c(list(paste0(ifelse(mixtures>1,paste0("Mixture ",mix," t"),"T"),"ransition probabilities for level",j),side=3,outer=TRUE,padj=2,cex=cex.main),marg))
                
                  #if(length(covnames[-cov])>1) do.call(mtext,c(list(paste0(ifelse(mixtures>1,paste0("Mixture ",mix," s"),"S"),"tationary state probabilities for level",j,": ",paste(covnames[-cov][which(covnames[-cov]!="level")]," = ",tmpcovs[which(tmpcovs$level==j),-cov][which(covnames[-cov]!="level")],collapse=", ")),side=3,outer=TRUE,padj=2,cex=cex.main),marg))
                  #else do.call(mtext,c(list(paste0(ifelse(mixtures>1,paste0("Mixture ",mix," s"),"S"),"tationary state probabilities for level",j),side=3,outer=TRUE,padj=2,cex=cex.main),marg))
                } 
              } else {
                t <- data.tree::Traverse(m$conditions$hierStates,filterFun=function(x) x$level==j)
                names(t) <- m$conditions$hierStates$Get("name",filterFun=function(x) x$level==j)
                for(k in names(t)){
                  ref <- t[[k]]$Get(function(x) data.tree::Aggregate(x,"state",min),filterFun=function(x) x$level==j+1)#t[[k]]$Get("state",filterFun = data.tree::isLeaf)
                  # only plot if jth node has children and there is variation in stationary state proabilities
                  if(!is.null(ref) && !all(apply(trMat[ref,ref,which(tempCovs$level==j)],1:2,function(x) all( abs(x - mean(x)) < 1.e-6 )))){
                    
                    plotTPM(nbStates,cov,ref,tempCovs[which(tempCovs$level==j),],trMat[,,which(tempCovs$level==j)],rawCovs,lwd,arg,plotCI,Sigma,gamInd,m,desMat[which(tempCovs$level==j),],nbRecovs,tmpSplineInputs$formula,hierRecharge,mix,muffWarn,quantSup,tmpSplineInputs$covs[which(tempCovs$level==j),],stateNames=names(ref))
                    
                    if(ncol(rawCovs[-cov])>1 | nbRecovs) do.call(mtext,c(list(paste0(ifelse(mixtures>1,paste0("Mixture ",mix," t"),"T"),"ransition probabilities for level",j," ",k,ifelse(nbRecovs," at next time step: ",": "),txt),side=3,outer=TRUE,padj=2,cex=cex.main),marg))
                    else do.call(mtext,c(list(paste0(ifelse(mixtures>1,paste0("Mixture ",mix," t"),"T"),"ransition probabilities for level",j," ",k),side=3,outer=TRUE,padj=2,cex=cex.main),marg))
                    
                  }
                }
              }
            }
          }
        }
        
        if(plotStationary) {
          par(mfrow=c(1,1))
          if(inherits(m,"hierarchical")){
            if(is.null(recharge)){
              tmpSplineInputs$covs <- tempCovs
            } else {
              tmpSplineInputs$covs <- tmpSplineInputs$covs[which(tmpSplineInputs$covs$level==levels(m$data$level)[1]),]
            }
          }
          statPlot(m,Sigma,nbStates,tmpSplineInputs$formula,tmpSplineInputs$covs,tempCovs,tmpcovs,cov,hierRecharge,alpha,gridLength,gamInd,names(rawCovs),col,plotCI,...)
        }
      }
    }
  }
  
  #################################
  ## Plot maps colored by states ##
  #################################

  if(all(coordNames %in% names(m$data)) | nbRecovs){
    
    if(nbStates>1) { # no need to plot the map if only one state
      par(mfrow=c(1,1))
      par(mar=c(5,4,4,2)-c(0,0,2,1)) # bottom, left, top, right
      
      #if(inherits(m,"hierarchical") & all(coordNames %in% names(m$data))){
      #  hViterbi <- hierViterbi(m, states, stateNames=FALSE)
      #  cat("DONE\n")
      #  for(j in 1:(m$conditions$hierStates$height-1)){
      #  }
      #}
      
      for(zoo in 1:nbAnimals) {
        # states for animal 'zoo'
        subStates <- states[which(m$data$ID==ID[zoo])]
        
        if(nbRecovs){
          #par(mfrow=c(1,1))
          ind <- which(m$data$ID==ID[zoo])
          
          for(j in 1:length(rechargeNames)){
            if(plotCI){
              irecovs <- stats::model.matrix(recharge$theta,m$data[ind,])
              ig0covs <- stats::model.matrix(recharge$g0,m$data[ind,])
              rechargeSigma <- Sigma[length(m$mod$estimate)-(nbRecovs+nbG0covs+1):0,length(m$mod$estimate)-(nbRecovs+nbG0covs+1):0]
              dN<-t(mapply(function(x) tryCatch(numDeriv::grad(get_recharge,m$mod$estimate[length(m$mod$estimate)-(nbRecovs+nbG0covs+1):0],recovs=irecovs,g0covs=ig0covs,recharge=recharge,hierRecharge=hierRecharge,rechargeName=rechargeNames[j],covs=m$data[ind,],workBounds=m$conditions$workBounds,k=x),error=function(e) NA),1:length(ind)))
              if(any(!is.finite(sqrt(diag(rechargeSigma))))) se <-NA
              else se<-t(apply(dN,1,function(x) tryCatch(suppressWarnings(sqrt(x%*%rechargeSigma%*%x)),error=function(e) NA)))
              lci<-m$data[[rechargeNames[j]]][ind]-quantSup*se
              uci<-m$data[[rechargeNames[j]]][ind]+quantSup*se
              
              if(!all(is.na(se))) reylim <- c(min(lci,na.rm=TRUE),max(uci,na.rm=TRUE))
              else reylim <- NULL
              do.call(plot,c(list(x=1:length(ind),y=m$data[[rechargeNames[j]]][ind],pch=16,xlab="t",ylab="g(t)",col=col[subStates],cex=cex,ylim=reylim),arg))
              do.call(segments,c(list(y0=m$data[[rechargeNames[j]]][ind][-length(ind)],x0=1:(length(ind)-1),y1=m$data[[rechargeNames[j]]][ind][-1],x1=2:length(ind),
                                      col=col[subStates[-length(subStates)]],lwd=lwd),arg))
              if(!all(is.na(se))) {
                ciInd <- which(!is.na(se))
                
                withCallingHandlers(do.call(arrows,c(list((1:length(ind))[ciInd], lci[ciInd], (1:length(ind))[ciInd], 
                                           uci[ciInd], length=0.025, angle=90, code=3, col=col[subStates], lwd=lwd),arg)),warning=muffWarn)
              }
            } else {
              do.call(plot,c(list(x=1:length(ind),y=m$data[[rechargeNames[j]]][ind],pch=16,xlab="t",ylab="g(t)",col=col[subStates],cex=cex),arg))
              do.call(segments,c(list(y0=m$data[[rechargeNames[j]]][ind][-length(ind)],x0=1:(length(ind)-1),y1=m$data[[rechargeNames[j]]][ind][-1],x1=2:length(ind),
                                      col=col[subStates[-length(subStates)]],lwd=lwd),arg))
            }
            do.call(mtext,c(list(paste("ID",ID[zoo],"recharge function",ifelse(rechargeNames[j]=="recharge","",paste0("for level",gsub("recharge","",rechargeNames[j])))),side=3,outer=TRUE,padj=2,cex=cex.main),marg))
            legend(ifelse(is.null(legend.pos),"topleft",legend.pos),legText,lwd=rep(lwd,nbStates),col=col,bty="n",cex=cex.legend)
            abline(h=0,lty=2)
          }
        }
        
        if(all(coordNames %in% names(m$data))){
          if(plotTracks){
            # plot trajectory
            remNA <- which(!is.na(x[[zoo]]))
            x[[zoo]] <- x[[zoo]][remNA]
            y[[zoo]] <- y[[zoo]][remNA]
            subStates <- subStates[remNA]
            if(length(coordNames)==2){
              do.call(plot,c(list(x=x[[zoo]],y=y[[zoo]],pch=16,xlab=coordNames[1],ylab=coordNames[2],col=col[subStates],cex=cex,asp=asp),arg))
              
              do.call(segments,c(list(x0=x[[zoo]][-length(x[[zoo]])],y0=y[[zoo]][-length(x[[zoo]])],x1=x[[zoo]][-1],y1=y[[zoo]][-1],
                       col=col[subStates[-length(subStates)]],lwd=lwd),arg))
              
              if(plotEllipse) {
                for(i in 1:length(x[[zoo]]))
                  do.call(lines,c(list(errorEllipse[[zoo]][[i]],col=adjustcolor(col[subStates[i]],alpha.f=0.25),cex=cex),arg))
              }
            } else {
              
              if (!requireNamespace("scatterplot3d", quietly = TRUE)) {
                warning("Package \"scatterplot3d\" needed to plot tracks. Please install it.",
                       call. = FALSE)
              }
              
              z[[zoo]] <- z[[zoo]][remNA]
              ## interactive 3d plot
              #do.call(plotly::plot_ly,list(x=x[[zoo]],y=y[[zoo]],z=z[[zoo]],type='scatter3d',mode='lines',
              #                             opacity = 1, line = list(width = 6, color = col[subStates])))  %>% plotly::layout(title=paste("ID",ID[zoo]))
              plot3d <- scatterplot3d::scatterplot3d(x = x[[zoo]], y = y[[zoo]], z = z[[zoo]],color=col[subStates],type="o",pch=20,xlab=coordNames[1],ylab=coordNames[2],zlab=coordNames[3],box=FALSE)
              plot2d <- plot3d$xyz.convert(x = x[[zoo]], y = y[[zoo]], z = z[[zoo]])
              do.call(segments,c(list(x0=plot2d$x[-length(plot2d$x)],y0=plot2d$y[-length(plot2d$y)],x1=plot2d$x[-1],y1=plot2d$y[-1],
                                      col=col[subStates[-length(subStates)]],lwd=lwd),arg))
            }  
            do.call(mtext,c(list(paste("ID",ID[zoo]),side=3,outer=TRUE,padj=2,cex=cex.main),marg))
            legend(ifelse(is.null(legend.pos),"topleft",legend.pos),legText,lwd=rep(lwd,nbStates),col=col,bty="n",cex=cex.legend)
          }
        }
      }
    }
  }
  
  # set the graphical parameters back to default
  par(mfrow=c(1,1))
  par(mar=c(5,4,4,2)) # bottom, left, top, right
  par(ask=FALSE)
}

# Plot histograms
#
# Plot histograms of steps and angles, and the fitted densities. This function is only
# used in the function plot.momentuHMM.
#
# Parameters:
#  - gen: list of data streams (if several animals), or otherwise a data stream.
#    (e.g. gen[[1]][3] is the 3rd observation of the first animal)
#  - genDensities: list of matrices of values of the fitted densities. Each matrix has
#    two columns, the first being the grid of values on which the density is estimated,
#    and the second the values of the density.
#  - dist: named list indicating the probability distribution for the data stream (e.g. list(step=``gamma''), list(angle=``vm''))
#  - message: message to print above the histograms
#  - sepStates, breaks, hist.ylim: see arguments of plot.momentuHMM.
#  - state: if sepStates, this function needs to know which state needs to be plotted.
#  - col: colors of the state-dependent density lines

plotHist <- function (gen,genDensities,dist,message,
                      sepStates,breaks="Sturges",state=NULL,hist.ylim=NULL,col=NULL,legText, cumul=TRUE, iStates, ...)
{
  # vertical limits
  if(!is.null(hist.ylim)) {
    ymin <- hist.ylim[1]
    ymax <- hist.ylim[2]
  } else {
    ymin <- 0
    ymax <- NA
  }
  
  nbStates <- length(iStates)
  if(!sepStates) {
    lty <- rep(1, nbStates)
    if (cumul) {
      legText <- c(legText, "Total")
      col <- c(col, "black")
      lty <- c(lty, 2)
    }
  }
  
  distname <- names(dist)
  
  if(!missing(...)){
    arg <- list(...)
    if(!is.null(arg$cex)) cex <- arg$cex
    else cex <- 0.6
    arg$cex <- NULL
    if(!is.null(arg$asp)) asp <- arg$asp
    else asp <- 1
    arg$asp <- NULL
    if(!is.null(arg$lwd)) lwd <- arg$lwd
    else lwd <- 2
    arg$lwd <- NULL
    if(!is.null(arg$cex.main)) cex.main <- arg$cex.main
    else cex.main <- NA
    arg$cex.main <- NULL
    if(!is.null(arg$cex.legend)) cex.legend <- arg$cex.legend
    else cex.legend <- 1
    arg$cex.legend <- NULL 
    if(!is.null(arg$legend.pos)) legend.pos <- arg$legend.pos
    else legend.pos <- NULL
    arg$legend.pos <- NULL
  } else {
    cex <- 0.6
    asp <- 1
    lwd <- 2
    cex.main <- NA
    cex.legend <- 1
    legend.pos <- NULL
    arg <- NULL
  }
  marg <- arg
  marg$cex <- NULL
  
  if(dist %in% angledists){
    h <- hist(gen,plot=F,breaks=breaks) # to determine 'breaks'
    breaks <- seq(-pi,pi,length=length(h$breaks))
    
    if(is.null(hist.ylim)) { # default
      h <- hist(gen,plot=F,breaks=breaks)
      ymax <- 1.3*max(h$density)
      
      # find the maximum of the gen densit-y-ies, and take it as ymax if necessary
      if(sepStates) {
        maxdens <- max(genDensities[[state]][,2])
        if(maxdens>ymax & maxdens<2*max(h$density))
          ymax <- maxdens
        
      } else {
        maxdens <- max(genDensities[[iStates[1]]][,2])
        if(nbStates>1) {
          for(state in iStates[-1]) {
            if(is.finite(max(genDensities[[state]][,2]))){
              if(max(genDensities[[state]][,2])>maxdens)
                maxdens <- max(genDensities[[state]][,2])
            }
          }
        }
        if(maxdens>ymax){
          ymax <- ifelse(maxdens<2*max(h$density),maxdens,2*max(h$density))
        }
      }
    }
    
    # plot gen histogram
    do.call(hist,c(list(gen,prob=T,main="",ylim=c(0,ymax),xlab=paste0(distname," (radians)"),
         col="grey",border="white",breaks=breaks,xaxt="n"),arg))
    do.call(axis,c(list(1, at = c(-pi, -pi/2, 0, pi/2, pi),
         labels = expression(-pi, -pi/2, 0, pi/2, pi)),arg))
    
    do.call(mtext,c(list(message,side=3,outer=TRUE,padj=2,cex=cex.main),marg))
    
    # plot gen density over the histogram
    if(sepStates)
      lines(genDensities[[state]],col=col[state],lwd=lwd)
    else {
      for(s in iStates)
        lines(genDensities[[s]],col=col[s],lwd=lwd)
      if(cumul){
        total <- genDensities[[iStates[1]]]
        for (s in iStates[-1]) total[, 2] <- total[, 2] + genDensities[[s]][, 2]
        lines(total, lwd = lwd, lty = 2)
      }
      legend(ifelse(is.null(legend.pos),"topright",legend.pos),legText,lwd=rep(lwd,nbStates),col=col[c(iStates,ifelse(cumul,length(col),0))],bty="n",cex=cex.legend)
    }  
  } else {
    
    h <- tryCatch(hist(gen,plot=F,breaks=breaks),error=function(e) e)
    
    if(!inherits(h,"error")){
      # determine ylim
      if(is.null(hist.ylim)) { # default
      
        ymax <- 1.3*max(h$density)
        
        # find the maximum of the gen densit-y-ies, and take it as ymax if necessary
        if(sepStates) {
          maxdens <- max(genDensities[[state]][,2])
          if(maxdens>ymax & maxdens<2*max(h$density))
            ymax <- maxdens
          
        } else {
          maxdens <- max(genDensities[[iStates[1]]][,2])
          if(nbStates>1) {
            for(state in iStates[-1]) {
              if(is.finite(max(genDensities[[state]][,2]))){
                if(max(genDensities[[state]][,2])>maxdens)
                  maxdens <- max(genDensities[[state]][,2])
              }
            }
          }
          if(maxdens>ymax){
            ymax <- ifelse(maxdens<2*max(h$density),maxdens,2*max(h$density))
          }
        }
      }
    
      # plot gen histogram
      do.call(hist,c(list(gen,prob=T,main="",ylim=c(ymin,ymax),xlab=distname,
           col="grey",border="white",breaks=breaks),arg))
      
      do.call(mtext,c(list(message,side=3,outer=TRUE,padj=2,cex=cex.main),marg))
      
      # plot gen density over the histogram
      if(sepStates)
        lines(genDensities[[state]],col=col[state],lwd=lwd)
      else {
        for(s in iStates)
          lines(genDensities[[s]],col=col[s],lwd=lwd)
        if(cumul){
          total <- genDensities[[iStates[1]]]
          for (s in iStates[-1]) total[, 2] <- total[, 2] + genDensities[[s]][, 2]
          lines(total, lwd = lwd, lty = 2)
        }
        legend(ifelse(is.null(legend.pos),"topright",legend.pos),legText,lwd=rep(lwd,nbStates),col=col[c(iStates,ifelse(cumul,length(col),0))],bty="n",cex=cex.legend)
      }
    }
  }
}

plotHistMVN <- function(gen,genDensities,dist,message,sepStates,breaks="Sturges",state=NULL,col=NULL,legText, cumul=TRUE, iStates, ...){
  
  nbStates <- length(iStates)
  if(!sepStates) {
    lty <- rep(1, nbStates)
    if (cumul) {
      legText <- c(legText, "Total")
      col <- c(col, "black")
      lty <- c(lty, 2)
    }
  }
  
  distname <- names(dist)
  
  if(!missing(...)){
    arg <- list(...)
    if(!is.null(arg$cex)) cex <- arg$cex
    else cex <- 0.6
    arg$cex <- NULL
    if(!is.null(arg$asp)) asp <- arg$asp
    else asp <- 1
    arg$asp <- NULL
    if(!is.null(arg$lwd)) lwd <- arg$lwd
    else lwd <- 2
    arg$lwd <- NULL
    if(!is.null(arg$cex.main)) cex.main <- arg$cex.main
    else cex.main <- NA
    arg$cex.main <- NULL
    if(!is.null(arg$cex.legend)) cex.legend <- arg$cex.legend
    else cex.legend <- 1
    arg$cex.legend <- NULL 
    if(!is.null(arg$legend.pos)) legend.pos <- arg$legend.pos
    else legend.pos <- NULL
    arg$legend.pos <- NULL
  } else {
    cex <- 0.6
    asp <- 1
    lwd <- 2
    cex.main <- NA
    cex.legend <- 1
    legend.pos <- NULL
    arg <- NULL
  }
  marg <- arg
  marg$cex <- NULL

  if(dist[[distname]]=="mvnorm2" || dist[[distname]]=="rw_mvnorm2"){
    # plot gen histogram
    par(mar=c(4,4,1,1)) # bottom, left, top, right
    graphics::layout(matrix(c(2,4,1,3),2,2,byrow=TRUE),c(3,1), c(1,3))
    h1 <- hist(gen[[paste0(distname,".x")]], breaks=breaks, plot=FALSE)
    h2 <- hist(gen[[paste0(distname,".y")]], breaks=breaks, plot=FALSE)
    top <- max(h1$counts, h2$counts)
    k <- MASS::kde2d(gen[[paste0(distname,".x")]][!is.na(gen[[paste0(distname,".x")]])], gen[[paste0(distname,".y")]][!is.na(gen[[paste0(distname,".y")]])], n=25)
    rf <- grDevices::colorRampPalette(c("#5E4FA2", "#3288BD", "#66C2A5", "#ABDDA4", "#E6F598", "#FFFFBF", "#FEE08B", "#FDAE61", "#F46D43", "#D53E4F", "#9E0142")) # grDevices::colorRampPalette(rev(RColorBrewer::brewer.pal(11,'Spectral')))
    r <- rf(32)
    graphics::image(k, col=r,xlab=paste0(distname,".x"),ylab=paste0(distname,".y"),cex.lab=1.1) #plot the image
    # plot gen density over the histogram
    if(sepStates)
      graphics::contour(genDensities[[state]],add=TRUE,col=col[state])
    else {
      for(s in iStates)
        graphics::contour(genDensities[[s]],add=TRUE,col=col[s])
      if(cumul){
        total <- genDensities[[iStates[1]]]
        for (s in iStates[-1]) total$z <- total$z + genDensities[[s]]$z
        graphics::contour(total,add=TRUE,lty=2)
      }
    }
    par(mar=c(0,3,1,0))
    barplot(h1$counts, axes=F, ylim=c(0, top), space=0,col="grey",border="white")
    do.call(mtext,c(list(message,side=3,outer=TRUE,padj=2,cex=cex.main),marg))
    par(mar=c(3,0,0.5,1))
    barplot(h2$counts, axes=F, xlim=c(0, top), space=0, horiz=TRUE,col="grey",border="white")
    plot.new()
    if(!sepStates) legend(ifelse(is.null(legend.pos),"topright",legend.pos),legText,lwd=rep(lwd,nbStates),col=col[c(iStates,ifelse(cumul,length(col),0))],bty="n",cex=cex.legend)
  }
}

plotTPM <- function(nbStates,cov,ref=1:nbStates,tempCovs,trMat,rawCovs,lwd,arg,plotCI,Sigma,gamInd,m,desMat,nbRecovs,formula,hierRecharge,mix,muffWarn,quantSup,covs,stateNames){
  
  par(mfrow=c(length(ref),length(ref)))
  
  for(i in 1:length(ref)){
    for(j in 1:length(ref)){
      do.call(plot,c(list(tempCovs[,cov],trMat[ref[i],ref[j],],type="l",ylim=c(0,1),xlab=names(rawCovs)[cov],ylab=paste(stateNames[i],"->",stateNames[j]),lwd=lwd),arg))
      if(plotCI){
        tmpSig <- Sigma[gamInd[unique(c(m$conditions$betaCons))],gamInd[unique(c(m$conditions$betaCons))]]
        if(!is.null(hierRecharge)){
          tmpSig <- Sigma[c(gamInd[unique(c(m$conditions$betaCons))],length(m$mod$estimate)-nbRecovs:0),c(gamInd[unique(c(m$conditions$betaCons))],length(m$mod$estimate)-nbRecovs:0)]
          dN<-matrix(unlist(lapply(split(covs,1:nrow(desMat)),function(x) tryCatch(numDeriv::grad(get_gamma_recharge,m$mod$estimate[c(gamInd[unique(c(m$conditions$betaCons))],length(m$mod$estimate)-nbRecovs:0)],covs=x,formula=formula,hierRecharge=hierRecharge,nbStates=nbStates,i=ref[i],j=ref[j],betaRef=m$conditions$betaRef,betaCons=m$conditions$betaCons,workBounds=rbind(m$conditions$workBounds$beta,m$conditions$workBounds$theta),mixture=mix),error=function(e) NA))),ncol=ncol(tmpSig),byrow=TRUE)
        } else {
          dN<-t(apply(desMat,1,function(x) tryCatch(numDeriv::grad(get_gamma,m$mod$estimate[gamInd[unique(c(m$conditions$betaCons))]],covs=matrix(x,1,dimnames=list(NULL,names(x))),nbStates=nbStates,i=ref[i],j=ref[j],betaRef=m$conditions$betaRef,betaCons=m$conditions$betaCons,workBounds=m$conditions$workBounds$beta,mixture=mix),error=function(e) NA)))
        }
        se<-t(apply(dN,1,function(x) tryCatch(suppressWarnings(sqrt(x%*%tmpSig%*%x)),error=function(e) NA)))
        if(!all(is.na(se))) {
          lci<-1/(1+exp(-(log(trMat[ref[i],ref[j],]/(1-trMat[ref[i],ref[j],]))-quantSup*(1/(trMat[ref[i],ref[j],]-trMat[ref[i],ref[j],]^2))*se)))#trMat[ref[i],ref[j],]-quantSup*se[ref[i],ref[j]]
          uci<-1/(1+exp(-(log(trMat[ref[i],ref[j],]/(1-trMat[ref[i],ref[j],]))+quantSup*(1/(trMat[ref[i],ref[j],]-trMat[ref[i],ref[j],]^2))*se)))#trMat[ref[i],ref[j],]+quantSup*se[ref[i],ref[j]]
          
          ciInd <- which(!is.na(se))
          
          withCallingHandlers(do.call(arrows,c(list(as.numeric(tempCovs[ciInd,cov]), lci[ciInd], as.numeric(tempCovs[ciInd,cov]), 
                                                    uci[ciInd], length=0.025, angle=90, code=3, col=gray(.5), lwd=lwd),arg)),warning=muffWarn)
        }
      }
    }
  }
}

getCovs <-function(m,covs,ID,checkHier=TRUE){
  if(is.null(covs)){
    if(inherits(m,"hierarchical")) covs <- as.data.frame(lapply(m$data,function(x) x[which.max(!is.na(x))]))
    else covs <- m$data[which(m$data$ID %in% ID),][1,]
    for(j in names(m$data)[which(unlist(lapply(m$data,function(x) any(class(x) %in% meansList))))]){
      if(inherits(m$data[[j]],"angle")) covs[[j]] <- CircStats::circ.mean(m$data[[j]][which(m$data$ID %in% ID)][!is.na(m$data[[j]][which(m$data$ID %in% ID)])])
      else covs[[j]]<-mean(m$data[[j]][which(m$data$ID %in% ID)],na.rm=TRUE)
    }
  } else {
    if(!is.data.frame(covs)) stop('covs must be a data frame')
    if(nrow(covs)>1) stop('covs must consist of a single row')
    if(is.null(recharge))
      if(!all(names(covs) %in% names(m$data))) stop('invalid covs specified')
    else 
      if(!all(names(covs) %in% c(names(m$data),"recharge"))) stop('invalid covs specified')
    if(any(names(covs) %in% "ID")) covs$ID<-factor(covs$ID,levels=unique(m$data$ID))
    if(checkHier && inherits(m,"hierarchical") && any(names(covs) %in% "level")) stop("covs$level cannot be specified for hierarchical models")
    for(j in names(m$data)[which(names(m$data) %in% names(covs))]){
      if(inherits(m$data[[j]],"factor")) covs[[j]] <- factor(covs[[j]],levels=levels(m$data[[j]]))
      if(is.na(covs[[j]])) stop("check covs value for ",j)
    }  
    for(j in names(m$data)[which(!(names(m$data) %in% names(covs)))]){
      if(any(class(m$data[[j]]) %in% meansList)){
        if(inherits(m$data[[j]],"angle")) covs[[j]] <- CircStats::circ.mean(m$data[[j]][!is.na(m$data[[j]])])
        else covs[[j]]<-mean(m$data[[j]],na.rm=TRUE)
      } else {
        if(inherits(m,"hierarchical")) covInd <- which.max(!is.na(m$data[[j]]))
        else covInd <- 1
        covs[[j]] <- m$data[[j]][covInd]
      }
    }
  }
  covs
}

# get 1-dimensional marginal density from a multivariate normal density (modified from dtmvnorm::dtmvnorm.marginal)
# xn	= Vector of quantiles to calculate the marginal density for.
# n	= Index position (1..k) within the random vector x to calculate the one-dimensional marginal density for.
# mean	 = Mean vector, default is rep(0, length = nrow(sigma)).
# sigma = Covariance matrix, default is diag(length(mean)).
# log	= Logical; if TRUE, densities d are given as log(d).
#' @importFrom stats dnorm
#' @importFrom mvtnorm pmvnorm
dmvnorm.marginal <- function (xn, n = 1, mean = rep(0, nrow(sigma)), sigma = diag(length(mean)), log = FALSE) {
  
  if (NROW(sigma) != NCOL(sigma)) {
    stop("sigma must be a square matrix")
  }
  if (length(mean) != NROW(sigma)) {
    stop("mean and sigma have non-conforming size")
  }
  k <- length(mean)
  if (n < 1 || n > length(mean) || !is.numeric(n) || length(n) > 
      1 || !n %in% 1:length(mean)) {
    stop("n must be an integer scalar in 1..length(mean)")
  }
  if (k == 1) {
    density <- stats::dnorm(xn, mean = mean, sd = sqrt(sigma))
    if (log == TRUE) {
      return(log(density))
    }
    else {
      return(density)
    }
  }
  C <- sigma
  A <- solve(sigma)
  A_1 <- A[-n, -n]
  A_1_inv <- solve(A_1)
  C_1 <- C[-n, -n]
  c_nn <- C[n, n]
  c <- C[-n, n]
  mu <- mean
  mu_1 <- mean[-n]
  mu_n <- mean[n]
  p <- mvtnorm::pmvnorm(mean = mu, sigma = C)
  f_xn <- c()
  for (i in 1:length(xn)) {
    if(is.infinite(xn[i])) {
      f_xn[i] <- 0
      next
    }
    m <- mu_1 + (xn[i] - mu_n) * c/c_nn
    f_xn[i] <- exp(-0.5 * (xn[i] - mu_n)^2/c_nn) * mvtnorm::pmvnorm(mean = m, sigma = A_1_inv)
  }
  density <- 1/p * 1/sqrt(2 * pi * c_nn) * f_xn
  if (log == TRUE) {
    return(log(density))
  }
  else {
    return(density)
  }
}

Try the momentuHMM package in your browser

Any scripts or data that you put into this service are public.

momentuHMM documentation built on Oct. 19, 2022, 1:07 a.m.