R/plotStates.R

Defines functions plotStates

Documented in plotStates

#' Plot states
#'
#' Plot the states and states probabilities.
#'
#' @param m A \code{\link{momentuHMM}}, \code{\link{momentuHierHMM}}, \code{\link{miHMM}}, or \code{\link{miSum}} object
#' @param animals Vector of indices or IDs of animals for which states will be plotted.
#' @param ask If \code{TRUE}, the execution pauses between each plot.
#'
#' @examples
#' # m is a momentuHMM object (as returned by fitHMM), automatically loaded with the package
#' m <- example$m
#'
#' # plot states for first and second animals
#' plotStates(m,animals=c(1,2))
#'
#' @export

plotStates <- function(m,animals=NULL,ask=TRUE)
{
  if(!is.momentuHMM(m) & !is.miHMM(m) & !is.miSum(m))
    stop("'m' must be a momentuHMM, miHMM, or miSum object (as output by fitHMM, MIfitHMM, or MIpool)")
  
  if(is.miHMM(m)) m <- m$miSum
  
  if(inherits(m,"hierarchical")) {
    hierarchical <- TRUE
    installDataTree()
  } else hierarchical <- FALSE
  
  nbAnimals <- length(unique(m$data$ID))
  nbStates <- length(m$stateNames)#ifelse(is.momentuHMM(m),ncol(m$mle$stepPar),ncol(m$Par$stepPar$est))
  
  if(nbStates==1)
    stop("Only one state.")
  
  if(is.momentuHMM(m)){
    cat("Decoding state sequence... ")
    states <- viterbi(m, hierarchical = hierarchical)
    cat("DONE\n")
    cat("Computing state probabilities... ")
    sp <- stateProbs(m, hierarchical = hierarchical)
    cat("DONE\n")
  } else {
    states <- m$Par$states
    sp <- m$Par$stateProbs$est
    if(hierarchical){
      states <- hierViterbi(m, states)
      sp <- hierStateProbs(m, sp)
    }
  }
  
  # define animals to be plotted
  if(is.null(animals)) # all animals are plotted
    animalsInd <- 1:nbAnimals
  else {
    if(is.character(animals)) { # animals' IDs provided
      animalsInd <- NULL
      for(zoo in 1:length(animals)) {
        if(length(which(unique(m$data$ID)==animals[zoo]))==0) # ID not found
          stop("Check animals argument.")
        
        animalsInd <- c(animalsInd,which(unique(m$data$ID)==animals[zoo]))
      }
    }
    
    if(is.numeric(animals)) { # animals' indices provided
      if(length(which(animals<1))>0 | length(which(animals>nbAnimals))>0) # index out of bounds
        stop("Check animals argument.")
      
      animalsInd <- animals
    }
  }
  
  if(!hierarchical){
    
    par(mfrow=c(nbStates+1,1))
    par(ask=ask)
  
    for(zoo in animalsInd) {
      ind <- which(m$data$ID==unique(m$data$ID)[zoo])
      
      # plot the states
      par(mar=c(5,4,4,2)-c(2,0,0,0))
      plot(states[ind],main=paste("ID ",unique(m$data$ID)[zoo],sep=""),ylim=c(0.5,nbStates+0.5),
           yaxt="n",xlab="",ylab="State")
      axis(side=2,at=1:nbStates,labels=as.character(1:nbStates))
      
      # plot the states probabilities
      par(mar=c(5,4,4,2)-c(0,0,2,0))
      for(i in 1:nbStates) {
        plot(sp[ind,i],type="l",xlab="Observation index",ylab=paste("Pr(State=",i,")",sep=""))
        abline(h=0.5,lty=2,col="darkgrey")
      }
    }
  } else {
    for(j in 1:(m$conditions$hierStates$height-1)){
      if(j==m$conditions$hierStates$height-1) ref <- m$conditions$hierStates$Get("name",filterFun=data.tree::isLeaf)
      else ref <- m$conditions$hierStates$Get("name",filterFun=function(x) x$level==j+1)
      nbStates <- length(ref)
      lData <- m$data[which(m$data$level %in% c(j,paste0(j,"i"))),"ID"]
      
      par(mfrow=c(nbStates+1,1))
      par(ask=ask)
    
      for(zoo in animalsInd) {
        ind <- which(lData==unique(m$data$ID)[zoo])
        
        # plot the states
        par(mar=c(5,6.5,4,2)-c(2,0,0,0))
        plot((1:nbStates)[match(states[[paste0("level",j)]][ind],ref)],main=paste("level",j,": ","ID ",unique(m$data$ID)[zoo],sep=""),ylim=c(0.5,nbStates+0.5),yaxt="n",
             yaxt="n",xlab="",ylab=NA)
        axis(side=2,at=1:nbStates,labels=ref,las=2)
        
        # plot the states probabilities
        par(mar=c(5,6.5,4,2)-c(0,0,2,0))
        for(i in 1:nbStates) {
          plot(sp[[paste0("level",j)]][ind,i],type="l",xlab="Observation index",ylab=paste("Pr(State=",ref[i],")",sep=""))
          abline(h=0.5,lty=2,col="darkgrey")
        }
      }
    }
  }
  
  # back to default
  par(mar=c(5,4,4,2)) # bottom, left, top, right
  par(mfrow=c(1,1))
  par(ask=FALSE)
}

Try the momentuHMM package in your browser

Any scripts or data that you put into this service are public.

momentuHMM documentation built on Oct. 19, 2022, 1:07 a.m.