ricvo
ricvo

Reputation: 93

Understanding the Definition of New Tensorflow Operators in C++

I am trying to follow the official guide for defining new operators in tensorflow. https://www.tensorflow.org/extend/adding_an_op

#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/shape_inference.h"

using namespace tensorflow;

REGISTER_OP("ZeroOut")
    .Input("to_zero: int32")
    .Output("zeroed: int32")
    .SetShapeFn([](::tensorflow::shape_inference::InferenceContext* c){
      c->set_output(0, c->input(0));
      return Status::OK();
    });

However I cannot find a line-by-line explanation of this code and in particular I do not understand what is the role of .SetShapeFn([](::tensorflow::shape_inference::InferenceContext* c) and its syntax. Also I am puzzled by InferenceContext, I am guessing it is a way to pass elements of any array one-by-one in succession.. I could not find explicit definitions anywhere, maybe I am looking in the wrong places, can someone help me either with explanation or reference? I would like to deeply understand what this piece of code is doing under the hood.

Upvotes: 6

Views: 971

Answers (1)

Pete Warden
Pete Warden

Reputation: 2878

Did you spot the section on shape inference functions here? https://www.tensorflow.org/extend/adding_an_op#shape_functions_in_c

That has quite a lot of discussion of the ShapeInferenceContext class and the mechanics of writing your own functions. If that doesn't cover what you're interested in, could you give more details?

Upvotes: 2

Related Questions