import torch from torch import nn from torch.utils.data import DataLoader from torchvision import datasets from torchvision.transforms import v2

from early_stopping import EarlyStopping

MAX_EPOCHS = 100

training_data = datasets.FashionMNIST( root=“data”, train=True, download=True, transform=v2.Compose([v2.ToImage(), v2.ToDtype(torch.float32, scale=True)]), )

test_data = datasets.FashionMNIST( root=“data”, train=False, download=True, transform=v2.Compose([v2.ToImage(), v2.ToDtype(torch.float32, scale=True)]), )

train_dataloader = DataLoader(training_data, batch_size=64) test_dataloader = DataLoader(test_data, batch_size=64)

class NeuralNetwork(nn.Module): def init(self): super().init() self.flatten = nn.Flatten() self.linear_relu_stack = nn.Sequential( nn.Linear(28 * 28, 512), nn.ReLU(), nn.Linear(512, 512), nn.ReLU(), nn.Linear(512, 10), )

def forward(self, x):
    x = self.flatten(x)
    logits = self.linear_relu_stack(x)
    return logits

def train_loop(dataloader, model, loss_fn, optimizer):

model.train()

running_loss = 0
total_batches = len(dataloader)

for batch, (X, y) in enumerate(dataloader):
    # Compute prediction and loss
    pred = model(X)
    loss = loss_fn(pred, y)

    # Backpropagation
    loss.backward()
    optimizer.step()
    optimizer.zero_grad()

    running_loss += loss.item()

    # val_acc, val_loss, val_auc, records = evaluate(
    #     model, val_loader, device, loss_criterion
    # )
avg_loss = running_loss / total_batches
return avg_loss

def test_loop(dataloader, model):

model.eval()
size = len(dataloader.dataset)
num_batches = len(dataloader)
test_loss, correct = 0, 0

with torch.no_grad():
    for X, y in dataloader:
        pred = model(X)
        test_loss += criterion(pred, y).item()
        correct += (pred.argmax(1) == y).type(torch.float).sum().item()

test_loss /= num_batches
correct /= size
print(
    f"Test Error: \n Accuracy: {(100 * correct):>0.1f}%, Avg loss: {test_loss:>8f} \n"
)
return test_loss

if name == “main”: model = NeuralNetwork()

learning_rate = 1e-3
batch_size = 64
epochs = 5

weight_ratio = torch.tensor([1] * 10, dtype=torch.float32)
# criterion = nn.BCEWithLogitsLoss(pos_weight=weight_ratio)
criterion = nn.CrossEntropyLoss(weight=weight_ratio)

optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)

early_stopping = EarlyStopping(
    patience=5, path="model.pth", class_weights=weight_ratio
)
for t in range(MAX_EPOCHS):
    train_loss = train_loop(train_dataloader, model, criterion, optimizer)
    test_loss = test_loop(test_dataloader, model)
    print(f"Epoch {t + 1} Train loss {train_loss} Test loss {test_loss}")
    early_stopping(test_loss, model)
    if early_stopping.early_stop:
        print("Early stopping triggered! Training halted.")
        break

print("Done!")