createDenseUnetModel2D: 2-D implementation of the dense U-net deep learning...

View source: R/createDenseUnetModel.R

createDenseUnetModel2DR Documentation

2-D implementation of the dense U-net deep learning architecture.

Description

Creates a keras model of the dense U-net deep learning architecture for image segmentation

Usage

createDenseUnetModel2D(
  inputImageSize,
  numberOfOutputs = 1L,
  numberOfLayersPerDenseBlock = c(6, 12, 36, 24),
  growthRate = 48,
  initialNumberOfFilters = 96,
  reductionRate = 0,
  depth = 7,
  dropoutRate = 0,
  weightDecay = 0.0001,
  mode = c("classification", "regression")
)

Arguments

inputImageSize

Used for specifying the input tensor shape. The shape (or dimension) of that tensor is the image dimensions followed by the number of channels (e.g., red, green, and blue).

numberOfOutputs

Meaning depends on the mode. For 'classification' this is the number of segmentation labels. For 'regression' this is the number of outputs.

numberOfLayersPerDenseBlock

number of dense blocks per layer.

growthRate

number of filters to add for each dense block layer (default = 48).

initialNumberOfFilters

number of filters at the beginning (default = 96).

reductionRate

reduction factor of transition blocks

depth

number of layers—must be equal to 3 * N + 4 where N is an integer (default = 7).

dropoutRate

drop out layer rate (default = 0.2).

weightDecay

weight decay (default = 1e-4).

mode

A switch to determine the activation function to use. If classification, then sigmoid or softmax depending on the numberOfOutputs, and if mode = "regression" then function is linear

Details

X. Li, H. Chen, X. Qi, Q. Dou, C.-W. Fu, P.-A. Heng. H-DenseUNet: Hybrid Densely Connected UNet for Liver and Tumor Segmentation from CT Volumes

available here:

    https://arxiv.org/pdf/1709.07330.pdf

with the author's implementation available at:

    https://github.com/xmengli999/H-DenseUNet

Value

an DenseUnet keras model

Author(s)

Tustison NJ

Examples

library( ANTsR )
library( ANTsRNet )
library( keras )

imageIDs <- c( "r16", "r27", "r30", "r62", "r64", "r85" )
trainingBatchSize <- length( imageIDs )

# Perform simple 3-tissue segmentation.

segmentationLabels <- c( 1, 2, 3 )
numberOfLabels <- length( segmentationLabels )
initialization <- paste0( 'KMeans[', numberOfLabels, ']' )

domainImage <- antsImageRead( getANTsRData( imageIDs[1] ) )

X_train <- array( data = NA, dim = c( trainingBatchSize, dim( domainImage ), 1 ) )
Y_train <- array( data = NA, dim = c( trainingBatchSize, dim( domainImage ) ) )

images <- list()
segmentations <- list()

for( i in seq_len( trainingBatchSize ) )
  {
  cat( "Processing image", imageIDs[i], "\n" )
  image <- antsImageRead( getANTsRData( imageIDs[i] ) )
  mask <- getMask( image )
  segmentation <- atropos( image, mask, initialization )$segmentation

  X_train[i,,, 1] <- as.array( image )
  Y_train[i,,] <- as.array( segmentation )
  }
Y_train <- encodeUnet( Y_train, segmentationLabels )

# Perform a simple normalization

X_train <- ( X_train - mean( X_train ) ) / sd( X_train )

# Create the model

model <- createDenseUnetModel2D( c( dim( domainImage ), 1L ),
  numberOfOutputs = numberOfLabels )

metric_multilabel_dice_coefficient <-
  custom_metric( "multilabel_dice_coefficient",
    multilabel_dice_coefficient )

loss_dice <- function( y_true, y_pred ) {
  -multilabel_dice_coefficient(y_true, y_pred)
}
attr(loss_dice, "py_function_name") <- "multilabel_dice_coefficient"

model %>% compile( loss = loss_dice,
  optimizer = optimizer_adam( lr = 0.0001 ),
  metrics = c( metric_multilabel_dice_coefficient,
    metric_categorical_crossentropy ) )

# Comment out the rest due to travis build constraints

# Fit the model

# checkpoint_file = tempfile(fileext = ".h5")
# track <- model %>% fit( X_train, Y_train,
#              epochs = 5, batch_size = 4, verbose = 1, shuffle = TRUE,
#              callbacks = list(
#                callback_model_checkpoint( checkpoint_file,
#                    monitor = 'val_loss', save_best_only = TRUE ),
#                callback_reduce_lr_on_plateau( monitor = "val_loss", factor = 0.1 )
#              ),
#              validation_split = 0.2 )


# Save the model and/or save the model weights

# save_model_hdf5( model, filepath = 'unetModel.h5' )
# save_model_weights_hdf5( unetModel, filepath = 'unetModelWeights.h5' ) )
rm(model); gc()

ANTsX/ANTsRNet documentation built on April 23, 2024, 1:24 p.m.