dataset_rejection_resample: A transformation that resamples a dataset to a target...

Description Usage Arguments Value Examples

View source: R/dataset_methods.R

Description

A transformation that resamples a dataset to a target distribution.

Usage

1
2
3
4
5
6
7
8
dataset_rejection_resample(
  dataset,
  class_func,
  target_dist,
  initial_dist = NULL,
  seed = NULL,
  name = NULL
)

Arguments

dataset

A tf.Dataset

class_func

A function mapping an element of the input dataset to a scalar tf.int32 tensor. Values should be in [0, num_classes).

target_dist

A floating point type tensor, shaped [num_classes].

initial_dist

(Optional.) A floating point type tensor, shaped [num_classes]. If not provided, the true class distribution is estimated live in a streaming fashion.

seed

(Optional.) Integer seed for the resampler.

name

(Optional.) A name for the tf.data operation.

Value

A tf.Dataset

Examples

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
## Not run: 
initial_dist <- c(.5, .5)
target_dist <- c(.6, .4)
num_classes <- length(initial_dist)
num_samples <- 100000
data <- sample.int(num_classes, num_samples, prob = initial_dist, replace = TRUE)
dataset <- tensor_slices_dataset(data)
tally <- c(0, 0)
`add<-` <- function (x, value) x + value
# tfautograph::autograph({
#   for(i in dataset)
#     add(tally[as.numeric(i)]) <- 1
# })
dataset %>%
  as_array_iterator() %>%
  iterate(function(i) {
    add(tally[i]) <<- 1
  }, simplify = FALSE)
# The value of `tally` will be close to c(50000, 50000) as
# per the `initial_dist` distribution.
tally # c(50287, 49713)

tally <- c(0, 0)
dataset %>%
  dataset_rejection_resample(
    class_func = function(x) (x-1) %% 2,
    target_dist = target_dist,
    initial_dist = initial_dist
  ) %>%
  as_array_iterator() %>%
  iterate(function(element) {
    names(element) <- c("class_id", "i")
    add(tally[element$i]) <<- 1
  }, simplify = FALSE)
# The value of tally will be now be close to c(75000, 50000)
# thus satisfying the target_dist distribution.
tally # c(74822, 49921)

## End(Not run)

tfdatasets documentation built on Nov. 10, 2021, 1:07 a.m.