R/geodist.R

Defines functions tidy.step_geodist print.step_geodist bake.step_geodist geo_dist_calc_lat_lon geo_dist_calc_xy prep.step_geodist check_is_lat_lon step_geodist_new step_geodist

Documented in step_geodist tidy.step_geodist

#' Distance between two locations
#'
#' `step_geodist()` creates a *specification* of a recipe step that will
#' calculate the distance between points on a map to a reference location.
#'
#' @inheritParams step_pca
#' @inheritParams step_center
#' @param lon,lat Selector functions to choose which variables are
#'  used by the step. See [selections()] for more details.
#' @param ref_lon,ref_lat Single numeric values for the location
#'  of the reference point.
#' @param is_lat_lon A logical: Are coordinates in latitude and longitude? If
#'  `TRUE` the Haversine formula is used and the returned result is meters. If
#'  `FALSE` the Pythagorean formula is used. Default is `TRUE` and for recipes
#'  created from previous versions of recipes, a value of `FALSE` is used.
#' @param log A logical: should the distance be transformed by
#'  the natural log function?
#' @param name A single character value to use for the new
#'  predictor column. If a column exists with this name, an error is
#'  issued.
#' @template step-return
#' @family multivariate transformation steps
#' @references https://en.wikipedia.org/wiki/Haversine_formula
#' @export
#' @details `step_geodist` uses the Pythagorean theorem to calculate Euclidean
#'  distances if `is_lat_lon` is FALSE. If `is_lat_lon` is TRUE, the Haversine
#'  formula is used to calculate the great-circle distance in meters.
#'
#' # Tidying
#'
#' When you [`tidy()`][tidy.recipe()] this step, a tibble with columns
#' echoing the values of `lat`, `lon`, `ref_lat`, `ref_lon`,
#' `is_lat_lon`, `name`, and `id` is returned.
#'
#' @template case-weights-not-supported
#'
#' @examplesIf rlang::is_installed("modeldata")
#' data(Smithsonian, package = "modeldata")
#'
#' # How close are the museums to Union Station?
#' near_station <- recipe(~., data = Smithsonian) %>%
#'   update_role(name, new_role = "location") %>%
#'   step_geodist(
#'     lat = latitude, lon = longitude, log = FALSE,
#'     ref_lat = 38.8986312, ref_lon = -77.0062457,
#'     is_lat_lon = TRUE
#'   ) %>%
#'   prep(training = Smithsonian)
#'
#' bake(near_station, new_data = NULL) %>%
#'   arrange(geo_dist)
#'
#' tidy(near_station, number = 1)
step_geodist <- function(recipe,
                         lat = NULL,
                         lon = NULL,
                         role = "predictor",
                         trained = FALSE,
                         ref_lat = NULL,
                         ref_lon = NULL,
                         is_lat_lon = TRUE,
                         log = FALSE,
                         name = "geo_dist",
                         columns = NULL,
                         keep_original_cols = TRUE,
                         skip = FALSE,
                         id = rand_id("geodist")) {
  if (length(ref_lon) != 1 || !is.numeric(ref_lon)) {
    rlang::abort("`ref_lon` should be a single numeric value.")
  }
  if (length(ref_lat) != 1 || !is.numeric(ref_lat)) {
    rlang::abort("`ref_lat` should be a single numeric value.")
  }
  if (length(is_lat_lon) != 1 || !is.logical(is_lat_lon)) {
    rlang::abort("`is_lat_lon` should be a single logical value.")
  }
  if (length(log) != 1 || !is.logical(log)) {
    rlang::abort("`log` should be a single logical value.")
  }
  if (length(name) != 1 || !is.character(name)) {
    rlang::abort("`name` should be a single character value.")
  }

  if (is_lat_lon && abs(ref_lat) > 90.0) {
    rlang::abort("`ref_lat` should be between -90 and 90")
  }
  if (is_lat_lon && abs(ref_lon) > 180.0) {
    rlang::abort("`ref_lon` should be between -180 and 180")
  }


  add_step(
    recipe,
    step_geodist_new(
      lon = enquos(lon),
      lat = enquos(lat),
      role = role,
      trained = trained,
      ref_lon = ref_lon,
      ref_lat = ref_lat,
      is_lat_lon = is_lat_lon,
      log = log,
      name = name,
      columns = columns,
      keep_original_cols = keep_original_cols,
      skip = skip,
      id = id
    )
  )
}

step_geodist_new <-
  function(lon, lat, role, trained, ref_lon, ref_lat, is_lat_lon,
           log, name, columns, keep_original_cols, skip, id) {
    step(
      subclass = "geodist",
      lon = lon,
      lat = lat,
      role = role,
      trained = trained,
      ref_lon = ref_lon,
      ref_lat = ref_lat,
      is_lat_lon = is_lat_lon,
      log = log,
      name = name,
      columns = columns,
      keep_original_cols = keep_original_cols,
      skip = skip,
      id = id
    )
  }

# for backward compatibility
check_is_lat_lon <- function(x) {
  if (!"is_lat_lon" %in% names(x)) {
    x$is_lat_lon <- FALSE
  }
  x
}


#' @export
prep.step_geodist <- function(x, training, info = NULL, ...) {
  lon_name <- recipes_eval_select(x$lon, training, info)
  lat_name <- recipes_eval_select(x$lat, training, info)
  check_type(training[, c(lon_name, lat_name)], types = c("double", "integer"))

  x <- check_is_lat_lon(x)

  if (length(lon_name) > 1) {
    rlang::abort("`lon` should resolve to a single column name.")
  }
  check_type(training[, lon_name], types = c("double", "integer"))

  if (length(lat_name) > 1) {
    rlang::abort("`lat` should resolve to a single column name.")
  }
  check_type(training[, lat_name], types = c("double", "integer"))

  step_geodist_new(
    lon = x$lon,
    lat = x$lat,
    role = x$role,
    trained = TRUE,
    ref_lon = x$ref_lon,
    ref_lat = x$ref_lat,
    is_lat_lon = x$is_lat_lon,
    log = x$log,
    name = x$name,
    columns = c(lat_name, lon_name),
    keep_original_cols = get_keep_original_cols(x),
    skip = x$skip,
    id = x$id
  )
}

# TODO case weights for distances
geo_dist_calc_xy <- function(x_1, y_1, x_2, y_2) {
  sqrt((x_1 - x_2)^2L + (y_1 - y_2)^2L)
}


# earth_radius = 6371e3 in meters
geo_dist_calc_lat_lon <- function(x_1, y_1, x_2, y_2, earth_radius = 6371e3,
                                  call = caller_env()) {
  if (any(abs(x_1) > 180.0)) {
    rlang::abort("All `lon` values should be between -180 and 180", call = call)
  }

  if (any(abs(y_1) > 90.0)) {
    rlang::abort("All `lat` values should be between -90 and 90", call = call)
  }

  to_rad <- pi / 180.0

  y_1 <- y_1 * to_rad
  y_2 <- y_2 * to_rad

  x_1 <- x_1 * to_rad
  x_2 <- x_2 * to_rad

  delta_lat <- (y_2 - y_1) / 2.0
  delta_lon <- (x_2 - x_1) / 2.0

  a <- sin(delta_lat)^2L + cos(y_1) * cos(y_2) * sin(delta_lon)^2L
  c <- 2.0 * atan2(sqrt(a), sqrt(1 - a))

  earth_radius * c
}


#' @export
bake.step_geodist <- function(object, new_data, ...) {
  col_names <- names(object$columns)
  check_new_data(col_names, object, new_data)

  if (length(col_names) == 0) {
    return(new_data)
  }

  object <- check_is_lat_lon(object)

  if (object$is_lat_lon) {
    dist_vals <-
      geo_dist_calc_lat_lon(
        new_data[[col_names[2]]], # lon
        new_data[[col_names[1]]], # lat
        object$ref_lon,
        object$ref_lat
      )
  } else {
    dist_vals <-
      geo_dist_calc_xy(
        new_data[[col_names[2]]], # lon
        new_data[[col_names[1]]], # lat
        object$ref_lon,
        object$ref_lat
      )
  }

  if (object$log) {
    dist_vals <- log(dist_vals)
  }

  geo_data <- tibble(dist_vals)
  names(geo_data) <- object$name

  geo_data <- check_name(geo_data, new_data, object, newname = object$name)

  new_data <- vec_cbind(new_data, geo_data)
  new_data <- remove_original_cols(new_data, object, col_names)
  new_data
}

print.step_geodist <-
  function(x, width = max(20, options()$width - 30), ...) {
    title <- paste(
      "Geographical distances from",
      format(x$ref_lat, digits = 10), "x",
      format(x$ref_lon, digits = 10), "using "
    )
    print_step(x$columns, c(x$lat, x$lon), x$trained, title, width)
    invisible(x)
  }



#' @rdname tidy.recipe
#' @export
tidy.step_geodist <- function(x, ...) {
  x <- check_is_lat_lon(x)

  if (length(x$columns) == 0) {
    res <- tibble(
      latitude = character(),
      longitude = character(),
      ref_latitude = double(),
      ref_longitude = double(),
      is_lat_lon = logical(),
      name = character()
    )
  } else if (is_trained(x)) {
    res <- tibble(
      latitude = x$columns[1],
      longitude = x$columns[2],
      ref_latitude = x$ref_lat,
      ref_longitude = x$ref_lon,
      is_lat_lon = x$is_lat_lon,
      name = x$name
    )
  } else {
    res <- tibble(
      latitude = sel2char(x$lat),
      longitude = sel2char(x$lon),
      ref_latitude = x$ref_lat,
      ref_longitude = x$ref_lon,
      is_lat_lon = x$is_lat_lon,
      name = x$name
    )
  }

  res$id <- x$id
  res
}

Try the recipes package in your browser

Any scripts or data that you put into this service are public.

recipes documentation built on Aug. 26, 2023, 1:08 a.m.