Source code for pipeline.parser

import json
import logging
import sys

import yaml
from pydantic import BaseModel, ValidationError

import config as cfg
from pipeline.blocks import ConfigBlocks


[docs] def j_print(data, *args, **kwargs): """ Print the given data in JSON format if possible, otherwise print it as is. For development purposes only. Args: data: The data to be printed. *args: Additional positional arguments to be passed to the print function. **kwargs: Additional keyword arguments to be passed to the print function. """ try: print(json.dumps(data, indent=4), *args, **kwargs) except Exception: print(data, *args, **kwargs)
[docs] def read_config_file(config_file: str) -> dict: """ Parse the configuration file :param config_file: The path to the configuration file :type config_file: str :return: The parsed configuration data :rtype: dict :raises yaml.YAMLError: If there is an error parsing the configuration file """ try: with open(config_file, "r") as f: data = yaml.safe_load(f) return data except yaml.YAMLError as e: logging.error(f"Error parsing the configuration file: {e}") raise e
[docs] def parse_config(config: dict) -> dict[str, BaseModel]: """ Parse and validate the configuration :param config: The configuration dictionary to be parsed and validated :type config: dict :return: A dictionary containing the parsed and validated configuration blocks :rtype: dict[str, BaseModel] """ results = {} unique_essential_blocks = [ConfigBlocks.output.name, ConfigBlocks.general.name] for block in ConfigBlocks: block_config = config.get(block.name.lower()) if block.name.lower() == ConfigBlocks.experiment.name.lower(): block_config = {**block_config, **get_experiment_configs(config)} # get general configs if block.name.lower() not in unique_essential_blocks: general_config = config.get(ConfigBlocks.general.name) if general_config is None: logging.error("General config not found") sys.exit("Parsing config failed") general_block_config = general_config.get(block.name.lower()) block_config = {**general_block_config, **block_config} if general_block_config else block_config try: block_instance: BaseModel = block.value(**block_config) results[block.name.lower()] = block_instance except ValidationError as e: logging.error(f"Error while processing block: {block.name} ") for e in e.errors(): logging.error( f"Error type: \033[1m{e['type']}\033[0m \tfor \033[1m{e['loc']}\033[0m\t| Error message: {e['msg']}" ) sys.exit("Parsing config failed") if block.name == ConfigBlocks.experiment.name and cfg.DEV_DEBUG: j_print(block_instance.model_dump()) print("OK", block.name) return results
[docs] def get_experiment_configs(config: dict) -> dict: """ Get training configuration defined in other blocks. :param config: The configuration dictionary. :type config: dict :return: The experiment configuration dictionary. :rtype: dict """ output_config = config.get(ConfigBlocks.output.name.lower()) general_config = config.get(ConfigBlocks.general.name.lower()) experiment_config = config.get(ConfigBlocks.experiment.name.lower()) if output_config is None: logging.error("Output config not found") sys.exit("Parsing config failed") if general_config is None: logging.error("General config not found") sys.exit("Parsing config failed") if experiment_config is None: logging.error("Experiment config not found") sys.exit("Parsing config failed") experiment_config = { "total_steps": general_config.get("total_steps", 0), "image_size": general_config.get("image_size"), "experiment_id": general_config.get("experiment_id"), "models": general_config.get("models"), "results_dir": output_config.get("results_dir"), "copy_results_to": output_config.get("copy_results_to"), **experiment_config, } return experiment_config