tianyapiaozi
tianyapiaozi

Reputation: 2008

How to use Custom OP to build TensorFlow Graph in C++?

From TensorFlow documentation, the following can be done to build graph using inherent OP

#include "tensorflow/cc/client/client_session.h"
#include "tensorflow/cc/ops/standard_ops.h"
#include "tensorflow/core/framework/tensor.h"

int main() {
  using namespace tensorflow;
  using namespace tensorflow::ops;
  Scope root = Scope::NewRootScope();
  // Matrix A = [3 2; -1 0]
  auto A = Const(root, { {3.f, 2.f}, {-1.f, 0.f} });
  // Vector b = [3 5]
  auto b = Const(root, { {3.f, 5.f} });
  // v = Ab^T
  auto v = MatMul(root.WithOpName("v"), A, b, MatMul::TransposeB(true));
  std::vector<Tensor> outputs;
  ClientSession session(root);
  // Run and fetch v
  TF_CHECK_OK(session.Run({v}, &outputs));
  // Expect outputs[0] == [19; -3]
  LOG(INFO) << outputs[0].matrix<float>();
  return 0;
}

It seems that MatMul class is auto generated as there is no tensorflow/cc/ops/math_ops.h in the github source code. How to do the same thing for custom op such as ZeroOut OP from here

Upvotes: 2

Views: 526

Answers (1)

tianyapiaozi
tianyapiaozi

Reputation: 2008

Take ZeroOut from here as example, you have to do the following

class ZeroOut {
 public:
  ZeroOut(const ::tensorflow::Scope& scope, ::tensorflow::Input x);
  operator ::tensorflow::Output() const { return y; }
  operator ::tensorflow::Input() const { return y; }
  ::tensorflow::Node* node() const { return y.node(); }

  ::tensorflow::Output y;
};

ZeroOut::ZeroOut(const ::tensorflow::Scope& scope, ::tensorflow::Input x) {
  if (!scope.ok()) return;
  auto _x = ::tensorflow::ops::AsNodeOut(scope, x);
  if (!scope.ok()) return;
  ::tensorflow::Node* ret;
  const auto unique_name = scope.GetUniqueNameForOp("ZeroOut");
  auto builder = ::tensorflow::NodeBuilder(unique_name, "ZeroOut")
                     .Input(_x)
  ;
  scope.UpdateBuilder(&builder);
  scope.UpdateStatus(builder.Finalize(scope.graph(), &ret));
  if (!scope.ok()) return;
  scope.UpdateStatus(scope.DoShapeInference(ret));
  this->y = Output(ret, 0);
}

Then you can use it to build graph

Scope root = Scope::NewRootScope();
// Matrix A = [3 2; -1 0]
auto A = Const(root, { {3, 2}, {-1, 0} });
auto v = ZeroOut(root.WithOpName("v"), A);
std::vector<Tensor> outputs;
ClientSession session(root);
// Run and fetch v
TF_CHECK_OK(session.Run({v}, &outputs));
LOG(INFO) << outputs[0].matrix<int>();

Note: For TensorFlow inherent OP, code like ZeroOut class is autogenerated by bazel rule. We can imitate those codes(e.g. tensorflow/cc/ops/math_ops.h) to hand write our own classes if we only have a few custom OPs.

Upvotes: 1

Related Questions