科学の箱

科学・IT・登山の話題

機械学習

手書き数字データについて次元縮約および教師ありモデルの構築

投稿日:

前回手書き数字データについてイメージで確認した結果、人の目で確認する分には区別ができる。では機械学習ではどのように実施していくのか。

今回は以下の内容について説明する。

  • 多様体学習による次元縮約
  • ナイーブベイズによる分類

多様体学習による次元縮約

さてこのイメージデータは64次元であり、取り扱うには不便である。そこで次元縮約により2次元に落とす。次元縮約といえばPCAであるが、残念ながらPCAの前提はデータが線形であること。イメージのデータ、つまりバイナリであることから、線形を仮定するには無理がある。非線形の次元縮約に利用できるのが多様体の学習(Manifold Learning)である。

2.2. Manifold learning — scikit-learn 0.21.3 documentation

多様体学習のアルゴリズムとしてはいくつかあるが今回はIsomapを利用する。

 

from sklearn.datasets import load_digits
digits = load_digits()

from sklearn.manifold import Isomap
iso = Isomap(n_components=2)
# 次元縮約 n_components=2を指定にする。
iso.fit(digits.data)
# 次元縮約
data_projected = iso.transform(digits.data)
# 縮約されたデータ
data_projected.shape
# (1797,2)

Isomapの結果、各イメージデータは2次元に縮約されていることがわかる。

次のこのデータを2次元でプロットし、データが区別されているかを確認する。Isomapの結果から取得した2次元データをそれぞれX, Y軸でプロットする。数字データにより色分けをする。

import matplotlib.pyplot as plt
import matplotlib.cm as cm
%matplotlib inline
plt.style.use('seaborn-darkgrid')
plt.scatter(data_projected[:, 0], data_projected[:, 1], c=digits.target,alpha=1, cmap = cm.Accent)
plt.colorbar(ticks=range(10))

 

数字の分類

手書き数字を分類する。正解データが存在しているので教師ありモデルのナイーブベイズを利用する。

データを訓練用および検証用に分割する。

from sklearn.model_selection import train_test_split
X = digits.data
y = digits.target
X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=0)

 

次に、ナイーブベイズモデルを構築する

from sklearn.naive_bayes import GaussianNB
model = GaussianNB()
model.fit(X_train, y_train)
y_model = model.predict(X_test)

 

最後にモデルを検証する。

from sklearn.metrics import accuracy_score
accuracy_score(y_test, y_model)
# 0.8333333333333334

結果として最も単純な分類アルゴリズムでありナイーブベイズであっても83%という高い制度であることが分かった。

メタ情報

inarticle



メタ情報

inarticle



-機械学習
-

執筆者:


comment

メールアドレスが公開されることはありません。 * が付いている欄は必須項目です

関連記事

no image

logistic regressionでの失敗

ロジスティック回帰でデータを分析しようとしたところうまくいかず。どうにもこうにもおかしな値が出るし、他の回帰分析との結果と明らか矛盾している。よくわからないのでとりあえずirisデータを使って手順を追 …

no image

Core Concept in Data Analysis – Week 1

パート Data Mining Core Analysis Visualization Illustrate Data Mining data mining = patterns in data + …

no image

数値項目の分析テンプレート

数値フィールド1 数値フィールド1 rate – 1 rate – 2 データ型 算術平均 中央値 分散 トップ3 ボトム3 足切 時系列分析 層別候補 ヒストグラム カウント …

no image

pythonでEDAを実施する – 記述統計

データを取り込む data frameに変換する desdribe()メソッドで要約統計量を出力 各項目について残差分析(ここでは各データが平均値からどの程度離れているか、要するに分散の傾向を把握する …

no image

Mahout環境構築

Mahoutのシステム要件を確認する。 Java 1.6.x or greater. Maven 3.x to build the source code. CPU, Disk and Memory …

2019年10月
« 9月   11月 »
 123456
78910111213
14151617181920
21222324252627
28293031  

side bar top



アーカイブ

カテゴリー