Reputation: 2008
From TensorFlow documentation, the following can be done to build graph using inherent OP
#include "tensorflow/cc/client/client_session.h"
#include "tensorflow/cc/ops/standard_ops.h"
#include "tensorflow/core/framework/tensor.h"
int main() {
using namespace tensorflow;
using namespace tensorflow::ops;
Scope root = Scope::NewRootScope();
// Matrix A = [3 2; -1 0]
auto A = Const(root, { {3.f, 2.f}, {-1.f, 0.f} });
// Vector b = [3 5]
auto b = Const(root, { {3.f, 5.f} });
// v = Ab^T
auto v = MatMul(root.WithOpName("v"), A, b, MatMul::TransposeB(true));
std::vector<Tensor> outputs;
ClientSession session(root);
// Run and fetch v
TF_CHECK_OK(session.Run({v}, &outputs));
// Expect outputs[0] == [19; -3]
LOG(INFO) << outputs[0].matrix<float>();
return 0;
}
It seems that MatMul
class is auto generated as there is no tensorflow/cc/ops/math_ops.h
in the github source code.
How to do the same thing for custom op such as ZeroOut OP from here
Upvotes: 2
Views: 526
Reputation: 2008
Take ZeroOut
from here as example, you have to do the following
class ZeroOut {
public:
ZeroOut(const ::tensorflow::Scope& scope, ::tensorflow::Input x);
operator ::tensorflow::Output() const { return y; }
operator ::tensorflow::Input() const { return y; }
::tensorflow::Node* node() const { return y.node(); }
::tensorflow::Output y;
};
ZeroOut::ZeroOut(const ::tensorflow::Scope& scope, ::tensorflow::Input x) {
if (!scope.ok()) return;
auto _x = ::tensorflow::ops::AsNodeOut(scope, x);
if (!scope.ok()) return;
::tensorflow::Node* ret;
const auto unique_name = scope.GetUniqueNameForOp("ZeroOut");
auto builder = ::tensorflow::NodeBuilder(unique_name, "ZeroOut")
.Input(_x)
;
scope.UpdateBuilder(&builder);
scope.UpdateStatus(builder.Finalize(scope.graph(), &ret));
if (!scope.ok()) return;
scope.UpdateStatus(scope.DoShapeInference(ret));
this->y = Output(ret, 0);
}
Then you can use it to build graph
Scope root = Scope::NewRootScope();
// Matrix A = [3 2; -1 0]
auto A = Const(root, { {3, 2}, {-1, 0} });
auto v = ZeroOut(root.WithOpName("v"), A);
std::vector<Tensor> outputs;
ClientSession session(root);
// Run and fetch v
TF_CHECK_OK(session.Run({v}, &outputs));
LOG(INFO) << outputs[0].matrix<int>();
Note: For TensorFlow inherent OP, code like ZeroOut class
is autogenerated by bazel rule. We can imitate those codes(e.g. tensorflow/cc/ops/math_ops.h
) to hand write our own classes if we only have a few custom OPs.
Upvotes: 1