R/TVLoss.R

#'Total Variation Loss
#'
#'@export
#'@import reticulate
#'@references \insertRef{Gatys2016}{neuralstyleR}
#'
TVLoss<-function(tvWeight = 1e-3){

    tvLossMain = py_run_string(
        "
import torch.nn as nn
import torch
class TVLoss(nn.Module):
        def __init__(self, strength):
            super(TVLoss, self).__init__()
            self.strength = strength
            self.x_diff = torch.Tensor()
            self.y_diff = torch.Tensor()
        def forward(self, input):
            self.output = input
            return self.output
        def backward(self, input, gradOutput):
            self.gradInput = zeros_like(input)
            C, H, W = input.size(1), input.size(2), input.size(3)



            self.x_diff = input[:, :, 1:-2, 1:-2] - input[:, :, 1:-2, 2:-1]
            self.y_diff = input[:, :, 1:-2, 1:-2] - input[:, :, 2:-1, 1:-2]


            self.gradInput = self.gradInput[ : , : , 1: -2, 1: -2].add(self.x_diff).add(self.y_diff)
            self.gradInput = self.gradInput[:, :, 1 : -2, 2: -1].add(-1, self.x_diff)
            self.gradInput = self.gradInput[:, :, 2 : -1, 1 : -2].add(-1, self.y_diff)
            self.gradInput = self.gradInput.mul(self.strength)
            self.gradInput = self.gradInput.add(gradOutput)
            return self.gradInput")

    return(tvLossMain$TVLoss(tvWeight))

}
David-J-R/neuralstyleR documentation built on May 8, 2019, 1:54 p.m.