R/GramMatrix.R

#'TVLoss
#'
#'@export
#'@import reticulate
#'@references \insertRef{Gatys2016}{neuralstyleR}
#'
GramMatrix<-function(){

    GramMatrixMain = py_run_string(
        "
import torch.nn as nn
import torch
class GramMatrix(nn.Module):
        def __init__(self):
            super(GramMatrix, self).__init__()
        def forward(self, input):
            assert input.dim() == 4
            C, H, W = input.size(1), input.size(2), input.size(3)
            x_flat = input.view(C, H * W)
            #self.output.resize(C, C)
            #self.output.mm(x_flat, x_flat.t())
            self.output = torch.mm(x_flat, x_flat.t())
            return self.output
        def backward(self, input, gradOutput):
            assert input.dim() == 4
            C, H, W = input.size(1), input.size(2), input.size(3)
            x_flat = input.view(C, H * W)
            self.gradInput.resize(C, H * W).mm(gradOutput, x_flat)
            self.gradInput.addmm(gradOutput.t(), x_flat)
            self.gradInput = self.gradInput.view(C, H, W)
            return self.gradInput")

    return(GramMatrixMain$GramMatrix())

}
David-J-R/neuralstyleR documentation built on May 8, 2019, 1:54 p.m.