R/utils-data-fetcher.R

BaseDatasetFetcher <- R6::R6Class(
  classname = "BaseDatasetFetcher",
  lock_objects = FALSE,
  public = list(
    initialize = function(dataset, auto_collation, collate_fn, drop_last) {
      self$dataset <- dataset
      self$auto_collation <- auto_collation
      self$collate_fn <- collate_fn
      self$drop_last <- drop_last
    },
    fetch = function(possibly_batched_index) {
      not_implemented_error()
    },
    state_dict = function() {
      self$dataset$state_dict()
    },
    load_state_dict = function(x, ...) {
      self$dataset$load_state_dict(x, ...)
    }
  )
)

IterableDatasetFetcher <- R6::R6Class(
  classname = "IterableDatasetFetcher",
  lock_objects = FALSE,
  inherit = BaseDatasetFetcher,
  public = list(
    initialize = function(dataset, auto_collation, collate_fn, drop_last) {
      super$initialize(dataset, auto_collation, collate_fn, drop_last)
      self$dataset_iter <- as_iterator(self$dataset)
    },
    fetch = function(possibly_batched_index) {
      if (self$auto_collation) {
        data <- vector(mode = "list", length = length(possibly_batched_index))
        for (i in seq_along(possibly_batched_index)) {
          d <- self$dataset_iter()

          if (is_exhausted(d)) {
            if (self$drop_last || i ==  1) {
              return(coro::exhausted())
            }
            
            # we drop the null values in that list.
            data <- data[seq_len(i-1L)]
            break
          }

          data[[i]] <- d
        }
      } else {
        data <- self$dataset_iter()
      }
      
      self$collate_fn(data)
    }
  )
)

MapDatasetFetcher <- R6::R6Class(
  classname = "MapDatasetFetcher",
  lock_objects = FALSE,
  inherit = BaseDatasetFetcher,
  public = list(
    initialize = function(dataset, auto_collation, collate_fn, drop_last) {
      super$initialize(dataset, auto_collation, collate_fn, drop_last)
    },
    fetch = function(possibly_batched_index) {
      if (self$auto_collation) {
        data <- vector(mode = "list", length = length(possibly_batched_index))
        dataset <- self$dataset
        for (i in seq_along(data)) {
          data[[i]] <- dataset[possibly_batched_index[[i]]]
        }
      } else {
        data <- self$dataset[possibly_batched_index]
      }
      self$collate_fn(data)
    }
  )
)

Try the torch package in your browser

Any scripts or data that you put into this service are public.

torch documentation built on May 29, 2024, 9:54 a.m.