resume_training_from_model_card: Continue training from model card

View source: R/train.R

resume_training_from_model_cardR Documentation

Continue training from model card

Description

Use information from model card to resume from the corresponding checkpoint using the same training arguments.

Usage

resume_training_from_model_card(
  path_model_card,
  seed = NULL,
  epoch = NULL,
  new_run_name = NULL,
  new_args = NULL,
  new_compile = NULL,
  use_mirrored_strategy = NULL,
  unfreeze = FALSE,
  verbose = FALSE
)

Arguments

path_model_card

Path to model card to resume training from.

seed

Seed for reproducible results. If NULL, set random seed.

epoch

Epoch to resume from. If NULL, use last epoch.

new_run_name

New run name. If NULL, new run name is old run name + '_cont'.

new_args

Named list of arguments to overwrite. Will use previous arguments from model card otherwise. For example, if you want to change the batch size and padding option: new_args = list(batch_size = 6, padding = TRUE).

new_compile

List of arguments to compile the model again. If NULL, use compiled model from checkpoint. Example: new_compile = list(loss = 'binary_crossentropy', metrics = 'acc', optimizer = keras::optimizer_adam())

use_mirrored_strategy

Whether to use distributed mirrored strategy. If NULL, will use distributed mirrored strategy only if >1 GPU available.

unfreeze

If TRUE, set trainable attribute of model to TRUE (unfreeze weights).

verbose

Whether to print all training arguments.

Value

A list of training metrics.

Examples


library(keras)
# create dummy data and temp directories
path_train_1 <- tempfile()
path_train_2 <- tempfile()
path_val_1 <- tempfile()
path_val_2 <- tempfile()
path_checkpoint <- tempfile()
dir.create(path_checkpoint)
path_model_card <- tempfile()
dir.create(path_model_card)

for (current_path in c(path_train_1, path_train_2,
                       path_val_1, path_val_2)) {
  dir.create(current_path)
  create_dummy_data(file_path = current_path,
                    num_files = 3,
                    seq_length = 10,
                    num_seq = 5,
                    vocabulary = c("a", "c", "g", "t"))
}

# create model
model <- create_model_lstm_cnn(layer_lstm = 8, layer_dense = 2, maxlen = 5)

# train model
run_name <- 'test_run_1'
hist <- train_model(train_type = "label_folder",
                    run_name = run_name,
                    path_checkpoint = path_checkpoint,
                    model_card = list(path_model_card = path_model_card, description = 'test run'),
                    model = model,
                    path = c(path_train_1, path_train_2),
                    path_val = c(path_val_1, path_val_2),
                    batch_size = 8,
                    epochs = 3,
                    steps_per_epoch = 6,
                    vocabulary_label = c("label_1", "label_2"))

# resume training
resume_training_from_model_card(path_model_card = file.path(path_model_card, run_name))


GenomeNet/deepG documentation built on Dec. 24, 2024, 12:11 p.m.