R/datasets.R

Defines functions hf_load_dataset

Documented in hf_load_dataset

#' Load a dataset from the Hugging Face Hub!
#'
#' Function has multiple uses - getting pre-made datasets for exploratory analysis, or to figure as means for evaluating your fine-tuned models.
#'
#' @param dataset The name of a Hugging Face dataset saved on the Hub. Use hf_list_models() to find a dataset.
#' @param label_conversion Whether to add an additional column converting labels from str2int or int2str?
#' @param split Usually one of 'train' , 'test', 'validation' however, check the dataset's meta data at The Hub first.
#'
#' @returns A Hugging Face data set as a tibble
#' @export
#' @seealso
#' \url{https://huggingface.co/docs/datasets/index}
#' @examples
#' (emo_all_splits <- hf_load_dataset('emo'))
#'
#' (imdb_train <- hf_load_dataset('imdb', split = "train"))
#'
#'
hf_load_dataset <- function(dataset, split = NULL,
                            label_conversion = c("str2int", "int2str", NULL)){

  hf_import_datasets_transformers()

  #Set the default value of label_conversion to 'intstr' unless specified, in which case match the input
  # label_conversion <- match.arg(if (missing(label_conversion)) "int2str" else label_conversion, c("str2int", "int2str", "NA"))



  #read in the dataset in Hugging Face datasets format.
  .dataset <- reticulate::py$load_dataset(dataset)
  available_splits <- paste0(names(.dataset), collapse = ";")

  # Return an error message if inputted split isn't found in the dataset's metadata
  if(!is.null(split) && !split %in% names(.dataset)){
    stop(paste0("The split you're looking for is not available for this dataset, try one of: ", available_splits))
  }

  # If no split is supplied, set splits as all the splits.
  if(is.null(split)){
    #Get all of the splits for later mapping
    splits <- names(.dataset)
  } else {
    splits <- split #Checking this works, if it does should refactor as it makes no sense like this
  }

  # Map over splits to read in dataset as pandas
  datasets <- purrr::map(splits, ~.dataset[[.x]]$to_pandas() %>%
                           tibble::as_tibble())
  names(datasets) <- splits

  unsupervised <- NULL #instantiate an object for unsupervised splits (ones in which labels will not be present)

  #If there is an unsupervised split, separate it from the datasets for mapping over int2str/str2int
  if('unsupervised' %in% splits && split == "unsupervised"){
    message("Unsupervised detected in splits, so adding labels to other splits and leaving unsupervised as is")

    unsupervised <- datasets[["unsupervised"]]
    datasets <- datasets[!stringr::str_detect(names(datasets), "unsupervised")]
  }


  #get int2str & str2int which can later be called directly on the label variable
  if(!is.null(label_conversion)){
    x <- splits[[1]]
    x <- .dataset[[x]]
    x <- x[["features"]]

    #Cover cases in which there is no label variable in the data set
    if("label" %in% names(x)) {
      x <- x[["label"]]
      int2str <- x[["int2str"]]
      str2int <- x[["str2int"]]
    } else{
      message("'label' not found in data set's column names, defaulting to no label conversion")
      label_conversion <- NULL
    }

  }



  if(!is.null(label_conversion) && label_conversion == "int2str"){
    label_names <- purrr::map(datasets, ~int2str(.x[["label"]]))
    datasets <- purrr::map2(.x = datasets, .y = label_names, .f = ~ .x %>% dplyr::mutate(label_name = .y))
  }
  if(!is.null(label_conversion) &&label_conversion == "str2int"){
    label_ids <- purrr::map(datasets, ~str2int(.x[["label"]]))
    datasets <- purrr::map2(.x = datasets, .y = label_ids, .f = ~ .x %>% dplyr::mutate(label_id = .y))
  }

  #Check for non-df objects and then filter them out (e.g. label, text etc.)
  logicals <- purrr::map(datasets, class) %>%
    purrr::map_lgl(~ "data.frame" %in% .x)
  datasets <- datasets[logicals]

  #If user asks for unsupervised split, give them it, if not, give them datasets and if datasets is a list of length 1, unlist it and return the tibble.
  if(!is.null(unsupervised) && split == "unsupervised"){
    return(unsupervised)
  }else if(length(datasets) == 1){
    datasets <- tibble::as_tibble(datasets[[1]])
    return(datasets)}
  else{
    return(datasets)
  }

}
##' examples
##' dontrun{
##' # Retrieve the 'emo' dataset
##' emo <- hf_load_dataset("emo")
##' emo
##' # Extract and visualize the training split in the emotion data
##' hf_load_dataset("emo", as_tibble = TRUE, split = "train") %>%
##'   dplyr::add_count(label) %>%
##'   dplyr::mutate(
##'     label = forcats::fct_reorder(as.factor(label), n)
##'   ) %>%
##'   ggplot2::ggplot(ggplot2::aes(label)) +
##'   ggplot2::geom_bar()
##' }

#Old version, temporarily maintained for posterity
# function(dataset,
#          split = NULL,
#          as_tibble = FALSE,
#          label_name = NULL,
#          ...) {
#   hf_import_datasets_transformers()
#
#   if (!as_tibble & !is.null(label_name)) {
#     stop("label_name must be specified with as_tibble = TRUE")
#   }
#
#
#   # Get this for str2int and int2str mapping later
#   dataset_base <- reticulate::py$load_dataset(dataset)
#
#   # If we just want the basic data set, unedited:
#   if (is.null(split) & !as_tibble & is.null(label_name)) {
#     return(reticulate::py$load_dataset(dataset, ...))
#
#     # If we want a specific, unedited split:
#   } else if (!is.null(split) & !as_tibble & is.null(label_name)) {
#     return(reticulate::py$load_dataset(dataset, split = split, ...))
#
#     # If we want all splits as a tibble without label_name specified:
#   } else if (is.null(split) & as_tibble == TRUE) {
#     dataset_load <- reticulate::py$load_dataset(dataset)
#     split_names <- names(dataset_load)
#     hf_data <- NULL
#
#     for (name in split_names) {
#       hf_dataset_loop <-
#         reticulate::py$load_dataset(dataset, split = name, ...)
#
#       hf_data_loop <- dplyr::as_tibble(hf_dataset_loop$to_pandas())
#
#       hf_data <- dplyr::bind_rows(hf_data, hf_data_loop)
#     }
#
#     # Adding the str2int or int2str logic, and if argument is blank, return the dataset
#     if (is.null(label_name)) {
#       return(hf_data)
#     } else if (label_name == "int2str") {
#       hf_data$label_name <- dataset_base$train$features$label$int2str(as.integer(hf_data$label))
#     } else if (label_name == "str2int") {
#       hf_data$label_name <- dataset_base$train$features$label$str2int(as.character(hf_data$label))
#     }
#
#     return(hf_data)
#
#     # Now add splits logic
#   } else if (!is.null(split) & as_tibble == TRUE) {
#     hf_data <- tibble::tibble(reticulate::py$load_dataset(dataset, split = split)$to_pandas())
#
#     # Now add str2int and int2str logic for splits (this could be refactored to not duplicate code, but ok for now)
#     if (is.null(label_name)) {
#       return(hf_data)
#     } else if (label_name == "int2str") {
#       hf_data$label_name <- dataset_base$train$features$label$int2str(as.integer(hf_data$label))
#     } else if (label_name == "str2int") {
#       hf_data$label_name <- dataset_base$train$features$label$str2int(as.character(hf_data$label))
#     }
#
#     return(hf_data)
#   }
# }
farach/huggingfaceR documentation built on Feb. 4, 2023, 10:31 p.m.