R/machine_split.R

Defines functions machine_split

Documented in machine_split

machine_split <- function(data , group = "timestamp" , behaviour , train_size , val_size ,
                          seed = 142 , names = c("train_data","val_data","test_data")){
  set.seed(seed)
  train_size <- round(x = train_size , digits = 2)
  val_size <- round(x = val_size , digits = 2)
  if(train_size + val_size > 1){
    stop("train_size + val_size have to be smaller than 1")
  }
  nested_data <- data%>%
    dplyr::group_by_(. , behaviour , group)%>%
    tidyr::nest(.)%>%
    dplyr::ungroup(.)%>%
    dplyr::group_by_(. , behaviour)

  train_data <- nested_data%>%
    dplyr::sample_frac( . , size = train_size)%>%
    dplyr::ungroup(.)%>%
    dplyr::slice(. , sample(nrow(.)))%>%
    tidyr::unnest(. , cols = c(data))

  remaining <- nested_data[!nested_data[[group]] %in% train_data[[group]], ]

  val_size_calc <- (nrow(data) * val_size) /
    (nrow(data) - train_size*nrow(data))

  val_data <- remaining%>%
    dplyr::sample_frac( . , size = val_size_calc)%>%
    dplyr::ungroup(.)%>%
    dplyr::slice(. , sample(nrow(.)))%>%
    tidyr::unnest(. , cols = c(data))

  if(train_size + val_size < 1){
    test_data <- remaining[!(remaining[[group]] %in% val_data[[group]]), ]%>%
      dplyr::ungroup(.)%>%
      dplyr::slice(. , sample(nrow(.)))%>%
      tidyr::unnest(. , cols = c(data))

    assign(names[3] , test_data , envir = .GlobalEnv)
  }

  assign(names[1] , train_data , envir = .GlobalEnv)
  assign(names[2] , val_data , envir = .GlobalEnv)
}
wanjarast/accelerateR documentation built on June 21, 2022, 3:29 p.m.