Source code for run_experiment

import torch
from denoising_diffusion_pytorch import Unet
from torchvision import transforms as T

import config as cfg
from datasets import OpthalAnonymizedDataset, get_csv_dataset
from models import GaussianDiffusion
from trainers import Trainer
from utils import copy_results_directory
from utils.transforms import HorizontalCenterCrop

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

IMAGE_SIZE = (256, 512)


[docs] def main(): df = get_csv_dataset(cfg.DATASET_FILE_PATH)["train"] DIAGNOSIS = "reference" # ["precancerous", "fluid", "benign", "reference"] transform = T.Compose( [ HorizontalCenterCrop(512), T.Resize(IMAGE_SIZE), T.RandomHorizontalFlip(), T.Grayscale(num_output_channels=1), T.ToTensor(), ] ) dataset = OpthalAnonymizedDataset( diagnosis=DIAGNOSIS, df=df, images_dir=cfg.DATASET_FILE_PATH.parent / "images", transform=transform, convert_image_to="L", ) model = Unet( dim=64, dim_mults=(1, 2, 4, 8), channels=1, ) model.to(device=DEVICE) diffusion = GaussianDiffusion( model, image_size=IMAGE_SIZE, timesteps=1000, # number of steps # # timesteps=2, # number of steps # loss_type = 'l1' # L1 or L2 ) diffusion.to(device=DEVICE) print(f"Model loaded to {DEVICE} device.") # data_path = "./data/input_datasets/with_fluid_eyes_512x256" # data_path = "./data/input_datasets/healthy_eyes_512x256" experiment_id = "0003" trainer = Trainer( # noqa : F841 diffusion, str(cfg.DATASET_FILE_PATH.parent / "images"), dataset=dataset, train_batch_size=4, train_lr=2e-4, save_and_sample_every=2000, # save_and_sample_every=10, results_folder=f"./.results/{experiment_id}/{DIAGNOSIS}", train_num_steps=100_000, # total training steps gradient_accumulate_every=4, # gradient accumulation steps ema_decay=0.995, # exponential moving average decay amp=True, # turn on mixed precision num_samples=9, # number of samples to save calculate_fid=False, # calculate FID during sampling tracker="wandb", tracker_kwargs={ "tags": [DIAGNOSIS, "opthal_anonymized"], "mode": "online", }, ) # trainer.load(".results/0002/reference/model-30.pt") trainer.train() copy_results_directory( f"./.results/{experiment_id}/{DIAGNOSIS}", f"/home/wmi/OneDrive/General/results/05.2024/.results/{experiment_id}/", )
if __name__ == "__main__": main()