experiment/data_synthesis_accuracy_experiment/synth_experiment.R

library(palmerpenguins)
library(tidyverse)
library(scatteR)
# library(parallel)
library(here)

gen_art_data <- function(data){
  cols <- colnames(data %>% select_if(is.numeric))
  results <- t(combn(cols,2))
  synth_ls <- future_apply(results,1,function(row){
    x <- row[1] |> unlist()
    y <- row[2] |> unlist()
    factors <- which(length(data[[x]]) %% seq(1,length(data[[x]])) == 0)
    d <- scatteR::scatteR(scagnostics::scagnostics(data[[x]],data[[y]]),
                          n_points = length(data[[x]]),
                          init_points = factors[ceiling(quantile(seq(length(factors)),0.75))],epochs = 200)
    # destd_x <- (d$x * (max(data[[x]])-min(data[[x]]))) + min(data[[x]])
    # destd_y <- (d$y * (max(data[[y]])-min(data[[y]]))) + min(data[[y]])
    # d$x <- destd_x
    # d$y <- destd_y
    list(d)
  })
  synth_df <- imap_dfr(synth_ls,function(df,i){
    df[[1]] |> mutate(comb = i,
                 comb_x = results[i,1] |> unlist(),
                 comb_y = results[i,2] |> unlist())
  })
  synth_df %>%
    ggplot(aes(x,y,color = factor(comb_y)))+
    geom_jitter()+
    geom_smooth(se = FALSE,method="lm") +
    coord_polar(theta = "x",start=0.1)+
    theme_void()+
    theme(legend.position="none")+
    scale_color_brewer(palette = "Reds")
}

check_synthesis <- function(data){
  # cl <- makeCluster(getOption("cl.cores", 4))
  # on.exit(stopCluster(cl))
  cols <- colnames(data %>% select_if(is.numeric))
  # results <- expand_grid(X = cols, Y = cols) %>%
  #   filter(X != Y)
  results <- t(combn(cols,2))
  exp_tib <- apply(results,1, function(row){
    x <- row[1] |> unlist()
    y <- row[2] |> unlist()
    factors <- which(length(data[[x]]) %% seq(1,length(data[[x]])) == 0)
    d <- scatteR::scatteR(scagnostics::scagnostics(data[[x]],data[[y]]),
                          n_points = length(data[[x]]),
                          init_points = factors[ceiling(quantile(seq(length(factors)),0.75))],epochs = 200)
    destd_x <- (d$x * (max(data[[x]])-min(data[[x]]))) + min(data[[x]])
    destd_y <- (d$y * (max(data[[y]])-min(data[[y]]))) + min(data[[y]])
    d$x <- destd_x
    d$y <- destd_y
    w <- transport::wasserstein(transport::pp(cbind(destd_x,destd_y)),
                                transport::pp(cbind(data[[x]],data[[y]])),p=2)
    return(list(scatter = d,wass_dist = w))
  })
  gen <- imap_dfr(exp_tib, function(ls,ind){
    ls$scatter |> mutate(comb = ind,
                         comb_x = results[ind,"X"] |> unlist(),
                         comb_y = results[ind,"Y"] |> unlist())
    })
  synth_data <- gen %>% group_by(comb_y,comb_x) %>%
    mutate(rank_y = rank(y)) %>%
    ungroup() %>%
    group_by(comb_y,rank_y) %>%
    summarize(value = median(y,na.rm=T)) %>%
    ungroup() %>%
    rename(c("variable"="comb_y")) %>%
    mutate(origin = "y") %>%
    select(variable,value,origin) %>%
    bind_rows(
      gen %>% group_by(comb_x,comb_y) %>%
        mutate(rank_x = rank(x)) %>%
        ungroup() %>%
        group_by(comb_x,rank_x) %>%
        summarize(value = median(x,na.rm=T)) %>%
        ungroup() %>%
        rename(c("variable"="comb_x")) %>%
        mutate(origin = "x") %>%
        select(variable,value,origin)
    ) %>%
    group_by(variable,origin) %>%
    mutate(rank_val = rank(value)) %>%
    ungroup() %>%
    group_by(variable,rank_val) %>%
    summarize(value = median(value,na.rm=T)) %>%
    pivot_wider(id_cols = rank_val,names_from = variable,values_from  = value) %>%
    select(-rank_val)

  results <- results %>% mutate(distance = map_dbl(exp_tib,~.x$wass_dist))
  return(list(results = results,gen = gen,synth_data = synth_data))
}

dataset <- "penguins" # can be either mtcars, adult or penguins

if(dataset == "mtcars"){
  mtcars <- read_csv(here("experiment","data_synthesis_accuracy_experiment","mtcars.csv"))
  mtsynth <- check_synthesis(mtcars[,1:4])
  saveRDS(mtsynth,file = here("experiment","data_synthesis_accuracy_experiment","mtsynth.rds"))
  # mtsynth$results %>%
  #   ggplot(aes(x = X,y = Y,fill = distance))+
  #   geom_tile()+
  #   geom_label(aes(label = round(distance,2)),color="white")+
  #   theme_minimal()
}

if(dataset == "penguins"){
  penguins <- read_csv(here("experiment","data_synthesis_accuracy_experiment","penguins.csv"))
  penguinsynth <- check_synthesis(penguins %>% drop_na() %>% select(-year))
  saveRDS(penguinsynth,file = here("experiment","data_synthesis_accuracy_experiment","penguinsynth.rds"))
  # penguinsynth %>%
  #   ggplot(aes(x = X,y = Y,fill = distance))+
  #   geom_tile()+
  #   geom_label(aes(label = round(distance,2)),color="white")+
  #   theme_minimal() +
  #   labs(title = "Wasserstein distance metrics for synthetic bivariate data generated by scatteR",subtitle = "Original number of data points were synthesized",x = "Variable 1",y = "Variable 2",fill = "Distance")
}

if(dataset == "adult"){
  adults <- read_csv(here("experiment","data_synthesis_accuracy_experiment","adult.csv"))
  adultsynth <- check_synthesis(adults)
  saveRDS(adultsynth,file = here("experiment","data_synthesis_accuracy_experiment","adultsynth.rds"))
}
janithwanni/scatteR documentation built on March 1, 2023, 6:08 a.m.