スポンサーリンク

[pytorch]重みの保存方法

pytorch
OpenClipart-Vectors / Pixabay
スポンサーリンク

どうもフジワラです。

今日は、pytorchでニューラルネットで学習した重みを保存する方法について書きます。

スポンサーリンク

前提

netが学習したネットワークモデルです。

t7ファイルで保存する方法

t7ファイルで保存する方法は、


import torch

state = {
'net': net.state_dict(), #保存するモデルの重み
'acc': acc, #精度
'epoch': epoch, #エポック数
}

torch.save(state, '保存ファイル.t7')

とまあ、こういう感じで保存します。

t7ファイルだと、重みだけでなく、精度、エポック数やいろいろ一緒に保存できるところが便利です。

いわばjsonや辞書型です。

 

pthファイルで保存する方法

pthファイルで保存する方法は、


import torch

torch.save(net.state_dict(), "保存ファイル.pth")

で保存します。

 

まとめ

後ろの.state_dict()をつけないと、gpu番号まで保存されてしまうので、違う環境で実行することがかなりむずくなったりするので、必ず.state_dict()をつけましょう。

 

 

 

コメント

タイトルとURLをコピーしました