crispy.data

コンパイル済みモデルの読み込みで `Missing key(s) in state_dict` が発生したら

2023-03-03

前書き

torch.compile()を使用することで、深層学習モデルの訓練速度をおよそ1.2倍にできる。1 コンパイル済みのモデルの重みを読み込むには一手間必要なことが分かったので備忘録的に対処方法を記しておく。 なお2023年2月末時点で torch.compile()は nightly build でのみ利用できる。2 正式リリース前に利用したい場合は、以下のコードでインストールできる。
pip install numpy --pre torch --force-reinstall --extra-index-url <https://download.pytorch.org/whl/nightly/cu117>

エラー内容と再現手順

  1. コンパイル済みのモデルの重みを保存する。
import torch from transformers import AutoModel model = AutoModel.from_pretrained('microsoft/deberta-v3-base') path = 'ckpt.pth' torch.save( torch.compile(model).state_dict(), path )
  1. 保存した重みをモデルに読み込む。
ckpt = torch.load(path) model.load_state_dict(ckpt)
  1. エラー発生
--------------------------------------------------------------------------- RuntimeError Traceback (most recent call last) <ipython-input-9-4866a322e7ab> in <module> 1 ckpt = torch.load(path) ----> 2 model.load_state_dict(ckpt) /usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py in load_state_dict(self, state_dict, strict) 2039 2040 if len(error_msgs) > 0: -> 2041 raise RuntimeError('Error(s) in loading state_dict for {}:\\n\\t{}'.format( 2042 self.__class__.__name__, "\\n\\t".join(error_msgs))) 2043 return _IncompatibleKeys(missing_keys, unexpected_keys) RuntimeError: Error(s) in loading state_dict for DebertaV2Model: Missing key(s) in state_dict: "embeddings.position_ids", "embeddings.word_embeddings.weight", "embeddings.LayerNorm.weight",

対処方法

結論だけ知りたい方用に、対処方法を先に書いておく。
ckpt = torch.load(path) restored_ckpt = {} for k,v in ckpt.items(): restored_ckpt[k.replace('_orig_mod.', '')] = v model.load_state_dict(restored_ckpt)
または、
ckpt = torch.load(path) model = torch.compile(model) model.load_state_dict(restored_ckpt)
<All keys matched successfully>

エラーの原因

モデルをコンパイルするとルートになっているtorch.nn.Moduleクラス3の直下に_orig_modというメンバ変数が追加され、そこにモデルの構造全てがコピーされるらしい。(そして学習が行われるのは_orig_mod配下のパラメータになるぽい?)
import torch from transformers import AutoModel model = AutoModel.from_pretrained('microsoft/deberta-v3-base') compiled_model = torch.compile(model)
compiled_model._orig_mod.embeddings
DebertaV2Embeddings( (word_embeddings): Embedding(128100, 768, padding_idx=0) (LayerNorm): LayerNorm((768,), eps=1e-07, elementwise_affine=True) (dropout): StableDropout() )
DebertaV2Embeddings の構造が丸々 _orig_mod の直下に存在することが分かる。
compiled_model.embeddings
DebertaV2Embeddings( (word_embeddings): Embedding(128100, 768, padding_idx=0) (LayerNorm): LayerNorm((768,), eps=1e-07, elementwise_affine=True) (dropout): StableDropout() )
元の構造が消えているわけではなく、そのまま各メンバ変数にはアクセスできる。
model._orig_mod.embeddings
--------------------------------------------------------------------------- AttributeError Traceback (most recent call last) <ipython-input-10-03ce1f919c16> in <module> ----> 1 model._orig_mod.embeddings /usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py in __getattr__(self, name) 1612 if name in modules: 1613 return modules[name] -> 1614 raise AttributeError("'{}' object has no attribute '{}'".format( 1615 type(self).__name__, name)) 1616 AttributeError: 'DebertaV2Model' object has no attribute '_orig_mod'
勿論コンパイルしていないモデルには、_orig_modのメンバ変数が存在しないためAttributeErrorになる。 そしてコンパイルしたモデルのstate_dict()を保存しようとすると、_orig_mod. 自身を含む配下の名前付きパラメータが保存される。
コンパイルしていないdeberta-v3-base
import torch path = 'ckpt.pth' torch.save( model.state_dict(), path ) ckpt = torch.load(path) [k for k in ckpt.keys() ]
['embeddings.position_ids', 'embeddings.word_embeddings.weight', 'embeddings.LayerNorm.weight', . . . 'encoder.rel_embeddings.weight', 'encoder.LayerNorm.weight', 'encoder.LayerNorm.bias']
embeddingsencoderがルートになった名前付きパラメータが含まれてることが分かる。
コンパイルしたdeberta-v3-base
torch.save( compiled_model.state_dict(), path ) ckpt = torch.load(path) [k for k in ckpt.keys() ]
['_orig_mod.embeddings.position_ids', '_orig_mod.embeddings.word_embeddings.weight', '_orig_mod.embeddings.LayerNorm.weight', . . . '_orig_mod.encoder.rel_embeddings.weight', '_orig_mod.encoder.LayerNorm.weight', '_orig_mod.encoder.LayerNorm.bias']
_orig_mod. 配下のパラメータが保存されていることがわかる。 したがって、取り得る対処方法は大きく2つ。
  • 読み込むパラメータの名前をよしなに書き換えてモデルに読み込む。
  • 読み込み時のモデルの構造と保存時のモデルの構造を同じにする。 前者の場合には、読み込む名前付きパラメータの情報であるkeyから_orig_mod.を消せば良いだけで、後者の場合には重みを読み込む前にモデル自体をコンパイルしておけば良いだけ。(詳細は対処方法を参照)

Footnotes

  1. ku-nlp/deberta-v2-base-japaneseをbackboneにした系列分類モデルで実験した結果
  2. つまり正式リリースでは本記事に記載しているエラーが発生しない、若しくは対処方法が異なる場合がある。
  3. 本記事ではHuggingFaceのモデルを使用しているため、transformers.models.deberta_v2.modeling_deberta_v2.DebertaV2Model。いずれにしてもtorch.nn.Moduleクラスのサブクラスを指す。