knitr::opts_chunk$set( collapse = TRUE, comment = "#>" )
library(keras) library(rspektral)
c(A, X, y, train_mask, val_mask, test_mask) %<-% dataset_citations()
channels <- 16 # Number of channels in the first layer K <- 2 # Max degree of the Chebyshev polynomials N <- dim(X)[1] # Number of nodes in the graph F <- dim(X)[2] # Original size of node features n_classes <- dim(y)[2] # Number of classes dropout <- 0.5 # Dropout rate for the features l2_reg <- 5e-4 / 2 # L2 regularization rate learning_rate <- 1e-2 # Learning rate epochs <- 300 # Number of training epochs es_patience <- 10 # Patience for early stopping
# Preprocessing operations fltr <- preprocess_chebyshev(A, densify = FALSE) dim(fltr) X <- as.matrix(X)
X_in <- layer_input(shape = c(F)) fltr_in <- layer_input(shape = c(N), sparse = TRUE) dropout_1 <- X_in %>% layer_dropout(rate = dropout) graph_conv_1 <- list(dropout_1, fltr_in) %>% layer_cheb_conv(channels, K = K, activation = 'relu', kernel_regularizer = regularizer_l2(l2_reg), use_bias = FALSE) dropout_2 <- graph_conv_1 %>% layer_dropout(dropout) graph_conv_2 <- list(dropout_2, fltr_in) %>% layer_cheb_conv(n_classes, K = K, activation = "softmax", use_bias = FALSE)
model <- keras_model(inputs = list(X_in, fltr_in), outputs = graph_conv_2) optimizer <- optimizer_adam(lr = learning_rate) model %>% compile(optimizer = optimizer, loss = "categorical_crossentropy", weighted_metrics = "acc") model
validation_data <- list(list(X, fltr), y, val_mask) history <- model %>% fit(x = list(X, fltr), y = y, sample_weight = train_mask, epochs = epochs, batch_size = N, validation_data = validation_data, shuffle = FALSE, # Shuffling data means shuffling the whole graph callbacks = callback_early_stopping(patience = es_patience, restore_best_weights = TRUE), verbose = 0, view_metrics = TRUE) plot(history)
eval_results <- model %>% evaluate(list(X, fltr), y, sample_weight = test_mask, batch_size = N) eval_results
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.