library(dplyr)
library(ggplot2)
source("./functions/20210205_sarrs_alex.R")
# set.seed(19921124)
num_pert <- 60
lam <- 1
maxiter <- 100
N <- 100
prob <- c(0.5,0.5)
k <- prob %>% length()
nvld <- 1e4
rho <- c(0,0)
sigma <- 1
p <- 50
m <- 20
s <- 5
r <- c(1,1)
b <- c(5,10)
int <- prob %>% cumsum()
rand_assign <- runif(N)
names <- paste0("c_",1:k)
if(k*s > p){print("FOR SEPARATION, VERIFY THAT K*S < P")}
# make sure k*s is less than p
clust_assign_true <- (rand_assign) %>%
purrr::map_int(.f = function(.x){
clust <- (.x <= int) %>%
which() %>%
min()
return(clust)
}) %>%
sort()
clust_assign_true_key <- clust_assign_true %>%
tibble() %>%
mutate(order = 1:N) %>%
arrange((.))
clust_assign_true_vec <- clust_assign_true_key %>%
pull(order)
n <- clust_assign_true %>%
as_tibble() %>%
group_by(value) %>%
summarize(n = n()) %>%
pull(n)
gamma_store <- as_tibble(matrix(rep(0,k),nrow=1)) %>%
rename_if(is.numeric,~names) %>%
mutate(w_ll = 0, iter = 0)
# clust_assign_true <- c(rep(1,39),rep(2,52), rep(3,109))
clust_iter <<- 1
clust_min <<- 1
clust_max <<- s
data_k <- n %>%
list(r,rho,b) %>%
purrr::pmap(.f = function(.n,.r,.rho,.b){
a_rows <- clust_min:clust_max
clust_iter <<- clust_iter + 1
clust_min <<- clust_max+1
clust_max <<- clust_iter*s
sim <- sim_sarrs(.n,nvld,.n,a_rows,p,m,.r,.rho,sigma,.b)
A <- sim %>%
purrr::pluck("A")
S <- sim %>%
purrr::pluck("sig")
X <- sim %>%
purrr::pluck("X")
Y <- sim %>%
purrr::pluck("Y")
return(list(X=X,Y=Y,A=A,S=S))
})
stack_X <- data_k %>%
purrr::map_dfr(.f = function(.x){
X <- .x %>%
purrr::pluck("X")
return(as_tibble(X))
}) %>%
slice(clust_assign_true_vec) %>%
as.matrix()
stack_Y <- data_k %>%
purrr::map_dfr(.f = function(.x){
Y <- .x %>%
purrr::pluck("Y")
return(as_tibble(Y))
}) %>%
slice(clust_assign_true_vec) %>%
as.matrix()
init_int <- rep(1/k,k) %>% cumsum()
init_rand_assign <- runif(N)
clust_assign <- (init_rand_assign) %>%
purrr::map_int(.f = function(.x){
clust <- (.x <= init_int) %>%
which() %>%
min()
return(clust)
})
clust_assign <- clust_assign_true
pert_samples <- sample(1:N, num_pert)
unique_vals <- 1:k
for(i in 1:length(pert_samples)){
obs <- pert_samples[i]
orig <- clust_assign[obs]
vals <- unique_vals[unique_vals != orig]
if(length(vals) == 1){
clust_assign[obs] <- vals
} else {
clust_assign[obs] <- sample(vals)
}
}
conv <- Inf
main_clust <- Inf
iter <- 0
while(conv>0 & iter < maxiter){
iter <- iter+1
pi_vec <- 1:k %>%
purrr::map_dbl(
.f = function(.x){
sum(clust_assign==.x)/N
}
)
gamma <- 1:k %>%
list() %>%
purrr::pmap_dfc(.f = function(.x){
select <- (clust_assign==.x) %>%
which()
nk <- select %>% length()
X_k <- stack_X %>%
as_tibble() %>%
slice(select) %>%
as.matrix()
Y_k <- stack_Y %>%
as_tibble() %>%
slice(select) %>%
as.matrix()
# P=X%*% ginv(t(X)%*%X) %*%t(X)
# rhat = sum(svd(P%*%Y)$d > sigmahat*(sqrt(2*m)+sqrt(2*min(n,p))))
rhat <- r[.x]
estimate <- SARRS(Y_k,X_k,rhat, lam, "grLasso")
A_k <- estimate %>%
purrr::pluck("Ahat")
sig_vec <- estimate %>%
purrr::pluck("sigvec")
mu_mat <- (stack_X %>%
bind_cols(int = rep(1,N)) %>%
as.matrix()) %*% A_k
gam <- array(dim = N)
for(i in 1:N){
mu_i <- mu_mat[i,]
gam[i] <- 1:m %>%
purrr::map_dbl(
.f = function(.y){
dnorm(stack_Y[i,.y], mean = mu_i[.y], sd = sig_vec[.y])
}
) %>%
log() %>%
sum()
}
return(tibble(gam) %>% setNames(names[.x]))
})
weighted_ll <- (gamma*pi_vec) %>%
tibble() %>%
mutate(total = rowSums(across(where(is.numeric)))) %>%
pull(total)
gamma_store <- gamma_store %>%
bind_rows(tibble(gamma, w_ll = weighted_ll, iter = iter))
clust_assign_old <- clust_assign
clust_assign <- 1:N %>%
purrr::map_int(
.f = function(.x){
res <- (gamma %>% as.matrix())[.x,] %>% which.max()
}
)
conv <- (clust_assign!=clust_assign_old) %>% sum()
print((table(clust_assign_true, clust_assign)))
print(paste("Iter:",iter,"| conv:",conv))
}
shuffle <- clue::solve_LSAP(table(clust_assign_true, clust_assign), maximum = TRUE)
(table(clust_assign_true, clust_assign)[,shuffle])
gamma_store %>%
filter(iter > 0) %>%
group_by(iter) %>%
summarize_all(~sum(.)) %>%
tidyr::pivot_longer(cols = tidyselect::starts_with("c"), names_to = "cluster") %>%
ggplot() +
geom_path(aes(x = iter, y = value, colour = cluster))
gamma_store %>%
filter(iter > 0) %>%
group_by(iter) %>%
summarize_all(~sum(.)) %>%
ggplot() +
geom_path(aes(x = iter, y = w_ll))
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.