cnn_learner: Cnn_learner

View source: R/cnn_learner.R

cnn_learnerR Documentation

Cnn_learner

Description

Build a convnet style learner from 'dls' and 'arch'

Usage

cnn_learner(
  dls,
  arch,
  loss_func = NULL,
  pretrained = TRUE,
  cut = NULL,
  splitter = NULL,
  y_range = NULL,
  config = NULL,
  n_out = NULL,
  normalize = TRUE,
  opt_func = Adam(),
  lr = 0.001,
  cbs = NULL,
  metrics = NULL,
  path = NULL,
  model_dir = "models",
  wd = NULL,
  wd_bn_bias = FALSE,
  train_bn = TRUE,
  moms = list(0.95, 0.85, 0.95)
)

Arguments

dls

data loader object

arch

a model architecture

loss_func

loss function

pretrained

pre-trained or not

cut

cut

splitter

It is a function that takes self.model and returns a list of parameter groups (or just one parameter group if there are no different parameter groups).

y_range

y_range

config

configuration

n_out

the number of out

normalize

normalize

opt_func

The function used to create the optimizer

lr

learning rate

cbs

Cbs is one or a list of Callbacks to pass to the Learner.

metrics

It is an optional list of metrics, that can be either functions or Metrics.

path

The folder where to work

model_dir

Path and model_dir are used to save and/or load models.

wd

It is the default weight decay used when training the model.

wd_bn_bias

It controls if weight decay is applied to BatchNorm layers and bias.

train_bn

It controls if BatchNorm layers are trained even when they are supposed to be frozen according to the splitter.

moms

The default momentums used in Learner.fit_one_cycle.

Value

learner object

Examples


## Not run: 

URLs_MNIST_SAMPLE()
# transformations
tfms = aug_transforms(do_flip = FALSE)
path = 'mnist_sample'
bs = 20

#load into memory
data = ImageDataLoaders_from_folder(path, batch_tfms = tfms, size = 26, bs = bs)


learn = cnn_learner(data, resnet18(), metrics = accuracy, path = getwd())


## End(Not run)


fastai documentation built on June 22, 2024, 11:15 a.m.