import torch import torch.nn as nn import torch.nn.functional as F import torch.optim as optim from torchvision import datasets, transforms import sys dataset="MNIST" if(len(sys.argv) > 1): if(sys.argv[1].upper() == "F" or sys.argv[1].upper() == "FASHION"): dataset="FashionMNIST" # Define the CNN architecture with Batch Normalization class Net(nn.Module): def __init__(self): super(Net, self).__init__() self.conv1 = nn.Conv2d(1, 6, 5, padding=2) self.bn1 = nn.BatchNorm2d(6) self.conv2 = nn.Conv2d(6, 16, 5) self.bn2 = nn.BatchNorm2d(16) self.fc1 = nn.Linear(400, 120) self.bn3 = nn.BatchNorm1d(120) self.fc2 = nn.Linear(120, 84) self.bn4 = nn.BatchNorm1d(84) self.fc3 = nn.Linear(84, 10) # Weight initialization for m in self.modules(): if isinstance(m, nn.Conv2d): nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') if m.bias is not None: # Only check bias if it exists nn.init.constant_(m.bias, 0) elif isinstance(m, nn.Linear): nn.init.xavier_normal_(m.weight) if m.bias is not None: # Only check bias if it exists nn.init.constant_(m.bias, 0) def forward(self, x): x = F.relu(self.bn1(self.conv1(x))) x = F.max_pool2d(x, 2) x = F.relu(self.bn2(self.conv2(x))) x = F.max_pool2d(x, 2) x = torch.flatten(x, 1) x = F.relu(self.bn3(self.fc1(x))) x = F.relu(self.bn4(self.fc2(x))) x = self.fc3(x) return x # Hyperparameters batch_size = 128 # Reduced batch size test_batch_size = 1000 epochs = 10 learning_rate = 0.01 # Further reduced learning rate log_interval = 100 # Set device (GPU if available, else CPU) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print(device) # Data preprocessing and augmentation transform = transforms.Compose([ transforms.ToTensor(), # Convert PIL images to tensors transforms.Normalize((0.1307,), (0.3081,)) # Normalize images ]) # Load datasets if(dataset=="FashionMNIST"): train_dataset = datasets.FashionMNIST('data', train=True, download=True, transform=transform) test_dataset = datasets.FashionMNIST('data', train=False, transform=transform) else: train_dataset = datasets.MNIST('data', train=True, download=True, transform=transform) test_dataset = datasets.MNIST('data', train=False, transform=transform) train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True) test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=test_batch_size, shuffle=False) # Initialize model and optimizer model = Net().to(device) optimizer = optim.LBFGS( model.parameters(), lr=learning_rate, max_iter=15, max_eval=20, tolerance_change=1e-09, history_size=50, line_search_fn='strong_wolfe' ) criterion = nn.CrossEntropyLoss() # Training function def train(model, device, train_loader, optimizer, epoch): model.train() losses = [] correct = 0 for batch_idx, (data, target) in enumerate(train_loader): data, target = data.to(device), target.to(device) def closure(): optimizer.zero_grad() output = model(data) loss = criterion(output, target) loss.backward() torch.nn.utils.clip_grad_norm_(model.parameters(), 0.1) # Smaller clipping return loss loss = optimizer.step(closure) losses.append(loss.item()) with torch.no_grad(): output = model(data) pred = output.argmax(dim=1) correct += pred.eq(target).sum().item() if batch_idx % log_interval == 0 and batch_idx != 0: print(f'Epoch {epoch} [{batch_idx * len(data)}/{len(train_loader.dataset)}] Loss: {loss.item():.6f}') avg_loss = sum(losses) / len(losses) accuracy = 100. * correct / len(train_loader.dataset) print(f'Epoch {epoch} - Avg Loss: {avg_loss:.6f}, Accuracy: {accuracy:.2f}%') # Test function def test(model, device, test_loader): model.eval() test_loss = 0 correct = 0 with torch.no_grad(): for data, target in test_loader: data, target = data.to(device), target.to(device) output = model(data) test_loss += criterion(output, target).item() pred = output.argmax(dim=1, keepdim=True) correct += pred.eq(target.view_as(pred)).sum().item() test_loss /= len(test_loader.dataset) accuracy = 100. * correct / len(test_loader.dataset) print(f'\nTest Set - Loss: {test_loss:.4f}, Accuracy: {accuracy:.2f}%\n') # Training loop for epoch in range(1, epochs + 1): train(model, device, train_loader, optimizer, epoch) test(model, device, test_loader) # Save the trained model torch.save(model.state_dict(), "mnist_cnn.pth") print("Model saved as mnist_cnn.pth")