PyTorchとMNIST手書き数字の、学習と分類

はてなブックマーク - PyTorchとMNIST手書き数字の、学習と分類
LINEで送る
Pocket

前回の記事で、scikit-learnの手書き数字の学習の内容を紹介しましt。

今日の記事は、PyTorch+MNISTの手書き数字データセットを使って学習とその後の分類(推論)を紹介します。

PyTorchとは

PyTorch(http://pytorch.org/)はFacebookの人工知能研究グループにより開発されたPython向けのオープンソース機械学習のライブラリです。

PyTorchは元々「Torch(トーチ)」「Lua言語」書かれていましたが、PyTorchはそれのPython版です。2018年10月前半にリリースされ、今安定しているのはPyTorch 0.4.1です。

PyTorchは2016年後半に発表された比較的新しいライブラリです。他の機械学習のフレームワークと比べると、後発ではありますが、最近人気が増え続けています。注目すべき機械学習フレームワークの一つだと言えるでしょう。

MNISTとは

MNIST(http://yann.lecun.com/exdb/mnist/)とは、「Mixed National Institute of Standards and Technology database」の略で、手書きの数字「0~9」とその正解ラベルがセットになっているデータセットです。手書き数字の画像データセットが70,000個あります。

データセットにある手書き文字は28ピクセル×28ピクセルの画像になっており、784次元のデータとなります。この784次元のデータを使って0~9を分類します。

scikit-learnのレ記事でも手書き数字の認識を機械学習のアプローチで体験してみましたが、今回は、ニューラルネットワークの手法を用いて、学習させていきます。

機械学習入門チュートリアルscikit-learnで手書き数字学習と認識

また、scikit-learnの手書き数字の一つの画像は8×8の画像で、全部で64ピクセルということに対して、MINISTのデータは、28×28の画像で、全部で784ピクセルがあるということで、MNISTデータの「粒度」はscikit-learnのデータと比べ、より細かく、ディテールが多いです。

MINISTの手書き数字データセットは様々な機械学習フレームワークから簡単に利用することができます。

例えば、Keras、PyTorch、Chainerも一行でデータセットのロードすることができます。

今回の記事の内容はGoogle Colabで実行できます。

PyTorchのインストール

下記のようにGoogle Colabが提供している標準な方法を使ってインストールします。

実行結果:0.4.1

必要なパケージを導入する

まず、torchvisionが必要です。torchvisionとは,PyTorchのコンピュータービジョン用のパッケージで,データセットのロードや画像の前処理の関数などが入っています。

今回使うMNISTのデータセット以外に、16種類のデータセットが予め入っています。こちらのデータを使って、練習するのもいいですね。

今回は、torchvision.datasetsを使って、MNISTのデータを導入します。

データセットのダウンロード

データの中身を見てみる

1つ目のデータを可視化してみる

実行結果:

ラベル: tensor(5)

 

学習データと検証データを用意する

学習のデータか検証のデータかは、trainという引数で決めています。

train=Trueの場合は、学習データとしてロードします。 train=Falseの場合は、検証データとしてロードします。

PyTorchでは、こんなやり方が分かりやすいですね!

ニューラルネットワークを定義

ニューラルネットワーク(MLP)は torch.nn パッケージを使用して構築できます。

親クラスはPyTorchでは、nn.Moduleとなります。

そして、マルチレイヤーパーセプトロン (ニューラルネットワーク)の定義をしていきます。

ネットワークをFeed Forwardネットワークにします。

モデル

コスト関数と最適化手法を定義

PyTorchでは代表的なコスト関数や最適化手法はあらかじめ提供されています。

コスト関数にクロスエントロピー、最適化手法にSGDをします。

optimizer.SGDに引数として最適化対象のパラメータ一覧を渡しています。

PyTorchでは

optimizer = optimizer.SGD(model.parameters(), lr=0.01)

という形で、モデルに最適手法を適用していますね。

学習

早速学習させてみましょう!まず、学習回数を4回にしてみましょう

学習ループ内では次のような作業を順次実行していきます:

実行結果

検証データを使った検証

実行結果:

正解率:9204 / 10000 = 0.920400

検証データを決める

実行結果:

ラベル: tensor(6)

まとめ

いかがですか?この記事の内容でPyTorchの使い方のイメージを摘むことができましたでしょうか?

ニューラルネットワークの定義の仕方がとても分かりやすいかと思います。

では、また別の機会でPyTorchの使い方を紹介していきたいと思います。

はてなブックマーク - PyTorchとMNIST手書き数字の、学習と分類
LINEで送る
Pocket

Add a Comment

メールアドレスが公開されることはありません。 * が付いている欄は必須項目です

By continuing to use the site, you agree to the use of cookies. more information

The cookie settings on this website are set to "allow cookies" to give you the best browsing experience possible. If you continue to use this website without changing your cookie settings or you click "Accept" below then you are consenting to this.

Close