R/entity.R

Defines functions entity_extraction entity_extraction_indv remove_trailing_stopwords index_to_entity trim_outlier_indexes trim_overlapping_entities gen_entity_class_index gen_entity_class load_entity_model get_path_entity_model

# Constants --------------------------------------------------------------------
missing_entity_tag <- ""

# Functions --------------------------------------------------------------------

#' Retrieve path to entity extraction model
#'
#' Retrieves the path to the entity extraction model. This prevents a
#' hard path being defined, which would cause an error when verifying
#' staged installation.
#'
#' @noRd

get_path_entity_model <- function() {
  system.file("extdata", "models","entity_extraction",
              package = 'HypothesisReader')
}


#' Load entity extraction model
#'
#' Loads the causality classification model. Wrapped in memoise to avoid
#' repeated loading of the same model.
#'
#' @noRd

load_entity_model <- function() {
  model_entity <- NULL
  
  # Load Model
  model_entity <-  tf$keras$models$load_model(get_path_entity_model())
  model_entity
}

mem_load_entity_model <- memoise::memoise(load_entity_model)


#' Generate entity class prediction
#'
#' Converts extracted hypothesis statements into predicted entity classes, on
#' a token by token basis, using the entity extraction model.
#'
#' @param hypothesis hypothesis statements
#'
#' @noRd

gen_entity_class <- function(hypothesis) {
  # For R CMD Checks
  model_entity <- X1 <- X2 <- X3 <- NULL
  entity_predict <- NULL
  
  # Convert to Numpy Array
  hypothesis.np <- np$array(hypothesis)
  hypothesis.tf <- tf$convert_to_tensor(hypothesis.np)

  # Load entity extraction model
  model_entity <- mem_load_entity_model()

  # Generate predictions
  pred_classes.np <- model_entity$predict(hypothesis.tf)

  # Convert predictions to dataframe
  pred_classes.df <- data.frame(
    t(
      matrix(
        data = pred_classes.np,
        nrow = dim(pred_classes.np)[3],
        byrow=TRUE
      )
    )
  )

  # Determine class prediction for each token
  pred_classes <- pred_classes.df %>%
    dplyr::mutate(
      class = (purrr::pmap_int(
        .l = list(X1, X2, X3),
        .f = ~which.max(c(...))
      )
      ) - 1
    ) %>%
    dplyr::pull(class)

  pred_classes
}


#' Generate entity class indexes
#'
#' Returns the vector index location of each entity type as generated by
#' [gen_entity_class()]
#'
#' @param pred_classes predicted class vector
#'
#' @noRd

gen_entity_class_index <- function(pred_classes){
  # Define each entity class  vector index
  entity_idx_1 <- which(pred_classes %in% 1)
  entity_idx_2 <- which(pred_classes %in% 2)

  # Store indexes in list for output
  index_entities <- vector(
    mode   = "list",
    length = 2
  )

  index_entities[[1]] = entity_idx_1
  index_entities[[2]] = entity_idx_2

  index_entities
}


#' Trim overlapping entity indexes.
#'
#' Trims the indexes of overlapping entity class predictions.
#'
#' @param entity_index_input entity indexes by class, output of
#'  [gen_entity_class_index()]
#'
#' @noRd

trim_overlapping_entities <- function(entity_index_input){
  # Extract entity indexes
  entity_idx_1 <- entity_index_input[[1]]
  entity_idx_2 <- entity_index_input[[2]]

  # Define entity index range
  entity_range_1 <- min(entity_idx_1):max(entity_idx_1)
  entity_range_2 <- min(entity_idx_2):max(entity_idx_2)

  # Determine if ranges overlap
  range_overlap <- intersect(entity_range_1, entity_range_2)

  if (length(range_overlap) == 0) {

    return(entity_index_input)

    # If overlap exists, trim index of entity with longest range
  } else {
    # Define entity class range length
    span_entity_1 <- length(entity_range_1)
    span_entity_2 <- length(entity_range_2)

    # Determine median of entity index to determine which direction to trim
    med_entity_1_idx <- stats::median(entity_idx_1)
    med_entity_2_idx <- stats::median(entity_idx_2)

    # Initial entity class in hypothesis statement - Entity 1
    if (med_entity_1_idx < med_entity_2_idx) {
      if (span_entity_2 >= span_entity_1){
        entity_idx_2 <- entity_idx_2[-1]

      } else {
        entity_1_dim <- length(entity_idx_1)
        entity_idx_1 <- entity_idx_1[-entity_1_dim]
      }

      # Initial entity class in hypothesis statement - Entity 2
    } else {
      if (span_entity_2 >= span_entity_1){
        entity_2_dim <- length(entity_idx_2)
        entity_idx_2 <- entity_idx_2[-entity_2_dim]

      } else {
        entity_idx_1 <- entity_idx_1[-1]
      }
    }

    # Store trimmed indexes in list for output
    entity_index_trim <- vector(
      mode = "list",
      length = 2
    )

    entity_index_trim[[1]] = entity_idx_1
    entity_index_trim[[2]] = entity_idx_2

    # Recursively execute function
    trim_overlapping_entities(entity_index_trim)
  }
}


#' Trim outlier indexes
#'
#' Trims entity indexes if entity 1 or entity 2 are outliers,
#' as based on 1.5 * IQR criteria.
#'
#' @param entity_index_input entity indexes by class, output of
#'  [trim_overlapping_entities()]
#'
#' @noRd

trim_outlier_indexes <- function(entity_index_input) {
  # Initialize
  entity_index_output <- vector(
    mode   = "list",
    length = 2
  )

  for (i in seq_along(entity_index_input)){
    # Define index vector
    index = entity_index_input[[i]]

    # Skip process if index vector is empty
    if (purrr::is_empty(index)) {
      entity_index_output[[i]] <- index
      next
    }

    # Calculate R summary statistics
    summary <- as.vector(summary(index))

    # Define outlier parameters
    iqr.range <- summary[5] - summary[2]
    upper <- summary[5] + iqr.range * 1.5
    lower <- summary[2] - iqr.range * 1.5

    # Drop if index is outlier
    index <- index[index >= lower]
    index <- index[index <= upper]

    entity_index_output[[i]] <- index
  }

  entity_index_output
}


#' Convert entity indexes to text strings.
#'
#' Converts the entity index vector to the entity text.
#'
#' @param hypothesis hypothesis statement text, string
#' @param entity_index_input entity indexes by class, output of
#'  [trim_outlier_indexes()]
#'
#' @noRd

index_to_entity <- function(hypothesis, entity_index_input) {
  # Initialize
  entity_text_output <- vector(
    mode = "character",
    length = 2
  )

  # Convert hypothesis to tokens
  tokens_all <- stringr::str_split(
    string  = hypothesis,
    pattern = " ")[[1]]

  for (i in seq_along(entity_index_input)){
    # Define index vector
    index = entity_index_input[[i]]

    # Skip process if index vector is empty
    if (purrr::is_empty(index)){
      entity_text_output[[i]] <- missing_entity_tag
      next
    }

    # Extract entity tokens
    tokens_entity <- tokens_all[
      min(entity_index_input[[i]]):
        max(entity_index_input[[i]])
    ]

    # Concatenate tokens to text strings
    entity_text_output[i] <- stringr::str_c(
      tokens_entity,
      collapse = " "
    )
  }

  entity_text_output
}


#' Drop trailing stopwords
#'
#' Removes last token(s) if they are a stopword
#'
#' @param hypothesis hypothesis statement text, string
#' @param entity_text_input entity text, output of
#'  [index_to_entity()]
#'
#' @noRd

remove_trailing_stopwords <- function(entity_text_input) {
  # Initialize
}


#' Extract entity clauses (single case)
#'
#' Wrapper function. Executes all steps in the entity  extraction process for
#' a single hypothesis statement.
#'
#' @param hypothesis hypothesis statement text, string
#'
#' @noRd

entity_extraction_indv <- function(hypothesis) {
  # Generate entity class predictions
  pred_classes <- gen_entity_class(hypothesis)
  index_entities <- gen_entity_class_index(pred_classes)

  # Trim overlapping entities
  ## Verify both entities detected
  both_entity_present = FALSE
  if (
    !(purrr::is_empty(index_entities[[1]])) &
    !(purrr::is_empty(index_entities[[2]]))
  ) {
    both_entity_present = TRUE
  }

  ## Trim overlap
  if (both_entity_present) {
    index_entities <- trim_overlapping_entities(index_entities)
  }

  # Remove Outliers
  index_entities <- trim_outlier_indexes(index_entities)

  ## Convert Indexes to Text
  entity_text_output <- index_to_entity(hypothesis, index_entities)

  entity_text_output
}


#' Extract entity clauses (multiple cases)
#'
#' Wrapper function. Executes all steps in the entity  extraction process for
#' a multiple hypothesis statements.
#'
#' @param hypothesis.df hypothesis statement output of [hypothesis_extraction()]
#'
#' @noRd

entity_extraction <- function(hypothesis.df){
  # For R CMD Checks
  cause <- effect <- V1 <- V2 <- NULL

  # Extract entity extraction hypothesis input
  hypothesis.v <- hypothesis.df %>%
    dplyr::pull(hypothesis)

  # Replace & with and
  hypothesis.v <- hypothesis.v %>%
    stringr::str_replace_all(
      pattern = "&",
      replacement = "and"
      )

  # Initialize output list
  num_hypothesis <- length(hypothesis.v)

  lst_entity_text_output <- vector(
    mode = "list",
    length = num_hypothesis
  )

  for (i in seq_along(hypothesis.v)){
    # Extract hypothesis
    hypothesis <- hypothesis.v[i]

    # Extract entities
    entity_text_output <- entity_extraction_indv(hypothesis)

    # Store in output list
    lst_entity_text_output[[i]] <- entity_text_output
  }

  # Convert list of lists to Dataframe
  entity_text_output.df <- as.data.frame(
    do.call(
      rbind,
      lapply(lst_entity_text_output, as.vector))
  ) %>%
    dplyr::rename(
      cause  = V1,
      effect = V2
    )

  # Replace missing entity
  entity_text_output.df <- entity_text_output.df %>%
    # dplyr::mutate(                                    # NA is missing
    #   cause  = dplyr::na_if(
    #     cause,
    #     missing_entity_tag
    #     ),
    #   effect = dplyr::na_if(
    #     effect,
    #     missing_entity_tag
    #     )
    # ) %>%
    dplyr::mutate(                                   # remove periods
      effect = stringr::str_remove_all(
        string  = effect,
        pattern = "\\."
      )
    ) %>%
    dplyr::mutate(
      cause = as.character(cause)
    )
  entity_text_output.df

}
canfielder/CausalityExtraction documentation built on Jan. 5, 2022, 10:55 a.m.