R/callCPO.R

Defines functions checkAllParams getBareHyperPars predict.CPORetrafo applyCPO.CPO callCPORetrafoElement callCPO.CPOPipeline callCPO.CPOPrimitive callCPO makeCPOTrainedBasic makeCPORetrafo makeCPOInverter

# Callcpo.R: functions that are involved in calling CPOs and retrafos.
# They handle input checks, call the relevant CPO functions, and construct
# relevant Retrafo / Inverter objects.

#' @include FormatCheck.R
##################################
### Creators                   ###
##################################

# Creates the "Inverter" S3 object. Both "Inverter" and "Retrafo"
# have the class "CPOTrained", with slight differences between them.
# Since a Retrafo can sometimes do the work of an Inverter, we can't
# use S3 to differentiate between them effectively.
# @param cpo [CPOPrimitive] the CPO from which the inverter is generated
# @param state [any] the "state" associated with the retrafo, usually generated by cpo.trafo
# @param prev.inverter [CPOInverter | NULL] potentially an inverter to prepend to this inverter
# @param data [data.frame | Task] the data used to build the inverter
# @return [CPOInverter] a new CPOInverter object. It inherits from `CPOTrained` and can be used to undo a
#   target-bound CPO's action on a prediction object.
makeCPOInverter = function(cpo, state, prev.inverter, data) {

  if (is.data.frame(data)) {
    td = NULL
    truth = NULL
  } else {
    td = getTaskDesc(data)
    truth = getTaskData(data, target.extra = TRUE)$target
  }

  inverter = makeCPOTrainedBasic(cpo, state, "CPOInverter", "InverterElement", c(retrafo = -1L, invert = 1L))
  # --- state for pure "inverter":
  inverter$element$task.desc = td
  inverter$element$truth = truth  # may be NULL for a data.frame / cluster task
  composeCPO(prev.inverter, inverter)
}

# Creates the "Retrafo" S3 object. See comment above 'makeCPOInverter'. Note that some CPOs
# create `CPORetrafo` objects that are hybrid retrafo / inverter, see getCPOTrainedCapability.
# @param cpo [CPOPrimitive] the CPO from which the retrafo is generated
# @param state [any] the "state" associated with the retrafo, usually generated by cpo.trafo
# @param state.invert [any] the "state" associated with the inverter, usually generated by cpo.trafo
# @param prev.retrafo [CPORetrafo | NULL] potentially a retrafo to prepend to this retrafo
# @param shapeinfo.input [InputShapeInfo] required shape of input data
# @param shapeinfo.output [OutputShapeInfo] required shape of data coming out of cpo.retrafo
# @return [CPORetrafo] a new CPORetrafo object. It inherits from `CPOTrained` and can be used to re-apply
#   a CPO's action on feature columns of new data. Note that some CPOs
#   create `CPORetrafo` objects that are hybrid retrafo / inverter, see getCPOTrainedCapability.
makeCPORetrafo = function(cpo, state, state.invert, prev.retrafo, shapeinfo.input, shapeinfo.output) {
  invcap = if (cpo$operating.type != "target") {
    assert(is.null(state.invert))
    0L
  } else if (cpo$constant.invert) {
    1L
  } else {
    -1L
  }
  retrafo = makeCPOTrainedBasic(cpo, state, "CPORetrafo", "RetrafoElement", c(retrafo = 1L, invert = invcap))
  # --- state only in "CPORetrafo":
  retrafo$element$shapeinfo.input = shapeinfo.input
  retrafo$element$shapeinfo.output = shapeinfo.output
  if (cpo$constant.invert) {
    retrafo$element$state.invert = state.invert
  }
  composeCPO(prev.retrafo, retrafo)
}

# Creates an object of class "CPOTrained", which
# serves as basis for both "Inverter" and "Retrafo" objects.
# @param cpo [CPOPrimitive] the CPO from which the retrafo / inverter is generated
# @param state [any] the "state" associated with the retrafo / inverter, usually generated by cpo.trafo
# @param subclass [character] the class(es) of the returned object, besides "CPOTrained"
# @param elclass [character(1)] class of the 'element' linked list element
# @param capability [integer(2)] capability of the CPOTrained. Named vector c(retrafo = x, invert = x), where
#   -1 means not possible, 0 means no effect, 1 means effect.
# @return [CPOTrained] object on which `CPORetrafo` and `CPOInverter` objects are based
makeCPOTrainedBasic = function(cpo, state, subclass, elclass, capability) {
  retrafo = makeS3Obj(c("CPOTrainedPrimitive", subclass, "CPOTrained"),
    element = makeS3Obj(elclass,      # A linked list of internal CPO states
        cpo = cpo,                    # the CPOPrimitive that was used to create this object
        state = state,                # whatever control object or retrafo function was given
        capability = capability,      # capability of this CPOTrained unit
        prev.predict.type =           # the predict.type transformation matrix of all preceding elements (in prev.retrafo.elt)
          cpo.identity.predict.type.map,
        prev.retrafo.elt = NULL),     # Point to the "previous" CPOTrained element, forming a linked list
    # --- cached info for chained CPOTrained
    name = cpo$name,
    capability = capability,          # capability of this retrafo: c(retrafo = x, invert = x)
    predict.type = cpo$predict.type,  # named list type to predict --> needed type
    properties = cpo$properties,      # properties list(handling, adding, needed)
    convertfrom = cpo$convertfrom,    # converting from task type, or NULL if not converting
    convertto = cpo$convertto)        # converting to task type, or NULL if not converting
}

##################################
### Primary Operations         ###
##################################
# CPO is a tree datastructure. CPOPrimitive are
# the leaves, CPOPipeline the nodes.
# CPOTrained is a linked list, which gets automatically
# constructed in 'callCPO'.

# Call the (possibly compound) CPO pipeline.
# Internal function; the user-facing functions (applyCPO) also make some checks
# and strips the retrafo and inverter tags.
# @param cpo [CPOPrimitive | CPOPipeline | ...] the CPO to apply to the data
# @param data [data.frame | Task] the data to transform
# @param build.retrafo [logical(1)] whether to create a 'retrafo' object
# @param prev.retrafo [CPORetrafo | NULL] possible retrafo linked list that the newly created retrafo gets appended to
# @param build.inverter [logical(1)] whether to create an 'inverter' object, if applicable
# @param prev.inverter [CPOInverter | NULL] possible inverter linked list to append the new inverter to
# @return [list] list(data, retrafo, inverter)
callCPO = function(cpo, data, build.retrafo, prev.retrafo, build.inverter, prev.inverter) {
  UseMethod("callCPO")
}

# TRAFO main function
# - checks the inbound and outbound data is in the right format
# - data will be turned into the shape requested by the cpo
# - properties check (inbound, and outbound)
# - automatically subsets 'args' to the relevant ones for cpo
# - collects control / cpo.retrafo from called function
# - returns list(data, retrafo = [CPOTrained object])
# attaches prev.retrafo to the returned retrafo object, if present.
# for parameters, see 'callCPO' documentation.
callCPO.CPOPrimitive = function(cpo, data, build.retrafo, prev.retrafo, build.inverter, prev.inverter) {

  assert(build.inverter, is.null(prev.inverter))

  checkAllParams(cpo$par.vals, cpo$par.set, cpo$debug.name)

  tin = prepareTrafoInput(data, cpo$dataformat, cpo$strict.factors, cpo$properties.raw,
    getCPOAffect(cpo, FALSE), cpo$fix.factors, cpo$operating.type, cpo$debug.name)

  tin$indata$build.inverter = build.inverter || cpo$constant.invert
  result = do.call(cpo$trafo.funs$cpo.trafo, insert(getBareHyperPars(cpo), tin$indata))

  tout = handleTrafoOutput(result$result, tin, cpo$properties$needed.max, cpo$properties$adding.min, cpo$convertto,
    cpo$operating.type.extended == "feature")  # this last bit is because (simple) FOCPOs return df.features instead of df.all / task

  if (build.retrafo && cpo$operating.type != "retrafoless") {
    prev.retrafo = makeCPORetrafo(cpo, result$state, result$state.invert, prev.retrafo, tin$shapeinfo, tout$shapeinfo)
  }

  if (build.inverter && cpo$operating.type == "target") {
    prev.inverter = makeCPOInverter(cpo, result$state.invert, prev.inverter, data)
  }

  list(data = tout$outdata, retrafo = prev.retrafo, inverter = prev.inverter)
}

# call cpo$first, then cpo$second, and chain the retrafos.
#
# A CPO tree looks like this:
#
#                CPOPipeline
#               /[first]    \[second]
#     CPOPipeline           CPOPipeline
#     /[first]  \[second]   /[first]  \[second]
# CPOPrim1    CPOPrim2  CPOPrim3    CPOPrim4
#
# callCPO calls go as:
#
#    1 -------> 2 -------> 3 -------> 4
#
# Retrafos are a chained list, where slot 'prev.retrafo' points to the previous retrafo object:
#
# retr.1 <-- retr.2 <-- retr.3 <-- retr.4
#
# for parameters, see 'callCPO' documentation
callCPO.CPOPipeline = function(cpo, data, build.retrafo, prev.retrafo, build.inverter, prev.inverter) {
  checkAllParams(cpo$par.vals, cpo$par.set, cpo$debug.name)
  first = cpo$first
  second = cpo$second
  first$par.vals = subsetParams(cpo$par.vals, first$par.set)
  second$par.vals = subsetParams(cpo$par.vals, second$par.set)
  intermediate = callCPO(first, data, build.retrafo, prev.retrafo, build.inverter, prev.inverter)
  callCPO(second, intermediate$data, build.retrafo, intermediate$retrafo, build.inverter, intermediate$inverter)
}

# RETRAFO main function
# - checks the inbound and outbound data is in the right format
# - checks the shape of input and output is as was before
# - data will be turned into the shape requested by the cpo
# - properties check (inbound, and outbound)
# - automatically subsets 'args' to the relevant ones for cpo
# - possibly calls next.retrafo
# - returns the resulting data
# This is the retrafo equivalent to callCPO. However, since CPOTrained
# is a different data structure than compound CPO ("CPOPipeline"), we don't need
# any S3 here.
# @param retrafo [CPORetrafoElement] the retrafo to apply
# @param data [data.frame | Task] the data to transform
# @param build.inverter [logical(1)] whether to create an 'inverter' object, if applicable
# @param prev.inverter [CPOInverter | NULL] possible inverter linked list to append the new inverter to
# @return [list] list(data, inverter)
callCPORetrafoElement = function(retrafo, data, build.inverter, prev.inverter) {


  assert(checkClass(retrafo, "RetrafoElement"), checkClass(retrafo, "InverterElement"))
  cpo = retrafo$cpo

  if (!is.null(retrafo$prev.retrafo.elt)) {
    upper.result = callCPORetrafoElement(retrafo$prev.retrafo.elt, data, build.inverter, prev.inverter)
    data = upper.result$data
    prev.inverter = upper.result$inverter
  }

  assertChoice(cpo$operating.type, c("feature", "target"))  # "retrafoless" has no retrafo

  tin = prepareRetrafoInput(data, cpo$dataformat, cpo$strict.factors, cpo$properties.raw,
    retrafo$shapeinfo.input, cpo$operating.type, cpo$name)

  # if we are target operating, have no target to change AND don't need the retrafostate
  # when we don't call cpo.retrafo at all.
  if (cpo$operating.type == "target" && !build.inverter && is.null(tin$indata$target)) {
    # neither data to modify nor an inverter to build
    return(list(data = data, inverter = prev.inverter))
  }

  tin$indata$state = retrafo$state
  result = do.call(cpo$trafo.funs$cpo.retrafo, insert(getBareHyperPars(cpo), tin$indata))


  if (cpo$operating.type != "target" || !is.null(tin$indata$target)) {
    # in retrafo we forgive the creation of 'missings', unless they are in adding.min
    properties.needed = cpo$properties$needed.max
    properties.adding = cpo$properties$adding.min
    if ("missings" %nin% properties.adding) {
      properties.needed = union(properties.needed, "missings")
    }

    data = handleRetrafoOutput(result$result, tin, properties.needed,
      properties.adding, cpo$convertto, retrafo$shapeinfo.output)
  }

  if (build.inverter && cpo$operating.type == "target") {
    invstate = if (cpo$constant.invert) retrafo$state.invert else result$state.invert
    prev.inverter = makeCPOInverter(cpo, invstate, prev.inverter, firstNonNull(tin$task, data))
  }

  list(data = data, inverter = prev.inverter)
}

# Basically wraps around callCPO with some checks and handling of attributes
# This is also called for CPORetrafo; there it calls 'callCPORetrafoElement'
#' @export
applyCPO.CPO = function(cpo, task) {
  if ("Task" %in% class(task) && hasTaskWeights(task)) {
    stop("CPO can not handle tasks with weights!")
  }

  prev.inverter = inverter(task)
  assert(is.nullcpo(prev.inverter), checkClass(prev.inverter, "CPOInverter"))


  prev.retrafo = retrafo(task)
  assert(is.nullcpo(prev.retrafo), checkClass(prev.retrafo, "CPORetrafo"))

  task = clearRI(task)

  if ("CPORetrafo" %in% class(cpo)) {
    result = callCPORetrafoElement(cpo$element, task, TRUE, prev.inverter)
    task = result$data
    retrafo(task) = prev.retrafo
  } else {
    result = callCPO(cpo, task, TRUE, prev.retrafo, TRUE, prev.inverter)
    task = result$data
    retrafo(task) = result$retrafo
  }
  inverter(task) = result$inverter
  task
}

# User-facing cpo retrafo application to a data object.
#' @export
applyCPO.CPORetrafo = applyCPO.CPO  # nolint

#' @export
predict.CPORetrafo = function(object, data, ...) {
  assert(length(list(...)) == 0)
  applyCPO(object, data)
}


# get par.vals with bare par.set names, i.e. the param names without the ID
# @param cpo [CPOPrimitive] the cpo to query
# @param include.unexported [logical(1)] whether to also get parameter values that are
#   not exported.
# @return [list] named list (bare parameter name) => (parameter value)
getBareHyperPars = function(cpo, include.unexported = TRUE) {
  assertClass(cpo, "CPOPrimitive")
  args = cpo$par.vals
  namestranslation = setNames(names2(cpo$bare.par.set$pars),
    names(cpo$par.set$pars))
  ret = c(setNames(args, namestranslation[names(args)]), if (include.unexported) cpo$unexported.pars)
  if (!length(ret)) {
    namedList()
  } else {
    ret
  }
}


# check whether all parameters of a CPO are given when it is called; either as a
# default value, during construction, or later during setHyperPars.
# @param par.vals [list] the par.vals of a CPO
# @param par.set [ParamSet] the param set of the CPO
# @param name [character(1)] the name of the CPO, to be used in error messages
# @return [invisible(NULL)]
checkAllParams = function(par.vals, par.set, name) {
  present = names(par.vals)

  # these parameters are either present or have fulfilled requirements
  needed = names(Filter(function(x) {
    x$id %in% names(par.vals) ||
          is.null(x$requires) || isTRUE(try(eval(x$requires, envir = par.vals), silent = TRUE))
  }, par.set$pars))

  missing.pars = setdiff(needed, present)
  if (length(missing.pars)) {
    plur = length(missing.pars) > 1
    stopf("Parameter%s %s of CPO %s %s missing\n%s", ifelse(plur, "s", ""),
      collapse(missing.pars, sep = ", "), name, ifelse(plur, "are", "is"),
      "Either give it during construction, or with setHyperPars.")
  }
}

Try the mlrCPO package in your browser

Any scripts or data that you put into this service are public.

mlrCPO documentation built on Nov. 18, 2022, 1:05 a.m.