七難ハック
日本語モデルを作りたかった話
最終更新: 2021/01/19

掲題の通りといえばそうなんですが
NLP、こと日本語の自然言語処理って楽しいですよねっという話です


最近話題のKWがわからない

賑やかなのはいいことですが、賑やかすぎてなにがなんだか混乱するので軽く整理してみます。
汎用言語モデルが成長した結果、対話生成の領域が盛り上がってる、という感じでしょうか。

GPT-3

OpenAIのGPTシリーズがついに「ヤバいから無条件には公開しないよ><」って言い出した(と理解してますがあってますか笑)汎用言語モデル
Redditかなんかでしばらく人間のふりして会話してたとかなんとか
前述の経緯もあって商用化されてるので一般人が気軽に試せるものではないし真似するのも難しい(と思う)

Meena

Googleがつくったチャットボット
文脈を理解して人間と違和感なく会話できるチャットボット(らしい)
これも特にオープンソースとかではないらしいので一般人が試せるものではなさそう

Blender

Facebookがつくったチャットボット、Meenaを越えたと豪語していた(気がする、そんな記事を目にした)
これは実はオープンソースで、パーレイで試せるらしい
Recipes for building an open-domain chatbot


「すごいぞ!」「やばいぞ!」というニュースばかり舞い込んでくるが、最近のホットなものは一昔前と違ってわりともうクローズドで、一般人が試して遊べる世界ではなくなっている。
しかもお察しの通りこれらは英語に限った話なので、日本語となるとさらに厳しい。

そんな日本語の界隈において頑張ってるのがLINE

LINEは日本語版GPT-3と言えそうな規模の汎用言語モデルを作ろうとしているらしく、これはめちゃくちゃ楽しみ。

じゃあ手元でなにができるのさ

手元で気軽に試せてしかも日本語でできるのってなにがあるのさって話で、とりあえずGPT-2の日本語モデルを作ってみることにします。

    # use multiple gpu
    mirrored_strategy = tf.distribute.MirroredStrategy()

    with mirrored_strategy.scope():
        loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
        metric = tf.keras.metrics.SparseCategoricalAccuracy('accuracy')
        model = TFGPT2LMHeadModel(config)
        optimizer = tf.keras.optimizers.Adam(
            learning_rate=LEARNING_RATE, epsilon=EPSILON, clipnorm=CLIPNORM
        )
        print("model compile")
        model.compile(
            optimizer=optimizer,
            loss=[loss, *[None] * model.config.n_layer],
            metrics=[metric]
        )

ちゃんとしたソースはそのうちリポジトリごと公開しますが、↑はマルチGPU環境で動かすのに苦労というか悩んだところですね。
通常tf.distribute.MirroredStrategy()は引数にgpuを指定するんですが、全GPUを使いたいときは空でいいそうです。
僕は無駄に自宅PCが3080と2080Tiの二枚刺しとかになってるのでマルチGPUできるんです…笑

定番のWikipedia全文とかそういうのをいくつか使って学習させるんですが、これが恐ろしい時間かかる。
普通に1エポックに7時間くらいかかるんですよね。
しばらくPCつけっぱにしてたんですが、面倒なので例のごとくアレを頼ります。

AWSでA100を使いたかっただけなんだ

そう!
AWSには夢のA1008枚も積んだ伝説のインスタンスタイプが存在するのです。
p4d.24xlarge
時間あたりだいたい4000円
使い方を間違えるとすぐさま支払不能な額に陥りそうな恐怖のインスタンスタイプ
でも、その性能を目の当たりにしてみたくないですか?
自宅でコンシューマ向けとはいえ二枚刺しして7時間かかる学習も、このメモリ量なら1エポックに1時間かからない想定です。
A100がどんなもんなのかとにかくぶん回してみたい…!

制限1. EC2リミット

EC2ってvCPUの数で制限がかかってるんですね(今まで意識したことがなかった)
まずEC2で動かしてみよう〜くらいのノリだったんですが、当然個人がvCPU96個とか開放されてるわけもなく、AWSのサポートに連絡です。
ちなみにこのインスタンスたしか東京はまだ来てないんじゃなかったかな。
僕は昔の名残で家で使うときはずっとオレゴンなので←
USのサポートに連絡したら、快く親切に対応してくれました。

制限2. Sagemakerのリミット

↑のEC2の件をサポートに連絡している間に、Sagemakerなら使えたりしない?という希望的観測でチャレンジしてみたところ、こっちはこっちで別の制限にひっかかる

"The account-level service limit 'ml.p4d.24xlarge for spot training job usage' is 0 Instances, with current utilization of 0 Instances and a request delta of 1 Instances. Please contact AWS support to request an increase for this limit."

なんぞ
これがまたくせ者で、サポートから探しても全然わからない
というかService Quotaの方だと思ってたらそこにはなかった
どうやらSagemakerのトレーニングジョブで起動できるインスタンスタイプと数は表には出てないが制限がある模様

EC2の件でついでに「これはどこから申請するの?」って聞いたら教えてもらえた
普通にサポートケース作って、Sagemaker選んで、トレーニングジョブのインスタンス数制限を…

っとここでまた問題ですよ
レアなインスタンスタイプだからか、プルダウンリストにp4d.24xlargeがないんですよね。。
どうしようって思ったけどとりあえず適当な別のインスタンスタイプ選んで、メッセージに「プルダウンリストになかったから適当なのを選んだだけで、本当はp4d.24xlargeが使いたいよ」って書いて送ったら丁寧に対応してくれました。
AWSのUSサポートはマジ神対応

A100で学習させてみた感想

さて、利用するための申請は大変でしたが、実際にA100を使ってGPT-2な日本語モデルの学習を進めることができたのでその感想を。

学習にはSagemakerのトレーニングジョブがいい

当初お試し的にEC2で動かす気でいましたが、実際にやってみるとSagemakerのトレーニングジョブはかなりいいです。
まず、スポットインスタンスを指定できるので、料金が安くなる
spot_instance_job
これは実際にp4d.24xlargeで3時間ほどぶん回したときの
7割オフですよ!
A100を8枚積んだ夢のマシンが時間1400円程度で使えたんですよ!!
お得感が半端ないので、機械学習やるならとりあえずSagemaker使うのがオススメです。

ただ、トレーニングジョブの作り方のルールとかの学習コストが結構高いと思いました。
特に自分の場合はローカルからDockerでやってたので結構親和性高かったんですが、Docker使ってない人は面倒かもしれないです、イメージに固めてECR上げたりしないといけないので。

ちなみにスポットインスタンス使うにはチェックポイントの設定が必須なんですが、チェックポイントはチェックポイントファイルのあるS3のディレクトリを指定すると、普通のトレーニングジョブでもちゃんと読み込んで使うことができます。
チェックポイントの設定ってスポットインスタンスが途中で終わったときとかしか使えないのかなって思ってたけどそんなことなかった、めちゃくちゃ便利。

学習結果はまずまず

ちなみに目論見通り1エポックあたり50分程度まで短縮することに成功しました!(メモリ量が増えてバッチサイズ増やせるから当たり前)
しかし、val_lossが3.xxあたりから下がり方が緩くなってしまって
1エポック回しても0.1とかしか下がらないと、さすがに1400円かけてこれはちょっとって感じで…
なんでしょうね、これは学習データだったり自分のセンスのなさが原因だなあと思います

…が、みなさん実際どれくらいなんだろうと思いました。
GPT-2の日本語モデルって結構挑戦してる方いらっしゃるイメージですが、みんなどんな感じなんだろう。
1くらいまでは下がるかなーとか期待してたけど甘い?
まだまだ研究の余地がありますね

今後

まずはもう少し汎用言語モデルを鍛えたいですね。
データ増やすとかして。
その後は対話生成をやりたいんですよ。
Twitterデータとか集めて、って感じなんですかね。
海外だとRedditデータを集めて学習させたりしてますよね、あれって日本だとなにになるんだろう?とか