get_assignment_map_from_checkpoint: Compute the intersection of the current variables and...

View source: R/modeling.R

get_assignment_map_from_checkpointR Documentation

Compute the intersection of the current variables and checkpoint variables

Description

Returns the intersection (not the union, as python docs say -JDB) of the sets of variable names from the current graph and the checkpoint.

Usage

get_assignment_map_from_checkpoint(tvars, init_checkpoint)

Arguments

tvars

List of training variables in the current model.

init_checkpoint

Character; path to the checkpoint directory, plus checkpoint name stub (e.g. "bert_model.ckpt"). Path must be absolute and explicit, starting with "/".

Details

Note that a Tensorflow checkpoint is not the same as a saved model. A saved model contains a complete description of the computational graph and is sufficient to reconstruct the entire model, while a checkpoint contains just the parameter values (and variable names), and so requires a specification of the original model structure to reconstruct the computational graph. -JDB

Value

List with two elements: the assignment map and the initialized variable names. The assignment map is a list of the "base" variable names that are in both the current computational graph and the checkpoint. The initialized variable names list contains both the base names and the base names + ":0". (This seems redundant to me. I assume it will make sense later. -JDB)

Examples

## Not run: 
# Just for illustration: create a "model" with a couple variables
# that overlap some variable names in the BERT checkpoint.
with(tensorflow::tf$variable_scope("bert",
  reuse = tensorflow::tf$AUTO_REUSE
), {
  test_ten1 <- tensorflow::tf$get_variable(
    "encoder/layer_9/output/dense/bias",
    shape = c(1L, 2L, 3L)
  )
  test_ten2 <- tensorflow::tf$get_variable(
    "encoder/layer_9/output/dense/kernel",
    shape = c(1L, 2L, 3L)
  )
})
tvars <- tensorflow::tf$get_collection(
  tensorflow::tf$GraphKeys$GLOBAL_VARIABLES
)
temp_dir <- tempdir()
init_checkpoint <- file.path(
  temp_dir,
  "BERT_checkpoints",
  "uncased_L-12_H-768_A-12",
  "bert_model.ckpt"
)

amap <- get_assignment_map_from_checkpoint(tvars, init_checkpoint)

## End(Not run)

jonathanbratt/RBERT documentation built on Jan. 26, 2023, 4:15 p.m.