diff options
-rw-r--r-- | mnist.py | 17 |
1 files changed, 13 insertions, 4 deletions
@@ -3,7 +3,12 @@ import torch.nn as nn import torch.nn.functional as F import torch.optim as optim from torchvision import datasets, transforms +import sys +dataset="MNIST" +if(len(sys.argv) > 1): + if(sys.argv[1].upper() == "F" or sys.argv[1].upper() == "FASHION"): + dataset="FashionMNIST" # Define the CNN architecture class Net(nn.Module): def __init__(self): @@ -43,8 +48,12 @@ transform = transforms.Compose([ ]) # Load datasets -train_dataset = datasets.MNIST('data', train=True, download=True, transform=transform) -test_dataset = datasets.MNIST('data', train=False, transform=transform) +if(dataset=="FashionMNIST"): + train_dataset = datasets.FashionMNIST('data', train=True, download=True, transform=transform) + test_dataset = datasets.FashionMNIST('data', train=False, transform=transform) +else: + train_dataset = datasets.MNIST('data', train=True, download=True, transform=transform) + test_dataset = datasets.MNIST('data', train=False, transform=transform) train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True) test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=test_batch_size, shuffle=False) @@ -71,11 +80,11 @@ def train(model, device, train_loader, optimizer, epoch): correct += pred.eq(target.view_as(pred)).sum().item() if batch_idx % log_interval == 0 and batch_idx != 0: - print(f'Epoch {epoch} [{batch_idx * len(data)}/{len(train_loader.dataset)}] Loss: {loss.item():.6f}') + print(f'{dataset}: Epoch {epoch} [{batch_idx * len(data)}/{len(train_loader.dataset)}] Loss: {loss.item():.6f}') avg_loss = sum(losses) / len(losses) accuracy = 100. * correct / len(train_loader.dataset) - print(f'Epoch {epoch} - Avg Loss: {avg_loss:.6f}, Accuracy: {accuracy:.2f}%') + print(f'{dataset}: Epoch {epoch} - Avg Loss: {avg_loss:.6f}, Accuracy: {accuracy:.2f}%') # Test function def test(model, device, test_loader): |