R/gp_helpers.R

Defines functions dissim_sample pick_candidate pred_gp quiet_pred_gp fit_gp partial_encode check_gp make_kernel find_qual_param

# Determine any qualitative parameters and their ranges

find_qual_param <- function(pset) {
  is_qual <- purrr::map_lgl(pset$object, ~ inherits(.x, "qual_param"))
  if (!any(is_qual)) {
    return(list())
  }

  pset <- pset[is_qual, ]

  possible_lvl <- purrr::map(pset$object, ~ .x$values)
  names(possible_lvl) <- pset$id
  possible_lvl
}

# Leave as-is but scale the others; order by parameter id in pset

# Create the kernel for the continuous parameters
# Add a new kernel for each qualitative parameter

make_kernel <- function(pset, lvls) {
  qual_ind <- which(pset$id %in% names(lvls))
  quant_ind <- setdiff(seq_along(pset$id), qual_ind)
  quant_avail <- length(quant_ind) > 0

  if (length(qual_ind) == 0) {
    return(GauPro::k_Matern32(D = length(quant_ind)))
  }

  num_kernels <- length(qual_ind) + quant_avail

  kernels <- vector(mode = "list", length = num_kernels)
  kern_count <- 0

  if (quant_avail) {
    kernels[[1]] <- GauPro::k_IgnoreIndsKernel(
      k = GauPro::k_Matern32(D = length(quant_ind)),
      ignoreinds = qual_ind
    )
    kern_count <- 1
  }

  for (i in seq_along(qual_ind)) {
    kern_count <- kern_count + 1
    kernels[[kern_count]] <- GauPro::k_FactorKernel(
      D = 1,
      xindex = qual_ind[i],
      nlevels = length(lvls[[1]])
    )
  }

  for (i in 1:num_kernels) {
    if (i == 1) {
      res <- kernels[[1]]
    } else {
      res <- res * kernels[[i]]
    }
  }
  res
}

check_gp <- function(x, control) {
  model_fail <- inherits(x, "try-error")
  gp_threshold <- 0.1
  if (!model_fail) {
    loo_res <- summary(x)
    loo_bad <- loo_res$coverage95LOO < gp_threshold
    loo_rsq <- loo_res$r.squaredLOO
  } else {
    loo_bad <- TRUE
    loo_rsq <- 0.0
  }
  # convergence?
  if (model_fail) {
    message_wrap(
      cli::format_inline("GP failed: {as.character(x)}"),
      prefix = cli::symbol$checkbox_circle_on,
      color_text = get_tune_colors()$message$danger
    )
  } else if (loo_rsq < gp_threshold) {
    if (control$verbose_iter) {
      msg <- cli::format_inline(
        "GP has a LOO R\u00B2 of {round(loo_rsq * 100, 1)}% and is unreliable."
      )
      message_wrap(
        msg,
        prefix = cli::symbol$checkbox_circle_on,
        color_text = get_tune_colors()$message$danger
      )
    }
  } else if (loo_bad) {
    if (control$verbose_iter) {
      msg <- cli::format_inline(
        "GP has a coverage rate < {round(gp_threshold * 100, 1)}% and is unreliable."
      )
      message_wrap(
        msg,
        prefix = cli::symbol$checkbox_circle_on,
        color_text = get_tune_colors()$message$danger
      )
    }
  }
  list(use = !loo_bad && !model_fail && loo_rsq > gp_threshold, rsq = loo_rsq)
}

# encode_set() was created to work on all types of tuning parameters; usage of
# GauPro means that we should not encode qualitative tuning parameters so we
# need a wrapper
partial_encode <- function(dat, pset) {
  qual_info <- find_qual_param(pset)

  if (any(names(dat) == "mean")) {
    outcomes <- dat$mean
  }

  normalized <- encode_set(
    dat |> dplyr::select(dplyr::all_of(pset$id)),
    pset = pset,
    as_matrix = FALSE
  )

  # Replace with the original data when qualitative parameters
  # GauPro::gpkm can take factor encodings to work
  for (i in seq_along(qual_info)) {
    nm <- names(qual_info)[i]
    normalized[[nm]] <- factor(dat[[nm]], levels = qual_info[[i]])
  }
  if (any(names(dat) == "mean")) {
    normalized$.outcome <- dat$mean
  }
  normalized
}

# ------------------------------------------------------------------------------

fit_gp <- function(
  dat,
  pset,
  metric,
  eval_time = NULL,
  control,
  ...
) {
  tune::empty_ellipses(...)

  dat <- dat |> dplyr::filter(.metric == metric)

  if (!is.null(eval_time)) {
    dat <- dat |> dplyr::filter(.eval_time == eval_time)
  }

  dat <- dat |>
    check_gp_data() |>
    dplyr::select(dplyr::all_of(pset$id), mean)

  qual_info <- find_qual_param(pset)
  num_pred <- nrow(pset)
  num_cand <- nrow(dat)

  normalized <- partial_encode(dat, pset)
  # gaupro will look for * or : in names and fail :-(
  # change names and change back later
  og_names <- colnames(normalized)
  gp_names <- make.names(og_names)
  colnames(normalized) <- gp_names

  gp_kernel <- make_kernel(pset, qual_info)

  if (num_cand <= num_pred && num_cand > 0 & control$verbose_iter) {
    msg <- cli::format_inline(
      "The Gaussian process model is being fit using {num_pred} feature{?s} but
			only has {num_cand} data point{?s} to do so. This may cause errors or a
			poor model fit."
    )
    message_wrap(
      msg,
      prefix = "!",
      color_text = get_tune_colors()$message$warning
    )
  }

  gp_fit <- try(
    GauPro::gpkm(
      .outcome ~ .,
      data = normalized,
      kernel = gp_kernel,
      verbose = 0,
      restarts = 5,
      nug.est = FALSE,
      parallel = FALSE
    ),
    silent = TRUE
  )

  new_check <- check_gp(gp_fit, control)

  if (control$verbose_iter) {
    if (new_check$use) {
      msg <- cli::format_inline(
        "Gaussian process model (LOO R\u00B2: {round(new_check$rsq * 100, 1)}%)"
      )
      message_wrap(
        msg,
        prefix = cli::symbol$tick,
        color_text = get_tune_colors()$message$success
      )
    } else {
      message_wrap(
        "Gaussian process model failed",
        prefix = cli::symbol$tick,
        color_text = get_tune_colors()$message$danger
      )
    }
  }

  list(
    fit = gp_fit,
    use = new_check$use,
    rsq = new_check$rsq,
    tr = normalized
  )
}

# ------------------------------------------------------------------------------

quiet_pred_gp <- function(object, new_data, ...) {
  sssh_pred <- purrr::quietly(object$fit$pred)
  res <- sssh_pred(new_data, ...)
  wrn <- res$warnings

  if (length(wrn) > 0) {
    wrn <- wrn[!grepl("Too small", wrn)]
    for (i in seq_along(wrn)) {
      cli::cli_warn("{wrn[i]}")
    }
  }
  res$result
}

pred_gp <- function(object, pset, size = 5000, current = NULL, control) {
  candidates <- dials::grid_space_filling(
    pset,
    size = size,
    type = "latin_hypercube"
  ) |>
    dplyr::distinct()

  if (!object$use) {
    x <- partial_encode(candidates, pset)
    colnames(x) <- make.names(colnames(x))
    x_old <- object$tr
    x_old <- x_old[, names(x)]

    # Remove existing points
    x <- dplyr::anti_join(x, x_old, by = make.names(pset$id))

    keep_ind <- dissim_sample(x_old, x, pset, max_n = Inf)
    candidates <- candidates[keep_ind, ] |>
      dplyr::mutate(.mean = NA_real_, .sd = NA_real_)

    if (control$verbose_iter) {
      message_wrap(
        "Generating a candidate as far away from existing points as possible.",
        prefix = cli::symbol$info,
        color_text = get_tune_colors()$message$info
      )
    }

    return(candidates)
  } else {
    if (control$verbose_iter) {
      message_wrap(
        paste("Generating", nrow(candidates), "candidates."),
        prefix = cli::symbol$info,
        color_text = get_tune_colors()$message$info
      )
    }
  }

  if (!is.null(current)) {
    candidates <- candidates |>
      dplyr::anti_join(current, by = pset$id)
  }

  if (nrow(candidates) == 0) {
    message_wrap(
      "No remaining candidate models",
      prefix = cli::symbol$tick,
      color_text = get_tune_colors()$message$warning
    )
    return(candidates |> dplyr::mutate(.mean = NA_real_, .sd = NA_real_))
  }

  x <- partial_encode(candidates, pset)
  colnames(x) <- make.names(colnames(x))

  gp_pred <- quiet_pred_gp(object, x, se.fit = TRUE)

  gp_pred <- tibble::as_tibble(gp_pred) |>
    dplyr::select(.mean = mean, .sd = se)
  dplyr::bind_cols(candidates, gp_pred)
}


pick_candidate <- function(results, info, control) {
  bad_gp <- all(is.na(results$.mean))
  if (!bad_gp & info$uncertainty < control$uncertain) {
    results <- results |>
      dplyr::slice_max(objective, n = 1, with_ties = FALSE)
  } else {
    if (control$verbose_iter) {
      message_wrap(
        "Uncertainty sample",
        prefix = cli::symbol$info,
        color_text = get_tune_colors()$message$info
      )
    }
    n_top <- max(floor(.1 * nrow(results)), 1L)
    results <- results |>
      dplyr::arrange(dplyr::desc(.sd)) |>
      dplyr::slice(seq_len(n_top)) |>
      dplyr::sample_n(1)
  }
  results
}

dissim_sample <- function(ref_data, candidates, pset, max_n = Inf) {
  max_n <- min(max_n, nrow(candidates))
  candidates <- candidates[1:max_n, , drop = FALSE]
  all_data <- dplyr::bind_rows(ref_data, candidates)

  # Deal with any qualitative predictors by casting them to c(0,1)
  qual_info <- find_qual_param(pset)
  if (length(qual_info) > 0) {
    for (i in seq_along(qual_info)) {
      nm <- names(qual_info)[i]
      uniq <- sort(unique(all_data[[nm]]))
      all_data[[nm]] <- as.character(all_data[[nm]])
      all_data[[nm]] <- factor(all_data[[nm]], levels = uniq)
    }
  }
  all_data <- stats::model.matrix(~ . + 0, data = all_data)

  n_ref <- nrow(ref_data)
  n_all <- nrow(all_data)
  distances <- stats::dist(all_data)
  distances <- as.matrix(distances)
  distances <- distances[1:n_ref, (n_ref + 1):n_all]
  min_distances <- apply(distances, 2, function(x) min(x[x > 0]))
  max_ind <- which.max(min_distances)[1]
  max_ind
}

Try the tune package in your browser

Any scripts or data that you put into this service are public.

tune documentation built on April 17, 2026, 5:07 p.m.