R/utils.R

Defines functions order_named_args get_example_batch CallbacksNone OptimizerNone LossNone check_callbacks check_nn_module replace_head auto_cache_lazy_tensors list_to_batch test_equal shape_to_str uniqueify argument_matcher test_equal_col_info get_nout load_col_info sample_input_from_shapes class_with_init default_task_id get_init assert_inherits_classname check_integerish_or_null check_function_or_null make_check_vector inferps running_on_mac auto_device

Documented in auto_device replace_head

#' Auto Device
#'
#' First tries cuda, then cpu.
#'
#' @param device (`character(1)`)\cr
#'   The device. If not `NULL`, is returned as is.
#' @export
auto_device = function(device = NULL) {
  if (device == "auto") {
    device = if (cuda_is_available()) "cuda" else "cpu"
    lg$debug("Auto-detected device '%s'.", device)
  }
  return(device)
}

running_on_mac = function() {
  Sys.info()["sysname"] == "Darwin"
}

inferps = function(fn, ignore = character(0), tags = "train") {
  if (inherits(fn, "R6ClassGenerator")) {
    fn = get_init(fn)
    if (is.null(fn)) {
      return(ps())
    }
  }
  assert_function(fn)
  assert_character(ignore, any.missing = FALSE)
  ignore = union(ignore, "...")
  frm = formals(fn)
  frm = frm[names(frm) %nin% ignore]

  frm_domains = lapply(frm, function(formal) p_uty(tags = tags))

  do.call(paradox::ps, frm_domains)
}


make_check_vector = function(d) {
  crate(function(x) {
    if (is.null(x) || test_integerish(x, any.missing = FALSE) && (length(x) %in% c(1, d))) {
      return(TRUE)
    }
    tmp = if (d == 1) "." else sprintf(" or %s.", d)
    sprintf("Must be an integerish vector of length 1%s", tmp)
    }, d, .parent = topenv())
}

check_function_or_null = function(x) check_function(x, null.ok = TRUE)

check_integerish_or_null = function(x) check_integerish(x, null.ok = TRUE)

assert_inherits_classname = function(class_generator, classname) {
  assert_class(class_generator, "R6ClassGenerator")
  while (!is.null(class_generator)) {
    if (class_generator$classname == classname) {
      return(TRUE)
    }
    class_generator = class_generator$get_inherit()
  }
  stopf("R6ClassGenerator does not generate object that inherits from %s.", classname)
}

get_init = function(x) {
  cls = class_with_init(x)
  if (is.null(cls)) return(NULL)
  cls$public_methods$initialize
}

default_task_id = function(learner) {
  task_id = get0("task_id", envir = parent.frame(), inherits = FALSE)
  if (!is.null(task_id)) {
    return(task_id)
  }
  if (inherits(learner, "LearnerTorchImage")) {
    stopf("Currently not available!")
  }
  switch(learner$task_type,
    classif = "iris",
    regr = "mtcars",
    stopf("No default task type.")
  )

}

class_with_init = function(x) {
  if (is.null(x)) {
    # This is the case where no initialize method is found
    return(NULL)
  } else if (is.null(x$public_methods) || exists("initialize", x$public_methods, inherits = FALSE)) {
    return(x)
  } else {
    Recall(x$get_inherit())
  }
}

sample_input_from_shapes = function(shapes, n = 1L) {
  assert_list(shapes, types = "numeric", min.len = 1)
  assert_int(n)
  imap(shapes, function(shape, nm) {
    shape[1] = n
    invoke(torch_randn, .args = as.list(shape))
  })
}

load_col_info = function(name) {
  readRDS(system.file("col_info", paste0(name, ".rds"), package = "mlr3torch"))
}

get_nout = function(task) {
  switch(task$task_type,
    regr = 1,
    classif = length(task$class_names),
    stopf("Unknown task type '%s'.", task$task_type)
  )
}


test_equal_col_info = function(x, y) {
  nms = c("id", "type", "levels")
  if (!(test_permutation(colnames(x), nms) && test_permutation(colnames(y), nms))) {
    return(FALSE)
  }

  x = x[order(get("id"))]
  y = y[order(get("id"))]

  isTRUE(all.equal(x$id, y$id)) && isTRUE(all.equal(x$type, y$type)) &&
    all(pmap_lgl(list(x = x$levels, y = y$levels), function(x, y) isTRUE(all.equal(x, y))))
}


# a function that has argument names 'names' and returns its arguments as a named list.
# used to simulate argument matching for `...`-functions.
# example:
# f = argument_matcher(c("a", "b", "c"))
# f(1, 2, 3) --> list(a = 1, b = 2, c = 3)
# f(1, 2, a = 3) --> list(a = 3, b = 1, c = 2)
# usecase:
# ff = function(...) {
#   l = argument_matcher(c("a", "b", "c"))(...)
#   l$a + l$b
# }
# # behaves like
# ff(a, b, c) a + b
# (Except in the aqward case of missing args)
argument_matcher = function(args) {
  fn = as.function(c(named_list(args, substitute()), quote(as.list(environment()))))
  environment(fn) = topenv()
  fn
}

uniqueify = function(new, existing) {
  make.unique(c(existing, new), sep = "_")[length(existing) + seq_along(new)]
}

shape_to_str = function(x) {
  shapedescs = map_chr(x, function(y) {
    if (is.null(y)) {
      return("<unknown>")
    }
    paste0("(", paste(y, collapse = ",", recycle0 = TRUE), ")")
  })
  if (test_named(x)) {
    repr = paste0("[", names(x), ": ",  paste(shapedescs, collapse = ";", recycle0 = TRUE), "]")
    return(repr)
  }
  paste0("[",  paste(shapedescs, collapse = ";", recycle0 = TRUE), "]")
}

test_equal = function(x, y) {
  isTRUE(all.equal(x, y))
}


dataset_image = dataset("image_dataset",
  initialize = function(uris) {
    self$uris = uris
  },
  .getitem = function(x) {
    list(x = torchvision::transform_to_tensor(torchvision::base_loader(self$uris[x])))
  },
  .length = function() {
    length(self$uris)
  }
)

list_to_batch = function(tensors) {
  torch_cat(map(tensors, function(x) x$unsqueeze(1)), dim = 1L)
}

auto_cache_lazy_tensors = function(lts) {
  any(duplicated(map_chr(lts, function(x) dd(x)$dataset_hash)))
}

#' Replace the head of a network
#' Replaces the head of the network with a linear layer with d_out classes.
#' @param network ([`torch::nn_module`])\cr
#'   The network
#' @param d_out (`integer(1)`)\cr
#'   The number of output classes.
#' @export
#' @keywords internal
replace_head = function(network, d_out) {
  UseMethod("replace_head")
}

check_nn_module = function(x) {
  check_class(x, "nn_module")
}

check_callbacks = function(x) {
  if (test_class(x, "TorchCallback")) {
    x = list(x)
  } else {
    if (!test_list(x, "TorchCallback")) {
      return("Invalid parameter 'callbacks'")
    }
  }
  callback_ids = ids(x)
  if (!test_names(callback_ids, type = "unique")) {
    return("Callbacks must have unique IDS that are valid names.")
  }
  TRUE
}

LossNone = function() {
  structure(list(), class = "LossNone")
}

OptimizerNone = function() {
  structure(list(), class = "OptimizerNone")
}

CallbacksNone = function() {
  structure(list(), class = "CallbacksNone")
}

get_example_batch = function(dl) {
  ds = dl$dataset
  if (length(ds) < 2) {
    stopf("Dataset needs to contain at least 2 observations")
  }
  if (!is.null(ds$.getbatch)) {
    ds$.getbatch(1:2)
  } else {
    torch_stack(list(
      ds$.getitem(1),
      ds$.getitem(2)
    ))
  }
}

# l: list containing args
# f: function to be called
# returns: l reordered so it can be passed by position
order_named_args = function(f, l) {
  args = formalArgs(f)
  x = match("...", args)
  f2 = f
  body(f2) = quote(as.list(match.call())[-1L])
  l2 = do.call(f2, l)
  # (function(..., x = 1) {})(1, 2, 3) works
  # (function(..., x = 1) {})(1, 2, x = 3) DOES NOT WORK
  if (!is.null(names(l2)) && !is.na(x) && x != length(args)) {
    stopf("Because arguments are passed to tracer by position, `...` must either be the only or the last argument when named arguments are also passed.")
  }
  l2
}

Try the mlr3torch package in your browser

Any scripts or data that you put into this service are public.

mlr3torch documentation built on April 4, 2025, 3:03 a.m.