R/csurvey.R

Defines functions .get_csurvey_option .onLoad ftable.csvy svycontrast.csvy coef.csvy SE.csvy barplot.csvy deff.csvy predict.csvy print.summary.csvy summary.csvy vcov.csvy print.csvy confint.csvy fitted.csvy fn_nbors_3 fn_nbors_2 fn_nbors impute_em2 makeamat_2d_block makeamat_2d makeamat_2D makeamat block.Ord getbin getvars makebin plotpersp.csvy plotpersp_csvy_control reorder_outputs recode_ordered_factor_to_numeric capitalize_first compute_bstat compute_mixture match_rows safe_svyby get_design_variables get_complete_design get_varname csvy.fit csvy

Documented in barplot.csvy block.Ord coef.csvy confint.csvy csvy plotpersp.csvy plotpersp_csvy_control predict.csvy summary.csvy vcov.csvy

#############################
## main routine for the user#
#############################
csvy = function(formula, design, family=stats::gaussian(), multicore=getOption("csurvey.multicore"), level=0.95, n.mix=100L, test=FALSE, subset=NULL) {
  cl <- match.call()
  # Handle subset argument
  subset_expr <- substitute(subset)
  subset_eval <- eval(subset_expr, model.frame(design), parent.frame())
  if (!is.null(subset_eval)) {
    if (anyNA(subset_eval)) stop("subset must not contain NA values")
    design <- design[subset_eval, ]
  } else {
    design0 <- design
    cleaned <- get_complete_design(formula, design)
    #mf <- cleaned$model_frame
    design <- cleaned$design
  }
  #names(design$variables) <- gsub(" ", "_", names(design$variables))
  if (is.character(family))
    family <- get(family, mode = "function", envir = parent.frame())
  if (is.function(family))
    family <- family()
  if (is.null(family$family))
    stop("'family' not recognized!")
  labels <- NULL
  formula <- terms.formula(formula)
  #print (formula)  
  #mf <- match.call(expand.dots = FALSE)
  #m <- match("formula", names(mf), 0L)
  #mf <- mf[c(1L, m)]
  #mf[[1L]] <- as.name("model.frame")
  #mf <- eval(mf, model.frame(design), parent.frame())
  #mf <- model.frame(formula, data = design$variables)
  mf <- model.frame(formula, data = get_design_variables(design))
  ynm <- names(mf)[1]
  mt <- attr(mf, "terms")
  y <- model.response(mf, "any")
  # Initialization
  n_pred <- ncol(mf) - 1
  block.ave.lst <- vector("list", length = n_pred)
  block.ord.lst <- vector("list", length = n_pred)
  xmat_add <- NULL
  #xmat_add0 <- NULL
  xnms_add <- NULL
  shapes_add <- NULL
  relaxes <- NULL
  # process each predictor
  cnms <- colnames(mf)
  #print (cnms)
  replaced <- as.data.frame(matrix(FALSE, nrow = 1, ncol = n_pred))
  #print (replaced)
  levels_old <- vector("list", length = n_pred)
  for (i in seq(2, n_pred+1)) {
    xi <- mf[, i]
    #print (class(xi))
    #print (is.character(levels(xi)))
    attr_xi <- attributes(xi)
    #print (cnms[i])
    #print (attr_xi)
    shape <- attr_xi$shape
    relax <- isTRUE(attr_xi$relax)
    block_type <- attr_xi$categ
    block_order <- attr_xi$order
    if (!is.null(shape) && is.numeric(shape)) {
      labels <- c(labels, "additive")
      shapes_add <- c(shapes_add, shape)
      relaxes <- c(relaxes, relax)
      if(is.character(xi)){
        xi <- factor(xi)
      }
      if (is.factor(xi)) {
        #if (is.factor(xi)){
        #xi_mat <- as.numeric(xi)
        #print (levels(xi))
        #test!
        tbl <- table(xi)
        levels_old_xi <- names(tbl[tbl > 0])
        zero_obs_levels <- names(tbl[tbl == 0])
        if (length(zero_obs_levels) > 0) {
          if (!exists("zero_level_warnings")) zero_level_warnings <- list()
          zero_level_warnings[[attr_xi$nm %||% cnms[i]]] <- zero_obs_levels
        }
        #levels_old_xi <- levels(xi)
        #print (is.character(levels(xi)))
        if(is.character(levels_old_xi)){
          #tbl <- table(xi)
          #levels_old <- names(tbl[tbl > 0])
          xi_mat <- recode_ordered_factor_to_numeric(xi, levels_wanted = levels_old_xi)
          #cat(attr_xi$nm, '\n')
          design$variables[[attr_xi$nm]] <- xi_mat
          #print(head(design$variables[[attr_xi$nm]]))
          replaced[[i-1]] <- TRUE
          levels_old[[i-1]] <- levels_old_xi
        } else {
          xi_mat <- as.numeric(xi)
        }
      } else {
        xi_mat <- xi
      }
      xmat_add <- cbind(xmat_add, xi_mat)
      #xmat_add0 <- cbind(xmat_add0, xi)
      xnms_add <- c(xnms_add, attr_xi$nm)
      block.ord.lst[[i - 1]] <- if (block_type == "block.ord") block_order else -1
      block.ave.lst[[i - 1]] <- if (block_type == "block.ave") block_order else -1
    } else {
      # No shape info
      labels <- c(labels, "additive")
      shapes_add <- c(shapes_add, 0)
      relaxes <- c(relaxes, relax)
      #test
      if(is.character(xi)) {
        xi <- factor(xi)
      }
      
      if(is.factor(xi)){
        #levels_old_xi <- levels(xi)
        tbl <- table(xi)
        levels_old_xi <- names(tbl[tbl > 0])
        zero_obs_levels <- names(tbl[tbl == 0])
        if (length(zero_obs_levels) > 0) {
          if (!exists("zero_level_warnings")) zero_level_warnings <- list()
          zero_level_warnings[[attr_xi$nm %||% cnms[i]]] <- zero_obs_levels
        }
        if(is.character(levels_old_xi)){
          #print ((is.character(levels(xi))))
          #levels_old <- levels(xi)
          #tbl <- table(xi)
          #levels_old <- names(tbl[tbl > 0])
          xi <- recode_ordered_factor_to_numeric(xi, levels_wanted = levels_old_xi)
          #cat(attr_xi$nm, '\n')
          #print (head(design$variables[[attr_xi$nm]]))
          design$variables[[cnms[i]]] <- xi
          #print(head(design$variables[[attr_xi$nm]]))
          #print (replaced[[i-1]])
          replaced[[i-1]] <- TRUE
          levels_old[[i-1]] <- levels_old_xi
        } else {
          xi <- as.numeric(xi)
        }
      }
      
      xmat_add <- cbind(xmat_add, xi)
      #xmat_add0 <- cbind(xmat_add0, xi)
      xnms_add <- c(xnms_add, cnms[i])
      block.ord.lst[[i - 1]] <- -1
      block.ave.lst[[i - 1]] <- -1
    }
  }
  #print (replaced)
  colnames(replaced) <- xnms_add 
  #how to handle ...?
  #fit = csvy.fit(design = design, fo = formula, subset = subset, family = family, M = nD, xm = xmat_add, xnms_add = xnms_add,
  #                sh = shapes_add, ynm = ynm, block.ave.lst = block.ave.lst, block.ord.lst = block.ord.lst,
  #                level = level, n.mix = n.mix, test = test,...)
  # dots <- match.call(expand.dots = FALSE)$...
  # if (!("nD" %in% names(cl)) && "nD" %in% names(dots)) {
  #   nD <- eval(dots$nD, parent.frame())
  #   dots$nD <- NULL
  # }
  # if (!("amat" %in% names(cl)) && "amat" %in% names(dots)) {
  #   amat <- eval(dots$amat, parent.frame())
  #   dots$amat <- NULL
  # }
  #print (nD)
  # print (relaxes)
  # print (head(xmat_add))
  # print (xnms_add)
  # print (shapes_add)
  # print (amat)
  # print (block.ave.lst)
  # print (block.ord.lst)
  # args <- alist(design = design, family = family, M = nD, relaxes = relaxes, xm = xmat_add, xnms_add = xnms_add,
  #  sh = shapes_add, ynm = ynm, amat = amat, block.ave.lst = block.ave.lst, block.ord.lst = block.ord.lst, level = level,
  #  n.mix = n.mix, multicore = multicore, cl = cl,test = test)
  #all_args <- c(args, dots)
  #all_args <- args
  #fit <- eval(as.call(c(quote(csvy.fit), all_args)), envir = parent.frame())
  #cat('finished process')
  #print (head(xmat_add))
  if(all(shapes_add == 0)){
    stop("No monotonic or partial-ordering is defined! use survey package instead!")
  } else {
    fit <- eval(call("csvy.fit", design = design, family = family, 
                     relaxes = relaxes, xm = xmat_add, xnms_add = xnms_add,
                     sh = shapes_add, y = y, ynm = ynm,
                     block.ave.lst = block.ave.lst, 
                     block.ord.lst = block.ord.lst, 
                     level = level, n.mix = n.mix, multicore = multicore, 
                     cl = cl, test = test))
    #fit$survey.design = design
    fit$survey.design = design0
    fit$data = design$variables
    fit$na.action = attr(mf, "na.action")
    fit$call = cl
    fit$family = family
    #fit$n.mix = n.mix
    fit$terms = mt
    fit$xmat_add = xmat_add 
    #fit$xmat_add0 = xmat_add0
    fit$xnms_add = xnms_add
    fit$shapes_add = shapes_add
    fit$ynm = ynm
    fit$levels_old = levels_old
    fit$replaced = replaced
    class(fit) = c("csvy", "cgam", "svyglm", "glm")
    # classes = if (inherits(design, "svyrep.design")) {
    #   c("svrepglm", "svyglm")
    # } else {
    #   "svyglm"
    # }
    # structure(c(fit, list(call = cl, n.mix = n.mix, xmat_add = xmat_add, xmat_add0 = xmat_add0, xnms_add = xnms_add, na.action = attr(mf, "na.action"), 
    #                       shapes_add = shapes_add, ynm = ynm, family = family, terms = mt, survey.design = design, data = design$variables)),
    #           class = c("csvy", "cgam", classes))
    if (exists("zero_level_warnings")) {
      cat("Note: Some factor levels had zero observations and were excluded:\n")
      for (var in names(zero_level_warnings)) {
        cat("-", var, ":", paste(zero_level_warnings[[var]], collapse = ", "), "\n")
      }
    }
    return (fit)
  }
}

############
## csvy.fit#
############
csvy.fit = function(design, family=stats::gaussian(), M=NULL, relaxes=NULL, 
                    xm=NULL, xnms_add=NULL, sh=1, y=NULL, ynm=NULL, nd=NULL, 
                    ID=NULL, block.ave.lst=NULL, block.ord.lst=NULL, 
                    amat=NULL, level=0.95, n.mix=100L, multicore=getOption("csvy.multicore"), 
                    cl=NULL, test=TRUE){
  #wp is population wt
  #Ds: observed domains?
  #make sure each column in xm is numeric:
  # stopifnot(is.data.frame(xm) || is.matrix(xm))
  # Convert to data.frame (if matrix) and ensure all numeric
  xm <- as.data.frame(xm)
  for (j in seq_along(xm)) {
    col <- xm[[j]]
    if (is.factor(col)) {
      xm[[j]] <- as.numeric(col)
    }
  }
  colnames(xm) = xnms_add
  # Number of unique values per column
  Ds <- vapply(xm, function(x) length(unique(x)), integer(1))
  # Unique combinations of rows
  xm.red <- unique(xm)
  #print (xm.red)
  # Sorted unique values per column
  xvs2 <- lapply(xm.red, function(x) sort(unique(x)))
  # Generate full grid of combinations
  grid2 <- expand.grid(xvs2, KEEP.OUT.ATTRS = FALSE, stringsAsFactors = FALSE)
  nD <- n_total_domains <- NROW(grid2)
  
  colnames(grid2) <- xnms_add
  df <- design$variables
  ds <- design
  #sample sizes for each observed domain
  if(is.null(nd)){
    xm.df = as.data.frame(xm)
    nd = aggregate(list(nd = rep(1, nrow(xm.df))), xm.df, length)$nd
  }
  if(is.null(ID)){
    ID = match_rows(xm, grid2)
  }
  
  ID = as.ordered(ID)
  uds = unique(ID)
  domain.ord = order(uds)
  uds = sort(uds)
  #new M:
  obs_cells = as.numeric(levels(uds))[uds]
  Tot_obs = length(obs_cells)
  if (is.null(M)) {
    #stop("nD (total number of domains) is not provided! \n")
    #M = max(obs_cells)
    M = nD
  } else if (M %% 1 != 0 | M <= 1) {
    stop(paste("nD (total number of domains) must be a positive integer > 1!", "\n"))
    #M = max(obs_cells)
  } else if (M < Tot_obs) {
    stop(paste("There are at least", Tot_obs, "domains! nD (total number of domains) should be >= ",
               Tot_obs, sep = " ", "\n"))
  }
  
  #new: to be used in makeamat_2d_block
  M_0 = M
  #if (any(class(ds) == "survey.design")) {
  if (inherits(ds, "svyrep.design")) {
    weights <- ds$pweights
  } else if (inherits(ds, "survey.design")) {
    weights <- 1 / ds$prob
  } else {
    stop("Unsupported design object: must be 'survey.design' or 'svyrep.design'.")
  }
  
  #Nhat and stratawt will follow the order: 1 to M
  #y = df[[ynm]] #not work if y is log(income) but df has income
  #test:
  if(family$family %in% c("binomial", "quasibinomial")){
    uvals = unique(y)
    if(length(uvals) > 2) {
      stop(paste("The response has more than 2 outcomes! family = binomial won't work!", "\n"))
    } else {
      if(is.factor(y)){
        ulvls = levels(y)
        success = ulvls[2] #same as glm; 2nd level is success
        y = ifelse(y == success, 1, 0)
      } 
    }
  }
  n = length(y)
  ys = stratawt = vector("list", length=M)
  #add ybar_vec for cic of non-gaussian cases
  Nhat = nd = ybar_vec = 1:M*0
  uds_num = as.numeric(levels(uds))[uds] 
  #new: for binomial
  y_ch_id = NULL
  
  #check more!
  if (length(Ds) >= 1) {
    #test!
    min_udi <- min(uds_num)
    for(udi in uds_num){
      ps = which(ID %in% udi)
      ysi = y[ps]
      wi = weights[ps]
      Nhi = sum(wi)
      
      idx <- udi - min_udi + 1
      Nhat[idx] = Nhi
      nd[idx] = length(wi)
      ys[[idx]] = ysi
      #new: for binomial
      #print (ysi)
      if (all(ysi == 0) | all(ysi == 1)) {
        y_ch_id = c(y_ch_id, udi)
      }
      stratawt[[idx]] = wi
      ybar_vec[idx] = mean(ysi)
    }
  }
  
  #new: for binomial
  y_ch_id2 = NULL
  log_odds = NULL
  
  if (!is.null(y_ch_id) & family$family %in% c('binomial', 'quasibinomial')){
    n_ch_id = length(y_ch_id)
    y_ch_id2 = y_ch_id - 1
    
    #check!
    if (y_ch_id2[1] == 0) {
      nbor = 1
      while(all(ys[[nbor]] == 1) | all(ys[[nbor]] == 0) | is.null(ys[[nbor]])){
        nbor = nbor + 1
      }
      y_ch_id2[1] = nbor
    }
    
    ch = sapply(y_ch_id2, function(e) all(ys[[e]] == 1) | all(ys[[e]] == 0) | is.null(ys[[e]]))
    
    if (any(ch)) {
      ps = which(ch)
      #for(i in 1:length(ps)) {
      for(psi in ps){
        #psi = ps[i]
        nbor = y_ch_id2[psi]
        if (nbor == 1) {
          while(nbor %in% y_ch_id | is.null(ys[[nbor]])){
            nbor = nbor + 1
          }
        } else {
          while(nbor %in% y_ch_id | is.null(ys[[nbor]])) {
            nbor = nbor - 1
            #test more
            if (nbor == 1) {
              while(nbor %in% y_ch_id | is.null(ys[[nbor]])){
                nbor = nbor + 1
              }
            }
          }
        }
        y_ch_id2[psi] = nbor
      }
    }
    
    ###
    y_merge = rbind(y_ch_id2, y_ch_id)
    n_y_merge = ncol(y_merge)
    for(i in 1:n_y_merge){
      y_merge_i = c(y_merge[, i])
      den_i = 0
      num_i = 0
      for(j in y_merge_i){
        #print (j)
        ys_j = ys[[j]]
        den_i = length(ys_j) + den_i
        num_i = sum(ys_j) + num_i
      }
      p_i = num_i / den_i
      log_odds = c(log_odds, log(p_i / (1 - p_i)))
    }
  }
  
  #new:
  obs_cells = which(nd > 0)
  fo = as.formula(paste0("~", ynm))
  #use fo2 for svyglm
  xnms_add_fac <- sapply(xnms_add, function(xnmi) {
    if (!isTRUE(all.equal(xnmi, make.names(xnmi)))) {
      xnmi <- paste0("`", xnmi, "`")
    }
    paste0("factor(", xnmi, ")")
  })
  xnms_fo_fac = paste(xnms_add_fac, collapse = "*")
  fo2 = formula(paste(ynm, "~", xnms_fo_fac))
  #new:
  fo2_null = formula(paste(ynm, "~", 1))
  xnms_fo = paste(xnms_add, collapse = "+")
  #new:
  if(family$family == "gaussian"){
    #ans.unc = suppressWarnings(svyby(formula=fo, by=~ID, design=ds, FUN=svymean, 
    #                                 covmat=TRUE, drop.empty.groups=FALSE))
    ans.unc <- safe_svyby(formula = fo, by = ~ID, design = ds, FUN = svymean, drop.empty.groups = FALSE)
    v1 = vcov(ans.unc)
    etahatu = yvecu = yvec = ans.unc[[ynm]]
    #vsu = round(diag(v1), 5)
    vsu = diag(v1)
    ans.unc_null = svyglm(formula=fo2_null, design=ds, family=family)
    prior.weights = ans.unc_null$prior.weights
    ans.unc_cp_0 = ans.unc
  } else {
    #fo_by = if (length(xnms_add) == 1) formula(paste("~", xnms_add)) else formula(paste("~", paste(xnms_add, collapse = "+")))
    fo_by <- reformulate(xnms_add)
    # ans.unc_cp_0 = suppressWarnings(svyby(formula = fo, by = fo_by, design = ds, FUN = svymean, keep.var = TRUE,
    #                                      keep.names = TRUE, verbose = FALSE, vartype = "se",
    #                                      drop.empty.groups = FALSE, covmat = TRUE,
    #                                      na.rm.by = FALSE, na.rm.all = FALSE))
    ans.unc_cp_0 <- safe_svyby(formula = fo, by = fo_by, design = ds, family = "binomial")
    ans.unc = suppressWarnings(svyglm(formula=fo2, design=ds, family=family, 
                                      control = list(epsilon = 1e-8, maxit = 10)))
    prior.weights = ans.unc$prior.weights
    #create newdata
    newdata = grid2
    #test more! colnames
    #colnames(newdata) = xnms_add_fac
    empty_cells = which(nd == 0)
    #predict.svyglm will ignore empty cells for 1D?
    if(length(Ds) > 1 & length(empty_cells) > 0) {
      newdata = newdata[-empty_cells, ,drop=FALSE]
    }
    
    #new: more concise
    # newdata <- if (length(Ds) > 1 && length(empty_cells) > 0) {
    #   grid2[-empty_cells, , drop = FALSE]
    # } else {
    #   grid2
    # }
    
    if(length(empty_cells) == 0){
      p.unc = predict(ans.unc, newdata, se.fit = TRUE, type = 'link', vcov = TRUE)
      v1 = vcov(p.unc)
      etahatu = yvecu = yvec = as.data.frame(p.unc)$link
      vsu = round(diag(v1), 5)
    }
    
    if(length(empty_cells) > 0){
      #print (length(empty_cells))
      tt = delete.response(terms(formula(ans.unc)))
      mf = model.frame(tt, data = newdata)
      mm = model.matrix(tt, mf)
      #modified from predict.svrepglm
      mm = mm[, -empty_cells, drop = FALSE]
      v1 = mm %*% vcov(ans.unc) %*% t(mm)
      #etahatu = yvecu = yvec = NULL
      etahatu = drop(mm %*% coef(ans.unc))
      yvecu <- etahatu
      yvec <- etahatu
      vsu = round(diag(v1), 5)
    }
  }
  
  z.mult = qnorm((1 - level)/2, lower.tail=FALSE)
  hlu = z.mult*sqrt(vsu)
  lwru = yvecu - hlu
  uppu = yvecu + hlu
  #new: replace n=1 variance with the average
  #not to change vsu, which will be used for 
  #ps = which(round(diag(v1), 7) == 0L)
  ps = which(diag(v1) == 0L)
  diag(v1)[ps] = mean(diag(v1)[-ps])
  
  w = Nhat/sum(Nhat)
  Nd = length(Ds)
  
  empty_cells = zeros_ps = which(nd == 0)
  ne = zeros = length(zeros_ps)
  #nd_ne = nd[empty_cells]
  nd_ne = rep(0, ne)
  #small_cells = which(nd >= 1 & nd <= 10)
  #nsm = length(small_cells)
  #temp
  #block.ave is not considered now
  #noord_cells will be a list, each elememt may have different length
  noord_cells = NULL
  #if(length(block.ave.lst) > 0 & Nd == 1){
  bool_ave = map_dbl(block.ave.lst, .f = function(.x) all(.x == -1))
  bool_ord = map_dbl(block.ord.lst, .f = function(.x) all(.x == -1))
  if(any(bool_ave == 0)){
    noord_cells = map(block.ave.lst, .f = function(.x) which(.x == 0))
    if(length(Ds) == 1){
      noord_cells = noord_cells[[1]]
    }
  }
  
  if(any(bool_ord == 0)){
    noord_cells = map(block.ord.lst, .f = function(.x) which(.x == 0))
    if(length(Ds) == 1){
      noord_cells = noord_cells[[1]]
    }
  }
  
  pskeep = ps = NULL
  M_0 = M
  w_0 = w
  
  M = M - zeros
  imat = diag(M)
  if (ne >= 1) {
    w = w[-zeros_ps]
  }
  #define weight mat and its inverse here to avoid repeating it
  #need to consider zeros
  wt = diag(w)
  wtinv = diag(1/w)
  hkeep = NULL
  amat_0 = NULL #to be used in Nd > 1
  amat_ci = NULL #to be used for ci
  
  if (is.null(amat)) {
    if (Nd == 1) {
      #sh = 0 means user must define amat 
      if (sh > 0) {
        #Ds is not correct when >= 1 empty cell
        #new: try to find the constrained est. first
        block.ave = NULL
        block.ord = NULL
        if (length(block.ave.lst) > 0) {
          #print ('a')
          block.ave = block.ave.lst[[1]]
        }
        if (length(block.ord.lst) > 0) {
          #print ('a')
          block.ord = block.ord.lst[[1]]
        }
        if (relaxes) {
          #amats = makeamat(1:M, sh=sh, Ds=M, interp=TRUE, relaxes=relaxes)
          if (is.null(h_grid)) {
            n_h = 60 
            h_grid = seq(0.01, 6, length = n_h)
          } else {
            n_h = length(h_grid)
          }
          amats = vector("list", length = n_h)
          for (i in seq_len(n_h)) {
            h = h_grid[i]
            amat = makeamat(1:M, sh=sh, Ds=M, interp=TRUE, relaxes=relaxes, h=h)
            amats[[i]] = amat
          }
          #CIC to choose h
          #find the h_hat besed on CIC criterion for each simulated y
          B1 = 10 #??
          #ysim = MASS::mvrnorm(B1, mu = yvec, Sigma = v1)
          ysim = MASS::mvrnorm(B1, mu = muhat, Sigma = v1)
          index = rep(0, B1)
          #wtinv = diag(1/w)
          for (j in 1:B1) {
            CIC_h = rep(0, n_h)
            #rme = matrix(0, M, n_h)
            for (i in seq_len(n_h)) {
              amati = amats[[i]]
              ansi = coneA(ysim[j, ], amati, w, msg=FALSE)
              #rme[,i] = ansi$thetahat
              facei = ansi$face
              
              if (length(facei) == 0){
                #pmat_is = diag(rep(0, M))
                mat = wt %*% v1
              } else {
                dd = amati[facei, ,drop=FALSE]
                #check solve
                wtinvdd = wtinv %*% t(dd)
                pmat_is = wtinvdd %*% solve(dd %*% wtinvdd) %*% dd
                mat = wt %*% (imat - pmat_is) %*% v1
              }
              evec = ysim[j,] - ansi$thetahat
              CIC_h[i] = t(evec) %*% wt %*% evec + 2*(sum(diag(mat)))
            }
            CIC_h = round(CIC_h, 6)
            index[j] = which.min(CIC_h)[1]
          }
          #ps = max(index)
          #try the 4th largest
          ps = (sort(index, decreasing = TRUE))[4]
          #choose amat_h 
          amat = amats[[ps]]
          hkeep = h_grid[ps]
        } else {
          #amat = makeamat(1:M, sh=sh, Ds=M, interp=TRUE, relaxes=FALSE)
          if(length(zeros_ps) > 0){
            M2 = seq(1, M+zeros)[-zeros_ps]
          }else{
            M2 = 1:M
          }
          amat = makeamat(x=M2, sh=sh, Ds=M, interp=TRUE, relaxes=FALSE, 
                          block.ave=block.ave.lst[[1]], block.ord=block.ord.lst[[1]], 
                          zeros_ps=zeros_ps, noord_cells=noord_cells, xvs2=xvs2, grid2=grid2)
          amat_0 = makeamat(x=1:M_0, sh=sh, Ds=M_0, interp=TRUE, relaxes=FALSE, 
                            block.ave=block.ave.lst[[1]], block.ord=block.ord.lst[[1]], 
                            zeros_ps=NULL, noord_cells=noord_cells, xvs2=xvs2, grid2=grid2)
        }  
      } else if (sh == 0 & is.null(amat)) {
        #print (sh)
        stop ("User must define a constraint matrix!")
      }
    } else if (Nd >= 2) {
      #temp:
      if (any(sh > 0)) {
        #if (!relaxes) {
        #print (zeros_ps)
        #new: create another amat used for forcing monotonicity on ci
        amat_rslt = makeamat_2D(x=NULL, sh=sh, grid2=grid2, xvs2=xvs2,
                                zeros_ps=zeros_ps, noord_cells=noord_cells, 
                                Ds=Ds, interp=TRUE, 
                                block.ave.lst=block.ave.lst,
                                block.ord.lst=block.ord.lst, M_0=M_0)
        amat = amat_rslt$amat
        #amat_ci = amat_rslt$amat_ci #wrong, should include empty domains
        #to be used in impute_2 or impute_3
        amat_0_rslt = makeamat_2D(x=NULL, sh=sh, grid2=grid2, xvs2=xvs2,
                                  zeros_ps=NULL, noord_cells=noord_cells,
                                  Ds=Ds, interp=TRUE,
                                  block.ave.lst=block.ave.lst,
                                  block.ord.lst=block.ord.lst, M_0=M_0)
        amat_0 = amat_0_rslt$amat
        amat_ci = amat_0_rslt$amat_ci
      } else if (all(sh == 0) & is.null(amat)) {
        stop (paste("User must define a constraint matrix!", "\n"))
      }
    }
  }
  
  #new: for binomial, to avoid extreme etahatun
  if(family$family %in% c("binomial", "quasibinomial")){
    if(ne >= 1){
      yvec_all = numeric(M_0)
      yvec_all[-empty_cells] = yvec
    }
    if(ne == 0){
      yvec_all = yvec
    }
    if(!is.null(log_odds)){
      iter = 1
      for(i in seq_len(n_y_merge)){
        y_merge_i = c(y_merge[, i])
        #for(j in 1:length(y_merge_i)){
        for(ps_j in y_merge_i){
          #ps_j = y_merge_i[j]
          yvec_all[ps_j] = log_odds[iter]
        }
        iter = iter + 1
      }
    }
    if(ne >= 1){
      yvec = yvec_all[-empty_cells]
    }
    if(ne == 0){
      yvec = yvec_all
    }
    etahatu = yvecu = yvec
  }
  
  #new: check irreducibility of amat
  if (!is.null(amat)) {
    #new:
    #if (family$family == "gaussian"){
    ans.polar = coneA(yvec, amat, w=w, msg=FALSE)
    etahat = round(ans.polar$thetahat, 10)
    #new
    #muhat = etahat 
    face = ans.polar$face
    facekeep = face #resid df
    #new: to be used in summary and anova
    #edf = 1.5*ans.polar$df
    edf = ans.polar$df #constr df
    if (length(face) == 0){
      mat = wt %*% v1
    } else {
      dd = amat[face, ,drop=FALSE]
      wtinvdd = wtinv %*% t(dd)
      pmat_is = wtinvdd %*% solve(dd %*% wtinvdd) %*% dd
      mat = wt %*% (imat - pmat_is) %*% v1
    }
    #evec = yvec - muhat
    if(family$family == "gaussian"){
      evec = yvec - etahat
      CIC = t(evec) %*% wt %*% evec + 2*(sum(diag(mat)))
      CIC = as.numeric(CIC)
      #new: add CIC.un
      CIC.un = 2*(sum(diag(wt %*% v1)))
      CIC.un = as.numeric(CIC.un)
    }else{
      #check for empty cells
      evec = ybar_vec[which(nd != 0)] - c(family$linkinv(etahat))
      CIC = t(evec) %*% wt %*% evec + 2*(sum(diag(mat)))
      CIC = as.numeric(CIC)
      CIC.un = 2*(sum(diag(wt %*% v1)))
      CIC.un = as.numeric(CIC.un)
    }
  }
  
  acov = v1
  lwr = upp = NULL
  use_parallel = isTRUE(getOption("csurvey.multicore", FALSE))
  #decide how many cores to use
  cores = 1
  if(use_parallel){
    cores = getOption("csurvey.cores", parallel::detectCores(logical = FALSE) - 1)
    #cores = min(8, getOption("csurvey.cores", parallel::detectCores(logical = FALSE) - 1))
    cores = max(1, cores)
  }
  #detect OS platform
  is_windows = .Platform$OS.type == "windows"
  
  if (!is.numeric(n.mix) || length(n.mix) != 1) {
    warning("Argument 'n.mix' must be a single numeric value. Setting n.mix to 10.")
    n.mix <- 10
  } else if (n.mix < 10) {
    warning("There should be at least 10 simulations to compute the mixture covariance. Setting n.mix to 10.")
    n.mix <- 10
  } else if (n.mix %% 1 != 0) {
    warning("Argument 'n.mix' must be an integer. Rounding down to the nearest integer.")
    n.mix <- floor(n.mix)
  }
  
  if(n.mix > 0){
    #get the constrained variance-covariance matrix
    #lwr = upp = NULL
    #acov = v1
    dp = -t(amat)
    dp = apply(dp, 2, function(e) e / (sum(e^2))^(.5))
    #test!
    #dp = Matrix(dp, sparse = TRUE)
    
    m_acc = M 
    sector = NULL
    times = NULL
    df.face = NULL
    iter = 1
    obs = 1:M
    #binomial can use mvrnorm?
    ysims = rbind(MASS::mvrnorm(n.mix, mu=etahat, Sigma=v1))
    #new: 2025
    if(use_parallel && cores > 1){
      #message(sprintf("Running in parallel using %d cores.", cores))
      if(is_windows){
        # Windows: use parLapply with a PSOCK cluster
        cl = parallel::makeCluster(cores)
        on.exit(parallel::stopCluster(cl))  # ensure cleanup
        mixture_rslt = parallel::parLapply(cl, 1:n.mix, compute_mixture, ysims=ysims, amat=amat, w=w)
      } else {
        # Unix/macOS: use mclapply
        mixture_rslt = parallel::mclapply(1:n.mix, compute_mixture, ysims=ysims, amat=amat, w=w)
        #cat('finished mixture', '\n')
      }
    }else{
      #message("Running in serial mode.")
      mixture_rslt = lapply(1:n.mix, compute_mixture, ysims=ysims, amat=amat, w=w)
    }
    
    mixture_rslt <- do.call(rbind, mixture_rslt) |> as.data.frame()
    bsec = aggregate(list(freq = rep(1, nrow(mixture_rslt))), mixture_rslt, length) |>
      mutate(prob = freq / sum(freq)) |>
      dplyr::select(freq, prob, everything())
    
    sector = bsec[, -c(1, 2), drop = FALSE]
    nsec = NROW(sector)
    imat = diag(M)
    if(nsec > 0){
      acov = matrix(0, nrow=M, ncol=M)
      wtinv = diag(1/w)
      for(isec in 1:nsec){
        #print (isec)
        jvec = sector[isec, ]
        if(all(jvec == 0)){
          next 
        } else {
          smat = dp[, which(jvec == 1), drop = FALSE]
          wtinv_sm = wtinv %*% smat
          pmat_is_p = wtinv_sm %*% solve(t(smat) %*% wtinv_sm, t(smat))
          pmat_is = (imat - pmat_is_p)
          p_j = bsec[isec,2]
          acov = acov + p_j * pmat_is%*%v1%*%t(pmat_is)
        }
      }
    } else {
      acov = v1
    }
    z.mult = qnorm((1 - level)/2, lower.tail=FALSE)
    vsc = diag(acov)
    hl = z.mult*sqrt(vsc)
    lwr = etahat - hl
    upp = etahat + hl
  }  else {
    z.mult = qnorm((1 - level)/2, lower.tail=FALSE)
    #z.mult = 2
    vsc = diag(v1)
    hl = z.mult*sqrt(vsc)
    lwr = etahat - hl
    upp = etahat + hl
  }
  
  vscu = diag(v1)
  hlu = z.mult*sqrt(vscu)
  lwru = etahatu - hlu
  uppu = etahatu + hlu
  
  if(ne >= 1){
    #etahatu = yvecu
    #new: if there's empty cells, just augment muhatu to include NA's
    etahatu_all = 1:(M+zeros)*0
    etahatu_all[zeros_ps] = NA
    etahatu_all[obs_cells] = etahatu
    etahatu = etahatu_all
    
    lwru_all = 1:(M+zeros)*0
    lwru_all[zeros_ps] = NA
    lwru_all[obs_cells] = lwru
    lwru = lwru_all
    
    uppu_all = 1:(M+zeros)*0
    uppu_all[zeros_ps] = NA
    uppu_all[obs_cells] = uppu
    uppu = uppu_all
    
    #for monotone only: to find the c.i. for empty/1 cells by using smallest L and largest U of neighbors
    #new: create grid2 again because the routines to do imputation need to have each x start from 1
    #xvs2 = apply(xm.red, 2, function(x) 1:length(unique(x)))
    xvs2 <- lapply(xm.red, function(x) sort(unique(x)))
    grid2 = expand.grid(xvs2, KEEP.OUT.ATTRS = FALSE, stringsAsFactors = FALSE)
    #revise later: replace lwru with lwr etc
    ans_im = impute_em2(empty_cells, obs_cells, M=(M+zeros), etahatu, Nd, nd, Nhat, w, 
                        domain.ord, grid2, Ds, sh, etahat, lwr=lwr, upp=upp, vsc=vsc, amat_0, level=level)
    etahat = ans_im$muhat
    #lwr = ans_im$lwr
    #upp = ans_im$upp
    vsc = ans_im$vsc
    domain.ord = ans_im$domain.ord
  }
  
  #sig1 = vsc
  #sig2 = sig1
  vsc_mix = vsc
  
  hl = z.mult*sqrt(vsc_mix)
  lwr = etahat - hl
  upp = etahat + hl
  
  #new: learn from cgam; nasa's idea
  #for now, only handles one predictor
  #new: don't constrain endpoints to follow monotonicity; make a copy of old lwr and upp
  # lwr_0 = lwr
  # upp_0 = upp
  # nci = length(lwr)
  
  #new: check monotonicity of lwr and upp
  sh_0 = sh
  #relabel sh == 1 and 9?
  if(any(sh == 9)){
    sh_0[which(sh == 9)] = 1
  }
  #remove sh = 0, there is no amat for sh = 0
  Ds_0 = Ds
  if(any(sh == 0)){
    rm_id = which(sh == 0)
    sh_0 = sh_0[-rm_id]
    Ds_0 = Ds[-rm_id]
  }
  
  #print (dim(amat_0))
  #print (length(lwr))
  check_ps_lwr = round(c(amat_0 %*% lwr), 4) 
  check_ps_upp = round(c(amat_0 %*% upp), 4) 
  not_monotone_ci = any(check_ps_lwr < 0) | any(check_ps_upp < 0)
  #check more
  wvec = nd/sum(nd)
  if(any(wvec == 0)){
     #1e-12 fail to have the constraint
     #need to be changed....
     wvec[which(wvec == 0)] = pmin(1e-5, min(wvec[wvec > 0]))
     #wvec[which(wvec == 0)] = 1e-4#1e-12#1e-4
  }
  #wvec[wvec == 0] <- 1e-4  
  
  vsc_mix_0 = vsc_mix
  # if (length(sh_0) == 1 & n.mix > 0 & not_monotone_ci) {
  #   lwr = coneA(lwr, amat_0, w = wvec)$thetahat
  #   upp = coneA(upp, amat_0, w = wvec)$thetahat
  #   vsc_mix_0 = ((upp - lwr) / 4)^2
  # }
  
  #new:
  check_ps_etahat = round(c(amat_0 %*% etahat), 4) 
  if (length(sh_0) > 1 & n.mix > 0 & any(check_ps_etahat < 0)) {
    etahat = coneA(etahat, amat_0, w = wvec)$thetahat
  }
  
  #handles small domains automatically
  if (length(sh_0) >= 1 & n.mix > 0 & not_monotone_ci) {
    lwr = coneA(lwr, amat_0, w = wvec)$thetahat
    upp = coneA(upp, amat_0, w = wvec)$thetahat
    #new:
    #vsc_mix_0 = ((upp - lwr) / 4)^2
    #new:
    vsc_mix_0 = ((upp - lwr) / (2*z.mult))^2
  }
  #new: should let vsc_mix = vsc_mix_0 
  vsc_mix = vsc_mix_0 
  #new: test of H_0: theta in V and H_1: theta in C
  #do lsq fit with Atheta = 0 to get fit under H0
  #use the 1st option to compute T1 and T2
  bval = NULL
  pval = NULL
  if(test){
    m = qr(amat)$rank
    #vmat = qr.Q(qr(t(amat)), complete = TRUE)[, -(1:(qr(t(amat))$rank)), drop = FALSE]
    vec = amat %*% yvecu
    tst = amat %*% v1 %*% t(amat)
    #print (ginv(tst))
    T1_hat = t(vec) %*% MASS::ginv(amat %*% v1 %*% t(amat)) %*% vec
    
    eans = eigen(v1, symmetric=TRUE)
    evecs = eans$vectors; evecs = round(evecs, 8)
    #new: change negative eigenvalues to be a small positive value
    #evals = eans$values
    evals = pmax(eans$values, 1e-12)
    sm = 1e-8 #1e-7
    
    #need test more
    #neg_ps = which(evals < sm)
    #if(length(neg_ps) > 0){
    #  evals[neg_ps] = sm
    #}
    L = evecs %*% diag(sqrt(evals))
    Linv = solve(L) #solve will give singular error message when n is small
    atil = amat %*% L
    Z_s = Linv %*% yvecu
    theta_hat = coneA(Z_s, atil, msg=FALSE)$thetahat
    T2_hat = t(Z_s-theta_hat) %*% (Z_s-theta_hat)
    bval = (T1_hat-T2_hat)/T1_hat
    if(bval > sm){
      nloop = 100
      zerovec = rep(0, M)
      zsims = mvrnorm(nloop, mu = zerovec, Sigma = imat)
      # J_length = rep(0, nloop)
      # for (i in 1:nloop) {
      #   J_length[i] = length(coneA(zsims[i, ], atil, msg=FALSE)$face)
      # }
      #J_length = parallel::mclapply(1:nloop, .compute_bstat, zsims=zsims, atil=atil) |> simplify2array()
      if(use_parallel && cores > 1){
        #message(sprintf("Running in parallel using %d cores.", cores))
        if(is_windows){
          # Windows: use parLapply with a PSOCK cluster
          cl = parallel::makeCluster(cores)
          on.exit(parallel::stopCluster(cl))  # ensure cleanup
          J_length = parallel::parLapply(cl, 1:nloop, compute_bstat, zsims=zsims, atil=atil) |> simplify2array()
        } else {
          # Unix/macOS: use mclapply
          J_length = parallel::mclapply(1:nloop, compute_bstat, zsims=zsims, atil=atil) |> simplify2array()
        }
      }else{
        #message("Running in serial mode.")
        J_length = lapply(1:nloop, compute_bstat, zsims=zsims, atil=atil) |> simplify2array()
      }
      mdist = 0:m*0
      for (i in 1:(m+1)){
        mdist[i] = length(which(J_length == (i-1)))
      }
      mdist = mdist/nloop
      pval = 1 - sum(pbeta(bval, m:0/2, 0:m/2)*mdist)
    } else {pval = 1}
  }
  #set h_grid = NULL for now
  h_grid = NULL
  muhat = c(family$linkinv(etahat))
  muhat.un = c(family$linkinv(etahatu))
  #used to compute deviance: every data point will be used, not just for domains 1,...,M
  muhat_all = 1:n*0 #not include empty cells
  etahat_all = 1:n*0
  #new: for null deviance of gaussian case
  #ybar_all = 1:n*0
  # for (udi in uds){
  #   udi = as.numeric(udi)
  #   ps = which(ID %in% udi)
  #   #print (ps)
  #   #should use muhat, not etahat
  #   #print (muhat[which(uds%in%udi)])
  #   print (which(uds%in%udi))
  #   muhat_all[ps] = muhat[which(uds%in%udi)] #should not use which(uds%in%udi): won't skip empty cells?
  #   etahat_all[ps] = etahat[which(uds%in%udi)]
  #   #ybar_all[ps] = ybar_vec[which(uds%in%udi)]
  # }
  
  #more correct: skip imputed empty cells
  min_udi <- min(uds_num)
  for(udi in uds_num){
    ps <- which(ID %in% udi)
    #print (ps)
    idx <- udi - min_udi + 1
    #print (muhat[idx])
    #print (idx)
    muhat_all[ps] <- muhat[idx]
    etahat_all[ps] <- etahat[idx]
  }
  
  # if (family$family %in% c("binomial", "quasibinomial")){
  #   wt <- prior.weights
  #   # prevent log(0) issues by forcing y and mu into a safe range
  #   eps <- .Machine$double.eps
  #   muhat_all <- pmin(pmax(muhat_all, eps), 1 - eps)
  #   #print (anyNA(wt))
  #   #print (anyNA(y))
  #   #print (anyNA(muhat_all))
  #   #deviance <- 2*sum(wt*(y*log(y/muhat_all) + (1-y)*log((1-y)/(1-muhat_all))))
  #   log_term1 <- ifelse(y == 0, 0, y * log(y / muhat_all))
  #   log_term2 <- ifelse(y == 1, 0, (1 - y) * log((1 - y) / (1 - muhat_all)))
  #   deviance <- 2 * sum(wt * (log_term1 + log_term2))
  # }
  # if (identical(family$family, "poisson")){
  #   wt <- prior.weights
  #   # Replace any y == 0 to avoid NaNs in log(y/mu)
  #   deviance <- 2*sum(wt*ifelse(y == 0, -muhat_all, y*log(y / muhat_all) - (y - muhat_all)))
  # }
  # if(identical(family$family, "gaussian")){
  #   wt <- prior.weights
  #   deviance <- sum(wt*(y-muhat_all)^2)
  # }
  #cat('deviance by formula:', deviance, '\n')
  deviance = sum(family$dev.resids(y, muhat_all, wt = prior.weights))
  #cat('deviance by dev.resids:', sum(family$dev.resids(y, muhat_all, wt = prior.weights)), '\n')  
  #CIC = t(evec) %*% wt %*% evec + 2*(sum(diag(mat)))
  #same as CIC:
  #for gaussian: -2loglike = nlog(t(evec) %*% wt %*% evec)
  #for binomial: yvec is not 0-1, and it doesn't work, then the penalty won't match
  #logLike_CIC = sum(family$dev.resids(yvec, muhat, diag(wt))) + 2*(sum(diag(mat)))
  logLike_CIC = deviance + 2*(sum(diag(mat)))
  ans = new.env()
  #revise later: about domain.ord
  ans$linear.predictors = etahat_all
  ans$etahat = c(etahat)
  ans$linear.predictors.un = ans.unc$linear.predictors
  #ans$etahatu = c(yvecu)
  ans$etahatu = etahatu #include NA's for zero cells
  ans$fitted.values = muhat_all
  ans$muhat = muhat
  ans$fitted.values.un = ans.unc$fitted.values
  ans$muhatu = muhat.un
  ans$cov.un = v1 #unconstrained cov
  #ans$cov.unscaled = vsc_mix #constrained cov with imputation, if fixed with monotonicity, then change it to be vsc_mix_0 
  ans$vsc_mix = vsc_mix #constrained cov with imputation, if fixed with monotonicity, then change it to be vsc_mix_0 
  ans$n.mix = n.mix
  if(n.mix == 0L){
    ans$acov = v1
  } else {
    ans$acov = acov #mixture variance-covariance without imputation
  }
  if(family$family == "gaussian"){
    ans$null.deviance = ans.unc_null$deviance
    #ans$null.deviance = sum(family$dev.resids(y, mean(y), prior.weights))
    #ans$null.deviance = sum(family$dev.resids(y, ybar_all, prior.weights))
  }else{
    ans$null.deviance = ans.unc$null.deviance
  }
  #ans$df.null = n - 1
  ans$df.null = M - 1
  ans$deviance = deviance
  #ans$df.residual = n - edf
  ans$df.residual = length(facekeep) #resid df obs #M - edf
  ans$edf = edf #constr df, obs
  ans$lwr = c(lwr)
  #ans$lwr = c(family$linkinv(lwr))
  ans$upp = c(upp)
  #ans$upp = c(family$linkinv(upp))
  ans$lwru = c(lwru)
  #ans$lwru = c(family$linkinv(lwru))
  ans$uppu = c(uppu)
  #ans$uppu = c(family$linkinv(uppu))
  ans$y = y
  ID = as.numeric(levels(ID))[ID] 
  ans$domain = ID
  ans$grid_ps = pskeep
  ans$amat = amat
  ans$Ds = Ds
  ans$domain.ord = domain.ord
  ans$nd = nd
  ans$grid = grid2
  ans$xvs2 = xvs2
  ans$ec = empty_cells
  ans$hkeep = hkeep
  ans$h_grid = h_grid
  ans$zeros_ps = zeros_ps
  ans$pval = pval
  ans$bval = bval
  ans$CIC = CIC
  ans$CIC.un = CIC.un
  ans$logLike_CIC = logLike_CIC
  #new: inherit attributes if svyby is used; put constrained fits in it
  if (family$family == "gaussian") {
    ans.unc_cp = data.frame(matrix(nrow = M_0, ncol = ncol(ans.unc)))
    colnames(ans.unc_cp) = colnames(ans.unc)
    ans.unc_cp$ID = factor(1:M_0)
    #ans.unc_cp = ans.unc_cp_0
    ans.unc_cp[[ynm]] = as.numeric(etahat)
    ans.unc_cp$se = vsc_mix^(.5) #updated if vsc_mix_0 is used to fix monotonicity
    #ans.unc_cp$se = vsc_mix_0^(.5) #new
    attr(ans.unc_cp, "class") = c("svyby", "data.frame")
    attr(ans.unc_cp, "svyby") = attr(ans.unc, "svyby")
    attr(ans.unc_cp, "var") = acov
  } else {
    ans.unc_cp = data.frame(matrix(nrow = M_0, ncol = ncol(ans.unc_cp_0)))
    colnames(ans.unc_cp) = colnames(ans.unc_cp_0)
    ans.unc_cp = ans.unc_cp_0
    ans.unc_cp[[ynm]] = as.numeric(etahat)
    ans.unc_cp$se = vsc_mix^(.5)
    #ans.unc_cp$se = vsc_mix_0^(.5) #new
    attr(ans.unc_cp, "class") = c("svyby", "data.frame")
    attr(ans.unc_cp, "svyby") = attr(ans.unc_cp_0, "svyby")
    attr(ans.unc_cp, "var") = acov
  }
  ans$M = M
  ans$w = w
  ans$xm.red = xm.red
  ans$ne = ne
  ans$empty_cells = empty_cells
  #ans$small_cells = small_cells
  ans$obs_cells = obs_cells
  ans$zeros = zeros
  ans$Nd = Nd
  ans$Nhat = Nhat
  ans$amat_0 = amat_0
  ans$amat_ci = amat_ci
  #very little difference between weights and prior.weights?
  ans$weights = weights
  ans$prior.weights = prior.weights
  ans$ans.unc = ans.unc
  ans$ans.unc_cp = ans.unc_cp
  ans$offset = ans.unc$offset
  ans$contrasts = ans.unc$contrasts
  ans$n.mix = n.mix
  return (ans)
}


####################################
#inherit from cgam
#apply plotpersp to a csvy.fit object
#####################################
# plotpersp = function(object,...) {
# #x1nm = deparse(substitute(x1))
# #x2nm = deparse(substitute(x2))
#  UseMethod("plotpersp", object)
# }
get_varname <- function(expr) {
  if (is.call(expr)) {
    # Recursive extraction of the first symbol in a call
    return(get_varname(expr[[2]]))
  }
  if (is.name(expr)) {
    return(as.character(expr))
  }
  stop("Unsupported expression for variable name extraction.")
}

# get_complete_design <- function(formula, design) {
#   mf <- model.frame(formula, data = design$variables, na.action = na.omit)
#   rows_kept <- rownames(mf)
#   design_new <- subset(design, rownames(design$variables) %in% rows_kept)
#   list(model_frame = mf, design = design_new)
# }

get_complete_design <- function(formula, design) {
  #test!
  if(!is.null(design$variables)){
    dat <- design$variables
  } else if(!is.null(design$call$data)){
    dat <- design$call$data
  } else {
    stop("no idea about how to access variables in the design!")
  }
  mf <- model.frame(formula, data = dat, na.action = na.omit)
  rows_kept <- rownames(mf)
  design_new <- subset(design, rownames(dat) %in% rows_kept)
  list(model_frame = mf, design = design_new)
}

get_design_variables <- function(design) {
  if (!is.null(design$variables)) {
    return(design$variables)
  } else if (!is.null(design$call$data)) {
    dat <- try(eval(design$call$data, envir = parent.frame()), silent = TRUE)
    if (!inherits(dat, "try-error")) return(dat)
  }
  mf <- tryCatch(model.frame(design), error = function(e) NULL)
  return(mf)
}

safe_svyby <- function(formula, by, design, family = c("gaussian", "binomial"),
                       FUN = svymean,
                       drop.empty.groups = FALSE,
                       vartype = "se", keep.var = TRUE, keep.names = TRUE,
                       verbose = FALSE, na.rm.by = FALSE, na.rm.all = FALSE, ...) {
  family <- match.arg(family)
  
  if (family == "gaussian") {
    tryCatch({
      svyby(formula = formula, by = by, design = design, FUN = FUN,
            covmat = TRUE, drop.empty.groups = drop.empty.groups, ...)
    }, error = function(e) {
      warning("svyby() failed with covmat = TRUE: ", conditionMessage(e),
              ". Falling back to covmat = FALSE.")
      svyby(formula = formula, by = by, design = design, FUN = FUN,
            covmat = FALSE, drop.empty.groups = drop.empty.groups, ...)
    })
  } else if (family == "binomial") {
    tryCatch({
      suppressWarnings(
        svyby(formula = formula, by = by, design = design, FUN = svymean,
              keep.var = keep.var, keep.names = keep.names, verbose = verbose,
              vartype = vartype, drop.empty.groups = drop.empty.groups,
              covmat = TRUE, na.rm.by = na.rm.by, na.rm.all = na.rm.all, ...)
      )
    }, error = function(e) {
      warning("svyby() failed with covmat = TRUE: ", conditionMessage(e),
              ". Falling back to covmat = FALSE.")
      suppressWarnings(
        svyby(formula = formula, by = by, design = design, FUN = svymean,
              keep.var = keep.var, keep.names = keep.names, verbose = verbose,
              vartype = vartype, drop.empty.groups = drop.empty.groups,
              covmat = FALSE, na.rm.by = na.rm.by, na.rm.all = na.rm.all, ...)
      )
    })
  }
}


match_rows <- function(xm, grid2) {
  stopifnot(is.data.frame(xm), is.data.frame(grid2))
  # Load data.table only inside function
  if (!requireNamespace("data.table", quietly = TRUE)) {
    stop("The 'data.table' package is required. Please install it with install.packages('data.table')")
  }
  # Convert to data.table
  dt_xm <- as.data.table(xm)
  dt_grid <- as.data.table(grid2)
  # Ensure column names match and are in the same order
  common_vars <- intersect(names(dt_xm), names(dt_grid))
  dt_xm <- dt_xm[, ..common_vars]
  dt_grid <- dt_grid[, ..common_vars]
  # Add row index to grid
  dt_grid[, grid_id := .I]
  # Set key for fast join
  data.table::setkeyv(dt_grid, common_vars)
  # Join and return index
  match_ids <- dt_grid[dt_xm, on = common_vars]$grid_id
  return(match_ids)
}

compute_mixture <- function(iloop, ysims, amat, w)
{
  ysim <- ysims[iloop, ]
  anssim <- coneA(ysim, amat, w=w, msg=FALSE)
  face <-  anssim$face
  sec <- 1:nrow(amat)*0
  if (length(face) > 0){
    sec[face] <- 1
  } 
  rslt <- sec
  return (rslt)
}

compute_bstat <- function(iloop, zsims, atil){
  zsim <- zsims[iloop, ]
  anssim <- coneA(zsim, atil, msg=FALSE)
  face <- anssim$face
  rslt <- length(face)
  return (rslt)
}

#--------------------------------------------------
#ci can be ci or upp
#need to include block ordering
#not used anymore
#--------------------------------------------------
# fix_monotone_1D = function(ci, sh, nd,...){
#   check_ps_ci = round(diff(ci), 6)
#   bool_ci1 = any(check_ps_ci < 0) & sh == 1
#   bool_ci2 = any(check_ps_ci > 0) & sh == 2
#   not_monotone_ci = ifelse(bool_ci1 | bool_ci2, TRUE, FALSE)
#   nci = length(ci)
#   
#   nrep = 0
#   while(not_monotone_ci & nrep < 20) {
#     nrep = nrep + 1
#     if (sh == 1) {
#       check_id_ci = which(check_ps_ci < 0)
#       n_check_ci = length(check_id_ci)
#       for(k in 1:n_check_ci){
#         i = check_id_ci[k]
#         if (any(which(ci[(i + 1):nci] < ci[i]))){
#           ps_left = i
#           check_ps = which(ci[(i + 1):nci] < ci[i])
#           ps_right = max(((i+1):nci)[check_ps])
#           nd_lr = nd[ps_left:ps_right]
#           if (any(nd_lr > 0)){
#             ws = nd_lr / sum(nd_lr)
#             ci[ps_left:ps_right] = sum(ws * ci[ps_left:ps_right])
#           }   
#         }
#       }
#     }
#     
#     if (sh == 2){
#       check_id_ci = which(check_ps_ci > 0)
#       n_check_ci = length(check_id_ci)
#       for(k in n_check_ci:1){
#         #the domain that violates the constraint is the ith domain
#         i = check_id_ci[k]
#         if (any(which(ci[1:i] < ci[i + 1]))){
#           check_ps = which(ci[1:i] < ci[i + 1])
#           ps_left = min((1:i)[check_ps])
#           ps_right = i + 1
#           nd_lr = nd[ps_left:ps_right]
#           if (any(nd_lr > 0)){
#             ws = nd_lr / sum(nd_lr)
#             ci[ps_left:ps_right] = sum(ws * ci[ps_left:ps_right])
#           }       
#         }
#       }
#     }
#     
#     check_ps_ci = round(diff(ci), 6)
#     bool_ci1 = any(check_ps_ci < 0) & sh == 1
#     bool_ci2 = any(check_ps_ci > 0) & sh == 2
#     not_monotone_ci = ifelse(bool_ci1 | bool_ci2, TRUE, FALSE)
#   }
#   return (ci)
# }

#--------------------------------------------------
# fix_monotone_2D = function(amat_ci, amat, ci, grid, Ds, nd, sh, sh_all, xvs2,...){
#   if(all(sh == 1)) {
#     ci = fix_monotone_2D_0(amat, ci, grid, Ds, nd)$ci
#   }
#   if(all(sh == 2)){
#     ci = fix_monotone_2D_0(-amat, -ci, grid, Ds, nd)$ci
#     ci = -ci
#   }
#   if(any(sh == 1) & any(sh == 2)){
#     nD = length(Ds)
#     decr_ps = which(sh_all == 2) #sh_all include sh = 0, xvs2 include sh = 0
#     
#     xvs2_0 = xvs2
#     for(i in decr_ps){
#       if(is.matrix(xvs2_0)){
#         xvs2_0[,i] = rev(xvs2_0[,i])
#       }else if(is.list(xvs2_0)){
#         xvs2_0[[i]] = rev(xvs2[[i]])
#       }
#     }
#     
#     if(is.matrix(xvs2_0)){
#       grid_rev = expand.grid(as.data.frame(xvs2_0))
#     }else if(is.list(xvs2_0)){
#       grid_rev = expand.grid(xvs2_0)
#     }
#     
#     ps_rev = apply(grid_rev, 1, function(elem) which(apply(grid, 1, function(gi) all(gi == elem))))
#     ci_rev = ci[ps_rev]
#     
#     amat_rev = amat
#     sh_track = NULL
#     for(i in 1:nD){
#       amati = amat_ci[[i]]
#       shi = sh[i]
#       sh_track = c(sh_track, rep(shi, nrow(amati)))
#     }
#     amat_rev[which(sh_track == 2), ] = -amat[which(sh_track == 2), , drop = FALSE]
#     
#     nd_rev = nd[ps_rev]
#     Ds_rev = Ds #?
#     ci_rev = fix_monotone_2D_0(amat_rev, ci_rev, grid_rev, Ds, nd_rev)$ci
#     ci = ci_rev[ps_rev]
#   }
#   return (ci)
# }

#--------------------------------------------------
# fix_monotone_2D_0 = function(amat, ci, grid, Ds, nd,...){
#   sm = 1e-5
#   check_ps_ci = round(c(amat %*% ci), 5) #4th or 5th digit?
#   not_monotone_ci = any(check_ps_ci < 0)
#   nrep = 0
#   upp_limit = 1000 # should be 2^D1 * 2^D2?
#   clusters_merged = list()
#   
#   while(not_monotone_ci & nrep < upp_limit){
#     nrep = nrep + 1
#     check_id_ci = which(check_ps_ci < 0)
#     amat_not_mono = amat[check_id_ci, ,drop = FALSE]
#     grid_ci = cbind(grid, ci, t(amat_not_mono)) #no use, for visual check only
#     #different from 1D case
#     ps_left_cands = which(apply(amat_not_mono, 2, function(e) any(-1 %in% e)))
#     ps_left_cands = sort(ps_left_cands)
#     n_lefts = length(ps_left_cands)
#     
#     rows_lefts = NULL
#     clusters = vector('list', length = n_lefts) #each cluster will be weight-averaged
#     for(k in 1:n_lefts){
#       ps_left_k = ps_left_cands[k]
#       colk = amat_not_mono[, ps_left_k, drop = TRUE]
#       row_ps = which(colk == -1)
#       rows_lefts = c(rows_lefts, row_ps)
#       rowk = amat_not_mono[row_ps, ,drop = FALSE] #need to move along rowk to find 1
#       ps_right_k = apply(rowk, 1, function(es) max(which(es == 1)))
#       clusters[[k]] = c(ps_left_k, ps_right_k)
#     }
#     
#     #make a copy
#     clusters_0 = clusters
#     #merge elements in clusters which have common domains
#     #clusters_merged = list()
#     iter = 1
#     nc = length(clusters)
#     
#     #----------------------------------------------------------------------------------------------------
#     #found here: https://stackoverflow.com/questions/47322126/merging-list-with-common-elements
#     #----------------------------------------------------------------------------------------------------
#     i = rep(1:length(clusters_0), lengths(clusters_0))
#     j = factor(unlist(clusters_0))
#     tab = sparseMatrix(i = i, j = as.integer(j), x = TRUE, dimnames = list(NULL, levels(j)))
#     connects = Matrix::tcrossprod(tab, boolArith = TRUE)
#     group = clusters(graph_from_adjacency_matrix(connects, mode = "undirected"))$membership
#     clusters_merged = tapply(clusters_0, group, function(x) sort(unique(unlist(x))))
#     
#     nc_merged = length(clusters_merged)
#     for(k in 1:nc_merged){
#       ps_lr = clusters_merged[[k]]
#       #print (ps_lr)
#       nd_lr = nd[ps_lr]
#       #print (nd_lr)
#       if (any(nd_lr > 0)){
#         ws = nd_lr / sum(nd_lr)
#         ci[ps_lr] = sum(ws * ci[ps_lr])
#         #temp
#         if(any(nd_lr == 0)){
#           nd_lr[which(nd_lr == 0)] = round(mean(nd_lr[which(nd_lr != 0)]))
#           nd[ps_lr] = nd_lr
#         }
#       }   
#       
#       #temp: difficult case, the merged cluster consists of zero cells only
#       if(all(nd_lr == 0)){
#         ci_left_most = ci[ps_lr[1]]
#         ci[ps_lr] = ci_left_most
#       }
#     }
#     check_ps_ci = round(c(amat %*% ci), 5)
#     not_monotone_ci = any(check_ps_ci < 0)
#   }
#   #print (nrep)
#   #rslt = list(ci = ci, nrep = nrep, clusters_merged = clusters_merged)
#   rslt = list(ci = ci, nrep = nrep)
#   return (rslt)
# }

###################################################
#marginal plot function with some helper functions
###################################################
capitalize_first <- function(x) {
  paste0(toupper(substr(x, 1, 1)), tolower(substr(x, 2, nchar(x))))
}

recode_ordered_factor_to_numeric <- function(x, levels_wanted) {
  # Drop unused levels and coerce to ordered factor
  x <- factor(x, levels = levels_wanted, ordered = TRUE)
  # Replace with integer codes (1, 2, 3, ...)
  as.numeric(x)
}

reorder_outputs <- function(object, x1nm, x2nm, x3nm = NULL, grid=NULL, keep_id=NULL, 
                            is_constrained = TRUE, ci = TRUE, type=c("constrained", "unconstrained", "both")) {
  #linkinv <- object$family$linkinv
  linkinv <- family(object)$linkinv
  grid_old <- grid %||% object$grid
  ps <- keep_id %||% 1:NROW(object$grid)
  xvs2 <- object$xvs2
  nxvs <- length(xvs2)
  xnms_add <- object$xnms_add
  other_xnms <- setdiff(xnms_add, c(x1nm, x2nm))
  
  if(nxvs == 2){
    grid_new <- expand.grid(xvs2[[x1nm]], xvs2[[x2nm]])
    colnames(grid_new) <- c(x1nm, x2nm)
  } else if (nxvs >= 3){
    grid_new <- expand.grid(xvs2[[x3nm]], xvs2[[x1nm]], xvs2[[x2nm]])
    colnames(grid_new) <- c(x3nm, x1nm, x2nm)
  } 
  
  index <- match(
    apply(grid_new, 1, paste, collapse = "_"),
    apply(grid_old[, colnames(grid_new)], 1, paste, collapse = "_")
  )
  
  type = match.arg(type)
  if(is.null(object$lwr) || !ci){#use lwru if n.mix=0
  #if(object$n.mix == 0){
    lwr = NA_real_
  } else {
    lwr = linkinv(if (is_constrained) object$lwr[ps] else object$lwru[ps])[index]
  }
  
  if(is.null(object$upp) || !ci){
  #if(object$n.mix == 0){
    upp = NA_real_
  } else {
    upp = linkinv(if (is_constrained) object$upp[ps] else object$uppu[ps])[index]
  }
  
  list(
    grid = grid_new,
    muhat = linkinv(if (is_constrained) object$etahat[ps] else object$etahatu[ps])[index],
    nd = (object$nd[ps])[index],
    lwr = lwr, 
    upp = upp
    #lwr =  linkinv(if (is_constrained) object$lwr[ps]    else object$lwru[ps])[index],
    #upp =  linkinv(if (is_constrained) object$upp[ps]    else object$uppu[ps])[index]
  )
}

#########################################
#new plotpersp
#with a control function
#########################################
plotpersp_csvy_control <- function(surface = c("C", "U"), x1nm = NULL, x2nm = NULL, categ = NULL,
                                   col = NULL, random = FALSE, ngrid = 12,
                                   xlim = NULL, ylim = NULL, zlim = NULL,
                                   xlab = NULL, ylab = NULL, zlab = NULL,
                                   th = NULL, ltheta = NULL, main = NULL,
                                   sub = NULL, ticktype = "simple",
                                   # New csvy-specific controls:
                                   ci = c("none", "lwr", "up", "both"), 
                                   cex = 1, categnm = NULL, type = c("response", "link"),
                                   cex.main = 0.8, box = TRUE, axes = TRUE,
                                   nticks = 5, palette = NULL, NCOL = NULL, transpose = FALSE) {
  # Step 1: Base control list from plotpersp_control()
  base_control <- cgam::plotpersp_control()
  # Step 2: Add extra csvy-specific fields
  csvy_control <- list(surface = match.arg(surface), ci = match.arg(ci), cex = cex, categnm = categnm, type = match.arg(type), 
                       cex.main = cex.main, box = box, axes = axes, nticks = nticks, palette = palette, NCOL=NCOL, transpose=transpose)
  # Step 3: Merge base and extra
  lst <- modifyList(base_control, csvy_control)
  return (lst)
}

#plotpersp.csvy = function(object, x1=NULL, x2=NULL, x1nm = NULL, x2nm = NULL, data = NULL, 
#                          ci = c("none", "lwr", "up", "both"),
#                          transpose = FALSE, main = NULL, categ = NULL, categnm = NULL, surface = c("C", "U"),
#                          type = c("link", "response"), col = "white", cex.main = 0.8, xlab = NULL, ylab = NULL,
#                          zlab = NULL, zlim = NULL, box = TRUE, axes = TRUE, th = NULL, ltheta = NULL, 
#                          ticktype = "detailed",
#                          nticks = 5, palette = NULL, NCOL = NULL,...) {
plotpersp.csvy = function(object, x1 = NULL, x2 = NULL, control = plotpersp_csvy_control(),...) {
  if (!inherits(object, "csvy")) {
    warning("calling plotpersp(<fake-csvy-object>) ...")
  }
  defaults <- plotpersp_csvy_control()
  control <- modifyList(defaults, control)
  x1nm   <- control$x1nm
  x2nm   <- control$x2nm
  surface <- control$surface
  categ  <- control$categ
  col    <- control$col
  random <- control$random
  ngrid  <- control$ngrid
  xlim   <- control$xlim
  ylim   <- control$ylim
  zlim   <- control$zlim
  xlab   <- control$xlab
  ylab   <- control$ylab
  zlab   <- control$zlab
  th     <- control$th
  ltheta <- control$ltheta
  main   <- control$main
  ticktype <- control$ticktype
  sub <- control$sub

  ci <- control$ci
  cex <- control$cex
  categnm <- control$categnm
  type <- control$type
  cex.main <- control$cex.main
  box <- control$box
  axes <- control$axes
  nticks <- control$nticks
  palette <- control$palette
  NCOL <- control$NCOL
  transpose <- control$transpose

  x1nm <- x1nm %||% if (!is.null(x1)) { 
    if (is.character(x1)) x1 else deparse(substitute(x1)) 
  }
  x2nm <- x2nm %||% if (!is.null(x2)) { 
    if (is.character(x2)) x2 else deparse(substitute(x2)) 
  }
  
  xnms = object$xnms_add
  shp = object$shapes_add
  Ds = object$Ds
  Nd = length(Ds)
  xmat = object$xmat_add
  bool = apply(xmat, 2, function(x) is.numeric(x))
  if(any(!bool)){
    xmat = apply(xmat, 2, function(x) as.numeric(x))
  }
  ynm = object$ynm
  family = object$family
  etahat = object$etahat
  lwr = object$lwr
  upp = object$upp
  grid = object$grid
  data = object$data
  main = ifelse(is.null(main), "Constrained Fit", main)
  
  switch(surface, U = {
    etahat = object$etahatu
    lwr = object$lwru
    upp = object$uppu
    main = "Unconstrained Fit"
  }, C = )
  
  switch(type, response = {
    etahat = family$linkinv(etahat)
    lwr = family$linkinv(lwr)
    upp = family$linkinv(upp)
  }, link = )
  
  obs = 1:Nd
  if (x1nm == "NULL" | x2nm == "NULL") {
    if (length(xnms) >= 2) {
      if (is.null(categ)) {
        x1nm = xnms[1]
        x2nm = xnms[2]
        x1_id = 1
        x2_id = 2
      } else {
        if (!is.character(categ)) {
          stop("categ must be a character argument!")
        } else if (any(grepl(categ, xnms))) {
          id = which(grepl(categ, xnms))
          x12nms = xnms[-id]
          x12ids = obs[-id]
          x1nm = x12nms[1]
          x2nm = x12nms[2]
          x1_id = x12ids[1]
          x2_id = x12ids[2]
        }  
      }
    } else {stop ("Number of non-parametric predictors must >= 2!")}
  } else {
    #new: use the data as the environment for x1 and x2
    x1 = data[[x1nm]]; x2 = data[[x2nm]]
    if (all(xnms != x1nm)) {
      if (length(x1) != nrow(xmat)) {
        stop(cat("Number of observations in the data set is not the same as the number of elements in x1!"), "\n")
      }
      bool = apply(xmat, 2, function(x) all(x1 == x))
      if (any(bool)) {
        x1_id = obs[bool]
        #change x1nm to be the one in formula
        x1nm = xnms[bool]
      } else {
        stop (cat(paste("'", x1nm, "'", sep = ''), "is not a predictor defined in the model!"), "\n")
      }
    } else {
      x1_id = obs[xnms == x1nm]
    }
    
    if (all(xnms != x2nm)) {
      if (length(x2) != nrow(xmat)) {
        stop (cat("Number of observations in the data set is not the same as the number of elements in x2!"), "\n")
      }
      bool = apply(xmat, 2, function(x) all(x2 == x))
      if (any(bool)) {
        x2_id = obs[bool]
        x2nm = xnms[bool]
      } else {
        stop (cat(paste("'", x2nm, "'", sep = ''), "is not a predictor defined in the model!", "\n"))
      }
    } else {
      x2_id = obs[xnms == x2nm]
    }
  }
  #xmat keeps the order of x's in the original data frame, but x1 and x2 may change depending on x1_id and x2_id
  
  #x12_id = c(2, 3)
  x12_id = c(x1_id, x2_id)
  # x1_id = x12_id[1]
  # x2_id = x12_id[2]
  if (Nd > 2) {
    x3_ids = obs[-x12_id] 
  }
  shp_12 = shp[x12_id]
  
  x1nm = xnms[x1_id]
  x2nm = xnms[x2_id]
  if (Nd > 2) {
    x3nms = xnms[x3_ids]
  }
  
  x1 = xmat[, x1_id]
  x2 = xmat[, x2_id]
  if (Nd > 2) {
    x3s = xmat[, x3_ids, drop = FALSE] #need change
    x3_vals = apply(x3s, 2, median)
  }
  
  if (Nd > 2) {
    if(is.null(categ)) {
      ps = 1:nrow(grid) #track rows with the median values of columns that are not x1 or x2
      nx3 = length(x3_vals)
      for(k in 1:nx3){
        x3_val_k = x3_vals[k]
        ps_k = which(grid[, x3_ids[k]] %in% x3_val_k)
        ps = intersect(ps, ps_k)
      }
      surf.etahat = matrix(etahat[ps], Ds[x1_id], Ds[x2_id])
      surf.lwr = matrix(lwr[ps], Ds[x1_id], Ds[x2_id])
      surf.upp = matrix(upp[ps], Ds[x1_id], Ds[x2_id])
    } else {
      if (!is.character(categ)) {
        stop("categ must be a character argument!")
      } else if (any(grepl(categ, x3nms))) {
        x3_id = which(xnms %in% categ)
        x3nm = xnms[x3_id]
        if (x3nm %in% c(x1nm, x2nm)) {
          stop("categ must be different than x1 and x2!")
        }
        x3p = 1:Ds[x3_id]
        nsurf = length(x3p)
        surf.etahat = vector('list', nsurf)
        surf.lwr = vector('list', nsurf)
        surf.upp = vector('list', nsurf)
        
        #test more
        ps = 1:nrow(grid) #track rows with the median values of columns that are not x1 or x2
        x3_others_id = which(!x3nms %in% x3nm)
        x3_vals_others = x3_vals[x3_others_id]
        nx3 = length(x3_others_id)
        if(nx3 >= 1){#used when there are more than 3 x's
          for(k in 1:nx3){
            x3_other_val_k = x3_vals_others[k]
            colk = x3_ids[x3_others_id[k]]
            ps_k = which(grid[, colk] %in% x3_other_val_k)
            ps = intersect(ps, ps_k)
          }
        }
        
        for(k in 1:nsurf){
          x3_val_k = x3p[k]
          ps_k = intersect(ps, which(grid[, x3_id] %in% x3_val_k))
          surf.etahat[[k]] = matrix(etahat[ps_k], Ds[x1_id], Ds[x2_id])
          surf.lwr[[k]] = matrix(lwr[ps_k], Ds[x1_id], Ds[x2_id])
          surf.upp[[k]] = matrix(upp[ps_k], Ds[x1_id], Ds[x2_id])
        }
      } else {
        stop(cat(categ, "is not an exact character name defined in the csurvey fit!", "\n"))
      }
      
      if(is.null(palette)){
        palette = c("peachpuff", "lightblue", "limegreen", "grey", "wheat", "yellowgreen", 
                    "seagreen1", "palegreen", "azure", "whitesmoke")
      }
      if (nsurf > 1 && nsurf < 11) {
        col = palette[1:nsurf]
      } else {
        #new: use rainbow
        col = topo.colors(nsurf)
      }
      width = nsurf^.5
      wd = round(width)
      if (width > wd) {
        wd = wd + 1
      }  
      if ((wd^2 - nsurf) >= wd ) {
        fm = c(wd, wd - 1)
      } else {
        fm = rep(wd, 2)
      }
      if (wd > 3) {
        cex = .7
        cex.main = .8
      }
      if (transpose) {
        fm = rev(fm)
      }
    }
  } else if (Nd == 2) {
    surf.etahat = matrix(etahat, Ds[x1_id], Ds[x2_id])
    surf.lwr = matrix(lwr, Ds[x1_id], Ds[x2_id])
    surf.upp = matrix(upp, Ds[x1_id], Ds[x2_id])
  } else if (Nd == 1) {
    stop('persp plot only works when there are >= 2 predictors!')
  }
  
  t_col = function(color, percent = 80, name = NULL) {
    rgb.val = col2rgb(color)
    ## Make new color using input color as base and alpha set by trobjectparency
    t.col = rgb(rgb.val[1], rgb.val[2], rgb.val[3],
                maxColorValue = 255,
                alpha = (100-percent)*255/100,names = name)
    ## Save the color
    invisible(t.col)
  }
  col.upp = t_col("green", perc = 90, name = "lt.green")
  col.lwr = t_col("pink", perc = 80, name = "lt.pink")
  
  if (is.null(xlab)) {
    #xlab = deparse(x1nm)
    xlab = x1nm
  }
  if (is.null(ylab)) {
    #ylab = deparse(x2nm)
    ylab = x2nm
  }
  if (is.null(zlab)) {
    if("response" %in% type){
      zlab = paste("Est mean of", ynm)
    } 
    if("link" %in% type){
      zlab = paste("Est systematic component of", ynm)
    }
  }
  
  ang = NULL
  if (is.null(th)) {
    if (all(shp_12 == 1)) {
      ang = -40
    } else if (all(shp_12 == 2)){
      ang = 40
    } else if (shp_12[1] == 1 & shp_12[2] == 2) {
      ang = -50
    } else if (shp_12[1] == 2 & shp_12[2] == 1) {
      ang = -230
    } else {
      ang = -40
    }
  } else {
    ang = th
  }
  
  x1p = 1:Ds[x1_id]
  x2p = 1:Ds[x2_id]
  
  if (is.list(surf.upp)) {
    zlim0 = vector('list', length = nsurf)
    for(i in 1:nsurf){
      if (ci != "none") {
        rg = sapply(surf.upp, max) - sapply(surf.lwr, min)
        zlim0[[i]] = c(min(surf.lwr[[i]]) - rg[i]/8, max(surf.upp[[i]]) + rg[i]/8)
      } else {
        rg = sapply(surf.etahat, max) - sapply(surf.etahat, min)
        zlim0[[i]] = c(min(surf.etahat[[i]])  - rg[i]/10, max(surf.etahat[[i]]) + rg[i]/10)
      }     
    }
    #zlim0 = c(min(surf.lwr[[1]]) - rg/5, max(surf.upp[[1]]) + rg/5)
    #print (zlim0)
  } else {
    rg = max(surf.upp) - min(surf.lwr)
    zlim0 = c(min(surf.lwr) - rg/5, max(surf.upp) + rg/5)
  }
  
  if (is.list(surf.etahat)) {#this means categ is not null
    if(is.null(NCOL)){
      old.par = par(mfrow = c(fm[1], fm[2]))
    } else {
      nr = fm[1]*fm[2]/NCOL
      old.par = par(mfrow = c(nr, NCOL))
    }
    on.exit(par(old.par))
    par(mar = c(4, 1, 1, 1))
    par(cex.main = cex.main)
    for(i in 1:nsurf){
      if (is.null(categnm)) {
        persp(x1p, x2p, surf.etahat[[i]], zlim = zlim0[[i]], col=col[i], xlab = xlab,
              ylab = ylab, zlab = zlab, main = paste0(x3nm, " = ", x3p[i]), 
              theta = ang, ltheta = -135, ticktype = ticktype, box=box, axes=axes, nticks=nticks)
        if (ci == "lwr" | ci == "both") {
          par(new = TRUE)
          persp(x1p, x2p, surf.lwr[[i]], zlim = zlim0[[i]], col=col.lwr, theta = ang, box=FALSE, axes=FALSE)
        }
        if (ci == "up" | ci == "both") {
          par(new = TRUE)
          persp(x1p, x2p, surf.upp[[i]], zlim = zlim0[[i]], col=col.upp, theta = ang, box=FALSE, axes=FALSE)
        }
      } else {
        if (length(categnm) == nsurf) {
          persp(x1p, x2p, surf.etahat[[i]], zlim = zlim0[[i]], col=col[i], xlab = xlab,
                ylab = ylab, zlab = zlab, main = categnm[i], 
                theta = ang, ltheta = -135, ticktype = ticktype, box=box, axes=axes, nticks=nticks)
          if (ci == "lwr" | ci == "both") {
            par(new = TRUE)
            persp(x1p, x2p, surf.lwr[[i]], zlim = zlim0[[i]], col=col.lwr, theta = ang, box=FALSE, axes=FALSE)
          }
          if (ci == "up" | ci == "both") {
            par(new = TRUE)
            persp(x1p, x2p, surf.upp[[i]], zlim = zlim0[[i]], col=col.upp, theta = ang, box=FALSE, axes=FALSE)
          }
        } else if (length(categnm) == 1) {
          persp(x1p, x2p, surf.etahat[[i]], zlim = zlim0[[i]], col=col[i], xlab = xlab,
                ylab = ylab, zlab = zlab, main = paste0(categnm, " = ", x3p[i]), 
                theta = ang, ltheta = -135, ticktype = ticktype, box=box, axes=axes, nticks=nticks)
          if (ci == "lwr" | ci == "both") {
            par(new = TRUE)
            persp(x1p, x2p, surf.lwr[[i]], zlim = zlim0[[i]], col=col.lwr, theta = ang, box=FALSE, axes=FALSE)
          }
          if (ci == "up" | ci == "both") {
            par(new = TRUE)
            persp(x1p, x2p, surf.upp[[i]], zlim = zlim0[[i]], col=col.upp, theta = ang, box=FALSE, axes=FALSE)
          }
        }
      }
    }
  } else {
    #old.par = ifelse(ci == "both", par(mfrow=c(1, 2)),  par(mfrow=c(1, 1)))
    if (ci == "both") {old.par = par(mfrow=c(1, 2))} else {old.par = par(mfrow=c(1, 1))}
    persp(x1p, x2p, surf.etahat, zlim = zlim0, 
          col=col, xlab = xlab, ylab = ylab, zlab = zlab, main = main,
          theta = ang, ltheta = -135, ticktype = ticktype, box=box, axes=axes, nticks=nticks)
    if (ci == "up") {
      par(new = TRUE)
      persp(x1p, x2p, surf.upp, zlim = zlim0, col=col.upp, theta = ang, box=FALSE, axes=FALSE)
    }
    #par(new = FALSE)
    
    if (ci == "lwr") {
      par(new = TRUE)
      persp(x1p, x2p, surf.lwr, zlim = zlim0, col=col.lwr, theta = ang, box=FALSE, axes=FALSE)
    }
    
    if (ci == "both") {
      par(new = TRUE)
      persp(x1p, x2p, surf.upp, zlim = zlim0, col=col.upp, theta = ang, box=FALSE, axes=FALSE)
      par(new=FALSE)
      
      persp(x1p, x2p, surf.etahat, zlim = zlim0, 
            col=col, xlab = xlab, ylab = ylab, zlab = zlab, main = main,
            theta = ang, ltheta = -135, ticktype = ticktype, box=box, axes=axes, nticks=nticks)
      par(new = TRUE)
      persp(x1p, x2p, surf.lwr, zlim = zlim0, col=col.lwr, theta = ang, box=FALSE, axes=FALSE)
    }
    
    par(new=FALSE)
    on.exit(par(old.par))
  }
  rslt = list(surf.etahat = surf.etahat, surf.upp = surf.upp, surf.lwr = surf.lwr, zlim = zlim0, xlab = xlab, ylab = ylab, zlab = zlab, theta = ang, 
              ltheta = ltheta, col = col, cex.axis = .75, main = main, ticktype = ticktype)
  invisible(rslt)
}

#########################################
#subroutines for confidence interval
#########################################
makebin = function(x) {
  k = length(x)
  r = 0
  for (i in 1:k) {
    r = r + x[k-i+1]*2^(i-1)
  }
  r
}

getvars = function(num) {
  i = num
  digits = 0
  power = 0
  while (digits == 0) {
    if (i < 2^power) {
      digits = power
    }
    power = power+1
  }
  binry = 1:digits*0
  if (num > 0) {
    binry[1] = 1
  }
  i = i - 2^(digits - 1)
  power = digits - 2
  for (p in power:0) {
    if (i >= 2^p) {
      i = i - 2^p
      binry[digits-p] = 1
    }
  }
  binry
}

getbin = function(num, capl) {
  br = getvars(num-1)
  digits = length(br)
  binrep = 1:capl*0
  binrep[(capl-digits+1):capl] = br
  binrep
}

##############################################
#eight shape functions
#add constr for the user-defined amat case
##############################################
# constr = function(x, numknots = 0, knots = 0, space = "E")
# {
#   cl = match.call()
#   pars = match.call()[-1]
#   attr(x, "nm") = deparse(pars$x)
#   attr(x, "shape") = 0
#   attr(x, "numknots") = numknots
#   attr(x, "knots") = knots
#   attr(x, "space") = space
#   attr(x, "categ") = "additive"
#   #class(x) = "additive"
#   x
# }

# incr = function(x, numknots = 0, knots = 0, space = "E")
# {
#   cl = match.call()
#   pars = match.call()[-1]
#   attr(x, "nm") = deparse(pars$x)
#   attr(x, "shape") = 1
#   attr(x, "numknots") = numknots
#   attr(x, "knots") = knots
#   attr(x, "space") = space
#   attr(x, "categ") = "additive"
#   #class(x) = "additive"
#   x
# }

#relaxed increasing
# relax.incr = function(x, numknots = 0, knots = 0, space = "E")
# {
#   cl = match.call()
#   pars = match.call()[-1]
#   attr(x, "nm") = deparse(pars$x)
#   attr(x, "shape") = 1
#   attr(x, "numknots") = numknots
#   attr(x, "knots") = knots
#   attr(x, "space") = space
#   attr(x, "categ") = "additive"
#   attr(x, "relax") = TRUE
#   #class(x) = "additive"
#   x
# }

# decr = function(x, numknots = 0, knots = 0, space = "E")
# {
#   cl = match.call()
#   pars = match.call()[-1]
#   attr(x, "nm") = deparse(pars$x)
#   attr(x, "shape") = 2
#   attr(x, "numknots") = numknots
#   attr(x, "knots") = knots
#   attr(x, "space") = space
#   attr(x, "categ") = "additive"
#   #class(x) = "additive"
#   x
# }

#relaxed decreasing
# relax.decr = function(x, numknots = 0, knots = 0, space = "E")
# {
#   cl = match.call()
#   pars = match.call()[-1]
#   attr(x, "nm") = deparse(pars$x)
#   attr(x, "shape") = 2
#   attr(x, "numknots") = numknots
#   attr(x, "knots") = knots
#   attr(x, "space") = space
#   attr(x, "categ") = "additive"
#   attr(x, "relax") = TRUE
#   #class(x) = "additive"
#   x
# }

# conv = function(x, numknots = 0, knots = 0, space = "E")
# {
#   cl = match.call()
#   pars = match.call()[-1]
#   attr(x, "nm") = deparse(pars$x)
#   attr(x, "shape") = 3
#   attr(x, "numknots") = numknots
#   attr(x, "knots") = knots
#   attr(x, "space") = space
#   attr(x, "categ") = "additive"
#   #class(x) = "additive"
#   x
# }

# conc = function(x, numknots = 0, knots = 0, space = "E")
# {
#   cl = match.call()
#   pars = match.call()[-1]
#   attr(x, "nm") = deparse(pars$x)
#   attr(x, "shape") = 4
#   attr(x, "numknots") = numknots
#   attr(x, "knots") = knots
#   attr(x, "space") = space
#   attr(x, "categ") = "additive"
#   #class(x) = "additive"
#   x
# }

# incr.conv = function(x, numknots = 0, knots = 0, space = "E")
# {
#   cl = match.call()
#   pars = match.call()[-1]
#   attr(x, "nm") = deparse(pars$x)
#   attr(x, "shape") = 5
#   attr(x, "numknots") = numknots
#   attr(x, "knots") = knots
#   attr(x, "space") = space
#   attr(x, "categ") = "additive"
#   #class(x) = "additive"
#   x
# }

# decr.conv = function(x, numknots = 0, knots = 0, space = "E")
# {
#   cl = match.call()
#   pars = match.call()[-1]
#   attr(x, "nm") = deparse(pars$x)
#   attr(x, "shape") = 6
#   attr(x, "numknots") = numknots
#   attr(x, "knots") = knots
#   attr(x, "space") = space
#   attr(x, "categ") = "additive"
#   #class(x) = "additive"
#   x
# }

# incr.conc = function(x, numknots = 0, knots = 0, space = "E")
# {
#   cl = match.call()
#   pars = match.call()[-1]
#   attr(x, "nm") = deparse(pars$x)
#   attr(x, "shape") = 7
#   attr(x, "numknots") = numknots
#   attr(x, "knots") = knots
#   attr(x, "space") = space
#   attr(x, "categ") = "additive"
#   #class(x) = "additive"
#   x
# }

# decr.conc = function(x, numknots = 0, knots = 0, space = "E")
# {
#   cl = match.call()
#   pars = match.call()[-1]
#   attr(x, "nm") = deparse(pars$x)
#   attr(x, "shape") = 8
#   attr(x, "numknots") = numknots
#   attr(x, "knots") = knots
#   attr(x, "space") = space
#   attr(x, "categ") = "additive"
#   #class(x) = "additive"
#   x
# }

# incr <- function (x, numknots = 0, knots = 0, space = "E") 
# {
#   cl <- match.call()
#   pars <- match.call()[-1]
#   #attr(x, "nm") <- deparse(pars$x)
#   attr(x, "nm") <- get_varname(pars$x)
#   #attr(x, "nm") <- all.vars(pars$x)[1]
#   #attr(x, "nm") <- rlang::as_name(pars$x)
#   attr(x, "shape") <- 1
#   attr(x, "numknots") <- numknots
#   attr(x, "knots") <- knots
#   attr(x, "space") <- space
#   attr(x, "categ") <- "additive"
#   x
# }

# 
# decr <- function (x, numknots = 0, knots = 0, space = "E") 
# {
#   cl <- match.call()
#   pars <- match.call()[-1]
#   #attr(x, "nm") <- deparse(pars$x)
#   attr(x, "nm") <- get_varname(pars$x)
#   #attr(x, "nm") <- all.vars(pars$x)[1]
#   #attr(x, "nm") <- rlang::as_name(pars$x)
#   attr(x, "shape") <- 2
#   attr(x, "numknots") <- numknots
#   attr(x, "knots") <- knots
#   attr(x, "space") <- space
#   attr(x, "categ") <- "additive"
#   x
# }

block.Ord = function(x, order = NULL, numknots = 0, knots = 0, space = "E")
{
  cl = match.call()
  pars = match.call()[-1]
  #attr(x, "nm") = deparse(pars$x)
  attr(x, "nm") <- get_varname(pars$x)
  #attr(x, "nm") <- all.vars(pars$x)[1]
  #attr(x, "nm") <- rlang::as_name(pars$x)
  attr(x, "shape") = 9
  attr(x, "numknots") = numknots
  attr(x, "knots") = knots
  attr(x, "space") = space
  attr(x, "categ") = "block.ord"
  attr(x, "order") = order
  #class(x) = "additive"
  x
}

# block.Ave = function(x, order = NULL, numknots = 0, knots = 0, space = "E")
# {
#   cl = match.call()
#   pars = match.call()[-1]
#   attr(x, "nm") = deparse(pars$x)
#   attr(x, "shape") = 10
#   attr(x, "numknots") = numknots
#   attr(x, "knots") = knots
#   attr(x, "space") = space
#   attr(x, "categ") = "block.ave"
#   attr(x, "order") = order
#   #class(x) = "additive"
#   x
# }

######
#1D
######
makeamat = function(x, sh, Ds = NULL, suppre = FALSE, interp = FALSE, 
                    relaxes = FALSE, h = NULL, block.ave = NULL, 
                    block.ord = NULL, zeros_ps = NULL, noord_cells = NULL,
                    D_index = 1, xvs2 = NULL, grid2 = NULL) {
  n = length(x)
  xu = sort(unique(x))
  n1 = length(xu)
  sm = 1e-7
  ms = NULL
  obs = 1:n
  Nd = length(Ds)
  M = rev(x)[1]
  if (relaxes) {
    #create a list of amat's first
    d = l = x
    amat = matrix(0, nrow=(M-1), ncol=M)
    for(k in 1:(M-1)){
      amat[k, ] = exp(-abs(l-(k+1))/h)/sum(exp(-abs(d-(k+1))/h))-exp(-abs(l-k)/h)/sum(exp(-abs(d-k)/h))
    }
    #print (amat)
    # return (amat)
  } else {
    #block.ave and block.ord = -1 if sh is between 1 and 9
    if (all(block.ave == -1) & all(block.ord == -1)) {
      #print ('a')
      if (sh < 3) {
        #if sh=0, then keep the zero matrix
        amat = matrix(0, nrow = n1 - 1, ncol = n)
        if (sh > 0) {
          for (i in 1:(n1-1)) {
            c1 = min(obs[abs(x - xu[i]) < sm]); c2 = min(obs[abs(x - xu[i + 1]) < sm])
            amat[i, c1] = -1; amat[i, c2] = 1
          }
          if (sh == 2) {amat = -amat}
        }
        #return (amat)
      } else if (sh == 3 | sh == 4) {
        #  convex or concave
        amat = matrix(0, nrow = n1 - 2 ,ncol = n)
        for (i in 1: (n1 - 2)) {
          c1 = min(obs[x == xu[i]]); c2 = min(obs[x == xu[i+1]]); c3 = min(obs[x == xu[i+2]])
          amat[i, c1] = xu[i+2] - xu[i+1]; amat[i, c2] = xu[i] - xu[i+2]; amat[i, c3] = xu[i+1] - xu[i]
        }
        if (sh == 4) {amat = -amat}
        #return (amat)
      } else if (sh > 4 & sh < 9) {
        amat = matrix(0, nrow = n1 - 1, ncol = n)
        for (i in 1:(n1 - 2)) {
          c1 = min(obs[x == xu[i]]); c2 = min(obs[x == xu[i + 1]]); c3 = min(obs[x == xu[i + 2]])
          amat[i, c1] = xu[i + 2] - xu[i + 1]; amat[i, c2] = xu[i] - xu[i + 2]; amat[i, c3] = xu[i + 1] - xu[i]
        }
        if (sh == 5) { ### increasing convex
          c1 = min(obs[x == xu[1]]); c2 = min(obs[x == xu[2]])
          amat[n1 - 1, c1] = -1; amat[n1 - 1, c2] = 1
          #return (amat)
        } else if (sh == 6) {  ## decreasing convex
          c1 = min(obs[x == xu[n1]]); c2 = min(obs[x == xu[n1 - 1]])
          amat[n1 - 1, c1] = -1; amat[n1 - 1, c2] = 1
          #return (amat)
        } else if (sh == 7) { ## increasing concave
          amat = -amat
          c1 = min(obs[x == xu[n1]]); c2 = min(obs[x == xu[n1 - 1]])
          amat[n1 - 1, c1] = 1; amat[n1 - 1, c2] = -1
          #return (amat)
        } else if (sh == 8) {## decreasing concave
          amat = -amat
          c1 = min(obs[x == xu[1]]); c2 = min(obs[x == xu[2]])
          amat[n1 - 1, c1] = 1; amat[n1 - 1, c2] = -1
          #return (amat)
        }
      } 
      return (amat)
    } else if (all(block.ave == -1) & (length(block.ord) > 1)) {
      #nbcks is the total number of blocks
      block.ord_nz = block.ord
      noord_ps = which(block.ord == 0)
      
      #if(ncol(grid2) == 1){
      rm_id = unique(sort(c(zeros_ps, noord_ps)))
      #} else {
      #  rm_id = unique(sort(noord_ps))
      #}
      if(length(rm_id)>0){
        block.ord_nz = block.ord[-rm_id]
      }
      
      ubck = unique(block.ord_nz)
      #use values in block.ord as an ordered integer vector
      ubck = sort(ubck)
      
      nbcks = length(table(block.ord_nz))
      szs = as.vector(table(block.ord_nz))
      
      #new:
      if(is.list(xvs2)){
        ps = sapply(xvs2[[D_index]], function(.x) which(grid2[, D_index] %in% .x)[1])
      }else if(is.matrix(xvs2)){
        #D_index=1
        #ps = lapply(xvs2, function(.x) which(as.vector(grid2) %in% .x)[1])
        #ps = as.matrix(grid2)
        #test! wrong to use x
        ps = 1:M
      }
      
      #if(length(noord_ps) > 0){
      #  ps = ps[-noord_ps]
      #}
      if(length(rm_id) > 0){
        ps = ps[-rm_id]
      }
      #amat dimension: l1*l2...*lk by M
      amat = NULL
      for(i in 1:(nbcks-1)) {
        #bck_1 = bck_lst[[i]]; bck_2 = bck_lst[[i+1]]
        #l1 = length(bck_1); l2 = length(bck_2)
        ubck1 = ubck[i]
        ubck2 = ubck[i+1]
        bck_1 = which(block.ord_nz == ubck1)
        bck_2 = which(block.ord_nz == ubck2)
        l1 = length(bck_1)
        l2 = length(bck_2)
        
        #ps_bck1 = ps[which(block.ord_nz == ubck1)]
        #ps_bck2 = ps[which(block.ord_nz == ubck2)]
        #nr = l1*l2; nc = M
        #for each element in block1, the value for each element in block 2 will be the same
        #for each element in block1, the number of comparisons is l2
        amat_bcki = NULL
        rm_l = length(zeros_ps) + length(noord_cells)
        #rm_l = length(zeros_ps) 
        #rm_l = 0
        #test more
        #l_noord_ps = length(which(grid2[, D_index] %in% noord_ps))
        #rm_l = length(zeros_ps) + l_noord_ps
        
        #tmp:
        #if(D_index == 1){
        #  M.ord = length(block.ord)
        #}else{
        #M.ord = nc2
        #  M.ord = nrow(grid2)
        #}
        M.ord = length(block.ord)
        
        #amatj = matrix(0, nrow = l2, ncol = M.ord)
        amatj = matrix(0, nrow = l2, ncol = M.ord-rm_l)
        bck_1k = bck_1[1]
        #bck_1k = ps_bck1[1]
        amatj[, bck_1k] = -1
        row_pointer = 1
        for(j in 1:l2){
          bck_2j = bck_2[j]
          #bck_2j = ps_bck2[j]
          amatj[row_pointer, bck_2j] = 1
          row_pointer = row_pointer + 1
        }
        amatj0 = amatj
        amat_bcki = rbind(amat_bcki, amatj)
        
        if(l1 >= 2){
          for(k in 2:l1) {
            bck_1k = bck_1[k]; bck_1k0 = bck_1[k-1]
            #bck_1k = ps_bck1[k]; bck_1k0 = ps_bck1[k-1]
            amatj = amatj0
            #set the value for the kth element in block 1 to be -1
            #set the value for the (k-1)th element in block 1 back to 0
            #keep the values for block 2
            amatj[, bck_1k] = -1; amatj[, bck_1k0] = 0
            amatj0 = amatj
            amat_bcki = rbind(amat_bcki, amatj)
          }
        }
        amat = rbind(amat, amat_bcki)
      }
      
      if(length(noord_cells)>0){
        nr = nrow(amat); nc = ncol(amat)
        amat_tmp = matrix(0, nrow = nr, ncol = (M.ord-rm_l+length(noord_cells)))
        if(length(zeros_ps) > 0){
          ps = which(block.ord[-zeros_ps] != 0)
        }else{
          ps = which(block.ord != 0)
        }
        amat_tmp[, ps] = amat
        amat = amat_tmp
      }
      return (amat)
    } else if ((length(block.ave) > 1) & all(block.ord == -1)) {
      block.ave_nz = block.ave
      #rm_id = unique(sort(c(zeros_ps, noord_cells)))
      
      noord_ps = which(block.ave == 0)
      rm_id = unique(sort(c(zeros_ps, noord_ps)))
      
      if(length(rm_id)>0){
        block.ave_nz = block.ave[-rm_id]
      }
      #if(length(zeros_ps)>0){
      #  block.ave_nz = block.ave[-zeros_ps]
      #}
      
      #if(length(noord_cells)>0){
      #  block.ave_nz = block.ave_nz[-noord_cells]
      #}
      
      ubck = unique(block.ave_nz)
      #use values in block.ord as an ordered integer vector
      ubck = sort(ubck)
      nbcks = length(table(block.ave_nz))
      
      szs = as.numeric(table(block.ave_nz))
      
      rm_l = length(zeros_ps) + length(noord_ps)
      M.ave = length(block.ave)
      amat = matrix(0, nrow = (nbcks-1), ncol = (M.ave-rm_l))
      for(i in 1:(nbcks-1)) {
        ubck1 = ubck[i]
        ubck2 = ubck[i+1]
        ps1 = which(block.ave_nz == ubck1)
        ps2 = which(block.ave_nz == ubck2)
        sz1 = length(ps1)
        sz2 = length(ps2)
        #print (sz1)
        #print (sz2)
        amat[i, ps1] = -1/sz1
        amat[i, ps2] = 1/sz2
      }
      
      if(length(noord_ps)>0){
        amat_tmp = matrix(0, nrow = (nbcks-1), ncol = (M.ave-rm_l+length(noord_ps)))
        if(length(zeros_ps) > 0){
          ps = which(block.ave[-zeros_ps] != 0)
        }else{
          ps = which(block.ave != 0)
        }
        amat_tmp[, ps] = amat
        amat = amat_tmp
      }
      return (amat)
    } else if ((length(block.ave) > 1) & (length(block.ord) > 1)) {
      stop ("only one of block average and block ordering can be used!")
    }
  }
  #return (amat)
}

######
#>=2D
######
makeamat_2D = function(x=NULL, sh, grid2=NULL, xvs2=NULL, zeros_ps=NULL, noord_cells=NULL,
                       Ds = NULL, suppre = FALSE, interp = FALSE, 
                       relaxes = FALSE, block.ave.lst = NULL, block.ord.lst = NULL, M_0 = NULL) {
  #n = length(x)
  #xu = sort(unique(x))
  #n1 = length(xu)
  sm = 1e-7
  ms = NULL
  Nd = length(Ds)
  ne = length(zeros_ps)
  #new: for ci
  amat_ci = vector('list', length = Nd)
  if (Nd == 2) {
    if (any(sh > 0)) {
      D1 = Ds[1]
      D2 = Ds[2]
      obs = 1:(D1*D2)
      #new: group sh = 9 | 10 with other shapes
      if (sh[1] <= 10) {
        #new:
        if(sh[1] == 0){
          amat1 = NULL
        }else{
          if (length(zeros_ps) > 0) {
            amat0_lst = list()
            grid3 = grid2[-zeros_ps, ,drop=FALSE]
            row_id = as.numeric(rownames(grid3))
            cts = row_id[which(row_id%%D1 == 0)]
            #temp: check if zeros_ps contains some point %% D1 = 0
            zeros_ps_at_cts <- zeros_ps[which((zeros_ps %% D1) == 0)]
            #if (any((zeros_ps %% D1) == 0)) {
            if(length(zeros_ps_at_cts) > 0){
              #zeros_ps_at_cts = zeros_ps[which((zeros_ps %% D1) == 0)]
              #cts_add = row_id[sapply(zeros_ps_at_cts, function(elem) max(which(row_id < elem)))]
              cts_add <- sapply(zeros_ps_at_cts, function(empty_dom) {
                prior_rows <- row_id[row_id < empty_dom]
                if (length(prior_rows) > 0) max(prior_rows) else NA_integer_
              })
              cts_add <- cts_add[!is.na(cts_add)]
              cts = unique(sort(c(cts, cts_add)))
            }
            obs = 1:nrow(grid3)
            for (i in 1:length(cts)) {
              if (i == 1) {
                st = 1
                #ed = cts[1]
                ed = obs[row_id == cts[1]]
              } else {
                st = obs[row_id == cts[i-1]]+1
                ed = obs[row_id == cts[i]]
              }
              gridi = grid3[st:ed,,drop=FALSE]
              na_ps = which(apply(gridi, 1, function(x) all(is.na(x))))
              if (length(na_ps) > 0) {
                gridi = gridi[-na_ps,,drop=FALSE]
              }
              obsi = as.numeric(rownames(gridi))
              vec = (D1*(i-1) + 1):(D1*i)
              #vec = 1:D1
              zerosi = vec[-which(vec %in% obsi)]
              non_zerosi = vec[which(vec %in% obsi)]
              D1i = nrow(gridi)
              #if ed == st, it could be: 1,2,3,4 with 0,0,0,92 observations
              #in this case, amat0i = NULL, and it will be removed before calling bdiag
              if (length(zerosi) >= 1 & (ed > st)) {
                if(sh[1] < 9){
                  amat0i = makeamat(1:D1i, sh[1])
                } else {
                  #new:
                  zeros_ps_i = zerosi%%D1
                  if(any(zeros_ps_i %in% 0)){
                    zeros_ps_i[which(zeros_ps_i %in% 0)] = D1
                  } else {
                    zeros_ps_i = zerosi%%D1
                  }
                  #not sure....
                  amat0i = makeamat(1:D1i, sh[1], block.ave = block.ave.lst[[1]], 
                                    block.ord = block.ord.lst[[1]], 
                                    zeros_ps = zeros_ps_i, 
                                    noord_cells = noord_cells[[1]], 
                                    D_index = 1, xvs2=xvs2, grid2=grid2)
                }
                amat0_lst[[i]] = amat0i
              } else if (length(zerosi) >= 1 & (ed == st)) {
                #next
                amat0i = 0
                amat0_lst[[i]] = amat0i
              } else if (length(zerosi) == 0) {
                if(sh[1] < 9){
                  amat0i = makeamat(1:D1, sh[1])
                }else{
                  #new:
                  amat0i = makeamat(1:D1i, sh[1], block.ave = block.ave.lst[[1]], 
                                    block.ord = block.ord.lst[[1]], 
                                    zeros_ps = NULL, 
                                    noord_cells = noord_cells[[1]], 
                                    D_index = 1,
                                    xvs2=xvs2, grid2=grid2)
                }
                amat0_lst[[i]] = amat0i
              }
              #amat0_lst[[i]] = amat0i
            }
          } else {
            #if(sh[1] < 9){
            amat0 = makeamat(1:D1, sh=sh[1], block.ave = block.ave.lst[[1]], block.ord = block.ord.lst[[1]],
                             zeros_ps=NULL, noord_cells = noord_cells[[1]], D_index = 1, xvs2 = xvs2, grid2 = grid2)
            #} else {
            #  amat0 = makeamat_2d_block(x=1:D1, sh=sh[1], block.ave = block.ave.lst[[1]], block.ord = block.ord.lst[[1]], 
            #        zeros_ps=NULL, noord_cells = noord_cells[[1]], D_index = 1, xvs2 = xvs2, grid2 = grid2, M_0 = M_0)
            #}
            amat0_lst = rep(list(amat0), D2)
          }
          amat0_lst = Filter(Negate(is.null), amat0_lst)
          amat1 = bdiag(amat0_lst)
          amat1 = as.matrix(amat1)
          amat_ci[[1]] = amat1
        }
      }
      #2D
      amat2 = NULL
      if (sh[2] < 3) {
        nr2 = D1*(D2-1)
        nc2 = D1*D2
        if (length(zeros_ps) == 0) {
          amat2 = matrix(0, nrow=nr2, ncol=nc2)
          #new: if sh[2] == 0, keep the zero matrix
          if (sh[2]>0){
            for(i in 1:nr2){
              amat2[i,i] = -1; amat2[i,(i+D1)] = 1
            }
          }
        } else {
          #temp
          #amat2 = matrix(0, nrow=1, ncol=nc2)
          amat2_lst = vector('list', length = nr2)
          rm_id = NULL
          for(i in 1:nr2) {
            if (i %in% zeros_ps) {
              if (i <= D1) {
                amat2_lst[[i]] = 1:nc2*0
                rm_id = c(rm_id, i)
              } else if (i <= (nc2-D1) & i > D1) {
                #ch_ps = c(ch_ps, i-D1)
                amat2_lst[[i]] = 1:nc2*0
                rm_id = c(rm_id, i)
                rowi = 1:nc2*0
                rowi[i-D1] = -1; rowi[i+D1] = 1
                amat2_lst[[i-D1]] = rowi
              }
            } else {
              rowi = 1:nc2*0
              rowi[i] = -1; rowi[i+D1] = 1
              amat2_lst[[i]] = rowi
              #amat2 = rbind(amat2, amat2i)
            }
          }
          for(i in (nr2+1):nc2) {
            if (i %in% zeros_ps)  {
              amat2_lst[[i-D1]] = 1:nc2*0
              rm_id = c(rm_id, i-D1)
            }
          }
          if (!is.null(rm_id)){
            amat2_lst = amat2_lst[-rm_id]
          }
          amat2 = NULL
          for(i in 1:length(amat2_lst)){
            amat2 = rbind(amat2, amat2_lst[[i]])
          }
          amat2 = amat2[,-zeros_ps,drop=FALSE]
        }
        #new: if sh[2] == 0, keep the zero matrix
        if (sh[2] == 0) {
          #nr2 = nrow(amat2)
          #nc2 = ncol(amat2)
          #amat2 = matrix(0, nrow=nr2, ncol=nc2)
          amat2 = NULL
        }
        if (sh[2] == 2) {amat2 = -amat2}
      } else if (sh[2] == 3 | sh[2] == 4) {
        nr2 = D1*(D2-2)
        nc2 = D1*D2
        amat2 = matrix(0, nrow=nr2, ncol=nc2)
        xu = 1:D2
        for (i in 1:nr2) {
          #c1 = min(obs[xu == xu[i]]); c2 = min(obs[xu == xu[i+1]]); c3 = min(obs[xu == xu[i+2]])
          amat2[i, i] = 1; amat2[i, (i+D1)] = -2; amat2[i, (i+2*D1)] = 1
        }
        if (sh[2] == 4) {amat2 = -amat2}
      } else if (sh[2] > 4 & sh[2] < 11) {
        nr2 = D1*(D2-2)
        nc2 = D1*D2
        amat2 = matrix(0, nrow=nr2, ncol=nc2)
        xu = 1:D2
        for (i in 1:nr2) {
          #c1 = min(obs[xu == xu[i]]); c2 = min(obs[xu == xu[i+1]]); c3 = min(obs[xu == xu[i+2]])
          #amat2[i, c1] = 1; amat2[i, (c2+D1)] = -2; amat2[i, (c3+2*D1)] = 1
          amat2[i, i] = 1; amat2[i, (i+D1)] = -2; amat2[i, (i+2*D1)] = 1
        }
        if (sh[2] == 5) { ### increasing convex
          amat2_add = matrix(0, nrow=D1, ncol=nc2)
          #c1 = min(obs[xu == xu[1]])
          #c2 = min(obs[xu == xu[2]])
          #amat2[nr2, c1] = -1; amat2[nr2, (c1+D1)] = 1
          for(i in 1:D1) {
            amat2_add[i,i] = -1; amat2_add[i,i+D1] = 1
          }
          amat2 = rbind(amat2, amat2_add)
        } else if (sh[2] == 6) {  ## decreasing convex
          #c1 = min(obs[xu == xu[D2]]); c2 = min(obs[xu == xu[D2 - 1]])
          amat2_add = matrix(0, nrow=D1, ncol=nc2)
          for(i in 1:D1) {
            amat2_add[i, i+2*D1] = -1; amat2_add[i, i+D1] = 1
          }
          amat2 = rbind(amat2, amat2_add)
        } else if (sh[2] == 7) { ## increasing concave
          amat2 = -amat2
          #c1 = min(obs[xu == xu[D2]]); c2 = min(obs[xu == xu[D2 - 1]])
          #amat2[nr2, c1+D1] = 1; amat2[nr2, c2] = -1
          amat2_add = matrix(0, nrow=D1, ncol=nc2)
          for(i in 1:D1) {
            amat2_add[i, i+2*D1] = 1; amat2_add[i, i+D1] = -1
          }
          amat2 = rbind(amat2, amat2_add)
        } else if (sh[2] == 8) {## decreasing concave
          amat2 = -amat2
          #c1 = min(obs[xu == xu[1]]); c2 = min(obs[xu == xu[2]])
          #amat2[nr2, c1] = 1; amat2[nr2, (c2+D1)] = -1
          amat2_add = matrix(0, nrow=D1, ncol=nc2)
          for(i in 1:D1) {
            amat2_add[i,i] = 1; amat2_add[i,i+D1] = -1
          }
          amat2 = rbind(amat2, amat2_add)
        } else if (sh[2] == 9 | sh[2] == 10) {## block ordering or block average
          amat2 = makeamat_2d_block(x=NULL, sh[2], block.ave = block.ave.lst[[2]], 
                                    block.ord = block.ord.lst[[2]], 
                                    zeros_ps = zeros_ps, 
                                    noord_cells = noord_cells[[2]], 
                                    D_index=2, xvs2=xvs2, grid2=grid2, M_0=M_0)
        }
      }
      amat = rbind(amat1, amat2)
      #new: for ci
      amat_ci[[2]] = amat2
    } 
  } else if (Nd >= 3) {
    M = length(Ds)
    D1 = Ds[1]
    cump = cumprod(Ds)
    nc = cump[M]
    amat1 = amat2 = amatm = NULL
    
    # 1D: group sh=9 or sh=10 with other shapes in 1D
    if(sh[1] == 0){
      amat1 = NULL
    } else {
      if(length(zeros_ps)>0){
        amat0_lst = list()
        grid3 = grid2[-zeros_ps, ,drop=FALSE]
        row_id = as.numeric(rownames(grid3))
        cts = row_id[which(row_id %% D1 == 0)]
        zeros_ps_at_cts <- zeros_ps[which((zeros_ps %% D1) == 0)]
        #if(any((zeros_ps %% D1) == 0)){
        if(length(zeros_ps_at_cts) > 0){
          #zeros_ps_at_cts = zeros_ps[which((zeros_ps %% D1) == 0)]
          #cts_add = row_id[sapply(zeros_ps_at_cts, function(elem) max(which(row_id < elem)))]
          cts_add <- sapply(zeros_ps_at_cts, function(empty_dom) {
            prior_rows <- row_id[row_id < empty_dom]
            if (length(prior_rows) > 0) max(prior_rows) else NA_integer_
          })
          cts_add <- cts_add[!is.na(cts_add)]
          cts = unique(sort(c(cts, cts_add)))
        }
        obs = 1:nrow(grid3)
        
        for (i in 1:length(cts)) {
          if (i == 1) {
            st = 1
            ed = obs[row_id == cts[1]]
          } else {
            st = obs[row_id == cts[i-1]]+1
            ed = obs[row_id == cts[i]]
          }
          gridi = grid3[st:ed,,drop=FALSE]
          na_ps = which(apply(gridi, 1, function(x) all(is.na(x))))
          if (length(na_ps) > 0) {
            gridi = gridi[-na_ps, ,drop=FALSE]
          }
          obsi = as.numeric(rownames(gridi))
          vec = (D1*(i-1) + 1):(D1*i)
          zerosi = vec[-which(vec %in% obsi)]
          non_zerosi = vec[which(vec %in% obsi)]
          D1i = nrow(gridi)
          if (length(zerosi) >= 1 & (ed > st)) {
            #amat0i = makeamat(1:D1i, sh[1], block.ave=block.ave.lst[[1]], block.ord=block.ord.lst[[1]])
            #new:
            zeros_ps_i = zerosi%%D1
            if(any(zeros_ps_i %in% 0)){
              zeros_ps_i[which(zeros_ps_i %in% 0)] = D1
            } #else {
            #zeros_ps_i = zerosi%%D1
            #}
            amat0i = makeamat(1:D1i, sh[1], block.ave = block.ave.lst[[1]], block.ord = block.ord.lst[[1]], 
                              zeros_ps = zeros_ps_i, noord_cells = noord_cells[[1]], xvs2=xvs2, grid2=grid2)
            amat0_lst[[i]] = amat0i
          } else if (length(zerosi) >= 1 & (ed == st)) {
            amat0i = 0
            amat0_lst[[i]] = amat0i
          } else if (length(zerosi) == 0) {
            #amat0i = makeamat(1:D1, sh[1], block.ave=block.ave.lst[[1]], block.ord=block.ord.lst[[1]])
            amat0i = makeamat(1:D1, sh[1], block.ave=block.ave.lst[[1]], block.ord=block.ord.lst[[1]], 
                              zeros_ps = zeros_ps, noord_cells = noord_cells[[1]], xvs2=xvs2, grid2=grid2)
            amat0_lst[[i]] = amat0i
          }
        }
        amat0_lst = Filter(Negate(is.null), amat0_lst)
        amat1 = bdiag(amat0_lst)
        amat1 = as.matrix(amat1)
      }else{
        n1 = nc/D1
        amat1_0 = makeamat(1:D1, sh[1], block.ave=block.ave.lst[[1]], block.ord=block.ord.lst[[1]],
                           zeros_ps = zeros_ps, noord_cells = noord_cells[[1]], xvs2=xvs2, grid2=grid2)
        amat1_0_lst = rep(list(amat1_0), n1)
        amat1 = bdiag(amat1_0_lst)
        amat1 = as.matrix(amat1)
      }
    }
    amat = amat1
    #new: for ci
    amat_ci[[1]] = amat1
    #2D
    for (i in 2:(M-1)) {
      Di = Ds[i]
      cump = cumprod(Ds[1:i])
      nci = cump[i]
      gap = cump[i-1]
      ni = nc/nci
      
      if (sh[i] %in% c(9,10)) {
        amati = makeamat_2d_block(x=NULL, sh[i], block.ave = block.ave.lst[[i]], 
                                  block.ord = block.ord.lst[[i]], 
                                  zeros_ps = zeros_ps, 
                                  noord_cells = noord_cells[[i]], 
                                  D_index=i, xvs2=xvs2, grid2=grid2, M_0=M_0)
      } else {
        #new: sh[i] == 0
        if(sh[i] == 0){
          amati = NULL
        }
        if(sh[i] > 0) {
          if ((sh[i] %in% c(1,2))) {
            nr = gap*(Di - 1)
          } 
          if(sh[i] > 2){
            if (Di < 3) {
              stop ("For monotonicity + convexity constraints, number of domains must be >= 3!")
            } else {
              nr = gap*(Di - 2)
            }
          }
          amati_0 = makeamat_2d(sh=sh[i], nr2=nr, nc2=nci, gap=gap, D2=Di, Ds=Ds)
          
          if(length(zeros_ps) == 0){
            amati_0_lst = rep(list(amati_0), ni)
          } else {
            amati_0_lst = vector("list", length = ni)
            cts = seq(nci, nc, length=ni)
            bool_lst = map(zeros_ps, .f = function(.x) as.numeric(cts >= .x))
            #ps_lst keeps track of the cut points such that the point giving -1 is the upp bd 
            #of the interval containining one zero cell
            ps = map_dbl(bool_lst, .f = function(.x) min(which(.x == 1)))
            cts_check = cts[ps]
            for(k in 1:ni){
              if(!(cts[k] %in% cts_check)){
                #no zero cell in this interval
                amati_0_lst[[k]] = amati_0
              } else {
                #find out the empty cells within an interval
                if(k == 1) {
                  zeros_ctsk = zeros_ps[zeros_ps <= cts[k]]
                }else{
                  zeros_ctsk = zeros_ps[zeros_ps <= cts[k] & zeros_ps > cts[k-1]]
                }
                nek = length(zeros_ctsk)
                rm_id = NULL
                amati_0_cp = amati_0; nci = ncol(amati_0)
                col_ps_vec = NULL
                for(k2 in 1:nek){
                  zero_k2 = zeros_ctsk[k2]
                  #check: zero_k2 %% nr = 0
                  if(zero_k2 != cts[k]){
                    col_ps = zero_k2 %% nci
                  }else{
                    col_ps = nci
                  }
                  col_ps_vec = c(col_ps_vec, col_ps)
                  
                  row_ps = col_ps - gap
                  #if((zero_k2 %% nr + gap) > nr){
                  if((col_ps + gap) > nci){
                    rm_id = c(rm_id, row_ps)
                  }else{
                    rm_id = c(rm_id, col_ps)
                    #test:
                    if(row_ps > 0){
                      amati_0_cp[row_ps, (col_ps+gap)] = 1
                      amati_0_cp[row_ps, col_ps] = 0
                    }
                  }
                }
                amati_0_cp = amati_0_cp[-rm_id, ]
                #amati_0_cp = amati_0_cp[,-(zeros_ctsk%%nr)]
                amati_0_cp = amati_0_cp[,-col_ps_vec]
                amati_0_lst[[k]] = amati_0_cp
              }
            }
          }
          amati = bdiag(amati_0_lst)
          amati = as.matrix(amati)
        }
      }  
      amat = rbind(amat, amati)
      #new: for ci
      amat_ci[[i]] = amati
    }
    
    #3D
    cump = cumprod(Ds[1:(M-1)])
    gap = cump[M-1]
    Dm = Ds[M]
    if((sh[M] %in% c(9,10))){
      amatm = makeamat_2d_block(x=NULL, sh[M], block.ave = block.ave.lst[[M]], 
                                block.ord = block.ord.lst[[M]], 
                                zeros_ps = zeros_ps, 
                                noord_cells = noord_cells[[M]], 
                                D_index=M, xvs2=xvs2, grid2=grid2, M_0=M_0)
    } else {
      if ((sh[M] %in% c(1,2))) {
        nr = gap*(Dm-1)
      } 
      if (sh[M] >= 3) {
        nr = gap*(Dm-2)
      }
      #new: sh[M] == 0
      if(sh[M] == 0){
        #nrm = nrow(amatm)
        #ncm = ncol(amatm)
        #amatm = matrix(0, nrow=nrm, ncol=ncm)
        amatm = NULL
      } else {
        amatm = makeamat_2d(sh=sh[M], nr2=nr, nc2=nc, gap=gap, D2=Dm, Ds=Ds)
        ncm = ncol(amatm)
        if(length(zeros_ps) > 0){
          amatm_cp = amatm
          rm_id = NULL
          for(k in 1:ne){
            zero_k = zeros_ps[k]
            col_ps = zero_k 
            row_ps = col_ps - gap
            if((col_ps + gap) > ncm){
              rm_id = c(rm_id, row_ps)
            }else{
              rm_id = c(rm_id, col_ps)
              #test:
              if(row_ps > 0){
                amatm_cp[row_ps, (col_ps+gap)] = 1
                amatm_cp[row_ps, col_ps] = 0
              }
            }
          }
          amatm_cp = amatm_cp[-rm_id, ]
          amatm_cp = amatm_cp[,-zeros_ps]
          amatm = amatm_cp
        }
      }
    }
    amat = rbind(amat, amatm)
    #new: for ci
    amat_ci[[M]] = amatm
  }
  rslt = list(amat = amat, amat_ci = amat_ci)
  return (rslt)
  #return (amat)
}

############################################
#local function to be called in makeamat_2D
############################################
makeamat_2d = function(x=NULL, sh, nr2, nc2, gap, D1=NULL, D2, Ds = NULL, suppre = FALSE, interp = FALSE) {
  sm = 1e-7
  ms = NULL
  Nd = length(Ds)
  
  amat2 = NULL
  if (sh < 3) {
    amat2 = matrix(0, nrow=nr2, ncol=nc2)
    for(i in 1:nr2){
      amat2[i,i] = -1; amat2[i,(i+gap)] = 1
    }
    if (sh == 2) {amat2 = -amat2}
    return (amat2)
  } else if (sh == 3 | sh == 4) {
    amat2 = matrix(0, nrow=nr2, ncol=nc2)
    #xu = 1:D2
    for (i in 1:nr2) {
      #c1 = min(obs[xu == xu[i]]); c2 = min(obs[xu == xu[i+1]]); c3 = min(obs[xu == xu[i+2]])
      #amat2[i, c1] = 1; amat2[i, (c2+gap)] = -2; amat2[i, (c3+2*gap)] = 1
      amat2[i, i] = 1; amat2[i, (i+gap)] = -2; amat2[i, (i+2*gap)] = 1
    }
    if (sh == 4) {amat2 = -amat2}
    return (amat2)
  } else if (sh > 4 & sh < 9) {
    amat2 = matrix(0, nrow=nr2, ncol=nc2)
    #xu = 1:D2
    for (i in 1:nr2) {
      #c1 = min(obs[xu == xu[i]]); c2 = min(obs[xu == xu[i+1]]); c3 = min(obs[xu == xu[i+2]])
      #amat2[i, c1] = 1; amat2[i, (c2+gap)] = -2; amat2[i, (c3+2*gap)] = 1
      amat2[i, i] = 1; amat2[i, (i+gap)] = -2; amat2[i, (i+2*gap)] = 1
    }
    if (sh == 5) { ### increasing convex
      nr2_add = gap
      amat2_add = matrix(0, nrow=nr2_add, ncol=nc2)
      for(i in 1:nr2_add) {
        amat2_add[i,i] = -1; amat2_add[i,i+gap] = 1
      }
      amat2 = rbind(amat2, amat2_add)
      return (amat2)
    } else if (sh == 6) {  ## decreasing convex
      nr2_add = gap
      amat2_add = matrix(0, nrow=nr2_add, ncol=nc2)
      for(i in 1:nr2_add) {
        amat2_add[i, i+2*gap] = -1; amat2_add[i, i+gap] = 1
      }
      amat2 = rbind(amat2, amat2_add)
      return (amat2)
    } else if (sh == 7) { ## increasing concave
      nr2_add = gap
      amat2 = -amat2
      amat2_add = matrix(0, nrow=nr2_add, ncol=nc2)
      for(i in 1:nr2_add) {
        amat2_add[i, i+2*gap] = 1; amat2_add[i, i+gap] = -1
      }
      amat2 = rbind(amat2, amat2_add)
      return (amat2)
    } else if (sh == 8) {## decreasing concave
      nr2_add = gap
      amat2 = -amat2
      amat2_add = matrix(0, nrow=nr2_add, ncol=nc2)
      for(i in 1:nr2_add) {
        amat2_add[i,i] = 1; amat2_add[i,i+gap] = -1
      }
      amat2 = rbind(amat2, amat2_add)
      return (amat2)
    }
  }
  #return (amat2)
  #rslt = list(amat = amat,  ms = ms)
  #rslt
}

################################################################
#local function to be called in makeamat_2D for block ordering
################################################################
makeamat_2d_block = function(x=NULL, sh, block.ave = NULL, block.ord = NULL, zeros_ps = NULL, 
                             noord_cells = NULL, D_index = 1, xvs2 = NULL, 
                             grid2 = NULL, M_0 = NULL) {
  #block average is not added yet
  if (all(block.ave == -1) & (length(block.ord) > 1)) {
    ps = sapply(xvs2[[D_index]], function(.x) which(grid2[, D_index] %in% .x)[1])
    #nbcks is the total number of blocks
    block.ord_nz = block.ord
    noord_ps = which(block.ord == 0)
    #this rm_id is not for empty cells; it is for some domain we don't want to compare ... with
    rm_id = unique(sort(noord_ps))
    #test:
    #z_rm_id = NULL
    #if(length(zeros_ps) > 0){
    #  for(e in zeros_ps){
    #    z_rm_id = c(z_rm_id, max(which(ps <= e)))
    #  }
    #}
    #if(ncol(grid2) == 1){
    #rm_id = unique(sort(c(z_rm_id, noord_ps)))
    #} else {
    #  rm_id = unique(sort(noord_ps))
    #}
    if(length(rm_id)>0){
      block.ord_nz = block.ord[-rm_id]
    }
    
    ubck = unique(block.ord_nz)
    #use values in block.ord as an ordered integer vector
    ubck = sort(ubck)
    
    nbcks = length(table(block.ord_nz))
    szs = as.vector(table(block.ord_nz))
    
    if(length(rm_id) > 0){
      ps = ps[-rm_id]
    }
    #amat dimension: l1*l2...*lk by M
    amat = NULL
    for(i in 1:(nbcks-1)) {
      ubck1 = ubck[i]
      ubck2 = ubck[i+1]
      bck_1 = which(block.ord_nz == ubck1)
      bck_2 = which(block.ord_nz == ubck2)
      l1 = length(bck_1)
      l2 = length(bck_2)
      
      ps_bck1 = ps[which(block.ord_nz == ubck1)]
      ps_bck2 = ps[which(block.ord_nz == ubck2)]
      #for each element in block1, the value for each element in block 2 will be the same
      #for each element in block1, the number of comparisons is l2
      amat_bcki = NULL
      
      #tmp:
      if(D_index == 1){
        M.ord = length(block.ord) 
      }else{
        M.ord = M_0
        #M.ord = nrow(grid2)
      }
      
      amatj = matrix(0, nrow = l2, ncol = M.ord)
      #amatj = matrix(0, nrow = l2, ncol = M.ord-rm_l)
      #bck_1k = bck_1[1]
      bck_1k = ps_bck1[1]
      amatj[, bck_1k] = -1
      row_pointer = 1
      for(j in 1:l2){
        #bck_2j = bck_2[j]
        bck_2j = ps_bck2[j]
        amatj[row_pointer, bck_2j] = 1
        row_pointer = row_pointer + 1
      }
      amatj0 = amatj
      amat_bcki = rbind(amat_bcki, amatj)
      
      if(l1 >= 2){
        for(k in 2:l1) {
          #bck_1k = bck_1[k]; bck_1k0 = bck_1[k-1]
          bck_1k = ps_bck1[k]; bck_1k0 = ps_bck1[k-1]
          amatj = amatj0
          #set the value for the kth element in block 1 to be -1
          #set the value for the (k-1)th element in block 1 back to 0
          #keep the values for block 2
          amatj[, bck_1k] = -1; amatj[, bck_1k0] = 0
          amatj0 = amatj
          amat_bcki = rbind(amat_bcki, amatj)
        }
      }
      amat = rbind(amat, amat_bcki)
    }
    
    if(D_index > 1){
      iter = 1
      amat0 = amat 
      nr = nrow(amat0)
      nc = ncol(amat0)
      ps = sapply(xvs2[[D_index]], function(.x) which(grid2[,D_index] %in% .x)[1])
      ps1_ed = ps[2] - 1
      
      if(length(noord_cells) > 0){
        ps = ps[-noord_cells]
      }
      
      while(iter < ps1_ed) {
        ps1 = ps + iter
        amati = matrix(0, nrow=nr, ncol=nc)
        amati[, ps1] = amat0[, ps]
        amat = rbind(amat, amati)
        iter = iter + 1
      }
    }
    
    #test more...
    #handle empty cells
    if (length(zeros_ps) > 0) {
      compared_pairs = apply(amat, 1, function(e) which(e != 0))
      #check if or not any empty cell has been compared; if yes, then delete the rows
      rm_rows = which(apply(compared_pairs, 2, function (e) any(e %in% zeros_ps)))
      if (length(rm_rows) > 0) {
        amat = amat[-rm_rows, ,drop = FALSE]
      }
      #columns for empty cells must be removed too; there are no thetahat for these domains when run coneA
      amat = amat[, -zeros_ps]
    }
  }
  return (amat)
}

#########################
#empty cell imputation 
##########################
impute_em2 = function(empty_cells, obs_cells, M, yvec, Nd, nd, Nhat, w, domain.ord, grid2, 
                      Ds=NULL, sh=NULL, muhat=NULL, lwr=NULL, upp=NULL, vsc=NULL, amat_0=NULL,
                      level=0.95)
{
  #new:
  z.mult = qnorm((1 - level)/2, lower.tail=FALSE)
  ne = length(empty_cells)
  #observed sample size in cells with 0 or 1 n
  nd_ne = nd[empty_cells]
  zeros = ones = 0
  zeros_ps = NULL
  if (ne >= 1) {
    #ignore nd=1 now
    # if (any(nd == 1)) {
    #    ones_ps = which(nd == 1)
    #    ones = length(ones_ps)
      #   yvec_ones = yvec[which(obs_cells %in% ones_ps)]
      #   Nhat_ones = Nhat[which(obs_cells %in% ones_ps)]
    #} 
    #if (any(nd == 0)) {
      zeros_ps = which(nd == 0)
      zeros = length(zeros_ps)
    #}
   # zeros = ne - ones
  }
  
  lc = rc = 1:ne*0
  if (Nd == 1) {
    vsc_all = lwr_all = upp_all = muhat_all = domain.ord_all = 1:M*0
    nslst = pslst = vector('list', ne)
    vsce = lwre = uppe = muhate = domain.ord_em = NULL
    for(k in 1:ne){
      e = empty_cells[k]
      ndek = nd_ne[k]
      #one_k = 1
      #new: consider n = 1 
      #at_max = all(obs_cells < e | (ndek == 1 & e == max(obs_cells))
      #at_min = all(obs_cells > e | (ndek == 1 & e == min(obs_cells))
      at_max = all(obs_cells < e | e == max(obs_cells))
      at_min = all(obs_cells > e | e == min(obs_cells))
      if (at_max) {
        lc = max(obs_cells[obs_cells < e])
        rc = lc       
        dmem = obs_cells[which(obs_cells%in%lc)]
      }
      if (at_min) {
        rc = min(obs_cells[obs_cells > e])
        lc = rc
        dmem = obs_cells[which(obs_cells%in%rc)]
      }
      if (!any(at_max) & !any(at_min)) {
        lc = max(obs_cells[obs_cells < e])
        rc = min(obs_cells[obs_cells > e])
        dmem = obs_cells[which(obs_cells%in%lc)]
      }
      n_lc = nd[lc]
      n_rc = nd[rc]
      
      # yl = yvec[which(obs_cells%in%lc)]
      # yr = yvec[which(obs_cells%in%rc)]
      # 
      # Nl = Nhat[lc]
      # Nr = Nhat[rc]
      #new: switch lc and rc if sh=2
      if (sh == 2) {
        lc_0 = rc; rc_0 = lc
        lc = lc_0; rc = rc_0
      }
      if (!is.null(muhat)) {
        mul = muhat[which(obs_cells%in%lc)]
        mur = muhat[which(obs_cells%in%rc)]
      }
      
      #variance
      if (!is.null(vsc)) {
        vl = vsc[which(obs_cells%in%lc)]
        vr = vsc[which(obs_cells%in%rc)]
      }
      
      #muhate = c(muhate, (mul + mur)/2)
      #lwre = c(lwre, (mul + mur)/2 - 2 * sqrt(vl))
      #uppe = c(uppe, (mul + mur)/2 + 2 * sqrt(vr))
      
      if (!is.null(muhat)) {
        #muhatek = (mul - 2*sqrt(vl) + mur + 2*sqrt(vr))/2
        #new: allow level != 0.95
        muhatek = (mul - z.mult*sqrt(vl) + mur + z.mult*sqrt(vr))/2
        #check if muhatek follows the shape constraint
        lr = c(lc, rc)
        #sh=10 is for block ave
        #sh=9 is for block ord
        if (sh == 1 | sh == 10 | sh == 9) {
          if (length(lr) > 1) {
            bool = mul <= muhatek & mur >= muhatek
          } else if (length(lr) == 1) {
            if (at_min) {
              bool = mur >= muhatek
            } else if (at_max) {
              bool = mul <= muhatek 
            }
          }
        }
        if (sh == 2) {
          if (length(lr) > 1) {
            bool = mul >= muhatek & mur <= muhatek
          } else if (length(lr) == 1) {
            if (at_min) {
              bool = mur <= muhatek
            } else if (at_max) {
              bool = mul >= muhatek 
            }
          }
        }
        if (!bool) {
          if (!any(at_max) & !any(at_min)) {
            muhatek = (mul+mur)/2
          } else if (at_min) {
            muhatek = mul
          } else if (at_max)  {
            muhatek = mur
          }
        }
        muhate = c(muhate, muhatek)
      }
      
      if (!is.null(vsc)) {
        #varek = (mur + 2*sqrt(vr) - mul + 2*sqrt(vl))^2 / 16 
        #rpt = mur + 2*sqrt(vr); lpt = mul - 2*sqrt(vl)
        #new: allow level != 0.95
        rpt = mur + z.mult*sqrt(vr); lpt = mul - z.mult*sqrt(vl)
        #varek = (rpt-lpt)^2/16
        #new: 
        varek = (rpt-lpt)^2/(4*z.mult^2)
        vsce = c(vsce, varek)
        if (at_min & (ndek == 0)) {
          lwrek = -1e+5
          #uppek = muhatek + 2*sqrt(varek)
          #new:
          uppek = muhatek + z.mult*sqrt(varek)
        } else if (at_max & (ndek == 0)) {
          uppek = 1e+5
          #lwrek = muhatek - 2*sqrt(varek)
          #new:
          lwrek = muhatek + z.mult*sqrt(varek)
        } else if (at_min & (ndek >= 1) & (ndek <= 10)) {
          varek_lwr = max(vsc)
          #lwrek = muhatek - 2*sqrt(varek_lwr)
          #uppek = muhatek + 2*sqrt(varek)
          #new:
          lwrek = muhatek - z.mult*sqrt(varek_lwr)
          uppek = muhatek + z.mult*sqrt(varek)
        } else if (at_max & (ndek >= 1) & (ndek <= 10)) {
          varek_upp = max(vsc)
          #lwrek = muhatek - 2*sqrt(varek)
          #uppek = muhatek + 2*sqrt(varek_upp)
          #new:
          lwrek = muhatek - z.mult*sqrt(varek)
          uppek = muhatek + z.mult*sqrt(varek_upp)
        } else {
          #lwrek = muhatek - 2*sqrt(varek)
          #uppek = muhatek + 2*sqrt(varek)
          #new:
          lwrek = muhatek - z.mult*sqrt(varek)
          uppek = muhatek + z.mult*sqrt(varek)
        }
        lwre = c(lwre, lwrek)
        uppe = c(uppe, uppek)
      }
      
      ps = unique(c(lc, rc))
      ns = unique(c(n_lc, n_rc))
      pslst[[k]] = ps
      nslst[[k]] = ns
      domain.ord_em = c(domain.ord_em, dmem)
    }
    
    #yvec_all[obs_cells] = yvec
    #yvec_all[empty_cells] = ye
    
    #w_all[obs_cells] = w[obs_cells] 
    #w_all[obs_cells] = w
    #w_all[empty_cells] = we
    # 
    domain.ord_all[obs_cells] = domain.ord
    #domain.ord_all[which(nd > 0)] = domain.ord
    domain.ord_all[empty_cells] = domain.ord_em
    # 
    if (!is.null(muhat)) {
      muhat_all[obs_cells] = muhat
      muhat_all[empty_cells] = muhate
      muhat = muhat_all
    }
    
    if (!is.null(vsc)) {
      lwr_all[obs_cells] = lwr
      lwr_all[empty_cells] = lwre
      
      upp_all[obs_cells] = upp
      upp_all[empty_cells] = uppe
      
      upp = upp_all
      lwr = lwr_all
      
      vsc_all[obs_cells] = vsc
      vsc_all[empty_cells] = vsce
      vsc = vsc_all
    }
    
    #yvec = yvec_all
    #w = w_all
    domain.ord = domain.ord_all
    
  } else if (Nd >= 2) {
    vsc_all = lwr_all = upp_all = muhat_all = 1:M*0
    yvec_all = w_all = domain.ord_all = 1:M*0
    empty_grids = grid2[empty_cells, ]
    maxs = apply(grid2, 2, max)
    mins = apply(grid2, 2, min)
    at_max = (empty_grids == maxs)
    at_min = (empty_grids == mins)
    
    vsce = lwre = uppe = muhate = domain.ord_em = NULL
    ye = we = NULL
    nslst = pslst = vector('list', 2)
    nbors = NULL
    nbors_lst_0 = fn_nbors_3(empty_cells, amat_0, muhat)
    for(k in 1:length(empty_cells)){
      #new: if is new because of rm_id2
      nbors_k_maxs = nbors_lst_0$nb_lst_maxs[[k]]
      nbors_k_mins = nbors_lst_0$nb_lst_mins[[k]]
      em_k = empty_cells[k]
      
      #revised: nbors_k_mins and nbors_k_maxs must be in obs_cells
      #if(length(nbors_k_mins) > 0 | length(nbors_k_maxs) > 0){
      if(length(nbors_k_mins) > 0 & any(nbors_k_mins %in% obs_cells) | length(nbors_k_maxs) > 0 & any(nbors_k_maxs %in% obs_cells)){
        max_ps = min_ps = NULL
        if(length(nbors_k_maxs)>0){
          mu_max = max(muhat[obs_cells %in% nbors_k_maxs])
          max_ps = (nbors_k_maxs[which(muhat[obs_cells %in% nbors_k_maxs] == mu_max)])[1]
        }else{
          #mu_max = muhat[em_k]
          #max_ps = NULL
          mu_min = min(muhat[obs_cells %in% nbors_k_mins])
          min_ps = (nbors_k_mins[which(muhat[obs_cells %in% nbors_k_mins] == mu_min)])[1]
          mu_max = mu_min
          max_ps = min_ps
          dmem = obs_cells[which(obs_cells%in%min_ps)]
        }
        
        if(length(nbors_k_mins)>0){
          mu_min = min(muhat[obs_cells %in% nbors_k_mins])
          min_ps = (nbors_k_mins[which(muhat[obs_cells %in% nbors_k_mins] == mu_min)])[1]
        } else{
          #mu_min = muhat[em_k]
          #min_ps = NULL
          mu_max = max(muhat[obs_cells %in% nbors_k_maxs])
          max_ps = (nbors_k_maxs[which(muhat[obs_cells %in% nbors_k_maxs] == mu_max)])[1]
          mu_min = mu_max
          min_ps = max_ps
          dmem = obs_cells[which(obs_cells%in%max_ps)]
        }
        
        if(!is.null(max_ps)){
          v_max = vsc[which(obs_cells %in% max_ps)]
        }
        if(!is.null(min_ps)){
          v_min = vsc[which(obs_cells %in% min_ps)]
        }
        
        if(!is.null(max_ps) & !is.null(min_ps)){
          #varek_new = (mu_max + 2*sqrt(v_max) - mu_min + 2*sqrt(v_min))^2 / 16 
          #muhatek = (mu_min - 2*sqrt(v_min) + mu_max + 2*sqrt(v_max))/2
          #new: allow level != 0.95
          varek_new = (mu_max + z.mult*sqrt(v_max) - mu_min + z.mult*sqrt(v_min))^2 / (4*z.mult^2) 
          muhatek = (mu_min - z.mult*sqrt(v_min) + mu_max + z.mult*sqrt(v_max))/2
          
          #?mu_min
          dmem = obs_cells[which(obs_cells%in%min_ps)]
        }
        
        varek = varek_new
        vsce = c(vsce, varek)
        muhate = c(muhate, muhatek)
        domain.ord_em = c(domain.ord_em, dmem)
        
        if (length(nbors_k_mins) == 0) {
          lwrek = -1e+5
          #uppek = muhatek + 2*sqrt(varek)
          #new:
          uppek = muhatek + z.mult*sqrt(varek)
        } else if (length(nbors_k_maxs) == 0) {
          uppek = 1e+5
          #lwrek = muhatek - 2*sqrt(varek)
          #new:
          lwrek = muhatek - z.mult*sqrt(varek)
        } else if (length(nbors_k_mins) > 0 & length(nbors_k_maxs) > 0) {
          #uppek = muhatek + 2*sqrt(varek)
          #lwrek = muhatek - 2*sqrt(varek)
          #new:
          uppek = muhatek + z.mult*sqrt(varek)
          lwrek = muhatek - z.mult*sqrt(varek)
        }
        
        uppe = c(uppe, uppek)
        lwre = c(lwre, lwrek)
        
      } else {
        #ye = c(ye, rev(ye)[1])
        #we = c(we, rev(we)[1])
        #need to test more
        if(!is.null(rev(domain.ord_em)[1])){
          domain.ord_em = c(domain.ord_em, rev(domain.ord_em)[1])
        }
        if (!is.null(muhat)) {
          muhate = c(muhate, rev(muhate)[1])
        }
        if (!is.null(vsc)) {
          vsce = c(vsce, rev(vsce)[1])
          lwre = c(lwre, rev(lwre)[1])
          uppe = c(uppe, rev(uppe)[1])
        }
      }
    }
    
    domain.ord_all[obs_cells] = domain.ord
    #temp:
    #domain.ord_em = obs_cells[which(obs_cells%in%ps[1:length(empty_cells)])]
    domain.ord_all[empty_cells] = domain.ord_em
    domain.ord = domain.ord_all
    
    if (!is.null(muhat)) {
      muhat_all[obs_cells] = muhat
      muhat_all[empty_cells] = muhate
      muhat = muhat_all
    }
    
    if (!is.null(vsc)){
      lwr_all[obs_cells] = lwr
      lwr_all[empty_cells] = lwre
      upp_all[obs_cells] = upp
      upp_all[empty_cells] = uppe
      vsc_all[obs_cells] = vsc
      vsc_all[empty_cells] = vsce
      
      upp = upp_all
      lwr = lwr_all
      vsc = vsc_all
    }
  }
  
  rslt = list(muhat = muhat, lwr = lwr, upp = upp, domain.ord = domain.ord, vsc = vsc)
  #rslt = list(ps = ps, ns = ns, muhat = muhat, lwr = lwr, upp = upp, domain.ord = domain.ord, yvec = yvec, w = w)
  return(rslt)
}

###############################
#subrountine to find nbors (>=2D)
###############################
fn_nbors = function(empty_grids, grid2, mins, maxs){
  nr = nrow(empty_grids)
  nb_lst = vector("list", length = nr)
  to_merge = list()
  iter = 1
  for(k in 1:nr){
    #ndek = nd_ne[k]
    #axes = apply(empty_grids[k,], 2, function(e) e + c(-1,1))
    nbors_k = NULL
    ptk = empty_grids[k, ]
    ptk = as.numeric(ptk)
    nc = ncol(grid2)
    for(i in 1:nc){
      pt1 = ptk[i]
      pt2 = ptk[-i]
      axes = pt1 + c(-1, 1)
      pair_i = matrix(0, nrow=2, ncol=nc)
      pair_i[1, i] = axes[1]; pair_i[1, -i] = pt2
      pair_i[2, i] = axes[2]; pair_i[2, -i] = pt2
      nbors_k = rbind(nbors_k, pair_i)
    }
    #rm_id = unique(c(which(apply(nbors_k, 1, function(e) any(e<mins))), which(apply(nbors_k, 1, function(e) any(e>maxs)))))
    #new: 
    rm_id <- unique(which(apply(nbors_k, 1, function(e) any(e < mins | e > maxs))))
    if (length(rm_id) > 0) {
      nbors_k = nbors_k[-rm_id, , drop = FALSE]
    }
    #new: label the cells whose neighbors should be merged later
    # rm_id2 = NULL
    if (nrow(nbors_k) > 0) {
      for(h in 1:nrow(nbors_k)) {
        nbh = nbors_k[h, ]
        bool = any(apply(empty_grids, 1, function(e) all(e == nbh)))
        if (bool) {
          #rm_id2 = c(rm_id2, h)
          ps = as.numeric(which(apply(empty_grids, 1, function(e) all(e == nbh))))
          to_merge[[iter]] = sort(c(k, ps))
          iter = iter + 1
        }
      }
    }
    nb_lst[[k]] = nbors_k
  }
  
  #need to be more efficient
  to_merge = unique(to_merge)
  nm = length(to_merge)
  if (nm > 1) {
    for(i in 1:(nm-1)) {
      to_merge_i = to_merge[[i]]
      vec = to_merge_i
      vec_ps = i
      for(j in (i+1):nm) {
        to_merge_j = to_merge[[j]]
        if (any(to_merge_j %in% to_merge_i)) {
          vec = sort(unique(c(vec, to_merge_j)))
          vec_ps = c(vec_ps, j)
        }
      }
      if (length(vec_ps) > 1) {
        to_merge[vec_ps] = lapply(to_merge[vec_ps], function(x) x = vec)
      }
    }
    to_merge = unique(to_merge)
  }
  
  if (nm > 0) {
    for(i in 1:length(to_merge)) {
      ps = to_merge[[i]]
      nbors_ps = unique(do.call(rbind, nb_lst[ps]))
      rm_id2 = NULL
      for(h in 1:nrow(nbors_ps)) {
        nbh = nbors_ps[h, ]
        bool = any(apply(empty_grids, 1, function(e) all(e == nbh)))
        if (bool) {
          rm_id2 = c(rm_id2, h)
        }
      }
      if (length(rm_id2) > 0) {
        nbors_ps = nbors_ps[-rm_id2,,drop=FALSE]
      }
      nb_lst[ps] = lapply(nb_lst[ps], function(x) x = nbors_ps)
    }
  }
  return (nb_lst)
}

####################################################################################
#compute the variance for n=2...10 cells
#keep muhat, only compute weighted variance
#not used anymore; use coneA to project lwr and upp to handle small cells
####################################################################################
# impute_em3 = function(small_cells, M, yvec, Nd, nd, Nhat, w, domain.ord, grid2, 
#                       Ds=NULL, sh=NULL, muhat=NULL, lwr=NULL, upp=NULL, 
#                       vsc=NULL, new_obs_cells=NULL, amat_0=NULL)
# {
#   ne = length(small_cells)
#   #observed sample size for n=2...10 cells
#   #the `small_cells` are not empty
#   nd_ne = nd[small_cells]
#   #vsc_ne = vsc[which(new_obs_cells%in%small_cells)]
#   vsc_ne = vsc[small_cells]
#   
#   lc = rc = 1:ne*0
#   if (Nd >= 2) {
#     vsc_all = lwr_all = upp_all = muhat_all = 1:M*0
#     yvec_all = w_all = domain.ord_all = 1:M*0
#     empty_grids = grid2[small_cells, ]
#     maxs = apply(grid2, 2, max)
#     mins = apply(grid2, 2, min)
#     at_max = (empty_grids == maxs)
#     at_min = (empty_grids == mins)
#     
#     vsce = lwre = uppe = muhate = domain.ord_em = NULL
#     ye = we = NULL
#     nslst = pslst = vector('list', 2)
#     nbors = NULL
#     nbors_lst_0 = fn_nbors_2(small_cells, amat_0, muhat)
#     for(k in 1:length(small_cells)){
#       #new: if is new because of rm_id2
#       #nbors_k_maxs = nbors_lst_0$nb_lst_maxs[[k]]
#       #nbors_k_mins = nbors_lst_0$nb_lst_mins[[k]]
#       max_ps = (nbors_lst_0$nb_lst_maxs[[k]])[1]
#       min_ps = (nbors_lst_0$nb_lst_mins[[k]])[1]
#       sm_k = small_cells[k]
#       
#       if(length(max_ps) > 0){      
#         mu_max = muhat[max_ps]
#         v_max = vsc[max_ps]
#       }
#       if(length(min_ps) > 0){
#         mu_min = muhat[min_ps]
#         v_min = vsc[min_ps]
#       }
#       #new:
#       vsc_nek = vsc_ne[k]
#       if(length(min_ps) > 0 & length(max_ps) > 0){
#         varek_new = (mu_max + 2*sqrt(v_max) - mu_min + 2*sqrt(v_min))^2 / 16 
#       }else{
#         varek_new = vsc_nek
#       }
#       #the variance by survey
#       #the if else can be replaced by vectorized computation?
#       vsce = c(vsce, varek_new)
#       #print(c(k, length(vsce)))
#     }
#     
#     if (!is.null(vsc)){
#       vsc_all = vsc
#       vsc_all[small_cells] = vsce
#       vsc = vsc_all  
#     }
#   }
#   #rslt = list(lwr = lwr, upp = upp, vsc = vsc)
#   rslt = list(vsc = vsc)
#   return(rslt)
# }

###############################
#subrountine to find nbors (>=2D)
###############################
fn_nbors_2 = function(small_cells, amat, muhat){
  nr = length(small_cells)
  nb_lst_mins = nb_lst_maxs = vector("list", length = nr)
  #to_merge = list()
  #iter = 1
  for(k in 1:nr){ 
    ptk = small_cells[k]
    col_ptk = amat[, ptk]
    rows_ps = which(col_ptk != 0)
    amat_sub = amat[rows_ps, ,drop=FALSE]
    #nbors_k = apply(amat_sub, 1, function(.x) which(.x != 0)) %>% as.vector() %>% unique() 
    nbors_k = apply(amat_sub, 1, function(.x) which(.x != 0))
    nbors_k = unique(as.vector(nbors_k))
    nbors_k = nbors_k[-which(nbors_k == ptk)]
    #1: min candidates; -1: max candidates
    sgn_nbors_k = col_ptk[which(col_ptk != 0)]
    nbors_k_mins = nbors_k[sgn_nbors_k == 1]
    nbors_k_maxs = nbors_k[sgn_nbors_k == -1]
    nbors_k_min = nbors_k_max = integer(0L)
    if(length(nbors_k_maxs) > 0){
      max_mu = max(muhat[nbors_k_maxs])
      nbors_k_max = nbors_k_maxs[which(muhat[nbors_k_maxs] == max_mu)]
    }
    if(length(nbors_k_mins) > 0){
      min_mu = min(muhat[nbors_k_mins])
      nbors_k_min = nbors_k_mins[which(muhat[nbors_k_mins] == min_mu)]
    }
    
    nb_lst_mins[[k]] = nbors_k_min
    nb_lst_maxs[[k]] = nbors_k_max
  }
  
  rslt = list(nb_lst_maxs=nb_lst_maxs, nb_lst_mins=nb_lst_mins)
  return (rslt)
}

##############################################################
#subrountine to find nbors (>=2D)
#new neighbor routine for >= 2D empty cells
##############################################################
fn_nbors_3 = function(empty_cells, amat_0, muhat){
  nr = length(empty_cells)
  nb_lst_mins = nb_lst_maxs = vector("list", length = nr)
  to_merge_mins = to_merge_maxs = list()
  iter = 1
  for(k in 1:nr){
    nbors_k = NULL
    ptk = empty_cells[k]
    col_ptk = amat_0[, ptk]
    rows_ps = which(col_ptk != 0)
    amat_sub = amat_0[rows_ps, ,drop=FALSE]
    #nbors_k = apply(amat_sub, 1, function(.x) which(.x != 0)) %>% as.vector() %>% unique() 
    nbors_k = apply(amat_sub, 1, function(.x) which(.x != 0))
    nbors_k = unique(as.vector(nbors_k))
    nbors_k = nbors_k[-which(nbors_k == ptk)]
    #1: min candidates; -1: max candidates
    sgn_nbors_k = col_ptk[which(col_ptk != 0)]
    nbors_k_mins = nbors_k[sgn_nbors_k == 1]
    nbors_k_maxs = nbors_k[sgn_nbors_k == -1]
    
    #new: need to remove empty_cells from nbors_k_maxs and nbors_k_mins
    nbors_k_mins = setdiff(nbors_k_mins, empty_cells)
    nbors_k_maxs = setdiff(nbors_k_maxs, empty_cells)
    
    if(length(nbors_k_maxs) > 0){
      nbors_k_maxs = nbors_k_maxs[which(muhat[nbors_k_maxs] == max(muhat[nbors_k_maxs]))]
      #nbors_k_maxs <- nbors_k_maxs[muhat[nbors_k_maxs] == max(muhat[nbors_k_maxs])]
    }
    if(length(nbors_k_mins) > 0){
      nbors_k_mins = nbors_k_mins[which(muhat[nbors_k_mins] == min(muhat[nbors_k_mins]))]
    }
    
    #new: label the cells whose neighbors should be merged later
    rm_id2 = NULL
    if (length(nbors_k_maxs) > 0) {
      for(h in 1:length(nbors_k_maxs)){
        nbh = nbors_k_maxs[h]
        bool = any(empty_cells %in% nbh)
        if (bool) {
          ps = which(empty_cells == nbh)
          to_merge_maxs[[iter]] = sort(c(k, ps))
          iter = iter + 1
        }
      }
    }
    
    if (length(nbors_k_mins) > 0) {
      for(h in 1:length(nbors_k_mins)){
        nbh = nbors_k_mins[h]
        bool = any(empty_cells %in% nbh)
        if (bool) {
          ps = which(empty_cells == nbh)
          to_merge_mins[[iter]] = sort(c(k, ps))
          iter = iter + 1
        }
      }
    }
    
    nb_lst_mins[[k]] = nbors_k_mins 
    nb_lst_maxs[[k]] = nbors_k_maxs
  }
  
  #need to be more efficient
  to_merge_maxs = unique(to_merge_maxs)
  to_merge_mins = unique(to_merge_mins)
  nm_maxs = length(to_merge_maxs)
  nm_mins = length(to_merge_mins)
  if (nm_maxs > 1) {
    for(i in 1:(nm_maxs-1)){
      to_merge_i = to_merge_maxs[[i]]
      vec = to_merge_i
      vec_ps = i
      for(j in (i+1):nm_maxs){
        to_merge_j = to_merge_maxs[[j]]
        if(any(to_merge_j %in% to_merge_i)){
          vec = sort(unique(c(vec, to_merge_j)))
          vec_ps = c(vec_ps, j)
        }
      }
      if (length(vec_ps) > 1) {
        to_merge_maxs[vec_ps] = lapply(to_merge_maxs[vec_ps], function(x) x = vec)
      }
    }
    to_merge_maxs = unique(to_merge_maxs)
  }
  
  if (nm_maxs > 0) {
    for(i in 1:length(to_merge_maxs)) {
      ps = to_merge_maxs[[i]]
      #nbors_ps_maxs = unlist(nb_lst_maxs[ps]) %>% unique()
      nbors_ps_maxs = unique(unlist(nb_lst_maxs[ps]))
      rm_id2_maxs = NULL
      for(h in 1:length(nbors_ps_maxs)) {
        nbh = nbors_ps_maxs[h]
        bool = any(empty_cells %in% nbh)
        if (bool) {
          rm_id2_maxs = c(rm_id2_maxs, h)
        }
      }
      if (length(rm_id2_maxs) > 0) {
        nbors_ps_maxs = nbors_ps_maxs[-rm_id2_maxs]
      }
      nb_lst_maxs[ps] = lapply(nb_lst_maxs[ps], function(x) x = nbors_ps_maxs)
      
      ###
      for(h in 1:length(nbors_ps_maxs)) {
        nbh = nbors_ps_maxs[h]
        bool = any(empty_cells %in% nbh)
        if (bool) {
          rm_id2_maxs = c(rm_id2_maxs, h)
        }
      }
      if (length(rm_id2_maxs) > 0) {
        nbors_ps_maxs = nbors_ps_maxs[-rm_id2_maxs]
      }
      nb_lst_maxs[ps] = lapply(nb_lst_maxs[ps], function(x) x = nbors_ps_maxs)
    }
  }
  
  if (nm_mins > 0) {
    for(i in 1:length(to_merge_mins)) {
      ps = to_merge_mins[[i]]
      #nbors_ps_mins = unlist(nb_lst_mins[ps]) %>% unique()
      nbors_ps_mins = unique(unlist(nb_lst_mins[ps]))
      rm_id2_mins = NULL
      for(h in 1:length(nbors_ps_mins)) {
        nbh = nbors_ps_mins[h]
        bool = any(empty_cells %in% nbh)
        if (bool) {
          rm_id2_mins = c(rm_id2_mins, h)
        }
      }
      if (length(rm_id2_mins) > 0) {
        nbors_ps_mins = nbors_ps_mins[-rm_id2_mins]
      }
      nb_lst_mins[ps] = lapply(nb_lst_mins[ps], function(x) x = nbors_ps_mins)
      
      ###
      for(h in 1:length(nbors_ps_mins)) {
        nbh = nbors_ps_mins[h]
        bool = any(empty_cells %in% nbh)
        if (bool) {
          rm_id2_mins = c(rm_id2_mins, h)
        }
      }
      if (length(rm_id2_mins) > 0) {
        nbors_ps_mins = nbors_ps_mins[-rm_id2_mins]
      }
      nb_lst_mins[ps] = lapply(nb_lst_mins[ps], function(x) x = nbors_ps_mins)
    }
  }
  
  rslt = list(nb_lst_maxs=nb_lst_maxs, nb_lst_mins=nb_lst_mins)
  return (rslt)
}


############
#fitted.csvy
############
fitted.csvy = function(object,...)
{
  ans = object$muhat
  ans
}

##############
#confint.csvy
##############
confint.csvy <- function(object, parm = NULL, level = 0.95, type = c("link", "response"), ...) {
  family <- object$family
  type <- match.arg(type)
  
  # Warn if parm is specified but not used
  if (!is.null(parm)) {
    warning("'parm' argument is currently ignored and all parameters are returned.")
  }
  
  # Validate required components
  if (is.null(object$lwr) || is.null(object$upp)) {
    stop("The object must contain 'lwr' and 'upp' components for confidence interval computation.")
  }
  
  # Compute z multiplier
  z_mult <- qnorm((1 - level) / 2, lower.tail = FALSE)
  
  linkinv <- family(object)$linkinv
  # Extract or transform bounds
  lwr <- switch(type,
                link = object$lwr,
                response = linkinv(object$lwr))
  
  upp <- switch(type,
                link = object$upp,
                response = linkinv(object$upp))
  
  # Construct output
  ci_bands <- as.data.frame(cbind(lwr, upp))
  
  # Label the columns with confidence level
  alpha <- (1 - level) / 2
  colnames(ci_bands) <- paste0(format(c(alpha, 1 - alpha) * 100, digits = 3), "%")
  
  # Optional: attach confidence level as attribute
  attr(ci_bands, "conf.level") <- level
  
  return(ci_bands)
}


###############
#print method
###############
print.csvy = function(x,...){
  print(x$survey.design, varnames = FALSE, design.summaries = FALSE,...)
  NextMethod()
}

###################################################################################
#extract mixture variance-covariance matrix
###################################################################################
vcov.csvy = function(object,...){
  #v = object$acov_cp
  v = object$acov
  #temp: with empty cells, off-diagonal values in v are not imputed
  if (all(object$nd > 0)) {
    dimnames(v) = list(names(coef(object)), names(coef(object)))
  } else {
    nv = dim(v)[1]
    vec_nm = as.character(1:nv)
    dimnames(v) = list(vec_nm, vec_nm)
  }
  return(v)
}

#############################################
#summary for csvy
#I don't remember why I changed this summary function
#use the one in _0517 instead
#############################################
# summary.csvy = function(object, ...) {
#   family = object$family
#   CIC = object$CIC
#   CIC.un = object$CIC.un
#   deviance = object$deviance
#   null.deviance = object$null.deviance
#   bval = object$bval
#   pval = object$pval
#   # edf
#   df.residual = object$df.residual
#   df.null = object$df.null
#   edf = object$edf
#   tms = object$terms
#   rslt = NULL
#   # learn from summary.svyglm
#   # not work for gaussian; svyby
#   #dispersion = svyvar(resid(object,"pearson"), object$survey.design, na.rm = TRUE)
#   #disperson = survey:::summary.svrepglm(object)$disperson
#   if(inherits(object, "svrepglm")){
#     presid = resid(object,"pearson")
#     dispersion = sum(object$survey.design$pweights*presid^2, na.rm = TRUE) / sum(object$survey.design$pweights)
#   } else if (inherits(object, "svyglm")){
#     dispersion = svyvar(resid(object, "pearson"), object$survey.design, na.rm = TRUE)
#   }
#   if (!is.null(pval)) {
#     # rslt = data.frame("edf" = round(resid_df_obs, 4), "mixture of Beta" = round(bval, 4), "p.value" = round(pval, 4))
#     # same as cgam:
#     rslt = data.frame("edf" = round(edf, 4), "mixture of Beta" = round(bval, 4), "p.value" = round(pval, 4))
#     # check?
#     rownames(rslt) = rev((attributes(tms)$term.labels))[1]
#     structure(list(
#       call = object$call, coefficients = rslt, CIC = CIC, CIC.un = CIC.un, df.null = df.null, df.residual = df.residual,
#       null.deviance = null.deviance, deviance = deviance, family = family, dispersion = dispersion
#     ), class = "summary.csvy")
#   } else {
#     structure(list(
#       call = object$call, CIC = CIC, CIC.un = CIC.un, df.null = df.null, df.residual = df.residual,
#       null.deviance = null.deviance, deviance = deviance, family = family, dispersion = dispersion
#     ), class = "summary.csvy")
#     # warning("summary function for a csvy object only works when test = TRUE!")
#   }
# }


summary.csvy <- function(object,...) {
  family <- object$family
  CIC <- object$CIC
  CIC.un <- object$CIC.un
  deviance <- object$deviance
  null.deviance <- object$null.deviance
  bval <- object$bval
  pval <- object$pval
  #edf
  df.residual <- object$df.residual
  df.null <- object$df.null
  edf <- object$edf
  tms <- object$terms
  rslt <- NULL
  if (!is.null(pval)) {
    #rslt = data.frame("edf" = round(resid_df_obs, 4), "mixture of Beta" = round(bval, 4), "p.value" = round(pval, 4))
    #same as cgam:
    rslt <- data.frame("edf" = round(edf, 4), "mixture of Beta" = round(bval, 4), "p.value" = round(pval, 4))
    #check?
    rownames(rslt) <- rev((attributes(tms)$term.labels))[1]
    structure(list(call = object$call, coefficients = rslt, CIC = CIC, CIC.un = CIC.un, df.null = df.null, df.residual = df.residual,
                   null.deviance = null.deviance, deviance = deviance, family = family), class = "summary.csvy")
    
    #structure(list(call = object$call, coefficients = rslt, CIC = CIC, df.residual = df.residual,
    #               null.deviance = null.deviance, deviance = deviance, family = family), class = "summary.csvy")
    #ans = list(call = x$call, coefficients = rslt, CIC = CIC, 
    #           null.deviance = null.deviance, deviance = deviance, 
    #           family = family)
    #class(ans) = "summary.csvy"
    #ans
  } else {
    structure(list(call = object$call, CIC = CIC, CIC.un = CIC.un, df.null = df.null, df.residual = df.residual,
                   null.deviance = null.deviance, deviance = deviance, family = family), class = "summary.csvy")
    #warning("summary function for a csvy object only works when test = TRUE!")
  }
}

#####################
#print.summary.csvy#
#####################
print.summary.csvy = function(x,...) {
  cat("Call:\n")
  print(x$call)
  cat("\n")
  cat("Null deviance: ", round(x$null.deviance, 4), "", "on", x$df.null, "", "degrees of freedom", "\n")
  cat("Residual deviance: ", round(x$deviance, 4), "", "on", x$df.residual, "", "degrees of freedom", "\n")
  if (!is.null(x$coefficients)) {
    cat("\n")
    cat("Approximate significance of constrained fit: \n")
    printCoefmat(x$coefficients, P.values = TRUE, has.Pvalue = TRUE)
  }
  if (!is.null(x$CIC)) {
    cat("CIC (constrained estimator): ", round(x$CIC, 4))
    cat("\n")
    if (!is.null(x$CIC.un)) {
      cat("CIC (unconstrained estimator): ", round(x$CIC.un, 4), "\n")
    }
    #cat("CIC (unconstrained estimator): ", round(x$CIC.un, 4))
  }
  invisible(x)
}


##############
#predict.csvy#
##############
predict.csvy = function(object, newdata=NULL, type=c("link","response"), 
                        se.fit=TRUE, level=0.95, n.mix=100,...)
{
  family <- object$family
  type <- match.arg(type)
  if (!inherits(object, "csvy")) { 
    warning("calling predict.csvy(<fake-csvy-object>) ...")
  }
  #match newdata with grid
  xnms_add <- object$xnms_add
  grid2 <- object$grid
  colnames(grid2) <- xnms_add
  #add colnames in the object?
  if (missing(newdata) || is.null(newdata)) {
    newdata <- grid2
  } else {
    #change it to be that rows in new data must exist in grid2
    if (!is.data.frame(newdata)) {
      stop ("newdata must be a data frame!")	
    } 
    if (is.data.frame(newdata) & any(!colnames(newdata) %in% xnms_add)) {
      missing_vars <- setdiff(xnms_add, colnames(newdata))
      stop("newdata is missing the following required predictor(s): ", paste(missing_vars, collapse = ", "))
    }
  }
  
  #learnt from predict.svyglm:
  #if (type=="terms")
  #  return(predterms(object,se=se.fit,...))
  #make it simpler?
  #if(ncol(newdata) == 1){
  #  ps = apply(newdata, 1, function(elem) grid2[which(apply(grid2, 1, function(gi) all(gi == elem))),])
  #}
  
  #re-order columns when we have >= 2 predictors
  if(NCOL(newdata) > 1){
    ps0 <- sapply(colnames(newdata), function(elem) which(sapply(colnames(grid2), function(gi) all(gi == elem))))
    #order the column in case the user doesn't define the variables as the order used in the formula
    newdata <- newdata[, as.numeric(ps0)]
    #ps = apply(newdata, 1, function(elem) which(apply(grid2, 1, function(gi) all(gi == elem))))
  }
  ps <- match_rows(newdata, grid2)
  # if(anyNA(ps)){
  #   undefined_rows <- newdata[which(is.na(ps)), , drop = FALSE]
  #   stop("The following rows have undefined domains:\n", paste(capture.output(print(undefined_rows)), collapse = "\n"))
  # }
  #identify rows with undefined predictions
  na_rows <- which(is.na(ps))
  if (all(is.na(ps))) {
    stop("None of the domains in 'newdata' are defined in the model.")
  } else if (length(na_rows) > 0) {
    removed_rows <- newdata[na_rows, , drop = FALSE]
    warning("The following rows were removed because their domains are not defined in the model:\n",
            paste(capture.output(print(removed_rows)), collapse = "\n"))
    newdata <- newdata[-na_rows, , drop = FALSE]
    ps <- ps[-na_rows]
  }
  
  #print (ps)
  #if object$n.mix = 0, then set n.mix=100 and simulate vsc_mix, lwr, upp
  #otherwise, mixture covariance is already computed in the object, just use vsc_mix, lwr, upp
  etahat <- object$etahat 
  vsc_mix <- object$vsc_mix 
  lwr <- object$lwr
  upp <- object$upp
  #interval = match.arg(interval)
  if(!se.fit) {
    fit = etahat[ps]
    fit = switch(type, response = family(object)$linkinv(fit), link = fit)
    return (list(fit = fit))
  } else {
    #when n.mix = 0, then the vsc_mix is the unconstrained variance
    #if (n.mix < 1) {
    #  stop("Argument 'n.mix' must be >= 1 to simulate a mixture covariance matrix.")
    #}
    use_parallel = isTRUE(getOption("csurvey.multicore", FALSE))
    #decide how many cores to use
    cores = 1
    if(use_parallel){
      cores = getOption("csurvey.cores", parallel::detectCores(logical = FALSE) - 1)
      #cores = min(8, getOption("csurvey.cores", parallel::detectCores(logical = FALSE) - 1))
      cores = max(1, cores)
    }
    #detect OS platform
    is_windows = .Platform$OS.type == "windows"
    if(object$n.mix == 0) {
      c(ynm, amat, nd, muhat, etahatu, w, v1, domain.ord, xvs2, xm.red, ne, zeros_ps, obs_cells, zeros, Nhat, Ds, sh, amat_0, Nd) %<-% list(
        object$ynm, object$amat, object$nd, object$muhat, object$etahatu, object$w, object$cov.un, object$domain.ord, object$xvs2, object$xm.red, object$ne,
        object$empty_cells, object$obs_cells, object$zeros, object$Nhat, object$Ds, object$shapes_add, object$amat_0, object$Nd)
      
      #M = object$M #excluded empty cells
      M = length(nd)
      empty_cells = zeros_ps
      #nsm = length(small_cells)
      M = M - zeros
      #ysims = MASS::mvrnorm(n.mix, mu=muhat, Sigma=v1)
      #n.mix = 100
      #binomial can use mvrnorm?
      #etahat returned from object has been imputed
      if(length(empty_cells) > 0){
        etahat = etahat[-empty_cells]
        #new: domain.ord has been imputed
        domain.ord = domain.ord[-empty_cells]
      }
      
      #get the constrained variance-covariance matrix
      #lwr = upp = NULL
      #acov = v1
      dp = -t(amat)
      dp = apply(dp, 2, function(e) e / (sum(e^2))^(.5))
      #test!
      #dp = Matrix(dp, sparse = TRUE)
      
      m_acc = M 
      sector = NULL
      #binomial can use mvrnorm?
      ysims = rbind(MASS::mvrnorm(n.mix, mu=etahat, Sigma=v1))
      #new: 2025
      if(use_parallel && cores > 1){
        #message(sprintf("Running in parallel using %d cores.", cores))
        if(is_windows){
          # Windows: use parLapply with a PSOCK cluster
          cl = parallel::makeCluster(cores)
          on.exit(parallel::stopCluster(cl))  # ensure cleanup
          mixture_rslt = parallel::parLapply(cl, 1:n.mix, compute_mixture, ysims=ysims, amat=amat, w=w)
        } else {
          # Unix/macOS: use mclapply
          mixture_rslt = parallel::mclapply(1:n.mix, compute_mixture, ysims=ysims, amat=amat, w=w)
          #cat('finished mixture', '\n')
        }
      }else{
        #message("Running in serial mode.")
        mixture_rslt = lapply(1:n.mix, compute_mixture, ysims=ysims, amat=amat, w=w)
      }
      
      mixture_rslt <- do.call(rbind, mixture_rslt) |> as.data.frame()
      bsec = aggregate(list(freq = rep(1, nrow(mixture_rslt))), mixture_rslt, length) |>
        mutate(prob = freq / sum(freq)) |>
        dplyr::select(freq, prob, everything())
      
      sector = bsec[, -c(1, 2), drop = FALSE]
      nsec = NROW(sector)
      imat = diag(M)
      if(nsec > 0){
        acov = matrix(0, nrow=M, ncol=M)
        wtinv = diag(1/w)
        for(isec in 1:nsec){
          #print (isec)
          jvec = sector[isec, ]
          if(all(jvec == 0)){
            next 
          } else {
            smat = dp[, which(jvec == 1), drop = FALSE]
            wtinv_sm = wtinv %*% smat
            pmat_is_p = wtinv_sm %*% solve(t(smat) %*% wtinv_sm, t(smat))
            pmat_is = (imat - pmat_is_p)
            p_j = bsec[isec,2]
            acov = acov + p_j * pmat_is%*%v1%*%t(pmat_is)
          }
        }
      } else {
        acov = v1
      }
      
      z.mult = qnorm((1 - level)/2, lower.tail=FALSE)
      vsc = diag(acov)
      hl = z.mult*sqrt(vsc)
      lwr = etahat - hl
      upp = etahat + hl
      
      #add unconstr: in case the level in predict is different from the 
      #level in csvy?
      vscu = diag(v1)
      hlu = z.mult*sqrt(vscu)
      lwru = etahatu - hlu
      uppu = etahatu + hlu
      
      grid2 = expand.grid(xvs2, KEEP.OUT.ATTRS = FALSE, stringsAsFactors = FALSE)
      if(ne >= 1){
        #etahatu = yvecu
        #new: if there's empty cells, just augment muhatu to include NA's
        etahatu_all = 1:(M+zeros)*0
        etahatu_all[zeros_ps] = NA
        etahatu_all[obs_cells] = etahatu
        etahatu = etahatu_all
        
        lwru_all = 1:(M+zeros)*0
        lwru_all[zeros_ps] = NA
        lwru_all[obs_cells] = lwru
        lwru = lwru_all
        
        uppu_all = 1:(M+zeros)*0
        uppu_all[zeros_ps] = NA
        uppu_all[obs_cells] = uppu
        uppu = uppu_all
        
        #for monotone only: to find the c.i. for empty/1 cells by using smallest L and largest U of neighbors
        #new: create grid2 again because the routines to do imputation need to have each x start from 1
        #xvs2 = apply(xm.red, 2, function(x) 1:length(unique(x)))
        xvs2 <- lapply(xm.red, function(x) sort(unique(x)))
        grid2 = expand.grid(xvs2, KEEP.OUT.ATTRS = FALSE, stringsAsFactors = FALSE)
        #revise later: replace lwru with lwr etc
        ans_im = impute_em2(empty_cells, obs_cells, M=(M+zeros), etahatu, Nd, nd, Nhat, w, 
                            domain.ord, grid2, Ds, sh, etahat, lwr=lwr, upp=upp, vsc=vsc, amat_0, level=level)
        etahat = ans_im$muhat
        #lwr = ans_im$lwr
        #upp = ans_im$upp
        vsc = ans_im$vsc
        domain.ord = ans_im$domain.ord
      }
      
      vsc_mix = vsc
      
      hl = z.mult*sqrt(vsc_mix)
      lwr = etahat - hl
      upp = etahat + hl
      
      #new: check monotonicity of lwr and upp
      sh_0 = sh
      #relabel sh == 1 and 9?
      if(any(sh == 9)){
        sh_0[which(sh == 9)] = 1
      }
      #remove sh = 0, there is no amat for sh = 0
      Ds_0 = Ds
      if(any(sh == 0)){
        rm_id = which(sh == 0)
        sh_0 = sh_0[-rm_id]
        Ds_0 = Ds[-rm_id]
      }
      
      #print (dim(amat_0))
      #print (length(lwr))
      check_ps_lwr = round(c(amat_0 %*% lwr), 4) 
      check_ps_upp = round(c(amat_0 %*% upp), 4) 
      not_monotone_ci = any(check_ps_lwr < 0) | any(check_ps_upp < 0)
      #check more
      wvec = nd/sum(nd)
      if(any(wvec == 0)){
        #wvec[which(wvec == 0)] = 1e-4# 1e-12 #1e-4
        wvec[which(wvec == 0)] = pmin(1e-5, min(wvec[wvec > 0]))
      }
      #wvec[wvec == 0] <- 1e-4  
      
      vsc_mix_0 = vsc_mix
      # if (length(sh_0) == 1 & n.mix > 0 & not_monotone_ci) {
      #   lwr = coneA(lwr, amat_0, w = wvec)$thetahat
      #   upp = coneA(upp, amat_0, w = wvec)$thetahat
      #   vsc_mix_0 = ((upp - lwr) / 4)^2
      # }
      
      #new:
      check_ps_etahat = round(c(amat_0 %*% etahat), 4) 
      if (length(sh_0) > 1 & n.mix > 0 & any(check_ps_etahat < 0)) {
        etahat = coneA(etahat, amat_0, w = wvec)$thetahat
      }
      
      #handles small domains automatically
      if (length(sh_0) >= 1 & n.mix > 0 & not_monotone_ci) {
        lwr = coneA(lwr, amat_0, w = wvec)$thetahat
        upp = coneA(upp, amat_0, w = wvec)$thetahat
        #new:
        #vsc_mix_0 = ((upp - lwr) / 4)^2
        #new:
        vsc_mix_0 = ((upp - lwr) / (2*z.mult))^2
      }
      #new: should let vsc_mix = vsc_mix_0 
      vsc_mix = vsc_mix_0 
    }
    #eta = etahat[ps]; se = vsc[ps]
    #learnt from predict.svyglm:
    #z.mult = qnorm((1 - level)/2, lower.tail=FALSE)
    fit = etahat[ps]
    se.fit = sqrt(vsc_mix[ps])
    lwr = lwr[ps]
    upp = upp[ps]
    #lwr = etahat[ps] - z.mult*se.fit
    #upp = etahat[ps] + z.mult*se.fit
    
    lwr = switch(type, response = object$family$linkinv(lwr), link = lwr)
    upp = switch(type, response = object$family$linkinv(upp), link = upp)
    #se.fit = switch(type, link = se.fit, response = object$family$linkinv(se.fit))
    se.fit <- switch(type, response = se.fit * abs(object$family$mu.eta(fit)), link = se.fit)
    fit = switch(type, response = object$family$linkinv(fit), link = fit)
    return (list(fit = fit, lwr = lwr, upp = upp, se.fit = se.fit))
  }
}

#---------------------------------------------------------------------------------------------------------------------------------------
#routines modified from methods of svyby or svyglm, svyrepglm
#---------------------------------------------------------------------------------------------------------------------------------------
# dotchart.csvy <- function(x, ..., pch = 19) {
#   xx <- x$ans.unc_cp
#   args <- list(...)
#   if (is.null(args$pch)) args$pch <- pch
#   do.call(survey::dotchart, c(list(xx), args))
# }

#---------------------------------------------------------------------------------------------------------------------------------------
deff.csvy <- function(object, ...) {
  if (is.null(object$ans.unc_cp)) {
    stop("Cannot compute design effect: object$ans.unc_cp is NULL.")
  }
  survey::deff(object$ans.unc_cp, ...)
}

#---------------------------------------------------------------------------------------------------------------------------------------
barplot.csvy = function(height, beside = TRUE,...){
  x = height$ans.unc_cp
  #survey:::barplot.svyby(xx, beside = TRUE,...)
  graphics::barplot(x, beside = TRUE,...)
}

#plot.csvy = barplot.csvy
#---------------------------------------------------------------------------------------------------------------------------------------
SE.csvy = function(object,...){
  x = object$ans.unc_cp
  survey::SE(x,...)
}

#---------------------------------------------------------------------------------------------------------------------------------------
coef.csvy = function(object, ...) {
  x = object$ans.unc_cp
  stats::coef(x,...)
}

#---------------------------------------------------------------------------------------------------------------------------------------
svycontrast.csvy = function(stat, contrasts,...) {
  x = stat$ans.unc_cp
  survey::svycontrast(x, contrasts,...)
}

#---------------------------------------------------------------------------------------------------------------------------------------
#fix
ftable.csvy = function(x,...) {
  xx = x$ans.unc_cp
  stats::ftable(xx,...)
}

#####################################################################
#for parallel
#####################################################################
.onLoad <- function(libname, pkgname) {
  op <- options()
  op.csurvey <- list(
    csurvey.multicore = FALSE,
    csurvey.cores = max(1, parallel::detectCores(logical = FALSE) - 1)
  )
  toset <- !(names(op.csurvey) %in% names(op))
  if (any(toset)) options(op.csurvey[toset])
  invisible()
}

.get_csurvey_option <- function(name, default = NULL) {
  getOption(paste0("csurvey.", name), default)
}

#---------------------------------------------------------------------------------------------------------------------------------------
# confint.csvyby = function(object, parm, level = 0.95, df = Inf,...){
#   x = object$ans.unc_cp
#   confint(x, parm, level, df,...)
# }

#---------------------------------------------------------------------------------------------------------------------------------------
# deff.csvyby = function(object,...){
#   x = object$ans.unc_cp
#   if(missing(x) | is.null(x)){
#     stop("only works when the fit is a class of svyby!")
#   }
#   deff(x,...)
# }

Try the csurvey package in your browser

Any scripts or data that you put into this service are public.

csurvey documentation built on June 8, 2025, 12:41 p.m.