七難ハック
gRPCを活用してLLM推論サーバをGoでつくる
最終更新: 2024/12/08

どうも、ken11です。おかげさまで、自分も著者の1人として参加させていただいた「事例でわかるMLOps」も無事に出版され、引き続きMLOps芸人として生きてます。

さて、MLOps芸人をしているとどうしてももどかしく感じることの1つとして、言語・フレームワーク選択肢の少なさというものがあります。どうしたってPythonを使わざるを得ないというのが現実で、推論用のWebサーバを構築しようとすると、FastAPIやFlaskを使うことになりがちです。

そしてサービス自体のバックエンドはRuby on RailsやGolang Echoなどで構築され、機械学習APIサーバは必然的にマイクロサービス化するのです。それはそれでメリットでもあるのですが、できればHTTPの通信は減らしたいですし、そもそも開発者の少ない場面ではどうせサービス自体のバックエンドも機械学習の推論APIもメンテナンスするのは自分なのになぜ別にしないといけないのかという気持ちになることもあります。

というわけで、本記事では、実際にMetaのLlama 3.2 1B Instructモデルを用いて、PythonでgRPCサーバーを構築し、それをGolangで利用する推論システムの作成手順を解説します。また、Flaskを利用した一般的なPythonサーバーとの比較も行います。


gRPCとは

gRPC(gRPC Remote Procedure Call)は、Googleが開発した高性能なオープンソースの通信フレームワークです。HTTP/2を基盤としており、高速かつ効率的な通信を実現します。また、Protocol Buffers(Protobuf)というシリアライズフォーマットを使用することで、データのやり取りを軽量化できます。


必要な準備

必要なツールとライブラリ

  • Python:
    • grpcio, grpcio-tools, transformers, torch
  • Go:
    • protoc, protoc-gen-go, protoc-gen-go-grpc, labstack/echo
  • 共通:
    • protoc(Protocol Buffersコンパイラ)

環境構築

# Pythonライブラリのインストール
pip install grpcio grpcio-tools transformers torch

# Goライブラリのインストール
go install google.golang.org/protobuf/cmd/protoc-gen-go@latest
go install google.golang.org/grpc/cmd/protoc-gen-go-grpc@latest
go get github.com/labstack/echo/v4

gRPCサーバー(Python)の構築

inference.proto の作成

以下の.protoファイルを作成します。このファイルは、gRPC通信で使用されるサービスとメッセージの定義です。

syntax = "proto3";

option go_package = "./your-project-grpc;your-project-grpc";

service LanguageModelService {
  rpc GenerateText (GenerateTextRequest) returns (GenerateTextResponse);
}

message GenerateTextRequest {
  string prompt = 1;
  int32 max_tokens = 2;
}

message GenerateTextResponse {
  string generated_text = 1;
}

gRPCサーバーの実装

PythonでgRPCサーバーを構築します。Llamaモデルをロードし、クライアントから受け取ったプロンプトに応じたテキスト生成を行います。

from concurrent import futures
import grpc
import inference_pb2
import inference_pb2_grpc
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

MODEL_NAME = "meta-llama/Llama-3.2-1B-Instruct"
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModelForCausalLM.from_pretrained(MODEL_NAME, torch_dtype=torch.bfloat16)
model.to("cuda")

class LanguageModelService(inference_pb2_grpc.LanguageModelServiceServicer):
    def GenerateText(self, request, context):
        inputs = tokenizer(request.prompt, return_tensors="pt", truncation=True)
        outputs = model.generate(
            inputs["input_ids"].to("cuda"),
            max_new_tokens=request.max_tokens,
            pad_token_id=tokenizer.eos_token_id,
        )
        generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
        return inference_pb2.GenerateTextResponse(generated_text=generated_text)

def serve():
    server = grpc.server(futures.ThreadPoolExecutor(max_workers=10))
    inference_pb2_grpc.add_LanguageModelServiceServicer_to_server(
        LanguageModelService(), server
    )
    server.add_insecure_port("[::]:50051")
    server.start()
    server.wait_for_termination()

if __name__ == "__main__":
    serve()

このサーバーは、クライアントから受け取ったリクエストを処理し、Llamaモデルで生成したテキストを返します。

python -m grpc_tools.protoc -I. --python_out=. --grpc_python_out=. inference.proto

gRPCのPythonコードを生成します。


gRPCクライアント(Golang)の構築

Echoを利用したWebサーバーを構築します。このサーバーは、HTTPリクエストを受け付け、gRPCサーバーにリクエストを転送してレスポンスを返します。

go mod init your-project
go get google.golang.org/grpc
go get google.golang.org/protobuf

プロジェクトの初期設定をしておきます。

package main

import (
	"context"
	"fmt"
	"log"
	"net/http"
	"time"

	pb "your-project/your-project-grpc"

	"github.com/labstack/echo/v4"
	"google.golang.org/grpc"
)

type Request struct {
	Prompt    string `json:"prompt"`
	MaxTokens int32  `json:"max_tokens"`
}

type Response struct {
	GeneratedText string `json:"generated_text"`
}

func main() {
	conn, err := grpc.Dial("localhost:50051", grpc.WithInsecure())
	if err != nil {
		log.Fatalf("failed to connect to gRPC server: %v", err)
	}
	defer conn.Close()
	client := pb.NewLanguageModelServiceClient(conn)

	e := echo.New()
	e.POST("/generate", func(c echo.Context) error {
		req := new(Request)
		if err := c.Bind(req); err != nil {
			return c.JSON(http.StatusBadRequest, map[string]string{"error": "invalid request"})
		}

		ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
		defer cancel()
		gRPCReq := &pb.GenerateTextRequest{
			Prompt:    req.Prompt,
			MaxTokens: req.MaxTokens,
		}

		gRPCResp, err := client.GenerateText(ctx, gRPCReq)
		if err != nil {
			log.Printf("gRPC error: %v", err)
			return c.JSON(http.StatusInternalServerError, map[string]string{"error": "failed to generate text"})
		}

		resp := Response{
			GeneratedText: gRPCResp.GeneratedText,
		}
		return c.JSON(http.StatusOK, resp)
	})

	port := ":8080"
	fmt.Printf("Starting server on port %s...\n", port)
	if err := e.Start(port); err != nil {
		log.Fatalf("failed to start server: %v", err)
	}
}
protoc --go_out=. --go-grpc_out=. inference.proto

Go用のコードも生成します。


Flaskサーバー(比較用)の構築

比較用に以下のようなシンプルなFlaskサーバーを用意します。実行内容は同じです。

from flask import Flask, jsonify, request
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch

app = Flask(__name__)
MODEL_NAME = "meta-llama/Llama-3.2-1B-Instruct"
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModelForCausalLM.from_pretrained(MODEL_NAME, torch_dtype=torch.bfloat16)
model.to("cuda")

@app.route("/generate", methods=["POST"])
def generate_text():
    data = request.get_json()
    prompt = data.get("prompt", "")
    max_tokens = data.get("max_tokens", 50)

    inputs = tokenizer(prompt, return_tensors="pt", truncation=True)
    outputs = model.generate(inputs["input_ids"].to("cuda"), max_new_tokens=max_tokens)
    generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
    return jsonify({"generated_text": generated_text})

if __name__ == "__main__":
    app.run(host="0.0.0.0", port=5000)

性能テストスクリプト

レスポンスタイムを比較するために、以下のスクリプトを実行します。

#!/bin/bash
echo "Testing Python Server:"
for i in {1..10}; do
    curl -X POST http://localhost:5000/generate \
         -H "Content-Type: application/json" \
         -d '{"prompt": "What is the capital of France?", "max_tokens": 50}' \
         -w "\nTotal time: %{time_total}s\n" -o /dev/null -s
done

echo "Testing Golang Server:"
for i in {1..10}; do
    curl -X POST http://localhost:8080/generate \
         -H "Content-Type: application/json" \
         -d '{"prompt": "What is the capital of France?", "max_tokens": 50}' \
         -w "\nTotal time: %{time_total}s\n" -o /dev/null -s
done

実行結果と比較

それぞれの平均レスポンスタイムを計算し、PythonとGolangのパフォーマンス差を評価します。

まずmax_tokensを50にして10回リクエストした結果、Pythonだけを使った今までどおりのFlaskサーバが平均0.46秒、GoとgRPCを使った今回の実装が0.51秒と差は僅かでした。

出力される文字列長によってばらつきがあるので、max_tokensを10まで減らして実行したところPythonだけの場合が0.1500秒、GoとgRPCの場合が0.1515秒となりました。

実用上ほとんど気にする必要がないレベルの遅延だということがわかります。むしろ、Goを使うことでリクエスト処理の並列化など、効率的に処理できる部分はあるので、速度向上を見込むこともできるでしょう。また、今回はトークナイズからPythonに渡してしまっていますが、実際にはトークナイズのような前処理はGo側で並列して実行できるので効率を上げることができると思います。


まとめ

自分はMLOpsの中でも特に推論の効率化や持続可能な推論サービスの提供について力を入れてきました。今年も振り返ってみるとそんなようなことばかりたくさん考えていましたが、ずっと頭の片隅にあって実現しなかったことの1つが今回の推論サーバアーキテクチャです。

実際にやってみるとそんなに大変なことは多くなく、スムーズに扱えることもわかったので、今後は基本的にこの構成で推論サーバをつくろうかなあと考えています。