paired_sampler: Sampling Paired Observations

View source: R/approach_vaeac_torch_modules.R

paired_samplerR Documentation

Sampling Paired Observations

Description

A sampler used to samples the batches where each instances is sampled twice

Usage

paired_sampler(vaeac_dataset_object, shuffle = FALSE)

Arguments

vaeac_dataset_object

A vaeac_dataset() object containing the data.

shuffle

Boolean. If TRUE, then the data is shuffled. If FALSE, then the data is returned in chronological order.

Details

A sampler object that allows for paired sampling by always including each observation from the vaeac_dataset() twice. A torch::sampler() object can be used with torch::dataloader() when creating batches from a torch dataset torch::dataset(). See https://rdrr.io/cran/torch/src/R/utils-data-sampler.R for more information. This function does not use batch iterators, which might increase the speed.

Author(s)

Lars Henry Berge Olsen

Examples

## Not run: 
# Example how to use it combined with mask generators with paired sampling activated
batch_size <- 4
if (batch_size %% 2 == 1) batch_size <- batch_size - 1 # Make sure that batch size is even
n_features <- 3
n_observations <- 5
shuffle <- TRUE
data <- torch_tensor(matrix(rep(seq(n_observations), each = n_features),
  ncol = n_features, byrow = TRUE
))
data
dataset <- vaeac_dataset(data, rep(1, n_features))
dataload <- torch::dataloader(dataset,
  batch_size = batch_size,
  sampler = paired_sampler(dataset,
    shuffle = shuffle
  )
)
dataload$.length() # Number of batches, same as ceiling((2 * n_observations) / batch_size)
mask_generator <- mcar_mask_generator(paired = TRUE)
coro::loop(for (batch in dataload) {
  mask <- mask_generator(batch)
  obs <- mask * batch
  print(torch::torch_cat(c(batch, mask, obs), -1))
})

## End(Not run)

NorskRegnesentral/shapr documentation built on April 19, 2024, 1:19 p.m.