Stable Diffusionの中身はどうなっている? 内部処理をハンズオンで可視化

青色のウェーブ模様 皆さんこんにちは。ソニー・ミュージックエンタテインメントで生成AIの調査・検討を行っているみみずくです。

2021年に画像生成AIモデルである「DALL-E」がOpenAIからリリースされて以来、時代は生成AIブームとも言われており、画像だけでなくテキストや音声などさまざまなものに対する生成AIが急速なスピードで進化してきました。

そうした中、我々のようなエンタメ業界は、IPホルダーとして権利関係を保護しながらも、未来のコンテンツ制作に向けてAIをクリーンに利活用できる可能性について模索する必要性に迫られています。そしてそのためには、生成AIがどのようにして動いているのか、仕組みについて理解を深めておく必要があります。

今回はそうした生成AIの中から、画像生成AI「Stable Diffusion」に代表される「拡散モデル」に着目し、内部でどのような処理がなされているのかを皆さんにハンズオンで体験していただければと思います。内容としては初学者向けのものとなります。これから画像生成AIを学んでいこうと考えている方の参考になれば幸いです。

Stable Diffusionについて

Stable Diffusionとは、画像から別の画像を生成するImage-to-Image(I2I)機能や文字列から画像を生成するText-to-Image(T2I)機能を搭載した画像生成モデルの一種で、その中身にはスタンフォード大学で開発された「拡散モデル」という技術が使われています。

当時、高品質な画像をほとんど生成することができなかった画像生成の世界において、この拡散モデルという発明は大きなゲームチェンジャーとなりました。

拡散モデルについて

拡散モデルとは、ざっくり言えばランダムなノイズ画像から、徐々に元の情報を推測・復元するような形でデータを生成するシステムです。 このようなノイズを介したプロセスを「拡散処理」と呼びます。なお、拡散処理には順方向と逆方向が存在し、綺麗なデータにノイズを加えていくプロセスを順方向の拡散過程、完全なノイズから元の情報を推測・復元するプロセスを逆方向 の拡散過程と呼びます。

車の画像にノイズが加えられていくプロセスのイメージ
拡散処理のイメージ

学習の際には順方向と逆方向の両方が用いられ、生成の際には逆方向のみが用いられます。

学習過程ではデータセットの画像に徐々にノイズを加えていき、ノイズが加えられる前と加えられた後の両方の画像をニューラルネットワークに学習させます。これにより、例えばノイズが10回加えられたデータから、ノイズが9回だけ加えられた(少しノイズが除去された)データを予測する処理ができるようになります。

一方、生成過程では、まずはランダムなノイズ画像を用意します。そしてあたかもそれが綺麗な画像にノイズを少しずつ加えていった結果であるかのように見立て、1ステップずつ元画像を復元するような処理を行います。すると、最終的には拡散モデルが推測した「ノイズが加えられる前の画像」が出来上がり、これが生成物として出力されます。

実際にはこのプロセスは「潜在空間」と呼ばれる、データの特徴量がプロットされた多次元空間で行われているため可視化できないのですが、画像生成の最終段階で使われる「デコーダー」を通すことでそのプロセスを可視化できます。

以下の章では、Google Cloudの提供する、Webブラウザー上でPythonを記述、実行できる機械学習教育・研究サービス 「Google Colaboratory(以下、Colab)」 を用いて実際にそのプロセスを見ていきたいと思います。

なお、今回は直観的な理解を目的としているので、テキストや画像の入力なしで単にノイズから画像を生成するシンプルなモデルを利用します。

生成段階の可視化

では、逆方向の拡散処理によって画像が生成される段階を、Colab上で可視化してみましょう。

下準備

まずColabを起動し、プロジェクトファイルを作成します。タイトルは適当に「DiffusionTest.ipynb」とでもしておきましょう。

これから行うのはディープラーニングを利用した機械学習 なので、CPUよりもGPUの方が演算に適しています。CPUのまま実行すると数時間以上の時間がかかってしまいますが、右上の「RAM ディスク」などと表示されている部分をクリックし、「ランタイプのタイプを変更」から「T4 GPU」を選択することで、この後行う機械学習のプロセスをかなり早く実行することができます。

次に以下のコードでColabのPython上に演算用のライブラリとデータセット、拡散モデルを読み込みます。

!pip install -qq -U einops==0.7.0 datasets==2.14.6
!pip install Denoising-Diffusion-Probabilistic-Models

実行が完了したら、次に以下のコードで拡散モデルを動かすのに必要な関数を読み込みます。

from diffusion import DDMFunctions, Network, Train

次に、以下のコードを追加し、画像サイズとタイムステップ(処理の段階数)、画像のチャンネル数(白黒であれば1)を設定します。タイムステップは1280×720ピクセルを超えるような大きめの画像 であれば1000以上必要となりますが、ここでは比較的小さい画像データを扱うため300程度に設定します。

image_size = 28 #@param{type:'integer'}
timesteps= 300 #@param{type:'integer'}
channels = 1 #@param{type:'integer'}

次に除去するべきノイズを予測するニューラル・ネットワークである「U-Net」を読み込みます 。U-Netは画像のセグメンテーション(物体がどこにあるか)のタスクでも頻繁に利用されるネットワークで、拡散モデルにおいて不可欠な機械学習メカニズムです。

# @title U-Net
import torch
device = "cuda" if torch.cuda.is_available() else "cpu"

model = Network.Unet(dim=image_size, channels=channels)
_ = model.to(device)

以上でライブラリのインストールとU-Netの準備が完了しました。

学習

次にモデルの学習を行います。

今回は、MITライセンスというオープンソースライセンスで管理されている、「Fashion-MNIST」という画像データセットを利用して習します。Fashion-MNISTは、6万枚以上のファッション商品の画像データによって構成されています。いずれの画像も28×28の白黒画像で、「T-shirt/top」や「Trouser」など10種類のラベルがつけられています。

Fashion-MNISTに含まれるデータの一例。Tシャツ、ズボンなど、10種類のラベルが用意されている。
出典: Fashion-MNIST: a Novel Image Dataset for Benchmarking Machine Learning Algorithms. Han Xiao, Kashif Rasul, Roland Vollgraf. arXiv:1708.07747

以下のコードでは、学習の際に必要なパラメータの設定をします。一連の学習を何度繰り返すかという意味を持つ「エポック数」を今回は6に設定します。

# from datasets import load_dataset
# @markdown パラメーターの設定
dataset_name = "fashion_mnist" #@param ["mnist", "fashion_mnist"]
batch_size = 64 #@param{type:'integer'}
epochs=6 #@param{type:'integer'}
save_and_sample_every = 500 #@param{type:'integer'}
schedule_type = "linear" #@param ["linear", "cosine"]

次に、学習のために必要なtrainクラスをいつでも使えるようにインスタンス化しておきます。

train = Train(model, image_size, channels, timesteps, dataset_name, schedule_type, device)

この際、「ノートブックにシークレットへのアクセス権がありません」などと表示される場合は、「アクセスを許可」のボタンを押して進めてください。 次に、以下のコードを実行し、現時点でのモデル(まだ学習を行っていない)を保存します。

#@モデルの保存
torch.save(model.state_dict(), 'model.pkl')

次に、以下の2つのコードを順番に実行します。

sample_batch_size = 10 #@param{type:"integer"}
dm = DDMFunctions(timesteps, schedule_type)
samples = dm.sample(model, image_size, sample_batch_size, channels)
#@markdown 生成結果の表示

import matplotlib.pyplot as plt
random_index = 2 #@param{type:"integer"}
plt.imshow(samples[-1][random_index].reshape(image_size, image_size, channels), cmap="gray")

一連のコードを実行すると、以下のようなランダムなノイズが表示されると思います。 ランダムなノイズの画像

これがランダムなノイズにしか見えない理由は、まだモデルの学習が行われていないためです。 次のコードを実行して、モデルの学習を行いましょう。

train.train(epochs, save_and_sample_every=save_and_sample_every, batch_size=batch_size)

実行すると、表示されている「Loss」の数字がだんだんと小さくなっていくのが分かるかと思います。このLossとは「損失」や「誤差」とも呼ばれ、機械学習にとって、学習データセットに含まれる「正解値」とモデルによって出力された「予測値」とのズレの大きさを示す値です。

この値が小さくなればなるほど、モデルは学習データセットに含まれる情報を上手に(齟齬なく)説明できるようになっていきます。

さて、この処理が終了したら、もう一度「#@モデルの保存」の処理を実行します。これにより、今学習を行ったモデルを保存することができます。 保存が終わったら、その下にある2つのコードも再実行してみましょう。 すると、先ほどとは異なり、ただのノイズではなくセーターのように見える 画像が表示されました。生成される画像はノイズの初期状態によっても異なるので、皆さんが実行した場合には異なる画像が表示されると思います。

28×28のセーターのような画像

これが拡散モデルによる画像生成です。 さらに仕組みが分かりやすいように、このプロセスをgifアニメーションで可視化してみましょう。 一番下に、以下のコードを追加します。

#@markdown 逆拡散プロセスのアニメーション化
gif_file_name = "/content/denoising_diffusion_process.gif" #@param{type:"string"}
import matplotlib.animation as animation

fig = plt.figure()
ims = []
for i in range(timesteps):
    im = plt.imshow(samples[i][random_index].reshape(image_size, image_size, channels), cmap="gray", animated=True)
    ims.append([im])

animate = animation.ArtistAnimation(fig, ims, interval=50, blit=True, repeat_delay=1000)
animate.save(gif_file_name)

これを実行すると、gifアニメーションが生成されます。 そのままでは静止画しか見えないので、左のサイドバーにあるファイルマークからdenoising_diffusion_process.gifを探し、ダブルクリックします。 すると、完全にランダムなノイズの中から先ほどの服の画像が出現するプロセスがアニメーションとして可視化されます。 28×28のセーターがノイズの中から出現していくgifアニメーション

このように、ノイズ画像から1タイムステップごとに「今の画像から少しノイズを除去した場合の画像」を推測し、それを繰り返した先で綺麗な画像が生成されるのが逆方向(ノイズ除去拡散プロセス)なのです(もちろん、Stable DiffusionやMidjourneyなどではより綺麗な画像を生成することができます)。

最後に

生成AIの技術的側面を数学的に説明するのは非常に難しいのですが、生成プロセスを可視化することでノイズから画像が生じてくるプロセスをイメージすることができるようになったのではないでしょうか(ただし、なぜ拡散モデルによってあれほどまでに綺麗な画像生成が可能なのかは、専門家でも厳密に説明するのは難しいようです)。

生成AI技術は日進月歩で進化しており、さまざまな新手法が登場していますが、私は画像生成においてはしばらくこの拡散モデルが主流であり続けると予想しています。

これを脅威と捉えるか、創造性を促進する便利な道具と考えるか、今後も議論が必要になるところですが、技術はあくまで技術。ツールとして上手な付き合い方を模索していくことが重要になっていくでしょう。そのために必要なもののひとつが技術に対する正しい理解です。この記事が、皆さんの生成AI理解の一助になれば幸いです。