七難ハック
Llama2の継続事前学習をしてみよう
最終更新: 2024/02/25

産経の記事、読みました?そうですアレのことです。何度も何度も某氏の名前を出したというのが最後に書いてありましたが、それを見て僕は「変わらないな」と思ったのでした。

あのときあなたに出会わなければ僕はいまエンジニアになっていないと思う。それほどにも人生に影響を与えた人が、いまも変わらず最前線で自分自身と戦い続けているのを見ると、とても勇気づけられる。

僕たちは不幸になることでしかアウトプットを生み出せない


Llama2の継続事前学習をしよう

どうも、社会不適合士1級の@ken11です。今日はタイトルにもある通り、Llama2(とか、いわゆるローカルLLM)の継続事前学習について話しましょう。

ところで継続事前学習ってなんだ?

継続事前学習とは

そもそも継続事前学習ってなんだよという話なんですが、まあ要は事前学習済みモデルとか基盤モデルとか呼んだりされるものに対して、追加で事前学習を施そうという試みです。

いわゆる日本語Llamaとか日本語Mistralとかって呼んでるのは基本的には公開されているLlama2やらMistralやらのモデルに対して日本語データで追加の事前学習をしているものになります。

(一部、Llamaのモデルアーキテクチャを使ってゼロからフルスクラッチで事前学習をしているモデルも存在します。これはめちゃくちゃ労力がかかるすごいことですね)

追加で事前学習を行う場合のアプローチ

では追加で事前学習を施す場合どういったアプローチがあるでしょうか?大きく分けて2種類存在します。

そのまま学習を実行する

一つは当然ながらそのまま学習を実行するパターンです。この場合、CausalLMとしての学習をそのまま新しいデータセットで実行するということになります。

当然ながらモデルパラメータが大きいものになれば、必要となるメモリサイズも大きくなるので、Llama2の7Bでやろうとするとそれなりのマシンが必要になってきます。

LoRAを使う

もう一つはLoRAを使う方法です。LoRA自体の説明はここではしませんが、要はメモリ消費を抑えて学習をする手段です。継続事前学習についてはLoRAではよくないのでは(フルファインチューンするべきなのでは)みたいに思ったりもするのですが、意外とやっている人もいたりして、なによりインフラ面で背に腹は変えられないケースもあるのでこれはこれとして手段の一つになるのだなというのが所感です。

いま世に多く出回っている日本語hogehoge的なモデル群が、いったいどの手法で追加の事前学習を行ったものかはほとんどわかっていないのが実情です。

実際にやってみる

では、実際に追加の事前学習をしてみましょう。

環境

今回は試しにやってみるということで、お金をかけず手元の環境で実行したいと思います。

僕がいま使っている環境は、Turing世代の2080TiとAmpere世代の3080Tiの2枚を挿したマシンです。合計のメモリサイズ的には20GBちょっと、LLMに使うには心もとないスペックとなっています。

そのため、今回の実験ではTinyLlamaを使います。このモデルは、Llama2アーキテクチャですが、1.1Bサイズで作成されたモデルです。もともとMetaが発表したLlama2シリーズ自体は最小でも7Bで、手元の環境で実験するにはやや厳しいものがありますので、今回はこのTinyLlamaを使いたいと思います。

学習方法

学習は今回、QLoRAでやりたいと思います。LoRAをさらにQuantizeして4bitで計算する手法です。

なぜこの手法をとるかというと、そこまでしないとメモリサイズが苦しいからです。世知辛いというか、LLMやっぱり強いインフラがないと厳しいものがあるな……

学習の実行

そして今回の学習にはLLaMA-Factoryを使います。これはLLMの学習を(継続事前学習のみならずSFTなど含め)全部いい感じに実行できるようにしてくれるツールです。

やっていることとしてはtransformersのTrainer等を実行しているだけなんですが、それにしたって多様なオプションなどをいい感じにして実行してくれるのでとても助かります。

我々の関心事としては data ディレクトリ内にある dataset_info.json と実行時のオプションだけで済みます。ここ数ヶ月で出会ったツールのなかでも最もクールだと思いますね。

というわけで今回はこのLLaMA-Factoryを使って、試しに日本語WikipediaのデータでTinyLlamaの追加事前学習を実行します。

accelerate launch src/train_bash.py \
    --stage pt \
    --do_train \
    --model_name_or_path TinyLlama/TinyLlama-1.1B-Chat-v1.0 \
    --dataset wikipedia_ja \
    --finetuning_type lora \
    --lora_target all \
    --output_dir $OUTPUT_DIR \
    --use_fast_tokenizer \
    --streaming \
    --preprocessing_num_workers 16 \
    --ignore_pad_token_for_loss \
    --max_steps 40000 \
    --push_to_hub False \
    --quantization_bit 4 \
    --overwrite_cache \
    --per_device_train_batch_size 16 \
    --gradient_accumulation_steps 1 \
    --lr_scheduler_type cosine \
    --logging_steps 1 \
    --save_steps 1000 \
    --save_total_limit 3 \
    --learning_rate 1e-5 \
    --num_train_epochs 1.0 \
    --fp16 \
    --report_to "wandb"

こんな感じで実行するだけです。簡単すぎて感動……昔は自分で一生懸命学習コードを書いていたのに…時代がよくなったなあ。

ポイント

  • 手元の環境にTuringが混ざっているのでbf16は使えてません
  • 同様の理由でflash_attnも使えてません
  • unslothも入れたらもっと早くなる?
  • per_device_train_batch_size * gradient_accumulation_steps = 16になるようしないといけないそうです

学習の様子

まだ学習途中ですが、こんな感じで順調にlossは下がっているので、学習自体はうまくいってそうです。

気になる時間等々ですが

  • メモリはこれでほぼフルフル(20GB)
  • GPUごとに16バッチ、1ステップで計32バッチ進む想定です。日本語Wikipediaのデータセットが130万件くらいなので、40000ステップほどやると1エポックですね。
  • 40000ステップで48時間かかる見込みです(丸二日)

世代の古いGPUがいるとはいえ、20GBメモリで1.1Bのモデルを学習するのに、130万件で2日かかるというのはなかなかしんどいですね。実際にはoscarなどもっと大規模なデータセットで学習させていく必要があるので、このままでは数ヶ月単位の時間がかかってしまいます。しかもQLoRAですし、これFull Finetuneだったらものすごく大変ってことですね。

一口にLLMと言っても、1B程度でもこんなに大変な学習をしているということがよくわかる結果です。

まとめ

今回はLlama2の継続事前学習を実際にトライしてみようということで、LLaMA-Factoryを使って、試しに日本語WikipediaのデータでTinyLlamaの追加事前学習をしてみました。

SFTの話は結構見つかるんですけどね、意外とここの情報は少なくて、なかなか苦労しました。

そういう意味でもLLaMA-Factoryのようなツールは非常に助かります。これで継続事前学習の手段等はよくわかったので、ローカルではなくクラウドを使ったりしながら、いろいろ試していきたいなと思います。