9. GCN Template#

We have provided a template to help get you started, as part of this assignment you will fill in the missing sections pertaining to the document

9.1. Q1. Load the dataset#

We will be using the MNISTSuperPixels dataset. See the original paper here: https://arxiv.org/pdf/1611.08402

We will not be going into the depth they did here, rather just using their datasets to get a feel for implementing Graph Convolutional Neural Networks on a familiar dataset.

import torch
from torch_geometric.datasets import MNISTSuperpixels
from torch_geometric.loader import DataLoader
from torch_geometric.utils import to_networkx

import matplotlib.pyplot as plt
import networkx as nx
import random

# 1. Load the dataset
dataset = MNISTSuperpixels(root='/tmp/MNISTSuperpixels')

# 2. Shuffle and split
# Use a traditional splitting of 75/15/15 (train/val/test)
######### Your code here ##############




#######################################


print(f"Train: {len(train_dataset)} | Val: {len(val_dataset)} | Test: {len(test_dataset)}")

# 3. Create your dataloaders
# Create training, validation and testing dataloaders
# You can use any batch size you want - it may effect your peformance.
######### Your code here ##############




#######################################

9.2. Visualize the dataset#

We want to visualize the dataset you have just created. Recall that these are no longer β€œimages”, but rather graph representations of MNIST.

You may want to run this a few times to see the different possible graphs that could occur.

def visualize_graph(data, title=None):
    G = to_networkx(data, node_attrs=['x'], to_undirected=True)

    # Extract 2D coordinates for layout (superpixel positions)
    pos = {i: data.pos[i].numpy() for i in range(data.num_nodes)}

    plt.figure(figsize=(4, 4))
    nx.draw(G, pos, node_size=50, with_labels=False, node_color='skyblue')
    if title:
        plt.title(title)
    plt.show()

#Pick some random samples to visualize
for i in range(3):
    idx = random.randint(0, len(train_dataset)-1)
    visualize_graph(train_dataset[idx], title=f"Label: {train_dataset[idx].y.item()}")

9.3. Q2. Create a basic GCN#

We will use torch_geometric here. There are a few layers we are going to need:

  • GCNConv: Graph Convolutional Layer from torch_geometric.nn.

  • Global Mean Pooling: Aggregates node features to graph-level representation.

  • ReLU: Activation function after each layer (except the final one).

  • Linear: Fully connected layers for classification.

The network should take:

  • x: node features

  • edge_index: graph connectivity

  • batch: batch assignment vector (for pooling)

    Input β†’ GCNConv (input_dim β†’ 128) β†’ ReLU  
    
    β†’ GCNConv (128 β†’ 64) β†’ ReLU  
    
    β†’ GCNConv (64 β†’ 128) β†’ ReLU
    
    β†’ Global Mean Pooling  
    
    β†’ Linear (128 β†’ 64) β†’ ReLU  
    
    β†’ Linear (64 β†’ num_classes) β†’ Output
    

Where output should be the raw logits (no final activation)

Instead of implementing the activation function as part of the class, i.e., self.activation = nn.ReLU() - directly apply it in the forward call:

As an example - self.fc(x).relu()

import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GCNConv, global_mean_pool


class GCN(torch.nn.Module):
    def __init__(self, input_dim,num_classes): # Pass input_features as an argument
        super(GCN, self).__init__()


    def forward(self, x, edge_index, batch):

        return x

9.4. Q3. Implement the training procedure#

You will implement the training procedure and the validation procedure. We have provided you some hints for different things you should be calculating in the training portion. For the validation portion, we leave this up to you entirely. Make sure the values that you are seeing during training make sense in terms of magnitude - i.e., divisions by number of batches or number of elements is correct. Feel free to change this however works for you.

The pkbar import is a handy package that makes pytorch trainings more akin to tensorflow’s .fit() function in terms of output.

The training function will return the trained model, along with a dictionary called history that can be used for plotting your metrics during training.

import pkbar

def trainer(net, train_loader, val_loader, num_epochs=50, lr=1e-3, device='cuda'):
    # Setup random seed
    torch.manual_seed(8)
    torch.cuda.manual_seed(8)

    history = {'train_loss':[], 'val_loss':[], 'train_acc':[], 'val_acc':[]}

    print("Training Size: {0}".format(len(train_loader.dataset)))
    print("Validation Size: {0}".format(len(val_loader.dataset)))

    # Create your optimizer


    print('===========  Optimizer  ==================:')
    print('      LR:', lr)
    print('      num_epochs:', num_epochs)
    print('')

    # Define your loss function, we are doing multiclass classification remember
    CCE = torch.nn.CrossEntropyLoss()
    for epoch in range(num_epochs):
        # Progress bar setup
        kbar = pkbar.Kbar(target=len(train_loader), epoch=epoch, num_epochs=num_epochs, width=20, always_stateful=False)

        net.train()  # Set the model to training mode
        running_loss = 0.0
        running_acc = 0.0

        for i, data in enumerate(train_loader):
            data = data.to(device)  # Move data to the specified device (e.g., GPU)

            optimizer.zero_grad()

            # Forward pass of your models
            # You can access different variables such as data.x, data.edge_index,data.batch,data.y
            logits =

            # We want to monitor our accuracy during training
            pred =
            train_acc =

            # Calculate your loss
            loss =

            # Backward pass and optimization
            loss.backward()
            optimizer.step()

            # Note we have per batch, and running metrics
            running_loss += loss.item() * data.num_graphs
            running_acc += train_acc * data.num_graphs
            kbar.update(i, values=[("loss", loss.item()),("acc:",train_acc)])

        # Track training loss
        history['train_loss'].append(running_loss / len(train_loader.dataset))
        history['train_acc'].append(running_acc / len(train_loader.dataset))

        ######################
        ## Validation phase ##
        ######################
        net.eval()  # Set the model to evaluation mode
        val_loss = 0.0
        val_acc = 0.0
        with torch.no_grad():
            for i, data in enumerate(val_loader):
                data = data.to(device)  # Move data to the specified device (e.g., GPU)

                # Forward pass
                out =

                # Compute validation metrics
                loss =
                val_loss +=
                pred =
                val_acc +=

        # Average validation loss
        val_loss /= len(val_loader)
        val_acc /= len(val_loader)

        # Track validation loss
        history['val_loss'].append(val_loss)
        history['val_acc'].append(val_acc)

        kbar.add(1, values=[("val_loss", loss.item()),("val_acc:",val_acc)])

    return net, history
# Instantiate and train your model
# What is your input size? How many classes do we have?
input_dim =
num_classes =
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = GCN(input_dim=input_dim, num_classes=num_classes).to(device)

model, history = trainer(model, train_loader, val_loader, num_epochs=50, lr=0.001, device=device)

9.5. Q4. Plotting#

Plot the loss and accuracy curves using the history from training. Make sure to overlay both training and validation. Provide analysis on potential issues you see if any.

def plot_loss(history):
  # Two inividual plots, one for losses and one for accuracy

plot_loss(history)

9.6. Q5. Implement a function to evaluate your model on the testing dataset#

We want to return two things:

  1. Test accuracy

  2. Confusion matrix on the test set (plot)

Hint: see what we have imported from sklearn and view the documentation. You might find some useful functions.

Provide some analysis as to what you see in terms of performance? Is this surprising? Are there biases towards any specific classes? Why?

import sklearn.metrics as metrics
import seaborn as sns
import matplotlib.pyplot as plt
import numpy as np


def evaluate_model(model, loader):
    model.eval()
    all_preds = [] # You will want to append to these and eventually combine into a singular array
    all_labels = []
    with torch.no_grad():
        for data in loader:
            data = data.to(device)
            # Recall your model is going to return raw logits over N classes
            # How do we obtain the prediction for N classes?
            # Do we need to softmax here?


    # Calculate confusion matrix
    conf_matrix =
    # Calculate accuracy
    accuracy =


    # Plot confusion matrix
    fig, ax = plt.subplots(figsize=(8, 6))
    sns.heatmap(conf_matrix, annot=True, fmt="d", cmap="Blues", xticklabels=np.unique(all_labels), yticklabels=np.unique(all_labels))
    ax.set_xlabel('Predicted Labels')
    ax.set_ylabel('True Labels')
    ax.set_title('Confusion Matrix')
    plt.show()

    return accuracy, conf_matrix

# Evaluate the model on the test set
test_accuracy, conf_matrix = evaluate_model(model, test_loader)

# Print the results
print(f"Test Accuracy: {test_accuracy:.4f}")

9.7. Q6. Class wise accuracy#

The confusion matrix gives us a good indication of class wise performance, but might not be the easiest thing to look at. Lets instead provide the accuracy class wise, which is more easily interpetable perhaps.

While we could make this cleaner (i.e., combining the above function with the one below) and reduce computation, the dataset is small and therefore we are not worried. You can reuse some of the above function here, or if you want you can simply combine the two functions.

def class_wise_accuracy(model, loader):
    model.eval()
    all_preds = []
    all_labels = []
    with torch.no_grad():
        for data in loader:
            data = data.to(device)

    # Calculate class-wise accuracy
    class_acc = []
    for class_idx in np.unique(all_labels):


    return np.array(class_acc)


# Calculate the class-wise accuracy on the test set
test_class_wise_acc = class_wise_accuracy(model, test_loader)

# Print the class-wise accuracy for each class
print("Class-wise accuracy:")
for i, acc in enumerate(test_class_wise_acc):
    print(f"Class {i}: {acc:.4f}")

9.8. Q7. Training with positional information of nodes#

In previous experiments, we have neglected information that might be very important for our task of classifiying digits - node position.

Modify the training script from above to also utilize this information in the form:

data.x = torch.cat([data.x,data.pos],dim=1)

You will need to think about how the shape of your input has changed with the addition of this new information

Train a new model with this additional information and provide the same metrics as above:

  1. plot_loss() - no changes required

  2. evaluate_model() - changes required for inputs

  3. class_wise_accuracy() - changes required for inputs

Provide analysis on how this additional information effects the performance of your model. Is this helpful information? Why?

# Instantiate and train your model
# What is your input size? How many classes do we have?
input_dim =
num_classes =
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = GCN(input_dim=input_dim, num_classes=num_classes).to(device)

model, history = trainer(model, train_loader, val_loader, num_epochs=50, lr=0.001, device=device)

9.9. Q8. Optimize your model#

Lets see how performative you can make your model. You are free to make any design choices you like, as well as changing hyperparameters. Provided detailed summaries of the choices you have made.

9.10. Bonus Question:#

Implement a CNN design of your choice and compare performance. Are graph structures the optimal way of representing the data?

For those interested, you can find the SOTA MNIST models here: https://paperswithcode.com/sota/image-classification-on-mnist