summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--mnist.py17
1 files changed, 13 insertions, 4 deletions
diff --git a/mnist.py b/mnist.py
index dcb6dad..3ec8da1 100644
--- a/mnist.py
+++ b/mnist.py
@@ -3,7 +3,12 @@ import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
+import sys
+dataset="MNIST"
+if(len(sys.argv) > 1):
+ if(sys.argv[1].upper() == "F" or sys.argv[1].upper() == "FASHION"):
+ dataset="FashionMNIST"
# Define the CNN architecture
class Net(nn.Module):
def __init__(self):
@@ -43,8 +48,12 @@ transform = transforms.Compose([
])
# Load datasets
-train_dataset = datasets.MNIST('data', train=True, download=True, transform=transform)
-test_dataset = datasets.MNIST('data', train=False, transform=transform)
+if(dataset=="FashionMNIST"):
+ train_dataset = datasets.FashionMNIST('data', train=True, download=True, transform=transform)
+ test_dataset = datasets.FashionMNIST('data', train=False, transform=transform)
+else:
+ train_dataset = datasets.MNIST('data', train=True, download=True, transform=transform)
+ test_dataset = datasets.MNIST('data', train=False, transform=transform)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=test_batch_size, shuffle=False)
@@ -71,11 +80,11 @@ def train(model, device, train_loader, optimizer, epoch):
correct += pred.eq(target.view_as(pred)).sum().item()
if batch_idx % log_interval == 0 and batch_idx != 0:
- print(f'Epoch {epoch} [{batch_idx * len(data)}/{len(train_loader.dataset)}] Loss: {loss.item():.6f}')
+ print(f'{dataset}: Epoch {epoch} [{batch_idx * len(data)}/{len(train_loader.dataset)}] Loss: {loss.item():.6f}')
avg_loss = sum(losses) / len(losses)
accuracy = 100. * correct / len(train_loader.dataset)
- print(f'Epoch {epoch} - Avg Loss: {avg_loss:.6f}, Accuracy: {accuracy:.2f}%')
+ print(f'{dataset}: Epoch {epoch} - Avg Loss: {avg_loss:.6f}, Accuracy: {accuracy:.2f}%')
# Test function
def test(model, device, test_loader):