# ============================================================ # MNIST Digit Classification using a Simple Neural Network (PyTorch) # ============================================================ import torch import torch.nn as nn import torch.optim as optim from torch.utils.data import DataLoader import torchvision import torchvision.transforms as transforms import matplotlib.pyplot as plt import numpy as np # ------------------------------------------------------------ # 1. Device Configuration (GPU if available) # ------------------------------------------------------------ device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print("Using device:", device) # ------------------------------------------------------------ # 2. Load MNIST Dataset # ------------------------------------------------------------ transform = transforms.Compose([ transforms.ToTensor(), # convert to tensor transforms.Normalize((0.5,), (0.5,)) # normalize ]) train_dataset = torchvision.datasets.MNIST( root='./data', train=True, transform=transform, download=True) test_dataset = torchvision.datasets.MNIST( root='./data', train=False, transform=transform, download=True) train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True) test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False) # ------------------------------------------------------------ # 3. Define Simple Neural Network # ------------------------------------------------------------ class SimpleNN(nn.Module): def __init__(self): super(SimpleNN, self).__init__() self.fc1 = nn.Linear(28*28, 128) self.fc2 = nn.Linear(128, 64) self.fc3 = nn.Linear(64, 10) self.relu = nn.ReLU() def forward(self, x): x = x.view(-1, 784) # flatten image x = self.relu(self.fc1(x)) x = self.relu(self.fc2(x)) x = self.fc3(x) # no softmax (CrossEntropyLoss handles it) return x model = SimpleNN().to(device) # ------------------------------------------------------------ # 4. Loss and Optimizer # ------------------------------------------------------------ criterion = nn.CrossEntropyLoss() optimizer = optim.Adam(model.parameters(), lr=0.001) # ------------------------------------------------------------ # 5. Training Loop # ------------------------------------------------------------ num_epochs = 5 train_loss_list = [] train_acc_list = [] for epoch in range(num_epochs): total_loss = 0 correct = 0 total = 0 for images, labels in train_loader: images, labels = images.to(device), labels.to(device) # forward pass outputs = model(images) loss = criterion(outputs, labels) # backward pass optimizer.zero_grad() loss.backward() optimizer.step() total_loss += loss.item() # accuracy _, predicted = outputs.max(1) total += labels.size(0) correct += (predicted == labels).sum().item() avg_loss = total_loss / len(train_loader) accuracy = 100 * correct / total train_loss_list.append(avg_loss) train_acc_list.append(accuracy) print(f"Epoch [{epoch+1}/{num_epochs}] Loss: {avg_loss:.4f} Accuracy: {accuracy:.2f}%") # ------------------------------------------------------------ # 6. Testing / Evaluation # ------------------------------------------------------------ model.eval() correct = 0 total = 0 with torch.no_grad(): for images, labels in test_loader: images, labels = images.to(device), labels.to(device) outputs = model(images) _, predicted = outputs.max(1) total += labels.size(0) correct += (predicted == labels).sum().item() test_accuracy = 100 * correct / total print("\nTest Accuracy:", test_accuracy) # ------------------------------------------------------------ # 7. Plot Accuracy and Loss Graphs # ------------------------------------------------------------ plt.figure(figsize=(8, 4)) plt.plot(train_acc_list, label="Training Accuracy") plt.title("Training Accuracy") plt.xlabel("Epoch") plt.ylabel("Accuracy (%)") plt.grid() plt.legend() plt.show() plt.figure(figsize=(8, 4)) plt.plot(train_loss_list, label="Training Loss") plt.title("Training Loss") plt.xlabel("Epoch") plt.ylabel("Loss") plt.grid() plt.legend() plt.show() # ------------------------------------------------------------ # 8. Show Predictions for First 5 Test Images # ------------------------------------------------------------ data_iter = iter(test_loader) images, labels = next(data_iter) images, labels = images.to(device), labels.to(device) outputs = model(images[:5]) _, preds = outputs.max(1) print("\nPredicted:", preds.cpu().numpy()) print("Actual: ", labels[:5].cpu().numpy()) # Plot the first 5 images plt.figure(figsize=(10, 2)) for i in range(5): plt.subplot(1, 5, i + 1) plt.imshow(images[i].cpu().squeeze(), cmap="gray") plt.title(f"P:{preds[i].item()} / T:{labels[i].item()}") plt.axis("off") plt.show()