summaryrefslogtreecommitdiffhomepage
path: root/mnist.py
diff options
context:
space:
mode:
Diffstat (limited to 'mnist.py')
-rw-r--r--mnist.py104
1 files changed, 104 insertions, 0 deletions
diff --git a/mnist.py b/mnist.py
new file mode 100644
index 0000000..dcb6dad
--- /dev/null
+++ b/mnist.py
@@ -0,0 +1,104 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import torch.optim as optim
+from torchvision import datasets, transforms
+
+# Define the CNN architecture
+class Net(nn.Module):
+ def __init__(self):
+ super(Net, self).__init__()
+ self.conv1 = nn.Conv2d(in_channels=1, out_channels=6, kernel_size=5, stride=1, padding=2)
+ self.conv2 = nn.Conv2d(in_channels=6, out_channels=16, kernel_size=5, stride=1, padding=0)
+ self.fc1 = nn.Linear(in_features=400, out_features=120)
+ self.fc2 = nn.Linear(in_features=120, out_features=84)
+ self.fc3 = nn.Linear(in_features=84, out_features=10)
+
+ def forward(self, x):
+ x = F.relu(self.conv1(x))
+ x = F.max_pool2d(x, kernel_size=2, stride=2)
+ x = F.relu(self.conv2(x))
+ x = F.max_pool2d(x, kernel_size=2, stride=2)
+ x = torch.flatten(x, 1)
+ x = F.relu(self.fc1(x))
+ x = F.relu(self.fc2(x))
+ x = self.fc3(x)
+ return x
+
+# Hyperparameters
+batch_size = 64
+test_batch_size = 1000
+epochs = 10
+learning_rate = 0.001
+log_interval = 100 # Log training progress every 100 batches
+
+# Set device (GPU if available, else CPU)
+device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+print(device)
+
+# Data preprocessing and augmentation
+transform = transforms.Compose([
+ transforms.ToTensor(), # Convert PIL images to tensors
+ transforms.Normalize((0.1307,), (0.3081,)) # Normalize images
+])
+
+# Load datasets
+train_dataset = datasets.MNIST('data', train=True, download=True, transform=transform)
+test_dataset = datasets.MNIST('data', train=False, transform=transform)
+
+train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
+test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=test_batch_size, shuffle=False)
+
+# Initialize model and optimizer
+model = Net().to(device)
+optimizer = optim.Adam(model.parameters(), lr=learning_rate)
+criterion = nn.CrossEntropyLoss() # Combines softmax and NLL loss
+
+# Training function
+def train(model, device, train_loader, optimizer, epoch):
+ model.train()
+ losses = []
+ correct = 0
+ for batch_idx, (data, target) in enumerate(train_loader):
+ data, target = data.to(device), target.to(device)
+ optimizer.zero_grad()
+ output = model(data)
+ loss = criterion(output, target)
+ loss.backward()
+ optimizer.step()
+ losses.append(loss.item())
+ pred = output.argmax(dim=1, keepdim=True)
+ correct += pred.eq(target.view_as(pred)).sum().item()
+
+ if batch_idx % log_interval == 0 and batch_idx != 0:
+ print(f'Epoch {epoch} [{batch_idx * len(data)}/{len(train_loader.dataset)}] Loss: {loss.item():.6f}')
+
+ avg_loss = sum(losses) / len(losses)
+ accuracy = 100. * correct / len(train_loader.dataset)
+ print(f'Epoch {epoch} - Avg Loss: {avg_loss:.6f}, Accuracy: {accuracy:.2f}%')
+
+# Test function
+def test(model, device, test_loader):
+ model.eval()
+ test_loss = 0
+ correct = 0
+ with torch.no_grad():
+ for data, target in test_loader:
+ data, target = data.to(device), target.to(device)
+ output = model(data)
+ test_loss += criterion(output, target).item()
+ pred = output.argmax(dim=1, keepdim=True)
+ correct += pred.eq(target.view_as(pred)).sum().item()
+
+ test_loss /= len(test_loader.dataset)
+ accuracy = 100. * correct / len(test_loader.dataset)
+ print(f'\nTest Set - Loss: {test_loss:.4f}, Accuracy: {accuracy:.2f}%\n')
+
+# Training loop
+for epoch in range(1, epochs + 1):
+ train(model, device, train_loader, optimizer, epoch)
+ test(model, device, test_loader)
+
+# Save the trained model
+torch.save(model.state_dict(), "mnist_cnn.pth")
+print("Model saved as mnist_cnn.pth")