Dash応用編:第8回 機械学習モデルとの連携

Python
この記事は約15分で読めます。

こんにちは、JS2IIUです。機械学習とPythonは相性抜群です。今回はDashを活用して機械学習の可視化にトライします。よろしくお願いします。

はじめに

機械学習モデルとDashアプリケーションを統合することで、ユーザーにリアルタイムで予測や分析結果を提供できるインタラクティブなアプリを構築できます。今回の記事では、scikit-learnを使用して構築したシンプルな分類モデルをDashに組み込み、ユーザーがアップロードしたデータをリアルタイムで予測するフローを作成します。

機械学習モデルの準備

まずは、scikit-learnを使ってシンプルな機械学習モデルを準備します。ここでは、Irisデータセットを使用し、RandomForestClassifierを利用した分類モデルを構築します。

from sklearn.datasets import load_iris
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import train_test_split

# Irisデータセットのロード
iris = load_iris()
X = iris.data
y = iris.target

# データをトレーニングセットとテストセットに分割
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=42)

# ランダムフォレストモデルの作成とトレーニング
model = RandomForestClassifier(n_estimators=100, random_state=42)
model.fit(X_train, y_train)

# モデルの精度を確認
accuracy = model.score(X_test, y_test)
print(f"モデルの精度: {accuracy:.2f}")

このモデルを使って、ユーザーが入力した新しいデータに対してリアルタイムで予測を行います。

コードの解説

一つ目のサンプルコードでは、scikit-learnのRandomForestClassifierを使用してIrisデータセットに基づいた分類モデルを作成しています。このコードの各部分を詳しく解説していきます。

Irisデータセットのロード

from sklearn.datasets import load_iris

まず、scikit-learnの中に含まれているIrisデータセットをインポートしています。このデータセットは、アヤメの花の特徴量(花弁の長さや幅など)に基づいて3つの種類(Setosa、Versicolor、Virginica)に分類するためによく使われます。load_iris関数を使って、データセットを簡単にロードできます。

iris = load_iris()
X = iris.data
y = iris.target

irisという変数にデータセット全体を格納します。Xには特徴量(花弁やがく片の長さなど)のデータを、yには各サンプルがどの種類に分類されるかを示すターゲットデータ(0, 1, 2のラベル)をそれぞれ格納しています。

  • X: 特徴量データ。花の計測データ(4つの数値: 花弁やがく片の長さと幅)です。
  • y: ターゲットデータ。アヤメの種類を0, 1, 2という整数で表しています。

トレーニングセットとテストセットに分割

from sklearn.model_selection import train_test_split

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=42)

train_test_split関数を使用して、データをトレーニングセット(X_train, y_train)とテストセット(X_test, y_test)に分割します。ここでは、全データの70%をトレーニング用に、30%をテスト用にしています。test_size=0.3でこの比率を指定し、random_state=42で結果が再現できるように乱数シードを固定しています。

  • X_train: トレーニング用の特徴量データ。
  • X_test: テスト用の特徴量データ。
  • y_train: トレーニング用のターゲットデータ(クラスラベル)。
  • y_test: テスト用のターゲットデータ。

ランダムフォレストモデルの作成とトレーニング

from sklearn.ensemble import RandomForestClassifier

model = RandomForestClassifier(n_estimators=100, random_state=42)
model.fit(X_train, y_train)

次に、RandomForestClassifierを使用してランダムフォレストモデルを作成します。n_estimators=100は、100本の決定木を使って分類を行う設定です。ランダムフォレストは複数の決定木を使い、それぞれの決定木の結果を多数決で集計することで、より安定した予測を行います。random_state=42は、再現性を保つために乱数シードを固定するためのオプションです。

model.fit(X_train, y_train)で、トレーニングデータを使ってモデルをトレーニングします。この段階でモデルは、トレーニングデータからパターンを学習し、Irisの種類を予測できるようになります。

モデルの精度確認

accuracy = model.score(X_test, y_test)
print(f"モデルの精度: {accuracy:.2f}")

最後に、テストデータを使ってモデルの精度を確認します。model.scoreは、テストデータに対する正解率(Accuracy)を計算します。X_testを入力し、y_testの正解ラベルと比較して、どれだけ正確に予測できたかを出力します。

例えば、accuracyが0.95であれば、95%の精度でテストデータを正しく分類できたことを示します。

解説のポイント

  • Irisデータセット: 4つの特徴量(花弁とがく片の長さと幅)と3つのクラスラベル(Irisの3種類)で構成される小規模なデータセットです。
  • ランダムフォレスト: 複数の決定木を使った分類アルゴリズムで、過学習を防ぎつつ、高い精度での分類が可能です。n_estimators=100で100本の決定木を使う設定にしています。
  • データ分割: トレーニングとテストにデータを分割し、モデルが未知のデータに対してどの程度正確に予測できるかを確認します。
  • 精度: 最終的にモデルの精度を確認することで、モデルがどれだけ正確に予測できるかを評価します。

このようにして、まずローカルでトレーニングしたモデルをDashアプリケーションに統合し、ユーザーが入力したデータに対してリアルタイムに予測を行う仕組みが実現します。

Dashアプリケーションの構築

続いて、Dashアプリを構築します。ユーザーがCSVファイルをアップロードし、そのデータを分類モデルに適用するシンプルなUIを作成します。

import dash
from dash import dcc, html
from dash.dependencies import Input, Output, State
import pandas as pd
import io

app = dash.Dash(__name__)

app.layout = html.Div([
    html.H1("機械学習モデルでの予測"),
    
    # ファイルアップロードコンポーネント
    dcc.Upload(
        id='upload-data',
        children=html.Div([
            'Drag and Dropまたは',
            html.A('ファイルを選択')
        ]),
        style={
            'width': '100%',
            'height': '60px',
            'lineHeight': '60px',
            'borderWidth': '1px',
            'borderStyle': 'dashed',
            'borderRadius': '5px',
            'textAlign': 'center',
            'margin': '10px'
        },
        multiple=False  # 一度に一つのファイルのみ
    ),
    
    html.Div(id='output-data-upload'),
])

# ファイルの処理
def parse_contents(contents):
    content_type, content_string = contents.split(',')
    decoded = io.StringIO(pd.read_csv(io.BytesIO(base64.b64decode(content_string))))
    df = pd.read_csv(decoded)
    return df

@app.callback(
    Output('output-data-upload', 'children'),
    Input('upload-data', 'contents')
)
def update_output(contents):
    if contents is not None:
        df = parse_contents(contents)
        predictions = model.predict(df)  # ユーザーがアップロードしたデータに対して予測
        df['予測結果'] = predictions
        return html.Div([
            html.H5("予測結果:"),
            dcc.Graph(
                id='result-graph',
                figure={
                    'data': [
                        {'x': df.index, 'y': df['予測結果'], 'type': 'bar', 'name': '予測結果'},
                    ],
                    'layout': {'title': '予測結果の可視化'}
                }
            ),
            html.Hr(),
            html.Pre(df.to_csv(index=False), style={'whiteSpace': 'pre-wrap'})
        ])

if __name__ == '__main__':
    app.run_server(debug=True)

このDashアプリケーションでは、dcc.Uploadを利用してユーザーがファイルをアップロードし、その内容をscikit-learnでトレーニングしたモデルに入力して予測結果を表示しています。

プログラムの全体像

このサンプルプログラムでは、ユーザーがアップロードしたCSVファイルから機械学習モデルにデータを入力し、予測結果を表示する仕組みを実装しています。以下の手順で進めます。

  1. ユーザーがCSVファイルをアップロード。
  2. アップロードされたデータを前処理して、モデルに適用できる形に整形。
  3. 事前に学習させた機械学習モデルを用いて、アップロードされたデータに対する予測を実行。
  4. 結果を表示。

解説

dcc.Uploadを使ったCSVアップロード

dcc.Uploadコンポーネントを使って、ユーザーがCSVファイルをアップロードできるようにしています。アップロードしたファイルの内容は、コールバック関数update_outputで処理されます。

dcc.Upload(
    id='upload-data',
    children=html.Div(['Drag and Drop or ', html.A('Select a File')]),
    style={
        'width': '100%',
        'height': '60px',
        'lineHeight': '60px',
        'borderWidth': '1px',
        'borderStyle': 'dashed',
        'borderRadius': '5px',
        'textAlign': 'center',
        'margin': '10px'
    },
    multiple=False
)

dcc.Uploadは、ファイルのドラッグ&ドロップや選択を受け付けるコンポーネントです。multiple=Falseにより、複数ファイルではなく1つのファイルのみをアップロード可能にしています。

アップロードファイルの内容をパース

parse_contents関数で、アップロードされたCSVファイルの内容をデコードし、pandasのDataFrameとして読み込みます。

content_type, content_string = contents.split(',')
decoded = base64.b64decode(content_string)
df = pd.read_csv(io.StringIO(decoded.decode('utf-8')))

contentsに含まれているデータは、Base64エンコードされた文字列形式のため、これをデコードしてCSV形式に変換します。

特徴量のチェックと予測

読み込んだDataFrameに対して、機械学習モデルが必要とする列(Irisデータセットの4つの特徴量)が揃っているか確認します。

if set(['sepal length (cm)', 'sepal width (cm)', 'petal length (cm)', 'petal width (cm)']).issubset(df.columns):
    predictions = model.predict(df[['sepal length (cm)', 'sepal width (cm)', 'petal length (cm)', 'petal width (cm)']])
    df['prediction'] = predictions

指定された列が揃っている場合、その特徴量データを使って事前に学習したRandomForestClassifierモデルで予測を行い、結果をDataFrameに追加します。

結果の表示

予測結果を表示する際には、処理された行数やCSVファイルの内容をプレーンテキストとして表示します。

html.Pre(df.to_csv(index=False, header=True), style={'whiteSpace': 'pre-wrap', 'wordBreak': 'break-all'})

ここでは、DataFrameの内容をCSV形式に変換し、予測結果とともに出力しています。

アップロード用CSVファイルのサンプル

以下は、このプログラムで使用可能なCSVファイルの例です。このCSVファイルには、Irisデータセットと同様の4つの特徴量が含まれています。

sepal length (cm),sepal width (cm),petal length (cm),petal width (cm)
5.1,3.5,1.4,0.2
4.9,3.0,1.4,0.2
5.8,2.6,4.0,1.2
6.7,3.1,4.7,1.5
6.3,2.9,5.6,1.8

このファイルをアップロードすることで、各サンプルのアヤメの種類(Setosa, Versicolor, Virginica)が予測され、Dashアプリケーションに結果が表示されます。

まとめ

このサンプルプログラムは、ユーザーがアップロードしたCSVファイルを機械学習モデルに適用し、リアルタイムで結果を予測する仕組みを実装しています。データ前処理、モデルの適用、そして予測結果の表示までを一連の流れとして実現しています。

ユーザー入力データの処理とモデル適用

アップロードされたデータは、pandasを使ってDataFrameに変換し、モデルに適用されます。ここで注意したいのは、アップロードされたデータがモデルで処理可能な形式になっていることを確認する必要がある点です。たとえば、Irisデータセットでは4つの特徴量が必要なので、ユーザーのデータも同じ構造を持っている必要があります。

parse_contents関数で、アップロードされたファイルの内容をパースし、predict関数を使って予測を行います。その後、結果を元のDataFrameに追加して返します。

def parse_contents(contents):
    content_type, content_string = contents.split(',')
    decoded = base64.b64decode(content_string)
    decoded_str = io.StringIO(decoded.decode('utf-8'))
    df = pd.read_csv(decoded_str)
    return df

結果の可視化

結果は棒グラフとして表示され、dcc.Graphコンポーネントを使って視覚的にフィードバックします。これにより、ユーザーはアップロードしたデータの分類結果を即座に確認できます。

dcc.Graph(
id='result-graph',
figure={
'data': [
{'x': df.index, 'y': df['予測結果'], 'type': 'bar', 'name': '予測結果'},
],
'layout': {'title': '予測結果の可視化'}
}
)

まとめ

今回の記事では、Dashアプリケーションに機械学習モデルを統合し、ユーザーがアップロードしたデータをリアルタイムで処理し、予測結果を可視化する方法を解説しました。このようにして、scikit-learnなどのライブラリで構築したモデルを手軽にインタラクティブなアプリケーションに組み込むことが可能です。次回は、アプリケーションのデプロイと自動化について解説します。

Dash関連記事まとめ

DashはJavaScriptライブラリであるReactの上に構築されたPythonフレームワークであるが、DashはRでも動作し、最近ではJuliaもサポートしている。
Wikipedia – Plotly/Dash から引用、翻訳

コメント

タイトルとURLをコピーしました