R/download_checkpoint.R

Defines functions .maybe_download_checkpoint `%||%` set_BERT_dir .infer_checkpoint_archive_path .infer_ckpt_dir .infer_archive_type .get_model_archive_path .get_model_subdir .get_model_archive_type .get_model_url .process_BERT_checkpoint .download_BERT_checkpoint .has_checkpoint .choose_BERT_dir download_BERT_checkpoint

Documented in .choose_BERT_dir .download_BERT_checkpoint download_BERT_checkpoint .get_model_archive_path .get_model_archive_type .get_model_subdir .get_model_url .has_checkpoint .infer_archive_type .infer_checkpoint_archive_path .infer_ckpt_dir .maybe_download_checkpoint .process_BERT_checkpoint set_BERT_dir

# Copyright 2019 Bedford Freeman & Worth Pub Grp LLC DBA Macmillan Learning.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# download_BERT_checkpoint ------------------------------------------------

#' Download a BERT checkpoint
#'
#' Downloads the specified BERT checkpoint from the Google Research collection
#' or other repositories.
#'
#' @section Checkpoints: \code{download_BERT_checkpoint} knows about several
#'   pre-trained BERT checkpoints. You can specify these checkpoints using the
#'   \code{model} parameter. Alternatively, you can supply a direct \code{url}
#'   to any BERT tensorflow checkpoint.
#'
#'   \tabular{rccccl}{ model \tab layers \tab hidden \tab heads \tab parameters
#'   \tab special\cr bert_base_* \tab 12 \tab 768 \tab 12 \tab 110M\cr
#'   bert_large_* \tab 24 \tab 1024 \tab 16 \tab 340M\cr bert_large_*_wwm \tab
#'   24 \tab 1024 \tab 16 \tab 340M \tab whole word masking\cr
#'   bert_base_multilingual_cased \tab 12 \tab 768 \tab 12 \tab 110M \tab 104
#'   languages\cr bert_base_chinese \tab 12 \tab 768 \tab 12 \tab 110M \tab
#'   Chinese Simplified and Traditional\cr scibert_scivocab_* \tab 12 \tab 768
#'   \tab 12 \tab 110M \tab Trained using the full text of 1.14M scientific
#'   papers (18\% computer science, 82\% biomedical), with a science-specific
#'   vocabulary.\cr scibert_basevocab_uncased \tab 12 \tab 768 \tab 12 \tab 110M
#'   \tab As scibert_scivocab_*, but using the original BERT vocabulary. }
#'
#' @param model Character vector. Which model checkpoint to download.
#' @param dir Character vector. Destination directory for checkpoints. Leave
#'   \code{NULL} to allow RBERT to automatically choose a directory. The path is
#'   determined from the \code{dir} parameter if supplied, followed by the
#'   `RBERT.dir` option (set using \link{set_BERT_dir}), followed by an "RBERT"
#'   folder in the user cache directory (determined using
#'   \code{\link[rappdirs]{user_cache_dir}}). If you provide a \code{dir}, the
#'   `RBERT.dir` option will be updated to that location. Note that the
#'   checkpoint will create a subdirectory inside this \code{dir}.
#' @param url Character vector. An optional url from which to download a
#'   checkpoint. Overrides \code{model} parameter if not NULL.
#' @param force Logical. Download even if the checkpoint already exists in the
#'   specified directory? Default \code{FALSE}.
#' @param keep_archive Logical. Keep the zip (or other archive) file? Leave as
#'   \code{FALSE} to save space.
#' @param archive_type How is the checkpoint archived? We currently support
#'   "zip" and "tar-gzip". Leave NULL to infer from the \code{url}.
#'
#' @return If successful, returns the path to the downloaded checkpoint.
#' @export
#'
#' @source \url{https://github.com/google-research/bert}
#' @source \url{https://github.com/allenai/scibert}
#'
#' @examples
#' \dontrun{
#' download_BERT_checkpoint("bert_base_uncased")
#' download_BERT_checkpoint("bert_large_uncased")
#' temp_dir <- tempdir()
#' download_BERT_checkpoint("bert_base_uncased", dir = temp_dir)
#' }
download_BERT_checkpoint <- function(model = c(
                                       "bert_base_uncased",
                                       "bert_base_cased",
                                       "bert_large_uncased",
                                       "bert_large_cased",
                                       "bert_large_uncased_wwm",
                                       "bert_large_cased_wwm",
                                       "bert_base_multilingual_cased",
                                       "bert_base_chinese",
                                       "scibert_scivocab_uncased",
                                       "scibert_scivocab_cased",
                                       "scibert_basevocab_uncased",
                                       "scibert_basevocab_cased"
                                     ),
                                     dir = NULL,
                                     url = NULL,
                                     force = FALSE,
                                     keep_archive = FALSE,
                                     archive_type = NULL) {
  dir <- .choose_BERT_dir(dir)
  # Use the same location for the dir for the rest of this session. This
  # function also attempts to create the directory if it does not exist.
  set_BERT_dir(dir)

  if (is.null(url)) {
    model <- match.arg(model)
    url <- .get_model_url(model)
    archive_type <- .get_model_archive_type(model)
    ckpt_dir <- .get_model_subdir(model, dir)
    checkpoint_archive_path <- .get_model_archive_path(model, dir, archive_type)
  } else {
    model <- NULL
    archive_type <- archive_type %||% .infer_archive_type(url)
    ckpt_dir <- .infer_ckpt_dir(url, dir)
    checkpoint_archive_path <- .infer_checkpoint_archive_path(url, dir)
  }

  has_checkpoint <- .has_checkpoint(
    model = model,
    dir = dir,
    ckpt_dir = ckpt_dir
  )

  if (
    force ||
      (keep_archive && !file.exists(checkpoint_archive_path)) ||
      !has_checkpoint
  ) {
    .download_BERT_checkpoint(url, checkpoint_archive_path)
  }

  if (force || !has_checkpoint) {
    .process_BERT_checkpoint(
      dir,
      checkpoint_archive_path,
      ckpt_dir,
      archive_type
    )
  }

  if (!keep_archive && file.exists(checkpoint_archive_path)) {
    file.remove(checkpoint_archive_path)
  }

  # The normalizePath shouldn't be necessary here, but I was getting
  # inconsistent returns on Windows. I suspect it's because the return is
  # slightly different when the path exists.
  return(normalizePath(ckpt_dir))
}

#' Choose a directory for BERT checkpoints
#'
#' If \code{dir} is not NULL, this function simply returns \code{dir}. Otherwise
#' it checks the `RBERT.dir` param, and then uses
#' \code{\link[rappdirs]{user_cache_dir}} to choose a directory if necessary.
#'
#' @inheritParams download_BERT_checkpoint
#'
#' @return A character vector indicating a directory in which BERT checkpoints
#'   are stored.
#' @keywords internal
.choose_BERT_dir <- function(dir) {
  return(
    dir %||%
      getOption("BERT.dir") %||%
      rappdirs::user_cache_dir(appname = "RBERT")
  )
}

#' Check whether the user already has a checkpoint
#'
#' Check the specified dir (or the default dir if none is specified) for a given
#' model or url.
#'
#' @inheritParams download_BERT_checkpoint
#' @param ckpt_dir The path to the subdir where this checkpoint should
#'   be saved. If model is given, ckpt_dir is inferred.
#'
#' @return A logical indicating whether the user already has that checkpoint in
#'   that location.
#' @keywords internal
.has_checkpoint <- function(model = NULL,
                            dir = NULL,
                            ckpt_dir = NULL) {
  dir <- .choose_BERT_dir(dir)
  if (is.null(ckpt_dir)) {
    ckpt_dir <- .get_model_subdir(model, dir)
  }
  filenames <- list.files(ckpt_dir)

  return(
    !any(
      !any(grepl("bert_config.json", filenames), na.rm = TRUE),
      !any(grepl("vocab.txt", filenames), na.rm = TRUE),
      !any(grepl("bert_model.ckpt", filenames), na.rm = TRUE)
    )
  )
}

#' Download a checkpoint zip file
#'
#' @inheritParams download_BERT_checkpoint
#' @param checkpoint_zip_path The path to which the checkpoint zip should be
#'   downloaded.
#'
#' @return \code{TRUE} invisibly.
#' @keywords internal
.download_BERT_checkpoint <- function(url, checkpoint_zip_path) {
  # This function is stubbed for testing, so tests don't see it. Test manually
  # from time to time, but it's really straightforward.
  # nocov start
  status <- utils::download.file(
    url = url,
    destfile = checkpoint_zip_path,
    method = "libcurl"
  )
  if (status != 0) {
    stop("Checkpoint download failed.") # nocov
  }
  invisible(TRUE)
  # nocov end
}

#' Unzip and check a BERT checkpoint zip
#'
#' @inheritParams download_BERT_checkpoint
#' @inheritParams .download_BERT_checkpoint
#'
#' @return \code{TRUE} invisibly.
#' @keywords internal
.process_BERT_checkpoint <- function(dir,
                                     checkpoint_archive_path,
                                     ckpt_dir,
                                     archive_type) {
  # We're only here if the files don't exist or we're supposed to overwrite, so
  # always overwrite.
  switch(
    archive_type,
    "zip" = utils::unzip(
      zipfile = checkpoint_archive_path,
      exdir = ckpt_dir,
      overwrite = TRUE
    ),
    "tar-gzip" = {
      con <- gzfile(checkpoint_archive_path, open = "rb")
      utils::untar(
        con,
        exdir = ckpt_dir
      )
      close(con)
    }
  )

  # We write into the subdir, but *usually* it'll make a folder inside of that
  # dir. Move everything up to be inside ckpt_dir.
  extra_dirs <- list.dirs(
    ckpt_dir,
    full.names = TRUE, recursive = FALSE
  )
  if (length(extra_dirs) > 0) {
    for (dir_name in extra_dirs) {
      cp_files <- list.files(
        dir_name,
        recursive = TRUE
      )
      file.rename(
        file.path(dir_name, cp_files),
        file.path(ckpt_dir, cp_files)
      )
      unlink(dir_name, recursive = TRUE)
    }
  }

  filenames <- list.files(ckpt_dir)

  # Quick check to see if expected files found.
  if (!("bert_config.json" %in% filenames)) {
    warning("No bert_config file found.") # nocovr
  }
  if (!("vocab.txt" %in% filenames)) {
    warning("No vocabulary file found.") # nocovr
  }
  if (!any(grepl("bert_model.ckpt", filenames))) {
    warning("No checkpoint file found.") # nocovr
  }

  invisible(TRUE)
}



# .get_model_* ------------------------------------------------

#' Get url of a BERT checkpoint
#'
#' Returns the url of the specified BERT checkpoint from the Google Research
#' collection or other repository.
#'
#' @inheritParams download_BERT_checkpoint
#'
#' @return The url to the specified BERT model.
#' @keywords internal
.get_model_url <- function(model) {
  return(
    checkpoint_url_map[checkpoint_url_map$model == model, ][["url"]]
  )
}

#' Get archive type of a BERT checkpoint
#'
#' Returns the archive type ("zip" or "tar-gzip") of the specified BERT
#' checkpoint from the Google Research collection or other repository.
#'
#' @inheritParams download_BERT_checkpoint
#'
#' @return The archive type to the specified BERT model.
#' @keywords internal
.get_model_archive_type <- function(model) {
  return(
    checkpoint_url_map[checkpoint_url_map$model == model, ][["archive_type"]]
  )
}

#' Locate a subdir for a BERT checkpoint
#'
#' @inheritParams download_BERT_checkpoint
#'
#' @return The path to the sub-directory where the checkpoint should be saved.
#' @keywords internal
.get_model_subdir <- function(model, dir) {
  return(
    normalizePath(
      file.path(dir, model),
      mustWork = FALSE
    )
  )
}

#' Locate an archive file for a BERT checkpoint
#'
#' @inheritParams download_BERT_checkpoint
#'
#' @return The path to the archive file where the raw checkpoint should be
#'   saved.
#' @keywords internal
.get_model_archive_path <- function(model, dir, archive_type) {
  archive_ending <- c(
    "zip" = ".zip",
    "tar-gzip" = ".tar.gz"
  )[[archive_type]]
  return(
    normalizePath(
      file.path(dir, paste0(model, archive_ending)),
      mustWork = FALSE
    )
  )
}


# .infer_archive_* --------------------------------------------------------

#' Infer the archive type for a BERT checkpoint
#'
#' @inheritParams download_BERT_checkpoint
#'
#' @return A character vector, currently either "zip" or "tar-gzip".
#' @keywords internal
.infer_archive_type <- function(url) {
  if (stringr::str_detect(url, "\\.tar\\.gz$")) {
    return("tar-gzip")
  } else if (stringr::str_detect(url, "\\.zip$")) {
    return("zip")
  } else { # nocov start
    stop(
      "Unknown archive type. Please supply an explicit archive_type."
    )
  } # nocov end
}

#' Infer the subdir for a BERT checkpoint
#'
#' @inheritParams download_BERT_checkpoint
#'
#' @return A character vector file path, reflecting the "name" part of a
#'   checkpoint \code{url}, placed within \code{dir}.
#' @keywords internal
.infer_ckpt_dir <- function(url, dir) {
  return(
    normalizePath(
      file.path(
        dir,
        stringr::str_replace_all(
          basename(url),
          c(
            "\\.tar\\.gz$" = "",
            "\\.zip$" = ""
          )
        )
      ),
      mustWork = FALSE
    )
  )
}

#' Infer the path to the archive for a BERT checkpoint
#'
#' @inheritParams download_BERT_checkpoint
#'
#' @return A character vector file path, pointing to where the raw checkpoint
#'   archive should be saved.
#' @keywords internal
.infer_checkpoint_archive_path <- function(url, dir) {
  return(
    normalizePath(
      file.path(
        dir,
        basename(url)
      ),
      mustWork = FALSE
    )
  )
}

# set_BERT_dir ------------------------------------------------------------

#' Set the directory for BERT checkpoints
#'
#' Set a given \code{dir} as the default BERT checkpoint directory for this
#' session, and create it if it does not exist.
#'
#' @inheritParams download_BERT_checkpoint
#'
#' @return A list with the previous value of `BERT.dir` (invisibly).
#' @export
#'
#' @examples
#' \dontrun{
#' set_BERT_dir("fake_dir")
#' }
set_BERT_dir <- function(dir) {
  if (!file.exists(dir)) {
    dir.create(dir, recursive = TRUE) # nocov
  }
  dir <- normalizePath(dir)
  options(BERT.dir = dir)
}

# Copied from `rlang` to avoid importing that package. Roxygen doesn't like it
# and I'm not sure how to fix that, so instead I'm not documenting.
`%||%` <- function(x, y) {
  if (is.null(x)) {
    y
  } else {
    x
  }
}

#' Find or Possibly Download a Checkpoint
#'
#' Verify that the user has a specified checkpoint, and prompt to download if
#' they don't (in interactive mode).
#'
#' @inheritParams .has_checkpoint
#' @return TRUE (invisibly)
#' @keywords internal
.maybe_download_checkpoint <- function(model = c(
                                         "bert_base_uncased",
                                         "bert_base_cased",
                                         "bert_large_uncased",
                                         "bert_large_cased",
                                         "bert_large_uncased_wwm",
                                         "bert_large_cased_wwm",
                                         "bert_base_multilingual_cased",
                                         "bert_base_chinese",
                                         "scibert_scivocab_uncased",
                                         "scibert_scivocab_cased",
                                         "scibert_basevocab_uncased",
                                         "scibert_basevocab_cased"
                                       ),
                                       dir = NULL,
                                       ckpt_dir = NULL) {
  has_checkpoint <- .has_checkpoint(
    model = model,
    dir = dir,
    ckpt_dir = ckpt_dir
  )
  if (!has_checkpoint) {
    if (interactive()) { # nocov start
      do_download <- utils::menu(
        c("Yes.", "No."),
        title = paste(
          "Model not found. Do you wish to download the model?",
          "\nThis may take a long time and use a lot of disk space."
        )
      )
      if (do_download == 1L) {
        download_BERT_checkpoint(model, dir)
      } else {
        stop(
          "Could not find the specified model. Specify ckpt_dir, or ",
          "run download_BERT_checkpoint."
        )
      }
      # nocov end
    } else {
      stop(
        "Could not find the specified model. Specify ckpt_dir, or ",
        "run download_BERT_checkpoint."
      )
    }
  }
  invisible(TRUE)
}
jonathanbratt/RBERT documentation built on Jan. 26, 2023, 4:15 p.m.