科学の箱

科学・IT・登山の話題

機械学習

irisをナイーブベイズで分類

投稿日:

ナイーブベイズの概要

ナイーブベイズは教師ありの分類アルゴリズムの一つである。

計算アルゴリズムとしてはベイズ定理を用いている。

機械学習における特徴としてはいかがある

  • 実装が簡単であり、複雑なハイパーパラメータの調整は必要ない
  • 大規模なトレーニングデータを必要としない
  • 処理が高速であり、利用するコンピュータ資源が少ない
  • 大きなデータセットについても有効である(3による)
  • 取り急ぎのベンチマークとして利用できる(1, 2, 3による)

分類問題

ナイーブベイズが得意とするのは分類問題である。分類問題とは与えられたデータをあるルールによって定義されたカテゴリに分類することである。

分類問題のもっとも簡単な例として2値分類である。これはYes-No、プラス-マイナスなど2項に分類する。

分類問題の例として最も知られているのが受け取ったメールがスパムメールであるのか普通のメールであるかの分類である。また商品に関するアンケートを取って、それが商品に対してプラスのイメージかマイナスのイメージを持つか判断するのも分類問題である。

似た概念としてクラスタリングがある。しかし機械学習では分類問題とクラスタリングは異なる。

分類問題ではモデル作成者が正解のラベルを指定する。それに対して、クラスタリングでは機械学習のプロセスを通して各データの分類ラベルを生成する。

分類問題と単回帰分析の関係

機械学習で最も簡単なアルゴリズムである単回帰分析は分類問題に使えるのだろうか。これは機械学習の初心者にとっても少々わかりにく概念である。

単回帰分析は目的変数は連続値となる。そのために得られる値は理論的には-∞ < y < ∞である。それに対して分類問題の場合もっとも単純な2値分類であれば目的変数は2値である。今簡単に0と1とする。その場合には単回帰分析の結果を0もしくは1にマッピングしなければならない。

しかしそもそも0と1には数字的な関係(大小、計算)はないので、これをマッピングすることはできない。

よって単回帰分析は分類問題に適用することはできず、別のアルゴリズムを考える必要がある。分類に関しては様々なアルゴリズムがあり、この中でもっとも簡単で分かりやすい

ナイーブベイズモデル

モデルの構築

# iris データの読み込み
from sklearn.datasets import load_iris 
iris = load_iris()
print(iris.target_names)
# ['setosa' 'versicolor' 'virginica']

sklearn.datasetsはdataに説明変数、targetに目的変数が入っている。これを使ってトレーニングデータと検証データを作成する。

X_iris = iris.data
y_iris = iris.target
from sklearn.model_selection import train_test_split
X_train, X_test, y_train, y_test = train_test_split(X_iris, y_iris,random_state=1)

 

今回はナイーブベイズで分類する。

sklearn.naive_bayes.GaussianNB — scikit-learn 0.21.3 …

こちらを見るとわかるがハイパーパラメータがない。つまりデータをナイーブベイズ分類器に放り込めばよい。

モデルからの推定値はpredictにより取得できる。

from sklearn.naive_bayes import GaussianNB
model = GaussianNB() 
model.fit(X_train, y_train) 
y_model = model.predict(X_test)

 

検証

検証する。結果として97%の高い正解率が算出された。

from sklearn.metrics import accuracy_score
accuracy_score(y_test, y_model)
# 0.9736842105263158

どこでずれが生じているかを知るために混同行列を生成する。

混同行列はconfusion_matrix()を利用する。

sklearn.metrics.confusion_matrix — scikit-learn 0.21.3 …

第一パラメータに正解ラベルを、第二パラメータにモデルから得られたラベルを与える。

 

from sklearn.metrics import confusion_matrix
cm = confusion_matrix(y_test, y_model)
print(cm)

#
[[13  0  0]
 [ 0 15  1]
 [ 0  0  9]]

さらにconfusion_matrixを可視化する。可視化のためにはseaborn.heatmapを利用する。

seaborn.heatmap — seaborn 0.9.0 documentation

import matplotlib.pyplot as plt
import seaborn as sns
%matplotlib inline

plt.figure(figsize = (6,4))
sns.heatmap(cm, annot=True,xticklabels=iris.target_names, yticklabels=iris.target_names )
plt.show()

 

メタ情報

inarticle



メタ情報

inarticle



-機械学習
-

執筆者:


comment

メールアドレスが公開されることはありません。 * が付いている欄は必須項目です

関連記事

no image

データサイエンス超入門

シンプソンのパラドックス レコメンドエンジン コンテンツベースフィルタリング 強調 アイテム ユーザー ビジネスにおけるデータ分析の手順 課題→ビジネスケース→仮説→分析→用途 データ分析の手順 デー …

no image

モンベルダウンジャケットについて売れ筋商品を分析してみる。

幾何平均が値付けに利用されているという話を聞いたので実際の商品を例にして分析してみる。 調査の目的 適切なダウンジャケットを選ぶことで冬季にあるいは夏季の3000m級の高山で気持ちよく過ご巣ことができ …

no image

SIGNATE お弁当の需要予測-2

データの内容を確認する。 期間を調べる d_train[‘datetime’].min() ‘ ‘2013-11-18’ d_train[‘datetime’].max() ‘ ‘2014-9-9’ …

no image

データ取り込み後に確認すること

# tidyデータの原則 # 1カラム = 1変数 # 1行 = 1観察 # 1テーブル = 1 unique key # foreign key to link # テーブル全体で見ること # カラ …

no image

K近傍法と決定木の比較

One of the most comprehensible non-parametric methods is k-nearest-neighbors: find the points which …

2019年10月
« 9月   11月 »
 123456
78910111213
14151617181920
21222324252627
28293031  

side bar top



アーカイブ

カテゴリー