9. GCN Template#
We have provided a template to help get you started, as part of this assignment you will fill in the missing sections pertaining to the document
9.1. Q1. Load the dataset#
We will be using the MNISTSuperPixels dataset. See the original paper here: https://arxiv.org/pdf/1611.08402
We will not be going into the depth they did here, rather just using their datasets to get a feel for implementing Graph Convolutional Neural Networks on a familiar dataset.
import torch
from torch_geometric.datasets import MNISTSuperpixels
from torch_geometric.loader import DataLoader
from torch_geometric.utils import to_networkx
import matplotlib.pyplot as plt
import networkx as nx
import random
# 1. Load the dataset
dataset = MNISTSuperpixels(root='/tmp/MNISTSuperpixels')
# 2. Shuffle and split
# Use a traditional splitting of 75/15/15 (train/val/test)
######### Your code here ##############
#######################################
print(f"Train: {len(train_dataset)} | Val: {len(val_dataset)} | Test: {len(test_dataset)}")
# 3. Create your dataloaders
# Create training, validation and testing dataloaders
# You can use any batch size you want - it may effect your peformance.
######### Your code here ##############
#######################################
9.2. Visualize the dataset#
We want to visualize the dataset you have just created. Recall that these are no longer βimagesβ, but rather graph representations of MNIST.
You may want to run this a few times to see the different possible graphs that could occur.
def visualize_graph(data, title=None):
G = to_networkx(data, node_attrs=['x'], to_undirected=True)
# Extract 2D coordinates for layout (superpixel positions)
pos = {i: data.pos[i].numpy() for i in range(data.num_nodes)}
plt.figure(figsize=(4, 4))
nx.draw(G, pos, node_size=50, with_labels=False, node_color='skyblue')
if title:
plt.title(title)
plt.show()
#Pick some random samples to visualize
for i in range(3):
idx = random.randint(0, len(train_dataset)-1)
visualize_graph(train_dataset[idx], title=f"Label: {train_dataset[idx].y.item()}")
9.3. Q2. Create a basic GCN#
We will use torch_geometric here. There are a few layers we are going to need:
GCNConv: Graph Convolutional Layer from
torch_geometric.nn
.Global Mean Pooling: Aggregates node features to graph-level representation.
ReLU: Activation function after each layer (except the final one).
Linear: Fully connected layers for classification.
The network should take:
x: node features
edge_index: graph connectivity
batch: batch assignment vector (for pooling)
Input β GCNConv (input_dim β 128) β ReLU β GCNConv (128 β 64) β ReLU β GCNConv (64 β 128) β ReLU β Global Mean Pooling β Linear (128 β 64) β ReLU β Linear (64 β num_classes) β Output
Where output should be the raw logits (no final activation)
Instead of implementing the activation function as part of the class, i.e., self.activation = nn.ReLU() - directly apply it in the forward call:
As an example - self.fc(x).relu()
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GCNConv, global_mean_pool
class GCN(torch.nn.Module):
def __init__(self, input_dim,num_classes): # Pass input_features as an argument
super(GCN, self).__init__()
def forward(self, x, edge_index, batch):
return x
9.4. Q3. Implement the training procedure#
You will implement the training procedure and the validation procedure. We have provided you some hints for different things you should be calculating in the training portion. For the validation portion, we leave this up to you entirely. Make sure the values that you are seeing during training make sense in terms of magnitude - i.e., divisions by number of batches or number of elements is correct. Feel free to change this however works for you.
The pkbar import is a handy package that makes pytorch trainings more akin to tensorflowβs .fit() function in terms of output.
The training function will return the trained model, along with a dictionary called history that can be used for plotting your metrics during training.
import pkbar
def trainer(net, train_loader, val_loader, num_epochs=50, lr=1e-3, device='cuda'):
# Setup random seed
torch.manual_seed(8)
torch.cuda.manual_seed(8)
history = {'train_loss':[], 'val_loss':[], 'train_acc':[], 'val_acc':[]}
print("Training Size: {0}".format(len(train_loader.dataset)))
print("Validation Size: {0}".format(len(val_loader.dataset)))
# Create your optimizer
print('=========== Optimizer ==================:')
print(' LR:', lr)
print(' num_epochs:', num_epochs)
print('')
# Define your loss function, we are doing multiclass classification remember
CCE = torch.nn.CrossEntropyLoss()
for epoch in range(num_epochs):
# Progress bar setup
kbar = pkbar.Kbar(target=len(train_loader), epoch=epoch, num_epochs=num_epochs, width=20, always_stateful=False)
net.train() # Set the model to training mode
running_loss = 0.0
running_acc = 0.0
for i, data in enumerate(train_loader):
data = data.to(device) # Move data to the specified device (e.g., GPU)
optimizer.zero_grad()
# Forward pass of your models
# You can access different variables such as data.x, data.edge_index,data.batch,data.y
logits =
# We want to monitor our accuracy during training
pred =
train_acc =
# Calculate your loss
loss =
# Backward pass and optimization
loss.backward()
optimizer.step()
# Note we have per batch, and running metrics
running_loss += loss.item() * data.num_graphs
running_acc += train_acc * data.num_graphs
kbar.update(i, values=[("loss", loss.item()),("acc:",train_acc)])
# Track training loss
history['train_loss'].append(running_loss / len(train_loader.dataset))
history['train_acc'].append(running_acc / len(train_loader.dataset))
######################
## Validation phase ##
######################
net.eval() # Set the model to evaluation mode
val_loss = 0.0
val_acc = 0.0
with torch.no_grad():
for i, data in enumerate(val_loader):
data = data.to(device) # Move data to the specified device (e.g., GPU)
# Forward pass
out =
# Compute validation metrics
loss =
val_loss +=
pred =
val_acc +=
# Average validation loss
val_loss /= len(val_loader)
val_acc /= len(val_loader)
# Track validation loss
history['val_loss'].append(val_loss)
history['val_acc'].append(val_acc)
kbar.add(1, values=[("val_loss", loss.item()),("val_acc:",val_acc)])
return net, history
# Instantiate and train your model
# What is your input size? How many classes do we have?
input_dim =
num_classes =
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = GCN(input_dim=input_dim, num_classes=num_classes).to(device)
model, history = trainer(model, train_loader, val_loader, num_epochs=50, lr=0.001, device=device)
9.5. Q4. Plotting#
Plot the loss and accuracy curves using the history from training. Make sure to overlay both training and validation. Provide analysis on potential issues you see if any.
def plot_loss(history):
# Two inividual plots, one for losses and one for accuracy
plot_loss(history)
9.6. Q5. Implement a function to evaluate your model on the testing dataset#
We want to return two things:
Test accuracy
Confusion matrix on the test set (plot)
Hint: see what we have imported from sklearn and view the documentation. You might find some useful functions.
Provide some analysis as to what you see in terms of performance? Is this surprising? Are there biases towards any specific classes? Why?
import sklearn.metrics as metrics
import seaborn as sns
import matplotlib.pyplot as plt
import numpy as np
def evaluate_model(model, loader):
model.eval()
all_preds = [] # You will want to append to these and eventually combine into a singular array
all_labels = []
with torch.no_grad():
for data in loader:
data = data.to(device)
# Recall your model is going to return raw logits over N classes
# How do we obtain the prediction for N classes?
# Do we need to softmax here?
# Calculate confusion matrix
conf_matrix =
# Calculate accuracy
accuracy =
# Plot confusion matrix
fig, ax = plt.subplots(figsize=(8, 6))
sns.heatmap(conf_matrix, annot=True, fmt="d", cmap="Blues", xticklabels=np.unique(all_labels), yticklabels=np.unique(all_labels))
ax.set_xlabel('Predicted Labels')
ax.set_ylabel('True Labels')
ax.set_title('Confusion Matrix')
plt.show()
return accuracy, conf_matrix
# Evaluate the model on the test set
test_accuracy, conf_matrix = evaluate_model(model, test_loader)
# Print the results
print(f"Test Accuracy: {test_accuracy:.4f}")
9.7. Q6. Class wise accuracy#
The confusion matrix gives us a good indication of class wise performance, but might not be the easiest thing to look at. Lets instead provide the accuracy class wise, which is more easily interpetable perhaps.
While we could make this cleaner (i.e., combining the above function with the one below) and reduce computation, the dataset is small and therefore we are not worried. You can reuse some of the above function here, or if you want you can simply combine the two functions.
def class_wise_accuracy(model, loader):
model.eval()
all_preds = []
all_labels = []
with torch.no_grad():
for data in loader:
data = data.to(device)
# Calculate class-wise accuracy
class_acc = []
for class_idx in np.unique(all_labels):
return np.array(class_acc)
# Calculate the class-wise accuracy on the test set
test_class_wise_acc = class_wise_accuracy(model, test_loader)
# Print the class-wise accuracy for each class
print("Class-wise accuracy:")
for i, acc in enumerate(test_class_wise_acc):
print(f"Class {i}: {acc:.4f}")
9.8. Q7. Training with positional information of nodes#
In previous experiments, we have neglected information that might be very important for our task of classifiying digits - node position.
Modify the training script from above to also utilize this information in the form:
data.x = torch.cat([data.x,data.pos],dim=1)
You will need to think about how the shape of your input has changed with the addition of this new information
Train a new model with this additional information and provide the same metrics as above:
plot_loss()
- no changes requiredevaluate_model()
- changes required for inputsclass_wise_accuracy()
- changes required for inputs
Provide analysis on how this additional information effects the performance of your model. Is this helpful information? Why?
# Instantiate and train your model
# What is your input size? How many classes do we have?
input_dim =
num_classes =
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = GCN(input_dim=input_dim, num_classes=num_classes).to(device)
model, history = trainer(model, train_loader, val_loader, num_epochs=50, lr=0.001, device=device)
9.9. Q8. Optimize your model#
Lets see how performative you can make your model. You are free to make any design choices you like, as well as changing hyperparameters. Provided detailed summaries of the choices you have made.
9.10. Bonus Question:#
Implement a CNN design of your choice and compare performance. Are graph structures the optimal way of representing the data?
For those interested, you can find the SOTA MNIST models here: https://paperswithcode.com/sota/image-classification-on-mnist