font-identifier / train_font_identifier.py
Gabor Cselle
It does help if we save the model :-)
2e58968
raw
history blame
3.39 kB
import copy
import os
import time
import torch
import torch.optim as optim
import torch.nn as nn
from torch.optim import lr_scheduler
from torchvision import datasets, models, transforms
from tqdm import tqdm
# Directory with organized font images
data_dir = './train_test_images'
# Transformations for the image data
data_transforms = transforms.Compose([
transforms.Grayscale(num_output_channels=3), # Convert images to grayscale with 3 channels
transforms.Resize((224, 224)), # Resize images to the expected input size of the model
transforms.ToTensor(), # Convert images to PyTorch tensors
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) # Normalize with ImageNet stats
])
# Create datasets
image_datasets = {
x: datasets.ImageFolder(os.path.join(data_dir, x), data_transforms)
for x in ['train', 'test']
}
# Create dataloaders
dataloaders = {
'train': torch.utils.data.DataLoader(image_datasets['train'], batch_size=4, shuffle=True),
'test': torch.utils.data.DataLoader(image_datasets['test'], batch_size=4, shuffle=True)
}
# Define the model
model = models.resnet18(weights=models.ResNet18_Weights.DEFAULT)
# Modify the last fully connected layer to match the number of font classes you have
num_classes = len(image_datasets['train'].classes)
model.fc = nn.Linear(model.fc.in_features, num_classes)
# Define the loss function
criterion = torch.nn.CrossEntropyLoss()
# Define loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters())
# Number of epochs to train for
num_epochs = 25
# Function to perform a training step with progress bar
def train_step(model, data_loader, criterion, optimizer):
model.train()
total_loss = 0
progress_bar = tqdm(data_loader, desc='Training', leave=True)
for inputs, targets in progress_bar:
outputs = model(inputs)
loss = criterion(outputs, targets)
optimizer.zero_grad()
loss.backward()
optimizer.step()
total_loss += loss.item()
progress_bar.set_postfix(loss=loss.item())
progress_bar.close()
return total_loss / len(data_loader)
# Function to perform a validation step with progress bar
def validate(model, data_loader, criterion):
model.eval()
total_loss = 0
correct = 0
progress_bar = tqdm(data_loader, desc='Validation', leave=False)
with torch.no_grad():
for inputs, targets in progress_bar:
outputs = model(inputs)
loss = criterion(outputs, targets)
total_loss += loss.item()
_, predicted = torch.max(outputs, 1)
correct += (predicted == targets).sum().item()
progress_bar.set_postfix(loss=loss.item())
progress_bar.close()
return total_loss / len(data_loader), correct / len(data_loader.dataset)
# Training loop with progress bar for epochs
num_epochs = 25 # Replace with the number of epochs you'd like to train for
for epoch in range(num_epochs):
print(f"Epoch {epoch+1}/{num_epochs}")
train_loss = train_step(model, dataloaders["train"], criterion, optimizer)
val_loss, val_accuracy = validate(model, dataloaders["test"], criterion)
print(f"Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}, Val Accuracy: {val_accuracy:.4f}")
# Save the model to disk
torch.save(model.state_dict(), 'font_identifier_model.pth')