スポンサーリンク

[pytorch]Error(s) in loading state_dict for DataParallel: Missing key(s) in state_dictと出るときの解決法

pytorch
OpenClipart-Vectors / Pixabay
スポンサーリンク

どうもフジワラです。

今日は、pytorchで

Error(s) in loading state_dict for DataParallel:
Missing key(s) in state_dict
とエラーを吐くときの解決法を載せます。
スポンサーリンク

前提

まず、前提条件として、t7ファイルやpthファイルで保存するときに、

net.state_dict()
の形で保存していることが前提です。
この形にせずに、netのまま保存すると、使用gpu番号とかも一緒に保存されてしまうので、
ロードするときにクッソ使いづらくなります。

エラーの解決法。

保存するときに、

「module.」が先頭について保存されている場合は、

先頭のmodule.を消してしまえばいいのです。


from collections import OrderedDict

checkpoint=torch.load("ロードしたいファイル")

state_dict=checkpoint
new_state_dict=OrderedDict()

for k, v in state_dict.items():
  name=k[7:]; #これで先頭のmodule.を消す

  new_state_dict[name]=v

net.load_state_dict(new_state_dict) #これで新たにつくったorderdictを読み込み


こういう感じです。
参考リンクは
https://discuss.pytorch.org/t/solved-keyerror-unexpected-key-module-encoder-embedding-weight-in-state-dict/1686/4

コメント

タイトルとURLをコピーしました