【PyTorch】サンプル⑦ 〜 optim パッケージ 〜


1. 目的
2. 前準備
3. optim パッケージ
4. PyTorchのインポート
5. 使用するデータ
6. ニューラルネットワークのモデルを定義
7. 損失(loss)の定義
8. 学習パラメータ
9. 最適化関数(optimizer)の定義
10. 学習回数
11. モデルへのデータ入出力 (Forward pass)
12. 損失(loss)の計算
13. 勾配の初期化
14. 勾配の計算
15. パラメータ(Weight)の更新
16. 実行
16.1. 7_optim_package.py


1. 目的

  • PyTorch: optimを参考にPyTorchのoptimパッケージを使って最適化関数(optimizer)を定義する。

2. 前準備

PyTorchのインストールはこちらから。

初めて、Google Colaboratoryを使いたい方は、こちらをご覧ください。

3. optim パッケージ

optimは、最適化アルゴリズムの定義に用います。

これまでは、モデルのパラメータ(weight)を更新するために、自分の手で確率勾配降下法(SGD: stochastic gradient descent)のコードを作っていました。

今回のチュートリアルでは、optimパッケージを使ってパラメータを更新する最適化関数(optimizer)を定義していきます。

optimパッケージには、SGD+momentum、RMSProp, Adamなどディープラーニング界でよく使われる最適化アルゴリズムが複数あります。

4. PyTorchのインポート

import torch

5. 使用するデータ

バッチサイズNを64、入力の次元D_inを1000、隠れ層の次元Hを100、出力の次元D_outを10とします。

# N is batch size; D_in is input dimension;
# H is hidden dimension; D_out is output dimension.
N, D_in, H, D_out = 64, 1000, 100, 10

入力(x)と予測したい(y)を乱数で定義します。

# Create random input and output data
x = torch.randn(N, D_in)
y = torch.randn(N, D_out)

6. ニューラルネットワークのモデルを定義

前回のnnパッケージのチュートリアルと同じように、ニューラルネットワークのモデルをnnパッケージを用いて定義します。

input > Linear(線型結合) > ReLU(活性化関数) > Linear(線型結合) > outputの順に層を積み重ねます。

model = torch.nn.Sequential(
    torch.nn.Linear(D_in, H),
    torch.nn.ReLU(),
    torch.nn.Linear(H, D_out),
)

7. 損失(loss)の定義

二乗誤差をnnパッケージを用いて計算します。

reductionのデフォルトはmeanですので、何も指定しなければtorch.nn.MSELossは平均二乗誤差を返します。

reduction=sumとした場合は、累積二乗誤差を算出します。

loss_fn = torch.nn.MSELoss(reduction='sum')

8. 学習パラメータ

学習率learning_rate1e-4とします。

learning_rate = 1e-4

9. 最適化関数(optimizer)の定義

最適化関数の定義は、torch.optimを使って簡単にできます。

このチュートリアルでは、最適化関数としてAdamを選択しています。

torch.optim.Adamの引数は、モデルのパラメータmodel.parameters()と学習率lr=learning_rateです。

optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

10. 学習回数

学習回数を500回とします。

for t in range(500):

11. モデルへのデータ入出力 (Forward pass)

定義したニューラルネットワークモデルmodelへデータxを入力し、予測値y_predを取得します。

    y_pred = model(x)

12. 損失(loss)の計算

定義した損失関数で予測値y_predと真値yとの間の損失を計算します。

    loss = loss_fn(y_pred, y)
    if t % 100 == 99:
        print(t, loss.item())

13. 勾配の初期化

逆伝播(backward)させる前に、モデルのパラメータが持つ勾配を0(ゼロ)で初期化します。

    optimizer.zero_grad()

(参考) optimパッケージを使わない場合、以下のように記述していました。

    model.zero_grad()

14. 勾配の計算

backwardメソッドでモデルパラメータ(Weight)の勾配を算出します。

    loss.backward()

15. パラメータ(Weight)の更新

stepメソッドでモデルのパラメータを更新します。

    optimizer.step()

(参考) optimパッケージを使わない場合、以下のように記述していました。

    with torch.no_grad():
        for param in model.parameters():
            param -= learning_rate param.grad

16. 実行

以下のコードを7_optim_package.pyとして保存します。

16.1. 7_optim_package.py

import torch

# N is batch size; D_in is input dimension;
# H is hidden dimension; D_out is output dimension.
N, D_in, H, D_out = 64, 1000, 100, 10

# Create random Tensors to hold inputs and outputs
x = torch.randn(N, D_in)
y = torch.randn(N, D_out)

# Use the nn package to define our model and loss function.
model = torch.nn.Sequential(
    torch.nn.Linear(D_in, H),
    torch.nn.ReLU(),
    torch.nn.Linear(H, D_out),
)
loss_fn = torch.nn.MSELoss(reduction='sum')

# Use the optim package to define an Optimizer that will update the weights of
# the model for us.
learning_rate = 1e-4
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
for t in range(500):
    # Forward pass: compute predicted y by passing x to the model.
    y_pred = model(x)

    # Compute and print loss.
    loss = loss_fn(y_pred, y)
    if t % 100 == 99:
        print(t, loss.item())

    # Before the backward pass, use the optimizer object to zero all of the
    # gradients for the variables it will update (which are the learnable
    # weights of the model).
    optimizer.zero_grad()

    # Backward pass: compute gradient of the loss with respect to model
    # parameters
    loss.backward()

    # Calling the step function on an Optimizer makes an update to its
    # parameters
    optimizer.step()

保存ができたら実行しましょう。

左の数字が学習回数、右の数値がパーセプトロンの推定値と実際の答えと二乗誤差です。

学習を重ねるごとに、二乗誤差が小さくなることがわかります。

In:

python3 7_optim_package.py 

Out:

99 54.53140640258789
199 0.6615686416625977
299 0.004912459757179022
399 3.133853169856593e-05
499 1.0647939063801459e-07

コメントを残す

このサイトはスパムを低減するために Akismet を使っています。コメントデータの処理方法の詳細はこちらをご覧ください