knitr::opts_chunk$set(
  collapse = TRUE,
  comment = "#>"
)
library(keras)
library(rspektral)
library(tfdatasets)
c(A, X, y, train_mask, val_mask, test_mask) %<-% dataset_citations()
channels <- 16           # Number of channels in the first layer
N <- dim(X)[1]           # Number of nodes in the graph
F <- dim(X)[2]           # Original size of node features
n_classes <- dim(y)[2]   # Number of classes
dropout <- 0.5           # Dropout rate for the features
l2_reg <- 5e-4 / 2       # L2 regularization rate
learning_rate <- 1e-2    # Learning rate
epochs <- 300            # Number of training epochs
es_patience <- 10        # Patience for early stopping
# Preprocessing operations
fltr <- preprocess_laplacian_mod(A, densify = FALSE)
dim(fltr)
X <- as.matrix(X)
X_in <- layer_input(shape = c(F))
fltr_in <- layer_input(shape = c(N), sparse = TRUE)

dropout_1 <- X_in %>%
  layer_dropout(rate = dropout)

graph_conv_1 <- list(dropout_1, fltr_in) %>%
  layer_graph_conv(channels,
                   activation = 'relu',
                   kernel_regularizer = regularizer_l2(l2_reg),
                   use_bias = FALSE)

dropout_2 <-  graph_conv_1 %>%
  layer_dropout(dropout)

graph_conv_2 <- list(dropout_2, fltr_in) %>%
  layer_graph_conv(n_classes,
                   activation = "softmax",
                   use_bias = FALSE)
model <- keras_model(inputs = list(X_in, fltr_in), outputs = graph_conv_2)

optimizer <- optimizer_adam(lr = learning_rate)

model %>%
  compile(optimizer = optimizer,
          loss = "categorical_crossentropy",
          weighted_metrics = "acc")

model
validation_data <- list(list(X, fltr), y, val_mask)

history <- model %>%
  fit(x = list(X, fltr),
      y = y,
      sample_weight = train_mask,
      epochs = epochs,
      batch_size = N,
      validation_data = validation_data,
      shuffle = FALSE, # Shuffling data means shuffling the whole graph
      callbacks = callback_early_stopping(patience = es_patience,
                                          restore_best_weights = TRUE),
      verbose = 0,
      view_metrics = TRUE)

plot(history)
eval_results <- model %>%
  evaluate(list(X, fltr),
           y,
           sample_weight = test_mask,
           batch_size = N)

eval_results  
gen <- tensors_dataset(list(inputs =
                              list(tf$convert_to_tensor(X), 
                                   tf$convert_to_tensor(fltr)), 
                            targets = tf$convert_to_tensor(y), 
                            sample_weights = tf$convert_to_tensor(train_mask)))

make_data_gen <- function(inputs, targets, sample_weights) {
  dat <- list(inputs = inputs, targets = targets, sample_weights = sample_weights)
  function() {
    dat
  }
}

genit <- make_data_gen(list(X, fltr), y, train_mask)

dat <- list(inputs = list(X, fltr), targets = y, sample_weights = train_mask)
gen <- function() {
  dat
}

gen <- keras:::as_generator.function(genit)

validation_data <- list(list(X, fltr), y, val_mask)

model %>%
  fit_generator(generator = gen,
                steps_per_epoch = 100,
                epochs = 5,
                #validation_data = validation_data,
                verbose = 1,
                view_metrics = TRUE)


rdinnager/rspektral documentation built on June 12, 2021, 1:26 a.m.