Reputation: 93
I am trying to follow the official guide for defining new operators in tensorflow. https://www.tensorflow.org/extend/adding_an_op
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/shape_inference.h"
using namespace tensorflow;
REGISTER_OP("ZeroOut")
.Input("to_zero: int32")
.Output("zeroed: int32")
.SetShapeFn([](::tensorflow::shape_inference::InferenceContext* c){
c->set_output(0, c->input(0));
return Status::OK();
});
However I cannot find a line-by-line explanation of this code and in particular I do not understand what is the role of .SetShapeFn([](::tensorflow::shape_inference::InferenceContext* c) and its syntax. Also I am puzzled by InferenceContext, I am guessing it is a way to pass elements of any array one-by-one in succession.. I could not find explicit definitions anywhere, maybe I am looking in the wrong places, can someone help me either with explanation or reference? I would like to deeply understand what this piece of code is doing under the hood.
Upvotes: 6
Views: 971
Reputation: 2878
Did you spot the section on shape inference functions here? https://www.tensorflow.org/extend/adding_an_op#shape_functions_in_c
That has quite a lot of discussion of the ShapeInferenceContext class and the mechanics of writing your own functions. If that doesn't cover what you're interested in, could you give more details?
Upvotes: 2