R/dplyr-mutate.R

Defines functions ensure_named_exprs check_transmute_args transmute.arrow_dplyr_query mutate.arrow_dplyr_query

# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.


# The following S3 methods are registered on load if dplyr is present

mutate.arrow_dplyr_query <- function(.data,
                                     ...,
                                     .by = NULL,
                                     .keep = c("all", "used", "unused", "none"),
                                     .before = NULL,
                                     .after = NULL) {
  call <- match.call()
  out <- as_adq(.data)

  by <- compute_by({{ .by }}, out, by_arg = ".by", data_arg = ".data")

  if (by$from_by) {
    out$group_by_vars <- by$names
  }
  grv <- out$group_by_vars
  expression_list <- expand_across(out, quos(...), exclude_cols = grv)
  exprs <- ensure_named_exprs(expression_list)

  .keep <- match.arg(.keep)
  .before <- enquo(.before)
  .after <- enquo(.after)

  if (.keep %in% c("all", "unused") && length(exprs) == 0) {
    # Nothing to do
    return(out)
  }

  # Restrict the cases we support for now
  has_aggregations <- any(unlist(lapply(exprs, all_funs)) %in% names(agg_funcs))
  if (has_aggregations) {
    # ARROW-13926
    # mutate() on a grouped dataset does calculations within groups
    # This doesn't matter on scalar ops (arithmetic etc.) but it does
    # for things with aggregations (e.g. subtracting the mean)
    return(abandon_ship(call, .data, "window functions not currently supported in Arrow"))
  }

  mask <- arrow_mask(out)
  results <- list()
  for (i in seq_along(exprs)) {
    # Iterate over the indices and not the names because names may be repeated
    # (which overwrites the previous name)
    new_var <- names(exprs)[i]
    results[[new_var]] <- arrow_eval(exprs[[i]], mask)
    if (inherits(results[[new_var]], "try-error")) {
      msg <- handle_arrow_not_supported(
        results[[new_var]],
        format_expr(exprs[[i]])
      )
      return(abandon_ship(call, .data, msg))
    } else if (!inherits(results[[new_var]], "Expression") &&
      !is.null(results[[new_var]])) {
      # We need some wrapping to handle literal values
      if (length(results[[new_var]]) != 1) {
        msg <- paste0("In ", new_var, " = ", format_expr(exprs[[i]]), ", only values of size one are recycled")
        return(abandon_ship(call, .data, msg))
      }
      results[[new_var]] <- Expression$scalar(results[[new_var]])
    }
    # Put it in the data mask too
    mask[[new_var]] <- mask$.data[[new_var]] <- results[[new_var]]
  }

  old_vars <- names(out$selected_columns)
  # Note that this is names(exprs) not names(results):
  # if results$new_var is NULL, that means we are supposed to remove it
  new_vars <- names(exprs)

  # Assign the new columns into the out$selected_columns
  for (new_var in new_vars) {
    out$selected_columns[[new_var]] <- results[[new_var]]
  }

  # Deduplicate new_vars and remove NULL columns from new_vars
  new_vars <- intersect(union(new_vars, grv), names(out$selected_columns))

  # Respect .before and .after
  if (!quo_is_null(.before) || !quo_is_null(.after)) {
    new <- setdiff(new_vars, old_vars)
    out <- dplyr::relocate(out, all_of(new), .before = !!.before, .after = !!.after)
  }

  # Respect .keep
  if (.keep == "none") {
    ## for consistency with dplyr, this appends new columns after existing columns
    ## by specifying the order
    new_cols_last <- c(intersect(old_vars, new_vars), setdiff(new_vars, old_vars))
    out$selected_columns <- out$selected_columns[new_cols_last]
  } else if (.keep != "all") {
    # "used" or "unused"
    used_vars <- unlist(lapply(exprs, all.vars), use.names = FALSE)
    if (.keep == "used") {
      out$selected_columns[setdiff(old_vars, used_vars)] <- NULL
    } else {
      # "unused"
      out$selected_columns[intersect(old_vars, used_vars)] <- NULL
    }
  }

  if (by$from_by) {
    out$group_by_vars <- character()
  }

  # Even if "none", we still keep group vars
  ensure_group_vars(out)
}
mutate.Dataset <- mutate.ArrowTabular <- mutate.RecordBatchReader <- mutate.arrow_dplyr_query

transmute.arrow_dplyr_query <- function(.data, ...) {
  dots <- check_transmute_args(...)
  .data <- as_adq(.data)
  grv <- .data$group_by_vars
  expression_list <- expand_across(.data, dots, exclude_cols = grv)

  has_null <- map_lgl(expression_list, quo_is_null)
  .data <- dplyr::mutate(.data, !!!expression_list, .keep = "none")
  if (is_empty(expression_list) || any(has_null)) {
    return(.data)
  }

  ## keeping with: https://github.com/tidyverse/dplyr/issues/6086
  cur_exprs <- map_chr(expression_list, as_label)
  transmute_order <- names(cur_exprs)
  transmute_order[!nzchar(transmute_order)] <- cur_exprs[!nzchar(transmute_order)]
  dplyr::select(.data, all_of(c(grv, transmute_order)))
}
transmute.Dataset <- transmute.ArrowTabular <- transmute.RecordBatchReader <- transmute.arrow_dplyr_query

# This function is a copy of dplyr:::check_transmute_args at
# https://github.com/tidyverse/dplyr/blob/main/R/mutate.R
check_transmute_args <- function(..., .keep, .before, .after) {
  if (!missing(.keep)) {
    abort("`transmute()` does not support the `.keep` argument")
  }
  if (!missing(.before)) {
    abort("`transmute()` does not support the `.before` argument")
  }
  if (!missing(.after)) {
    abort("`transmute()` does not support the `.after` argument")
  }
  enquos(...)
}

ensure_named_exprs <- function(exprs) {
  # Check for unnamed expressions and fix if any
  unnamed <- !nzchar(names(exprs))
  # Deparse and take the first element in case they're long expressions
  names(exprs)[unnamed] <- map_chr(exprs[unnamed], format_expr)
  exprs
}

Try the arrow package in your browser

Any scripts or data that you put into this service are public.

arrow documentation built on Nov. 25, 2023, 1:09 a.m.