如何解决在torch.jit.tracing中跟踪Python对象以进行保存和加载
我正在尝试将pytorch模型从一台服务器发送到另一台服务器。我选择使用torch.jit.tracing保存模型以进行传输。现在的问题是,我想发送加密的参数,这是一个int python对象的列表(为每个密文保存100多个数字)。我试图将其保存在dict()对象中并对其进行跟踪。
# Model
class Net(nn.Module):
def __init__(self,input_size,output_size):
super(Net,self).__init__()
self.fc1 = nn.Linear(input_size,50)
# nn.init.normal_(self.fc1.weight,mean=0,std=1)
self.fc2 = nn.Linear(50,10)
# nn.init.normal_(self.fc2.weight,std=1)
self.fc3 = nn.Linear(10,output_size)
nn.init.normal_(self.fc3.weight,std=1)
self.fc1_enc = None
self.fc2_enc = None
self.fc3_enc = None
# a dict of python int objects to save encrypted big ints.
self.encrypted_params = {"int_wei": [[74812937489217492138741294723198479123749321874,11111112222222222222888888888888888888888888888]]}
self.exponent_encrypted_params = dict()
def forward(self,x):
x = F.elu(self.fc1(x))
x = F.elu(self.fc2(x))
return F.log_softmax(self.fc3(x),dim=-1)
然后我通过以下方式跟踪模型:
model = Net(119,2).to(device)
traced_model = torch.jit.trace(model,torch.zeros(119,dtype=torch.float))
但是,我无法加载dict wit错误:AttributeError: 'ScriptModule' object has no attribute 'encrypted_params'
。
如何使用大整数对象作为参数序列化模型?谢谢!
版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。