From e74082fb6e95bea894b4f8c78c76dafc769dd0bd Mon Sep 17 00:00:00 2001
From: Jeff Heiges <jeff.heiges@colorado.edu>
Date: Tue, 4 Mar 2025 18:45:44 -0700
Subject: LeNet-inspired model

---
 mnist.py | 104 +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
 1 file changed, 104 insertions(+)
 create mode 100644 mnist.py

(limited to 'mnist.py')

diff --git a/mnist.py b/mnist.py
new file mode 100644
index 0000000..dcb6dad
--- /dev/null
+++ b/mnist.py
@@ -0,0 +1,104 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import torch.optim as optim
+from torchvision import datasets, transforms
+
+# Define the CNN architecture
+class Net(nn.Module):
+    def __init__(self):
+        super(Net, self).__init__()
+        self.conv1 = nn.Conv2d(in_channels=1, out_channels=6, kernel_size=5, stride=1, padding=2)
+        self.conv2 = nn.Conv2d(in_channels=6, out_channels=16, kernel_size=5, stride=1, padding=0)
+        self.fc1 = nn.Linear(in_features=400, out_features=120)
+        self.fc2 = nn.Linear(in_features=120, out_features=84)
+        self.fc3 = nn.Linear(in_features=84, out_features=10)
+
+    def forward(self, x):
+        x = F.relu(self.conv1(x))
+        x = F.max_pool2d(x, kernel_size=2, stride=2)
+        x = F.relu(self.conv2(x))
+        x = F.max_pool2d(x, kernel_size=2, stride=2)
+        x = torch.flatten(x, 1)
+        x = F.relu(self.fc1(x))
+        x = F.relu(self.fc2(x))
+        x = self.fc3(x)
+        return x
+
+# Hyperparameters
+batch_size = 64
+test_batch_size = 1000
+epochs = 10
+learning_rate = 0.001
+log_interval = 100  # Log training progress every 100 batches
+
+# Set device (GPU if available, else CPU)
+device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+print(device)
+
+# Data preprocessing and augmentation
+transform = transforms.Compose([
+    transforms.ToTensor(),  # Convert PIL images to tensors
+    transforms.Normalize((0.1307,), (0.3081,))  # Normalize images
+])
+
+# Load datasets
+train_dataset = datasets.MNIST('data', train=True, download=True, transform=transform)
+test_dataset = datasets.MNIST('data', train=False, transform=transform)
+
+train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
+test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=test_batch_size, shuffle=False)
+
+# Initialize model and optimizer
+model = Net().to(device)
+optimizer = optim.Adam(model.parameters(), lr=learning_rate)
+criterion = nn.CrossEntropyLoss()  # Combines softmax and NLL loss
+
+# Training function
+def train(model, device, train_loader, optimizer, epoch):
+    model.train()
+    losses = []
+    correct = 0
+    for batch_idx, (data, target) in enumerate(train_loader):
+        data, target = data.to(device), target.to(device)
+        optimizer.zero_grad()
+        output = model(data)
+        loss = criterion(output, target)
+        loss.backward()
+        optimizer.step()
+        losses.append(loss.item())
+        pred = output.argmax(dim=1, keepdim=True)
+        correct += pred.eq(target.view_as(pred)).sum().item()
+
+        if batch_idx % log_interval == 0 and batch_idx != 0:
+            print(f'Epoch {epoch} [{batch_idx * len(data)}/{len(train_loader.dataset)}] Loss: {loss.item():.6f}')
+
+    avg_loss = sum(losses) / len(losses)
+    accuracy = 100. * correct / len(train_loader.dataset)
+    print(f'Epoch {epoch} - Avg Loss: {avg_loss:.6f}, Accuracy: {accuracy:.2f}%')
+
+# Test function
+def test(model, device, test_loader):
+    model.eval()
+    test_loss = 0
+    correct = 0
+    with torch.no_grad():
+        for data, target in test_loader:
+            data, target = data.to(device), target.to(device)
+            output = model(data)
+            test_loss += criterion(output, target).item()
+            pred = output.argmax(dim=1, keepdim=True)
+            correct += pred.eq(target.view_as(pred)).sum().item()
+
+    test_loss /= len(test_loader.dataset)
+    accuracy = 100. * correct / len(test_loader.dataset)
+    print(f'\nTest Set - Loss: {test_loss:.4f}, Accuracy: {accuracy:.2f}%\n')
+
+# Training loop
+for epoch in range(1, epochs + 1):
+    train(model, device, train_loader, optimizer, epoch)
+    test(model, device, test_loader)
+
+# Save the trained model
+torch.save(model.state_dict(), "mnist_cnn.pth")
+print("Model saved as mnist_cnn.pth")
-- 
cgit v1.2.3