カタカタブログ

SIerで働くITエンジニアがカタカタした記録を残す技術ブログ。Java, Oracle Database, Linuxが中心です。たまに数学やデータ分析なども。

【機械学習】Rで手書き数字データMNISTをニューラルネットで学習してみた

今日はR言語を使って機械学習の入門ということで、ニューラルネットを使ってみた。
今回の目標は、Rを使って機械学習の一連のプロセスである、トレーニングデータの学習によるモデル構築と、それを使ったテストデータの評価までを一通りやってみる。なので、手書き数字データの分類という、ごくありふれた例ではあるが、精度を高める、というところは目的としない。

手書き数字データMNISTの準備

MNISTは28*28のグレースケールの手書き数字データで、28*28=784個のピクセルと、その画像が0から9のどの数字を表しているかの正解ラベル1つの合計785次元ベクトルとして提供されているデータで、機械学習による分類器の作成の検証にあたって、よく利用されているらしい。

データはkaggleという機械学習の精度を競うサイトからダウンロードできたので、これを使ってみる。
https://www.kaggle.com/c/digit-recognizer/data

ここからトレーニングデータとテストデータを入手できるが、今回は正解ラベル付きのトレーニングデータのみを利用する。
f:id:osn_th:20160411092843p:plain

RでMNIST手書き数字を可視化してみる

まずはトレーニングデータを可視化してみる。
train.csvはlabel, pixel0, pixel2, … , pixel783 というヘッダを持つcsvファイルの形式になっている。label列は0~9のどの数字であるかの正解ラベルを表し、それ以外のピクセルは28*28のそれぞれの画素に対応する濃淡を表すグレースケールの数字が0~255の間の数値として表現されている。

なので、まず以下のRコードで各行ごとに数字をヒートマップで表示してみる。

# MNISTのトレーニングデータ読み込み
train <- read.csv("MNIST/train.csv")

# MNISTのトレーニングデータを画像表示する
view_train <- function(train, range = 1:20) {
  par(mfrow=c(length(range)/4, 5))
  par(mar=c(0,0,0,0))
  for (i in range) {
    m <- matrix(data.matrix(train[i,-1]), 28, 28)
    image(m[,28:1])
  }
}

# ラベルを表示
view_label <- function(train, range = 1:20) {
  matrix(train[range,"label"], 4, 5, byrow = TRUE)
}

range <- 1:20
view_train(train, range)
view_label(train, range)

これを実行すると、以下の結果が得られる。手書き数字の形が見て取れる。
f:id:osn_th:20160411092847p:plain
この画像と対応する正解ラベルの値も同じ4行5列の行列で表示させてみる。確かに手書き画像とラベルの対応関係は一致していそう。

> view_label(train, range)
     [,1] [,2] [,3] [,4] [,5]
[1,]    1    0    1    4    0
[2,]    0    7    3    5    3
[3,]    8    9    1    3    3
[4,]    1    2    0    7    5

ニューラルネットで学習

今回はニューラルネットワークを使って、ピクセルのデータから正解ラベルを導出するための学習を行ってみる。
Rにはnnetというニューラルネットワークのライブラリがあるので、これを使ってみる。

library(nnet)

# MNISTのトレーニングデータ読み込み
train <- read.csv("MNIST/train.csv")

# 42000件のデータを30000件のトレーニングデータと12000件のテストデータにランダム・サンプリング
training.index <- sample(1:nrow(train), 30000)
mnist.train <- train[training.index,]
mnist.test <- train[-training.index,]

# トレーニングデータをニューラルネットで学習
mnist.nnet <- nnet(label ~ ., size=3, data=mnist.train)

これを実行すると、以下の箇所でエラーになった。

> mnist.nnet <- nnet(label ~ ., size=3, data=mnist.train)
 nnet.default(x, y, w, ...) でエラー: too many (2359) weights

?nnetでヘルプを見ると、MaxNWtsというパラメータで重みの最大値を指定しているが、その値を超えたことが原因のよう。

MaxNWts
The maximum allowable number of weights. 
There is no intrinsic limit in the code, but increasing MaxNWts will probably allow fits that are very slow and time-consuming.

なので、MaxNWtsを4000として再実行する。

> mnist.nnet <- nnet(label ~ ., size=3, data=mnist.train, MaxNWts = 4000)
# weights:  2359
initial  value 756877.345042 
final  value 608029.000000 
converged

今度はエラーにならずに正常に学習できたよう。

続いて、テストデータを学習したモデルを使って予測してみる。

> mnist.result <- predict(mnist.nnet, mnist.test, type="class")
 predict.nnet(mnist.nnet, mnist.test, type = "class") でエラー: 
  inappropriate fit for class

またエラーになった。。エラーメッセージはあまり親切でなくよく分からないが、クラス分類のところがまずそう。
なんとなく、元のトレーニングデータの正解ラベルを表すlabel列がintegerなのがまずい気がする。0~9はラベルとして扱いたいのに、数値として回帰分析されているような気がする。。なので、factorにしてみる。

train[,"label"] <- as.factor(train[,"label"])

これで再度学習をやりなおす。

> mnist.nnet <- nnet(label ~ ., size=3, data=mnist.train, MaxNWts = 4000)
# weights:  2395
initial  value 70427.970781 
iter  10 value 56024.007059
iter  20 value 53612.550630
iter  30 value 53358.071108
iter  40 value 53198.640707
iter  50 value 52431.262209
iter  60 value 52263.579201
iter  70 value 51939.286749
iter  80 value 51693.656379
iter  90 value 51588.731655
iter 100 value 51332.028230
final  value 51332.028230 
stopped after 100 iterations

さっきまでと明らかに動きが違う!学習にも結構時間がかかるようになった。
これで再度テストデータを予測してみる。

> mnist.result <- predict(mnist.nnet, mnist.test, type="class")

エラーがでなくなった。成功したっぽい!

最後にテスト結果を出力する。

> table(mnist.result)
mnist.result
   0    1    4    7 
 996 1419 6762 2823

> table(mnist.test$label, mnist.result, dnn = c("Actual", "Predicted"))
      Predicted
Actual    0    1    4    7
     0  891   27   36  186
     1    0 1144  183    6
     2    4   28 1079   44
     3    8   53 1012  139
     4    1    7 1125   50
     5   60   31  438  584
     6   12   15 1152   49
     7    8   12  145 1093
     8    6   96  939  121
     9    6    6  653  551

Actualが実際の正解ラベルの値で、Predictedが今回の学習モデルが推測したラベルである。なんか0,1,4,7としか予測していない。。

もう一度学習

トレーニングデータが少なかったせいか?と思い、今度は42,000件のデータのうち40,000件を使って学習させた後、2,000件のデータでテストすることに。

最終的に以下のRコードを実行。

library(nnet)

# MNISTのトレーニングデータ読み込み
train <- read.csv("MNIST/train.csv")
train[,"label"] <- as.factor(train[,"label"])

# 42000件のデータを40000件のトレーニングデータと2000件のテストデータに分割
training.index <- 1:40000
mnist.train <- train[training.index,]
mnist.test <- train[-training.index,]

# トレーニングデータをニューラルネットで学習
mnist.nnet <- nnet(label ~ ., size=3, data=mnist.train, MaxNWts=4000)

# テストデータを使って評価
mnist.result <- predict(mnist.nnet, mnist.test, type="class")
table(mnist.test$label, mnist.result, dnn = c("Actual", "Predicted"))

# テストデータの正解、予測を表示
range <- 1:20
view_train(mnist.test, range)
view_label(mnist.test, range)
matrix(mnist.result, 4, 5, byrow = TRUE)

結果は以下。さっきよりはマシだが、やはり5, 8, 9の数値が予測結果に一つも出てきていない。

> table(mnist.test$label, mnist.result, dnn = c("Actual", "Predicted"))
      Predicted
Actual   0   1   2   3   4   6   7
     0 180   0   0   1   0  16   0
     1   2 217   1   0   1   0   4
     2   9   8  88  14   1  69   1
     3  99   3  12  50  16   8  10
     4  17   6   0   0 190  10   3
     5 124   1   9   8  11   9   0
     6  16   3   3   0   0 193   1
     7   0   7   2   4  11   0 178
     8 106  12   0   3  43  10   5
     9   9   3   0   1 166   4  22

最後に正答率を算出する。

# 正答率算出
accuracy <- function(actual, predicted) {
  ret = data.frame(actual = actual, predicted = predicted)
  ok = 0
  for (i in 0:9) {
    ok = ok + nrow(ret[ret$actual == i & ret$predicted == i, ])
  }
  ok / length(actual)
}
accuracy(actual = mnist.test$label, predicted = mnist.result)

この結果は以下。

> accuracy(actual = mnist.test$label, predicted = mnist.result)
[1] 0.548

54.8%。。結構低い。。が、最初にも言ったとおり、今回の目的は精度の高い学習モデルを作ることではないので、いったんここまでにする。

まとめ

Rを使ってMNIST手書き数字データをニューラルネットで機械学習し、テストデータを評価して正答率を出す一連の流れをやってみた。ニューラルネットのパラメータはほとんどチューニングせず適当にやっていることもあって、精度は今ひとつだったが、Rを使って学習モデルを作る流れは他の学習アルゴリズムを使ったりパラメータチューニングしても変わらないので、あとはデータの特性や学習結果、アルゴリズムの中身を知りながら精度を高める作業を進めていくのだろうが、今日はここまでとする。

以上!

参考図書

データサイエンティスト養成読本 R活用編 【ビジネスデータ分析の現場で役立つ知識が満載! 】 (Software Design plus)

データサイエンティスト養成読本 R活用編 【ビジネスデータ分析の現場で役立つ知識が満載! 】 (Software Design plus)

  • 作者: 酒巻隆治,里洋平,市川太祐,福島真太朗,安部晃生,和田計也,久本空海,西薗良太
  • 出版社/メーカー: 技術評論社
  • 発売日: 2014/12/12
  • メディア: 大型本
  • この商品を含むブログ (7件) を見る
WEB+DB PRESS Vol.89

WEB+DB PRESS Vol.89

  • 作者: 佐藤歩,泉水翔吾,村田賢太,門田芳典,多賀千夏,奥一穂,伊藤直也,鍛治匠一,中山裕司,高山温,佐藤太一,西尾泰和,中島聡,はまちや2,竹原,青木大祐,WEB+DB PRESS編集部
  • 出版社/メーカー: 技術評論社
  • 発売日: 2015/10/24
  • メディア: 大型本
  • この商品を含むブログを見る

関連記事