# Constants --------------------------------------------------------------------
missing_entity_tag <- ""
# Functions --------------------------------------------------------------------
#' Retrieve path to entity extraction model
#'
#' Retrieves the path to the entity extraction model. This prevents a
#' hard path being defined, which would cause an error when verifying
#' staged installation.
#'
#' @noRd
get_path_entity_model <- function() {
system.file("extdata", "models","entity_extraction",
package = 'HypothesisReader')
}
#' Load entity extraction model
#'
#' Loads the causality classification model. Wrapped in memoise to avoid
#' repeated loading of the same model.
#'
#' @noRd
load_entity_model <- function() {
model_entity <- NULL
# Load Model
model_entity <- tf$keras$models$load_model(get_path_entity_model())
model_entity
}
mem_load_entity_model <- memoise::memoise(load_entity_model)
#' Generate entity class prediction
#'
#' Converts extracted hypothesis statements into predicted entity classes, on
#' a token by token basis, using the entity extraction model.
#'
#' @param hypothesis hypothesis statements
#'
#' @noRd
gen_entity_class <- function(hypothesis) {
# For R CMD Checks
model_entity <- X1 <- X2 <- X3 <- NULL
entity_predict <- NULL
# Convert to Numpy Array
hypothesis.np <- np$array(hypothesis)
hypothesis.tf <- tf$convert_to_tensor(hypothesis.np)
# Load entity extraction model
model_entity <- mem_load_entity_model()
# Generate predictions
pred_classes.np <- model_entity$predict(hypothesis.tf)
# Convert predictions to dataframe
pred_classes.df <- data.frame(
t(
matrix(
data = pred_classes.np,
nrow = dim(pred_classes.np)[3],
byrow=TRUE
)
)
)
# Determine class prediction for each token
pred_classes <- pred_classes.df %>%
dplyr::mutate(
class = (purrr::pmap_int(
.l = list(X1, X2, X3),
.f = ~which.max(c(...))
)
) - 1
) %>%
dplyr::pull(class)
pred_classes
}
#' Generate entity class indexes
#'
#' Returns the vector index location of each entity type as generated by
#' [gen_entity_class()]
#'
#' @param pred_classes predicted class vector
#'
#' @noRd
gen_entity_class_index <- function(pred_classes){
# Define each entity class vector index
entity_idx_1 <- which(pred_classes %in% 1)
entity_idx_2 <- which(pred_classes %in% 2)
# Store indexes in list for output
index_entities <- vector(
mode = "list",
length = 2
)
index_entities[[1]] = entity_idx_1
index_entities[[2]] = entity_idx_2
index_entities
}
#' Trim overlapping entity indexes.
#'
#' Trims the indexes of overlapping entity class predictions.
#'
#' @param entity_index_input entity indexes by class, output of
#' [gen_entity_class_index()]
#'
#' @noRd
trim_overlapping_entities <- function(entity_index_input){
# Extract entity indexes
entity_idx_1 <- entity_index_input[[1]]
entity_idx_2 <- entity_index_input[[2]]
# Define entity index range
entity_range_1 <- min(entity_idx_1):max(entity_idx_1)
entity_range_2 <- min(entity_idx_2):max(entity_idx_2)
# Determine if ranges overlap
range_overlap <- intersect(entity_range_1, entity_range_2)
if (length(range_overlap) == 0) {
return(entity_index_input)
# If overlap exists, trim index of entity with longest range
} else {
# Define entity class range length
span_entity_1 <- length(entity_range_1)
span_entity_2 <- length(entity_range_2)
# Determine median of entity index to determine which direction to trim
med_entity_1_idx <- stats::median(entity_idx_1)
med_entity_2_idx <- stats::median(entity_idx_2)
# Initial entity class in hypothesis statement - Entity 1
if (med_entity_1_idx < med_entity_2_idx) {
if (span_entity_2 >= span_entity_1){
entity_idx_2 <- entity_idx_2[-1]
} else {
entity_1_dim <- length(entity_idx_1)
entity_idx_1 <- entity_idx_1[-entity_1_dim]
}
# Initial entity class in hypothesis statement - Entity 2
} else {
if (span_entity_2 >= span_entity_1){
entity_2_dim <- length(entity_idx_2)
entity_idx_2 <- entity_idx_2[-entity_2_dim]
} else {
entity_idx_1 <- entity_idx_1[-1]
}
}
# Store trimmed indexes in list for output
entity_index_trim <- vector(
mode = "list",
length = 2
)
entity_index_trim[[1]] = entity_idx_1
entity_index_trim[[2]] = entity_idx_2
# Recursively execute function
trim_overlapping_entities(entity_index_trim)
}
}
#' Trim outlier indexes
#'
#' Trims entity indexes if entity 1 or entity 2 are outliers,
#' as based on 1.5 * IQR criteria.
#'
#' @param entity_index_input entity indexes by class, output of
#' [trim_overlapping_entities()]
#'
#' @noRd
trim_outlier_indexes <- function(entity_index_input) {
# Initialize
entity_index_output <- vector(
mode = "list",
length = 2
)
for (i in seq_along(entity_index_input)){
# Define index vector
index = entity_index_input[[i]]
# Skip process if index vector is empty
if (purrr::is_empty(index)) {
entity_index_output[[i]] <- index
next
}
# Calculate R summary statistics
summary <- as.vector(summary(index))
# Define outlier parameters
iqr.range <- summary[5] - summary[2]
upper <- summary[5] + iqr.range * 1.5
lower <- summary[2] - iqr.range * 1.5
# Drop if index is outlier
index <- index[index >= lower]
index <- index[index <= upper]
entity_index_output[[i]] <- index
}
entity_index_output
}
#' Convert entity indexes to text strings.
#'
#' Converts the entity index vector to the entity text.
#'
#' @param hypothesis hypothesis statement text, string
#' @param entity_index_input entity indexes by class, output of
#' [trim_outlier_indexes()]
#'
#' @noRd
index_to_entity <- function(hypothesis, entity_index_input) {
# Initialize
entity_text_output <- vector(
mode = "character",
length = 2
)
# Convert hypothesis to tokens
tokens_all <- stringr::str_split(
string = hypothesis,
pattern = " ")[[1]]
for (i in seq_along(entity_index_input)){
# Define index vector
index = entity_index_input[[i]]
# Skip process if index vector is empty
if (purrr::is_empty(index)){
entity_text_output[[i]] <- missing_entity_tag
next
}
# Extract entity tokens
tokens_entity <- tokens_all[
min(entity_index_input[[i]]):
max(entity_index_input[[i]])
]
# Concatenate tokens to text strings
entity_text_output[i] <- stringr::str_c(
tokens_entity,
collapse = " "
)
}
entity_text_output
}
#' Drop trailing stopwords
#'
#' Removes last token(s) if they are a stopword
#'
#' @param hypothesis hypothesis statement text, string
#' @param entity_text_input entity text, output of
#' [index_to_entity()]
#'
#' @noRd
remove_trailing_stopwords <- function(entity_text_input) {
# Initialize
}
#' Extract entity clauses (single case)
#'
#' Wrapper function. Executes all steps in the entity extraction process for
#' a single hypothesis statement.
#'
#' @param hypothesis hypothesis statement text, string
#'
#' @noRd
entity_extraction_indv <- function(hypothesis) {
# Generate entity class predictions
pred_classes <- gen_entity_class(hypothesis)
index_entities <- gen_entity_class_index(pred_classes)
# Trim overlapping entities
## Verify both entities detected
both_entity_present = FALSE
if (
!(purrr::is_empty(index_entities[[1]])) &
!(purrr::is_empty(index_entities[[2]]))
) {
both_entity_present = TRUE
}
## Trim overlap
if (both_entity_present) {
index_entities <- trim_overlapping_entities(index_entities)
}
# Remove Outliers
index_entities <- trim_outlier_indexes(index_entities)
## Convert Indexes to Text
entity_text_output <- index_to_entity(hypothesis, index_entities)
entity_text_output
}
#' Extract entity clauses (multiple cases)
#'
#' Wrapper function. Executes all steps in the entity extraction process for
#' a multiple hypothesis statements.
#'
#' @param hypothesis.df hypothesis statement output of [hypothesis_extraction()]
#'
#' @noRd
entity_extraction <- function(hypothesis.df){
# For R CMD Checks
cause <- effect <- V1 <- V2 <- NULL
# Extract entity extraction hypothesis input
hypothesis.v <- hypothesis.df %>%
dplyr::pull(hypothesis)
# Replace & with and
hypothesis.v <- hypothesis.v %>%
stringr::str_replace_all(
pattern = "&",
replacement = "and"
)
# Initialize output list
num_hypothesis <- length(hypothesis.v)
lst_entity_text_output <- vector(
mode = "list",
length = num_hypothesis
)
for (i in seq_along(hypothesis.v)){
# Extract hypothesis
hypothesis <- hypothesis.v[i]
# Extract entities
entity_text_output <- entity_extraction_indv(hypothesis)
# Store in output list
lst_entity_text_output[[i]] <- entity_text_output
}
# Convert list of lists to Dataframe
entity_text_output.df <- as.data.frame(
do.call(
rbind,
lapply(lst_entity_text_output, as.vector))
) %>%
dplyr::rename(
cause = V1,
effect = V2
)
# Replace missing entity
entity_text_output.df <- entity_text_output.df %>%
# dplyr::mutate( # NA is missing
# cause = dplyr::na_if(
# cause,
# missing_entity_tag
# ),
# effect = dplyr::na_if(
# effect,
# missing_entity_tag
# )
# ) %>%
dplyr::mutate( # remove periods
effect = stringr::str_remove_all(
string = effect,
pattern = "\\."
)
) %>%
dplyr::mutate(
cause = as.character(cause)
)
entity_text_output.df
}
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.