library(palmerpenguins)
library(tidyverse)
library(scatteR)
# library(parallel)
library(here)
gen_art_data <- function(data){
cols <- colnames(data %>% select_if(is.numeric))
results <- t(combn(cols,2))
synth_ls <- future_apply(results,1,function(row){
x <- row[1] |> unlist()
y <- row[2] |> unlist()
factors <- which(length(data[[x]]) %% seq(1,length(data[[x]])) == 0)
d <- scatteR::scatteR(scagnostics::scagnostics(data[[x]],data[[y]]),
n_points = length(data[[x]]),
init_points = factors[ceiling(quantile(seq(length(factors)),0.75))],epochs = 200)
# destd_x <- (d$x * (max(data[[x]])-min(data[[x]]))) + min(data[[x]])
# destd_y <- (d$y * (max(data[[y]])-min(data[[y]]))) + min(data[[y]])
# d$x <- destd_x
# d$y <- destd_y
list(d)
})
synth_df <- imap_dfr(synth_ls,function(df,i){
df[[1]] |> mutate(comb = i,
comb_x = results[i,1] |> unlist(),
comb_y = results[i,2] |> unlist())
})
synth_df %>%
ggplot(aes(x,y,color = factor(comb_y)))+
geom_jitter()+
geom_smooth(se = FALSE,method="lm") +
coord_polar(theta = "x",start=0.1)+
theme_void()+
theme(legend.position="none")+
scale_color_brewer(palette = "Reds")
}
check_synthesis <- function(data){
# cl <- makeCluster(getOption("cl.cores", 4))
# on.exit(stopCluster(cl))
cols <- colnames(data %>% select_if(is.numeric))
# results <- expand_grid(X = cols, Y = cols) %>%
# filter(X != Y)
results <- t(combn(cols,2))
exp_tib <- apply(results,1, function(row){
x <- row[1] |> unlist()
y <- row[2] |> unlist()
factors <- which(length(data[[x]]) %% seq(1,length(data[[x]])) == 0)
d <- scatteR::scatteR(scagnostics::scagnostics(data[[x]],data[[y]]),
n_points = length(data[[x]]),
init_points = factors[ceiling(quantile(seq(length(factors)),0.75))],epochs = 200)
destd_x <- (d$x * (max(data[[x]])-min(data[[x]]))) + min(data[[x]])
destd_y <- (d$y * (max(data[[y]])-min(data[[y]]))) + min(data[[y]])
d$x <- destd_x
d$y <- destd_y
w <- transport::wasserstein(transport::pp(cbind(destd_x,destd_y)),
transport::pp(cbind(data[[x]],data[[y]])),p=2)
return(list(scatter = d,wass_dist = w))
})
gen <- imap_dfr(exp_tib, function(ls,ind){
ls$scatter |> mutate(comb = ind,
comb_x = results[ind,"X"] |> unlist(),
comb_y = results[ind,"Y"] |> unlist())
})
synth_data <- gen %>% group_by(comb_y,comb_x) %>%
mutate(rank_y = rank(y)) %>%
ungroup() %>%
group_by(comb_y,rank_y) %>%
summarize(value = median(y,na.rm=T)) %>%
ungroup() %>%
rename(c("variable"="comb_y")) %>%
mutate(origin = "y") %>%
select(variable,value,origin) %>%
bind_rows(
gen %>% group_by(comb_x,comb_y) %>%
mutate(rank_x = rank(x)) %>%
ungroup() %>%
group_by(comb_x,rank_x) %>%
summarize(value = median(x,na.rm=T)) %>%
ungroup() %>%
rename(c("variable"="comb_x")) %>%
mutate(origin = "x") %>%
select(variable,value,origin)
) %>%
group_by(variable,origin) %>%
mutate(rank_val = rank(value)) %>%
ungroup() %>%
group_by(variable,rank_val) %>%
summarize(value = median(value,na.rm=T)) %>%
pivot_wider(id_cols = rank_val,names_from = variable,values_from = value) %>%
select(-rank_val)
results <- results %>% mutate(distance = map_dbl(exp_tib,~.x$wass_dist))
return(list(results = results,gen = gen,synth_data = synth_data))
}
dataset <- "penguins" # can be either mtcars, adult or penguins
if(dataset == "mtcars"){
mtcars <- read_csv(here("experiment","data_synthesis_accuracy_experiment","mtcars.csv"))
mtsynth <- check_synthesis(mtcars[,1:4])
saveRDS(mtsynth,file = here("experiment","data_synthesis_accuracy_experiment","mtsynth.rds"))
# mtsynth$results %>%
# ggplot(aes(x = X,y = Y,fill = distance))+
# geom_tile()+
# geom_label(aes(label = round(distance,2)),color="white")+
# theme_minimal()
}
if(dataset == "penguins"){
penguins <- read_csv(here("experiment","data_synthesis_accuracy_experiment","penguins.csv"))
penguinsynth <- check_synthesis(penguins %>% drop_na() %>% select(-year))
saveRDS(penguinsynth,file = here("experiment","data_synthesis_accuracy_experiment","penguinsynth.rds"))
# penguinsynth %>%
# ggplot(aes(x = X,y = Y,fill = distance))+
# geom_tile()+
# geom_label(aes(label = round(distance,2)),color="white")+
# theme_minimal() +
# labs(title = "Wasserstein distance metrics for synthetic bivariate data generated by scatteR",subtitle = "Original number of data points were synthesized",x = "Variable 1",y = "Variable 2",fill = "Distance")
}
if(dataset == "adult"){
adults <- read_csv(here("experiment","data_synthesis_accuracy_experiment","adult.csv"))
adultsynth <- check_synthesis(adults)
saveRDS(adultsynth,file = here("experiment","data_synthesis_accuracy_experiment","adultsynth.rds"))
}
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.