Source code for pipeline.blocks.experiment_dto

from pathlib import Path
from typing import Literal, Optional, Union

from pydantic import BaseModel, ConfigDict, computed_field, field_validator

from pipeline.blocks.models_dto import ModelsDTO

TRAIN_GENERATOR = "train_generator"
GENERATE_SAMPLES = "generate_samples"
VALIDATE = "validate"
FOO = "foo"


[docs] class LoopObjectDTO(BaseModel): model_config = ConfigDict(protected_namespaces=()) name: str repeat: bool = True
[docs] @field_validator("name") def name_validator(cls, v): if v not in [TRAIN_GENERATOR, GENERATE_SAMPLES, VALIDATE, FOO]: raise ValueError("Undefined training loop object") return v.title()
[docs] class TrainGeneratorDTO(LoopObjectDTO): batch_size: int lr: float save_and_sample_every: int num_steps: int results_dir: str gradient_accumulate_every: int = 4 experiment_id: Optional[str] = None copy_results_to: Optional[str] = None start_from_checkpoint: Optional[str] = None # TODO implement dataset_split_type: Literal["train", "val", "test"] = "train" diagnosis: Literal["precancerous", "fluid", "benign", "reference"]
[docs] @field_validator("name") def name_validator(cls, v): if v != TRAIN_GENERATOR: raise ValueError("name must be 'train_generator'") return v.title()
[docs] class GenerateSamplesDTO(LoopObjectDTO): num_samples: int batch_size: int wandb: bool = True model_version: str base_on: str results_dir: str copy_results_to: Optional[str] = None relative_dataset_results_dir: str = "dataset" @computed_field @property def checkpoint_path(self) -> str: return str(Path(self.results_dir) / self.base_on / f"{self.model_version}.pt") @computed_field @property def generete_samples_dir(self) -> str: return str(Path(self.results_dir) / self.relative_dataset_results_dir / self.base_on)
[docs] @field_validator("name") def name_validator(cls, v): if v != GENERATE_SAMPLES: raise ValueError(f"name must be {GENERATE_SAMPLES}") return v.title()
[docs] class FooDTO(LoopObjectDTO): foo: str = "foo"
[docs] @field_validator("name") def name_validator(cls, v): if v != FOO: raise ValueError(f"name must be {FOO}") return v.title()
[docs] class ClassificationDTO(BaseModel): epochs: int lr: float loss_multiply: float = 1.0 ratio: list[float] = [0.8, 0.2] class_names: list[str] = ["fluid", "benign", "precancerous", "reference"] train_data_type: Literal["real", "synthetic"] = "synthetic" train_dataset_dir: str val_dataset_dir: Optional[str] = None test_dataset_dir: str loss_fn: str = "cross_entropy" num_workers: int = 4 batch_size: int = 32 log_every_n_steps: int = 10 logger_tags: Optional[list[str]] = None logger_experiment_name: Optional[str] = None offline: bool = False results_dir: str @computed_field @property def num_classes(self) -> int: return len(self.class_names)
[docs] class ValidateDTO(LoopObjectDTO): results_dir: str copy_results_to: Optional[str] = None classification: Optional[ClassificationDTO] = None
[docs] @field_validator("name") def name_validator(cls, v): if v != VALIDATE: raise ValueError(f"name must be {VALIDATE}") return v.title()
[docs] @field_validator("classification", mode="before") @classmethod def classification_validator(cls, v, values): v["results_dir"] = values.data.get("results_dir", ".results") return ClassificationDTO(**v)
[docs] class ExperimentDTO(BaseModel): total_steps: int image_size: list[int] models: ModelsDTO results_dir: str copy_results_to: Optional[str] = None loop: list[Union[TrainGeneratorDTO, GenerateSamplesDTO, ValidateDTO, FooDTO]]
[docs] @field_validator("loop", mode="before") @classmethod def loop_validator(cls, v, values): res = [] for loop_obj in v: loop_obj["results_dir"] = values.data.get("results_dir", ".results") loop_obj["copy_results_to"] = values.data.get("copy_results_to", None) if loop_obj["name"] == TRAIN_GENERATOR: res.append(TrainGeneratorDTO(**loop_obj)) if loop_obj["name"] == GENERATE_SAMPLES: res.append(GenerateSamplesDTO(**loop_obj)) if loop_obj["name"] == VALIDATE: res.append(ValidateDTO(**loop_obj)) if loop_obj["name"] == FOO: res.append(FooDTO(**loop_obj)) return res