3DFauna_demo / video3d /segmentation.py
kyleleey
first commit
98a77e0
raw
history blame contribute delete
No virus
4.81 kB
import configargparse
import torch
import torch.nn as nn
import torch.utils.data
import torchvision.utils as tvutils
import torchvision.transforms
from video3d.utils.segmentation_transforms import *
from video3d.utils.misc import setup_runtime
from video3d import networks
from video3d.trainer import Trainer
from video3d.dataloaders import SegmentationDataset
class Segmentation:
def __init__(self, cfgs, _):
self.cfgs = cfgs
self.device = cfgs.get('device', 'cpu')
self.total_loss = None
self.net = networks.EDDeconv(cin=3, cout=1, zdim=128, nf=64, activation=None)
self.optimizer = torch.optim.Adam(
filter(lambda p: p.requires_grad, self.net.parameters()),
lr=cfgs.get('lr', 1e-4),
betas=(0.9, 0.999),
weight_decay=5e-4)
def load_model_state(self, cp):
self.net.load_state_dict(cp["net"])
def load_optimizer_state(self, cp):
self.net.load_state_dict(cp["optimizer"])
@staticmethod
def get_data_loaders(cfgs):
batch_size = cfgs.get('batch_size', 64)
num_workers = cfgs.get('num_workers', 4)
data_dir = cfgs.get('data_dir', './data')
img_size = cfgs.get('image_size', 64)
min_size = int(img_size * cfgs.get('aug_min_resize', 0.5))
max_size = int(img_size * cfgs.get('aug_max_resize', 2.0))
transform = Compose([RandomResize(min_size, max_size),
RandomHorizontalFlip(cfgs.get("aug_horizontal_flip", 0.4)),
RandomCrop(img_size),
ImageOnly(torchvision.transforms.ColorJitter(**cfgs.get("aug_color_jitter", {}))),
ImageOnly(torchvision.transforms.RandomGrayscale(cfgs.get("aug_grayscale", 0.2))),
ToTensor()])
train_loader = torch.utils.data.DataLoader(
SegmentationDataset(data_dir, is_validation=False, transform=transform, sequence_range=(0, 0.5)),
batch_size=batch_size,
shuffle=True,
num_workers=num_workers,
pin_memory=True
)
transform = Compose([ToTensor()])
val_loader = torch.utils.data.DataLoader(
SegmentationDataset(data_dir, is_validation=True, transform=transform, sequence_range=(0.5, 1.0)),
batch_size=batch_size,
shuffle=False,
num_workers=num_workers,
pin_memory=True
)
return train_loader, val_loader, None
def get_state_dict(self):
return {
"net": self.net.state_dict(),
"optimizer": self.optimizer.state_dict()
}
def to(self, device):
self.device = device
self.net.to(device)
def set_train(self):
self.net.train()
def set_eval(self):
self.net.eval()
def backward(self):
self.optimizer.zero_grad()
self.total_loss.backward()
self.optimizer.step()
def forward(self, batch, visualize=False):
image, target = batch
image = image.to(self.device)*2 - 1
target = target[:, 0, :, :].to(self.device).unsqueeze(1)
pred = self.net(image)
self.total_loss = nn.functional.binary_cross_entropy_with_logits(pred, target)
metrics = {'loss': self.total_loss}
visuals = {}
if visualize:
visuals['rgb'] = self.image_visual(image, normalize=True, range=(-1, 1))
visuals['target'] = self.image_visual(target, normalize=True, range=(0, 1))
visuals['pred'] = self.image_visual(nn.functional.sigmoid(pred), normalize=True, range=(0, 1))
return metrics, visuals
return metrics
def visualize(self, logger, total_iter, max_bs=25):
pass
def save_results(self, save_dir):
pass
def save_scores(self, path):
pass
@staticmethod
def image_visual(tensor, **kwargs):
if tensor.shape[1] == 1:
tensor = tensor.repeat(1, 3, 1, 1)
n = int(tensor.shape[0]**0.5 + 0.5)
tensor = tvutils.make_grid(tensor.detach(), nrow=n, **kwargs).permute(1, 2, 0)
return torch.clamp(tensor[:, :, :3] * 255, 0, 255).byte().cpu()
if __name__ == "__main__":
parser = configargparse.ArgumentParser(description='Training configurations.')
parser.add_argument('--config', default="config/train_segmentation.yml", type=str, is_config_file=True,
help='Specify a config file path')
parser.add_argument('--gpu', default=1, type=int, help='Specify a GPU device')
parser.add_argument('--seed', default=0, type=int, help='Specify a random seed')
args, _ = parser.parse_known_args()
cfgs = setup_runtime(args)
trainer = Trainer(cfgs, Segmentation)
trainer.train()