#' Compute Posterior
#'
#' @param x matrix
#' @param y matrix
#' @param k clusters
#' @param N observations
#' @param p dimension
#' @param m dimension
#' @param lam penalty param
#' @param rank rank
#' @param clust_assign current assignment
#' @param val_frac fraction validation
#' @param penal_search search vector
#'
#' @return tibble of posterior
#' @export
#'
#' @import dplyr
#' @import purrr
#' @import stats
#'
#' @examples
fct_gamma <- function(x, y, k, N, p, m, lam, rank, clust_assign, val_frac, penal_search){
names <- paste0("c_",1:k)
gamma_calc <- 1:k %>%
list() %>%
purrr::pmap(.f = function(.x){
cluster_rows <- (clust_assign==.x) %>%
which()
n_k <- cluster_rows %>%
length()
X_k <- x %>%
dplyr::as_tibble(.name_repair = "universal") %>%
dplyr::slice(cluster_rows) %>%
as.matrix()
Y_k <- y %>%
dplyr::as_tibble(.name_repair = "universal") %>%
dplyr::slice(cluster_rows) %>%
as.matrix()
if(is.null(lam) & nrow(X_k) > 3){
# Cross Validate to select penalization
val_rows <- cluster_rows %>%
dplyr::tibble() %>%
dplyr::slice_sample(prop = ifelse((val_frac*nrow(X_k)) < 1, 0.5, 1)) %>%
dplyr::pull()
train_rows <- cluster_rows %>%
dplyr::tibble() %>%
dplyr::filter(!((.) %in% val_rows)) %>%
dplyr::pull()
split_data <- fct_data_split(X_k, Y_k, ifelse((val_frac*nrow(X_k)) < 1, 0.5, val_frac))
x_train <- split_data %>%
purrr::pluck("x_train")
x_test <- split_data %>%
purrr::pluck("x_test")
y_train <- split_data %>%
purrr::pluck("y_train")
y_test <- split_data %>%
purrr::pluck("y_test")
n_train <- split_data %>%
purrr::pluck("n_train")
n_test <- split_data %>%
purrr::pluck("n_test")
rank_var_test <- fct_rank_var(x_train, y_train, n_train, p, m)
sigmahat_test <- rank_var_test %>%
purrr::pluck("sigmahat")
grid_search <- expand.grid(lam = penal_search)
model_k <- list(grid_search$lam) %>%
purrr::pmap_dfr(.f = function(.l){
lam_0 <- fct_lam_coef(x_train, sigmahat_test, m, p)
lam <- .l*lam_0
model <- fct_sarrs(y_train,x_train,rank, lam, "grLasso")
error <- mean((y_test-(cbind(x_test,1) %*% model$Ahat))^2)
dplyr::tibble(lam = lam, rank = rank, error = error, model = list(model))
}) %>%
dplyr::arrange(error) %>%
dplyr::slice(1) %>%
dplyr::pull(model) %>%
purrr::pluck(1)
} else if (nrow(X_k) > 1){
rank_var <- fct_rank_var(x, y, n_k, p, m)
sigmahat <- rank_var %>%
purrr::pluck("sigmahat")
lam_0 <- fct_lam_coef(x, sigmahat, m, p)
model_k <- fct_sarrs(Y_k, X_k, rank, lam_0, "grLasso")
} else if (nrow(X_k) <= 1){
# model_k <- fct_sarrs(Y_k, X_k, 1, 1, "grLasso")
return(dplyr::tibble(-Inf) %>% stats::setNames(names[.x]))
} else {
model_k <- fct_sarrs(Y_k, X_k, rank, lam, "grLasso")
}
A_k <- model_k %>%
purrr::pluck("Ahat")
sig_vec <- model_k %>%
purrr::pluck("sigvec")
mu_mat <- (x %>%
dplyr::bind_cols(int = rep(1,N)) %>%
as.matrix()) %*% A_k
gam <- fct_log_lik(mu_mat, sig_vec, y, N, m)
return(list(gamma = dplyr::tibble(gam) %>% stats::setNames(names[.x]), A_k = A_k, sig_vec = sig_vec))
})
gamma <- gamma_calc %>%
purrr::map_dfc(.f = function(.x){.x %>% purrr::pluck("gamma")})
A <- gamma_calc %>%
purrr::map(.f = function(.x){.x %>% purrr::pluck("A_k")})
sig_vec <- gamma_calc %>%
purrr::map(.f = function(.x){.x %>% purrr::pluck("sig_vec")})
return(list(gamma = gamma, A = A, sig_vec = sig_vec))
}
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.