読者です 読者をやめる 読者になる 読者になる

Gunosyデータ分析ブログ

Gunosyで働くデータエンジニアが知見を共有するブログです。

いまさら聞けない機械学習の評価関数

評価関数 機械学習 勉強会

アライアンス事業開発部の大曽根(@dr_paradi)です。 ニュースパスというアプリの分析と開発を行っております。

今回は機械学習の評価関数のお話をします。 内容は、【FiNC×プレイド】Machine Learning Meetup #1 - connpassで発表したものになります。

発表資料

www.slideshare.net

機械学習における評価

現在は機械学習ライブラリが充実しており、また、Webサービスの普及により学習に必要なデータの獲得も以前と比較して容易になっています。 そのため、機械学習のビジネス利用への敷居が下がっています。

予測や分類といった問題を解く際には、設定した課題に対してどのモデルが最も適しているかを評価するための指標(評価関数)が必要になります。 Kaggle*1などのコンペティションではあらかじめ評価指標が定まっています。しかし、実ビジネスで機械学習を応用する際には自ら評価指標を設定する必要があります。 さらに、適切な評価関数を選ぶのは初学者には難しく、ビジネスの問題設定、目的意識によっても異なります。 また、オフラインでの予測はユーザの実際の行動予測とはギャップがある
場合もあります*2(最適でない答えの中に異質なものが入るとついクリックしてしまうなど)。

そこで、本ブログでは多くある評価関数のうち代表的なものいくつかをまとめて解説します。 はじめに2値分類を評価する際に用いられる基礎的な評価関数を解説し、現在のKaggleのActive Competitionの評価指標を紹介します。

基礎編

2値分類 (例えばとある病気に罹患しているかどうか) の予測をする際の指標を解説します。

f:id:dr_paradi:20160804194809p:plain

正解率 (Accuracy)

この指標自体を精度という場合もあります*3。 予測結果全体と、答えがどれぐらい一致しているかを判断する指標。計算式は下記を参照。

 Accuracy = \frac{TP + TN}{TP + FP + FN + TN}

一見優れた指標に見えますが、発表スライドにもある通り、正解データの正が1%で負が99%のような場合、すべてのデータを不正解と予測するデータがある場合、99%の精度 を持つモデルと評価されてしまいます。これではよいモデルとは言えないので下記の指標を併せて使う場合が多いです。

適合率 (Precision)

偽陽性を低く抑えることを目的とする場合には適合率が高いモデルを採用します。犯罪の検挙を例にすると、一般市民を冤罪で逮捕してしまう率を低く抑えることができます。しかし、逆に真犯人を見つける確率も下がる場合もあります。

 Precision = \frac{TP}{TP + FP}

再現率 (Recall)

偽陰性を低く抑えたい場合に採用する指標です。

 Recall = \frac{TP}{TP + FN}

F値 (F-measure, F-score)

適合率と再現率はトレードオフの関係にある(どちらかが高くなるとどちらかが低くなる)ので調和平均をとった指標です。

 F{\unicode{x2013}} measure = \frac{2 \cdot Recall \cdot Precision}{Recall + Precision}

重み付きF値 (Weighted F-measure)

検索において、ユーザの反応を元に"Precisionをやや高める"など際に重み付けのF値を使う場合もあります。

 F{\unicode{x2013}} measure = \frac{\left( 1 + \beta ^ 2 \right)\cdot Recall \cdot Precision}{Recall + \beta ^ 2 \cdot Precision}

\( 0 < \beta < 1\) のときに再現率を重視し、\( 1 < \beta\) のときに適合率を重視します。

Kaggleの指標

8/3時点でActiveだったCompetitionのいくつかをピックアップして評価指標を見てみます。

Dice係数 (Dice Coefficient)

"Ultrasound Nerve Segmentation" で使われている指標 適合率などと近い指標で、ある集合を予測する際に、狭く予測した際に評価が高くなる指標 Jaccard係数などと同じく文書や単語の類似度を図る際にも用いられる(らしい)

 Dice Coefficient = \frac {2 \cdot | X | \cap | Y | }{| X | + | Y | }

\( |X| \): 予測した解の集合

\( |Y| \): 正解の解の集合です

f:id:dr_paradi:20160804201207p:plain

RMS●E系

"Grupo Bimbo Inventory Demand" ではRMSLEが使われており広く用いられる最小二乗誤差を拡張した指標になっています。

以下にいくつか代表的な最小二乗●●誤差系をまとめました。

RMSE (Root Mean Squared Error)

一般的な二乗誤差

 RMSE = \sqrt{ \frac{1}{N}  \sum_{i=1}^n \left(y_i - y_i' \right)^{2}}

RMSPE (Root Mean Squared Persentage Error)

割合の差の二乗誤差

 RMSPE = \sqrt{ \frac{1}{N}  \sum_{i=1}^n \left(\frac{y_i - y_i'}{y_i} \right)^{2}}

RMSLE (Root Mean Squared Logarithmic Error)

"Grupo Bimbo Inventory Demand" で使われている指標で、対数の差を取っています。例えば、売り上げが100円の商品を10,000円と予測した場合の差が小さく評価されます。 個人の資産の額などの桁が大きくなり対数正規分布に近い分布において有用です。

RMSLEは対数を取っているので一つの大きな間違いでの差が出にくくなっています。RMSPEも割合なので同様です。 (例えば店舗の売り上げ予測などで、一つの店舗のが異常に売り上げが高いと、最小二乗誤差の場合、その店の予測精度だけを上げればよい)

 RMSLE = \sqrt{ \frac{1}{N}  \sum_{i=0}^n \left(\log\left(y_i + 1\right) - \log\left(y_i' + 1 \right) \right)^{2}}

\( y_i' \): i番目の要素の予測値

\( y_i \): i番目の要素の正解値

Multi-class logarithmic loss

"TalkingData Mobile User Demographics" で使われている指標

多クラス分類の場合にはaccuracyを使うことも多いですが、予測モデルの出力が特定のクラスに属する確率であることが多いので、正解との距離を対数で取ったものの和を評価関数としています。

 log loss = - \frac{1}{N}  \sum_{i=1}^{n} \sum_{j=1}^{m} y_{ij} \log\left( p_{ij} \right)

\( p_{ij}' \): i番目の要素のj番目のクラスに属する予測確率

\( y_{ij} \): i番目の要素のj番目のクラスに属するかどうか (1 or 0)

f:id:dr_paradi:20160804201109p:plain

上図の場合には、予測モデル2の評価が高くなります(Multi-class logarithmic loss自体は0に近い方がよい)。

Accuracyで評価した場合には予測モデル1、予測モデル2の双方とも同じ評価になります。

その他

確率分布

カルバックライブラーダイバージェンスなどが確率分布同士の差を比較する際に利用されます。

まとめ

評価関数を淡々と紹介しましたが、最初に述べたように、オフラインでの予測はユーザの実際の行動予測とはギャップがある
場合もあります(最適でない答えの中に異質なものが入るとついクリックしてしまうなど)。

BtoCの実ビジネスに機械学習を導入する際にはABテスト*4などを用いて実際のユーザの反応を見つつ評価指標を変化させ (時には新しい指標を作成し)、改善させていくことが重要かと思います。

*1:有名なデータサイエンスのコンペティション

*2:Data-Driven Metric Development for Online Controlled Experiments: Seven Lessons Learned Xiaolin Shi*, Yahoo Labs; Alex Deng, Microsoft;
KDD '16

*3:数式などは主にこちらを参考にしています: 情報検索の基礎 Christopher D.Manning (著), Prabhakar Raghavan (著), Hinrich Schutze (著)岩野 和生ら (翻訳), 共立出版 2012

*4:参考: シリコンバレーのIT企業が利用しているA/Bテスト手法まとめ - データ分析エンジニアのブログ