#'TVLoss
#'
#'@export
#'@import reticulate
#'@references \insertRef{Gatys2016}{neuralstyleR}
#'
GramMatrix<-function(){
GramMatrixMain = py_run_string(
"
import torch.nn as nn
import torch
class GramMatrix(nn.Module):
def __init__(self):
super(GramMatrix, self).__init__()
def forward(self, input):
assert input.dim() == 4
C, H, W = input.size(1), input.size(2), input.size(3)
x_flat = input.view(C, H * W)
#self.output.resize(C, C)
#self.output.mm(x_flat, x_flat.t())
self.output = torch.mm(x_flat, x_flat.t())
return self.output
def backward(self, input, gradOutput):
assert input.dim() == 4
C, H, W = input.size(1), input.size(2), input.size(3)
x_flat = input.view(C, H * W)
self.gradInput.resize(C, H * W).mm(gradOutput, x_flat)
self.gradInput.addmm(gradOutput.t(), x_flat)
self.gradInput = self.gradInput.view(C, H, W)
return self.gradInput")
return(GramMatrixMain$GramMatrix())
}
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.