如何使用Tensorflow的C API遍历图形?

下面的小程序创建了一个简单的tf图形。我需要遍历这个图形,并在遍历过程中打印节点的信息。

是否可以假设每个图形都有一个根节点(或特定的节点)?我认为这个图形有3个节点,我听说边是张量。

#include<stdio.h>#include<stdlib.h>#include<string.h>#include"tensorflow/c/c_api.h"TF_Graph* g;TF_Status* s;#define CHECK_OK(x) if(TF_OK != TF_GetCode(s))return printf("%s\n",TF_Message(s)),(void*)0TF_Tensor* FloatTensor2x2(const float* values) {  const int64_t dims[2] = {2, 2};  TF_Tensor* t = TF_AllocateTensor(TF_FLOAT, dims, 2, sizeof(float) * 4);  memcpy(TF_TensorData(t), values, sizeof(float) * 4);  return t;}TF_Operation* FloatConst2x2(TF_Graph* graph, TF_Status* s, const float* values, const char* name) {  TF_Tensor* tensor=FloatTensor2x2(values);  TF_OperationDescription* desc = TF_NewOperation(graph, "Const", name);  TF_SetAttrTensor(desc, "value", tensor, s);  if (TF_GetCode(s) != TF_OK) return 0;  TF_SetAttrType(desc, "dtype", TF_FLOAT);  TF_Operation* op = TF_FinishOperation(desc, s);  CHECK_OK(s);  return op;}TF_Operation* MatMul(TF_Graph* graph, TF_Status* s, TF_Operation* l, TF_Operation* r, const char* name,                     char transpose_a, char transpose_b) {  TF_OperationDescription* desc = TF_NewOperation(graph, "MatMul", name);  if (transpose_a) {    TF_SetAttrBool(desc, "transpose_a", 1);  }  if (transpose_b) {    TF_SetAttrBool(desc, "transpose_b", 1);  }  TF_AddInput(desc,(TF_Output){l, 0});  TF_AddInput(desc,(TF_Output){r, 0});  TF_Operation* op = TF_FinishOperation(desc, s);  CHECK_OK(s);  return op;}TF_Graph* BuildSuccessGraph(TF_Output* inputs, TF_Output* outputs) {  //            |  //           z|  //            |  //          MatMul  //         /       \  //        ^         ^  //        |         |  //    x Const_0  y Const_1  //  float const0_val[] = {1.0, 2.0, 3.0, 4.0};  float const1_val[] = {1.0, 0.0, 0.0, 1.0};  TF_Operation* const0 = FloatConst2x2(g, s, const0_val, "Const_0");  TF_Operation* const1 = FloatConst2x2(g, s, const1_val, "Const_1");  TF_Operation* matmul = MatMul(g, s, const0, const1, "MatMul",0,0);  inputs[0] = (TF_Output){const0, 0};  inputs[1] = (TF_Output){const1, 0};  outputs[0] = (TF_Output){matmul, 0};  CHECK_OK(s);  return g;}int main(int argc, char const *argv[]) {  g = TF_NewGraph();  s = TF_NewStatus();  TF_Output inputs[2],outputs[1];  BuildSuccessGraph(inputs,outputs);  /* HERE traverse g -- maybe with {inputs,outputs} -- to print the graph */  fprintf(stdout, "OK\n");}

如果有人能帮助提供一些函数来获取图形的信息,将不胜感激。


回答:

来自c_api.h:

// 遍历图形的操作。使用方法如下:// size_t pos = 0;// TF_Operation* oper;// while ((oper = TF_GraphNextOperation(graph, &pos)) != nullptr) {//   DoSomethingWithOperation(oper);// }TF_CAPI_EXPORT extern TF_Operation* TF_GraphNextOperation(TF_Graph* graph,                                                      size_t* pos);

请注意,这仅返回操作,并没有定义从一个节点(操作)导航到下一个节点的方法——这种边关系存储在节点本身(作为指针)。

Related Posts

Keras Dense层输入未被展平

这是我的测试代码: from keras import…

无法将分类变量输入随机森林

我有10个分类变量和3个数值变量。我在分割后直接将它们…

如何在Keras中对每个输出应用Sigmoid函数?

这是我代码的一部分。 model = Sequenti…

如何选择类概率的最佳阈值?

我的神经网络输出是一个用于多标签分类的预测类概率表: …

在Keras中使用深度学习得到不同的结果

我按照一个教程使用Keras中的深度神经网络进行文本分…

‘MatMul’操作的输入’b’类型为float32,与参数’a’的类型float64不匹配

我写了一个简单的TensorFlow代码,但不断遇到T…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注