View source: R/run_classifier.R
model_fn_builder | R Documentation |
model_fn
closure for TPUEstimator
Returns model_fn
closure, which is an input to TPUEstimator
.
model_fn_builder( bert_config, num_labels, init_checkpoint, learning_rate, num_train_steps, num_warmup_steps, use_tpu )
bert_config |
|
num_labels |
Integer; number of classification labels. |
init_checkpoint |
Character; path to the checkpoint directory, plus checkpoint name stub (e.g. "bert_model.ckpt"). Path must be absolute and explicit, starting with "/". |
learning_rate |
Numeric; the learning rate. |
num_train_steps |
Integer; number of steps to train for. |
num_warmup_steps |
Integer; number of steps to use for "warm-up". |
use_tpu |
Logical; whether to use TPU. |
The model_fn
function takes four parameters:
A list (or similar structure) that contains objects such as
input_ids
, input_mask
, segment_ids
, and
label_ids
. These objects will be inputs to the create_model
function.
Not used in this function, but presumably we need to keep this slot here.
Character; value such as "train", "infer", or "eval".
Not used in this function, but presumably we need to keep this slot here.
The output of model_fn
is the result of a
tf$contrib$tpu$TPUEstimatorSpec
call.
This reference may be helpful: https://tensorflow.rstudio.com/tfestimators/articles/creating_estimators.html
model_fn
closure for TPUEstimator
.
## Not run: with(tensorflow::tf$variable_scope("examples", reuse = tensorflow::tf$AUTO_REUSE ), { input_ids <- tensorflow::tf$constant(list( list(31L, 51L, 99L), list(15L, 5L, 0L) )) input_mask <- tensorflow::tf$constant(list( list(1L, 1L, 1L), list(1L, 1L, 0L) )) token_type_ids <- tensorflow::tf$constant(list( list(0L, 0L, 1L), list(0L, 2L, 0L) )) config <- BertConfig( vocab_size = 30522L, hidden_size = 768L, num_hidden_layers = 8L, type_vocab_size = 2L, num_attention_heads = 12L, intermediate_size = 3072L ) temp_dir <- tempdir() init_checkpoint <- file.path( temp_dir, "BERT_checkpoints", "uncased_L-12_H-768_A-12", "bert_model.ckpt" ) example_mod_fn <- model_fn_builder( bert_config = config, num_labels = 2L, init_checkpoint = init_checkpoint, learning_rate = 0.01, num_train_steps = 20L, num_warmup_steps = 10L, use_tpu = FALSE ) }) ## End(Not run)
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.