diff options
Diffstat (limited to 'lbfgs.py')
-rw-r--r-- | lbfgs.py | 144 |
1 files changed, 144 insertions, 0 deletions
diff --git a/lbfgs.py b/lbfgs.py new file mode 100644 index 0000000..05fc55d --- /dev/null +++ b/lbfgs.py @@ -0,0 +1,144 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.optim as optim +from torchvision import datasets, transforms +import sys + +dataset="MNIST" +if(len(sys.argv) > 1): + if(sys.argv[1].upper() == "F" or sys.argv[1].upper() == "FASHION"): + dataset="FashionMNIST" + +# Define the CNN architecture with Batch Normalization +class Net(nn.Module): + def __init__(self): + super(Net, self).__init__() + self.conv1 = nn.Conv2d(1, 6, 5, padding=2) + self.bn1 = nn.BatchNorm2d(6) + self.conv2 = nn.Conv2d(6, 16, 5) + self.bn2 = nn.BatchNorm2d(16) + self.fc1 = nn.Linear(400, 120) + self.bn3 = nn.BatchNorm1d(120) + self.fc2 = nn.Linear(120, 84) + self.bn4 = nn.BatchNorm1d(84) + self.fc3 = nn.Linear(84, 10) + + # Weight initialization + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') + if m.bias is not None: # Only check bias if it exists + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.Linear): + nn.init.xavier_normal_(m.weight) + if m.bias is not None: # Only check bias if it exists + nn.init.constant_(m.bias, 0) + + def forward(self, x): + x = F.relu(self.bn1(self.conv1(x))) + x = F.max_pool2d(x, 2) + x = F.relu(self.bn2(self.conv2(x))) + x = F.max_pool2d(x, 2) + x = torch.flatten(x, 1) + x = F.relu(self.bn3(self.fc1(x))) + x = F.relu(self.bn4(self.fc2(x))) + x = self.fc3(x) + return x +# Hyperparameters +batch_size = 128 # Reduced batch size +test_batch_size = 1000 +epochs = 10 +learning_rate = 0.01 # Further reduced learning rate +log_interval = 100 + +# Set device (GPU if available, else CPU) +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") +print(device) + +# Data preprocessing and augmentation +transform = transforms.Compose([ + transforms.ToTensor(), # Convert PIL images to tensors + transforms.Normalize((0.1307,), (0.3081,)) # Normalize images +]) + +# Load datasets +if(dataset=="FashionMNIST"): + train_dataset = datasets.FashionMNIST('data', train=True, download=True, transform=transform) + test_dataset = datasets.FashionMNIST('data', train=False, transform=transform) +else: + train_dataset = datasets.MNIST('data', train=True, download=True, transform=transform) + test_dataset = datasets.MNIST('data', train=False, transform=transform) + +train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True) +test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=test_batch_size, shuffle=False) + +# Initialize model and optimizer +model = Net().to(device) +optimizer = optim.LBFGS( + model.parameters(), + lr=learning_rate, + max_iter=15, + max_eval=20, + tolerance_change=1e-09, + history_size=50, + line_search_fn='strong_wolfe' +) +criterion = nn.CrossEntropyLoss() + +# Training function +def train(model, device, train_loader, optimizer, epoch): + model.train() + losses = [] + correct = 0 + for batch_idx, (data, target) in enumerate(train_loader): + data, target = data.to(device), target.to(device) + + def closure(): + optimizer.zero_grad() + output = model(data) + loss = criterion(output, target) + loss.backward() + torch.nn.utils.clip_grad_norm_(model.parameters(), 0.1) # Smaller clipping + return loss + + loss = optimizer.step(closure) + losses.append(loss.item()) + + with torch.no_grad(): + output = model(data) + pred = output.argmax(dim=1) + correct += pred.eq(target).sum().item() + + if batch_idx % log_interval == 0 and batch_idx != 0: + print(f'Epoch {epoch} [{batch_idx * len(data)}/{len(train_loader.dataset)}] Loss: {loss.item():.6f}') + + avg_loss = sum(losses) / len(losses) + accuracy = 100. * correct / len(train_loader.dataset) + print(f'Epoch {epoch} - Avg Loss: {avg_loss:.6f}, Accuracy: {accuracy:.2f}%') + +# Test function +def test(model, device, test_loader): + model.eval() + test_loss = 0 + correct = 0 + with torch.no_grad(): + for data, target in test_loader: + data, target = data.to(device), target.to(device) + output = model(data) + test_loss += criterion(output, target).item() + pred = output.argmax(dim=1, keepdim=True) + correct += pred.eq(target.view_as(pred)).sum().item() + + test_loss /= len(test_loader.dataset) + accuracy = 100. * correct / len(test_loader.dataset) + print(f'\nTest Set - Loss: {test_loss:.4f}, Accuracy: {accuracy:.2f}%\n') + +# Training loop +for epoch in range(1, epochs + 1): + train(model, device, train_loader, optimizer, epoch) + test(model, device, test_loader) + +# Save the trained model +torch.save(model.state_dict(), "mnist_cnn.pth") +print("Model saved as mnist_cnn.pth") |