tests/testthat/test-serve.R

context("Serve")

library("carrier")

wait_for_server_to_start <- function(server_process, port) {
  status_code <- 500
  for (i in 1:5) {
    tryCatch(
      {
        status_code <- httr::status_code(httr::GET(sprintf("http://127.0.0.1:%d/ping/", port)))
      },
      error = function(...) {
        status_code <- 500
      }
    )
    if (status_code != 200) {
      Sys.sleep(5)
    } else {
      break
    }
  }
  if (status_code != 200) {
    write("FAILED to start the server!", stderr())
    error_text <- server_process$read_error()
    stop("Failed to start the server!", error_text)
  }
}

testthat_model_server <- NULL

teardown({
  if (!is.null(testthat_model_server)) {
    testthat_model_server$kill()
  }
})

test_that("mlflow can serve a model function", {
  mlflow_clear_test_dir("model")

  model <- lm(Sepal.Width ~ Sepal.Length + Petal.Width, iris)
  fn <- crate(~ stats::predict(model, .x), model = model)
  mlflow_save_model(fn, path = "model")
  expect_true(dir.exists("model"))
  port <- httpuv::randomPort()
  testthat_model_server <<- processx::process$new(
    "Rscript",
    c(
      "-e",
      sprintf(
        "mlflow::mlflow_rfunc_serve('model', port = %d, browse = FALSE)",
        port
      )
    ),
    supervise = TRUE,
    stdout = "|",
    stderr = "|"
  )
  wait_for_server_to_start(testthat_model_server, port)
  newdata <- iris[1:2, c("Sepal.Length", "Petal.Width")]

  http_prediction <- httr::content(
    httr::POST(
      sprintf("http://127.0.0.1:%d/predict/", port),
      body = jsonlite::toJSON(as.list(newdata))
    )
  )
  if (is.character(http_prediction)) {
    stop(http_prediction)
  }

  expect_equal(
    unlist(http_prediction),
    as.vector(predict(model, newdata)),
    tolerance = 1e-5
  )
})

test_that("mlflow models server api works with R model function", {
  model <- lm(Sepal.Width ~ Sepal.Length + Petal.Width, iris)
  fn <- crate(~ stats::predict(model, .x), model = model)
  mlflow_save_model(fn, path = "model")
  expect_true(dir.exists("model"))
  port <- httpuv::randomPort()
  testthat_model_server <<- mlflow:::mlflow_cli("models", "serve", "-m", "model", "-p", as.character(port),
    background = TRUE
  )
  wait_for_server_to_start(testthat_model_server, port)
  newdata <- iris[1:2, c("Sepal.Length", "Petal.Width")]
  check_prediction <- function(http_prediction) {
    if (is.character(http_prediction)) {
      stop(http_prediction)
    }
    expect_equal(
      unlist(http_prediction),
      as.vector(predict(model, newdata)),
      tolerance = 1e-5
    )
  }
  # json records
  check_prediction(
    httr::content(
      httr::POST(
        sprintf("http://127.0.0.1:%d/invocation/", port),
        httr::content_type("application/json"),
        body = jsonlite::toJSON(list(dataframe_records=as.list(newdata)))
      )
    )
  )
  newdata_split <- list(
    columns = names(newdata), index = row.names(newdata),
    data = as.matrix(newdata)
  )
  # json split
  content_type <- "application/json"
  check_prediction(
    httr::content(
      httr::POST(
        sprintf("http://127.0.0.1:%d/invocation/", port),
        httr::content_type(content_type),
        body = jsonlite::toJSON(list(dataframe_split=newdata_split))
      )
    )
  )

  # csv
  csv_header <- paste(names(newdata), collapse = ", ")
  csv_row_1 <- paste(newdata[1, ], collapse = ", ")
  csv_row_2 <- paste(newdata[2, ], collapse = ", ")
  newdata_csv <- paste(csv_header, csv_row_1, csv_row_2, "", sep = "\n")
  check_prediction(
    httr::content(
      httr::POST(
        sprintf("http://127.0.0.1:%d/invocation/", port),
        httr::content_type("text/csv"),
        body = newdata_csv
      )
    )
  )
})

Try the mlflow package in your browser

Any scripts or data that you put into this service are public.

mlflow documentation built on Nov. 23, 2023, 9:13 a.m.