DALLE2-pytorch - PytorchでのDALL-E2、OpenAIの更新されたテキストから画像への合成ニューラルネットワークの実装

(Implementation of DALL-E 2, OpenAI's updated text-to-image synthesis neural network, in Pytorch)

Created at: 2022-04-07 12:14:08
Language: Python
License: MIT

DALL-E 2-Pytorch

PytorchでのOpenAIの更新されたテキストから画像への合成ニューラルネットワークであるDALL-E2の実装。

YannicKilcherの概要| AssemblyAIの説明者

主な目新しさは、CLIPからのテキスト埋め込みに基づいて画像埋め込みを予測する、以前のネットワーク(自己回帰トランスフォーマーまたは拡散ネットワーク)との間接的な追加レイヤーのようです。具体的には、このリポジトリは、最もパフォーマンスの高いバリアントであるため、拡散前のネットワークのみを構築します(ただし、ノイズ除去ネットワークとして因果関係のある変圧器が偶然に含まれます)😂)。

このモデルは、現時点ではテキストからイメージへのSOTAです。

LAIONコミュニティDiscordにご参加くださいでの複製を支援することに興味がある場合は、参加してください。ヤニックインタビュー

5/23/22の時点で、SOTAではなくなりました。SOTAはここにあります。Jaxバージョンとテキストからビデオへのプロジェクトは、はるかに単純なため、Imagenアーキテクチャに移行します。

状態

  • 研究グループは、このリポジトリのコードを使用して、CLIP世代の前に機能の普及をトレーニングしました。プレプリントをリリースしたら、作品を共有します。これとキャサリン自身の実験は、余分な事前情報が世代の多様性を増加させるというOpenAIの発見を検証します。

  • オックスフォードの花の実験的なセットアップで、デコーダーが無条件で生成できることが確認されました。2人の研究者もデコーダーが彼らのために働いていることを確認しました。

21kステップで進行中

  • Justin Pinkneyは、CLIP to Stylegan2 text-to-imageアプリケーションのリポジトリで、事前に拡散のトレーニングに成功しました。

事前トレーニング済みモデル

  • LAIONは以前のモデルをトレーニングしています。チェックポイントはで利用可能です🤗huggingfaceとトレーニング統計はで利用可能です🐝WANDB
  • デコーダー-進行中のテスト実行 🚧
  • DALL-E 2🚧

インストール

$ pip install dalle2-pytorch

使用法

DALLE-2のトレーニングは3ステップのプロセスであり、CLIPのトレーニングが最も重要です。

CLIPをトレーニングするには、x-clipパッケージを使用するか、LAIONの不和に参加することができます。ここでは、多くの複製作業がすでに進行中です

このリポジトリは

x-clip
、初心者向けとの統合を示します

import torch
from dalle2_pytorch import CLIP

clip = CLIP(
    dim_text = 512,
    dim_image = 512,
    dim_latent = 512,
    num_text_tokens = 49408,
    text_enc_depth = 1,
    text_seq_len = 256,
    text_heads = 8,
    visual_enc_depth = 1,
    visual_image_size = 256,
    visual_patch_size = 32,
    visual_heads = 8,
    use_all_token_embeds = True,            # whether to use fine-grained contrastive learning (FILIP)
    decoupled_contrastive_learning = True,  # use decoupled contrastive learning (DCL) objective function, removing positive pairs from the denominator of the InfoNCE loss (CLOOB + DCL)
    extra_latent_projection = True,         # whether to use separate projections for text-to-image vs image-to-text comparisons (CLOOB)
    use_visual_ssl = True,                  # whether to do self supervised learning on images
    visual_ssl_type = 'simclr',             # can be either 'simclr' or 'simsiam', depending on using DeCLIP or SLIP
    use_mlm = False,                        # use masked language learning (MLM) on text (DeCLIP)
    text_ssl_loss_weight = 0.05,            # weight for text MLM loss
    image_ssl_loss_weight = 0.05            # weight for image self-supervised learning loss
).cuda()

# mock data

text = torch.randint(0, 49408, (4, 256)).cuda()
images = torch.randn(4, 3, 256, 256).cuda()

# train

loss = clip(
    text,
    images,
    return_loss = True              # needs to be set to True to return contrastive loss
)

loss.backward()

# do the above with as many texts and images as possible in a loop

次に、デコーダーをトレーニングする必要があります。デコーダーは、上記のトレーニング済みCLIPからの画像埋め込みに基づいて画像を生成することを学習します。

import torch
from dalle2_pytorch import Unet, Decoder, CLIP

# trained clip from step 1

clip = CLIP(
    dim_text = 512,
    dim_image = 512,
    dim_latent = 512,
    num_text_tokens = 49408,
    text_enc_depth = 1,
    text_seq_len = 256,
    text_heads = 8,
    visual_enc_depth = 1,
    visual_image_size = 256,
    visual_patch_size = 32,
    visual_heads = 8
).cuda()

# unet for the decoder

unet = Unet(
    dim = 128,
    image_embed_dim = 512,
    cond_dim = 128,
    channels = 3,
    dim_mults=(1, 2, 4, 8)
).cuda()

# decoder, which contains the unet and clip

decoder = Decoder(
    unet = unet,
    clip = clip,
    timesteps = 100,
    image_cond_drop_prob = 0.1,
    text_cond_drop_prob = 0.5
).cuda()

# mock images (get a lot of this)

images = torch.randn(4, 3, 256, 256).cuda()

# feed images into decoder

loss = decoder(images)
loss.backward()

# do the above for many many many many steps
# then it will learn to generate images based on the CLIP image embeddings

最後に、論文の主な貢献。リポジトリは、拡散前のネットワークを提供します。CLIPテキスト埋め込みを取得し、CLIP画像埋め込みを生成しようとします。繰り返しますが、最初のステップからトレーニングを受けたCLIPが必要になります

import torch
from dalle2_pytorch import DiffusionPriorNetwork, DiffusionPrior, CLIP

# get trained CLIP from step one

clip = CLIP(
    dim_text = 512,
    dim_image = 512,
    dim_latent = 512,
    num_text_tokens = 49408,
    text_enc_depth = 6,
    text_seq_len = 256,
    text_heads = 8,
    visual_enc_depth = 6,
    visual_image_size = 256,
    visual_patch_size = 32,
    visual_heads = 8,
).cuda()

# setup prior network, which contains an autoregressive transformer

prior_network = DiffusionPriorNetwork(
    dim = 512,
    depth = 6,
    dim_head = 64,
    heads = 8
).cuda()

# diffusion prior network, which contains the CLIP and network (with transformer) above

diffusion_prior = DiffusionPrior(
    net = prior_network,
    clip = clip,
    timesteps = 100,
    cond_drop_prob = 0.2
).cuda()

# mock data

text = torch.randint(0, 49408, (4, 256)).cuda()
images = torch.randn(4, 3, 256, 256).cuda()

# feed text and images into diffusion prior network

loss = diffusion_prior(text, images)
loss.backward()

# do the above for many many many steps
# now the diffusion prior can generate image embeddings from the text embeddings

この論文では、彼らは実際に、ジョナサン・ホー自身(DALL-E v2で使用されるコア技術であるDDPMの元の著者)から最近発見された技術を高解像度画像合成に使用しました。

これは、このフレームワーク内で簡単に使用できます。

import torch
from dalle2_pytorch import Unet, Decoder, CLIP

# trained clip from step 1

clip = CLIP(
    dim_text = 512,
    dim_image = 512,
    dim_latent = 512,
    num_text_tokens = 49408,
    text_enc_depth = 6,
    text_seq_len = 256,
    text_heads = 8,
    visual_enc_depth = 6,
    visual_image_size = 256,
    visual_patch_size = 32,
    visual_heads = 8
).cuda()

# 2 unets for the decoder (a la cascading DDPM)

unet1 = Unet(
    dim = 32,
    image_embed_dim = 512,
    cond_dim = 128,
    channels = 3,
    dim_mults = (1, 2, 4, 8)
).cuda()

unet2 = Unet(
    dim = 32,
    image_embed_dim = 512,
    cond_dim = 128,
    channels = 3,
    dim_mults = (1, 2, 4, 8, 16)
).cuda()

# decoder, which contains the unet(s) and clip

decoder = Decoder(
    clip = clip,
    unet = (unet1, unet2),            # insert both unets in order of low resolution to highest resolution (you can have as many stages as you want here)
    image_sizes = (256, 512),         # resolutions, 256 for first unet, 512 for second. these must be unique and in ascending order (matches with the unets passed in)
    timesteps = 1000,
    image_cond_drop_prob = 0.1,
    text_cond_drop_prob = 0.5
).cuda()

# mock images (get a lot of this)

images = torch.randn(4, 3, 512, 512).cuda()

# feed images into decoder, specifying which unet you want to train
# each unet can be trained separately, which is one of the benefits of the cascading DDPM scheme

loss = decoder(images, unet_number = 1)
loss.backward()

loss = decoder(images, unet_number = 2)
loss.backward()

# do the above for many steps for both unets

最後に、テキストからDALL-E2画像を生成します。訓練さ

DiffusionPrior
れたものと同様に
Decoder
(ラップ
CLIP
、因果トランスフォーマー、およびunet(s))を挿入します

from dalle2_pytorch import DALLE2

dalle2 = DALLE2(
    prior = diffusion_prior,
    decoder = decoder
)

# send the text as a string if you want to use the simple tokenizer from DALLE v1
# or you can do it as token ids, if you have your own tokenizer

texts = ['glistening morning dew on a flower petal']
images = dalle2(texts) # (1, 3, 256, 256)

それでおしまい!

以下のスクリプト全体を見てみましょう

import torch
from dalle2_pytorch import DALLE2, DiffusionPriorNetwork, DiffusionPrior, Unet, Decoder, CLIP

clip = CLIP(
    dim_text = 512,
    dim_image = 512,
    dim_latent = 512,
    num_text_tokens = 49408,
    text_enc_depth = 6,
    text_seq_len = 256,
    text_heads = 8,
    visual_enc_depth = 6,
    visual_image_size = 256,
    visual_patch_size = 32,
    visual_heads = 8
).cuda()

# mock data

text = torch.randint(0, 49408, (4, 256)).cuda()
images = torch.randn(4, 3, 256, 256).cuda()

# train

loss = clip(
    text,
    images,
    return_loss = True
)

loss.backward()

# do above for many steps ...

# prior networks (with transformer)

prior_network = DiffusionPriorNetwork(
    dim = 512,
    depth = 6,
    dim_head = 64,
    heads = 8
).cuda()

diffusion_prior = DiffusionPrior(
    net = prior_network,
    clip = clip,
    timesteps = 100,
    cond_drop_prob = 0.2
).cuda()

loss = diffusion_prior(text, images)
loss.backward()

# do above for many steps ...

# decoder (with unet)

unet1 = Unet(
    dim = 128,
    image_embed_dim = 512,
    cond_dim = 128,
    channels = 3,
    dim_mults=(1, 2, 4, 8)
).cuda()

unet2 = Unet(
    dim = 16,
    image_embed_dim = 512,
    cond_dim = 128,
    channels = 3,
    dim_mults = (1, 2, 4, 8, 16)
).cuda()

decoder = Decoder(
    unet = (unet1, unet2),
    image_sizes = (128, 256),
    clip = clip,
    timesteps = 100,
    image_cond_drop_prob = 0.1,
    text_cond_drop_prob = 0.5,
    condition_on_text_encodings = False  # set this to True if you wish to condition on text during training and sampling
).cuda()

for unet_number in (1, 2):
    loss = decoder(images, unet_number = unet_number) # this can optionally be decoder(images, text) if you wish to condition on the text encodings as well, though it was hinted in the paper it didn't do much
    loss.backward()

# do above for many steps

dalle2 = DALLE2(
    prior = diffusion_prior,
    decoder = decoder
)

images = dalle2(
    ['cute puppy chasing after a squirrel'],
    cond_scale = 2. # classifier free guidance strength (> 1 would strengthen the condition)
)

# save your image (in this example, of size 256x256)

このreadmeのすべてがエラーなしで実行されるはずです

CLIPがトレーニングされたサイズ(256x256)よりも大きいサイズ(たとえば512x512)の画像でデコーダーをトレーニングすることもできます。画像は、画像埋め込み用にCLIP画像解像度にサイズ変更されます

素人の場合、心配する必要はありません。少なくとも小規模なトレーニングの場合、トレーニングはすべてCLIツールに自動化されます。

前処理されたCLIP埋め込みに関するトレーニング

スケールアップする場合、前のネットワークをトレーニングする前に、まず画像とテキストを対応する埋め込みに前処理する可能性があります。、、およびオプションで

image_embed
、を渡すだけで簡単に行うことができます。
text_embed
text_encodings
text_mask

以下の作業例

import torch
from dalle2_pytorch import DiffusionPriorNetwork, DiffusionPrior, CLIP

# get trained CLIP from step one

clip = CLIP(
    dim_text = 512,
    dim_image = 512,
    dim_latent = 512,
    num_text_tokens = 49408,
    text_enc_depth = 6,
    text_seq_len = 256,
    text_heads = 8,
    visual_enc_depth = 6,
    visual_image_size = 256,
    visual_patch_size = 32,
    visual_heads = 8,
).cuda()

# setup prior network, which contains an autoregressive transformer

prior_network = DiffusionPriorNetwork(
    dim = 512,
    depth = 6,
    dim_head = 64,
    heads = 8
).cuda()

# diffusion prior network, which contains the CLIP and network (with transformer) above

diffusion_prior = DiffusionPrior(
    net = prior_network,
    clip = clip,
    timesteps = 100,
    cond_drop_prob = 0.2,
    condition_on_text_encodings = False  # this probably should be true, but just to get Laion started
).cuda()

# mock data

text = torch.randint(0, 49408, (4, 256)).cuda()
images = torch.randn(4, 3, 256, 256).cuda()

# precompute the text and image embeddings
# here using the diffusion prior class, but could be done with CLIP alone

clip_image_embeds = diffusion_prior.clip.embed_image(images).image_embed
clip_text_embeds = diffusion_prior.clip.embed_text(text).text_embed

# feed text and images into diffusion prior network

loss = diffusion_prior(
    text_embed = clip_text_embeds,
    image_embed = clip_image_embeds
)

loss.backward()

# do the above for many many many steps
# now the diffusion prior can generate image embeddings from the text embeddings

完全にレスにすることもできます。その場合、初期化時

CLIP
ににを渡す必要があります。
image_embed_dim
DiffusionPrior

import torch
from dalle2_pytorch import DiffusionPriorNetwork, DiffusionPrior

# setup prior network, which contains an autoregressive transformer

prior_network = DiffusionPriorNetwork(
    dim = 512,
    depth = 6,
    dim_head = 64,
    heads = 8
).cuda()

# diffusion prior network, which contains the CLIP and network (with transformer) above

diffusion_prior = DiffusionPrior(
    net = prior_network,
    image_embed_dim = 512,               # this needs to be set
    timesteps = 100,
    cond_drop_prob = 0.2,
    condition_on_text_encodings = False  # this probably should be true, but just to get Laion started
).cuda()

# mock data

text = torch.randint(0, 49408, (4, 256)).cuda()
images = torch.randn(4, 3, 256, 256).cuda()

# precompute the text and image embeddings
# here using the diffusion prior class, but could be done with CLIP alone

clip_image_embeds = torch.randn(4, 512).cuda()
clip_text_embeds = torch.randn(4, 512).cuda()

# feed text and images into diffusion prior network

loss = diffusion_prior(
    text_embed = clip_text_embeds,
    image_embed = clip_image_embeds
)

loss.backward()

# do the above for many many many steps
# now the diffusion prior can generate image embeddings from the text embeddings

OpenAI CLIP

リリースされていない、より強力なCLIPを使用している可能性はありますが、独自のCLIPを最初からトレーニングしたくない場合は、リリースされているCLIPのいずれかを使用できます。これにより、コミュニティは論文の結論をより迅速に検証することもできます。

事前にトレーニングされたOpenAICLIPを使用するには、インポートして、など

OpenAIClipAdapter
に渡すだけです。
DiffusionPrior
Decoder

import torch
from dalle2_pytorch import DALLE2, DiffusionPriorNetwork, DiffusionPrior, Unet, Decoder, OpenAIClipAdapter

# openai pretrained clip - defaults to ViT-B/32

clip = OpenAIClipAdapter()

# mock data

text = torch.randint(0, 49408, (4, 256)).cuda()
images = torch.randn(4, 3, 256, 256).cuda()

# prior networks (with transformer)

prior_network = DiffusionPriorNetwork(
    dim = 512,
    depth = 6,
    dim_head = 64,
    heads = 8
).cuda()

diffusion_prior = DiffusionPrior(
    net = prior_network,
    clip = clip,
    timesteps = 100,
    cond_drop_prob = 0.2
).cuda()

loss = diffusion_prior(text, images)
loss.backward()

# do above for many steps ...

# decoder (with unet)

unet1 = Unet(
    dim = 128,
    image_embed_dim = 512,
    cond_dim = 128,
    channels = 3,
    dim_mults=(1, 2, 4, 8)
).cuda()

unet2 = Unet(
    dim = 16,
    image_embed_dim = 512,
    cond_dim = 128,
    channels = 3,
    dim_mults = (1, 2, 4, 8, 16)
).cuda()

decoder = Decoder(
    unet = (unet1, unet2),
    image_sizes = (128, 256),
    clip = clip,
    timesteps = 100,
    image_cond_drop_prob = 0.1,
    text_cond_drop_prob = 0.5,
    condition_on_text_encodings = False  # set this to True if you wish to condition on text during training and sampling
).cuda()

for unet_number in (1, 2):
    loss = decoder(images, unet_number = unet_number) # this can optionally be decoder(images, text) if you wish to condition on the text encodings as well, though it was hinted in the paper it didn't do much
    loss.backward()

# do above for many steps

dalle2 = DALLE2(
    prior = diffusion_prior,
    decoder = decoder
)

images = dalle2(
    ['a butterfly trying to escape a tornado'],
    cond_scale = 2. # classifier free guidance strength (> 1 would strengthen the condition)
)

# save your image (in this example, of size 256x256)

これで、PriorとDecoderのトレーニングについて心配する必要があります。

実験的

潜在拡散を伴うDALL-E2

このリポジトリは、次のステップに進み、潜在拡散と組み合わせたDALL-Ev2をRombachetalから提供することを決定します。

以下のように使用できます。潜在的な拡散は、カスケードの最初のU-Netのみ、または任意の数に制限できます。

リポジトリには

ViT-VQGan
改良されたVQGansペーパーから再作成するために必要なすべての設定も備わっています。さらに、ベクトル量子化ライブラリには、残差量子化またはマルチヘッド量子化を実行する機能も備わっています。これにより、オートエンコーダのパフォーマンスがさらに向上すると思います。

import torch
from dalle2_pytorch import Unet, Decoder, CLIP, VQGanVAE

# trained clip from step 1

clip = CLIP(
    dim_text = 512,
    dim_image = 512,
    dim_latent = 512,
    num_text_tokens = 49408,
    text_enc_depth = 1,
    text_seq_len = 256,
    text_heads = 8,
    visual_enc_depth = 1,
    visual_image_size = 256,
    visual_patch_size = 32,
    visual_heads = 8
)

# 3 unets for the decoder (a la cascading DDPM)

# first two unets are doing latent diffusion
# vqgan-vae must be trained beforehand

vae1 = VQGanVAE(
    dim = 32,
    image_size = 256,
    layers = 3,
    layer_mults = (1, 2, 4)
)

vae2 = VQGanVAE(
    dim = 32,
    image_size = 512,
    layers = 3,
    layer_mults = (1, 2, 4)
)

unet1 = Unet(
    dim = 32,
    image_embed_dim = 512,
    cond_dim = 128,
    channels = 3,
    sparse_attn = True,
    sparse_attn_window = 2,
    dim_mults = (1, 2, 4, 8)
)

unet2 = Unet(
    dim = 32,
    image_embed_dim = 512,
    channels = 3,
    dim_mults = (1, 2, 4, 8, 16),
    cond_on_image_embeds = True,
    cond_on_text_encodings = False
)

unet3 = Unet(
    dim = 32,
    image_embed_dim = 512,
    channels = 3,
    dim_mults = (1, 2, 4, 8, 16),
    cond_on_image_embeds = True,
    cond_on_text_encodings = False,
    attend_at_middle = False
)

# decoder, which contains the unet(s) and clip

decoder = Decoder(
    clip = clip,
    vae = (vae1, vae2),                # latent diffusion for unet1 (vae1) and unet2 (vae2), but not for the last unet3
    unet = (unet1, unet2, unet3),      # insert unets in order of low resolution to highest resolution (you can have as many stages as you want here)
    image_sizes = (256, 512, 1024),    # resolutions, 256 for first unet, 512 for second, 1024 for third
    timesteps = 100,
    image_cond_drop_prob = 0.1,
    text_cond_drop_prob = 0.5
).cuda()

# mock images (get a lot of this)

images = torch.randn(1, 3, 1024, 1024).cuda()

# feed images into decoder, specifying which unet you want to train
# each unet can be trained separately, which is one of the benefits of the cascading DDPM scheme

with decoder.one_unet_in_gpu(1):
    loss = decoder(images, unet_number = 1)
    loss.backward()

with decoder.one_unet_in_gpu(2):
    loss = decoder(images, unet_number = 2)
    loss.backward()

with decoder.one_unet_in_gpu(3):
    loss = decoder(images, unet_number = 3)
    loss.backward()

# do the above for many steps for both unets

# then it will learn to generate images based on the CLIP image embeddings

# chaining the unets from lowest resolution to highest resolution (thus cascading)

mock_image_embed = torch.randn(1, 512).cuda()
images = decoder.sample(mock_image_embed) # (1, 3, 1024, 1024)

トレーニングラッパー

デコーダートレーニング

Decoder
それぞれのオプティマイザを個別に追跡する必要があるため、トレーニングは混乱を招く可能性があります
Unet
。それぞれ
Unet
に、対応する指数移動平均も必要になります。以下
DecoderTrainer
に示すように、これを単純にすることを望んでいます

import torch
from dalle2_pytorch import DALLE2, Unet, Decoder, CLIP, DecoderTrainer

clip = CLIP(
    dim_text = 512,
    dim_image = 512,
    dim_latent = 512,
    num_text_tokens = 49408,
    text_enc_depth = 6,
    text_seq_len = 256,
    text_heads = 8,
    visual_enc_depth = 6,
    visual_image_size = 256,
    visual_patch_size = 32,
    visual_heads = 8
).cuda()

# mock data

text = torch.randint(0, 49408, (32, 256)).cuda()
images = torch.randn(32, 3, 256, 256).cuda()

# decoder (with unet)

unet1 = Unet(
    dim = 128,
    image_embed_dim = 512,
    text_embed_dim = 512,
    cond_dim = 128,
    channels = 3,
    dim_mults=(1, 2, 4, 8)
).cuda()

unet2 = Unet(
    dim = 16,
    image_embed_dim = 512,
    text_embed_dim = 512,
    cond_dim = 128,
    channels = 3,
    dim_mults = (1, 2, 4, 8, 16),
    cond_on_text_encodings = True
).cuda()

decoder = Decoder(
    unet = (unet1, unet2),
    image_sizes = (128, 256),
    clip = clip,
    timesteps = 1000,
    condition_on_text_encodings = True
).cuda()

decoder_trainer = DecoderTrainer(
    decoder,
    lr = 3e-4,
    wd = 1e-2,
    ema_beta = 0.99,
    ema_update_after_step = 1000,
    ema_update_every = 10,
)

for unet_number in (1, 2):
    loss = decoder_trainer(
        images,
        text = text,
        unet_number = unet_number, # which unet to train on
        max_batch_size = 4         # gradient accumulation - this sets the maximum batch size in which to do forward and backwards pass - for this example 32 / 4 == 8 times
    )

    decoder_trainer.update(unet_number) # update the specific unet as well as its exponential moving average

# after much training
# you can sample from the exponentially moving averaged unets as so

mock_image_embed = torch.randn(4, 512).cuda()
images = decoder_trainer.sample(mock_image_embed, text = text) # (4, 3, 256, 256)

拡散事前トレーニング

同様に、を使用して、

DiffusionPriorTrainer
以前に平均化された指数関数的な移動を自動的にインスタンス化して追跡できます。

import torch
from dalle2_pytorch import DALLE2, DiffusionPriorNetwork, DiffusionPrior, DiffusionPriorTrainer, Unet, Decoder, CLIP

clip = CLIP(
    dim_text = 512,
    dim_image = 512,
    dim_latent = 512,
    num_text_tokens = 49408,
    text_enc_depth = 6,
    text_seq_len = 256,
    text_heads = 8,
    visual_enc_depth = 6,
    visual_image_size = 256,
    visual_patch_size = 32,
    visual_heads = 8
).cuda()

# mock data

text = torch.randint(0, 49408, (512, 256)).cuda()
images = torch.randn(512, 3, 256, 256).cuda()

# prior networks (with transformer)

prior_network = DiffusionPriorNetwork(
    dim = 512,
    depth = 6,
    dim_head = 64,
    heads = 8
).cuda()

diffusion_prior = DiffusionPrior(
    net = prior_network,
    clip = clip,
    timesteps = 100,
    cond_drop_prob = 0.2
).cuda()

diffusion_prior_trainer = DiffusionPriorTrainer(
    diffusion_prior,
    lr = 3e-4,
    wd = 1e-2,
    ema_beta = 0.99,
    ema_update_after_step = 1000,
    ema_update_every = 10,
)

loss = diffusion_prior_trainer(text, images, max_batch_size = 4)
diffusion_prior_trainer.update()  # this will update the optimizer as well as the exponential moving averaged diffusion prior

# after much of the above three lines in a loop
# you can sample from the exponential moving average of the diffusion prior identically to how you do so for DiffusionPrior

image_embeds = diffusion_prior_trainer.sample(text, max_batch_size = 4) # (512, 512) - exponential moving averaged image embeddings

ボーナス

無条件のトレーニング

リポジトリには、無条件のDDPMモデル、またはカスケードDDPMをトレーニングする手段も含まれています。あなたは単に設定

unconditional = True
する必要があります
Decoder

元。

import torch
from dalle2_pytorch import Unet, Decoder, DecoderTrainer

# unet for the cascading ddpm

unet1 = Unet(
    dim = 128,
    dim_mults=(1, 2, 4, 8)
).cuda()

unet2 = Unet(
    dim = 32,
    dim_mults = (1, 2, 4, 8, 16)
).cuda()

# decoder, which contains the unets

decoder = Decoder(
    unet = (unet1, unet2),
    image_sizes = (256, 512),  # first unet up to 256px, then second to 512px
    timesteps = 1000,
    unconditional = True
).cuda()

# decoder trainer

decoder_trainer = DecoderTrainer(decoder)

# images (get a lot of this)

images = torch.randn(1, 3, 512, 512).cuda()

# feed images into decoder

for i in (1, 2):
    loss = decoder_trainer(images, unet_number = i)
    decoder_trainer.update(unet_number = i)

# do the above for many many many many images
# then it will learn to generate images

images = decoder_trainer.sample(batch_size = 36, max_batch_size = 4) # (36, 3, 512, 512)

データローダー

デコーダーデータローダー

データの読み込みを簡単かつ効率的にするために、ネットワークの一部をトレーニングするために使用できるいくつかの一般的なデータローダーが含まれています。

デコーダー:画像埋め込みデータセット

When training the decoder (and up samplers if training together) in isolation, you will need to load images and corresponding image embeddings. This dataset can read two similar types of datasets. First, it can read a webdataset that contains

.jpg
and
.npy
files in the
.tar
s that contain the images and associated image embeddings respectively. Alternatively, you can also specify a source for the embeddings outside of the webdataset. In this case, the path to the embeddings should contain
.npy
files with the same shard numbers as the webdataset and there should be a correspondence between the filename of the
.jpg
and the index of the embedding in the
.npy
. So, for example,
0001.tar
from the webdataset with image
00010509.jpg
(最初の4桁はシャード番号で、最後の4桁はインデックスです)その中のは
img_emb_0001.npy
、インデックス509に埋め込まれたNumPy配列を含むaと並列化する必要があります。

このタイプのデータセットの生成:

  1. img2datasetを使用してwebdatasetを生成します。
  2. クリップ検索を使用して、画像を埋め込みに変換します。
  3. embedding-dataset-reorderingを使用して、埋め込みを期待される形式に並べ替えます。

使用法:

from dalle2_pytorch.dataloaders import ImageEmbeddingDataset, create_image_embedding_dataloader

# Create a dataloader directly.
dataloader = create_image_embedding_dataloader(
    tar_url="/path/or/url/to/webdataset/{0000..9999}.tar", # Uses braket expanding notation. This specifies to read all tars from 0000.tar to 9999.tar
    embeddings_url="path/or/url/to/embeddings/folder",     # Included if .npy files are not in webdataset. Left out or set to None otherwise
    num_workers=4,
    batch_size=32,
    shard_width=4,                                         # If a file in the webdataset shard 3 is named 0003039.jpg, we know the shard width is 4 and the last three digits are the index
    shuffle_num=200,                                       # Does a shuffle of the data with a buffer size of 200
    shuffle_shards=True,                                   # Shuffle the order the shards are read in
    resample_shards=False,                                 # Sample shards with replacement. If true, an epoch will be infinite unless stopped manually
)
for img, emb in dataloader:
    print(img.shape)  # torch.Size([32, 3, 256, 256])
    print(emb.shape)  # torch.Size([32, 512])
    # Train decoder only as shown above

# Or create a dataset without a loader so you can configure it manually
dataset = ImageEmbeddingDataset(
    urls="/path/or/url/to/webdataset/{0000..9999}.tar",
    embedding_folder_url="path/or/url/to/embeddings/folder",
    shard_width=4,
    shuffle_shards=True,
    resample=False
)

スクリプト(wip)

train_diffusion_prior.py

このスクリプトを使用すると、事前に計算されたテキストと画像の埋め込みについてDiffusionPriorをトレーニングできます。以下の実例は、このプロセスを説明しています。以下の例とは異なり、スクリプトは内部でtext_embedとimage_embedをDiffusionPriorに渡すことに注意してください。

使用法

$ python train_diffusion_prior.py

スクリプトの最も重要なパラメータは次のとおりです。

  • image-embed-url
    、デフォルト=
    "https://mystic.the-eye.eu/public/AI/cah/laion5b/embeddings/laion2B-en/img_emb/"

  • text-embed-url
    、デフォルト=
    "https://mystic.the-eye.eu/public/AI/cah/laion5b/embeddings/laion2B-en/text_emb/"

  • image-embed-dim
    、デフォルト=
    768
    -768はViT iL / 14埋め込みサイズに対応し、選択したViTが生成するサイズに変更します

  • learning-rate
    、デフォルト=
    1.1e-4

  • weight-decay
    、デフォルト=
    6.02e-2

  • max-grad-norm
    、デフォルト=
    0.5

  • batch-size
    、デフォルト=
    10 ** 4

  • num-epochs
    、デフォルト=
    5

  • clip
    、default =
    None
    #事前に計算された埋め込みを使用する前に通知します

DiffusionPriorモデルのロードと保存

load_diffusion_modelとsave_diffusion_modelの2つのメソッドが提供されており、名前は一目瞭然です。

from dalle2_pytorch.train import load_diffusion_model, save_diffusion_model
読み込み中
load_diffusion_model(dprior_path, device) 
    dprior_path : path to saved model(.pth)
    device      : the cuda device you're running on
保存
save_diffusion_model(save_path, model, optimizer, scaler, config, image_embed_dim)
    save_path : path to save at
    model     : object of Diffusion_Prior
    optimizer : optimizer object - see train_diffusion_prior.py for how to create one. 
        e.g: optimizer = get_optimizer(diffusion_prior.net.parameters(), wd=weight_decay, lr=learning_rate)
    scaler    : a GradScaler object.
        e.g: scaler = GradScaler(enabled=amp)
    config    : config object created in train_diffusion_prior.py - see file for example. 
    image_embed_dim - the dimension of the image_embedding
        e.g: 768

CLI(ワイプ)

$ dream 'sharing a sunset at the summit of mount everest with my dog'

ビルドされると、画像はコマンドが呼び出されたのと同じディレクトリに保存されます

テンプレート

トレーニングCLI(wip)

テンプレート

感謝

このライブラリは、

  • 拡散トレーニングスクリプトのZionKumar
  • デコーダートレーニングスクリプトとデータローダーのAidan
  • プルリクエストのレビューとプロジェクト管理のためのRomain
  • He Cao and xiankgx for the Q&A and for identifying of critical bugs
  • Katherine for her advice
  • Stability AI for the generous sponsorship

... and many others. Thank you! 🙏

Todo

  • [x] finish off gaussian diffusion class for latent embedding - allow for prediction of epsilon
  • [x] add what was proposed in the paper, where DDPM objective for image latent embedding predicts x0 directly (reread vq-diffusion paper and get caught up on that line of work)
  • [x] make sure it works end to end to produce an output tensor, taking a single gradient step
  • [x] augment unet so that it can also be conditioned on text encodings (although in paper they hinted this didn't make much a difference)
  • [x] figure out all the current bag of tricks needed to make DDPMs great (starting with the blur trick mentioned in paper)
  • [x] build the cascading ddpm by having Decoder class manage multiple unets at different resolutions
  • [x] add efficient attention in unet
  • [x] be able to finely customize what to condition on (text, image embed) for specific unet in the cascade (super resolution ddpms near the end may not need too much conditioning)
  • [x] offload unets not being trained on to CPU for memory efficiency (for training each resolution unets separately)
  • [x] build out latent diffusion architecture, with the vq-reg variant (vqgan-vae), make it completely optional and compatible with cascading ddpms
  • [x] for decoder, allow ability to customize objective (predict epsilon vs x0), in case latent diffusion does better with prediction of x0
  • [x] use attention-based upsampling https://arxiv.org/abs/2112.11435
  • [x] use inheritance just this once for sharing logic between decoder and prior network ddpms
  • [x] bring in vit-vqgan https://arxiv.org/abs/2110.04627 for the latent diffusion
  • [x] abstract interface for CLIP adapter class, so other CLIPs can be brought in
  • [x] take care of mixed precision as well as gradient accumulation within decoder trainer
  • [x] just take care of the training for the decoder in a wrapper class, as each unet in the cascade will need its own optimizer
  • [x] bring in tools to train vqgan-vae
  • [x] add convnext backbone for vqgan-vae (in addition to vit [vit-vqgan] + resnet)
  • [x] make sure DDPMs can be run with traditional resnet blocks (but leave convnext as an option for experimentation)
  • [x] make sure for the latter unets in the cascade, one can train on crops for learning super resolution (constrain the unet to be only convolutions in that case, or allow conv-like attention with rel pos bias)
  • [x] offer setting in diffusion prior to split time and image embeddings into multiple tokens, configurable, for more surface area during attention
  • [x] make sure resnet hyperparameters can be configurable across unet depth (groups and expansion factor)
  • [x] pull logic for training diffusion prior into a class DiffusionPriorTrainer, for eventual script based + CLI based training
  • [x] make sure the cascading ddpm in the repository can be trained unconditionally, offer a one-line CLI tool for training on a folder of images
  • [x] bring in cross-scale embedding from iclr paper https://github.com/lucidrains/vit-pytorch/blob/main/vit_pytorch/crossformer.py#L14
  • [x] cross embed layers for downsampling, as an option
  • [x] use an experimental tracker agnostic setup, as done here
  • [x] use pydantic for config drive training
  • [x] for both diffusion prior and decoder, all exponential moving averaged models needs to be saved and restored as well (as well as the step number)
  • [x] offer save / load methods on the trainer classes to automatically take care of state dicts for scalers / optimizers / saving versions and checking for breaking changes
  • [x] allow for creation of diffusion prior model off pydantic config classes - consider the same for tracker configs
  • [] unetsのエキスパートになり、unetコードをクリーンアップし、完全に構成可能にし、すべての学習をhttps://github.com/lucidrains/x-unetに移植します(ddpmリポジトリでunet²をテストします)-https ://githubを検討してください.com / lucidrains/uformer-pytorch注意ベースのunet
  • []コードをJaxに転記します。これにより、TPUにアクセスできる場合、分散トレーニングの活性化エネルギーが低下します。
  • []おもちゃのタスクでトレーニングし、コラボで提供する
  • []デコーダーでの複数のネットワークの事前エンコードとトレーニングの事前エンコードを処理する宣言型トレーニング構成を設計するための最良の方法を考えてください
  • []拡散ヘッドを拡張して拡散ガン(軽量ガンを使用する可能性がある)を使用して推論を高速化する
  • [] https://arxiv.org/abs/2204.11824で説明されているように、可能であれば外部メモリで拡張することを検討します
  • [] ddpmをローカルにカスケードする際にグリッドの注意をテストし、 https: //arxiv.org/abs/2204.01697を保持するか削除するかを決定します
  • [] vqgan-vaeをインターフェースして、事前にトレーニングされたものを棚から引き出して、潜在拡散+DALL-E2を検証できるようにします。
  • []FILIPがx-cliphttps://arxiv.org/abs/2111.07783のDALL-E2で動作することを確認します
  • []スキップレイヤーの興奮を(軽量のガンペーパーから)持ち込み、それがunetまたはvqgan-vaeトレーニングのデコーダーのいずれかに役立つかどうかを確認します
  • []デコーダーは、技術的負債のために1日分のリファクタリングが必要です
  • []unetが非クロスアテンションスタイルも調整できるようにする
  • []論文を読み、理解し、構築しますhttps://github.com/lucidrains/DALLE2-pytorch/issues/89

引用

@misc{ramesh2022,
    title   = {Hierarchical Text-Conditional Image Generation with CLIP Latents}, 
    author  = {Aditya Ramesh et al},
    year    = {2022}
}
@misc{crowson2022,
    author  = {Katherine Crowson},
    url     = {https://twitter.com/rivershavewings}
}
@misc{rombach2021highresolution,
    title   = {High-Resolution Image Synthesis with Latent Diffusion Models}, 
    author  = {Robin Rombach and Andreas Blattmann and Dominik Lorenz and Patrick Esser and Björn Ommer},
    year    = {2021},
    eprint  = {2112.10752},
    archivePrefix = {arXiv},
    primaryClass = {cs.CV}
}
@article{shen2019efficient,
    author  = {Zhuoran Shen and Mingyuan Zhang and Haiyu Zhao and Shuai Yi and Hongsheng Li},
    title   = {Efficient Attention: Attention with Linear Complexities},
    journal = {CoRR},
    year    = {2018},
    url     = {http://arxiv.org/abs/1812.01243},
}
@inproceedings{Tu2022MaxViTMV,
    title   = {MaxViT: Multi-Axis Vision Transformer},
    author  = {Zhengzhong Tu and Hossein Talebi and Han Zhang and Feng Yang and Peyman Milanfar and Alan Conrad Bovik and Yinxiao Li},
    year    = {2022},
    url     = {https://arxiv.org/abs/2204.01697}
}
@article{Yu2021VectorquantizedIM,
    title   = {Vector-quantized Image Modeling with Improved VQGAN},
    author  = {Jiahui Yu and Xin Li and Jing Yu Koh and Han Zhang and Ruoming Pang and James Qin and Alexander Ku and Yuanzhong Xu and Jason Baldridge and Yonghui Wu},
    journal = {ArXiv},
    year    = {2021},
    volume  = {abs/2110.04627}
}
@article{Shleifer2021NormFormerIT,
    title   = {NormFormer: Improved Transformer Pretraining with Extra Normalization},
    author  = {Sam Shleifer and Jason Weston and Myle Ott},
    journal = {ArXiv},
    year    = {2021},
    volume  = {abs/2110.09456}
}
@article{Yu2022CoCaCC,
    title   = {CoCa: Contrastive Captioners are Image-Text Foundation Models},
    author  = {Jiahui Yu and Zirui Wang and Vijay Vasudevan and Legg Yeung and Mojtaba Seyedhosseini and Yonghui Wu},
    journal = {ArXiv},
    year    = {2022},
    volume  = {abs/2205.01917}
}
@misc{wang2021crossformer,
    title   = {CrossFormer: A Versatile Vision Transformer Hinging on Cross-scale Attention},
    author  = {Wenxiao Wang and Lu Yao and Long Chen and Binbin Lin and Deng Cai and Xiaofei He and Wei Liu},
    year    = {2021},
    eprint  = {2108.00154},
    archivePrefix = {arXiv},
    primaryClass = {cs.CV}
}
@article{ho2021cascaded,
    title   = {Cascaded Diffusion Models for High Fidelity Image Generation},
    author  = {Ho, Jonathan and Saharia, Chitwan and Chan, William and Fleet, David J and Norouzi, Mohammad and Salimans, Tim},
    journal = {arXiv preprint arXiv:2106.15282},
    year    = {2021}
}
@misc{Saharia2022,
    title   = {Imagen: unprecedented photorealism × deep level of language understanding},
    author  = {Chitwan Saharia*, William Chan*, Saurabh Saxena†, Lala Li†, Jay Whang†, Emily Denton, Seyed Kamyar Seyed Ghasemipour, Burcu Karagol Ayan, S. Sara Mahdavi, Rapha Gontijo Lopes, Tim Salimans, Jonathan Ho†, David Fleet†, Mohammad Norouzi*},
    year    = {2022}
}

データからノイズを作成するのは簡単です。ノイズからデータを作成することは生成モデリングです。-ヤンソンの論文