R/RetrafoState.R

Defines functions convertPrettyState objectStateToFunction functionToObjectState getPrettyState makeCPOTrainedFromState getCPOTrainedState.CPOTrained getCPOTrainedState.CPOTrainedPrimitive getCPOTrainedState

Documented in getCPOTrainedState makeCPOTrainedFromState

# RetrafoState.R provides functionality for inspecting the state of a Retrafo
# object, and reconstructing a Retrafo object with a modified state.

#' @title Get the Internal State of a CPORetrafo Object
#'
#' @description
#' A \code{\link{CPOTrained}} always has access to some kind of state
#' that represents information gotten from the training data,
#' as well as the parameters it was called with.
#'
#' Only primitive \code{\link{CPOTrained}} objects can be inspected like this.
#' If the supplied \code{\link{CPOTrained}} is not primitive, split it into
#' its constituents using \code{\link{as.list.CPOTrained}}.
#'
#' The structure of the internal state depends on the \code{\link{CPO}} backend
#' used. For Functional CPO, the state is the environment of the
#' retrafo function, turned into a list. For Object based CPO,
#' the state is a list containing the parameters, as well as the
#' control object generated by the trafo function.
#'
#' The object can be slightly modified and used to create a new
#' CPOTrained object using \code{\link{makeCPOTrainedFromState}}.
#'
#' @param trained.object [\code{CPOTrained}]\cr
#'   The object to get the state of.
#' @return [\code{list}]. A named list, containing the complete internal state of the \code{\link{CPOTrained}}.
#' @family state functions
#' @family retrafo related
#' @family inverter related
#' @export
getCPOTrainedState = function(trained.object) {
  UseMethod("getCPOTrainedState")
}


# RETRAFO / INVERTER State
# The state is basically the control object, or the trafo-function's environment
# We also keep the shapeinfo.input and shapeinfo.output information
#' @export
getCPOTrainedState.CPOTrainedPrimitive = function(trained.object) {
  if (is.nullcpo(trained.object)) {
    return(NULL)
  }
  cpo = trained.object$element$cpo
  is.retrafo = getCPOClass(trained.object) == "CPORetrafo"

  if (is.retrafo) {
    res = getPrettyState(trained.object$element$state, cpo, cpo$control.type$retrafo, FALSE)
    if (trained.object$capability["invert"] == 1) {
      # also has an inverter state
      res$target = getPrettyState(trained.object$element$state.invert, cpo, cpo$control.type$invert, TRUE, TRUE)
    }
  } else {
    # inverter
    assert(cpo$operating.type == "target")
    res = getPrettyState(trained.object$element$state, cpo, cpo$control.type$invert, TRUE)
  }

  if (is.retrafo) {
    res$data = c(trained.object$element[c("shapeinfo.input", "shapeinfo.output")])  # nolint
  } else {
    res$data = list(truth = trained.object$element$truth, task.desc = trained.object$element$task.desc)
  }
  res
}

#' @export
getCPOTrainedState.CPOTrained = function(trained.object) {
  stop("Cannot get state of compound CPOTrained. Use as.list to get individual elements")
}


# rebuilds a retrafo or inverter  object from a given state. It does that by
# constructing a "bare" (empty) CPOTrained object and filling in the missing slots.
#' @title Create a CPOTrained with Given Internal State
#'
#' @description
#' This creates a new \code{\link{CPOTrained}} object which will
#' behave according to the given state. The state should usually be obtained using
#' \code{\link{getCPOTrainedState}} and then slightly modified. No checks for correctness
#' of the state will (or can) be done, it is the user's responsibility to ensure
#' that the correct \code{\link{CPOConstructor}} is used, and that the state is
#' only modified in a way the CPO can handle.
#'
#' @param constructor [\code{\link{CPOConstructor}}]\cr
#'   A cpo constructor.
#' @param state [\code{list}]\cr
#'   A state gotten from another \code{\link{CPORetrafo}} or \code{\link{CPOInverter}} object using
#'   \code{\link{getCPOTrainedState}}.
#' @param get.inverter [logical(1)]\cr
#'   Whether to get a \code{\link{CPOInverter}}. Usually a \code{\link{CPORetrafo}} is
#'   created. This must be \code{TRUE} if the \code{state} was created from a \code{\link{CPOInverter}},
#'   \code{FALSE} otherwise. Default is \code{FALSE}.
#' @return [\code{CPOTrained}]. A \code{\link{CPORetrafo}} or \code{\link{CPOInverter}}
#'   (as if retrieved using \code{\link{retrafo}} or \code{\link{inverter}} after
#'   a primitive \code{\link{CPO}} was applied to some data) with the given state.
#' @family state functions
#' @family retrafo related
#' @family inverter related
#' @export
makeCPOTrainedFromState = function(constructor, state, get.inverter = FALSE) {
  assertClass(constructor, "CPOConstructor")
  assertList(state, names = "unique")
  bare = constructor()

  data = state$data
  state$data = NULL
  if (get.inverter) {
    assertSetEqual(names(data), c("truth", "task.desc"))
    converted = convertPrettyState(bare, state, "cpo.invert", bare$control.type$invert)
    result = makeCPOInverter(converted$bare, converted$newstate, NULLCPO, data.frame())
    result$element$truth = data$truth
    result$element$task.desc = data$task.desc

    result
  } else {
    assertSetEqual(names(data), c("shapeinfo.input", "shapeinfo.output"))
    target = state$target
    state$target = NULL
    converted = convertPrettyState(bare, state, "cpo.retrafo", bare$control.type$retrafo)
    bare = converted$bare
    if (bare$operating.type == "target" && bare$constant.invert) {
      assert(bare$control.type$retrafo != "dual.functional")
      state.invert = convertPrettyState(bare, target, "cpo.invert", bare$control.type$invert, TRUE)$newstate
    } else {
      assert(bare$control.type$retrafo == "dual.functional", is.null(target))
      state.invert = NULL
    }

    makeCPORetrafo(bare, converted$newstate, state.invert, NULLCPO, data$shapeinfo.input, data$shapeinfo.output)
  }
}



# Turn the state from the CPO into a prettier state
#
# @param state [any] the state, as returned by trafo, retrafo etc.
# @param cpo [CPO] the constructing CPO
# @param statekind [character(1)] whether the state is a function (the environment
#   needs to be unpacked), an object (the CPO's hyperparameters need to be added),
#   or a list of two functions (happens for functional simple target operation cpo:
#   then the two functions are cpo.retrafo and cpo.train.invert)
# @param is.invert [logical(1)] whether the state is for an inverter
# @param only.basic [logical(1)] whether, for object based, to not get the hyperparameters but only the state itself.
# @return [list] a list that resembles a state of a trafo / retrafo
getPrettyState = function(state, cpo, statekind = c("functional", "object", "dual.functional"), is.invert, only.basic = FALSE) {
  statekind = match.arg(statekind)
  if (statekind == "functional") {
    fun.name = if (is.invert) "cpo.invert" else "cpo.retrafo"
    res = functionToObjectState(state, fun.name, cpo$debug.name)
  } else if (statekind == "object") {
    if (only.basic) {
      res = state
    } else {
      res = getBareHyperPars(cpo)
      res$control = state
    }
  } else {  # statekind == "dual.functional"
    res = functionToObjectState(state$cpo.retrafo, "cpo.retrafo", cpo$debug.name)
    res$target = functionToObjectState(state$cpo.train.invert, "cpo.train.invert", cpo$debug.name)
  }
  res
}

# Turn the function state into an object kind of state
#
# This turns the function's environment into a list and puts the function itself in that list
# @param fun [function] the function to turn into a state
# @param name [character(1)] the name of the function within that list
# @param cpo.name [character(1)] name of the CPO, for message printing
# @return [list] a named list representing the environment of the function
functionToObjectState = function(fun, name, cpo.name) {
  assertFunction(fun)
  res = as.list(environment(fun))
  if (!name %in% names(res)) {
    res[[name]] = fun
  } else if (!identical(res[[name]], fun)) {
    stopf("Could not get coherent state of CPO Retrafo %s, since '%s' in\n%s",
      cpo.name, name, "the environment of the retrafo function is not identical to the retrafo function.")
  }
  res
}

# Turn the object into a function state
#
# This retrieves the named function from the list and turns the list into the
# function's environment
# @param state [list] the object state to use
# @param name [character(1)] the name of the function within that list
# @return [function] a function with the state as its innermost environment
objectStateToFunction = function(state, name) {
  assertSubset(name, names(state))

  newstate = state[[name]]
  # update newstate's environment to actually contain the
  # values set in the 'state'
  env = new.env(parent = parent.env(environment(newstate)))
  list2env(state, envir = env)
  environment(newstate) = env
  # also set the [[name]] in the env to point to the current 'newstate' function.
  # if we did not do this, the [[name]] variable visible to newstate would
  # be the same function *but with a different environment* -- recursion would break
  # (this is because of 'environment(newstate) = env' above)
  env[[name]] = newstate
  newstate
}

# Recover internally usable 'state' from 'CPOTrainedState'
#
# Does some internal work
# @param bare [CPO] the bare cpo to build around
# @param state [list] the prettified state to use
# @param fun.name [character(1)] the name of the function, if the state is functional
# @param control.type [character(1)] one of "functional", "object", "dual.functional"
# @return [list] list(bare, newstate)
convertPrettyState = function(bare, state, fun.name, control.type, only.basic = FALSE) {
  if (control.type == "dual.functional") {
    assert(bare$operating.type == "target")
    assert(bare$constant.invert)
    assert(fun.name == "cpo.retrafo")
    bare$par.vals = list()
    newstate = list(
        cpo.retrafo = objectStateToFunction(state, "cpo.retrafo"),
        cpo.train.invert = objectStateToFunction(state, "cpo.train.invert"))
  } else if (control.type == "functional") {
    bare$par.vals = list()
    newstate = objectStateToFunction(state, fun.name)
  } else {
    assert(control.type == "object")
    if (only.basic) {
      newstate = state
    } else {
      assertSubset("control", names(state))
      newstate = state$control
      state$control = NULL
      assertSubset(names(state), c(names(bare$bare.par.set$pars), names(bare$unexportedpar.set$pars)))
      bare$unexported.pars = state[names(bare$unexportedpar.set$pars)]
      state = dropNamed(state, names(bare$unexportedpar.set$pars))
      if (length(state)) {
        names(state) = paste(bare$id, names2(state), sep = ".")
      }
      bare$par.vals = state

    }
  }
  list(bare = bare, newstate = newstate)
}

Try the mlrCPO package in your browser

Any scripts or data that you put into this service are public.

mlrCPO documentation built on Nov. 18, 2022, 1:05 a.m.