如何解决在 MXNet 1.7.0/Gluon 中从预训练的 BERT 网络中提取词嵌入 SymbolBlock
我正在尝试创建一个 Transformer seq2seq 模型,并且需要能够检索目标序列的词嵌入。这可以在 Torch 中完成,例如self.encoder.embeddings(target_ids)
这给了我词 + 位置 + 标记类型嵌入。
从我收集到的信息来看,使用 MXNet Gluon SymbolBlock 中的预训练网络实际上不可能做到这一点,所以我决定“提取”我需要的网络部分并通过它传递 target_id。我能够复制我需要的网络部分,但后来我遇到了重复参数的问题。
我的下一个想法是将网络分成两部分,这样我就可以直接访问嵌入部分并且没有重复的参数,但我正在努力寻找一种将 SymbolBlock 或 Symbol 一分为二的方法。
这是复制词嵌入部分的工作示例(由于重复参数,在整个网络中不起作用):
# load the bert-symbol.json file
symbol = mx.sym.load(symbol_file)
inputs = [mx.sym.var('data0'),mx.sym.var('data1'),mx.sym.var('data2')]
# end of the word embedding part of the model
outputs = symbol.get_internals()['bertencoder0_layernorm0_layernorm0_output']
# construct SymbolBlock
embed = mx.gluon.SymbolBlock(inputs=inputs,outputs=outputs)
# load weights into symbolblock
embed.load_parameters(weight_file,ctx=ctx,ignore_extra=True)
我进行了广泛的搜索,但没有找到解决方案,也许有更好的方法来实现这一目标?
版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。