View source: R/run_classifier.R
| model_fn_builder | R Documentation |
model_fn closure for TPUEstimatorReturns model_fn closure, which is an input to TPUEstimator.
model_fn_builder( bert_config, num_labels, init_checkpoint, learning_rate, num_train_steps, num_warmup_steps, use_tpu )
bert_config |
|
num_labels |
Integer; number of classification labels. |
init_checkpoint |
Character; path to the checkpoint directory, plus checkpoint name stub (e.g. "bert_model.ckpt"). Path must be absolute and explicit, starting with "/". |
learning_rate |
Numeric; the learning rate. |
num_train_steps |
Integer; number of steps to train for. |
num_warmup_steps |
Integer; number of steps to use for "warm-up". |
use_tpu |
Logical; whether to use TPU. |
The model_fn function takes four parameters:
A list (or similar structure) that contains objects such as
input_ids, input_mask, segment_ids, and
label_ids. These objects will be inputs to the create_model
function.
Not used in this function, but presumably we need to keep this slot here.
Character; value such as "train", "infer", or "eval".
Not used in this function, but presumably we need to keep this slot here.
The output of model_fn is the result of a
tf$contrib$tpu$TPUEstimatorSpec call.
This reference may be helpful: https://tensorflow.rstudio.com/tfestimators/articles/creating_estimators.html
model_fn closure for TPUEstimator.
## Not run:
with(tensorflow::tf$variable_scope("examples",
reuse = tensorflow::tf$AUTO_REUSE
), {
input_ids <- tensorflow::tf$constant(list(
list(31L, 51L, 99L),
list(15L, 5L, 0L)
))
input_mask <- tensorflow::tf$constant(list(
list(1L, 1L, 1L),
list(1L, 1L, 0L)
))
token_type_ids <- tensorflow::tf$constant(list(
list(0L, 0L, 1L),
list(0L, 2L, 0L)
))
config <- BertConfig(
vocab_size = 30522L,
hidden_size = 768L,
num_hidden_layers = 8L,
type_vocab_size = 2L,
num_attention_heads = 12L,
intermediate_size = 3072L
)
temp_dir <- tempdir()
init_checkpoint <- file.path(
temp_dir,
"BERT_checkpoints",
"uncased_L-12_H-768_A-12",
"bert_model.ckpt"
)
example_mod_fn <- model_fn_builder(
bert_config = config,
num_labels = 2L,
init_checkpoint = init_checkpoint,
learning_rate = 0.01,
num_train_steps = 20L,
num_warmup_steps = 10L,
use_tpu = FALSE
)
})
## End(Not run)
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.