Source code for models.gausian_diffusion

import torch
from denoising_diffusion_pytorch import GaussianDiffusion as GausianDiffusionModel


[docs] class GaussianDiffusion(GausianDiffusionModel): """A class representing the Gaussian Diffusion model. Args: *args: Variable length argument list. **kwargs: Arbitrary keyword arguments. Attributes: image_size: The size of the input image. channels: The number of channels in the input image. num_timesteps: The number of diffusion timesteps. Methods: forward: Performs the forward pass of the model. sample: Generates samples from the model. """ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) # TODO add direct arguments passing with typing
[docs] def forward(self, img, *args, **kwargs): (b, _, h, w, device, img_size) = ( *img.shape, img.device, self.image_size, ) if isinstance(img_size, int): img_size = (img_size, img_size) assert h == img_size[0] and w == img_size[1], f"height and width of image must be {img_size}" t = torch.randint(0, self.num_timesteps, (b,), device=device).long() img = self.normalize(img) return self.p_losses(img, t, *args, **kwargs)
[docs] @torch.inference_mode() def sample(self, batch_size=16, return_all_timesteps=False): image_size, channels = self.image_size, self.channels sample_fn = self.p_sample_loop if not self.is_ddim_sampling else self.ddim_sample if isinstance(image_size, tuple): return sample_fn( (batch_size, channels, *image_size), return_all_timesteps=return_all_timesteps, ) return sample_fn( (batch_size, channels, image_size, image_size), return_all_timesteps=return_all_timesteps, )