u1ug
u1ug

Reputation: 1

Shapes mismatch while training diffusers/UNet2DConditionModel

I am trying to train diffusers/UNet2DConditionModel from scratch. Currently I have error on unet forwarding: mat1 and mat2 shapes cannot be multiplied (288x512 and 1280x512). I noticed that mat1 first dimension (288) can vary depending on dataset batch.

How do I fix matrices shapes error? Do I need to pad mat1 with zeros to make its shape as same as mat2: 1280x512, or I have invalid model init parameters set. I will be thankful for any help.

Here is my training code

import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms as T
from transformers import CLIPModel, CLIPProcessor, CLIPTextModel, AutoTokenizer
from diffusers import AutoencoderKL, UNet2DConditionModel, DDPMScheduler
import pandas as pd
from PIL import Image
import io
from tqdm.auto import tqdm

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


# Dataset class, returns images and corresponding textual captions
class TextImgDataset(Dataset):
    def __init__(self, fp: str):
        self.df = pd.read_parquet(fp)
        self.transform = T.Compose([
            T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),
            T.Resize((64, 64)),
            T.ToTensor(),
        ])

    def __len__(self) -> int:
        return self.df.shape[0]

    def __getitem__(self, idx) -> (torch.Tensor, str):
        row = self.df.iloc[idx]
        img_bytes = io.BytesIO(row['image']['bytes'])
        image = Image.open(img_bytes)
        image_tensor = self.transform(image)
        caption = row['text']

        return image_tensor, caption


# Initialize models
clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to(device)
clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")

text_encoder = CLIPTextModel.from_pretrained("openai/clip-vit-base-patch32")
tokenizer = AutoTokenizer.from_pretrained("openai/clip-vit-base-patch32")

vae = AutoencoderKL.from_single_file(
    "https://huggingface.co/stabilityai/sd-vae-ft-mse-original/blob/main/vae-ft-mse-840000-ema-pruned.safetensors").to(
    device)

unet = UNet2DConditionModel(
    in_channels=4,
    out_channels=4,
    layers_per_block=2,
    sample_size=64,
    block_out_channels=(128, 256, 512, 512),
    down_block_types=("DownBlock2D", "DownBlock2D", "DownBlock2D", "AttnDownBlock2D"),
    up_block_types=("AttnUpBlock2D", "UpBlock2D", "UpBlock2D", "UpBlock2D"),
).to(device)

noise_scheduler = DDPMScheduler(num_train_timesteps=1000, beta_start=0.0001, beta_end=0.02, beta_schedule="linear")

# DataLoader
# I use dataset from here https://huggingface.co/datasets/pranked03/flowers-blip-captions
dataset = TextImgDataset(fp='~/dataset.parquet')
dataloader = DataLoader(dataset, batch_size=16, shuffle=True)

# Optimizers
optimizer_vae = torch.optim.Adam(vae.parameters(), lr=1e-4)
optimizer_unet = torch.optim.Adam(unet.parameters(), lr=1e-4)

# Training Loop
num_epochs = 10

unet.train()

for epoch in range(num_epochs):
    for batch in tqdm(dataloader):
        images, captions = batch
        images = images.to(device)
        latents = vae.encode(images).latent_dist.sample()
        latents = latents * vae.config.scaling_factor

        noise = torch.randn_like(latents)
        bsz = latents.shape[0]
        timesteps = torch.randint(0, noise_scheduler.num_train_timesteps, (bsz,), device=latents.device).long()

        inputs = tokenizer(captions, padding=True, return_tensors="pt", truncation=True)
        outputs = text_encoder(**inputs).last_hidden_state.to(device)

        noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
        #noisy_latents shape: [16, 4, 8, 8]
        #timesteps shape: torch.Size([16])
        #encoder_hidden_states shape: torch.Size([16, 18, 512])
        pred = unet(sample=noisy_latents, timestep=timesteps, encoder_hidden_states=outputs, return_dict=False) # getting the error here

        # .... rest of the code (backpropagation and sampling)

Upvotes: 0

Views: 170

Answers (0)

Related Questions