R/ensemble.R

Defines functions detect_sdg

Documented in detect_sdg

#' Detect SDGs in text using ensemble model
#'
#' \code{detect_sdg} identifies SDGs in text using an ensemble model approach considering multiple existing SDG query systems and text length.
#'
#' \code{detect_sdg} implements a ensemble model to detect SDGs in text. The ensemble model combines the six systems implemented by \code{\link{detect_sdg_systems}} and text length in a random forest architecture. The ensemble model has been trained on three data sets with SDG labels assigned by experts and a matching number of synthetic texts generated by random sampling from a word frequency list. The user has the choice of multiple versions of the ensemble model that have been trained on different amounts of synthetic texts to adjust the sensitivity and specificity of the model. Increasing the amount of of synthetic data makes the ensemble more conservative, leading to increased sensitivity and decreased specificity.
#'
#' By default, \code{detect_sdg} implements the version of the ensemble model that has been trained on an equal amount of expert-labeled and synthetic data, providing a reasonable balance between sensitivity and specificity. For details, see article by Wulff et al. (2023).
#'
#'
#' @param text \code{character} vector or object of class \code{tCorpus} containing text in which SDGs shall be detected.
#' @param systems As of text2sdg 1.0.0 the `systems` argument of `detect_sdg()` is deprecated. This is because `detect_sdg()` now makes use of an ensemble approach that draws on all systems as well as on the text length, see --preprint-- for more information. The old version of `detect_sdg()` is available through the `detect_sdg_systems()` function.
#' @param output As of text2sdg 1.0.0 the `output` argument of `detect_sdg()` is deprecated. This is because `detect_sdg()` now makes use of an ensemble approach that draws on all systems as well as on the text length, see --preprint-- for more information. The old version of `detect_sdg()` is available through the `detect_sdg_systems()` function.
#' @param sdgs \code{numeric} vector with integers between 1 and 17 specifying the sdgs to identify in \code{text}. Defaults to \code{1:17}.
#' @param synthetic \code{character} vector specifying the ensemble version to be used. These versions vary in terms of the amount of synthetic data used in training (relative to the amount of expert-labeled data). Can be one or more of \code{"none"}, \code{"third"}, \code{"equal"}, and \code{"triple"}. The default is \code{"equal"}.
#' @param verbose \code{logical} specifying whether messages on the function's progress should be printed.
#'
#' @return The function returns a \code{tibble} containing the SDG hits found in the vector of documents. The columns of the \code{tibble} are described below. The \code{tibble} also includes as an attribute with name \code{"system_hits"} the predictions of the individual systems produced by \code{detect_sdg_systems()}.
#' \describe{
#'  \item{document}{Index of the element in \code{text} where match was found. Formatted as a factor with the number of levels matching the original number of documents.}
#'  \item{sdg}{Label of the SDG found in document.}
#'  \item{system}{The name of the ensemble system that produced the match.}
#'  \item{hit}{Index of hit for the Ensemble model.}
#' }
#'
#' @references Wulff, D. U., Meier, D., & Mata, R. (2023). Using novel data and ensemble models to improve automated SDG-labeling. arXiv
#' @importFrom ranger treeInfo
#'
#' @examples
#' \donttest{
#' # run sdg detection
#' hits <- detect_sdg(projects)
#'
#' # run sdg detection for sdg 3 only
#' hits <- detect_sdg(projects, sdgs = 3)
#'
#' # extract systems hits
#' attr(hits, "system_hits")
#' }
#' @export
detect_sdg <- function(text,
                       systems = lifecycle::deprecated(),
                       output = lifecycle::deprecated(),
                       sdgs = 1:17,
                       synthetic = c("equal"),
                       verbose = TRUE) {

  # Check if `system` argument is present
  if (lifecycle::is_present(systems)) {

    # Signal the deprecation to the user
    lifecycle::deprecate_stop("1.0.0", "text2sdg::detect_sdg(systems = )", details = "As of text2sdg 1.0.0, the `system` argument of `detect_sdg()` is deprecated. This is because `detect_sdg()` now implements an ensemble model that pools the predictions of all other systems and considers text length, see `?detect_sdg` for more information. The old functionality of `detect_sdg()` is now provided by the `detect_sdg_systems()` function.")
  }


  # Check if `system` argument is present
  if (lifecycle::is_present(output)) {

    # Signal the deprecation to the user
    lifecycle::deprecate_stop("1.0.0", "text2sdg::detect_sdg(output = )", details = "As of text2sdg 1.0.0, the `output` argument of `detect_sdg()` is deprecated. This is because `detect_sdg()` now implements an ensemble model that pools the predictions of all other systems and considers text length, see `?detect_sdg` for more information. The old functionality of `detect_sdg()` is now provided by the `detect_sdg_systems()` function.")
  }


  # make corpus
  if (inherits(text, "character")) {
    if (length(text) == 1 && text == "") {
      stop("Argument text must not be an empty string.")
    }
    corpus <- make_corpus(text)
  } else if (inherits(text, "tCorpus")) {
    corpus <- text
  } else {
    stop("Argument text must be either class character or corpustools::tCorpus.")
  }

  # test model selector
  if (any(!(synthetic %in% c("none", "third", "equal", "triple")))) {
    stop('Argument synthetic must be one or more of "none","third","equal", or "triple".')
  }

  # run systems
  if (verbose) cat("Running systems", sep = "")

  # run detect sdg
  system_hits <- detect_sdg_systems(
    text = corpus,
    sdgs = sdgs,
    systems = c("Aurora", "Elsevier", "Auckland", "SIRIS", "SDSN", "SDGO"),
    output = "documents",
    verbose = FALSE
  )

  # return empty tibble if no SDGs were detected
  if (nrow(system_hits) == 0) {
    return(tibble::tibble(
      document = factor(),
      sdg = character(),
      system = character(),
      hit = integer()
    ))
  }

  # add lengths
  if (verbose) cat("Obtaining text lengths", sep = "")
  lens <- table(corpus$tokens$doc_id)
  lens <- tibble::tibble(
    document = factor(names(lens)),
    n_words = c(lens)
  )
  # generate features
  if (verbose) cat("\nBuilding features", sep = "")
  tbl <- tibble::tibble(document = factor(1:corpus$n_meta)) %>%
    dplyr::left_join(system_hits %>%
      dplyr::select(document, sdg, system) %>%
      dplyr::mutate(hit = TRUE),
    by = "document"
    ) %>%
    dplyr::mutate(system = factor(system, levels = c("Aurora", "Elsevier", "Auckland", "SIRIS", "SDSN", "SDGO"))) %>%
    tidyr::complete(document, sdg, system) %>%
    dplyr::filter(!is.na(system)) %>%
    dplyr::mutate(hit = dplyr::case_when(is.na(hit) ~ FALSE, TRUE ~ hit)) %>%
    tidyr::pivot_wider(names_from = system, values_from = hit) %>%
    dplyr::left_join(lens, by = "document")

  # get around ::: warning
  predict.ranger <- utils::getFromNamespace("predict.ranger", "ranger")

  ignore_unused_imports <- function() {
    ranger::treeInfo
  }

  if (verbose) cat("\nRunning ensemble", sep = "")

  # newline
  cat("\n")

  hits <- list()
  sdgs <- paste0("SDG-", ifelse(sdgs < 10, "0", ""), sdgs) %>% sort()

  for (synt in synthetic) {

    # select model
    ensemble_sel <- text2sdgData::ensembles[[synt]]



    # run ensemble
    hits_ensemble <- list()
    for (s in 1:length(sdgs)) {
      m <- ensemble_sel[[sdgs[s]]]
      tbl_sdg <- tbl %>% dplyr::filter(sdg == sdgs[s])
      if (nrow(tbl_sdg) == 0) {
        next
      }
      if (s == 17) {
        tbl_sdg <- tbl_sdg %>% dplyr::select(document, dplyr::all_of(c("Aurora", "SDGO", "SDSN", "n_words")))
      }
      # set seed for ranger model
      set.seed(1)
      hits_ensemble[[s]] <- tibble::tibble(
        document = tbl_sdg %>% dplyr::pull(document),
        sdg = sdgs[s],
        pred = predict.ranger(m, data = tbl_sdg)$predictions
      )
    }

    # combine hits
    hits_ensemble <- dplyr::bind_rows(hits_ensemble) %>%
      dplyr::mutate(system = paste0("Ensemble ", !!synt))


    hits[[synt]] <- hits_ensemble
  }

  # combine hits from all ensemble models
  hits <- dplyr::bind_rows(hits)


  # return early if all ensemble predictions are 0
  if (all(hits$pred == 0)) {
    return(tibble::tibble(
      document = factor(),
      sdg = character(),
      system = character(),
      hit = integer()
    ))
  }

  # output
  hits <- hits %>%
    dplyr::filter(pred == 1) %>%
    dplyr::select(-pred) %>%
    dplyr::group_by(system) %>%
    dplyr::mutate(hit = 1:dplyr::n()) %>%
    dplyr::ungroup() %>%
    dplyr::arrange(document, sdg, system)

  # set attribute
  attr(hits, "system_hits") <- system_hits

  # out
  hits
}

Try the text2sdg package in your browser

Any scripts or data that you put into this service are public.

text2sdg documentation built on March 31, 2023, 7:22 p.m.