R/Helpers.R

Defines functions .get_varying .make_sum_vals .get_cuts_cov .cov_latsp .cov_lnorm .cov_norm .cov_pois .cov_ord .cov_bern .remove_nas .count_cats .irf .gp_prior .na_if .idx_col2rowm .num_pars .calc_starts .pars_total_indexes .remove_empty_pars .check_pars .check_pars_second .get_kept_samples2 .extract_nonp .item_plot_ls .item_plot_ord_grm .item_plot_ord_rs .item_plot_binary .prepare_rollcall .calc_true_pts .init_stan .check_quoted .create_array .prepare_legis_data .extract_samples .vb_fix

#' @noRd
.vb_fix <- function(object=NULL,
                    this_data=NULL,nfix=NULL,
                    ncores=NULL,all_args=NULL,
                    restrict_ind_high=NULL,
                    restrict_ind_low=NULL,
                    tol_rel_obj=NULL,
                    model_type=NULL,
                    use_groups=NULL,
                    const_type=NULL,
                    fixtype=NULL,...) {
  
  # collect additional arguments
  if(is.null(all_args)) {
    all_args <- list(...) 
  } 

  . <- NULL
  
  if(this_data$time_proc==4) {
    tol_rel_obj <- .001
    eval_elbo <- 100
  } else {
    eval_elbo <- 100
    tol_rel_obj <- .001
  }

  print("(First Step): Estimating model with variational inference to identify modes to constrain.")

  post_modes <- object@stanmodel_map$variational(data =this_data,
                          refresh=this_data$id_refresh,
                          eval_elbo=eval_elbo,threads=1,
                          tol_rel_obj=tol_rel_obj, # better convergence criterion than default
                          output_samples=200)
  
  # pull out unidentified parameters
  
  if(const_type=="persons") {
    
    this_params <- post_modes$draws("L_full") %>% as_draws_matrix
    
    person <- apply(this_params,2,mean)

    restrict_ind_high <- which(person==max(person))[1]
    restrict_ind_low <- which(person==min(person))[1]
    val_high <- person[restrict_ind_high]
    val_low <- person[restrict_ind_low]
    
  } else if(const_type=="items") {
    
    this_params <- post_modes$draws("sigma_reg_full") %>% as_draws_matrix
    
    items <- apply(this_params,2,mean)

    restrict_ind_high <- which(items==max(items))[1]
    restrict_ind_low <- which(items==min(items))[1]
    val_high <- items[restrict_ind_high]
    val_low <- items[restrict_ind_low]
    
  }
  
  object@restrict_num_high <- val_high
  object@restrict_num_low <- val_low
  object@restrict_ind_high <- restrict_ind_high
  object@restrict_ind_low <- restrict_ind_low
  object@constraint_type <- const_type

  return(object)

  
}



#' @noRd
.extract_samples <- function(obj=NULL,extract_type=NULL,...) {
  if(!is.null(extract_type)) {
    param <- switch(extract_type,persons='L_full',
                    obs_discrim='sigma_reg_free',
                    miss_discrim='sigma_abs_free',
                    obs_diff='B_int_free',
                    miss_diff='A_int_free',
                    cutpoints='steps_votes3')
    as.data.frame(obj@stan_samples,pars=param,...)
  } else {
    as.data.frame(obj@stan_samples,...)
  }
  
  
}



#' Helper function for preparing person ideal point plot data
#' @noRd
.prepare_legis_data <- function(object,
                                high_limit=NULL,
                                low_limit=NULL,
                                aggregate=TRUE,
                                type='ideal_pts') {
  
  if(length(unique(object@score_data@score_matrix$time_id))>1 && type!='variance') {
    
    person_params <- object@time_varying 
    
    if("draws_matrix" %in% class(person_params)) {
      
      person_params <- person_params %>% 
        as_draws_df %>% 
        dplyr::select(-`.chain`,-`.iteration`,-`.draw`)
      
    } else {
      
      person_params <- as_tibble(person_params)
      
    }
    
   
    
    if(aggregate) {
      person_params <- person_params %>% gather(key = legis,value=ideal_pts) %>% 
        group_by(legis) %>% 
        summarize(low_pt=quantile(ideal_pts,low_limit),high_pt=quantile(ideal_pts,high_limit),
                  median_pt=median(ideal_pts)) %>% 
        mutate(param_id=stringr::str_extract(legis,'[0-9]+\\]'),
               param_id=as.numeric(stringr::str_extract(param_id,'[0-9]+')),
               time_point=stringr::str_extract(legis,'\\[[0-9]+'),
               time_point=as.numeric(stringr::str_extract(time_point,'[0-9]+')))
    } else {
      person_params <- person_params %>% gather(key = legis,value=ideal_pts) %>% 
        group_by(legis) %>% 
        mutate(param_id=stringr::str_extract(legis,'[0-9]+\\]'),
               param_id=as.numeric(stringr::str_extract(param_id,'[0-9]+')),
               time_point=stringr::str_extract(legis,'\\[[0-9]+'),
               time_point=as.numeric(stringr::str_extract(time_point,'[0-9]+')))
    }

    # get ids out 
    
    person_ids <- select(object@score_data@score_matrix,
                         !!quo(person_id),
                         !!quo(time_id),
                         !!quo(group_id)) %>% 
      distinct %>% 
      mutate(person_id_num=as.numeric(!!quo(person_id)),
             time_id_num=as.numeric(factor(!!quo(time_id))),
             group_id_num=as.numeric(!!quo(group_id)))
    
    if(object@use_groups) {
      person_params <-  person_params %>% 
        left_join(person_ids,by=c(param_id='group_id_num',
                                  time_point='time_id_num'))
    } else {
      person_params <-  person_params %>% 
        left_join(person_ids,by=c(param_id='person_id_num',
                                  time_point='time_id_num'))
    }
    
  } else {
    # need to match estimated parameters to original IDs
    if(type=='ideal_pts') {
      
        
        person_params <- object@stan_samples$draws('L_full') %>% as_draws_df %>% 
          dplyr::select(-`.chain`,-`.iteration`,-`.draw`)
      
      
    } else if(type=='variance') {
      
        
        # load time-varying person variances
        person_params <- object@stan_samples$draws('time_var_free') %>% as_draws_df %>% 
          dplyr::select(-`.chain`,-`.iteration`,-`.draw`)
        
      
    }
    
    person_params <- person_params %>% gather(key = legis,value=ideal_pts) 
    # get ids out 
    
    person_ids <- data_frame(long_name=person_params$legis) %>% 
      distinct
    legis_nums <- stringr::str_extract_all(person_ids$long_name,'[0-9]+',simplify=T)
    person_ids <-   mutate(person_ids,id_num=as.numeric(legis_nums))
    
    person_data <- distinct(select(object@score_data@score_matrix,
                                   person_id,group_id))
    
    
    # add in all data in the person_data object
    if(object@use_groups) {
      person_data <- mutate(person_data,id_num=as.numeric(group_id))

    } else {
      person_data <- mutate(person_data,id_num=as.numeric(person_id))
    }
    
    person_ids <- left_join(person_ids,person_data)
    
    if(aggregate) {
      person_params <-  person_params %>% 
        group_by(legis) %>% 
        summarize(low_pt=quantile(ideal_pts,low_limit),high_pt=quantile(ideal_pts,high_limit),
                  median_pt=median(ideal_pts)) %>% 
        left_join(person_ids,by=c(legis='long_name'))
    } else {
      person_params <-  person_params %>% 
        left_join(person_ids,by=c(legis='long_name'))
    }

    
  }
  
  person_params 
}

#' Helper function to create arrays
#' 
#' Function takes a data.frame in long mode and converts it to an array. Function can also repeat a 
#' single matrix to fill out an array.
#' 
#' @param input_matrix Either a data.frame in long mode or a single matrix
#' @param arr_dim If \code{input_matrix} is a single matrix, \code{arr_dim} determines the length of the resulting array
#' @param row_var Unquoted variable name that identifies the data.frame column corresponding to the rows (1st dimension) of the array (must be unique)
#' @param col_var_name Unquoted variable name that identifies the data.frame column corresponding names of the columns (2nd dimension) of the array
#' @param col_var_value Unquoted variable name that identifies the data.frame column corresponding to the values that populate the cells of the array
#' @param third_dim_var Unquoted variable name that identifis the data.frame column corresponding to the dimension around which to stack the matrices (3rd dimension of array)
#' @noRd
.create_array <- function(input_matrix,arr_dim=2,row_var=NULL,
                          col_var_name=NULL,
                          col_var_value,third_dim_var=NULL) {
  
  if('matrix' %in% class(input_matrix)) {
    
    # if just a matrix, rep it to hit array dims
    rep_matrix <- rep(c(input_matrix),arr_dim)
    out_array <- array(rep_matrix,dim=c(dim(input_matrix),arr_dim))
    
  } else if('data.frame' %in% class(input_matrix)) {
    
    # assuming data is in long form, select and then spread the bugger
    row_var <- enquo(row_var)
    col_var_name <- enquo(col_var_name)
    col_var_value <- enquo(col_var_value)
    third_dim_var <- enquo(third_dim_var)
    to_spread <- ungroup(input_matrix) %>% select(!!row_var,!!col_var_name,!!third_dim_var,!!col_var_value)
    
    # figure out how big this array should be
    arr_dim <- length(unique(pull(to_spread,!!third_dim_var)))
    
    if(!(nrow(distinct(to_spread))==nrow(to_spread))) stop('Each row in the data must be uniquely identified given row_var, col_var and third_dim_var.')
    
    to_array <- lapply(split(to_spread,pull(to_spread,!!third_dim_var)), function(this_data) {
      # spread and stuff into a list
      spread_it <- try(spread(this_data,key=!!col_var_name,value=!!col_var_value))
      if('try-error' %in% class(spread_it)) {
        print('Failed to find unique covariate values for dataset:')
        print(this_data)
        stop()
      }
      spread_it <- spread_it %>% 
        select(-!!row_var,-!!third_dim_var) %>% as.matrix
      row.names(spread_it) <- unique(pull(this_data,!!row_var))
      return(spread_it)
    })
    # convert to a vector before array-ing it
    long_vec <- c(do.call(c,to_array))
    # BOOM
    out_array <- array(long_vec,
                       dim=c(dim(to_array[[1]]),arr_dim),
                       dimnames=list(row.names=row.names(to_array[[1]]),
                                     colnames=colnames(to_array[[1]]),
                                     stack=unique(pull(to_spread,!!third_dim_var))))
  }
  
  return(out_array)
}

#' Simple function to test for what an input is
#' Default_val should be quoted
#' @noRd
.check_quoted <- function(quoted=NULL,default_val) {
  if(is.null(quoted)) {
    quoted <- default_val
  } else if(class(quoted)=='character') {
    quoted <- as.name(quoted)
    quoted <- enquo(quoted)
  } else {
    stop(paste0('Please do not enter a non-character value for ',as.character(default_val)[2]))
  }
}

#' Simple function to provide initial values to Stan given current values of restrict_sd
#' @importFrom stats optimize
#' @noRd
.init_stan <- function(chain_id=NULL,
                       restrict_sd_high=NULL,
                       restrict_sd_low=NULL,
                        person_sd=NULL,
                       num_legis=NULL,
                       legis_labels=NULL,
                       item_labels=NULL,
                       num_cit=NULL,
                        fix_high=NULL,
                       ar1_up=NULL,
                       ar1_down=NULL,
                       fix_low=NULL,
                       restrict_ind_high=NULL,
                       restrict_ind_low=NULL,
                       m_sd_par=NULL,
                       num_diff=NULL,
                       time_range=NULL,
                       const_type=NULL,
                       T=NULL,
                       time_proc=NULL,
                       time_fix_sd=NULL,
                       actual=TRUE,
                       use_ar=NULL,
                       person_start=NULL,
                       restrict_var=NULL) {

  L_full <- array(rnorm(n=num_legis,mean=0,sd=person_sd))
  sigma_reg_free <- array(rnorm(n=num_cit,mean=0,sd=2))
  sigma_abs_free <- array(rnorm(n=num_cit,mean=0,sd=2))
  A_int_free <- array(rnorm(n=num_cit,mean=0,sd=2))
  B_int_free <- array(rnorm(n=num_cit,mean=0,sd=2))
  
  names(L_full) <- legis_labels
  names(sigma_reg_free) <- paste0("Obs_Discrim_",item_labels)
  names(sigma_abs_free) <- paste0("Miss_Discrim_",item_labels)
  names(A_int_free) <- paste0("Obs_Difficulty_",item_labels)
  names(B_int_free) <- paste0("Miss_Difficulty_",item_labels)
  
  if(const_type==1 && !is.null(const_type)) {
    
    L_full[restrict_ind_high] <- fix_high
    L_full[restrict_ind_low] <- fix_low
    
  } else if(const_type==2 && !is.null(const_type)) {

    sigma_reg_free[restrict_ind_high] <- fix_high
    sigma_reg_free[restrict_ind_low] <- fix_low
  }
  
  # given upper bound on m_sd figure out mean to match real value on the real numbers
  
  # rev_trans <- function(x,m_sd_par) {
  #   m_sd_par[1]*plogis(x) - 1/m_sd_par[2]
  # }

  
  if(actual==TRUE) {
    # full run
    if(T>1) {
      
      if(restrict_var) {
        num_var <- num_legis - 1
      } else {
        num_var <- num_legis
      }
      
      # figure out optimized gp par
      
      # m_sd_optim <- optimize(f=rev_trans,
      #                        interval=c(0,m_sd_par[1]),
      #                     m_sd_par=m_sd_par)$objective
      # 
      # if(m_sd_optim<0) {
      #   # shouldn't happen, but just in case
      #   m_sd_optim <- m_sd_par[1]/2
      # }
      if(time_proc==4) {
        return(list(L_full = L_full,
                    sigma_reg_free=sigma_reg_free,
                    sigma_abs_free=sigma_abs_free,
                    A_int_free=A_int_free,
                    B_int_free=B_int_free,
                    m_sd=rep(m_sd_par,num_legis)))
      } else if(time_proc==3) {
        return(list(L_full = L_full,
                    L_AR1 = array(runif(n = num_legis,min = ar1_down+.1,max=ar1_up-.1)),
                    time_var_free = rexp(rate=1/time_fix_sd,n=num_var),
                    sigma_reg_free=sigma_reg_free,
                    sigma_abs_free=sigma_abs_free,
                    A_int_free=A_int_free,
                    B_int_free=B_int_free))
        
        } else if(time_proc==2) {
          return(list(L_full = L_full,
                      time_var_free = rexp(rate=1/time_fix_sd,n=num_var),
                      sigma_reg_free=sigma_reg_free,
                      sigma_abs_free=sigma_abs_free,
                      A_int_free=A_int_free,
                      B_int_free=B_int_free))
          } else {
            
        return(list(L_full = L_full,
                    sigma_reg_free=sigma_reg_free,
                    sigma_abs_free=sigma_abs_free,
                    A_int_free=A_int_free,
                    B_int_free=B_int_free))
        
      }
          

  } else {
    #identification run
    return(list(L_full = L_full,
                sigma_reg_free=sigma_reg_free,
                sigma_abs_free=sigma_abs_free,
                A_int_free=A_int_free,
                B_int_free=B_int_free))
  }
  
  
  
  }
  
}

#' used to calculate the true ideal points
#' given that a non-centered parameterization is used.
#' @importFrom posterior as_draws_df as_draws_matrix
#' @noRd
.calc_true_pts <- function(obj) {


  over_time <- rstan::extract(obj@stan_samples,'L_tp1')$L_tp1
  drift <- rstan::extract(obj@stan_samples,'L_full')$L_full
  
  save_array <- environment()
  save_array$array_slot <- array(data=NA,dim=dim(over_time))
  if(obj@use_ar) {
    new_pts <- sapply(1:dim(over_time)[2], function(t) {
      sapply(1:dim(over_time)[3], function(i) {
        if(t==1) {
          save_array$array_slot[,t,i] <- drift[,i]
        } else {
          save_array$array_slot[,t,i] <- over_time[,t,i,drop=F] + drift[,i]
        }
        
      })
      
    })
    new_pts <- save_array$array_slot
  } else {
    over_time[,1,] <- drift
    new_pts <- over_time
  }


  return(new_pts)
}

#' Pre-process rollcall objects
#' @noRd
.prepare_rollcall <- function(rc_obj=NULL,item_id=NULL,time_id=NULL) {
  
  # make the outcome

  score_data <- as_data_frame(rc_obj$votes) %>% 
    mutate(person_id=row.names(rc_obj$votes))  %>% 
    gather(key = item_id,value = outcome,-person_id)
  
   # merge in other data
  if(is.null(rc_obj$legis.data$legis.names)) {
    rc_obj$legis.data$legis.names <- row.names(rc_obj$legis.data)
  }
  
  score_data <- left_join(score_data,rc_obj$legis.data,by=c(person_id='legis.names'))
  
  score_data <- mutate(score_data,group_id=party)
  
  # extract time from bill labels if it exists
  if(!is.null(rc_obj$vote.data)) {

    score_data <- left_join(score_data,as_data_frame(rc_obj$vote.data),by=c(item_id=item_id))
  } else {
    score_data$time_id <- 1
    time_id <- 'time_id'
  }
  
  item_id <- 'item_id'
  
  return(list(score_data=score_data,
              time_id=time_id,
              item_id=item_id))
  
} 

#' Generate item-level midpoints for binary IRT outcomes
#' @noRd
.item_plot_binary <- function(param_name,object,
                       high_limit=NULL,
                       low_limit=NULL,
                       all=FALSE,
                       aggregate=FALSE) {
  
  # first need to get num of the parameter
  
  param_num <- which(levels(object@score_data@score_matrix$item_id)==param_name)
  
  # now get all the necessary components
  
  reg_diff <- as_draws_matrix(object@stan_samples$draws(paste0('B_int_free[',param_num,']')))[,1]
  reg_discrim <- as_draws_matrix(object@stan_samples$draws(paste0('sigma_reg_free[',param_num,']')))[,1]
  abs_diff <- as_draws_matrix(object@stan_samples$draws(paste0('A_int_free[',param_num,']')))[,1]
  abs_discrim <- as_draws_matrix(object@stan_samples$draws(paste0('sigma_abs_free[',param_num,']')))[,1]
  
  reg_mid <- reg_diff/reg_discrim
  abs_mid <- abs_diff/abs_discrim
  
  if(class(object@score_data@score_matrix$outcome_disc)=='factor') {
    cut_names <- levels(object@score_data@score_matrix$outcome_disc)
  } else {
    cut_names <- as.character(unique(object@score_data@score_matrix$outcome_disc))
  }
  if(!all) {
    reg_data <- data_frame(item_median=quantile(reg_mid,0.5),
                           item_high=quantile(reg_mid,high_limit),
                           item_low=quantile(reg_mid,low_limit),
                           item_type='Non-Inflated\nDiscrimination',
                           Outcome=cut_names[2],
                           item_name=param_name)
    
    abs_data <- data_frame(item_median=quantile(abs_mid,0.5),
                           item_high=quantile(abs_mid,high_limit),
                           item_low=quantile(abs_mid,low_limit),
                           item_type='Inflated\nDiscrimination',
                           Outcome='Missing',
                           item_name=param_name)
    
    out_d <- bind_rows(reg_data,abs_data)
    
    return(out_d)
    
  } else if(all && aggregate) {
    reg_data_mid <- data_frame(`Posterior Median`=quantile(reg_mid,0.5),
                               `High Posterior Interval`=quantile(reg_mid,high_limit),
                               `Low Posterior Interval`=quantile(reg_mid,low_limit),
                           `Item Type`='Non-Inflated Item Midpoint',
                           `Predicted Outcome`=cut_names[2],
                           `Item Name`=param_name,
                           `Parameter`=paste0('A function of other parameters'))
    
    abs_data_mid <- data_frame(`Posterior Median`=quantile(abs_mid,0.5),
                               `High Posterior Interval`=quantile(abs_mid,high_limit),
                               `Low Posterior Interval`=quantile(abs_mid,low_limit),
                           `Item Type`='Inflated Item Midpoint',
                           `Item Name`=param_name,
                           `Predicted Outcome`='Missing',
                           `Parameter`=paste0('A function of other parameters'))
    
    reg_data_discrim <- data_frame(`Posterior Median`=quantile(reg_discrim,0.5),
                                   `High Posterior Interval`=quantile(reg_discrim,high_limit),
                                   `Low Posterior Interval`=quantile(reg_discrim,low_limit),
                                   `Item Name`=param_name,
                               `Item Type`='Non-Inflated Discrimination',
                               `Predicted Outcome`=cut_names[2],
                               `Parameter`=paste0('sigma_reg_free[',param_name,']'))
    
    abs_data_discrim <- data_frame(`Posterior Median`=quantile(abs_discrim,0.5),
                                   `High Posterior Interval`=quantile(abs_discrim,high_limit),
                                   `Low Posterior Interval`=quantile(abs_discrim,low_limit),
                                   `Item Name`=param_name,
                               `Item Type`='Inflated Discrimination',
                               `Predicted Outcome`='Missing',
                               `Parameter`=paste0('sigma_abs_free[',param_name,']'))
    
    reg_data_diff <- data_frame(`Posterior Median`=quantile(reg_diff,0.5),
                                `High Posterior Interval`=quantile(reg_diff,high_limit),
                                `Low Posterior Interval`=quantile(reg_diff,low_limit),
                                `Item Name`=param_name,
                                   `Item Type`='Non-Inflated Difficulty',
                                   `Predicted Outcome`=cut_names[2],
                                   `Parameter`=paste0('B_int_free[',param_name,']'))
    
    abs_data_diff <- data_frame(`Posterior Median`=quantile(abs_diff,0.5),
                                `High Posterior Interval`=quantile(abs_diff,high_limit),
                                `Low Posterior Interval`=quantile(abs_diff,low_limit),
                                `Item Name`=param_name,
                                   `Item Type`='Inflated Difficulty',
                                   `Predicted Outcome`='Missing',
                                   `Parameter`=paste0('A_int_free[',param_name,']'))
    
    out_d <- bind_rows(reg_data_mid,abs_data_mid,reg_data_discrim,
                       abs_data_discrim,
                       reg_data_diff,
                       abs_data_diff)
    
    return(out_d)
  } else if(all && !aggregate) {
    reg_data_mid <- data_frame(Posterior_Sample=as.numeric(reg_mid),
                               `Item Name`=param_name,
                               `Item Type`='Non-Inflated Item Midpoint',
                               `Predicted Outcome`=cut_names[2],
                               `Parameter`='A function of other parameters') %>% 
      mutate(Iteration=1:n())
    
    abs_data_mid <- data_frame(`Posterior_Sample`=as.numeric(abs_mid),
                               `Item Name`=param_name,
                               `Item Type`='Inflated Item Midpoint',
                               `Predicted Outcome`='Missing',
                               `Parameter`='A function of other parameters') %>% 
      mutate(Iteration=1:n())
    
    reg_data_discrim <- data_frame(`Posterior_Sample`=as.numeric(reg_discrim),
                                   `Item Name`=param_name,
                                   `Item Type`='Non-Inflated Discrimination',
                                   `Predicted Outcome`=cut_names[2],
                                   `Parameter`=paste0('sigma_reg_free[',param_name,']')) %>% 
      mutate(Iteration=1:n())
    
    abs_data_discrim <- data_frame(`Posterior_Sample`=as.numeric(abs_discrim),
                                   `Item Name`=param_name,
                                   `Item Type`='Inflated Discrimination',
                                   `Predicted Outcome`='Missing',
                                   `Parameter`=paste0('sigma_abs_free[',param_name,']')) %>% 
      mutate(Iteration=1:n())
    
    reg_data_diff <- data_frame(`Posterior_Sample`=as.numeric(reg_diff),
                                `Item Name`=param_name,
                                `Item Type`='Non-Inflated Difficulty',
                                `Predicted Outcome`=cut_names[2],
                                `Parameter`=paste0('B_int_free[',param_name,']')) %>% 
      mutate(Iteration=1:n())
    
    abs_data_diff <- data_frame(`Posterior_Sample`=as.numeric(abs_discrim),
                                `Item Name`=param_name,
                                `Item Type`='Inflated Difficulty',
                                `Predicted Outcome`='Missing',
                                `Parameter`=paste0('A_int_free[',param_name,']')) %>% 
      mutate(Iteration=1:n())
    
    out_d <- bind_rows(reg_data_mid,abs_data_mid,reg_data_discrim,
                       abs_data_discrim,
                       reg_data_diff,
                       abs_data_diff)
    
    return(out_d)
  }

}

#' Generate item-level midpoints for ordinal-rating scale IRT outcomes
#' @noRd
.item_plot_ord_rs <- function(param_name,object,
                              high_limit=NULL,
                              low_limit=NULL,
                              all=FALSE,
                              aggregate=FALSE) {

  # first need to get num of the parameter
  
  param_num <- which(levels(object@score_data@score_matrix$item_id)==param_name)
  
  # now get all the necessary components
  
  reg_diff <- as_draws_matrix(object@stan_samples$draws(paste0('B_int_free[',param_num,']')))[,1]
  reg_discrim <- as_draws_matrix(object@stan_samples$draws(paste0('sigma_reg_free[',param_num,']')))[,1]
  abs_diff <- as_draws_matrix(object@stan_samples$draws(paste0('A_int_free[',param_num,']')))[,1]
  abs_discrim <- as_draws_matrix(object@stan_samples$draws(paste0('sigma_abs_free[',param_num,']')))[,1]
  
  cuts <- as_draws_df(object@stan_samples$draws('steps_votes'))
  
  if(class(object@score_data@score_matrix$outcome_disc)=='factor') {
    cut_names <- levels(object@score_data@score_matrix$outcome_disc)
  } else {
    cut_names <- as.character(unique(object@score_data@score_matrix$outcome_disc))
  }
  abs_mid <- abs_diff/abs_discrim
  # need to loop over cuts
  
  reg_data <- lapply(1:ncol(cuts), function(c) {
    reg_mid <- (reg_diff+cuts[[c]])/reg_discrim
    
    
    reg_data <- data_frame(item_median=quantile(reg_mid,0.5),
                           item_high=quantile(reg_mid,high_limit),
                           item_low=quantile(reg_mid,low_limit),
                           item_type='Non-Inflated\nDiscrimination',
                           Outcome=cut_names[c],
                           item_name=param_name)
    
    return(reg_data)
  }) %>% bind_rows
  
  abs_data <- data_frame(item_median=quantile(abs_mid,0.5),
                         item_high=quantile(abs_mid,high_limit),
                         item_low=quantile(abs_mid,low_limit),
                         item_type='Inflated\nDiscrimination',
                         Outcome='Missing',
                         item_name=param_name)
  
  out_d <- bind_rows(abs_data,reg_data)
  
  if(!all) {
    
    return(out_d)
  
} else if(all && aggregate) {
  
  # need to loop over cuts
  
  reg_data <- lapply(1:ncol(cuts), function(c) {
    reg_mid <- (reg_diff+cuts[[c]])/reg_discrim
    
    reg_data <- data_frame(`Posterior Median`=quantile(reg_mid,0.5),
                           `High Posterior Interval`=quantile(reg_mid,high_limit),
                           `Low Posterior Interval`=quantile(reg_mid,low_limit),
                           `Item Type`='Non-Inflated Item Midpoint',
                           `Predicted Outcome`=cut_names[c],
                           `Parameter`=param_name)
    
    
    
    return(reg_data)
  }) %>% bind_rows
  
  abs_data <- data_frame(`Posterior Median`=quantile(abs_mid,0.5),
                         `High Posterior Interval`=quantile(abs_mid,high_limit),
                         `Low Posterior Interval`=quantile(abs_mid,low_limit),
                         `Item Type`='Inflated Item Midpoint',
                         `Predicted Outcome`='Missing',
                         `Parameter`=param_name)
  
  reg_data_discrim <- data_frame(`Posterior Median`=quantile(reg_discrim,0.5),
                                 `High Posterior Interval`=quantile(reg_discrim,high_limit),
                                 `Low Posterior Interval`=quantile(reg_discrim,low_limit),
                                 `Item Type`='Non-Inflated Discrimination',
                                 `Predicted Outcome`=cut_names[2],
                                 `Parameter`=param_name)
  
  abs_data_discrim <- data_frame(`Posterior Median`=quantile(abs_discrim,0.5),
                                 `High Posterior Interval`=quantile(abs_discrim,high_limit),
                                 `Low Posterior Interval`=quantile(abs_discrim,low_limit),
                                 `Item Type`='Inflated Discrimination',
                                 `Predicted Outcome`='Missing',
                                 `Parameter`=param_name)
  
  reg_data_diff <- data_frame(`Posterior Median`=quantile(reg_diff,0.5),
                              `High Posterior Interval`=quantile(reg_diff,high_limit),
                              `Low Posterior Interval`=quantile(reg_diff,low_limit),
                              `Item Type`='Non-Inflated Difficulty',
                              `Predicted Outcome`=cut_names[2],
                              `Parameter`=param_name)
  
  abs_data_diff <- data_frame(`Posterior Median`=quantile(abs_discrim,0.5),
                              `High Posterior Interval`=quantile(abs_discrim,high_limit),
                              `Low Posterior Interval`=quantile(abs_discrim,low_limit),
                              `Item Type`='Inflated Difficulty',
                              `Predicted Outcome`='Missing',
                              `Parameter`=param_name)
  
  out_d <- bind_rows(reg_data,abs_data,reg_data_discrim,
                     abs_data_discrim,
                     reg_data_diff,
                     abs_data_diff)
  
  return(out_d)
} else if(all && !aggregate) {
  
  reg_data_mid <- lapply(1:ncol(cuts), function(c) {
    reg_mid <- (reg_diff+cuts[[c]])/reg_discrim
    
    reg_data_mid <- data_frame(Posterior_Sample=reg_mid,
                               `Item Type`='Non-Inflated Item Midpoint',
                               `Predicted Outcome`=cut_names[2],
                               `Parameter`=param_name) %>% 
      mutate(Iteration=1:n())
    
    
    
    return(reg_data_mid)
  }) %>% bind_rows

  
  abs_data_mid <- data_frame(`Posterior_Sample`=abs_mid,
                             `Item Type`='Inflated Item Midpoint',
                             `Predicted Outcome`='Missing',
                             `Parameter`=param_name) %>% 
    mutate(Iteration=1:n())
  
  reg_data_discrim <- data_frame(`Posterior_Sample`=reg_discrim,
                                 `Item Type`='Non-Inflated Discrimination',
                                 `Predicted Outcome`=cut_names[2],
                                 `Parameter`=param_name) %>% 
    mutate(Iteration=1:n())
  
  abs_data_discrim <- data_frame(`Posterior_Sample`=abs_discrim,
                                 `Item Type`='Inflated Discrimination',
                                 `Predicted Outcome`='Missing',
                                 `Parameter`=param_name) %>% 
    mutate(Iteration=1:n())
  
  reg_data_diff <- data_frame(`Posterior_Sample`=reg_diff,
                              `Item Type`='Non-Inflated Difficulty',
                              `Predicted Outcome`=cut_names[2],
                              `Parameter`=param_name) %>% 
    mutate(Iteration=1:n())
  
  abs_data_diff <- data_frame(`Posterior_Sample`=abs_discrim,
                              `Item Type`='Inflated Difficulty',
                              `Predicted Outcome`='Missing',
                              `Parameter`=param_name) %>% 
    mutate(Iteration=1:n())
  
  out_d <- bind_rows(reg_data_mid,abs_data_mid,reg_data_discrim,
                     abs_data_discrim,
                     reg_data_diff,
                     abs_data_diff)
  
  return(out_d)
}
  
}

#' Generate item-level midpoints for ordinal-GRM IRT outcomes
#' @noRd
.item_plot_ord_grm <- function(param_name,object,
                              high_limit=NULL,
                              low_limit=NULL,
                              all=FALSE,
                              aggregate=FALSE) {

  # first need to get num of the parameter
  
  param_num <- which(levels(object@score_data@score_matrix$item_id)==param_name)
  
  # now get all the necessary components
  
  reg_diff <- as_draws_matrix(object@stan_samples$draws(paste0('B_int_free[',param_num,']')))[,1]
  reg_discrim <- as_draws_matrix(object@stan_samples$draws(paste0('sigma_reg_free[',param_num,']')))[,1]
  abs_diff <- as_draws_matrix(object@stan_samples$draws(paste0('A_int_free[',param_num,']')))[,1]
  abs_discrim <- as_draws_matrix(object@stan_samples$draws(paste0('sigma_abs_free[',param_num,']')))[,1]

  # figure out how many categories we need
  
  total_cat <- length(as_draws_df(object@stan_samples$draws('steps_votes')))
  
  cuts <- as_draws_df(object@stan_samples$draws(paste0('steps_votes_grm[',param_num,',',total_cat,']')))
  
  if(class(object@score_data@score_matrix$outcome_disc)=='factor') {
    cut_names <- levels(object@score_data@score_matrix$outcome_disc)
  } else {
    cut_names <- as.character(unique(object@score_data@score_matrix$outcome_disc))
  }
  abs_mid <- abs_diff/abs_discrim
  # need to loop over cuts
  
  reg_data <- lapply(1:ncol(cuts), function(c) {
    reg_mid <- (reg_diff+cuts[[c]])/reg_discrim
    
    
    reg_data <- data_frame(item_median=quantile(reg_mid,0.5),
                           item_high=quantile(reg_mid,high_limit),
                           item_low=quantile(reg_mid,low_limit),
                           item_type='Non-Inflated\nDiscrimination',
                           Outcome=cut_names[c],
                           item_name=param_name)
    
    return(reg_data)
  }) %>% bind_rows
  
  abs_data <- data_frame(item_median=quantile(abs_mid,0.5),
                         item_high=quantile(abs_mid,high_limit),
                         item_low=quantile(abs_mid,low_limit),
                         item_type='Inflated\nDiscrimination',
                         Outcome='Missing',
                         item_name=param_name)
  
  out_d <- bind_rows(abs_data,reg_data)
  
  if(!all) {
    
    return(out_d)
    
  } else if(all && aggregate) {
    
    # need to loop over cuts
    
    reg_data <- lapply(1:ncol(cuts), function(c) {
      reg_mid <- (reg_diff+cuts[[c]])/reg_discrim
      
      reg_data <- data_frame(`Posterior Median`=quantile(reg_mid,0.5),
                             `High Posterior Interval`=quantile(reg_mid,high_limit),
                             `Low Posterior Interval`=quantile(reg_mid,low_limit),
                             `Item Type`='Non-Inflated Item Midpoint',
                             `Predicted Outcome`=cut_names[c],
                             `Parameter`=param_name)
      
      
      
      return(reg_data)
    }) %>% bind_rows
    
    abs_data <- data_frame(`Posterior Median`=quantile(abs_mid,0.5),
                           `High Posterior Interval`=quantile(abs_mid,high_limit),
                           `Low Posterior Interval`=quantile(abs_mid,low_limit),
                           `Item Type`='Inflated Item Midpoint',
                           `Predicted Outcome`='Missing',
                           `Parameter`=param_name)
    
    reg_data_discrim <- data_frame(`Posterior Median`=quantile(reg_discrim,0.5),
                                   `High Posterior Interval`=quantile(reg_discrim,high_limit),
                                   `Low Posterior Interval`=quantile(reg_discrim,low_limit),
                                   `Item Type`='Non-Inflated Discrimination',
                                   `Predicted Outcome`=cut_names[2],
                                   `Parameter`=param_name)
    
    abs_data_discrim <- data_frame(`Posterior Median`=quantile(abs_discrim,0.5),
                                   `High Posterior Interval`=quantile(abs_discrim,high_limit),
                                   `Low Posterior Interval`=quantile(abs_discrim,low_limit),
                                   `Item Type`='Inflated Discrimination',
                                   `Predicted Outcome`='Missing',
                                   `Parameter`=param_name)
    
    reg_data_diff <- data_frame(`Posterior Median`=quantile(reg_diff,0.5),
                                `High Posterior Interval`=quantile(reg_diff,high_limit),
                                `Low Posterior Interval`=quantile(reg_diff,low_limit),
                                `Item Type`='Non-Inflated Difficulty',
                                `Predicted Outcome`=cut_names[2],
                                `Parameter`=param_name)
    
    abs_data_diff <- data_frame(`Posterior Median`=quantile(abs_discrim,0.5),
                                `High Posterior Interval`=quantile(abs_discrim,high_limit),
                                `Low Posterior Interval`=quantile(abs_discrim,low_limit),
                                `Item Type`='Inflated Difficulty',
                                `Predicted Outcome`='Missing',
                                `Parameter`=param_name)
    
    out_d <- bind_rows(reg_data,abs_data,reg_data_discrim,
                       abs_data_discrim,
                       reg_data_diff,
                       abs_data_diff)
    
    return(out_d)
  } else if(all && !aggregate) {
    
    reg_data_mid <- lapply(1:ncol(cuts), function(c) {
      reg_mid <- (reg_diff+cuts[[c]])/reg_discrim
      
      reg_data_mid <- data_frame(Posterior_Sample=reg_mid,
                                 `Item Type`='Non-Inflated Item Midpoint',
                                 `Predicted Outcome`=cut_names[2],
                                 `Parameter`=param_name) %>% 
        mutate(Iteration=1:n())
      
      
      
      return(reg_data_mid)
    }) %>% bind_rows
    
    
    abs_data_mid <- data_frame(`Posterior_Sample`=abs_mid,
                               `Item Type`='Inflated Item Midpoint',
                               `Predicted Outcome`='Missing',
                               `Parameter`=param_name) %>% 
      mutate(Iteration=1:n())
    
    reg_data_discrim <- data_frame(`Posterior_Sample`=reg_discrim,
                                   `Item Type`='Non-Inflated Discrimination',
                                   `Predicted Outcome`=cut_names[2],
                                   `Parameter`=param_name) %>% 
      mutate(Iteration=1:n())
    
    abs_data_discrim <- data_frame(`Posterior_Sample`=abs_discrim,
                                   `Item Type`='Inflated Discrimination',
                                   `Predicted Outcome`='Missing',
                                   `Parameter`=param_name) %>% 
      mutate(Iteration=1:n())
    
    reg_data_diff <- data_frame(`Posterior_Sample`=reg_diff,
                                `Item Type`='Non-Inflated Difficulty',
                                `Predicted Outcome`=cut_names[2],
                                `Parameter`=param_name) %>% 
      mutate(Iteration=1:n())
    
    abs_data_diff <- data_frame(`Posterior_Sample`=abs_discrim,
                                `Item Type`='Inflated Difficulty',
                                `Predicted Outcome`='Missing',
                                `Parameter`=param_name) %>% 
      mutate(Iteration=1:n())
    
    out_d <- bind_rows(reg_data_mid,abs_data_mid,reg_data_discrim,
                       abs_data_discrim,
                       reg_data_diff,
                       abs_data_diff)
    
    return(out_d)
  }  
}

#' Generate item-level midpoints for binary latent-space outcomes
#' @noRd
.item_plot_ls <- function(param_name,object,
                              high_limit=NULL,
                              low_limit=NULL,
                          aggregate=F) {

  # first need to get num of the parameter
  
  param_num <- which(levels(object@score_data@score_matrix$item_id)==param_name)
  
  # now get all the necessary components
  
  reg_diff <- as_draws_matrix(object@stan_samples$draws(paste0('B_int_free[',param_num,']')))[,1]
  reg_discrim <- as_draws_matrix(object@stan_samples$draws(paste0('sigma_reg_free[',param_num,']')))[,1]
  abs_diff <- as_draws_matrix(object@stan_samples$draws(paste0('A_int_free[',param_num,']')))[,1]
  item_int <- as_draws_matrix(object@stan_samples$draws(paste0('sigma_abs_free[',param_num,']')))[,1]
  ideal_int <- as_draws_matrix(object@stan_samples$draws(paste0('ls_int[',param_num,']')))[,1]
  
  if(class(object@score_data@score_matrix$outcome_disc)=='factor') {
    cut_names <- levels(object@score_data@score_matrix$outcome_disc)
  } else {
    cut_names <- as.character(unique(object@score_data@score_matrix$outcome_disc))
  }
  
  reg_data <- data_frame(item_median=quantile(reg_diff,0.5),
                         item_high=quantile(reg_diff,high_limit),
                         item_low=quantile(reg_diff,low_limit),
                         item_type='Non-Inflated\nItem\nIdeal Point',
                         Outcome=cut_names[2],
                         item_name=param_name)
  
  abs_data <- data_frame(item_median=quantile(abs_diff,0.5),
                         item_high=quantile(abs_diff,high_limit),
                         item_low=quantile(abs_diff,low_limit),
                         item_type='Inflated\nItem\nIdeal Point',
                         Outcome='Missing',
                         item_name=param_name)
  
  out_d <- bind_rows(reg_data,abs_data)
  
  
  
  if(!all) {
    
    return(out_d)
    
  } else if(all && aggregate) {
    reg_data <- data_frame(item_median=quantile(reg_diff,0.5),
                           item_high=quantile(reg_diff,high_limit),
                           item_low=quantile(reg_diff,low_limit),
                           item_type='Non-Inflated Item Ideal Point',
                           Outcome=cut_names[2],
                           item_name=param_name,
                           Parameter=paste0('B_int_free[',param_num,']'))
    
    abs_data <- data_frame(item_median=quantile(abs_diff,0.5),
                           item_high=quantile(abs_diff,high_limit),
                           item_low=quantile(abs_diff,low_limit),
                           item_type='Inflated Item Ideal Point',
                           Outcome='Missing',
                           item_name=param_name,
                           Parameter=paste0('A_int_free[',param_num,']'))
    
    ideal_int <- data_frame(item_median=quantile(ideal_int,0.5),
                           item_high=quantile(ideal_int,high_limit),
                           item_low=quantile(ideal_int,low_limit),
                           item_type='Ideal Point Intercept',
                           Outcome=cut_names[2],
                           item_name=param_name,
                           Parameter=paste0('sigma_reg_free[',param_num,']'))
    
    item_int <- data_frame(item_median=quantile(item_int,0.5),
                           item_high=quantile(item_int,high_limit),
                           item_low=quantile(item_int,low_limit),
                           item_type='Item Intercept',
                           Outcome=cut_names[2],
                           item_name=param_name,
                           Parameter=paste0('sigma_abs_free[',param_num,']'))
    
    out_d <- bind_rows(reg_data,abs_data,ideal_int,item_int)
    
    return(out_d)
  } else if(all && !aggregate) {
    reg_data <- data_frame(Posterior_Sample=reg_diff,
                               `Item Name`=param_name,
                               `Item Type`='Non-Inflated Item Ideal Point',
                               `Predicted Outcome`=cut_names[2],
                               `Parameter`=paste0('B_int_free[',param_num,']')) %>% 
      mutate(Iteration=1:n())
    
    abs_data <- data_frame(`Posterior_Sample`=abs_diff,
                               `Item Name`=param_name,
                               `Item Type`='Inflated Item Ideal Point',
                               `Predicted Outcome`='Missing',
                               `Parameter`=paste0('A_int_free[',param_num,']')) %>% 
      mutate(Iteration=1:n())
    
    ideal_int <- data_frame(`Posterior_Sample`=ideal_int,
                                   `Item Name`=param_name,
                                   `Item Type`='Ideal Point Intercept',
                                   `Predicted Outcome`=cut_names[2],
                                   `Parameter`=paste0('sigma_reg_free[',param_name,']')) %>% 
      mutate(Iteration=1:n())
    
    item_int<- data_frame(`Posterior_Sample`=item_int,
                                   `Item Name`=param_name,
                                   `Item Type`='Item Intercept',
                                   `Predicted Outcome`='Missing',
                                   `Parameter`=paste0('sigma_abs_free[',param_name,']')) %>% 
      mutate(Iteration=1:n())
    
    out_d <- bind_rows(reg_data,abs_data,
                       ideal_int,item_int)
    
    return(out_d)
  }
  
}

#' a slightly hacked function to extract parameters as I want to
#' @noRd
.extract_nonp <- function(object, pars, permuted = TRUE, 
                                inc_warmup = FALSE, include = TRUE) {
            # Extract the samples in different forms for different parameters. 
            #
            # Args:
            #   object: the object of "stanfit" class 
            #   pars: the names of parameters (including other quantiles) 
            #   permuted: if TRUE, the returned samples are permuted without
            #     warming up. And all the chains are merged. 
            #   inc_warmup: if TRUE, warmup samples are kept; otherwise, 
            #     discarded. If permuted is TRUE, inc_warmup is ignored. 
            #   include: if FALSE interpret pars as those to exclude
            #
            # Returns:
            #   If permuted is TRUE, return an array (matrix) of samples with each
            #   column being the samples for a parameter. 
            #   If permuted is FALSE, return array with dimensions
            #   (# of iter (with or w.o. warmup), # of chains, # of flat parameters). 
            
            if (object@mode == 1L) {
              cat("Stan model '", object@model_name, "' is of mode 'test_grad';\n",
                  "sampling is not conducted.\n", sep = '')
              return(invisible(NULL)) 
            } else if (object@mode == 2L) {
              cat("Stan model '", object@model_name, "' does not contain samples.\n", sep = '') 
              return(invisible(NULL)) 
            } 
            
            if(!include) pars <- setdiff(object@sim$pars_oi, pars)
            pars <- if (missing(pars)) object@sim$pars_oi else .check_pars_second(object@sim, pars) 
            pars <- .remove_empty_pars(pars, object@sim$dims_oi)
            tidx <- .pars_total_indexes(object@sim$pars_oi, 
                                       object@sim$dims_oi, 
                                       object@sim$fnames_oi, 
                                       pars) 
            
            n_kept <- object@sim$n_save - object@sim$warmup2
            fun1 <- function(par_i) {
              # sss <- sapply(tidx[[par_i]], get_kept_samples2, object@sim)
              # if (is.list(sss))  sss <- do.call(c, sss)
              # the above two lines are slower than the following line of code
              sss <- do.call(cbind, lapply(tidx[[par_i]], .get_kept_samples2, object@sim)) 
              dim(sss) <- c(sum(n_kept), object@sim$dims_oi[[par_i]]) 
              dimnames(sss) <- list(iterations = NULL)
              attr(sss,'num_chains') <- object@sim$chains
              attr(sss,'chain_order') <- rep(1:object@sim$chains,each=dim(sss)[1]/object@sim$chains)

              sss 
            } 
            
            if (permuted) {
              slist <- lapply(pars, fun1) 
              names(slist) <- pars 
              return(slist) 
            } 
            
            tidx <- unlist(tidx, use.names = FALSE) 
            tidxnames <- object@sim$fnames_oi[tidx] 
            sss <- lapply(tidx, .get_kept_samples2, object@sim, inc_warmup) 
            sss2 <- lapply(sss, function(x) do.call(c, x))  # concatenate samples from different chains
            sssf <- unlist(sss2, use.names = FALSE) 
            
            n2 <- object@sim$n_save[1]  ## assuming all the chains have equal iter 
            if (!inc_warmup) n2 <- n2 - object@sim$warmup2[1] 
            dim(sssf) <- c(n2, object@sim$chains, length(tidx)) 
            cids <- sapply(object@stan_args, function(x) x$chain_id)
            dimnames(sssf) <- list(iterations = NULL, chains = paste0("chain:", cids), parameters = tidxnames)
            sssf 
          }


#' we are going to modify this rstan function so that it no longer permutes
#' just delete the last term -- maybe submit PR to rstan
#' @noRd
.get_kept_samples2 <- function(n, sim) {

  # a different implementation of get_kept_samples 
  # It seems this one is faster than get_kept_samples 
  # TODO: to understand why it is faster? 
  lst <- vector("list", sim$chains)
  for (ic in 1:sim$chains) { 
    if (sim$warmup2[ic] > 0) 
      lst[[ic]] <- sim$samples[[ic]][[n]][-(1:sim$warmup2[ic])]
    else 
      lst[[ic]] <- sim$samples[[ic]][[n]]
  } 
  out <- do.call(c, lst)
}

#' another hacked function
#' @noRd
.check_pars_second <- function(sim, pars) {
  #
  # Check if all parameters in pars are parameters for which we saved
  # their samples
  #
  # Args:
  #   sim: The sim slot of class stanfit
  #   pars: a character vector of parameter names
  #
  # Returns:
  #   pars without white spaces, if any, if all are valid
  #   otherwise stop reporting error
  if (missing(pars)) return(sim$pars_oi)
  allpars <- c(sim$pars_oi, sim$fnames_oi)
  .check_pars(allpars, pars)
}

#' another hacked function
#' @noRd
.check_pars <- function(allpars, pars) {
  pars_wo_ws <- gsub('\\s+', '', pars)
  m <- which(match(pars_wo_ws, allpars, nomatch = 0) == 0)
  if (length(m) > 0)
    stop("no parameter ", paste(pars[m], collapse = ', '))
  if (length(pars_wo_ws) == 0)
    stop("no parameter specified (pars is empty)")
  unique(pars_wo_ws)
}

#' yet another hacked function
#' @noRd
.remove_empty_pars <- function(pars, model_dims) {
  #
  # Remove parameters that are actually empty, which
  # could happen when for exmample a user specify the
  # following stan model code:
  #
  # transformed data { int n; n <- 0; }
  # parameters { real y[n]; }
  #
  # Args:
  #   pars: a character vector of parameters names
  #   model_dims: a named list of the parameter dimension
  #
  # Returns:
  #   A character vector of parameter names with empty parameter
  #   being removed.
  #
  ind <- rep(TRUE, length(pars))
  model_pars <- names(model_dims)
  if (is.null(model_pars)) stop("model_dims need be a named list")
  for (i in seq_along(pars)) {
    p <- pars[i]
    m <- match(p, model_pars)
    if (!is.na(m) && prod(model_dims[[p]]) == 0)  ind[i] <- FALSE
  }
  pars[ind]
}

#' yet another hacked function
#' @noRd
.pars_total_indexes <- function(names, dims, fnames, pars) {
# Obtain the total indexes for parameters (pars) in the
# whole sequences of names that is order by 'column major.'
# Args:
#   names: all the parameters names specifying the sequence of parameters
#   dims:  the dimensions for all parameters, the order for all parameters
#          should be the same with that in 'names'
#   fnames: all the parameter names specified by names and dims
#   pars:  the parameters of interest. This function assumes that
#     pars are in names.
# Note: inside each parameter (vector or array), the sequence is in terms of
#   col-major. That means if we have parameter alpha and beta, the dims
#   of which are [2,2] and [2,3] respectively.  The whole parameter sequence
#   are alpha[1,1], alpha[2,1], alpha[1,2], alpha[2,2], beta[1,1], beta[2,1],
#   beta[1,2], beta[2,2], beta[1,3], beta[2,3]. In addition, for the col-majored
#   sequence, an attribute named 'row_major_idx' is attached, which could
#   be used when row major index is favored.

starts <- .calc_starts(dims)
par_total_indexes <- function(par) {
  # for just one parameter
  #
  p <- match(par, fnames)
  # note that here when `par' is a scalar, it would
  # match one of `fnames'
  if (!is.na(p)) {
    names(p) <- par
    attr(p, "row_major_idx") <- p
    return(p)
  }
  p <- match(par, names)
  np <- .num_pars(dims[[p]])
  if (np == 0) return(NULL)
  idx <- starts[p] + seq(0, by = 1, length.out = np)
  names(idx) <- fnames[idx]
  attr(idx, "row_major_idx") <- starts[p] + .idx_col2rowm(dims[[p]]) - 1
  idx
}
idx <- lapply(pars, FUN = par_total_indexes)
nulls <- sapply(idx, is.null)
idx <- idx[!nulls]
names(idx) <- pars[!nulls]
idx
}

#yet another hacked function
#' @noRd
.calc_starts <- function(dims) {
  len <- length(dims)
  s <- sapply(unname(dims), function(d)  .num_pars(d), USE.NAMES = FALSE)
  cumsum(c(1, s))[1:len]
}

#' yet another hacked function
#' @noRd
.num_pars <- function(d) prod(d)

#' yet another hacked function
#' @noRd
.idx_col2rowm <- function(d) {
# Suppose an iteration of samples for an array parameter is ordered by
# col-major. This function generates the indexes that can be used to change
# the sequences to row-major.
# Args:
#   d: the dimension of the parameter
len <- length(d)
if (0 == len) return(1)
if (1 == len) return(1:d)
idx <- aperm(array(1:prod(d), dim = d))
return(as.vector(idx))
}

#' A wrapper around na_if that also works on factors
#' @noRd
.na_if <- function(x,to_na=NULL,discrete_mods=NULL) {
  
    if(is.factor(x)) {
      levels(x)[levels(x)==to_na] <- NA
    } else {
      x <- na_if(x,to_na)
    }
  
    return(x)
  
}

#' Calculate priors for Gaussian processes
#' @noRd
.gp_prior <- function(time_points=NULL) {

  # need to calculate minimum and maximum difference between *any* two points
  diff_time <- diff(time_points)
  min_diff <- min(diff_time)+1
  # divide max_diff by 2 to constrain the prior away from very large values
  max_diff <- abs(time_points[1]-time_points[2])*2
  
  # now run the stan program with the data
  
  fit <- sampling(object = stanmodels[['gp_prior_tune']], iter=1, warmup=0, chains=1,
              seed=5838298, algorithm="Fixed_param")
  params <- extract(fit)
  
  return(list(a=c(params$a),
              b=c(params$b)))
  
}

#' Function to calculate IRFs
#' @noRd
.irf <- function( time=1,shock=1,
                  adj_in=NULL,
                  y_1=0,
                  total_t=10,
                  old_output=NULL) {
  
  # set up the exogenous shock
  # unless the shock comes from an exogenous covariate beta_x
  if(time==1) {
    x_1 <- shock
  } else {
    x_1 <- 0
  }
  
  print(paste0('Now processing time point ',time))
  
  # Calculate current values of y and x given posterior uncertainty
  
  output <- data_frame(y_shock= adj_in*y_1 + x_1,
                       time=time,
                       iter=1:length(adj_in))
  
  
  if(!is.null(old_output)) {
    new_output <- bind_rows(old_output,output)
  } else {
    new_output <- output
  }
  
  # run function recursively until time limit is reached
  
  if(time<total_t) {
    .irf(time=time+1,
         adj_in=adj_in,
         y_1=output$y_shock,
         total_t=total_t,
         old_output=new_output)
  } else {
    return(new_output)  
  }
}

#' Function to create table/matrix of which rows of the data
#' correspond to which model types.
#' @noRd
.count_cats <- function(modelpoints=NULL,
                        billpoints=NULL,
                        Y_int=NULL,
                        discrete=NULL,
                        within_chain=NULL,
                        pad_id=NULL) {

  if(length(Y_int)>1 && any(unique(modelpoints) %in% c(3,4,5,6))) {
      
      # count cats for ordinal models 
      
      get_counts <- group_by(tibble(modelpoints=modelpoints[discrete==1],
                                    billpoints=billpoints[discrete==1],
                                    Y_int),billpoints) %>% 
        group_by(modelpoints,billpoints) %>% 
        summarize(num_cats=length(unique(Y_int))) %>% 
        mutate(num_cats_rat=if_else(modelpoints==3 & num_cats<3,
                                3L,
                                num_cats),
               num_cats_rat=if_else(modelpoints==4 & num_cats<4,
                                3L,
                                num_cats_rat),
               order_cats_rat=as.numeric(factor(num_cats_rat)),
               num_cats_grm=if_else(modelpoints==5 & num_cats<3,
                                    3L,
                                    num_cats),
               num_cats_grm=if_else(modelpoints==6 & num_cats<4,
                                    3L,
                                    num_cats_grm),
               order_cats_grm=as.numeric(factor(num_cats_grm))) 
      
      num_cats_rat <- sort(unique(get_counts$num_cats_rat))
      num_cats_grm <- sort(unique(get_counts$num_cats_grm))
      
      # need to zero out non-present categories
      
      n_cats_rat <- ifelse(3:10 %in% num_cats_rat,3:10,1L)
      n_cats_grm <- ifelse(3:10 %in% num_cats_grm,3:10,1L)
      
      # join the data back together
      
      out_data <- left_join(tibble(modelpoints=modelpoints[discrete==1],
                                   billpoints=billpoints[discrete==1],
                                   Y_int),
                            select(get_counts,
                                   -num_cats_rat,
                                   -num_cats_grm),
                            by=c("modelpoints","billpoints")) 
      out_data$order_cats_grm[is.na(out_data$order_cats_grm)] <- 0L
      out_data$order_cats_rat[is.na(out_data$order_cats_rat)] <- 0L
      
    } else {
      
      out_data <- tibble(order_cats_grm=rep(0L,length(modelpoints[discrete==1])),
                         order_cats_rat=rep(0L,length(modelpoints[discrete==1])))
      n_cats_rat <- rep(1L,length(3:10))
      n_cats_grm <- rep(1L,length(3:10))
    }
  
  return(list(order_cats_rat=out_data$order_cats_rat,
              order_cats_grm=out_data$order_cats_grm,
              n_cats_rat=n_cats_rat,
              n_cats_grm=n_cats_grm))
}

#' Function to figure out how to remove missing values from
#' data before running models.
#' @noRd
.remove_nas <- function( Y_int=NULL,
                        Y_cont=NULL,
                        discrete=NULL,
                        legispoints=NULL,
                        billpoints=NULL,
                        timepoints=NULL,
                        modelpoints=NULL,
                        ordered_id=NULL,
                        idealdata=NULL,
                        time_ind=NULL,
                        time_proc=NULL,
                        gp_sd_par=NULL,
                        num_diff=NULL,
                        m_sd_par=NULL,
                        min_length=NULL,
                        const_type=NULL,
                        legis_sd=NULL,
                        restrict_sd_high=NULL,
                        restrict_sd_low=NULL,
                        restrict_high=NULL,
                        restrict_low=NULL,
                        ar_sd=NULL,
                        diff_reg_sd=NULL,
                        diff_miss_sd=NULL,
                        discrim_reg_sd=NULL,
                        discrim_miss_sd=NULL,
                        fix_high=NULL,
                        fix_low=NULL) {
  

  # need to determine which missing values should not be considered
  # only remove missing values if non-inflated model is used

  if(length(Y_cont)>1 && !is.na(idealdata@miss_val[2])) {

      Y_cont <- ifelse(modelpoints %in% c(10,12),
                       Y_cont,
                       .na_if(Y_cont,as.numeric(idealdata@miss_val[2])))
  }

  if(length(Y_int)>1 && !is.na(idealdata@miss_val[1])) {

      Y_int <- if_else(modelpoints %in% c(0,2,
                                                      4,
                                                      6,
                                                      8,
                                                      14),
                      Y_int,
                      .na_if(Y_int,idealdata@miss_val[1]))
      
    
    # need to downward adjust Y_int
    # convert from factor back to numeric as we have dealt with missing data
    # drop unused levels
    # need to get back to zero index
    if(levels(Y_int)[1]=="0") {
      
      Y_int <- as.numeric(Y_int) - 1
      
    } else {
      
      Y_int <- as.numeric(Y_int)
      
    }
    
    
    #Y_int[modelpoints %in% c(1,2) & Y_int<3] <- Y_int[modelpoints %in% c(1,2)  & Y_int<3] - 1

  }
  
  idealdata@Y_int <- Y_int
  idealdata@Y_cont <- Y_cont

  # now need to calculate the true remove NAs
  # if within chain, we by definition won't have any NAs
  remove_nas_cont <- !is.na(Y_cont)
  remove_nas_int <- !is.na(Y_int)

  # this works because the data were sorted in id_make
  if(length(Y_cont)>1 && length(Y_int)>1) {

      remove_nas <- remove_nas_int & remove_nas_cont
    
  } else if(length(Y_cont)>1) {
    remove_nas <- remove_nas_cont
  } else {
    remove_nas <- remove_nas_int
  }
   
    if(length(Y_cont)>1) {
      Y_cont <- Y_cont[remove_nas]
      N_cont <- length(Y_cont)
    } else {
      N_cont <- 0
      Y_cont <- array(dim=c(0)) + 0
    }
    
    if(length(Y_int)>1) {
      Y_int <- Y_int[remove_nas]
      N_int <- length(Y_int)
    } else {
      N_int <- 0
      Y_int <- array(dim=c(0)) + 0L
    }
    
    N <- pmax(N_int, N_cont)
    
    legispoints <- legispoints[remove_nas]
    billpoints <- billpoints[remove_nas]
    timepoints <- timepoints[remove_nas]
    modelpoints <- modelpoints[remove_nas]
    ordered_id <- ordered_id[remove_nas]
    discrete <- discrete[remove_nas]
    
    # no padding necessary
    
    if(any(unique(modelpoints) %in% c(1,13)) && length(table(Y_int[modelpoints %in% c(1,13)]))>3) {
      stop('Too many values in score matrix for a binary model. Choose a different model_type.')
    } else if(any(unique(modelpoints) %in% c(2,14)) && length(table(Y_int[modelpoints %in% c(2,14)]))>4) {
      stop("Too many values in score matrix for a binary model. Choose a different model_type.")
    }
    
    # use zero values for map_rect stuff
    
    all_int_array <- array(dim=c(0,0)) + 0L
    all_cont_array <- array(dim=c(0,0)) + 0L
    
    Y_cont_map <- 0
    N_cont_map <- 0
    Y_int_map <- 0
    N_int_map <- 0
    N_cont_map <- 0
    N_map <- 0
    
    # create covariates
    
    legis_pred <- as.matrix(select(idealdata@score_matrix,
                                      idealdata@person_cov))[remove_nas,,drop=F]
    
    srx_pred <- as.matrix(select(idealdata@score_matrix,
                                 idealdata@item_cov))[remove_nas,,drop=F]
    
    sax_pred <- as.matrix(select(idealdata@score_matrix,
                                 idealdata@item_cov_miss))[remove_nas,,drop=F]
    
    LX <- ncol(legis_pred)
    SRX <- ncol(srx_pred)
    SAX <- ncol(sax_pred)
    
    if(!is.infinite(max(Y_int))) {
      
      if(N_cont>0) {
        
        # Top level is always joint posterior
        
        y_int_miss <- max(Y_int) - 1
        
      } else {
        
        y_int_miss <- max(Y_int)
        
      }
      
      
      
    } else {
      y_int_miss <- 0
    }
    
    if(!is.infinite(max(Y_cont))) {
      y_cont_miss <- max(Y_cont)
    } else {
      y_cont_miss <- 0
    }
 

  max_t <- max(timepoints,na.rm=T)
  num_bills <- max(billpoints,na.rm=T)
  num_legis <- max(legispoints)
  
  if(any(c(5,6) %in% modelpoints)) {
    num_bills_grm <- num_bills
  } else {
    num_bills_grm <- 0L
  }
  
  if(any(c(13,14) %in% modelpoints)) {
    num_ls <- num_legis
  } else {
    num_ls <- 0L
  }
  
  
  
  
  # now need to determine number of categories

  # need to calculate number of categories for ordinal models
  
  if(N_int>0) {
    
    order_cats_rat <- ordered_id
    order_cats_grm <- ordered_id
    
  } else {
    
    order_cats_rat <- array(dim=c(0)) + 0L
    order_cats_grm <- array(dim=c(0)) + 0L
    
  }

    
    
    if(any(modelpoints %in% c(3,4))) {
      n_cats_rat <- unique(order_cats_rat)
    } else {
      n_cats_rat <- 0
    } 
    
    if(any(modelpoints %in% c(5,6))) {
      n_cats_grm <- unique(order_cats_grm)
    } else {
      n_cats_grm <- 0
    }
  
  n_cats_rat <- sapply(3:10,function(s) {
    if(s %in% n_cats_rat) {
      s
    } else {
      1
    }
  })
  
  n_cats_grm <- sapply(3:10,function(s) {
    if(s %in% n_cats_grm) {
      s
    } else {
      1
    }
  })
  

    
    if(length(time_ind)==1) {
      tibble_time <- tibble(time_ind=rep(time_ind,nrow(idealdata@score_matrix)))
    }
  
  
    return(list(Y_int=Y_int,
                Y_cont=Y_cont,
                legispoints=legispoints,
                billpoints=billpoints,
                timepoints=timepoints,
                modelpoints=modelpoints,
                remove_nas=remove_nas,
                N=N,
                y_cont_miss=y_cont_miss,
                y_int_miss=y_int_miss,
                discrete=discrete,
                max_t=max_t,
                num_bills=num_bills,
                num_legis=num_legis,
                num_ls=num_ls,
                num_bills_grm=num_bills_grm,
                N_cont=N_cont,
                N_int=N_int,
                order_cats_rat=order_cats_rat,
                order_cats_grm=order_cats_grm,
                n_cats_rat=n_cats_rat,
                n_cats_grm=n_cats_grm,
                legis_pred=legis_pred,
                srx_pred=srx_pred,
                sax_pred=sax_pred,
                LX=LX,
                SRX=SRX,
                SAX=SAX,
                idealdata=idealdata))
  
  
}

# Need functions to calculate predicted outcomes

#' Bernoulli
#' @noRd
.cov_bern <- function(lin_val=NULL,...) {
  
  mean(plogis(lin_val)-0.5)
  
}

#' Ordinal outcomes
#' @noRd
.cov_ord <- function(lin_val=NULL,
                      covp=NULL,
                      K=null,
                      ...) {
  if(K==1) {
    1 - mean(plogis(lin_val - covp[,1]))
  } else if(K>1 && K<K) {
    mean(plogis(lin_val - covp[,1]) - plogis(lin_val - covp[,2]))
  } else {
    plogis(lin_val - covp[,1])
  }
  
  
}

#' Poisson
#' @noRd
.cov_pois <- function(lin_val,...) {
  mean(exp(lin_val))
}

#' Normal
#' @noRd
.cov_norm <- function(lin_val,...) {
  mean(lin_val)
}

#' Log-Normal
#' @noRd
.cov_lnorm <- function(lin_val,...) {
  mean(exp(lin_val))
}

#' Latent-Space
#' @noRd
.cov_latsp <- function(lin_val,...) {
  mean(plogis(lin_val)-0.5)
}

#' How to find cutpoints for id_plot_cov function
#' @noRd
.get_cuts_cov <- function(k=NULL,
                          m=NULL,
                          i=NULL,
                          sigma_all=NULL,
                          K=NULL,
                          obj=NULL,
                          these_items=NULL) {
  
  
  if(m %in% c(3,4)) {
    
    # easy peesy, just get the right intercept for k
    
    if(k==1) {
      cutp <- as.matrix(rstan::extract(obj@stan_samples,paste0("steps_votes",K,"[",k,"]"))[[1]])
    } else if(k>1 && k<K) {
      cutp <- cbind(as.matrix(rstan::extract(obj@stan_samples,paste0("steps_votes",K,"[",k-1,"]"))[[1]]),
                    as.matrix(rstan::extract(obj@stan_samples,paste0("steps_votes",K,"[",k,"]"))[[1]]))
    } else {
      cutp <- as.matrix(rstan::extract(obj@stan_samples,paste0("steps_votes",K,"[",k-1,"]"))[[1]])
    }
    
    return(cutp)
    
  } else {
    
    # need to determine the right graded response intercept based on sigma_all then return the cutpoint
    item_num <- these_items[i]
    if(k==1) {
      cutp <- as.matrix(rstan::extract(obj@stan_samples,paste0("grm_steps_votes",K,"[",item_num,",",k,"]"))[[1]])
    } else if(k>1 && k<K) {
      cutp <- cbind(as.matrix(rstan::extract(obj@stan_samples,paste0("steps_votes",K,"[",item_num,",",k-1,"]"))[[1]]),
                    as.matrix(rstan::extract(obj@stan_samples,paste0("steps_votes",K,"[",item_num,",",k,"]"))[[1]]))
    } else {
      cutp <- as.matrix(rstan::extract(obj@stan_samples,paste0("steps_votes",K,"[",item_num,",",k-1,"]"))[[1]])
    }
  }
  
  return(cutp)
}

#' Function to square data for map_rect
#' @noRd
.make_sum_vals <- function(this_data,map_over_id=NULL,use_groups=FALSE,
                           remove_nas=NULL) {
  
  this_data <- this_data %>% 
    filter(remove_nas)
  
  # need to save original order to reconvert if necessary
  
  this_data$orig_order <- 1:nrow(this_data)
  
  # need a matrix equal to each ID and row number for where it starts/ends
  
  if(map_over_id=="persons") {
    if(use_groups) {
      
      this_data <- dplyr::arrange(this_data, group_id,time_id) 
      
      sum_vals <- this_data %>% 
        mutate(rownum=row_number()) %>% 
        group_by(group_id) %>% 
        filter(row_number() %in% c(1,n())) %>% 
        select(group_id,rownum) %>% 
        mutate(type=c("start","end")[1:n()]) %>% 
        spread(key="type",value = "rownum") %>% 
        ungroup %>% 
        select(group_id,start,end) %>% 
        mutate(group_id=as.numeric(group_id),
               end=coalesce(end,start))
      
      
    } else {
      
      this_data <- dplyr::arrange(this_data,person_id,time_id)
        
        sum_vals <- this_data %>% 
          mutate(rownum=row_number()) %>% 
          group_by(person_id) %>% 
          filter(row_number() %in% c(1,n())) %>% 
          select(person_id,rownum) %>% 
          mutate(type=c("start","end")[1:n()]) %>% 
          spread(key="type",value = "rownum") %>% 
          ungroup %>% 
          select(person_id,start,end) %>% 
          mutate(person_id=as.numeric(person_id),
                 end=coalesce(end,start))
      
    }
  } else {
    
    this_data <- dplyr::arrange(this_data, item_id, time_id)
    
    sum_vals <- this_data %>% 
      mutate(rownum=row_number()) %>% 
      group_by(item_id) %>% 
      filter(row_number() %in% c(1,n())) %>% 
      select(item_id,rownum) %>% 
      mutate(type=c("start","end")[1:n()]) %>% 
      spread(key="type",value = "rownum") %>% 
      ungroup %>% 
      select(item_id,start,end) %>% 
      mutate(item_id=as.numeric(item_id),
             end=coalesce(end,start))
    
  }
  
  return(list(sum_vals=sum_vals,this_data=this_data))
  
}


#' Need new function to re-create time-varying ideal points given reduce sum
#' @importFrom tidyr unite
#' @noRd
.get_varying <- function(obj) {
  
  if(obj@use_groups) {
    obj@score_data@score_matrix$person_id <- obj@score_data@score_matrix$group_id
  }
  
  if(obj@map_over_id=="items") {
    
    # needs to be in the same format, varying in T then person
      
      all_time <- obj@stan_samples$draws("L_tp1") %>% as_draws_matrix()
    
  } else {
      
    L_tp1_var <- obj@stan_samples$draws("L_tp1_var") %>% as_draws_matrix()
    
    
    if(obj@time_proc==2 && length(unique(obj@score_data@score_matrix$time_id))<50) {
      
        
        L_full <- obj@stan_samples$draws("L_full") %>% as_draws_matrix()
        
        time_var_free <- obj@stan_samples$draws("time_var_free") %>% as_draws_matrix()
      
      #make a grid, time varying fastest
      
      time_grid <- expand.grid(1:length(unique(obj@score_data@score_matrix$time_id)),
                               unique(as.numeric(obj@score_data@score_matrix$person_id)))
      
      time_func <- function(t=NULL,
                            points=NULL,
                            prior_est=NULL,
                            time_var_free=NULL,
                            initial=NULL,
                            L_full=NULL,
                            p=NULL,
                            L_tp1_var=NULL) {
        
        if(obj@restrict_var) {
          
          time_fix_sd <- obj@time_fix_sd
          p_time <- p - 1
          
        } else {
          
          time_fix_sd <- time_var_free[,p]
          p_time <- p
          
        }
        
        if(p>1) {
          
          if(t==2) {
            
            prior_est <- initial + time_var_free[,p_time]*L_tp1_var[,(time_grid$Var1==(t-1) & time_grid$Var2==p)]
            
            prior_est <- cbind(initial,prior_est)
            
          } else {
            
            this_t <- prior_est[,t-1]  + time_var_free[,p_time]*L_tp1_var[,(time_grid$Var1==(t-1) & time_grid$Var2==p)]
            prior_est <- cbind(prior_est,this_t)
            
            
          }
          
          if(t<max(points)) { 
            
            time_func(t=t+1,
                      points=points,
                      prior_est=prior_est,
                      time_var_free=time_var_free,
                      p=p,
                      L_full=L_full,
                      L_tp1_var=L_tp1_var)
          } else {
            # break recursion
            
            out_d <- as_tibble(prior_est) 
            names(out_d) <- as.character(1:length(unique(obj@score_data@score_matrix$time_id)))
            
            out_d <- mutate(out_d,person_id=p,
                            iter=1:n())
            
            return(out_d)
          }
          
        } else {
          
          if(t==2) {
            
            prior_est <- initial + time_fix_sd*L_tp1_var[,(time_grid$Var1==(t-1) & time_grid$Var2==p)]
            
            prior_est <- cbind(initial,prior_est)
            
          } else {
            
            this_t <- prior_est[,t-1]  + time_fix_sd*L_tp1_var[,(time_grid$Var1==(t-1) & time_grid$Var2==p)]
            prior_est <- cbind(prior_est,this_t)
            
            
          }
          
          if(t<max(points)) { 
            
            time_func(t=t+1,
                      points=points,
                      prior_est=prior_est,
                      time_var_free=time_var_free,
                      p=p,
                      L_full=L_full,
                      L_tp1_var=L_tp1_var)
          } else {
            # break recursion
            
            out_d <- as_tibble(prior_est) 
            names(out_d) <- as.character(1:length(unique(obj@score_data@score_matrix$time_id)))
            
            out_d <- mutate(out_d,person_id=p,
                            iter=1:n())
            
            return(out_d)
          }
          
        }
        
        
        # we don't do anything here because we need to return results from the
        # enclosing function call above
        
      }
      
      all_time <- lapply(unique(as.numeric(obj@score_data@score_matrix$person_id)), 
                         function (p) {
        
          initial <- L_tp1_var[,(time_grid$Var1==1 & time_grid$Var2==p)]
        
          time_func(t=2,
                           points=1:length(unique(obj@score_data@score_matrix$time_id)),
                           time_var_free=time_var_free,
                           initial=initial,
                           p=p,
                           L_tp1_var=L_tp1_var)
          
          
        }) %>% bind_rows()
      
      # need to reformat by spreading in correct manner
      # one row per sample
      # make joint time-person IDs
      
      all_time <- gather(all_time,"time_id",value="estimate",
                         -person_id,-iter) %>% 
        mutate(time_id=as.numeric(time_id)) %>% 
        arrange(person_id,time_id) %>% 
        unite(col='key',time_id,person_id) %>% 
        mutate(key2=factor(key,levels=unique(key)),
               key3=as.numeric(key2)) %>% 
        select(-key,-key2) %>% 
        spread(key="key3",value="estimate") %>% 
        select(-iter) %>% 
        as.matrix
      
      time_grid <- expand.grid(1:length(unique(obj@score_data@score_matrix$time_id)),
                               unique(as.numeric(obj@score_data@score_matrix$person_id)))
      
      colnames(all_time) <- paste0("L_tp1[",time_grid$Var1,",",time_grid$Var2,"]")

      
    } else if(obj@time_proc==3  && length(unique(obj@score_data@score_matrix$time_id))>50) {
        
        L_full <- obj@stan_samples$draws("L_full") %>% as_draws_matrix()
        
        time_var_free <- obj@stan_samples$draws("time_var_free") %>% as_draws_matrix()
        
        L_AR1 <- obj@stan_samples$draws("L_AR1") %>% as_draws_matrix()
        
      #make a grid, time varying fastest
      
      time_grid <- expand.grid(1:length(unique(obj@score_data@score_matrix$time_id)),
                               unique(as.numeric(obj@score_data@score_matrix$person_id)))
      
      # what we use for the recursion
      
      time_func <- function(t=NULL,
                            points=NULL,
                            prior_est=NULL,
                            time_var_free=NULL,
                            initial=NULL,
                            L_AR1=NULL,
                            L_full=NULL,
                            p=NULL,
                            L_tp1_var=NULL) {
        
          
        if(obj@restrict_var) {
          
          time_fix_sd <- obj@time_fix_sd
          p_time <- p - 1
          
        } else {
          
          time_fix_sd <- time_var_free[,p]
          p_time <- p
          
        }
        

          if(p>1) {
            
            if(t==2) {
              
              prior_est <- L_full[,p] + L_AR1[,p]*initial + time_var_free[,p_time]*L_tp1_var[,(time_grid$Var1==t & time_grid$Var2==p)]
              
              prior_est <- cbind(initial,prior_est)
              
            } else {
              
              this_t <- L_full[,p] + L_AR1[,p]*prior_est[,t-1]  + time_var_free[,p_time]*L_tp1_var[,(time_grid$Var1==t & time_grid$Var2==p)]
              prior_est <- cbind(prior_est,this_t)
              
              
            }
            
            if(t<max(points)) { 
              
              time_func(t=t+1,
                        points=points,
                        prior_est=prior_est,
                        time_var_free=time_var_free,
                        p=p,
                        L_AR1=L_AR1,
                        L_full=L_full,
                        L_tp1_var=L_tp1_var)
            } else {
              # break recursion
              
              out_d <- as_tibble(prior_est) 
              names(out_d) <- as.character(1:length(unique(obj@score_data@score_matrix$time_id)))
              
              out_d <- mutate(out_d,person_id=p,
                              iter=1:n())
              
              return(out_d)
            }
            
          } else {
            
            if(t==2) {
              
              prior_est <- L_full[,p] + L_AR1[,p]*initial + time_fix_sd*L_tp1_var[,(time_grid$Var1==t & time_grid$Var2==p)]
              
              prior_est <- cbind(initial,prior_est)
              
            } else {
              
              this_t <- L_full[,p] + L_AR1[,p]*prior_est[,t-1]  + time_fix_sd*L_tp1_var[,(time_grid$Var1==t & time_grid$Var2==p)]
              prior_est <- cbind(prior_est,this_t)
              
              
            }
            
            if(t<max(points)) { 
            
              time_func(t=t+1,
                        points=points,
                        prior_est=prior_est,
                        time_var_free=time_var_free,
                        p=p,
                        L_AR1=L_AR1,
                        L_full=L_full,
                        L_tp1_var=L_tp1_var)
            } else {
              # break recursion
              
              out_d <- as_tibble(prior_est) 
              names(out_d) <- as.character(1:length(unique(obj@score_data@score_matrix$time_id)))
              
              out_d <- mutate(out_d,person_id=p,
                              iter=1:n())
              
              return(out_d)
            }
            
          }
        
        
        # we don't do anything here because we need to return results from the
        # enclosing function call above
        
      }
      
      all_time <- lapply(unique(as.numeric(obj@score_data@score_matrix$person_id)), 
                         function (p) {
                           
                           initial <- L_tp1_var[,(time_grid$Var1==1 & time_grid$Var2==p)]
                           
                           out_d <- time_func(t=2,
                                     points=1:length(unique(obj@score_data@score_matrix$time_id)),
                                     time_var_free=time_var_free,
                                     initial=initial,
                                     p=p,
                                     L_full=L_full,
                                     L_AR1=L_AR1,
                                     L_tp1_var=L_tp1_var)
                           
                           return(out_d)
                           
                         }) %>% bind_rows()
      
      # need to reformat by spreading in correct manner
      # one row per sample
      # make joint time-person IDs
      
      all_time <- gather(all_time,"time_id",value="estimate",
                         -person_id,-iter) %>% 
                  mutate(time_id=as.numeric(time_id)) %>% 
                  arrange(person_id,time_id) %>% 
                  unite(col='key',time_id,person_id) %>% 
                  mutate(key2=factor(key,levels=unique(key)),
                         key3=as.numeric(key2)) %>% 
                  select(-key,-key2) %>% 
                  spread(key="key3",value="estimate") %>% 
        select(-iter) %>% 
        as.matrix
      
      time_grid <- expand.grid(1:length(unique(obj@score_data@score_matrix$time_id)),
                               unique(as.numeric(obj@score_data@score_matrix$person_id)))
      
      colnames(all_time) <- paste0("L_tp1[",time_grid$Var1,",",time_grid$Var2,"]")

      
    } else {
      
      # GP or random walk and AR(1) but with centered time series   
      
      all_time <- L_tp1_var <- obj@stan_samples$draws("L_tp1_var") %>% as_draws_matrix()
      
    } 
    
  } # end of if statement differentiating between mapping over items vs. persons
    
    return(all_time)
  }
saudiwin/idealstan documentation built on Sept. 2, 2023, 1:29 a.m.