coef.citocnn: Retrieve parameters of a fitted CNN model

View source: R/cnn.R

coef.citocnnR Documentation

Retrieve parameters of a fitted CNN model

Description

This function returns the list of parameters (weights and biases) and buffers (e.g. running mean and variance of batch normalization layers) currently in use by the neural network model created using the cnn function.

Usage

## S3 method for class 'citocnn'
coef(object, ...)

Arguments

object

A model created by cnn.

...

Additional arguments (currently not used).

Value

A list with two components:

  • parameters: A list of the model's weights and biases for the currently used model epoch.

  • buffers: A list of buffers (e.g., running statistics) for the currently used model epoch.

Examples


if(torch::torch_is_installed()){
library(cito)

device <- ifelse(torch::cuda_is_available(), "cuda", "cpu")

set.seed(222)

## Data
shapes <- cito:::simulate_shapes(320, 28)
X <- shapes$data
Y <- shapes$labels

## Architecture
architecture <- create_architecture(conv(5), maxPool(), conv(5), maxPool(), linear(10))

## Build and train network
cnn.fit <- cnn(X, Y, architecture, loss = "softmax", epochs = 50, validation = 0.1, lr = 0.05, device=device)

# Weights of neural network
coef(cnn.fit)
}


citoverse/cito documentation built on Jan. 16, 2025, 11:49 p.m.