downscaleTrain.keras: Train a deep model with keras in the climate4R framework

View source: R/downscaleTrain.keras.R

downscaleTrain.kerasR Documentation

Train a deep model with keras in the climate4R framework

Description

Train a deep model with keras in the climate4R framework.

Usage

downscaleTrain.keras(
  obj,
  model,
  compile.args = list(object = model),
  fit.args = list(object = model),
  clear.session = FALSE
)

Arguments

obj

The object as returned by prepareData.keras.

model

A keras sequential or functional model.

compile.args

List of arguments passed to compile function of keras. Some arguments are the loss function or the optimizer. An example could be: compile.args = list("loss" = "mse", "optimizer" = optimizer_adam(lr = 0.0001)). The default parameters are those used by default in the official compile keras function. Note that the loss == "mse" and optimizer = optimizer_adam() as DEFAULT.

fit.args

List of arguments passed to fit function of keras. Arguments are those encountered in the fit R documentation. An example could be: fit.args = list("batch_size" = 100,"epochs" = 50, "validation_split" = 0.1). The default parameters are those used by default in the official fit keras function.

clear.session

A logical value. Indicates whether we want to destroy the current tensorflow graph and clear the model from memory. In particular, refers to whether we want to use the function k_clear_session after training. If FALSE, model is returned. If TRUE, then k_clear_session() is applied and no model is returned. Default to FALSE.

Details

This function relies on keras, which is a high-level neural networks API capable of running on top of tensorflow, CNTK or theano. There are official keras tutorials regarding how to build deep learning models. We suggest the user, especially the beginners, to consult these tutorials before using downscaleTrain.keras.

Value

The infered keras model.

Author(s)

J. Bano-Medina

See Also

downscalePredict.keras for predicting with a keras model prepareNewData.keras for predictor preparation with new (test) data prepareData.keras for predictor preparation of training data downscaleR.keras Wiki

Other downscaling.functions: downscalePredict.keras(), relevanceMaps()

Examples


# Loading data
require(climate4R.datasets)
require(transformeR)
data("VALUE_Iberia_tas")
y <- VALUE_Iberia_tas
data("NCEP_Iberia_hus850", "NCEP_Iberia_psl", "NCEP_Iberia_ta850")
x <- makeMultiGrid(NCEP_Iberia_hus850, NCEP_Iberia_psl, NCEP_Iberia_ta850)
# We standardize the predictors using transformeR function scaleGrid
x <- scaleGrid(x,type = "standardize") 
# Preparing the predictors
data <- prepareData.keras(x = x, y = y, 
                          first.connection = "conv",
                          last.connection = "dense",
                          channels = "last")

# Defining the keras model.... 
# We define 3 hidden layers that consists on 
# 2 convolutional steps followed by a dense connection.
input_shape  <- dim(data$x.global)[-1]
output_shape  <- dim(data$y$Data)[2]
inputs <- layer_input(shape = input_shape)
hidden <- inputs %>% 
  layer_conv_2d(filters = 25, kernel_size = c(3,3), activation = 'relu') %>%  
  layer_conv_2d(filters = 10, kernel_size = c(3,3), activation = 'relu') %>% 
  layer_flatten() %>% 
  layer_dense(units = 10, activation = "relu")
outputs <- layer_dense(hidden,units = output_shape)
model <- keras_model(inputs = inputs, outputs = outputs)

# We can print model in console to observe its configuration
model

# Training the deep learning model
model <- downscaleTrain.keras(data,
               model = model,
               compile.args = list("loss" = "mse", 
               "optimizer" = optimizer_adam(lr = 0.01)),
               fit.args = list("epochs" = 30, "batch_size" = 100))

# Training a deep learning model 
# (saving the model using callbacks according to an early-stopping criteria)
downscaleTrain.keras(data,
           model = model,
           compile.args = list("loss" = "mse", 
           "optimizer" = optimizer_adam(lr = 0.01)),
           fit.args = list("epochs" = 50, "batch_size" = 100, 
                           "validation_split" = 0.1,
           "callbacks" = list(callback_early_stopping(patience = 10),
                callback_model_checkpoint(filepath=paste0(getwd(),"/model.h5"),
                monitor='val_loss', save_best_only=TRUE))),
           clear.session = TRUE)


SantanderMetGroup/downscaleR.keras documentation built on July 7, 2023, 1:22 p.m.