R/make_dataset.R

Defines functions make_dataset

make_dataset <- function(data, x, device) {
	self <- NULL
	dataset <- torch::dataset(
		name = "tmp_crumble_dataset",
		initialize = function(data, x, device) {
			for (df in names(data)) {
				if (ncol(data[[df]]) > 0) {
					df_x <- data[[df]][, x, drop = FALSE]
					self[[df]] <- one_hot_encode(df_x) |>
						as_torch(device = device)
				}
			}
		},
		.getitem = function(i) {
			fields <- grep("data", names(self), value = TRUE)
			setNames(lapply(fields, function(x) self[[x]][i, ]), fields)
		},
		.length = function() {
			self$data$size()[1]
		}
	)
	dataset(data, x, device)
}

Try the crumble package in your browser

Any scripts or data that you put into this service are public.

crumble documentation built on April 13, 2025, 5:10 p.m.