どうもフジワラです。
今日は、pytorchで
Error(s) in loading state_dict for DataParallel:
Missing key(s) in state_dict
Missing key(s) in state_dict
とエラーを吐くときの解決法を載せます。
前提
まず、前提条件として、t7ファイルやpthファイルで保存するときに、
net.state_dict()
の形で保存していることが前提です。
この形にせずに、netのまま保存すると、使用gpu番号とかも一緒に保存されてしまうので、
ロードするときにクッソ使いづらくなります。
エラーの解決法。
保存するときに、
「module.」が先頭について保存されている場合は、
先頭のmodule.を消してしまえばいいのです。
from collections import OrderedDict checkpoint=torch.load("ロードしたいファイル") state_dict=checkpoint new_state_dict=OrderedDict() for k, v in state_dict.items(): name=k[7:]; #これで先頭のmodule.を消す new_state_dict[name]=v net.load_state_dict(new_state_dict) #これで新たにつくったorderdictを読み込み
こういう感じです。
参考リンクは
https://discuss.pytorch.org/t/solved-keyerror-unexpected-key-module-encoder-embedding-weight-in-state-dict/1686/4
コメント