tests/testthat/testthat-pretrained-pipeline.R

setup({
  sc <- testthat_spark_connection()
  text_tbl <- testthat_tbl("test_text")
  
  assign("sc", sc, envir = parent.frame())
  assign("text_tbl", text_tbl, envir = parent.frame())
})

teardown({
  spark_disconnect(sc)
  rm(sc, envir = .GlobalEnv)
  rm(text_tbl, envir = .GlobalEnv)
})

test_that("nlp_pretrained_pipeline() tbl_spark", {
  result <- nlp_pretrained_pipeline(text_tbl, "recognize_entities_dl")
  expect_true("entities" %in% colnames(result))
})

test_that("nlp_pretrained_pipeline() spark_connection", {
  result <- nlp_pretrained_pipeline(sc, "recognize_entities_dl")
  expect_equal(jobj_class(spark_jobj(result)), c("PretrainedPipeline", "Object"))
})

test_that("nlp_pretrained_pipeline annotate", {
  pipeline <- nlp_pretrained_pipeline(sc, "recognize_entities_dl")
  annotations <- nlp_annotate(pipeline, text_tbl, column = "text")
  expect_true("entities" %in% colnames(annotations))
})

test_that("as_pipeline_model().nlp_pretrained_pipeline", {
  pipeline <- nlp_pretrained_pipeline(sc, "recognize_entities_dl")
  pm <- as_pipeline_model(pipeline)
  expect_s3_class(pm, "ml_pipeline_model")
})
r-spark/sparknlp documentation built on Oct. 15, 2022, 10:50 a.m.