View source: R/dataset_methods.R
dataset_rejection_resample | R Documentation |
A transformation that resamples a dataset to a target distribution.
dataset_rejection_resample( dataset, class_func, target_dist, initial_dist = NULL, seed = NULL, name = NULL )
dataset |
A |
class_func |
A function mapping an element of the input dataset to a
scalar |
target_dist |
A floating point type tensor, shaped |
initial_dist |
(Optional.) A floating point type tensor, shaped
|
seed |
(Optional.) Integer seed for the resampler. |
name |
(Optional.) A name for the tf.data operation. |
A tf.Dataset
## Not run: initial_dist <- c(.5, .5) target_dist <- c(.6, .4) num_classes <- length(initial_dist) num_samples <- 100000 data <- sample.int(num_classes, num_samples, prob = initial_dist, replace = TRUE) dataset <- tensor_slices_dataset(data) tally <- c(0, 0) `add<-` <- function (x, value) x + value # tfautograph::autograph({ # for(i in dataset) # add(tally[as.numeric(i)]) <- 1 # }) dataset %>% as_array_iterator() %>% iterate(function(i) { add(tally[i]) <<- 1 }, simplify = FALSE) # The value of `tally` will be close to c(50000, 50000) as # per the `initial_dist` distribution. tally # c(50287, 49713) tally <- c(0, 0) dataset %>% dataset_rejection_resample( class_func = function(x) (x-1) %% 2, target_dist = target_dist, initial_dist = initial_dist ) %>% as_array_iterator() %>% iterate(function(element) { names(element) <- c("class_id", "i") add(tally[element$i]) <<- 1 }, simplify = FALSE) # The value of tally will be now be close to c(75000, 50000) # thus satisfying the target_dist distribution. tally # c(74822, 49921) ## End(Not run)
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.