Reputation: 13
I was trying to replicate this repository based on a Google Research Paper.
I am facing issues with Training the Latent Edge Predictor for batch_size > 1
. There are some changes which I made to the code and some issues. I shall be helpful if the community could help me clear them.
Issue 1:
File: commands/train_LEP.py
Line 81
Why is torch.cat([latent_image] * 2
passed as a parameter instead of simply passing latent_image
?
Issue 2
File: commands/train_LEP.py
Lines 78 and 93
Why is batch_size*2
passed as a parameter instead of just batch_size
?
Issue 3
I always get some dimensionality issue (mismatch of tensor dimensions) when I try to train the LEP with batch_size>1
. As a result I made some changes to the code, observing the inaccuracy in tensor dimensions of noise_level
, pred_edge_map
. This now matches the size of tensors wherever required, but the batch size issue is still not resolved. Please find below the updated code:
UPDATED train_LEP.py
import os
import math
from diffusers import StableDiffusionPipeline
from einops import rearrange
import numpy as np
import torch
from tqdm import tqdm
from transformers import CLIPTokenizer
import typer
from typing import List
from typing_extensions import Annotated
from internals.diffusion_utils import encode_img, encode_text, hook_unet, noisy_latent
from internals.latent_edge_predictor import LatentEdgePredictor
from internals.LEP_dataset import LEPDataset
def train_LEP(
model_id: Annotated[str, typer.Option()] = "CompVis/stable-diffusion-v1-4",
device: Annotated[str, typer.Option()] = "cuda:1",
dataset_dir: Annotated[str, typer.Option(help="path to the parent directory of image data")] = "./data/imagenet/imagenet_images",
edge_map_dir: Annotated[str, typer.Option(help="path to the parent directory of edge map data")] = "./data/imagenet/edge_maps",
save_path: Annotated[str, typer.Option(help="path to save LEP model")] = "./output/LEP.pt",
batch_size: Annotated[int, typer.Option(help="batch size for training LEP. Decrease this if OOM occurs.")] = 1,
training_step: Annotated[int, typer.Option()] = 4633,
lr: Annotated[float, typer.Option()] = 1e-4, # not specified in the paper
num_train_timestep: Annotated[int, typer.Option(help="maximum diffusion timestep")] = 250, # not specified in the paper
):
'''
Train the Latent Edge Predictor.
'''
# create output folder
os.makedirs(os.path.dirname(save_path), exist_ok=True)
# create dataset & loader
dataset = LEPDataset(dataset_dir, edge_map_dir)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True)
# initialize stable diffusion pipeline.
# the paper use stable-diffusion-v1.4
pipe = StableDiffusionPipeline.from_pretrained(model_id, safety_checker=None, requires_safety_checker = False).to(device)
unet = pipe.unet
unet.enable_xformers_memory_efficient_attention()
# hook the feature_blocks of unet
feature_blocks = hook_unet(pipe.unet)
# initialize LEP
LEP = LatentEdgePredictor(input_dim=9324, output_dim=4, num_layers=10).to(device)
pipe.unet.eval()
pipe.vae.eval()
pipe.text_encoder.eval()
# need this lines?
pipe.unet.requires_grad_(False)
pipe.text_encoder.requires_grad_(False)
LEP.requires_grad_(True)
# load clip tokenizer
tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
optimizer = torch.optim.Adam(LEP.parameters(), lr=lr)
criterion = torch.nn.MSELoss()
train_epochs = 10
max_train_steps = train_epochs * len(dataloader)
num_update_steps_per_epoch = len(dataloader)
num_train_epochs = math.ceil(max_train_steps / num_update_steps_per_epoch)
progress_bar = tqdm(
range(1, max_train_steps),
smoothing=0,
desc="steps",
position=0, leave=True
)
for epoch in range(num_train_epochs):
progress_bar.set_description_str(f"Epoch {epoch+1}/{num_train_epochs}")
loss_total = 0
for step, batch in enumerate(dataloader):
image, edge_map, caption = batch[0], batch[1], batch[2]
optimizer.zero_grad()
# image to latent
latent_image = encode_img(pipe.vae, image)
latent_edge = encode_img(pipe.vae, edge_map)
latent_edge = latent_edge.transpose(1,3)
caption_embedding = torch.cat([encode_text(pipe.text_encoder, tokenizer, c) for c in caption])
noisy_image, noise_level, timesteps = noisy_latent(latent_image, pipe.scheduler, batch_size , num_train_timestep)
# one reverse step to get the feature blocks
pipe.unet(torch.cat([latent_image] * 2), timesteps, encoder_hidden_states=caption_embedding)
activations = []
for block in feature_blocks:
activations.append(block.output)
block.output = None
features = activations
assert all([isinstance(acts, torch.Tensor) for acts in features])
size = latent_image.shape[2:]
resized_activations = []
for acts in features:
acts = torch.nn.functional.interpolate(acts, size=size, mode="bilinear")
acts = acts[:1]
acts = acts.transpose(1,3)
resized_activations.append(acts)
intermediate_result = torch.cat(resized_activations, dim=3)
intermediate_result = intermediate_result.transpose(1,3)
pred_edge_map = LEP(intermediate_result, noise_level)
pred_edge_map = rearrange(pred_edge_map, "(b w h) c -> b h w c", b=batch_size, h=latent_edge.shape[1], w=latent_edge.shape[2])
# calculate MSE loss
loss = criterion(pred_edge_map, latent_edge)
loss.backward()
optimizer.step()
current_loss = loss.detach().item()
loss_total += current_loss
avr_loss = loss_total / (step + 1)
if step % 10 == 0:
progress_bar.set_description(f"Loss: {avr_loss:.3f}")
if step >= max_train_steps:
break
step += 1
if step >= training_step:
print(f'Finish to optimize. Save file to {save_path}, Epoch = {epoch+1}')
path = "./output/LEP-" + str(epoch+1) + ".pt"
torch.save(LEP.state_dict(), path)
Updated internals/diffusion_utils.py
from diffusers import AutoencoderKL, UNet2DConditionModel
import torch
from transformers.models.clip import CLIPTextModel, CLIPTokenizer
def encode_img(vae: AutoencoderKL, image: torch.Tensor):
generator = torch.Generator(vae.device).manual_seed(0)
latents = vae.encode(image.to(device=vae.device, dtype=vae.dtype)).latent_dist.sample(generator=generator)
latents = latents * 0.18215
return latents
def encode_text(text_encoder: CLIPTextModel, tokenizer: CLIPTokenizer, text):
text_input = tokenizer([text], padding="max_length", max_length=tokenizer.model_max_length, truncation=True, return_tensors="pt")
with torch.no_grad():
text_embeddings = text_encoder(text_input.input_ids.to(text_encoder.device))[0]
max_length = text_input.input_ids.shape[-1]
uncond_input = tokenizer([""], padding="max_length", max_length=max_length, return_tensors="pt")
with torch.no_grad():
uncond_embeddings = text_encoder(uncond_input.input_ids.to(text_encoder.device))[0]
# return torch.cat([uncond_embeddings, text_embeddings]).unsqueeze(0)
return torch.cat([uncond_embeddings, text_embeddings])
def noisy_latent(image, noise_scheduler, batch_size, num_train_timestep):
timesteps = torch.randint(0, num_train_timestep, (batch_size,), dtype=torch.int64, device=image.device).long()
noise = torch.randn_like(image, device=image.device)
alphas_cumprod = noise_scheduler.alphas_cumprod[timesteps.cpu()].to(image.device)
# print("alpha_prod = ", alphas_cumprod)
sqrt_alpha_prod = alphas_cumprod ** 0.5
sqrt_alpha_prod = sqrt_alpha_prod.flatten()
while len(sqrt_alpha_prod.shape) < len(image.shape):
sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1)
sqrt_one_minus_alpha_prod = (1 - alphas_cumprod) ** 0.5
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
while len(sqrt_one_minus_alpha_prod.shape) < len(image.shape):
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1)
noisy_samples = sqrt_alpha_prod * image + sqrt_one_minus_alpha_prod * noise
noise_level = noisy_samples - (sqrt_alpha_prod * image)
return noisy_samples, noise_level, timesteps
def hook_unet(unet: UNet2DConditionModel):
blocks_idx = [0, 1, 2]
feature_blocks = []
def hook(module, input, output):
if isinstance(output, tuple):
output = output[0]
if isinstance(output, torch.TensorType):
feature = output.float()
setattr(module, "output", feature)
elif isinstance(output, dict):
feature = output.sample.float()
setattr(module, "output", feature)
else:
feature = output.float()
setattr(module, "output", feature)
# TODO: Check below lines are correct
# 0, 1, 2 -> (ldm-down) 2, 4, 8
for idx, block in enumerate(unet.down_blocks):
if idx in blocks_idx:
block.register_forward_hook(hook)
feature_blocks.append(block)
# ldm-mid 0, 1, 2
for block in unet.mid_block.attentions + unet.mid_block.resnets:
block.register_forward_hook(hook)
feature_blocks.append(block)
# 0, 1, 2 -> (ldm-up) 2, 4, 8
for idx, block in enumerate(unet.up_blocks):
if idx in blocks_idx:
block.register_forward_hook(hook)
feature_blocks.append(block)
return feature_blocks
Upvotes: 0
Views: 39