#'Total Variation Loss
#'
#'@export
#'@import reticulate
#'@references \insertRef{Gatys2016}{neuralstyleR}
#'
TVLoss<-function(tvWeight = 1e-3){
tvLossMain = py_run_string(
"
import torch.nn as nn
import torch
class TVLoss(nn.Module):
def __init__(self, strength):
super(TVLoss, self).__init__()
self.strength = strength
self.x_diff = torch.Tensor()
self.y_diff = torch.Tensor()
def forward(self, input):
self.output = input
return self.output
def backward(self, input, gradOutput):
self.gradInput = zeros_like(input)
C, H, W = input.size(1), input.size(2), input.size(3)
self.x_diff = input[:, :, 1:-2, 1:-2] - input[:, :, 1:-2, 2:-1]
self.y_diff = input[:, :, 1:-2, 1:-2] - input[:, :, 2:-1, 1:-2]
self.gradInput = self.gradInput[ : , : , 1: -2, 1: -2].add(self.x_diff).add(self.y_diff)
self.gradInput = self.gradInput[:, :, 1 : -2, 2: -1].add(-1, self.x_diff)
self.gradInput = self.gradInput[:, :, 2 : -1, 1 : -2].add(-1, self.y_diff)
self.gradInput = self.gradInput.mul(self.strength)
self.gradInput = self.gradInput.add(gradOutput)
return self.gradInput")
return(tvLossMain$TVLoss(tvWeight))
}
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.