先日書いた、TensorFrow使ってみました的なエントリには、
実行したスクリプトの説明が全くなかったので、
改めて書いてみることにしました。
使ったチュートリアルはこちらです。
https://www.tensorflow.org/versions/r0.10/tutorials/mnist/beginners/index.html
利用したスクリプトはこちらです。
from tensorflow.examples.tutorials.mnist import input_data mnist = input_data.read_data_sets("MNIST_data/", one_hot=True) import tensorflow as tf x = tf.placeholder(tf.float32, [None, 784]) W = tf.Variable(tf.zeros([784, 10])) b = tf.Variable(tf.zeros([10])) y = tf.nn.softmax(tf.matmul(x, W) + b) y_ = tf.placeholder(tf.float32, [None, 10]) cross_entropy = tf.reduce_mean(-tf.reduce_sum(y_ * tf.log(y), reduction_indices=[1])) train_step = tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy) init = tf.initialize_all_variables() sess = tf.Session() sess.run(init) for i in range(1000): batch_xs, batch_ys = mnist.train.next_batch(100) sess.run(train_step, feed_dict={x: batch_xs, y_: batch_ys}) correct_prediction = tf.equal(tf.argmax(y,1), tf.argmax(y_,1)) accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32)) print(sess.run(accuracy, feed_dict={x: mnist.test.images, y_: mnist.test.labels}))
キーワードとしては、
・ニューラルネットワーク(畳み込みではない) ・尤度関数:ソフトマックス関数 ・誤差関数:クロスエントロピー関数 ・MNISTのテストデータを利用 ・勾配降下法
というところです。
今回は1行ごとに説明を付けてみました。
それではいってみましょー。
○ 準備(モジュールインポートなど)
from tensorflow.examples.tutorials.mnist import input_data
まず表面的に読めば、
tensorflow.examples.tutorials.mnist
というライブラリから、
input_data
というモジュールを読み込んでくる、
ということになります。
で、input_dataモジュールってなんなの?
というと、
今回の目的に限って言えば、
read_data_sets関数(※)の使用が目的です。
※MNISTのデータをWEBからダウンロードしてロードするための関数。
input_dataモジュールの実体は、
/usr/local/lib/python2.7/dist-packages/tensorflow/examples/tutorials/mnist/input_data.py
です。
この中身を見てみると、
(略) """Functions for downloading and reading MNIST data.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function import gzip import os import tempfile import numpy from six.moves import urllib from six.moves import xrange # pylint: disable=redefined-builtin import tensorflow as tf from tensorflow.contrib.learn.python.learn.datasets.mnist import read_data_sets
となっています。
まあ要はこれから使うであろうモジュール・ライブラリをインポートするための
ライブラリです。
大まかに、互換性維持のためのものと、
MNISTのテストデータセットをインポートしてくるためのものがあります。
肝心なのは、最後のread_data_setsで、
これはMNISTデータをロードしてくるモジュールです。
実体は、
/usr/local/lib/python2.7/dist-packages/tensorflow/contrib/learn/python/learn/datasets/mnist.py
の中で関数として定義されています。
全部書くときりがないので概略だけ書くと、
http://yann.lecun.com/exdb/mnist/
からテストデータをダウンロードしてきて、
スクリプトを実行したカレントディレクトリにMNIST_dataディレクトリとして配備して、
学習用データ、学習用データのラベル、テスト用データ、テスト用データのラベルを
ロードしてきます。
引数は
def read_data_sets(train_dir, fake_data=False, one_hot=False, dtype=dtypes.float32, reshape=True):
であり、
リターンは
return base.Datasets(train=train, validation=validation, test=test)
です。
※リターンの各変数の内訳は、
train = DataSet(train_images, train_labels, dtype=dtype, reshape=reshape) validation = DataSet(validation_images, validation_labels, dtype=dtype, reshape=reshape) test = DataSet(test_images, test_labels, dtype=dtype, reshape=reshape)
です。
長くなりましたが、
そんなわけでこの行は、
read_data_sets関数を使うためにinput_dataモジュールをロードしてくる、
というための行でした。
mnist = input_data.read_data_sets(“MNIST_data/”, one_hot=True)
1行目が分かっていれば、ここはすぐでしょう。
input_dataモジュールのread_data_setsを呼んで、mnist変数に代入するわけです。
テストデータのダウンロード先は、カレントディレクトリのMNIST_data、
one_hot、すなわち、
教師信号はベクトル中のただ1つの要素のみが1で、それ以外が0である、
という条件をTrueにしています。
import tensorflow as tf
tensorflowライブラリをtfの別名でインポートしています。
実体は
/usr/local/lib/python2.7/dist-packages/tensorflow
です。
x = tf.placeholder(tf.float32, [None, 784])
ここは学習対象データの型と次元数を定義していて、
placeholder関数がまさにそのための関数です。
リファレンスはこちら。
https://www.tensorflow.org/versions/r0.10/api_docs/python/io_ops.html#placeholder
この例では単精度浮動小数点のデータであり、
784次元である、と定義しています。
784、という数字は、
MNISTの画像データが縦28 × 横28 = 784画素である点に由来しています。
○ 領域定義
W = tf.Variable(tf.zeros([784, 10]))
b = tf.Variable(tf.zeros([10]))
ここは、結合重みWと、バイアスbの領域を定義しています。
結合重みは784個のインプット(画素)に対して、10通りの出力(0~9)があるから、
784×10、となっています。
言い方を変えれば、入力層が784次元、出力層が10次元、ということです。
バイアスは、学習データxと結合重みWの内積に加えるものなので、10次元でOKです。
y = tf.nn.softmax(tf.matmul(x, W) + b)
softmax関数により計算した、出力を定義しています。
softmax関数のリファレンスはこちら。
https://www.tensorflow.org/versions/r0.10/api_docs/python/nn.html#softmax
matmulはベクトルの内積を計算する関数です。
y_ = tf.placeholder(tf.float32, [None, 10])
これは教師信号用の領域です。
0~9の10通りの値を持ちうるので、10次元です。
cross_entropy = tf.reduce_mean(-tf.reduce_sum(y_ * tf.log(y), reduction_indices=[1]))
誤差関数を定義しています。
ここではクロスエントロピー関数を使うため、
変数名はcross_entropyとなっています。
使用している関数reduce_sum、reduce_meanのリファレンスはこちら。
https://www.tensorflow.org/versions/r0.10/api_docs/python/math_ops.html#reduce_sum
https://www.tensorflow.org/versions/r0.10/api_docs/python/math_ops.html#reduce_mean
○ トレーニング
train_step = tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy)
トレーニングの核となる処理を定義します。
今回のチュートリアルでは、勾配降下法により学習を行います。
GradientDescentOptimizerはそのための関数です。
リファレンスはこちら
https://www.tensorflow.org/versions/r0.10/api_docs/python/train.html#GradientDescentOptimizer
学習レート(結合重みの更新量)を0.5とし、
先ほど定義したクロスエントロピー(誤差関数)を最小化するように学習していきます。
init = tf.initialize_all_variables()
全ての変数を初期化しています。
まあおまじないみたいなもんでしょう。
sess = tf.Session()
セッションの定義です。
Sessionクラスの定義はこちら。
https://www.tensorflow.org/versions/r0.10/api_docs/python/client.html#Session
sess.run(init)
セッションの開始です。
for i in range(1000):
今回は学習を1000回繰り返すので、
for文で1000回繰り返すよう指定しています。
batch_xs, batch_ys = mnist.train.next_batch(100)
batch_xsとbatch_ysに、next_batch関数で取得した100個のミニバッチを
入力します。
next_batch関数は、
mnistモジュールで定義されており、mnistモジュールの実体は、
/usr/local/lib/python2.7/dist-packages/tensorflow/contrib/learn/python/learn/datasets/mnist.py
にあります。
sess.run(train_step, feed_dict={x: batch_xs, y_: batch_ys})
先ほど定義したtrain_stepにしたがって、
前行で定義したミニバッチを使って学習します。
○ 性能評価
correct_prediction = tf.equal(tf.argmax(y,1), tf.argmax(y_,1))
正しく予想できたケースをブーリアンのテンソルで返します。
使用しているequal関数のリファレンスはこちら。
https://www.tensorflow.org/versions/r0.10/api_docs/python/control_flow_ops.html#equal
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
正答率の計算です。
前行のcorrect_predictionはまだテンソルなので、
cast関数で単精度浮動小数点に変換します。
cast関数のリファレンスはこちら。
https://www.tensorflow.org/versions/r0.10/api_docs/python/array_ops.html#cast
元がブーリアン型と分かっているので、reduce_mean関数で平均を出せば、
それが正答率になる、ということですね。
print(sess.run(accuracy, feed_dict={x: mnist.test.images, y_: mnist.test.labels}))
計算した正答率を出力します。
こんな感じですね。
○ まとめ
ぶっちゃけ、pythonさわったのも、機械学習のライブラリさわったのも、
NN実装したのも、ubuntu使ったのも、ぜーんぶ初めてだったので、
完全に手探りです。
ざーっとなら、理論的な知識がなくても書けるだろうな、
と思いますが、一歩踏み込んで理解しようとすると、
やはりNNの考え方は知らないと無理ですし、
尤度関数、誤差関数など、
なんのためにあるのかが説明できないでしょう。
今後はCNN(畳み込みニューラルネットワーク)も手を付けていきたいなあ、
と思っていますが、
ただpythonでコード書いてみました、
だけではなく、しっかり理論的な裏付けもとりながら
やっていきたいですね。
では。