## STATISTICS functions

#' Describe data
#' @return A data frame
#' @export
#' @examples
#' describe2(iris)
#' describe2(mpg)
#' describe2(iris, group = iris$Species)
describe2 = function(x, group = NULL, all_vars = F) {

  #convert logical variables to numeric
  for (v in names(x)) {
    if (is.logical(x[[v]])) x[[v]] = x[[v]] %>% as.numeric()

  #remove factors and chrs
  x_nonnum = purrr::map_lgl(x, ~is.character(.)|is.factor(.))
  x = x[!x_nonnum]

  #group level?
  if (!is.null(group)) {
    #factor for the right order of groups, if the user wants
    group = as.factor(group)

    #loop groups and bind
    y = plyr::ddply(bind_cols(x, ..group = group), .variables = "..group", .fun = function(dd) {
      dd %>% select(-`..group`) %>% describe2(all_vars = all_vars)
    }) %>% rename(group = ..group)

  } else {
    #get results
    y = psych::describe(as.data.frame(x), na.rm = TRUE, interp = FALSE, skew = TRUE, ranges = TRUE,
                        trim = 0.1, type = 3, check = TRUE, fast = NULL, quant = NULL,
                        IQR = FALSE, omit = FALSE, data = NULL)

    #fix class of output
    class(y) = "data.frame"

    #subset, explicit rownames
    if (!all_vars) {
      y %<>%
        tibble::rownames_to_column(var = "var") %>%
        dplyr::select(var, n, mean, median, sd, mad, min, max, skew, kurtosis)
    } else {
      y %<>%
        tibble::rownames_to_column(var = "var")


#' Adjust Cohen's d for measurement error
#' Adjust' Cohen's d for measurement error using the Pearson r approach
#' @param x A vector of d values
#' @param rel A vector of reliabilities
#' @return A vector of adjusted d values
#' @export
#' @examples
#' adj_d_reliability(1.00, 0.8)
adj_d_reliability = function(x, rel) {
  #nothing to do
  if (rel == 1) return(x)
  if (rel == 0) return(Inf * sign(x))
  if (!is_between(rel, a = 0, b = 1)) stop("Reliability must be within [0:1]", call. = F)

  #convert to r and adjust
  r = (x/(x^2+4)^0.5)

  r_adj = r/sqrt(rel)

  #back to d
  2 * r_adj/sqrt(1 - r_adj^2)

#' Compute correlation confidence intervals from correlations and standard errors
#' @param cor A correlation value or matrix
#' @param se A standard error value or matrix
#' @param ci The desired confidence level
#' @param winsorise Keep values within -1 to 1
#' @return If you entered a single value, you get 2 values back: lower and upper confidence numbers. If you entered matrices, you get 2 matrices back in a list.
#' @export
#' @examples
#' #simple results
#' cor_CI_from_SE(.50, .10)
#' cor_CI_from_SE(.50, .10, ci = .99)
#' #matrices
#' iris_res = suppressWarnings(weights::wtd.cor(iris[-5]))
#' cor_CI_from_SE(iris_res$correlation, iris_res$std.err)
cor_CI_from_SE = function(cor, se, ci = .95, winsorise = T) {
  #matrix input?
  assert_that(length(cor) == length(se))

  y = c(
    cor - qnorm(ci + ((1-ci)/2), lower.tail = T) * se,
    cor + qnorm(ci + ((1-ci)/2), lower.tail = T) * se

  if (winsorise) {
    y %<>% winsorise(upper = 1, lower = -1)

  #matrix input, then restore the matrix
  if (is.matrix(cor)) {
    #split sizes
    num_in_matrix = nrow(cor) * ncol(cor)
    y_lower = matrix(y[1:num_in_matrix], nrow = nrow(cor), ncol = ncol(cor))
    y_upper = matrix(y[(num_in_matrix+1):(num_in_matrix*2)], nrow = nrow(cor), ncol = ncol(cor))

      lower = y_lower,
      upper = y_upper


#' Correlation matrix
#' Outputs a correlation matrix. Supports weights, confidence intervals, correcting for measurement error and rounding.
#' Correction for measurement error is done using the standard Pearson formula: r_true = r_observed / sqrt(reliability_x * reliability_y).
#' Weighted correlations are calculated using wtd.cor or wtd.cors from weights package.
#' `rank_order` can take either a logical scalar or a character scalar. If given TRUE, it will use rank ranking method with the default settings (average ranks). If given a chr scalar, it will use that ranking method. If given FALSE, will not use rank data (default).
#' Confidence intervals are analytic confidence intervals based on the standard error.
#' @param data (data.frame or coercible into data.frame) The data.
#' @param weights (numeric vector, numeric matrix/data.frame or character scalar) Weights to use for the correlations. Can be a numeric vector with weights, the name of a variable in data, or a matrix/data.frame with weights for each variable. If the latter, then harmonic means are used. If none given, defaults to rep(1, nrow(data)).
#' @param by Grouping variable
#' @param reliabilities (num vector) Reliabities used to correct for measurement error. If not present, assumed to be 1.
#' @param CI (numeric scalar) The confidence level to use as a fraction.
#' @param CI_template (character scalar) A template to use for formatting the confidence intervals.
#' @param skip_nonnumeric (logical scalar) Whether to skip non-numeric variables. Defaults to TRUE.
#' @param CI_round (whole number scalar) If confidence intervals are used, how many digits should be shown?
#' @param p_val (log scalar) Add p values or not.
#' @param p_template (chr scalar) If p values are desired, the template to use.
#' @param p_round (int scalar) Number of digits to round p values to. Uses scientific notation for small numbers.
#' @param asterisks The thresholds to use for p value asterisks
#' @param asterisks_only Whether to only include astrisks not numerical values
#' @param rank_order (lgl or chr) Whether to use rank ordered data so as to compute Spearman's correlations instead.

#' @export
#' @examples
#' cor_matrix(iris) #just correlations
#' cor_matrix(iris, CI = .95) #with confidence intervals
#' cor_matrix(iris, CI = .99) #with 99% confidence intervals
#' cor_matrix(iris, p_val = .95) #with p values
#' cor_matrix(iris, p_val = .95, p_template = "%r (%p)") #with p values, with an alternative template
#' cor_matrix(iris, reliabilities = c(.8, .9, .7, .75)) #correct for measurement error
#' cor_matrix(iris, reliabilities = c(.8, .9, .7, .75), CI = .95) #correct for measurement error + CI
#' cor_matrix(iris, rank_order = T) #rank order correlations, default method
#' cor_matrix(iris, rank_order = "first") #rank order correlations, specific method
#' cor_matrix(iris, weights = "Petal.Width") #weights from name
#' cor_matrix(iris, weights = 1:150) #weights from vector
#' #complex weights
#' cor_matrix(iris, weights = matrix(runif(nrow(iris) * 4), nrow = nrow(iris)))
#' cor_matrix(iris, weights = matrix(runif(nrow(iris) * 4), nrow = nrow(iris)), CI = .95)
#' #groups
#' cor_matrix(iris, by = iris$Species)
#' cor_matrix(iris, by = iris$Species, weights = 1:150)
cor_matrix = function(data, weights = NULL, by = NULL, reliabilities = NULL, CI = NULL, CI_template = "%r [%lower %upper]", skip_nonnumeric = T, CI_round = 2, p_val = F, p_template = "%r [%p]", p_round = 3, rank_order = F, asterisks = c(.01, .005, .001), asterisks_only = T) {

  data = as.data.frame(data)
  if (skip_nonnumeric) data = extract_num_vars(data)
  if (!is_numeric(data)) stop("data contains non-numeric columns!")
  is_(rank_order, class = c("character", "logical"), size = 1, error_on_false = T)

  #CI and p vals
  if (!is.null(CI) && p_val) stop("Cannot both calculate CIs and p values!")
  v_noextras = is.null(CI) && !p_val

  if (!is.null(by)) {
    assert_that(is.character(by) | is.factor(by) | is.logical(by))

    #drop rows where group is NA but throw warning as this might be an error
    if (anyNA(by)) {
      warning(call. = F, "Grouping variable contains `NA` values, check your data")
      data = data[!is.na(by), ]

      #error if no data remaining
      if (nrow(data) == 0) stop("Grouping variable contains only `NA` values!", call. = F)

    #groups and their order
    groups = levels(forcats::fct_drop(as.factor(by)))

  if (is.null(reliabilities)) {
    reliabilities = rep(1, ncol(data))
  } else {
    #check length
    if (length(reliabilities) != ncol(data)) stop("reliabilities length incorrect")

  #weights not given or as character
  weights_used = !is.null(weights)
  if (is.null(weights)) weights = rep(1, nrow(data))
  if (is.character(weights)) {
    weights = data[[weights]] #fetch from data

  #simple weights?
  simpleweights = !(is.matrix(weights) || is.data.frame(weights))
  if (simpleweights) {
    weights = as.numeric(weights) #this prevents some weird data type errors!
    if (length(weights) != nrow(data)) stop("weights not the same length as the data!")

  #rank order data?
  if (is.logical(rank_order) && rank_order) {
    data = df_rank(data)
  if (is.character(rank_order)) {
    data = df_rank(data, ties.method = rank_order)

  ##simple weights and no extras?
  if (simpleweights && v_noextras) {
    make_cor_mat_simple = function(group) {
      #subset data if any group
      if (!is.null(group)) {
        data = data[NA_to_F(by == group), ]
        weights = weights[NA_to_F(by == group)]

        #error if no data for group
        if (nrow(data) == 0) stop(stringr::str_glue("No data for group {group}!"))

      #these just indicate near-1 correlations
      m_res = suppressWarnings(weights::wtd.cor(data, weight = weights))
      m = m_res$correlation

      #correct for unreliability
      m = psych::correct.cor(m, reliabilities)

      #winsor to -1 to 1
      m = winsorise(m, lower = -1, upper = 1)

      #reliabilities in diagonal
      diag(m) = reliabilities


    #groups or not?
    if (!is.null(by)) {
      m = purrr::map(groups, make_cor_mat_simple)
      names(m) = groups
    } else {
      m = make_cor_mat_simple(group = NULL)


  #complex weights check
  if (length(get_dims(weights)) == 2) {

    #do weights fit the data?
    weights_data = all(get_dims(weights) == get_dims(data))

    #do weights fit the correlation matrix?
    #weights_cor_matrix = get_dims(weights) == rep(ncol(data), 2)

    if (!weights_data) stop(str_c("weights did not fit the data!"))

  #inner function
  #use a function so it can loop across groups if needed
  make_cor_mat = function(group = NULL) {
    #subset data if any group
    if (!is.null(group)) {
      data = data[NA_to_F(by == group), ]
      weights = weights[NA_to_F(by == group), ]

      #error if no data for group
      if (nrow(data) == 0) stop(stringr::str_glue("No data for group {group}!"))

    #make matrix
    m = matrix(ncol = ncol(data), nrow = ncol(data))

    #fill in results
    for (row in 1:nrow(m)) {
      for (col in 1:ncol(m)) {
        #next if diagonal or above
        if (col >= row) next

        #simple weights & CI
        if (simpleweights && !is.null(CI)) {
          #weighted cor
          #these just indicate near-1 correlations
          r_obj = suppressWarnings(weights::wtd.cor(data[, c(row, col)], weight = weights))

          #correct for unreliability
          r_obj$correlation %<>% {. / sqrt(reliabilities[col] * reliabilities[row])}

          r_obj$correlation %<>% winsorise(1, -1)

          #sample size
          r_n = psych::pairwiseCount(data[row], data[col])

          #format cor
          r_r = r_obj$correlation[1, 2] %>% str_round(digits = CI_round)

          #confidence interval
          r_CI = cor_CI_from_SE(r_obj$correlation[1, 2], se = r_obj$std.err[1, 2], ci = CI)

          r_CI %<>% str_round(digits = CI_round)

          #format and save
          m[row, col] = stringr::str_replace(CI_template, "%r", r_r) %>%
            str_replace("%lower", r_CI[1]) %>%
            str_replace("%upper", r_CI[2])

        #simple weights & p_val
        if (simpleweights && p_val) {

          #observed r
          r_obj = weights::wtd.cor(data[row], data[col], weight = weights)

          #correct for unreliability
          r_obj[1] %<>% {. / sqrt(reliabilities[col] * reliabilities[row])}

          r_obj[1] %<>% winsorise(1, -1)

          #sample size
          r_n = psych::pairwiseCount(data[row], data[col])

          r_r = r_obj[1] %>% str_round(digits = CI_round)

          #finalize with p value
          if (asterisks_only) {
            m[row, col] = r_r + p_to_asterisk(r_obj[, "p.value"], asterisks = asterisks, asterisks_only = T)
          } else {

            m[row, col] = p_template %>%
              str_replace("%r", r_r) %>%
              str_replace("%p", p_to_asterisk(r_obj[, "p.value"], asterisks = asterisks, asterisks_only = F))

        #complex weights
        if (!simpleweights) {
          #for each case, compute the harmonic mean weight
          v_weights = weights[, c(row, col)] %>%
            t() %>%

          #if plain correlation output, compute and be done
          if (v_noextras) {
            m[row, col] = weights::wtd.cors(data[row], data[col], weight = v_weights) / sqrt(reliabilities[row] * reliabilities[col])

          #if CI wanted
          if (!is.null(CI)) {
            #observed r
            r_obj = weights::wtd.cor(data[row], data[col], weight = v_weights)

            #correct for unreliability
            r_obj[1] %<>% {. / sqrt(reliabilities[col] * reliabilities[row])}

            r_obj[1] %<>% winsorise(1, -1)

            #sample size
            r_n = psych::pairwiseCount(data[row], data[col])

            #format r
            r_r = r_obj[1] %>% str_round(digits = CI_round)

            #confidence interval
            r_CI = cor_CI_from_SE(r_obj[1], se = r_obj[2], ci = CI)

            r_CI %<>% str_round(digits = CI_round)

            #format and save
            m[row, col] = stringr::str_replace(CI_template, "%r", r_r) %>%
              str_replace("%lower", r_CI[1]) %>%
              str_replace("%upper", r_CI[2])

          #if p value wanted
          if (p_val) {
            #observed r
            r_obj = weights::wtd.cor(data[row], data[col], weight = v_weights)

            #correct for unreliability
            r_obj[1] %<>% {. / sqrt(reliabilities[col] * reliabilities[row])}

            r_obj[1] %<>% winsorise(1, -1)

            r_n = psych::pairwiseCount(data[row], data[col])

            #format r
            r_r = r_obj[1] %>% str_round(digits = CI_round)

            #finalize with p value
            if (asterisks_only) {
              m[row, col] = r_r + p_to_asterisk(r_obj[, "p.value"], asterisks = asterisks, asterisks_only = T)
            } else {

              m[row, col] = p_template %>%
                str_replace("%r", r_r) %>%
                str_replace("%p", p_to_asterisk(r_obj[, "p.value"], asterisks = asterisks, asterisks_only = F))


    #make symmetric
    m = MAT_half(m) %>% MAT_vector2full()

    rownames(m) = colnames(m) = colnames(data)

    diag(m) = reliabilities


  #groups or not?
  if (!is.null(by)) {
    m = purrr::map(groups, make_cor_mat)
    names(m) = groups
  } else {
    m = make_cor_mat(group = NULL)


#' Remove redundant variables
#' Remove redundant variables from a data.frame based on a threshold value. This is done by calculating all the intercorrelations, then finding those that correlate at or above the threshold (absolute value), then removing the second pair of each variable and not removing more variables than strictly necessary.
#' @param df (data.frame) A data.frame with numeric variables.
#' @param threshold (numeric scalar) A threshold above which intercorrelations are removed. Defaults to .9.
#' @param cor_method (character scalar) The correlation method to use. Parameter is fed to cor(). Defaults to "pearson".
#' @param messages (boolean) Whether to print diagnostic messages.
#' @export
#' @examples
#' remove_redundant_vars(iris[-5]) %>% head
remove_redundant_vars = function(df, threshold = .9, cor_method = "pearson", messages = T) {
  #Check input
  if (!is.data.frame(df)) {
    stop(paste0("First parameter is not a data frame. Instead it is ", class(df)))
  if (!is.numeric(threshold)) {
    stop(paste0("Second parameter is not numeric. Instead is ", class(num.to.remove)))
  if (!(threshold > 0 && threshold <= 1)) stop("threshold must be 0>x>=1 !")
  if (!is.logical(messages)) stop("messages paramter was not a logical!")

  #Old variable names
  old_names = colnames(df) #save old variable names

  #remove data in diag and top
  m = cor(df, use = "p", method = cor_method)

  #to long form
  m_long = cbind(expand.grid(rownames(m), colnames(m), stringsAsFactors = F),
                 r = as.vector(m),
                 abs_r = as.vector(m) %>% abs,
                 keep = upper.tri(m) %>% as.vector)

  #remove self-correlations and duplicates
  m_long = m_long[m_long$keep, ]

  #sort by abs r
  m_long = m_long[order(m_long$abs_r, decreasing = T), ]

  m_long = m_long[1:3]

  #over threshold message
  m_long_threshold = m_long[m_long$r >= threshold | m_long$r <= -threshold, ]
  if(nrow(m_long_threshold) != 0) {
    silence({message("The following variable pairs had stronger intercorrelations than |", threshold, "|:")}, messages = messages)
    if (messages) df_round(m_long_threshold) %>% print #round and print

  } else {
    silence({message("No variables needed to be excluded.")}, messages = messages)


  #exclude variables
  #one cannot just remove the ones in the Var2 col because this can result in variables being removed despite them not correlating >threshold with any variable
  vars_to_exclude = character() #for the varnames
  while(T) {
    #exit loop if done
    if (nrow(m_long_threshold) == 0) break

    #add top var in Var2 to the exclusion vector and remove top row
    vars_to_exclude = c(vars_to_exclude, m_long_threshold$Var2[1])
    m_long_threshold = m_long_threshold[-1, ]

    #remove rows that contain any variable from the exclusion vector
    exclude_rows = m_long_threshold$Var1 %in% vars_to_exclude | m_long_threshold$Var2 %in% vars_to_exclude
    m_long_threshold = m_long_threshold[!exclude_rows, ]

  #exclude variables
  silence({message(str_c("The following variables were excluded:"))}, messages = messages)
  silence({message(str_c(vars_to_exclude, collapse = ", "))}, messages = messages)
  df = df[!colnames(df) %in% vars_to_exclude]

  #return reduced df

#' Semi-partial correlation with weights
#' Returns the semi-partial correlation. Weights may be used.
#' @param x A numeric vector to correlate with y.
#' @param y A numeric vector to correlate with x after partialing out z.
#' @param z A numeric vector to partial out of y.
#' @param weights A numeric vector of weights to use. If none given, will return unweighted results.
#' @export
semi_par = function(x, y, z, weights = NULL, complete_cases = T) {

  #x vector
  x = as.vector(x)
  y = as.vector(y)

  if (is.null(w)) {
    w = rep(1, length(x))
  } else {
    w = as.vector(w)

  df = data.frame(x = as.vector(x), #we vectorize the input because otherwise may get strange
                  y = as.vector(y), #results when input is a df or matrix
                  z = as.vector(z),
                  w = as.vector(weights))

  #complete cases only
  if (complete_cases) {
    df = df[complete.cases(df), ]

  df$y_res = resid(lm(y ~ z, weights = w, data = df))
  r_sp = weights::wtd.cor(df$x, df$y_res, weight = df$w)
  r = weights::wtd.cor(df$x, df$y, weight = df$w)
  return(list(normal = r,
              semi_partial = r_sp))

#' Semi-partial correlation with weights
#' Returns a table of semi-partial correlations where a dependent variable has first been regressed on the primary predictor variable, and then correlated with each of the secondary predictors in turn.
#' @param df A data frame with the variables.
#' @param dependent A string with the name of the dependent variable.
#' @param primary A string with the name of the primary predictor variable.
#' @param secondaries A character vector with the names of the secondary predictor variables.
#' @param weights A string with the name of the variable to use for weights.
#' @param standardize Whether to standardize the data frame before running results. The weights variable will not be standardized.
#' @export
semi_par_serial = function(df, dependent, primary, secondaries, weights = NULL, standardize = T) {

  #subset and deal with lack of weights
  if (is.null(weights)) {
    weights = "weights_var"
    df[weights] = rep(1, nrow(df))
  } else {
    df["weights_var"] = df[weights] #move the weights var to another name
    weights = "weights_var"

  df = subset(df, select = c(dependent, primary, secondaries, weights))

  #complete cases only
  df = na.omit(df)

  if (standardize) df = df_standardize(df, exclude = weights)

  r_prim = round(weights::wtd.cor(df[, dependent], df[, primary], weight = df[, weights]), 2)

  #make results object
  results = data.frame(matrix(nrow = length(secondaries), ncol = 2))
  rownames(results) = secondaries
  colnames(results) = c("Orig. cor", "Semi-partial cor")
  #loop over each secondary
  for (sec_idx in seq_along(secondaries)) {

    #the current secondary var
    tmp_secondary = secondaries[sec_idx]

    #make the model
    tmp_model = stringr::str_c(dependent, " ~ ", primary)

    df$tmp_resids = resid(lm(as.formula(tmp_model), weights = df[, weights], data = df))

    #secondary original
    r_sec = weights::wtd.cor(df[, dependent], df[, tmp_secondary], weight = df[, weights])[1]

    r_sec_sp = weights::wtd.cor(df$tmp_resids, df[, tmp_secondary], weight = df[, weights])[1]

    results[sec_idx, ] = c(r_sec, r_sec_sp)


#' Calculate a partial correlation.
#' Calculates the partial correlation.
#' @param df A data frame.
#' @param x String with the name of the first variable.
#' @param y String with the name of the second variable.
#' @param z String with the name of the control variable.
#' @param weights_var String with the name of the weights variable. Can be left out.
#' @export
MOD_partial = function(df, x, y, z, weights_var = NULL) {

  #make or move weights
  if (is.null(weights_var)) {
    df$weights___ = rep(1, nrow(df)) #make unit weights
  } else {
    df$weights___ = df[[weights_var]] #reassign weights var
  weights_var = "weights___"

  #build models
  mod1 = stringr::str_c(x, " ~ ", str_c(z, collapse = " + "))
  mod2 = stringr::str_c(y, " ~ ", str_c(z, collapse = " + "))

  #fit models
  fit1 = lm(mod1, data = df, weights = weights___, na.action = na.exclude)
  fit2 = lm(mod2, data = df, weights = weights___, na.action = na.exclude)
  #na.exclude is important becus otherwise NA values are removed

  #get residuals
  resid1 = resid(fit1)
  resid2 = resid(fit2)

  r = weights::wtd.cor(resid1, resid2, weight = df$weights___)


#' Get t value by confidence-level and degree of freedom
#' Wrapper function for \code{qt} to get the t value needed for finding confidence intervals or p values.
#' @param conf (numeric scalar) The confidence level desired as a fraction.
#' @param df (numeric scalar) The degrees of freedom.
#' @param ... (other named arguments) Other arguments to pass to \code{qt}. See that function for details.
#' @export
#' @examples
#' #get t value for 95 pct. confidence interval with df = 20
#' get_t_value(.95, 20)
get_t_value = function(conf, df, ...) {
  value = conf + ((1 - conf) / 2)
  qt(value, df = df, ...)

#' Compute width of confidence interval
#' Calculate the width of a confidence interval given a standard error and sample size. Based on the t distribution.
#' @param se (numeric scalar) The standard error.
#' @param n (numeric scalar) The sample size. Defaults to Inf.
#' @param confidence_level (numeric scalar) The confidence level. Defaults to .95.
#' @param single_arm (logical scalar) Whether to calculate the width for a single arm. Defaults to TRUE.
#' @return The width of the confidence interval.
#' @export
#' @examples
#' conf_interval_width(1) #about 1.96
#' conf_interval_width(1, confidence_level = .99) #about 2.58
conf_interval_width <- function(se, n = Inf, confidence_level = 0.95, single_arm = T) {
  #use t distribution, but if no sample size is given, assume infinite degrees of freedom, same as z dist
  t_value <- qt(1 - (1 - confidence_level) / 2, df = n - 1)

  #single arm width
  width <- t_value * se

  #double it if desired
  if (!single_arm) {
    width <- 2 * width


#' Pooled sd
#' Calculate pooled sd.
#' @param x (numeric vector) The data.
#' @param group (vector) Group membership.
#' @export
#' @examples
#' #Wikipedia's example https://en.wikipedia.org/wiki/Pooled_variance
#' v_test_vals = c(31, 30, 29, 42, 41, 40, 39, 31, 28, 23, 22, 21, 19, 18, 21, 20, 19, 18, 17)
#' v_test_group = c(rep(1, 3), rep(2, 4), rep(3, 2), rep(4, 5), rep(5, 5))
#' pool_sd(v_test_vals, v_test_group)
pool_sd = function(x, group) {

  #validate input
  if (!is_simple_vector(x)) stop("x must be a vector!")
  group = as.factor(group)
  if (!is.factor(group)) stop("x must be a factor or convertible to that!")

  #to df
  d = data.frame("x" = x, "group" = group)

  d_sum = plyr::ddply(d, "group", plyr::summarize,
                df = sum(!is.na(x)) - 1,
                var = var(x, na.rm = TRUE)) %>%
    na.omit() #missing data groups

  #weighted sum divided by weights (with Bessel's correction), then square root
  sqrt(sum(d_sum$df * d_sum$var) / sum(d_sum$df))

#' Standardized mean differences
#' Calculate standardized mean differneces between all groups.
#' @param x (numeric vector) A vector of values.
#' @param group (vector) A vector of group memberships.
#' @param central_tendency (function) A function to use for calculating the central tendency. Must support a parameter called na.rm. Ideal choices: mean, median.
#' @param dispersion (character or numeric scalar) Either the name of the metric to use (sd or mad) or a value to use.
#' @param dispersion_method (character scalar) If using one of the built in methods for dispersion, then a character indicating whether to use the pooled value from the total dataset (all), the pairwise comparison (pair), or the sd from the total dataset (total).
#' @param extended_output (lgl) Whether to output a list of matrices. Useful for computational reuse.
#' @param CI (num) Confidence interval coverage.
#' @param str_template (chr) A string template to use.
#' @param reliability (num) A reliability to use for correcting for measurement error. Done via [kirkegaard::adj_d_reliability()].
#' @param se_analytic (lgl) Use analytic standard errors. If not, then it will bootstrapping (slower). Always uses bootstrapping for non-mean functions.
#' @param ... (other arguments) Additional arguments to pass to the central tendency function.
#' @export
#' @examples
#' SMD_matrix(iris$Sepal.Length, iris$Species)
#' SMD_matrix(iris$Sepal.Length, iris$Species, extended_output = T)
#' #pairwise SDs
#' SMD_matrix(iris$Sepal.Length, iris$Species, extended_output = T, dispersion_method = "pair")
SMD_matrix = function(x,
                      central_tendency = wtd_mean,
                      dispersion = "sd",
                      dispersion_method = "all",
                      extended_output = F,
                      CI = .95,
                      str_template = "%d [%lower %upper]",
                      digits = 2,
                      reliability = 1,
                      se_analytic = T,
                      ...) {

  #CI z score
  se_CI = qnorm(CI + (1 - CI)/2)

  #df form
  d_x = data.frame(x = x, group = as.factor(group)) %>%
    #remove missing data

  #find uniqs
  uniq = levels(d_x$group)

  #how many groups
  n_groups = length(uniq)

  #group sample sizes
  group_ns = table2(d_x[[2]], include_NA = F)
  group_ns = set_names(group_ns$Count, group_ns$Group) #change to named vector

  #make matrices for results
  m = matrix(NA, nrow = n_groups, ncol = n_groups)

  #set names
  colnames(m) = rownames(m) = uniq

  CI_lower = m
  CI_upper = m
  se = m
  pairwise_n = m
  m_str = m
  pval = m

  #loop for each combo
  for (row_i in seq_along(uniq)) {
    for (col_i in seq_along(uniq)) {
      #skip effect size if dia/above diag
      if (col_i >= row_i) next

      #set values
      col = uniq[col_i]
      row = uniq[row_i]

      #partition data
      d_comb = d_x[d_x$group %in% c(col, row), ]

      #remove unused factor levels
      d_comb$group = forcats::fct_drop(d_comb$group)

      if (dispersion == "sd") {
        if (dispersion_method == "all") {
          disp = pool_sd(d_x$x, d_x$group)
        if (dispersion_method == "pair") {
          disp = pool_sd(d_comb$x, d_comb$group)
        if (dispersion_method == "total") disp = sd(d_x$x)
      } else if (dispersion == "mad") {
        #mean of medians, robust
        if (dispersion_method == "all") {
          disp = plyr::daply(d_x, "group", function(part) {
          }) %>% mean
        if (dispersion_method == "pair") {
          disp = plyr::daply(d_comb, "group", function(part) {
          }) %>% mean
      } else if (is.numeric(dispersion)) disp = dispersion #use given number

      diff = central_tendency(d_comb$x[d_comb$group == col], ...) - central_tendency(d_comb$x[d_comb$group == row], ...)

      #devide by dispersion measure
      SMD = diff / disp

      #save effect size
      m[row, col] = SMD

      #pairwise sample size
      pairwise_n[row, col] = d_comb %>% nrow

      #group sample sizes
      pair_ns = group_ns[c(row, col)]

      #adjust for reliability?
      if (reliability != 1) {
        # browser()
        m[row, col] = adj_d_reliability(m[row, col], rel = reliability)

      #standard error
      #note variance vs. se (hence sqrt)!
      se[row, col] = sqrt((sum(pair_ns)/prod(pair_ns)) + ((m[row, col]^2) / (2*(sum(pair_ns) - 2))))

      CI_upper[row, col] = m[row, col] + se_CI * se[row, col]
      CI_lower[row, col] = m[row, col] - se_CI * se[row, col]

      #p val, 2-tailed
      pval[row, col] = pnorm(abs(abs(m[row, col]) / se[row, col]), lower.tail = F) * 2

      m_str[row, col] = str_template %>%
        #insert estimate
        str_replace_all(pattern = "%d", replacement = m[row, col] %>% str_round(digits = digits)) %>%
        #insert upper CI
        str_replace_all(pattern = "%upper", replacement = CI_upper[row, col] %>% str_round(digits = digits)) %>%
        #insert lower CI
        str_replace_all(pattern = "%lower", replacement = CI_lower[row, col] %>% str_round(digits = digits)) %>%
        #insert n
        str_replace_all(pattern = "%n", replacement = pairwise_n[row, col] %>% as.character()) %>%
        #insert se
        str_replace_all(pattern = "%se", replacement = se[row, col] %>% str_round(digits = digits))

  copy_names(m, m_str)
  copy_names(m, se)
  copy_names(m, CI_lower)
  copy_names(m, CI_upper)
  copy_names(m, pval)
  copy_names(m, pairwise_n)

  #fill upper halves
  m %<>% MAT_half2full(diag=T)
  m_str %<>% MAT_half2full(diag=T)
  se %<>% MAT_half2full(diag=T)
  CI_upper %<>% MAT_half2full(diag=T)
  CI_lower %<>% MAT_half2full(diag=T)
  pval %<>% MAT_half2full(diag=T)
  pairwise_n %<>% MAT_half2full(diag=T)
  #mode(pairwise_n) = "integer" #change to integer to prevent downstream problems

  if (!extended_output) {
  } else {
      d = m,
      d_string = m_str,
      CI_lower = CI_lower,
      CI_upper = CI_upper,
      se_z = se_CI,
      se = se,
      p = pval,
      pairwise_n = pairwise_n

#' Calculate homogeneity/heterogeneity
#' Calculate a simple index of homogeneity for nominal data, that is variously called Simpson's, Herfindahl's or Hirschman's index.
#' @param x (a vector) A vector of values.
#' @param reverse (log scalar) Whether to reverse the index to index heterogeneity (default false).
#' @param summary (log scalar) Whether to treat data as summary statistics of the group proportions (default false). If data are given in 0-100 format, it will automatically convert.
#' @export
#' @examples
#' homogeneity(iris$Species)
#' homogeneity(iris$Species, reverse = T)
#' homogeneity(c(.7, .2, .1), summary = T)
#' homogeneity(c(80, 15, 5), summary = T)
homogeneity = function(x, reverse = F, summary = F) {

  #not using summary statistics
  if (!summary) {
    if (!reverse) {
      return(table(x) %>% prop.table %>% as.vector %>% raise_to_power(2) %>% sum)

    return(table(x) %>% prop.table %>% as.vector %>% raise_to_power(2) %>% sum %>% subtract(1, .))

  #using summary statistics
  if (sum(x) %>% is_between(.99, 1.01)) {
    #not reversed
    if (!reverse) {
      return(x %>% raise_to_power(2) %>% sum)

    return(x %>% raise_to_power(2) %>% sum() %>% subtract(1, .))
  } else if (sum(x) %>% is_between(99, 101)) {
    #not reversed
    if (!reverse) {
      return(x %>% divide_by(100) %>% raise_to_power(2) %>% sum)

    return(x %>% divide_by(100) %>% raise_to_power(2) %>% sum %>% subtract(1, .))

  } else {
      stop("Tried to use summary statistics, but they did not sum to either around 1 or 100 ( 1%)")


#' Calculate weighted standard deviation
#' Calculated the weighted standard deviation using a vector of values and a vector of weights.
#' @param x (num vector) A vector of values.
#' @param w (num vector) A vector of weights.
#' @param sample (log scalar) Whether this is a sample as opposed to a population (default true).
#' @param error (lgl scalr) Whether to throw an error if there is no data at all or no pairwise complete cases. Default yes.
#' @export
#' @examples
#' set.seed(1)
#' X = rnorm(100)
#' set.seed(1)
#' W = runif(100)
#' sd(X) #0.898
#' wtd_sd(X, W) #0.894, slightly different
#' wtd_sd(X) #0.898, not using weights
wtd_sd = function(x, w = NULL, sample = T, error = F) {
  #x vector
  x = as.vector(x)

  if (is.null(w)) {
    w = rep(1, length(x))
  } else {
    w = as.vector(w)

  #make temp df
  d = data.frame(x = x, w = w) %>% na.omit

  #check sample
  if (nrow(d) == 0 & error) stop("There were no complete cases!")
  if (nrow(d) == 0) return(NA) #return NA on no cases

  #weighted mean
  wtd_mean = wtd_mean(x, w, error = error)

  #diffs squared
  diffs_sq = (x - wtd_mean)^2

  #weighted variance
  if (sample) wtd_var = sum(diffs_sq, na.rm = T) / (miss_count(x, reverse = T) - 1)
  if (!sample) wtd_var = sum(diffs_sq, na.rm = T) / miss_count(x, reverse = T)

  #weighted sd

#' Calculate a weighted mean
#' This is an improvement on \code{\link{weighted.mean}} in \code{base-r}.
#' The original function returns \code{NA} when there are missing values in the weights vector despite na.rm=T. This function avoids that problem. It also returns a useful error message if there are no complete cases. The function wraps base-r's function.
#' @param x (num vector) A vector of values.
#' @param w (num vector) A vector of weights.
#' @param error (lgl scalr) Whether to throw an error if there is no data at all or no pairwise complete cases. Default yes.
#' @export
#' @examples
#' set.seed(1)
#' X = rnorm(100)
#' set.seed(1)
#' W = runif(100)
#' wtd_mean(X) # not using weights
#' mean(X) #same as above
#' wtd_mean(X, W) #slightly different
wtd_mean = function(x, w = NULL, error = F) {

  #x vector
  x = as.vector(x)

  if (is.null(w)) {
    w = rep(1, length(x))
  } else {
    w = as.vector(w)

  if (!lengths_match(x, w)) stop("Lengths of x and w do not match!")

  #make temp df
  d = data.frame(x = x, w = w) %>% na.omit

  #check sample
  if (nrow(d) == 0 & error) stop("There were no complete cases!")
  if (nrow(d) == 0) return(NA) #return NA on no cases

  stats::weighted.mean(x = d$x, w = d$w)

#' Calculate a weighted sum
#' This is an improvement on \code{\link{sum}} in \code{base-r}.
#' It automatically handles missing data. It returns a useful error message if there are no complete cases.
#' @param x (num vector) A vector of values.
#' @param w (num vector) A vector of weights.
#' @param error (lgl scalr) Whether to throw an error if there is no data at all or no pairwise complete cases. Default yes.
#' @export
#' @examples
#' set.seed(1)
#' X = rnorm(100)
#' set.seed(1)
#' W = runif(100)
#' wtd_sum(X) # not using weights
#' sum(X) #same as above
#' wtd_sum(X, W) #different
wtd_sum = function(x, w = NULL, error=F) {
  #x vector
  x = as.vector(x)

  if (is.null(w)) {
    w = rep(1, length(x))
  } else {
    w = as.vector(w)

  lengths_match(x, w)

  #make temp df
  d = data.frame(x = x, w = w) %>% na.omit

  #check sample
  if (nrow(d) == 0 & error) stop("There were no complete cases!")
  if (nrow(d) == 0) return(NA) #return NA on no cases

  x_w = sum(d$x * d$w, na.rm = T) # sum of x * w
  w_sum = sum(d$w, na.rm = T) # sum of w
  (x_w / w_sum) * miss_count(d$x, reverse = T) #weighted sum

#' Standardize a vector
#' Standardize a vector. Can use weights and robust measures of central tendency and dispersion. Returns a clean vector as opposed to base-r's \code{\link{scale}}.
#' @param x (num vector) A vector of values.
#' @param w (num vector) A vector of weights.
#' @param robust (log vector) Whether to use robust measures (default false). See \code{\link{mad}} and \code{\link{median}}.
#' @param sample (log scalar) Whether this is a sample as opposed to a population (default true).
#' @param focal_group (lgl vector) A subset of the data to standardize the values to. This is useful when you want one subgroup to be the focal group, using their mean/sd as 0/1.
#' @export
#' @examples
#' set.seed(1)
#' X = rnorm(100, mean = 10, sd = 5)
#' set.seed(1)
#' W = runif(100)
#' standardize(X, W)
#' standardize(X, robust = T) #almost the same for these data
standardize = function(x, w = NULL, robust = F, sample = T, focal_group = NULL) {
  #x vector
  x = as.vector(x)

  if (is.null(w)) {
    w = rep(1, length(x))
  } else {
    w = as.vector(w)

  if (is.null(focal_group)) focal_group = rep(T, length(x))
  #assert equal lengths
  assertthat::assert_that(length(focal_group) == length(x))
  if (anyNA(focal_group)) {
    warning("`focal_group` contains `NA` values. These were converted to `FALSE` following tidyverse convention.")
    focal_group = focal_group %>% kirkegaard::NA_to_F()

  #full x
  full_x = x
  x = x[focal_group]
  w = w[focal_group]

  if (!robust) {
    #weighted mean
    wtd_mean = weighted.mean(x, w, na.rm = T)

    #diffs squared
    diffs_sq = (x - wtd_mean)^2

    #weighted variance
    if (sample) wtd_var = sum(diffs_sq, na.rm = T) / (miss_count(x, reverse = T) - 1)
    if (!sample) wtd_var = sum(diffs_sq, na.rm = T) / miss_count(x, reverse = T)

    #weighted sd (sample)
    wtd_sd = sqrt(wtd_var)

    return((full_x - wtd_mean) / wtd_sd)

  if (robust) {
    mdn = median(x, na.rm = T)

    mad = mad(x, na.rm = T)

    return((full_x - mdn) / mad)

#' Transform to 0-1 scale
#' @param x Numeric factor
#' @return Numeric factor on scale 0-1
#' @export
#' @examples
#' transform_01(1:5)
#' transform_01(c(1:3, NA, 4:5))
transform_01 = function(x) {

  #0 length
  if (length(x) == 0) return(c())

  #set to 0-1 scale
  #subtract minimum
  x = x - min(x, na.rm = T)

  #divide by maximum
  x = x / max(x, na.rm = T)


#' Find a cutoff of a normal distribution that results in a given mean trait value above the cutoff
#' Assuming a normal distribution for a trait and a cutoff value. Estimate what this cutoff value is to obtain a population above the cutoff with a known mean trait value.
#' @param mean_above (num scalar) Mean trait level of population above the cutoff.
#' @param mean_pop (num scalar) Mean trait level of population. Default = 100 (IQ scale).
#' @param sd_pop (num scalar) Standard deviation of the trait in the population. Default = 15 (IQ scale).
#' @param n (num scalar) Sample size to generate in the process. More results in higher precision and memory use. Default = 1e4.
#' @param precision (num scalar) The precision to use. Default = .1.
#' @param below (log scalar) Reverse the model to find cutoffs for
#' @export
#' @examples
#' #what cutoff is needed to get a population above the cutoff with a mean of 115 when the population mean is 100?
#' find_cutoff(115)
#' #try impossible
#' find_cutoff(95)
find_cutoff = function(mean_above, mean_pop = 100, sd_pop = 15, n = 1e4, precision = .1, below = F) {
  if (mean_above < mean_pop) stop("This model is inapplicable if the mean trait level is lower than the population mean!")

  #dangerous loop!
  cutoff = mean_pop #begin with unselected group
  iter = 1
  while (T) {

    #make a population
    population = rnorm(n = n, mean = mean_pop, sd = sd_pop)

    #get the population above the cutoff
    population_above = population[population > cutoff]

    #mean above
    population_above_mean = mean(population_above)

    v_diff = mean_above - population_above_mean
    if (abs(v_diff) < precision) {

    #adjust cutoff
    if (v_diff > 0) cutoff = cutoff + precision
    if (v_diff < 0) cutoff = cutoff - precision

    #iter + 1
    iter = iter + 1

    #check if infinite
    if (iter >= 1e5) stop("Loop reached 100k iterations without finding a solution! Use a lower level of precision!")

#' Calculate row-wise representativeness
#' @param x Data frame of numerical data
#' @param central_tendency_fun Function to use for central tendency
#' @param central_tendencies If custom values for central tendency, input here
#' @param standardize Whether to standardize values to ensure common metric. If not done, variables with larger SDs will be given more influence.
#' @param absolute Whether to use absolute errors (if not, then there is also an attempt to ensure the error directions are balanced across variables)
#' @return A data frame of error values and their mean and medians
#' @export
#' @examples
#' calc_row_representativeness(iris)
calc_row_representativeness = function(x, central_tendency_fun = median, central_tendencies = NULL, standardize = T, absolute = T) {

  #skip non-numeric variables
  is_num = map_lgl(x, is.numeric)
  if (any(!is_num)) message("Skipped non-numeric variables")
  x = x[is_num]

  #standardize if desired
  if (standardize) x = map_df(x, kirkegaard::standardize)

  #compute centencies
  if (is.null(central_tendencies)) {
    if (identical(central_tendency_fun, median)) {
      central_tendencies = map_dbl(x, median, na.rm = T)
    } else if (identical(central_tendency_fun, mean)) {
      central_tendencies = map_dbl(x, mean, na.rm = T)
    } else {
      #something else
      central_tendencies = map_dbl(x, mean, na.rm = T)
  } else {
    #ensure matching lengths
    assert_that(ncol(x) == length(central_tendencies))

  #subtract central tendencies
  if (!absolute) {
    x_error = t(t(x) - central_tendencies) %>% as.data.frame()
  } else {
    x_error = abs(t(t(x) - central_tendencies)) %>% as.data.frame()

  #add row means and medians, dplyr way
  x_error %>%
    rowwise() %>%
      mean = mean(c_across(everything()), na.rm = T),
      median = median(c_across(everything()), na.rm = T)
    ) %>% ungroup() %>%
    mutate(row = 1:n()) %>%
    select(row, everything())

#internal helper functions
#get model R2 adj
get_r2adj = function(x) {
  x %>% summary.lm() %>% .$adj.r.squared

#boomer code
get_p = function(x) {
  1 - pchisq(x$stats['Model L.R.'], x$stats['d.f.'])

#' Test for heteroscedasticity
#' @param resid Model residuals
#' @param x Model predictor of interest
#' @return Data frame of results
#' @export
#' @examples
#' #look for mostly nonexistent HS in iris
#' test_HS(resid = resid(lm(Sepal.Length ~ Petal.Length, data = iris2)), x = iris2$Petal.Length)
#' #a lot of HS here
#' test_HS(resid = resid(lm(Petal.Width ~ Petal.Length, data = iris2)), x = iris2$Petal.Length)
test_HS = function(resid, x) {
  d = tibble(
    resid = standardize(abs(resid)),
    x = standardize(x),
    resid_rank = rank(resid),
    x_rank = rank(x)

  mods = list(
    fit_linear = rms::ols(resid ~ x, data = d),
    fit_spline = rms::ols(resid ~ rms::rcs(x), data = d),
    fit_linear_rank = rms::ols(resid_rank ~ x_rank, data = d),
    fit_spline_rank = rms::ols(resid_rank ~ rms::rcs(x_rank), data = d)

  y = tibble(
    test = c("linear raw", "spline raw", "linear rank", "spline rank"),
    r2adj = purrr::map_dbl(mods, get_r2adj),
    p = purrr::map_dbl(mods, get_p),
    fit = mods

  #for the rcs, we have to compute the model improvement vs. linear
  y$p[2] = rms::lrtest(mods$fit_linear, mods$fit_spline)$stats["P"]
  y$p[4] = rms::lrtest(mods$fit_linear_rank, mods$fit_spline_rank)$stats["P"]

  y %>% dplyr::mutate(log10_p = -log10(p))

#' Quantile smoothening
#' @param x Predictor variable
#' @param y Outcome variable
#' @param quantile Which quantile (e.g., .95)
#' @param method Method to use
#' @param k Number of knots for method qgam, see [qgam::qgam]
#' @param window Window size for method running, see [caTools::runquantile]
#' @return Vector of fitted values
#' @export
#' @examples
#' quantile_smooth(iris$Petal.Length, iris$Sepal.Length, quantile = .90)
quantile_smooth = function(x, y, quantile, method = c("qgam", "Rq", "running"), k = 5, window = 50) {

  method = match.arg(method[1], method)
  if (method == "qgam") {
    #same as in example here
    assertthat::assert_that(nchar(system.file(package = "qgam")) > 0)
    fit = qgam::qgam(list(y~s(x, k = k, bs = "cr"), ~ s(x, k = k, bs = "cr")),
                     data = data.frame(x = x, y = y), qu = quantile)

  } else if (method == "Rq") {
    #fit with spline
    assertthat::assert_that(nchar(system.file(package = "rms")) > 0)
    fit = rms::Rq(y ~ rms::rcs(x), data = data.frame(x = x, y = y), tau = quantile)
    return(fitted(fit) %>% as.vector())

  } else if (method == "running") {
    #beware, need to sort the y by x first
    #and then return it to original state
    assertthat::assert_that(nchar(system.file(package = "caTools")) > 0)
    x_order = order(x)
    return(caTools::runquantile(y[x_order], k = k, probs = quantile)[x_order])

  } else {
    stop(str_glue("method not recognized: {method}"), call. = F)


#' Compute many proportion tests
#' This is a convenient tidy wrapper for `stat::prop.test()`
#' @param x A factor variable
#' @param group A grouping variable
#' @param correct a logical indicating whether Yates' continuity correction should be applied where possible.
#' @param conf_level confidence level of the returned confidence interval. Must be a single number between 0 and 1. Only used when testing the null that a single proportion equals a given value, or that two proportions are equal; ignored otherwise.
#' @param alternative a character string specifying the alternative hypothesis, must be one of "two.sided" (default), "greater" or "less". You can specify just the initial letter. Only used for testing the null that a single proportion equals a given value, or that two proportions are equal; ignored otherwise.
#' @return A data frame, which may have missing levels if a combination did not exist in the data.
#' @export
#' @examples
#' prop_tests(mpg$cyl, mpg$manufacturer)
prop_tests = function(x, group, correct = T, conf_level = .95, alternative = c("two.sided", "less", "greater")) {

  #make data frame
    x = x %>% as.factor(),
    group = group %>% as.factor()
  ) -> xdata

  #begin pipe
  xdata %>%
    #filter NA
    filter(!is.na(x), !is.na(group)) %>%
    #loop groups
    plyr::ddply(("group"), function(dd) {
      #loop levels
      map_df(levels(dd$x), function(this_level) {
        #suppress the continuity correction warnings
        suppressWarnings(prop.test(sum(dd$x==this_level), n = nrow(dd), correct = correct, conf.level = conf_level, alternative = alternative) %>% broom::tidy() %>% mutate(n = nrow(dd), n_level = sum(dd$x == this_level)))
      }) %>%
        mutate(level = levels(dd$x))
    }) %>%
      level = level %>% factor(levels = levels(xdata$x)),
      group = group %>% factor(levels = levels(xdata$group)),
    ) %>%
    as_tibble() %>%
    select(group, level, everything())


#' Test for differential item functioning (DIF)
#' Tests are done following the mirt package approach outlined by Chalmers.
#' @param items Item data to use
#' @param model Item model
#' @param group Group (2 groups at most)
#' @param fscores_pars Any extra scoring parameters used
#' @param messages Show messages
#' @param method Method
#' @param technical Further technical args to pass to mirt
#' @param itemtype Item type
#' @param verbose Verbose output
#' @param DIF_args Arguments to pass to mirt::DIF
#' @param multiple_testing_method Method to use for multiple testing correction, see p.adjust(). Default is "bonferroni"
#' @param ... Other arguments passed to mirt functions
#' @return A list of results
#' @export
#' @examples
#' library(mirt)
#' n = 1000
#' n_items = 10
#' #slopes
#' set.seed(1)
#' a1 = runif(n_items, min = .5, max = 2)
#' a2 = a1
#' a2[1] = 0 #item doesnt work for this group
#' #intercepts
#' i1 = rnorm(n_items, mean = -0.5, sd = 2)
#' i2 = i1
#' i2[2] = -2 #item much harder for this group
#' #simulate data twice
#' d1 = simdata(
#' a1,
#' i1,
#' N = n,
#' itemtype = "2PL",
#' mu = 0
#' )
#' d2 = simdata(
#' a2,
#' i2,
#' N = n,
#'   itemtype = "2PL",
#'   mu = 1
#' )
#' #combine
#' d = rbind(
#'   d1 %>% set_names("item_" + 1:n_items),
#'   d2 %>% set_names("item_" + 1:n_items)
#' ) %>% as.data.frame()
#' #find the bias
#' DIF_results = DIF_test(d, model = 1, itemtype = "2PL", group = rep(c(1, 2), each = n))
#' DIF_results$effect_size_items$conservative
#' plot(DIF_results$fits$anchor_conservative)
#' plot(DIF_results$fits$anchor_conservative, type = "trace")
DIF_test = function(items, model, group, fscores_pars = list(full.scores = T, full.scores.SE = T), messages = T, method = "EM", technical = list(), itemtype = NULL, verbose = T, DIF_args = NULL, multiple_testing_method = "bonferroni", ...) {

  #deal with missing data in the group var
  group_keep = !is.na(group)
  items = items[group_keep, ]
  group = group[group_keep]

  #make mirt args
  mirt_args = c(list(data = items, model = model, technical = technical, verbose = verbose, method = method, itemtype = itemtype), list(...))
  mirt_args_set2 = mirt_args[!names(mirt_args) %in% c("model", "itemtype")]

  #regular fit joint group
  if (messages) message("There are 8 steps")
  if (messages) message("Step 1: Initial joint fit\n")
  mirt_fit = rlang::exec(mirt::mirt, !!!mirt_args)

  #step 3
  if (!is.character(group) && !is.factor(group)) group = factor(group)
  if (messages) message("\nStep 2: Initial MI fit")
  mirt_fit_MI = rlang::exec(mirt::multipleGroup, !!!mirt_args, group = group, invariance = c('intercepts','slopes', 'free_means', 'free_var'))

  if (messages) message("\nStep 3: Leave one out MI testing")

  if (is.null(DIF_args)) {
    DIF_args = list(
      #test all pars in model
      which.par = mirt::coef(mirt_fit, simplify = T)$items %>% colnames(),
      scheme = "drop"

  #call DIF() with arguments
  DIFs = rlang::exec(
    MGmodel = mirt_fit_MI,
    !!!(mirt_args[!names(mirt_args) %in% c("data", "model", "group", "itemtype")]),

  DIFs = DIFs %>% rownames_to_column("item")
  DIFs$number = 1:nrow(DIFs)

  #adjust p values
  DIFs$p_adj = DIFs$p %>% p.adjust(method = multiple_testing_method)

  #with significant DIF
  DIFs_detected_liberal = DIFs %>% filter(p < .05)
  DIFs_detected_conservative = DIFs %>% filter(p_adj < .05)

  #subset itmes
  items_noDIF_liberal = items %>% dplyr::select(!!setdiff(DIFs$item, DIFs_detected_liberal$item))
  items_noDIF_conservative = items %>% dplyr::select(!!setdiff(DIFs$item, DIFs_detected_conservative$item))

  #subset models
  #if its a g only model, we dont have to do anything
  #but if its complex we need name format or Q matrix format
  #extract loadings matrix
  #convert to Q matrix

  mirt_fit_loadings = mirt_fit@Fit$`F`
  model_noDIF_liberal_Q = mirt_fit_loadings %>% apply(MARGIN = 2, as.logical) %>% magrittr::set_rownames(rownames(mirt_fit_loadings))
  model_noDIF_conservative_Q = model_noDIF_liberal_Q

  #set unused items' rows to FALSE
  model_noDIF_liberal_Q[DIFs_detected_liberal$item, ] = F
  model_noDIF_conservative_Q[DIFs_detected_conservative$item, ] = F

  #fit together without DIF
  if (messages) message("\nStep 4: Fit without DIF items, liberal threshold")
  mirt_fit_noDIF_liberal = rlang::exec(mirt::mirt, model = mirt::mirt.model(model_noDIF_liberal_Q), !!!mirt_args_set2)

  if (messages) message("\nStep 5: Fit without DIF items, conservative threshold")
  mirt_fit_noDIF_conservative = rlang::exec(mirt::mirt, model = mirt::mirt.model(model_noDIF_conservative_Q), !!!mirt_args_set2)

  #with anchors
  if (messages) message("\nStep 6: Fit with anchor items, liberal threshold")
  mirt_fit_anchors_liberal = rlang::exec(mirt::multipleGroup, !!!mirt_args, group = group, invariance = c(items_noDIF_liberal %>% names(), 'free_means', 'free_var'))

  if (messages) message("\nStep 7: Fit with anchor items, conservative threshold")
  mirt_fit_anchors_conservative = rlang::exec(mirt::multipleGroup, !!!mirt_args, group = group, invariance = c(items_noDIF_conservative %>% names(), 'free_means', 'free_var'))

  #get scores
  if (messages) message("\nStep 8: Get scores")

  orig_scores = do.call(what = mirt::fscores, args = c(list(object = mirt_fit), fscores_pars))
  noDIF_scores_liberal = do.call(what = mirt::fscores, args = c(list(object = mirt_fit_noDIF_liberal), fscores_pars))
  noDIF_scores_conservative = do.call(what = mirt::fscores, args = c(list(object = mirt_fit_noDIF_conservative), fscores_pars))
  anchor_scores_liberal = do.call(what = mirt::fscores, args = c(list(object = mirt_fit_anchors_liberal), fscores_pars))
  anchor_scores_conservative = do.call(what = mirt::fscores, args = c(list(object = mirt_fit_anchors_conservative), fscores_pars))

  #in a data frame
  scores = list(
    #original scores
    original = orig_scores,

    #after DIF removal
    noDIF_liberal = noDIF_scores_liberal,
    noDIF_conservative = noDIF_scores_conservative,

    #anchor scores
    anchor_liberal = anchor_scores_liberal,
    anchor_conservative = anchor_scores_conservative

  #effect sizes
  #this only works with 1 dimensional models
  if (ncol(mirt_fit_loadings) == 1) {
    #item level
    effect_size_items = list(
      liberal = mirt::empirical_ES(mirt_fit_anchors_liberal, DIF = T, plot = F),
      conservative = mirt::empirical_ES(mirt_fit_anchors_conservative, DIF = T, plot = F)

    #test level
    effect_size_test = list(
      liberal = mirt::empirical_ES(mirt_fit_anchors_liberal, DIF = F, plot = F),
      conservative = mirt::empirical_ES(mirt_fit_anchors_conservative, DIF = F, plot = F)
  } else {
    #we fill in NULLS to keep structure
    effect_size_items = list(
      liberal = NULL,
      conservative = NULL

    effect_size_test = list(
      liberal = NULL,
      conservative = NULL

    scores = scores,
    fits = list(
      original = mirt_fit,
      noDIF_liberal = mirt_fit_noDIF_liberal,
      noDIF_conservative = mirt_fit_noDIF_conservative,
      anchor_liberal = mirt_fit_anchors_liberal,
      anchor_conservative = mirt_fit_anchors_conservative
    DIF_stats = DIFs,
    effect_size_items = effect_size_items,
    effect_size_test = effect_size_test


#convert from slope to loading
#' Convert from factor loading to discrimination
#' @param x Vector of factor loadings
#' @param logit_scaling Whether to use logit scaling
#' @return A vector of slopes (discrimination)
#' @export
slope_to_loading = function(x, logit_scaling = T) {
  if (logit_scaling) {
    scaling_factor = 3
  } else {
    scaling_factor = 1

  sqrt(x^2 / (x^2 + scaling_factor))

#' Convert from slopes to factor loadings
#' @param x A vector of slopes
#' @param logit_scaling Whether to use logit scaling
#' @return A vector of loadings
#' @export
loading_to_slope = function(x, logit_scaling = T) {
  if (logit_scaling) {
    scaling_factor = 3
  } else {
    scaling_factor = 1

  sqrt((scaling_factor*x^2) / (1 - x^2))

#' Calculate item gaps
#' @param x Item data frame
#' @param group A grouping variable
#' @param return_data_frame Whether you want a data frame back, or just a vector
#' @return a data frame or a vector
#' @export
#' @examples
#' set.seed(1)
#' X = matrix(rbinom(10000, 1, .5), ncol = 10)
#' group = rbinom(1000, 1, .5)
#' calc_item_gaps(X, group)
calc_item_gaps = function(x, group, return_data_frame = T) {
  # browser()
  #subset to complete cases for group
  no_na = !is.na(group)
  x = x[no_na, ]
  group = group[no_na] %>% as.factor()

  #input check
  assert_that(is.matrix(x) | is.data.frame(x))
  x = as.data.frame(x)
  assert_that(all(map_lgl(x, is.numeric)))
  assert_that(length(levels(group)) == 2)

  #decide focal group
  focal_group = levels(group)[1]
  alt_group = levels(group)[2]
  message(str_glue("Focal group is {focal_group}. Positive values mean that {focal_group} > {alt_group}"))

  #prep data frame
  res = tibble(
    item_i = 1:ncol(x),
    item = colnames(x),
    focal_pass_rate = map_dbl(x[group == focal_group, ], mean, na.rm = T),
    alt_pass_rate = map_dbl(x[group == alt_group, ], mean, na.rm = T),
    d_gap = qnorm(focal_pass_rate) - qnorm(alt_pass_rate)

  if (return_data_frame) {
  } else {

#' Create norms and adjust scores for age effects
#' This function adjusts for the effect of age on the mean and standard deviation of a score, and then creates norms for setting the resulting scores to follow the IQ scale of 100/15 for the chosen subgroup.
#' @param score A vector of scores
#' @param age A vector of ages
#' @param norm_group A logical vector of the same length as the score vector, indicating the norm group. Default is all cases.
#' @param p_value The p value threshold for model inclusion. Default is .01.
#' @return A list of results, including corrected scores, age correction models used, and means and standard deviations for the norm group post-correction.
#' @export
#' @rdname Norms_and_age_corrections
#' @examples
#' #simulate some data, mess up the norms, and get them back
#' set.seed(1)
#' data = tibble(
#'  true_IQ = c(rnorm(1000, mean = 100, sd = 15), rnorm(1000, mean = 90, sd = 15)),
#'  true_z = (true_IQ - 100 ) / 15,
#'  age = runif(2000, min = 18, max = 80),
#'  norm_group = c(rep(T, 1000), rep(F, 1000))
#'  ) %>% mutate(
#'  #add an effect of age on the mean and dispersion of scores
#'  age_mean_effect = age * 0.3 - 0.3 * 18,
#'  age_sd_effect = true_z * 15 * rescale(age, new_min = .80, new_max = 1.20),
#'  score = 100 + age_sd_effect + age_mean_effect
#'  )
#'  #restore the norms
#'  norms = make_norms(data$score, data$age, data$norm_group)
#'  data$IQ = norms$data$IQ
#'  #manually apply norms
#'  data$IQ2 = apply_norms(data$score, data$age, prior_norms = norms)
#'  data$IQ3 = apply_norms(data$score, data$age,
#'  age_model_slope = norms$age_model$coefficients[2],
#'  age_model_intercept = norms$age_model$coefficients[1],
#'  age_model_abs_slope = norms$age_model_abs$coefficients[2],
#'  age_model_abs_intercept = norms$age_model_abs$coefficients[1],
#'  mean = norms$norm_desc$mean,
#'  sd = norms$norm_desc$sd
#'  )
#'  #verify that scores were made correct
#'  cor(data)
#'  GG_scatter(data, "age", "score")
#'  #can we detect the age heteroscedasticity too?
#'  test_HS(resid = resid(lm(score ~ age, data = data)), x = data$age)
#'  #plot the results
#'  GG_scatter(data, "true_IQ", "score") + geom_abline(slope = 1, intercept = 0, linetype = 2)
#'  GG_scatter(data, "true_IQ", "IQ") + geom_abline(slope = 1, intercept = 0, linetype = 2)
make_norms = function(score, age, norm_group = NULL, p_value = .01) {
  #if no norm group, use all rows
  if (is.null(norm_group)) {
    norm_group = rep(T, length(score))

  #make a data frame with variables
  d_all = tibble(
    score = score,
    age = age,
    norm_group = norm_group

  #norm subset
  d_norm = d_all %>%

  #correct for the mean effect of age
  age_model = lm(score ~ age, data = d_norm)
  age_model_glance = broom::glance(age_model)

  #if the p value is too high, we dont use the model
  if (age_model_glance$p.value >= p_value) {
    d_norm$score_ageadj1 = d_norm$score
  } else {
    d_norm$score_ageadj1 = resid(age_model)

  #fit model for for the SD effect of age
  age_model_abs = lm(abs(score_ageadj1) ~ age, data = d_norm)
  age_model_abs_glance = broom::glance(age_model_abs)

  #apply corrections to the full data
  #if model is not significant, we dont use it
  if (age_model_glance$p.value >= p_value) {
    message(str_glue("No detected linear effect of age on the score (p = {p_to_asterisk(age_model_glance$p.value)}). Model not used."))
    d_all$score_ageadj1 = d_all$score
  } else {
    message(str_glue("Detected linear effect of age on the score (p = {p_to_asterisk(age_model_glance$p.value)}). Model used."))
    d_all$score_ageadj1 = d_all$score - predict(age_model, newdata = d_all)

  #if model is not significant, we dont use it
  if (age_model_abs_glance$p.value >= p_value) {
    message(str_glue("No detected variance effect of age on the score (p = {p_to_asterisk(age_model_abs_glance$p.value)}). Model not used."))
    d_all$score_ageadj2 = d_all$score_ageadj1
  } else {
    message(str_glue("Detected variance effect of age on the score (p = {p_to_asterisk(age_model_abs_glance$p.value)}). Model used."))
    d_all$score_ageadj2 = d_all$score_ageadj1 / predict(age_model_abs, newdata = d_all)

  #find the norm group mean/SD
  norm_desc = d_all %>% filter(norm_group) %>% select(score_ageadj2) %>% describe2()
  d_all$score_ageadj3 = (d_all$score_ageadj2 - norm_desc$mean) / norm_desc$sd

  #output IQ scores
  d_all$IQ = d_all$score_ageadj3 * 15 + 100

  #list of results
    #adjusted data
    data = d_all,

    #models fitted
    age_model = age_model,
    age_model_abs = age_model_abs,

    #used models?
    age_model_used = age_model_glance$p.value < p_value,
    age_model_abs_used = age_model_abs_glance$p.value < p_value,

    #means and SD for norm group post corrections
    norm_desc = norm_desc

#make a function that applies norms to a new data set based on prior results

#' Apply norms to a new data set
#' @param score A vector of scores to correct
#' @param age A vector of ages to use
#' @param prior_norms A list of prior norms to use from `make_norms`
#' @param age_model_slope A manually supplied slope for the age model
#' @param age_model_intercept A manually supplied intercept for the age model
#' @param age_model_abs_slope A manually supplied slope for the age model
#' @param age_model_abs_intercept A manually supplied intercept for the age model
#' @param mean A manually supplied mean for the norm group
#' @param sd A manually supplied standard deviation for the norm group
#' @return A vector of IQ scores
#' @export
#' @rdname Norms_and_age_corrections
apply_norms = function(score,
                       prior_norms = NULL,
                       age_model_slope = 0,
                       age_model_intercept = 0,
                       age_model_abs_slope = 0,
                       age_model_abs_intercept = 0,
                       mean = NULL,
                       sd = NULL) {
  # browser()
  #if prior norms, use them
  if (!is.null(prior_norms)) {
    #poll values from prior norms
    #but don't if they are non-significant
    if (prior_norms$age_model_used) {
      age_model_slope = prior_norms$age_model$coefficients[2]
      age_model_intercept = prior_norms$age_model$coefficients[1]
    if (prior_norms$age_model_abs_used) {
      age_model_abs_slope = prior_norms$age_model_abs$coefficients[2]
      age_model_abs_intercept = prior_norms$age_model_abs$coefficients[1]

    mean = prior_norms$norm_desc$mean
    sd = prior_norms$norm_desc$sd

  #apply the norms
  #appy linear correction
  if (age_model_slope == 0) {
    score_ageadj1 = score
  } else {
    score_ageadj1 = score - age_model_slope * age - age_model_intercept

  #apply variance correction
  if (age_model_abs_slope == 0) {
    score_ageadj2 = score_ageadj1
  } else {
    score_ageadj2 = score_ageadj1 / (age_model_abs_slope * age + age_model_abs_intercept)

  #apply norm group correction
  score_ageadj3 = (score_ageadj2 - mean) / sd

  #IQ scale
  IQ = score_ageadj3 * 15 + 100

