R/neuralTransferStyle.R

Defines functions neuralStyleTransfer

Documented in neuralStyleTransfer

#' Neural transfer style
#'
#' The popular neural style transfer described here:
#'
#'     https://arxiv.org/abs/1508.06576 and https://arxiv.org/abs/1605.04603
#'
#' and taken from François Chollet's implementation
#'
#'     https://keras.io/examples/generative/neural_style_transfer/
#'
#' and titu1994's modifications:
#'
#'     https://github.com/titu1994/Neural-Style-Transfer
#'
#' in order to possibly modify and experiment with medical images.
#'
#' @param contentImage ANTs image (1 or 3-component).  Content (or base) image.
#' @param styleImages ANTsImage or list of ANTsImages as the style (or reference)
#' image.
#' @param initialCombinationImage ANTsImage (1 or 3-component).  Starting point
#' for the optimization.  Allows one to start from the output from a previous
#' run.  Otherwise, start from the content image. Note that the original paper
#' starts with a noise image.
#' @param numberOfIterations Number of gradient steps taken during optimization.
#' @param learningRate Parameter for Adam optimization.
#' @param totalVariationWeight A penalty on the regularization term to keep the
#' features of the output image locally coherent.
#' @param contentWeight Weight of the content layers in the optimization function.
#' @param styleImageWeights float or vector of floats.  Weights of the style term
#' in the optimization function for each style image.  Can either specify a
#' single scalar to be used for all the images or one for each image.  The
#' style term computes the sum of the L2 norm between the Gram matrices of the
#' different layers (using ImageNet-trained VGG) of the style and content images.
#' @param contentLayerNames vector of strings. Names of VGG layers from which
#' to compute the content loss.
#' @param styleLayerNames vector of strings. Names of VGG layers from which to
#' compute the style loss.  If "all", the layers used are c('block1_conv1',
#' 'block1_conv2', 'block2_conv1', 'block2_conv2', 'block3_conv1', 'block3_conv2',
#' 'block3_conv3', 'block3_conv4', 'block4_conv1', 'block4_conv2', 'block4_conv3',
#' 'block4_conv4', 'block5_conv1', 'block5_conv2', 'block5_conv3', 'block5_conv4').
#' This is a proposed improvement from https://arxiv.org/abs/1605.04603.  In the
#' original implementation, the layers used are: c('block1_conv1', 'block2_conv1',
#' block3_conv1', 'block4_conv1', 'block5_conv1').
#' @param contentMask an ANTsImage mask to specify the region for content consideration.
#' @param styleMasks ANTsImage masks to specify the region for style consideration.
#' @param useShiftedActivations boolean to determine whether or not to use shifted
#' activations in calculating the Gram matrix (improvement mentioned in
#' https://arxiv.org/abs/1605.04603).
#' @param useChainedInference boolean corresponding to another proposed improvement
#' from https://arxiv.org/abs/1605.04603.
#' @param verbose boolean to print progress to the screen.
#' @param outputPrefix If specified, outputs a png image to disk at each iteration.
#' @return ANTs 3-component image.
#' @author Tustison, NJ
#' @examples
#' \dontrun{
#' library( ANTsRNet )
#'
#' }
#' @export
neuralStyleTransfer <- function(contentImage, styleImages,
  initialCombinationImage = NULL, numberOfIterations = 10,
  learningRate = 1.0, totalVariationWeight = 8.5e-5, contentWeight = 0.025,
  styleImageWeights = 1.0, contentLayerNames = c( 'block5_conv2' ),
  styleLayerNames = "all", contentMask = NULL, styleMasks = NULL,
  useShiftedActivations = TRUE, useChainedInference = TRUE,
  verbose = FALSE, outputPrefix = NULL )
{

  K <- keras::backend()
  tf <- tensorflow::tf

  preprocessAntsImage <- function( image, doScaleAndCenter = TRUE )
    {
    imageArray <- array( data = 0, dim = c( 1, dim( image ), 3 ) )
    if( image@components == 1 )
      {
      imageArray[1,,,1] <- as.array( image )
      imageArray[1,,,2] <- as.array( image )
      imageArray[1,,,3] <- as.array( image )
      } else if( image@components == 3 ) {
      imageChannels <- splitChannels( image )
      imageArray[1,,,1] <- as.array( imageChannels[[1]] )
      imageArray[1,,,2] <- as.array( imageChannels[[2]] )
      imageArray[1,,,3] <- as.array( imageChannels[[3]] )
      } else {
      stop( "Unexpected number of components." )
      }

    if( doScaleAndCenter == TRUE )
      {
      for( i in seq.int( 3 ) )
        {
        imageArray[1,,,i] <- ( imageArray[1,,,i] - min( imageArray[1,,,i] ) ) /
          ( max( imageArray[1,,,i] ) - min( imageArray[1,,,i] ) )
        }
      imageArray <- imageArray * 255
      # RGB -> BGR
      imageArray <- imageArray[,,,rev( seq.int( 3 ) ), drop = FALSE]
      imageArray[1,,,1] <- imageArray[1,,,1] - 103.939
      imageArray[1,,,2] <- imageArray[1,,,2] - 116.779
      imageArray[1,,,3] <- imageArray[1,,,3] - 123.68
      }
    return( imageArray )
    }

  postProcessArray <- function( imageArray, referenceImage )
    {
    imageArray <- drop( imageArray )
    imageArray[,,1] <- imageArray[,,1] + 103.939
    imageArray[,,2] <- imageArray[,,2] + 116.779
    imageArray[,,3] <- imageArray[,,3] + 123.68
    # BGR -> RGB
    imageArray <- imageArray[,,rev( seq.int( 3 ) ), drop = FALSE]
    imageArray[imageArray < 0] <- 0
    imageArray[imageArray > 255] <- 255
 
    imageChannels <- list()
    imageChannels[[1]] <- as.antsImage( drop( imageArray[,,1] ), reference = referenceImage )  
    imageChannels[[2]] <- as.antsImage( drop( imageArray[,,2] ), reference = referenceImage )  
    imageChannels[[3]] <- as.antsImage( drop( imageArray[,,3] ), reference = referenceImage )  

    image <- mergeChannels( imageChannels )
    return( image )
    }

  gramMatrix <- function( x, shiftedActivations = FALSE )
    {
    F <- K$batch_flatten( K$permute_dimensions( x, c( 2L, 0L, 1L ) ) )
    if( shiftedActivations )
      {
      F <- F - 1
      }
    gram <- K$dot( F, K$transpose( F ) )
    return( gram )
    }

  processMask <- function( mask, shape )
    {
    maskProcessed <- tf$image$resize( mask, size = c( shape[0], shape[1] ),
      method = tf$image$ResizeMethod$NEAREST_NEIGHBOR )
    maskProcessedTensor <- array( data = maskProcessed, dim = c( dim( mask ), shape[2] ) )
    for( i in range( shape[2] ) )
        maskProcessedTensor[,,i] = maskProcessed[,,0]
    return( maskProcessedTensor )
    }

  styleLoss <- function( styleFeatures, combinationFeatures, imageShape, styleMask = NULL, contentMask = NULL )
    {
    if( ! is.null( contentMask ) )
      {
      maskTensor <- K$variable( processMask( contentMask, combinationFeatures$shape ) )
      combinationFeatures <- combinationFeatures * K$stop_gradient( maskTensor )
      rm( maskTensor )
      }

    if( ! is.null( styleMask ) )
      {
      maskTensor <- K$variable( processMask( styleMask, styleFeatures$shape ) )
      styleFeatures <- styleFeatures * K$stop_gradient( maskTensor )
      if( ! is.null( contentMask ) )
        {
        combinationFeatures <- combinationFeatures * K$stop_gradient( maskTensor )
        }
      rm( maskTensor )
      }
    styleGram <- gramMatrix( styleFeatures, useShiftedActivations )
    contentGram <- gramMatrix( combinationFeatures, useShiftedActivations )
    size <- imageShape[1] * imageShape[2]
    numberOfChannels <- 3
    loss <- tf$reduce_sum( tf$square( styleGram - contentGram ) ) /
      ( 4.0 * numberOfChannels^2 * size^2 )
    return( loss )
    }

  contentLoss <- function( contentFeatures, combinationFeatures )
    {
    loss <- tf$reduce_sum( tf$square( contentFeatures - combinationFeatures ) )
    return( loss )
    }

  totalVariationLoss <- function( x )
    {
    shape <- x$shape
    a <- tf$square( x[, 1:( shape[[2]] - 1L ), 1:( shape[[3]] - 1L ),] - x[, 2:shape[[2]], 1:( shape[[3]] - 1L ),] )
    b <- tf$square( x[, 1:( shape[[2]] - 1L ), 1:( shape[[3]] - 1L ),] - x[, 1:( shape[[2]] - 1L ), 2:shape[[3]],] )
    loss <- tf$reduce_sum( tf$pow( a + b, 1.25 ) )
    return( loss )
    }

  computeTotalLoss <- function( contentArray, styleArrayList, combinationTensor,
                               featureModel, contentLayerNames, styleLayerIndices,
                               imageShape, contentMaskTensor = NULL, styleMaskTensorList = NULL )
    {
    numberOfStyleImages <- length( styleArrayList )

    inputArrays <- list()
    inputArrays[[1]] <- contentArray
    for( i in seq.int( numberOfStyleImages ) )
      {
      inputArrays[[i + 1]] <- styleArrayList[[i]]
      }
    inputArrays[[2 + numberOfStyleImages]] <- combinationTensor
    inputTensor <- tf$concat( inputArrays, axis = 0L )

    features <- featureModel( inputTensor )

    totalLoss <- tf$zeros( shape = list() )

    # content loss
    for( i in seq.int( length( contentLayerNames ) ) )
      {
      layerFeatures <- features[[contentLayerNames[i]]]
      contentFeatures <- layerFeatures[1,,,]
      combinationFeatures <- layerFeatures[3,,,]
      totalLoss <- totalLoss + contentLoss( contentFeatures, combinationFeatures ) *
        contentWeight / as.numeric( length( contentLayerNames ) )
      }

    # style loss
    if( useChainedInference )
      {
      for( i in seq.int( length( styleLayerIndices ) - 1 ) )
        {
        layerFeatures <- features[styleLayerIndices[i]][[1]]
        styleFeatures <- layerFeatures[2:( numberOfStyleImages + 1 ),,,]
        combinationFeatures <- layerFeatures[( numberOfStyleImages + 2 ),,,]
        loss <- list()
        for( j in seq.int( numberOfStyleImages ) )
          {
          if( is.null( styleMaskTensorList ) )
            {
            loss[[j]] <- styleLoss( styleFeatures[j,,,], combinationFeatures, imageShape,
                                    styleMask = NULL, contentMask = contentMaskTensor )
            } else {
            loss[[j]] <- styleLoss( styleFeatures[j,,,], combinationFeatures, imageShape,
                                    styleMask = styleMaskTensorList[[j]], contentMask = contentMaskTensor )
            }
          }

        layerFeatures = features[styleLayerIndices[i+1]][[1]]
        styleFeatures = layerFeatures[2:( numberOfStyleImages + 1 ),,,]
        combinationFeatures = layerFeatures[( numberOfStyleImages + 2 ),,,]
        lossP1 <- list()
        for( j in seq.int( numberOfStyleImages ) )
          {
          if( is.null( styleMaskTensorList ) )
            {
            lossP1[[j]] <- styleLoss( styleLoss( styleFeatures[j,,,], combinationFeatures, imageShape,
                                      styleMask = NULL, contentMask = contentMaskTensor ) )
            } else {
            lossP1[[j]] <- styleLoss( styleLoss( styleFeatures[j,,,], combinationFeatures, imageShape,
                                      styleMask = styleMaskTensorList[i], contentMask = contentMaskTensor ) )
            }
          }

        for( j in seq.int( numberOfStyleImages ) )
          {
          lossDifference <- loss[j] - lossP1[j]
          totalLoss <- totalLoss + styleImageWeights[j] * lossDifference / 
            ( 2^( as.numeric( length( styleLayerNames ) - ( i + 1 ) ) ) )
          }
        }
      } else {
      for( i in seq.int( length( styleLayerIndices ) ) )
        {
        layerFeatures <- features[styleLayerIndices[i]][[1]]
        styleFeatures <- layerFeatures[2:( numberOfStyleImages + 1 ),,,]
        combinationFeatures <- layerFeatures[( numberOfStyleImages + 2 ),,,]
        loss <- list()
        for( j in seq.int( numberOfStyleImages ) )
          {
          if( is.null( styleMaskTensorList ) )
            {
            loss[[j]] <- styleLoss( styleFeatures[j,,,], combinationFeatures, imageShape,
                                    styleMask = NULL, contentMask = contentMaskTensor )
            } else {
            loss[[j]] <- styleLoss( styleFeatures[j,,,], combinationFeatures, imageShape,
                                    styleMask = styleMaskTensorList[[j]], contentMask = contentMaskTensor )
            }
          }
        for( j in seq.int( numberOfStyleImages ) )
          {
          totalLoss <- totalLoss + ( loss[[j]] * styleImageWeights[j] / 
            as.numeric( length( styleLayerIndices ) ) )
          }
        }
      }
    totalLoss <- totalLoss + totalVariationWeight + tf$cast( totalVariationLoss( combinationTensor ), tf$float32 )
    return( totalLoss )
    }

  computeLossAndGradients <- function( contentArray, styleArrayList, combinationTensor,
                featureModel, contentLayerNames, styleLayerIndices, imageShape,
                contentMaskTensor, styleMaskTensorList )
    {
    with( tf$GradientTape() %as% tape, {
      loss <- computeTotalLoss( contentArray, styleArrayList, combinationTensor,
                    featureModel, contentLayerNames, styleLayerIndices, imageShape, contentMaskTensor,
                    styleMaskTensorList )
      })
    gradients <- tape$gradient( loss, combinationTensor )
    list( loss, gradients )
    }


  numberOfStyleImages <- 1
  if( is.list( styleImages ) )
    {
    numberOfStyleImages <- length( styleImages )
    }

  styleImageList <- list()
  if( numberOfStyleImages == 1 )
    {
    styleImageList[[1]] <- styleImages
    } else {
    styleImageList <- styleImages
    }

  for( i in seq.int( numberOfStyleImages ) )
    {
    if( styleImageList[[i]]@dimension != 2 )
      {
      stop( "Input style images must be 2-D." )
      }
    if( any( dim( styleImageList[[i]] ) != dim( contentImage ) ) )
      {
      stop( "Input images must have matching dimensions/shapes." )
      }
    }

  numberOfStyleMasks <- 0
  styleMaskTensorList <- NULL
  if( ! is.null( styleMasks ) )
    {
    numberOfStyleMasks <- 1
    if( is.list( styleMasks ) )
      {
      numberOfStyleMasks <- length( styleMasks )
      }

    styleMaskTensorList <- list()
    if( numberOfStyleMasks == 1 )
      {
      styleMaskArray <- as.array( thresholdImage( styleMasks, 0, 0, 0, 1 ) )
      styleMaskTensor <- array( data = styleMaskArray, dim = c( dim( styleMaskArray ), 1 ) )
      styleMaskTensorList[[1]] <- styleMaskTensor
      } else {
      for( i in seq.int( length( styleMasks ) ) )
        {
        styleMaskArray <- as.array( thresholdImage( styleMasks[[i]], 0, 0, 0, 1 ) )
        styleMaskTensor <- array( data = styleMaskArray, dim = c( dim( styleMaskArray ), 1 ) )
        styleMaskTensorList[[i]] <- styleMaskTensor
        }
      }
    }

  if( numberOfStyleMasks > 0 && numberOfStyleImages != numberOfStyleMasks )
    {
    stop( "The number of style images/masks are not the same." )
    }

  if( is.numeric( styleImageWeights ) )
    {
    styleImageWeights <- rep( styleImageWeights, length( styleImageList ) )
    } else {
    if( length( styleImageWeights ) == 1 )
      {
      styleImageWeights <- rep( styleImageWeights[1], length( styleImageList ) )
      } else if( length( styleImageWeights ) != length( styleImageList ) ) {
      stop( "Length of style weights must be 1 or the number of style images." )
      }
    }

  if( contentImage@dimension != 2 )
    {
    stop( "Input content image must be 2-D." )
    }

  contentMaskTensor <- NULL
  if( ! is.null( contentMask ) )
    {
    contentMaskArray <- as.array( thresholdImage( contentMask, 0, 0, 0, 1 ) )
    contentMaskTensor <- array( data = contentMaskArray, dim = c( dim( contentMaskArray ), 1 ) )
    }

  if( styleLayerNames == "all" )
    {
    styleLayerNames <- c( 'block1_conv1', 'block1_conv2', 'block2_conv1',
          'block2_conv2', 'block3_conv1', 'block3_conv2', 'block3_conv3',
          'block3_conv4', 'block4_conv1', 'block4_conv2', 'block4_conv3',
          'block4_conv4', 'block5_conv1', 'block5_conv2', 'block5_conv3',
          'block5_conv4')
    }

  model <- tf$keras$applications$VGG19( weights = "imagenet", include_top = FALSE )

  styleLayerIndices <- c()
  count <- 1
  for( i in seq.int( length( model$layers ) ) )
    {
    index <- which( model$layers[[i]]$name %in% styleLayerNames )
    if( length( index ) == 0 )
      {
      next
      }
    styleLayerIndices[count] <- i
    count <- count + 1
    }
  if( length( styleLayerIndices ) != length( styleLayerNames ) )
    {
    stop( "Style layer names don't match model." )
    }

  outputsList <- list()
  for( i in seq.int( model$layers ) )
    {
    outputsList[[i]] <- model$layers[[i]]$output
    }

  featureModel <- tf$keras$Model( inputs = model$inputs, outputs = outputsList )

  # Preprocess data
  contentArray <- preprocessAntsImage( contentImage )
  styleArrayList <- list()
  for( i in seq.int( numberOfStyleImages ) )
    {
    styleArrayList[[i]] <- preprocessAntsImage( styleImageList[[i]] )
    }

  imageShape <- c( dim( contentArray )[2:3], 3 )

  combinationTensor <- NULL
  if( is.null( initialCombinationImage ) )
    {
    combinationTensor <- tf$Variable( array( data = contentArray, dim = dim( contentArray ) ), dtype = tf$float32 )
    } else {
    initialCombinationTensor <- preprocessAntsImage( initialCombinationImage, doScaleAndCenter = FALSE )
    combinationTensor <- tf$Variable( initialCombinationTensor, dtype = tf$float32 )
    }

  if( any( imageShape != c( dim( combinationTensor )[2:3], 3 ) ) )
    {
    stop( "Initial combination image size does not match content image." )
    }

  optimizer <- tf$optimizers$Adam( learning_rate = learningRate, beta_1 = 0.99, epsilon = 0.1 )

  for( i in seq.int( numberOfIterations ) )
    {
    startTime <- Sys.time()
    c( loss, gradients ) %<-% computeLossAndGradients( contentArray, styleArrayList,
                          combinationTensor, featureModel, contentLayerNames,
                          styleLayerIndices, imageShape, contentMaskTensor,
                          styleMaskTensorList )
    endTime <- Sys.time()
    if( verbose == TRUE )
      {
      cat( "Iteration ", i, " of ", numberOfIterations, ": total loss = ", 
        as.numeric( loss ), 
        " (elapsed time = ", endTime - startTime, "s)\n",  sep = "" )
      }
    optimizer$apply_gradients( list( tuple( gradients, combinationTensor ) ) )

    if( ! is.null( outputPrefix ) )
      {
      combinationArray <- as.array( combinationTensor )
      combinationImage <- postProcessArray( combinationArray, contentImage )
      combinationRgb <- antsImageClone( combinationImage, out_pixeltype = 'unsigned char' )
      antsImageWrite( combinationRgb, paste0( outputPrefix, "_iteration", i, ".png" ) )
      }
    }

  combinationArray <- as.array( combinationTensor )
  combinationImage <- postProcessArray( combinationArray, contentImage )
  return( combinationImage )
}
ANTsX/ANTsRNet documentation built on April 28, 2024, 12:16 p.m.