コンパイル済みモデルの読み込みで `Missing key(s) in state_dict` が発生したら
2023-03-03
前書き
torch.compile()
を使用することで、深層学習モデルの訓練速度をおよそ1.2倍にできる。1
コンパイル済みのモデルの重みを読み込むには一手間必要なことが分かったので備忘録的に対処方法を記しておく。
なお2023年2月末時点で torch.compile()
は nightly build でのみ利用できる。2
正式リリース前に利用したい場合は、以下のコードでインストールできる。pip install numpy --pre torch --force-reinstall --extra-index-url <https://download.pytorch.org/whl/nightly/cu117>
エラー内容と再現手順
- コンパイル済みのモデルの重みを保存する。
import torch from transformers import AutoModel model = AutoModel.from_pretrained('microsoft/deberta-v3-base') path = 'ckpt.pth' torch.save( torch.compile(model).state_dict(), path )
- 保存した重みをモデルに読み込む。
ckpt = torch.load(path) model.load_state_dict(ckpt)
- エラー発生
--------------------------------------------------------------------------- RuntimeError Traceback (most recent call last) <ipython-input-9-4866a322e7ab> in <module> 1 ckpt = torch.load(path) ----> 2 model.load_state_dict(ckpt) /usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py in load_state_dict(self, state_dict, strict) 2039 2040 if len(error_msgs) > 0: -> 2041 raise RuntimeError('Error(s) in loading state_dict for {}:\\n\\t{}'.format( 2042 self.__class__.__name__, "\\n\\t".join(error_msgs))) 2043 return _IncompatibleKeys(missing_keys, unexpected_keys) RuntimeError: Error(s) in loading state_dict for DebertaV2Model: Missing key(s) in state_dict: "embeddings.position_ids", "embeddings.word_embeddings.weight", "embeddings.LayerNorm.weight",
対処方法
結論だけ知りたい方用に、対処方法を先に書いておく。
ckpt = torch.load(path) restored_ckpt = {} for k,v in ckpt.items(): restored_ckpt[k.replace('_orig_mod.', '')] = v model.load_state_dict(restored_ckpt)
または、
ckpt = torch.load(path) model = torch.compile(model) model.load_state_dict(restored_ckpt)
↓
<All keys matched successfully>
エラーの原因
モデルをコンパイルするとルートになっている
torch.nn.Module
クラス3の直下に_orig_mod
というメンバ変数が追加され、そこにモデルの構造全てがコピーされるらしい。(そして学習が行われるのは_orig_mod
配下のパラメータになるぽい?)import torch from transformers import AutoModel model = AutoModel.from_pretrained('microsoft/deberta-v3-base') compiled_model = torch.compile(model)
compiled_model._orig_mod.embeddings
DebertaV2Embeddings( (word_embeddings): Embedding(128100, 768, padding_idx=0) (LayerNorm): LayerNorm((768,), eps=1e-07, elementwise_affine=True) (dropout): StableDropout() )
DebertaV2Embeddings
の構造が丸々 _orig_mod
の直下に存在することが分かる。compiled_model.embeddings
DebertaV2Embeddings( (word_embeddings): Embedding(128100, 768, padding_idx=0) (LayerNorm): LayerNorm((768,), eps=1e-07, elementwise_affine=True) (dropout): StableDropout() )
元の構造が消えているわけではなく、そのまま各メンバ変数にはアクセスできる。
model._orig_mod.embeddings
--------------------------------------------------------------------------- AttributeError Traceback (most recent call last) <ipython-input-10-03ce1f919c16> in <module> ----> 1 model._orig_mod.embeddings /usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py in __getattr__(self, name) 1612 if name in modules: 1613 return modules[name] -> 1614 raise AttributeError("'{}' object has no attribute '{}'".format( 1615 type(self).__name__, name)) 1616 AttributeError: 'DebertaV2Model' object has no attribute '_orig_mod'
勿論コンパイルしていないモデルには、
_orig_mod
のメンバ変数が存在しないためAttributeErrorになる。
そしてコンパイルしたモデルのstate_dict()を保存しようとすると、_orig_mod.
自身を含む配下の名前付きパラメータが保存される。コンパイルしていない
deberta-v3-base
import torch path = 'ckpt.pth' torch.save( model.state_dict(), path ) ckpt = torch.load(path) [k for k in ckpt.keys() ]
['embeddings.position_ids', 'embeddings.word_embeddings.weight', 'embeddings.LayerNorm.weight', . . . 'encoder.rel_embeddings.weight', 'encoder.LayerNorm.weight', 'encoder.LayerNorm.bias']
embeddings
やencoder
がルートになった名前付きパラメータが含まれてることが分かる。コンパイルした
deberta-v3-base
torch.save( compiled_model.state_dict(), path ) ckpt = torch.load(path) [k for k in ckpt.keys() ]
['_orig_mod.embeddings.position_ids', '_orig_mod.embeddings.word_embeddings.weight', '_orig_mod.embeddings.LayerNorm.weight', . . . '_orig_mod.encoder.rel_embeddings.weight', '_orig_mod.encoder.LayerNorm.weight', '_orig_mod.encoder.LayerNorm.bias']
_orig_mod.
配下のパラメータが保存されていることがわかる。
したがって、取り得る対処方法は大きく2つ。- 読み込むパラメータの名前をよしなに書き換えてモデルに読み込む。
- 読み込み時のモデルの構造と保存時のモデルの構造を同じにする。
前者の場合には、読み込む名前付きパラメータの情報であるkeyから
_orig_mod.
を消せば良いだけで、後者の場合には重みを読み込む前にモデル自体をコンパイルしておけば良いだけ。(詳細は対処方法を参照)