diff options
author | Jeff Heiges <jeff.heiges@colorado.edu> | 2025-03-04 19:07:49 -0700 |
---|---|---|
committer | Jeff Heiges <jeff.heiges@colorado.edu> | 2025-03-04 19:07:49 -0700 |
commit | 144b3858fecde8193fd7e0854904203b5a23acb0 (patch) | |
tree | e671ed663048ca821a1be0f7f6a14a756c5e345f /mnist.py | |
parent | e74082fb6e95bea894b4f8c78c76dafc769dd0bd (diff) |
Added FashionMNIST
Diffstat (limited to 'mnist.py')
-rw-r--r-- | mnist.py | 17 |
1 files changed, 13 insertions, 4 deletions
@@ -3,7 +3,12 @@ import torch.nn as nn import torch.nn.functional as F import torch.optim as optim from torchvision import datasets, transforms +import sys +dataset="MNIST" +if(len(sys.argv) > 1): + if(sys.argv[1].upper() == "F" or sys.argv[1].upper() == "FASHION"): + dataset="FashionMNIST" # Define the CNN architecture class Net(nn.Module): def __init__(self): @@ -43,8 +48,12 @@ transform = transforms.Compose([ ]) # Load datasets -train_dataset = datasets.MNIST('data', train=True, download=True, transform=transform) -test_dataset = datasets.MNIST('data', train=False, transform=transform) +if(dataset=="FashionMNIST"): + train_dataset = datasets.FashionMNIST('data', train=True, download=True, transform=transform) + test_dataset = datasets.FashionMNIST('data', train=False, transform=transform) +else: + train_dataset = datasets.MNIST('data', train=True, download=True, transform=transform) + test_dataset = datasets.MNIST('data', train=False, transform=transform) train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True) test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=test_batch_size, shuffle=False) @@ -71,11 +80,11 @@ def train(model, device, train_loader, optimizer, epoch): correct += pred.eq(target.view_as(pred)).sum().item() if batch_idx % log_interval == 0 and batch_idx != 0: - print(f'Epoch {epoch} [{batch_idx * len(data)}/{len(train_loader.dataset)}] Loss: {loss.item():.6f}') + print(f'{dataset}: Epoch {epoch} [{batch_idx * len(data)}/{len(train_loader.dataset)}] Loss: {loss.item():.6f}') avg_loss = sum(losses) / len(losses) accuracy = 100. * correct / len(train_loader.dataset) - print(f'Epoch {epoch} - Avg Loss: {avg_loss:.6f}, Accuracy: {accuracy:.2f}%') + print(f'{dataset}: Epoch {epoch} - Avg Loss: {avg_loss:.6f}, Accuracy: {accuracy:.2f}%') # Test function def test(model, device, test_loader): |