如何解决Tensorflow自定义OP C ++:Tensorflow使用负值测试SetShapeFn导致崩溃
我正在尝试实现自己的池化操作,并且可以正常工作。但是,在初始化期间,TensorFlow似乎正在测试不同的输入形状,特别是-1。在python中,“未知值”可以从其他维度计算得出(将10个元素数组重塑为(-1,5)意味着将其重塑为(2,5))。我不知道如何正确修改我的SetShapeFn
函数。这是一个片段:
REGISTER_OP("FtPool")
.Input("x: T")
.Attr("stride: list(float)")
.Attr("pool_size: list(float)")
.Attr("T: realnumbertype = DT_FLOAT")
.Output("output: T")
.SetShapeFn([](::tensorflow::shape_inference::InferenceContext* c) {
std::cout << "a" << std::endl;
ShapeHandle shape_hnd;
TF_RETURN_IF_ERROR(c->WithRank(c->input(0),4,&shape_hnd));
std::cout << "b" << std::endl;
// save tensor shape into separated variables
auto N = c->Dim(c->input(0),0);
auto H = c->Dim(c->input(0),1);
auto W = c->Dim(c->input(0),2);
auto C = c->Dim(c->input(0),3);
// tensorflow,during init,tries inputshape with w = -1 and h = -1 which crashes in MakeDim
if (c->Value(H) <= 0 || c->Value(W) <= 0)
return Status::OK();
// save stride so we can use it to calculate output shape
std::vector<float> stride;
c->GetAttr("stride",&stride);
H = c->MakeDim(round(c->Value(H)/stride[0]));
W = c->MakeDim(round(c->Value(W)/stride[1]));
c->set_output(0,c->MakeShape({N,H,W,C}));
return Status::OK();
})
解决方法是以下几行:
if (c->Value(H) <= 0 || c->Value(W) <= 0)
return Status::OK();
但是,这显然是错误的,应该有一种对无用输入尺寸做出反应的正确方法。 SetShapeFn应该从Status
返回一些内容。
版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。