どうもフジワラです。
今日は、pytorchでニューラルネットで学習した重みを保存する方法について書きます。
前提
netが学習したネットワークモデルです。
t7ファイルで保存する方法
t7ファイルで保存する方法は、
import torch state = { 'net': net.state_dict(), #保存するモデルの重み 'acc': acc, #精度 'epoch': epoch, #エポック数 } torch.save(state, '保存ファイル.t7')
とまあ、こういう感じで保存します。
t7ファイルだと、重みだけでなく、精度、エポック数やいろいろ一緒に保存できるところが便利です。
いわばjsonや辞書型です。
pthファイルで保存する方法
pthファイルで保存する方法は、
import torch torch.save(net.state_dict(), "保存ファイル.pth")
で保存します。
まとめ
後ろの.state_dict()をつけないと、gpu番号まで保存されてしまうので、違う環境で実行することがかなりむずくなったりするので、必ず.state_dict()をつけましょう。