R/tokenize.R

Defines functions tagger_impl tagger_inner

#' @noRd
tagger_inner <- function(x, sys_dic, user_dic, max_grouping_len, verbose) {
  res <- vbrt(x, sys_dic, user_dic, max_grouping_len) %>%
    dplyr::as_tibble() %>%
    dplyr::mutate(
      sentence_id = as.integer(.data$sentence_id),
      token_id = as.integer(.data$token_id)
    )
  to_omit <- if (verbose) 0 else 4
  dplyr::select(res, 1:dplyr::last_col(to_omit))
}

#' Wrapper that takes a tagger function
#'
#' @details
#' `tagger` is expected to be a function that takes single argument
#' (character vector to be tokenized) and returns a data.frame
#' containing the following columns:
#'
#' * `sentence_id`
#' * `token`
#' * `feature`
#'
#' @param sentences A character vector to be tokenized.
#' @param docnames A character vector that indicates document names.
#' @param split Logical.
#' @param tagger A tagger function created by [create_tagger()].
#' @returns A tibble.
#' @noRd
tagger_impl <- function(sentences, docnames, split, tagger) {
  if (isTRUE(split)) {
    res <-
      stringi::stri_split_boundaries(sentences, type = "sentence") %>%
      rlang::as_function(~ {
        sizes <- lengths(.)
        dplyr::left_join(
          tagger(unlist(., use.names = FALSE)),
          data.frame(
            doc_id = rep_len(docnames, sizes),
            sentence_id = seq_len(sum(sizes))
          ),
          by = "sentence_id"
        )
      })() %>%
      dplyr::mutate(
        sentence_id = dplyr::consecutive_id(.data$sentence_id),
        .by = "doc_id"
      )
  } else {
    res <-
      tagger(sentences) %>%
      dplyr::left_join(
        data.frame(
          sentence_id = seq_along(sentences),
          doc_id = docnames
        ),
        by = "sentence_id"
      )
  }
  res %>%
    dplyr::mutate(doc_id = factor(.data$doc_id, unique(.data$doc_id))) %>%
    dplyr::relocate("doc_id", dplyr::everything())
}

#' Create a tagger function
#'
#' @param sys_dic Character scalar; path to the system dictionary for 'vibrato'.
#' @param user_dic Character scalar; path to the user dictionary for 'vibrato'.
#' @param max_grouping_len Integer scalar;
#' The maximum grouping length for unknown words.
#' The default value is `0L`, indicating the infinity length.
#' @param verbose Logical.
#' If `TRUE`, returns additional information for debugging.
#' @returns A function inheriting class `purrr_function_partial`.
#' @export
create_tagger <- function(sys_dic,
                          user_dic = "",
                          max_grouping_len = 0L,
                          verbose = FALSE) {
  if (!file.exists(sys_dic)) {
    rlang::abort(c(
      "`sys_dic` is not found.",
      "i" = "Download the dictionary file at first."
    ))
  }
  purrr::partial(
    tagger_inner,
    sys_dic = path.expand(sys_dic),
    user_dic = path.expand(user_dic),
    max_grouping_len = as.integer(max_grouping_len),
    verbose = verbose
  )
}

#' Tokenize sentences using a tagger
#'
#' @param x A data.frame like object or a character vector to be tokenized.
#' @param text_field <[`data-masked`][rlang::args_data_masking]>
#' String or symbol; column containing texts to be tokenized.
#' @param docid_field <[`data-masked`][rlang::args_data_masking]>
#' String or symbol; column containing document IDs.
#' @param split split Logical. When passed as `TRUE`, the function
#' internally splits the sentences into sub-sentences
#' @param mode Character scalar to switch output format.
#' @param tagger A tagger function created by [create_tagger()].
#' @returns A tibble or a named list of tokens.
#' @export
tokenize <- function(x,
                     text_field = "text",
                     docid_field = "doc_id",
                     split = FALSE,
                     mode = c("parse", "wakati"),
                     tagger) {
  UseMethod("tokenize", x)
}

#' @export
tokenize.default <- function(x,
                             text_field = "text",
                             docid_field = "doc_id",
                             split = FALSE,
                             mode = c("parse", "wakati"),
                             tagger) {
  mode <- rlang::arg_match(mode, c("parse", "wakati"))

  text_field <- enquo(text_field)
  docid_field <- enquo(docid_field)

  x <- dplyr::as_tibble(x)

  tbl <- tagger_impl(
    dplyr::pull(x, {{ text_field }}),
    dplyr::pull(x, {{ docid_field }}),
    split, tagger
  )

  # if it's a factor, preserve ordering
  col_names <- rlang::as_name(docid_field)
  if (is.factor(x[[col_names]])) {
    col_u <- levels(x[[col_names]])
  } else {
    col_u <- unique(x[[col_names]])
  }

  tbl <- x %>%
    dplyr::select(-!!text_field) %>%
    dplyr::mutate(dplyr::across(!!docid_field, ~ factor(., col_u))) %>%
    dplyr::rename(doc_id = {{ docid_field }}) %>%
    dplyr::left_join(
      tbl,
      by = c("doc_id" = "doc_id")
    )
  if (!identical(mode, "wakati")) {
    return(tbl)
  }
  as_tokens(tbl, pos_field = NULL)
}

#' @export
tokenize.character <- function(x,
                               text_field = "text",
                               docid_field = "doc_id",
                               split = FALSE,
                               mode = c("parse", "wakati"),
                               tagger) {
  mode <- rlang::arg_match(mode, c("parse", "wakati"))

  nm <- names(x)
  if (is.null(nm)) {
    nm <- seq_along(x)
  }
  tbl <- tagger_impl(x, nm, split, tagger)

  if (!identical(mode, "wakati")) {
    return(tbl)
  }
  as_tokens(tbl, pos_field = NULL)
}
paithiov909/kagomer documentation built on June 12, 2025, 7:44 a.m.