summaryrefslogtreecommitdiffhomepage
path: root/lbfgs.py
diff options
context:
space:
mode:
authorJeff Heiges <jeff.heiges@colorado.edu>2025-03-05 16:20:15 -0700
committerJeff Heiges <jeff.heiges@colorado.edu>2025-03-05 16:20:15 -0700
commitd81318b0707dba6e7939994d5f9011be3be4219c (patch)
tree062644115367a6cceb6f263a5a4144fd31d984e6 /lbfgs.py
parent144b3858fecde8193fd7e0854904203b5a23acb0 (diff)
Modifications required for LBFGS to work
Diffstat (limited to 'lbfgs.py')
-rw-r--r--lbfgs.py144
1 files changed, 144 insertions, 0 deletions
diff --git a/lbfgs.py b/lbfgs.py
new file mode 100644
index 0000000..05fc55d
--- /dev/null
+++ b/lbfgs.py
@@ -0,0 +1,144 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import torch.optim as optim
+from torchvision import datasets, transforms
+import sys
+
+dataset="MNIST"
+if(len(sys.argv) > 1):
+ if(sys.argv[1].upper() == "F" or sys.argv[1].upper() == "FASHION"):
+ dataset="FashionMNIST"
+
+# Define the CNN architecture with Batch Normalization
+class Net(nn.Module):
+ def __init__(self):
+ super(Net, self).__init__()
+ self.conv1 = nn.Conv2d(1, 6, 5, padding=2)
+ self.bn1 = nn.BatchNorm2d(6)
+ self.conv2 = nn.Conv2d(6, 16, 5)
+ self.bn2 = nn.BatchNorm2d(16)
+ self.fc1 = nn.Linear(400, 120)
+ self.bn3 = nn.BatchNorm1d(120)
+ self.fc2 = nn.Linear(120, 84)
+ self.bn4 = nn.BatchNorm1d(84)
+ self.fc3 = nn.Linear(84, 10)
+
+ # Weight initialization
+ for m in self.modules():
+ if isinstance(m, nn.Conv2d):
+ nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
+ if m.bias is not None: # Only check bias if it exists
+ nn.init.constant_(m.bias, 0)
+ elif isinstance(m, nn.Linear):
+ nn.init.xavier_normal_(m.weight)
+ if m.bias is not None: # Only check bias if it exists
+ nn.init.constant_(m.bias, 0)
+
+ def forward(self, x):
+ x = F.relu(self.bn1(self.conv1(x)))
+ x = F.max_pool2d(x, 2)
+ x = F.relu(self.bn2(self.conv2(x)))
+ x = F.max_pool2d(x, 2)
+ x = torch.flatten(x, 1)
+ x = F.relu(self.bn3(self.fc1(x)))
+ x = F.relu(self.bn4(self.fc2(x)))
+ x = self.fc3(x)
+ return x
+# Hyperparameters
+batch_size = 128 # Reduced batch size
+test_batch_size = 1000
+epochs = 10
+learning_rate = 0.01 # Further reduced learning rate
+log_interval = 100
+
+# Set device (GPU if available, else CPU)
+device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+print(device)
+
+# Data preprocessing and augmentation
+transform = transforms.Compose([
+ transforms.ToTensor(), # Convert PIL images to tensors
+ transforms.Normalize((0.1307,), (0.3081,)) # Normalize images
+])
+
+# Load datasets
+if(dataset=="FashionMNIST"):
+ train_dataset = datasets.FashionMNIST('data', train=True, download=True, transform=transform)
+ test_dataset = datasets.FashionMNIST('data', train=False, transform=transform)
+else:
+ train_dataset = datasets.MNIST('data', train=True, download=True, transform=transform)
+ test_dataset = datasets.MNIST('data', train=False, transform=transform)
+
+train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
+test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=test_batch_size, shuffle=False)
+
+# Initialize model and optimizer
+model = Net().to(device)
+optimizer = optim.LBFGS(
+ model.parameters(),
+ lr=learning_rate,
+ max_iter=15,
+ max_eval=20,
+ tolerance_change=1e-09,
+ history_size=50,
+ line_search_fn='strong_wolfe'
+)
+criterion = nn.CrossEntropyLoss()
+
+# Training function
+def train(model, device, train_loader, optimizer, epoch):
+ model.train()
+ losses = []
+ correct = 0
+ for batch_idx, (data, target) in enumerate(train_loader):
+ data, target = data.to(device), target.to(device)
+
+ def closure():
+ optimizer.zero_grad()
+ output = model(data)
+ loss = criterion(output, target)
+ loss.backward()
+ torch.nn.utils.clip_grad_norm_(model.parameters(), 0.1) # Smaller clipping
+ return loss
+
+ loss = optimizer.step(closure)
+ losses.append(loss.item())
+
+ with torch.no_grad():
+ output = model(data)
+ pred = output.argmax(dim=1)
+ correct += pred.eq(target).sum().item()
+
+ if batch_idx % log_interval == 0 and batch_idx != 0:
+ print(f'Epoch {epoch} [{batch_idx * len(data)}/{len(train_loader.dataset)}] Loss: {loss.item():.6f}')
+
+ avg_loss = sum(losses) / len(losses)
+ accuracy = 100. * correct / len(train_loader.dataset)
+ print(f'Epoch {epoch} - Avg Loss: {avg_loss:.6f}, Accuracy: {accuracy:.2f}%')
+
+# Test function
+def test(model, device, test_loader):
+ model.eval()
+ test_loss = 0
+ correct = 0
+ with torch.no_grad():
+ for data, target in test_loader:
+ data, target = data.to(device), target.to(device)
+ output = model(data)
+ test_loss += criterion(output, target).item()
+ pred = output.argmax(dim=1, keepdim=True)
+ correct += pred.eq(target.view_as(pred)).sum().item()
+
+ test_loss /= len(test_loader.dataset)
+ accuracy = 100. * correct / len(test_loader.dataset)
+ print(f'\nTest Set - Loss: {test_loss:.4f}, Accuracy: {accuracy:.2f}%\n')
+
+# Training loop
+for epoch in range(1, epochs + 1):
+ train(model, device, train_loader, optimizer, epoch)
+ test(model, device, test_loader)
+
+# Save the trained model
+torch.save(model.state_dict(), "mnist_cnn.pth")
+print("Model saved as mnist_cnn.pth")