サイトアイコン アマチュア無線局JS2IIU

HTTP APIで使えるシンプルな推論サーバをFlaskで作る

推論サーバをFlaskで構築する

推論サーバをFlaskで構築する

こんにちは、JS2IIUです。
機械学習モデルを実際のアプリケーションで利用するためには、モデルをAPI化し、外部から推論リクエストを受け取れるようにする必要があります。例えば、Webアプリやモバイルアプリから画像やテキストを送信し、その結果を返すといった仕組みです。

このような「モデルのAPI化」には、FlaskやFastAPIといった軽量なPythonフレームワークがよく使われます。本記事では、Flaskを用いて学習済みPyTorchモデルをHTTP APIとして公開する方法を、初学者にもわかりやすく解説します。

また、推論エンドポイント /predict の実装だけでなく、バッチ推論の処理方法スレッド安全性の確保、そして簡単な負荷テスト方法についても取り上げます。

本記事を読み終えるころには、自分のモデルをFlaskサーバとしてデプロイし、HTTP経由で推論を呼び出す方法を理解できるようになります。今回もよろしくお願いします。

FlaskとHTTP APIの基礎

まずはFlaskの基本から確認しましょう。FlaskはPythonで書かれた軽量なWebフレームワークで、数行のコードでHTTPサーバを立ち上げることができます。HTTP API(Application Programming Interface)は、アプリケーション間でデータをやり取りするための仕組みで、リクエストとレスポンスで構成されます。

以下のコードは、最も基本的なFlaskサーバの例です。

Python
from flask import Flask, jsonify

app = Flask(__name__)

@app.route('/')
def hello():
    return jsonify({'message': 'Hello, Flask API!'})

if __name__ == '__main__':
    app.run(debug=True)

このスクリプトを実行すると、http://127.0.0.1:5000/ にアクセスした際に "Hello, Flask API!" というJSONレスポンスが返ります。
@app.route('/') はエンドポイントを定義するデコレータで、HTTPリクエストを受け取ったときの処理を指定します。

これを応用して、学習済みモデルにデータを渡し、予測結果を返すAPIを構築します。

Flaskに関する関連情報リンク

学習済みPyTorchモデルを読み込む

次に、推論に使うモデルをPyTorchで読み込みます。ここでは、説明を簡単にするため、事前に保存されたダミーの線形モデルを使用します。

Python
import torch
import torch.nn as nn

# シンプルな線形モデルを定義
class SimpleModel(nn.Module):
    def __init__(self):
        super(SimpleModel, self).__init__()
        self.linear = nn.Linear(3, 1)  # 入力3次元、出力1次元

    def forward(self, x):
        return self.linear(x)

# 学習済みモデルの読み込み(ここではダミー)
model = SimpleModel()
model.load_state_dict(torch.load('simple_model.pt', map_location='cpu'))
model.eval()  # 推論モードに設定

model.eval() を呼び出すことで、ドロップアウトやバッチ正規化などが推論用の動作になります。
このようにモデルをロードしたあと、Flaskサーバ内で利用できるように保持します。

ダミーのsimple_model.ptを作成する方法

テストや記事のハンズオンで手早く動作を確認したい場合は、あらかじめ簡単なダミーモデル(あるいは短時間学習したモデル)の重みを simple_model.pt として保存しておくと便利です。ここでは素早く試せる2つの方法(1. 再現性のある「固定重み」ダミー、2. 小さなデータで少しだけ学習して保存)を示します。どちらもローカル環境で数秒〜数分で作成できます。

1) 最速:固定重みのダミーモデルを作る(テスト向け)

入力形状やモデル定義が本体の Flask コードと一致していれば、推論フローや API レスポンスの確認だけを行う目的で固定重みを保存するのが最も手早いです。再現性のために重みを定数で埋めています。

save_dummy_model.py:

Python
import torch
import torch.nn as nn

class SimpleModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.linear = nn.Linear(3, 1)

    def forward(self, x):
        return self.linear(x)

if __name__ == "__main__":
    model = SimpleModel()
    # 再現性のために重みを固定値で埋める
    with torch.no_grad():
        model.linear.weight.fill_(0.1)
        model.linear.bias.fill_(0.0)

    # state_dict を保存するのが推奨
    torch.save(model.state_dict(), "simple_model.pt")
    print("Saved simple_model.pt (state_dict)")

ターミナルで実行:

Bash
python3 save_dummy_model.py

これで同ディレクトリに simple_model.pt が作成され、記事内の Flask サンプルでそのまま読み込めます。

2) 少しだけ学習して保存する(実データに近い挙動を確認したい場合)

実際の学習済み重みで挙動を確認したい場合は、小さなダミーデータで短時間だけ学習して state_dict を保存します。モデルの保存方法は実運用と同じにしておくと移行がスムーズです。

train_and_save.py:

Python
import torch
import torch.nn as nn
import torch.optim as optim

class SimpleModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.linear = nn.Linear(3, 1)

    def forward(self, x):
        return self.linear(x)

if __name__ == "__main__":
    # ダミーデータ(入力:3次元 -> 出力:1次元)
    X = torch.tensor([[1.0, 2.0, 3.0],
                      [0.5, 0.2, 0.1],
                      [2.0, 1.0, 0.0]], dtype=torch.float32)
    y = torch.tensor([[1.0], [0.5], [1.8]], dtype=torch.float32)

    model = SimpleModel()
    optimizer = optim.SGD(model.parameters(), lr=0.1)
    loss_fn = nn.MSELoss()

    for epoch in range(200):
        optimizer.zero_grad()
        preds = model(X)
        loss = loss_fn(preds, y)
        loss.backward()
        optimizer.step()

    torch.save(model.state_dict(), "simple_model.pt")
    print("Trained (briefly) and saved simple_model.pt")

実行:

Bash
python3 train_and_save.py

短時間学習でも、入力に対する出力の分布がより現実に近くなり、API のエンドツーエンド検証がやりやすくなります。

推奨と注意点

Python
torch.save({
    'model_state': model.state_dict(),
    'input_size': 3,
    'notes': 'dummy model for demo'
}, 'simple_model_bundle.pt')

動作確認(Flask との接続テスト)

記事の Flask サンプルと同じディレクトリに simple_model.pt を置いたら、Flask アプリを起動して以下の curl で推論が返るか確認します。

Bash
curl -X POST -H "Content-Type: application/json" \
  -d '{"inputs": [[1.0, 2.0, 3.0]]}' \
  http://127.0.0.1:5000/predict

期待通り JSON が返れば成功です。

上のスクリプトは記事の読者がすぐに試せるように最小限にまとめています。実運用を想定する場合は、モデルのバージョン管理(ファイル名にバージョンを含める、メタデータで保存する)、認証、外部から供給されたモデルの検証などを追加してください。

/predict エンドポイントの実装

次に、Flaskサーバに /predict エンドポイントを追加し、クライアントから送られてきたデータをモデルに入力して予測結果を返すようにします。

Python
from flask import Flask, request, jsonify
import torch

app = Flask(__name__)

# すでにロード済みのモデルを使用
model = SimpleModel()
model.load_state_dict(torch.load('simple_model.pt', map_location='cpu'))
model.eval()

@app.route('/predict', methods=['POST'])
def predict():
    # クライアントから送られたJSONデータを取得
    data = request.get_json()

    # 入力データの検証
    if 'inputs' not in data:
        return jsonify({'error': 'Missing "inputs" key'}), 400

    # テンソルに変換
    inputs = torch.tensor(data['inputs'], dtype=torch.float32)

    # 推論を実行(勾配は不要)
    with torch.no_grad():
        outputs = model(inputs)

    # 結果をPythonのリストに変換して返す
    predictions = outputs.numpy().tolist()
    return jsonify({'predictions': predictions})

このエンドポイントは、JSON形式の入力を受け取り、推論結果をJSONで返します。
例えば、次のようなリクエストを送ると動作します。

Bash
curl -X POST -H "Content-Type: application/json" \
    -d '{"inputs": [[1.0, 2.0, 3.0], [0.5, 0.2, 0.1]]}' \
    http://127.0.0.1:5000/predict

出力は次のような形式になります。

JSON
{
  "predictions": [[2.345], [0.983]]
}

このように、Flaskはリクエストデータの処理やレスポンスの返却をシンプルに記述できます。

バッチ推論とスレッド安全性

サーバは同時に複数のリクエストを受け取る可能性があります。そのため、スレッド安全性を確保しつつ、複数データをまとめて処理する「バッチ推論」を行う設計が望ましいです。

PyTorchのモデルはスレッド安全ではない場合があるため、グローバルロックを使う方法があります。

Python
import threading

model_lock = threading.Lock()

@app.route('/predict', methods=['POST'])
def predict():
    data = request.get_json()
    inputs = torch.tensor(data['inputs'], dtype=torch.float32)

    with model_lock:  # 同時アクセスを防ぐ
        with torch.no_grad():
            outputs = model(inputs)

    return jsonify({'predictions': outputs.numpy().tolist()})

このようにロックを導入することで、同時リクエストが発生しても安全に処理できます。
Flaskを本番環境で動かす際は、Gunicorn のようなWSGIサーバを利用してマルチワーカーで動作させることが一般的です。

起動例:

Bash
gunicorn -w 4 app:app

ここで -w 4 はワーカー数を示し、同時に4つのリクエストを処理できます。

簡単な負荷テストを行う

API を構築したら、並列リクエストや高い同時接続数に対してサービスがどのように振る舞うかを確認する「負荷テスト(ロードテスト)」を行いましょう。

負荷テストは「何に負荷をかけるか」を明確にするところから始めます。代表的な対象は次の通りです

負荷をかけると何が起きるか(典型的な挙動)

これらの現象を観測するため、負荷テストでは下記の指標を取得して比べます。

スレッドとスレッドプールの概念は以下の通りです。

注: サーバ側(Flask)を本番で動かす場合、Gunicorn のような WSGI サーバで「ワーカー(プロセス)」と「各ワーカーのスレッド数」を組み合わせてチューニングするのが一般的です。CPU バウンドな推論処理はプロセス単位で分散する方が効果的なケースがあります。

簡単な負荷テストの実例(Python スクリプト)

以下は記事の範囲で気軽に試せるスクリプトです。ローカルの Flask サーバ(記事の /predict)に対して並列にリクエストを投げ、結果を集めます。

Python
import requests
import concurrent.futures
import time

def send_request(i):
    url = "http://127.0.0.1:5000/predict"
    payload = {"inputs": [[1.0, 2.0, 3.0]]}
    start = time.time()
    try:
        r = requests.post(url, json=payload, timeout=5)
        elapsed = time.time() - start
        return {'idx': i, 'status': r.status_code, 'elapsed': elapsed, 'body': r.json() if r.ok else None}
    except Exception as e:
        return {'idx': i, 'status': 'error', 'error': str(e)}

if __name__ == '__main__':
    # 同時に送るスレッド数(ここは増やして試す)
    CONCURRENCY = 10
    with concurrent.futures.ThreadPoolExecutor(max_workers=CONCURRENCY) as ex:
        futures = [ex.submit(send_request, i) for i in range(CONCURRENCY)]
        results = [f.result() for f in concurrent.futures.as_completed(futures)]

    # 結果の簡易集計
    successes = [r for r in results if r.get('status') == 200]
    errors = [r for r in results if r.get('status') != 200]
    latencies = [r['elapsed'] for r in successes]
    print(f"total: {len(results)}, success: {len(successes)}, errors: {len(errors)}")
    if latencies:
        print(f"avg latency: {sum(latencies)/len(latencies):.3f}s, p95: {sorted(latencies)[int(len(latencies)*0.95)-1]:.3f}s")
    print(errors)

使い方のヒント

どう解釈するか?

最後に、本番運用では負荷試験は単発で終わらせず、デプロイ前の回帰テストや負荷変動シナリオ(長時間テスト、スパイクテスト)も組み合わせて実施することをおすすめします。

まとめ

本記事では、Flaskを使って学習済みPyTorchモデルをHTTP APIとして提供する方法を解説しました。
主なポイントを振り返ります。

  1. Flaskは軽量でシンプルなWebフレームワークであり、少ないコードでAPIを構築できる。
  2. PyTorchモデルを model.eval() で推論モードにし、/predict エンドポイントでHTTP経由の入力を処理できる。
  3. 同時アクセスに備えて threading.Lock を使いスレッド安全性を確保する。
  4. 本番運用にはGunicornなどのWSGIサーバを使うことで安定したスケーラビリティを実現できる。
  5. 簡単な負荷テストを行うことで、サーバの性能を検証できる。

このように、Flaskを使えばPythonで手軽に推論サーバを構築でき、AIモデルを外部アプリケーションから利用できるようになります。
次のステップとして、FastAPIによる非同期推論やDockerでのデプロイにも挑戦してみるとよいでしょう。

最後まで読んでいただきありがとうございました。

モバイルバージョンを終了