diff options
Diffstat (limited to 'mnist.py')
-rw-r--r-- | mnist.py | 104 |
1 files changed, 104 insertions, 0 deletions
diff --git a/mnist.py b/mnist.py new file mode 100644 index 0000000..dcb6dad --- /dev/null +++ b/mnist.py @@ -0,0 +1,104 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.optim as optim +from torchvision import datasets, transforms + +# Define the CNN architecture +class Net(nn.Module): + def __init__(self): + super(Net, self).__init__() + self.conv1 = nn.Conv2d(in_channels=1, out_channels=6, kernel_size=5, stride=1, padding=2) + self.conv2 = nn.Conv2d(in_channels=6, out_channels=16, kernel_size=5, stride=1, padding=0) + self.fc1 = nn.Linear(in_features=400, out_features=120) + self.fc2 = nn.Linear(in_features=120, out_features=84) + self.fc3 = nn.Linear(in_features=84, out_features=10) + + def forward(self, x): + x = F.relu(self.conv1(x)) + x = F.max_pool2d(x, kernel_size=2, stride=2) + x = F.relu(self.conv2(x)) + x = F.max_pool2d(x, kernel_size=2, stride=2) + x = torch.flatten(x, 1) + x = F.relu(self.fc1(x)) + x = F.relu(self.fc2(x)) + x = self.fc3(x) + return x + +# Hyperparameters +batch_size = 64 +test_batch_size = 1000 +epochs = 10 +learning_rate = 0.001 +log_interval = 100 # Log training progress every 100 batches + +# Set device (GPU if available, else CPU) +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") +print(device) + +# Data preprocessing and augmentation +transform = transforms.Compose([ + transforms.ToTensor(), # Convert PIL images to tensors + transforms.Normalize((0.1307,), (0.3081,)) # Normalize images +]) + +# Load datasets +train_dataset = datasets.MNIST('data', train=True, download=True, transform=transform) +test_dataset = datasets.MNIST('data', train=False, transform=transform) + +train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True) +test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=test_batch_size, shuffle=False) + +# Initialize model and optimizer +model = Net().to(device) +optimizer = optim.Adam(model.parameters(), lr=learning_rate) +criterion = nn.CrossEntropyLoss() # Combines softmax and NLL loss + +# Training function +def train(model, device, train_loader, optimizer, epoch): + model.train() + losses = [] + correct = 0 + for batch_idx, (data, target) in enumerate(train_loader): + data, target = data.to(device), target.to(device) + optimizer.zero_grad() + output = model(data) + loss = criterion(output, target) + loss.backward() + optimizer.step() + losses.append(loss.item()) + pred = output.argmax(dim=1, keepdim=True) + correct += pred.eq(target.view_as(pred)).sum().item() + + if batch_idx % log_interval == 0 and batch_idx != 0: + print(f'Epoch {epoch} [{batch_idx * len(data)}/{len(train_loader.dataset)}] Loss: {loss.item():.6f}') + + avg_loss = sum(losses) / len(losses) + accuracy = 100. * correct / len(train_loader.dataset) + print(f'Epoch {epoch} - Avg Loss: {avg_loss:.6f}, Accuracy: {accuracy:.2f}%') + +# Test function +def test(model, device, test_loader): + model.eval() + test_loss = 0 + correct = 0 + with torch.no_grad(): + for data, target in test_loader: + data, target = data.to(device), target.to(device) + output = model(data) + test_loss += criterion(output, target).item() + pred = output.argmax(dim=1, keepdim=True) + correct += pred.eq(target.view_as(pred)).sum().item() + + test_loss /= len(test_loader.dataset) + accuracy = 100. * correct / len(test_loader.dataset) + print(f'\nTest Set - Loss: {test_loss:.4f}, Accuracy: {accuracy:.2f}%\n') + +# Training loop +for epoch in range(1, epochs + 1): + train(model, device, train_loader, optimizer, epoch) + test(model, device, test_loader) + +# Save the trained model +torch.save(model.state_dict(), "mnist_cnn.pth") +print("Model saved as mnist_cnn.pth") |