Rishi Suman
Rishi Suman

Reputation: 13

Sketch Guided Text to Image Generation

I was trying to replicate this repository based on a Google Research Paper.

I am facing issues with Training the Latent Edge Predictor for batch_size > 1. There are some changes which I made to the code and some issues. I shall be helpful if the community could help me clear them.

Issue 1: File: commands/train_LEP.py Line 81 Why is torch.cat([latent_image] * 2 passed as a parameter instead of simply passing latent_image?

Issue 2 File: commands/train_LEP.py Lines 78 and 93 Why is batch_size*2 passed as a parameter instead of just batch_size?

Issue 3 I always get some dimensionality issue (mismatch of tensor dimensions) when I try to train the LEP with batch_size>1. As a result I made some changes to the code, observing the inaccuracy in tensor dimensions of noise_level, pred_edge_map. This now matches the size of tensors wherever required, but the batch size issue is still not resolved. Please find below the updated code:

UPDATED train_LEP.py

import os
import math
from diffusers import StableDiffusionPipeline
from einops import rearrange
import numpy as np
import torch
from tqdm import tqdm
from transformers import CLIPTokenizer
import typer
from typing import List
from typing_extensions import Annotated

from internals.diffusion_utils import encode_img, encode_text, hook_unet, noisy_latent
from internals.latent_edge_predictor import LatentEdgePredictor
from internals.LEP_dataset import LEPDataset


def train_LEP(
    model_id: Annotated[str, typer.Option()] = "CompVis/stable-diffusion-v1-4",
    device: Annotated[str, typer.Option()] = "cuda:1",
    dataset_dir: Annotated[str, typer.Option(help="path to the parent directory of image data")] = "./data/imagenet/imagenet_images",
    edge_map_dir: Annotated[str, typer.Option(help="path to the parent directory of edge map data")] = "./data/imagenet/edge_maps",
    save_path: Annotated[str, typer.Option(help="path to save LEP model")] = "./output/LEP.pt",
    batch_size: Annotated[int, typer.Option(help="batch size for training LEP. Decrease this if OOM occurs.")] = 1,
    training_step: Annotated[int, typer.Option()] = 4633,
    lr: Annotated[float, typer.Option()] = 1e-4, # not specified in the paper
    num_train_timestep: Annotated[int, typer.Option(help="maximum diffusion timestep")] = 250, # not specified in the paper
):
    '''
    Train the Latent Edge Predictor.
    '''
    # create output folder
    os.makedirs(os.path.dirname(save_path), exist_ok=True)

    # create dataset & loader
    dataset = LEPDataset(dataset_dir, edge_map_dir)
    dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True)

    # initialize stable diffusion pipeline.
    # the paper use stable-diffusion-v1.4
    pipe = StableDiffusionPipeline.from_pretrained(model_id, safety_checker=None, requires_safety_checker = False).to(device)

    unet = pipe.unet
    unet.enable_xformers_memory_efficient_attention()

    # hook the feature_blocks of unet
    feature_blocks = hook_unet(pipe.unet)       
    
    # initialize LEP
    LEP = LatentEdgePredictor(input_dim=9324, output_dim=4, num_layers=10).to(device)

    pipe.unet.eval()
    pipe.vae.eval()
    pipe.text_encoder.eval()

    # need this lines?
    pipe.unet.requires_grad_(False)
    pipe.text_encoder.requires_grad_(False)
    LEP.requires_grad_(True)

    # load clip tokenizer
    tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")

    optimizer = torch.optim.Adam(LEP.parameters(), lr=lr)
    criterion = torch.nn.MSELoss()

    train_epochs = 10
    max_train_steps = train_epochs * len(dataloader)
    num_update_steps_per_epoch = len(dataloader)
    num_train_epochs = math.ceil(max_train_steps / num_update_steps_per_epoch)
    progress_bar = tqdm(
        range(1, max_train_steps),
        smoothing=0,
        desc="steps",
        position=0, leave=True
    )

    for epoch in range(num_train_epochs):
        progress_bar.set_description_str(f"Epoch {epoch+1}/{num_train_epochs}")
        loss_total = 0
        for step, batch in enumerate(dataloader):
            image, edge_map, caption = batch[0], batch[1], batch[2]
            optimizer.zero_grad()
            
            # image to latent
            latent_image = encode_img(pipe.vae, image)
            latent_edge = encode_img(pipe.vae, edge_map)
            latent_edge = latent_edge.transpose(1,3)
            
            caption_embedding = torch.cat([encode_text(pipe.text_encoder, tokenizer, c) for c in caption])
            noisy_image, noise_level, timesteps = noisy_latent(latent_image, pipe.scheduler, batch_size , num_train_timestep)
           
            # one reverse step to get the feature blocks
            pipe.unet(torch.cat([latent_image] * 2), timesteps, encoder_hidden_states=caption_embedding)

            activations = []
            for block in feature_blocks:
                activations.append(block.output)
                block.output = None
        
            features = activations
        
            assert all([isinstance(acts, torch.Tensor) for acts in features])
            size = latent_image.shape[2:]
            resized_activations = []
            for acts in features:
                acts = torch.nn.functional.interpolate(acts, size=size, mode="bilinear")
                acts = acts[:1]
                acts = acts.transpose(1,3)
                resized_activations.append(acts)
            
            intermediate_result = torch.cat(resized_activations, dim=3)
            intermediate_result = intermediate_result.transpose(1,3)
            
            pred_edge_map = LEP(intermediate_result, noise_level)
            pred_edge_map = rearrange(pred_edge_map, "(b w h) c -> b h w c", b=batch_size, h=latent_edge.shape[1], w=latent_edge.shape[2])
            
            # calculate MSE loss
            loss = criterion(pred_edge_map, latent_edge)
            loss.backward()

            optimizer.step()

            current_loss = loss.detach().item()
            loss_total += current_loss
            avr_loss = loss_total / (step + 1)

            if step % 10 == 0:
                progress_bar.set_description(f"Loss: {avr_loss:.3f}")

            if step >= max_train_steps:
                break

            step += 1
        
        if step >= training_step:
            print(f'Finish to optimize. Save file to {save_path}, Epoch = {epoch+1}')
            path = "./output/LEP-" + str(epoch+1) + ".pt"
            torch.save(LEP.state_dict(), path)

Updated internals/diffusion_utils.py

from diffusers import AutoencoderKL, UNet2DConditionModel
import torch
from transformers.models.clip import CLIPTextModel, CLIPTokenizer


def encode_img(vae: AutoencoderKL, image: torch.Tensor):
    generator = torch.Generator(vae.device).manual_seed(0)
    latents = vae.encode(image.to(device=vae.device, dtype=vae.dtype)).latent_dist.sample(generator=generator)
    latents = latents * 0.18215
    return latents


def encode_text(text_encoder: CLIPTextModel, tokenizer: CLIPTokenizer, text):
    text_input = tokenizer([text], padding="max_length", max_length=tokenizer.model_max_length, truncation=True, return_tensors="pt")
    with torch.no_grad():
        text_embeddings = text_encoder(text_input.input_ids.to(text_encoder.device))[0]
    max_length = text_input.input_ids.shape[-1]
    uncond_input = tokenizer([""], padding="max_length", max_length=max_length, return_tensors="pt")
    with torch.no_grad():
        uncond_embeddings = text_encoder(uncond_input.input_ids.to(text_encoder.device))[0]   
    # return torch.cat([uncond_embeddings, text_embeddings]).unsqueeze(0)
    return torch.cat([uncond_embeddings, text_embeddings])


def noisy_latent(image, noise_scheduler, batch_size, num_train_timestep):
    timesteps = torch.randint(0, num_train_timestep, (batch_size,), dtype=torch.int64, device=image.device).long()
    noise = torch.randn_like(image, device=image.device)

    alphas_cumprod = noise_scheduler.alphas_cumprod[timesteps.cpu()].to(image.device)
    # print("alpha_prod = ", alphas_cumprod)
    sqrt_alpha_prod = alphas_cumprod ** 0.5
    sqrt_alpha_prod = sqrt_alpha_prod.flatten()
    
    while len(sqrt_alpha_prod.shape) < len(image.shape):
            sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1)

    sqrt_one_minus_alpha_prod = (1 - alphas_cumprod) ** 0.5
    sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
    
    while len(sqrt_one_minus_alpha_prod.shape) < len(image.shape):
        sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1)

    noisy_samples = sqrt_alpha_prod * image + sqrt_one_minus_alpha_prod * noise
    noise_level = noisy_samples - (sqrt_alpha_prod * image)

    return noisy_samples, noise_level, timesteps
    
def hook_unet(unet: UNet2DConditionModel):
    blocks_idx = [0, 1, 2]
    feature_blocks = []
    def hook(module, input, output):
        if isinstance(output, tuple):
            output = output[0]
        
        if isinstance(output, torch.TensorType):
            feature = output.float()
            setattr(module, "output", feature)
        elif isinstance(output, dict): 
            feature = output.sample.float()
            setattr(module, "output", feature)
        else: 
            feature = output.float()
            setattr(module, "output", feature)
    
    # TODO: Check below lines are correct

    # 0, 1, 2 -> (ldm-down) 2, 4, 8
    for idx, block in enumerate(unet.down_blocks):
        if idx in blocks_idx:
            block.register_forward_hook(hook)
            feature_blocks.append(block) 
            
    # ldm-mid 0, 1, 2
    for block in unet.mid_block.attentions + unet.mid_block.resnets:
        block.register_forward_hook(hook)
        feature_blocks.append(block) 
    
    # 0, 1, 2 -> (ldm-up) 2, 4, 8
    for idx, block in enumerate(unet.up_blocks):
        if idx in blocks_idx:
            block.register_forward_hook(hook)
            feature_blocks.append(block)
     
    return feature_blocks

Upvotes: 0

Views: 39

Answers (0)

Related Questions