model_fn_builder: Define 'model_fn' closure for 'TPUEstimator'

View source: R/run_classifier.R

model_fn_builderR Documentation

Define model_fn closure for TPUEstimator

Description

Returns model_fn closure, which is an input to TPUEstimator.

Usage

model_fn_builder(
  bert_config,
  num_labels,
  init_checkpoint,
  learning_rate,
  num_train_steps,
  num_warmup_steps,
  use_tpu
)

Arguments

bert_config

BertConfig instance.

num_labels

Integer; number of classification labels.

init_checkpoint

Character; path to the checkpoint directory, plus checkpoint name stub (e.g. "bert_model.ckpt"). Path must be absolute and explicit, starting with "/".

learning_rate

Numeric; the learning rate.

num_train_steps

Integer; number of steps to train for.

num_warmup_steps

Integer; number of steps to use for "warm-up".

use_tpu

Logical; whether to use TPU.

Details

The model_fn function takes four parameters:

features

A list (or similar structure) that contains objects such as input_ids, input_mask, segment_ids, and label_ids. These objects will be inputs to the create_model function.

labels

Not used in this function, but presumably we need to keep this slot here.

mode

Character; value such as "train", "infer", or "eval".

params

Not used in this function, but presumably we need to keep this slot here.

The output of model_fn is the result of a tf$contrib$tpu$TPUEstimatorSpec call.

This reference may be helpful: https://tensorflow.rstudio.com/tfestimators/articles/creating_estimators.html

Value

model_fn closure for TPUEstimator.

Examples

## Not run: 
with(tensorflow::tf$variable_scope("examples",
  reuse = tensorflow::tf$AUTO_REUSE
), {
  input_ids <- tensorflow::tf$constant(list(
    list(31L, 51L, 99L),
    list(15L, 5L, 0L)
  ))

  input_mask <- tensorflow::tf$constant(list(
    list(1L, 1L, 1L),
    list(1L, 1L, 0L)
  ))
  token_type_ids <- tensorflow::tf$constant(list(
    list(0L, 0L, 1L),
    list(0L, 2L, 0L)
  ))
  config <- BertConfig(
    vocab_size = 30522L,
    hidden_size = 768L,
    num_hidden_layers = 8L,
    type_vocab_size = 2L,
    num_attention_heads = 12L,
    intermediate_size = 3072L
  )

  temp_dir <- tempdir()
  init_checkpoint <- file.path(
    temp_dir,
    "BERT_checkpoints",
    "uncased_L-12_H-768_A-12",
    "bert_model.ckpt"
  )

  example_mod_fn <- model_fn_builder(
    bert_config = config,
    num_labels = 2L,
    init_checkpoint = init_checkpoint,
    learning_rate = 0.01,
    num_train_steps = 20L,
    num_warmup_steps = 10L,
    use_tpu = FALSE
  )
})

## End(Not run)

jonathanbratt/RBERT documentation built on Jan. 26, 2023, 4:15 p.m.