ナイーブベイズの概要
ナイーブベイズは教師ありの分類アルゴリズムの一つである。
計算アルゴリズムとしてはベイズ定理を用いている。
機械学習における特徴としてはいかがある
- 実装が簡単であり、複雑なハイパーパラメータの調整は必要ない
- 大規模なトレーニングデータを必要としない
- 処理が高速であり、利用するコンピュータ資源が少ない
- 大きなデータセットについても有効である(3による)
- 取り急ぎのベンチマークとして利用できる(1, 2, 3による)
分類問題
ナイーブベイズが得意とするのは分類問題である。分類問題とは与えられたデータをあるルールによって定義されたカテゴリに分類することである。
分類問題のもっとも簡単な例として2値分類である。これはYes-No、プラス-マイナスなど2項に分類する。
分類問題の例として最も知られているのが受け取ったメールがスパムメールであるのか普通のメールであるかの分類である。また商品に関するアンケートを取って、それが商品に対してプラスのイメージかマイナスのイメージを持つか判断するのも分類問題である。
似た概念としてクラスタリングがある。しかし機械学習では分類問題とクラスタリングは異なる。
分類問題ではモデル作成者が正解のラベルを指定する。それに対して、クラスタリングでは機械学習のプロセスを通して各データの分類ラベルを生成する。
分類問題と単回帰分析の関係
機械学習で最も簡単なアルゴリズムである単回帰分析は分類問題に使えるのだろうか。これは機械学習の初心者にとっても少々わかりにく概念である。
単回帰分析は目的変数は連続値となる。そのために得られる値は理論的には-∞ < y < ∞である。それに対して分類問題の場合もっとも単純な2値分類であれば目的変数は2値である。今簡単に0と1とする。その場合には単回帰分析の結果を0もしくは1にマッピングしなければならない。
しかしそもそも0と1には数字的な関係(大小、計算)はないので、これをマッピングすることはできない。
よって単回帰分析は分類問題に適用することはできず、別のアルゴリズムを考える必要がある。分類に関しては様々なアルゴリズムがあり、この中でもっとも簡単で分かりやすい
ナイーブベイズモデル
モデルの構築
# iris データの読み込み from sklearn.datasets import load_iris iris = load_iris() print(iris.target_names) # ['setosa' 'versicolor' 'virginica']
sklearn.datasetsはdataに説明変数、targetに目的変数が入っている。これを使ってトレーニングデータと検証データを作成する。
X_iris = iris.data y_iris = iris.target from sklearn.model_selection import train_test_split X_train, X_test, y_train, y_test = train_test_split(X_iris, y_iris,random_state=1)
今回はナイーブベイズで分類する。
sklearn.naive_bayes.GaussianNB — scikit-learn 0.21.3 …
こちらを見るとわかるがハイパーパラメータがない。つまりデータをナイーブベイズ分類器に放り込めばよい。
モデルからの推定値はpredictにより取得できる。
from sklearn.naive_bayes import GaussianNB model = GaussianNB() model.fit(X_train, y_train) y_model = model.predict(X_test)
検証
検証する。結果として97%の高い正解率が算出された。
from sklearn.metrics import accuracy_score accuracy_score(y_test, y_model) # 0.9736842105263158
どこでずれが生じているかを知るために混同行列を生成する。
混同行列はconfusion_matrix()を利用する。
sklearn.metrics.confusion_matrix — scikit-learn 0.21.3 …
第一パラメータに正解ラベルを、第二パラメータにモデルから得られたラベルを与える。
from sklearn.metrics import confusion_matrix cm = confusion_matrix(y_test, y_model) print(cm) # [[13 0 0] [ 0 15 1] [ 0 0 9]]
さらにconfusion_matrixを可視化する。可視化のためにはseaborn.heatmapを利用する。
seaborn.heatmap — seaborn 0.9.0 documentation
import matplotlib.pyplot as plt import seaborn as sns %matplotlib inline plt.figure(figsize = (6,4)) sns.heatmap(cm, annot=True,xticklabels=iris.target_names, yticklabels=iris.target_names ) plt.show()