下面的小程序创建了一个简单的tf图形。我需要遍历这个图形,并在遍历过程中打印节点的信息。
是否可以假设每个图形都有一个根节点(或特定的节点)?我认为这个图形有3个节点,我听说边是张量。
#include<stdio.h>#include<stdlib.h>#include<string.h>#include"tensorflow/c/c_api.h"TF_Graph* g;TF_Status* s;#define CHECK_OK(x) if(TF_OK != TF_GetCode(s))return printf("%s\n",TF_Message(s)),(void*)0TF_Tensor* FloatTensor2x2(const float* values) { const int64_t dims[2] = {2, 2}; TF_Tensor* t = TF_AllocateTensor(TF_FLOAT, dims, 2, sizeof(float) * 4); memcpy(TF_TensorData(t), values, sizeof(float) * 4); return t;}TF_Operation* FloatConst2x2(TF_Graph* graph, TF_Status* s, const float* values, const char* name) { TF_Tensor* tensor=FloatTensor2x2(values); TF_OperationDescription* desc = TF_NewOperation(graph, "Const", name); TF_SetAttrTensor(desc, "value", tensor, s); if (TF_GetCode(s) != TF_OK) return 0; TF_SetAttrType(desc, "dtype", TF_FLOAT); TF_Operation* op = TF_FinishOperation(desc, s); CHECK_OK(s); return op;}TF_Operation* MatMul(TF_Graph* graph, TF_Status* s, TF_Operation* l, TF_Operation* r, const char* name, char transpose_a, char transpose_b) { TF_OperationDescription* desc = TF_NewOperation(graph, "MatMul", name); if (transpose_a) { TF_SetAttrBool(desc, "transpose_a", 1); } if (transpose_b) { TF_SetAttrBool(desc, "transpose_b", 1); } TF_AddInput(desc,(TF_Output){l, 0}); TF_AddInput(desc,(TF_Output){r, 0}); TF_Operation* op = TF_FinishOperation(desc, s); CHECK_OK(s); return op;}TF_Graph* BuildSuccessGraph(TF_Output* inputs, TF_Output* outputs) { // | // z| // | // MatMul // / \ // ^ ^ // | | // x Const_0 y Const_1 // float const0_val[] = {1.0, 2.0, 3.0, 4.0}; float const1_val[] = {1.0, 0.0, 0.0, 1.0}; TF_Operation* const0 = FloatConst2x2(g, s, const0_val, "Const_0"); TF_Operation* const1 = FloatConst2x2(g, s, const1_val, "Const_1"); TF_Operation* matmul = MatMul(g, s, const0, const1, "MatMul",0,0); inputs[0] = (TF_Output){const0, 0}; inputs[1] = (TF_Output){const1, 0}; outputs[0] = (TF_Output){matmul, 0}; CHECK_OK(s); return g;}int main(int argc, char const *argv[]) { g = TF_NewGraph(); s = TF_NewStatus(); TF_Output inputs[2],outputs[1]; BuildSuccessGraph(inputs,outputs); /* HERE traverse g -- maybe with {inputs,outputs} -- to print the graph */ fprintf(stdout, "OK\n");}
如果有人能帮助提供一些函数来获取图形的信息,将不胜感激。
回答:
来自c_api.h:
// 遍历图形的操作。使用方法如下:// size_t pos = 0;// TF_Operation* oper;// while ((oper = TF_GraphNextOperation(graph, &pos)) != nullptr) {// DoSomethingWithOperation(oper);// }TF_CAPI_EXPORT extern TF_Operation* TF_GraphNextOperation(TF_Graph* graph, size_t* pos);
请注意,这仅返回操作,并没有定义从一个节点(操作)导航到下一个节点的方法——这种边关系存储在节点本身(作为指针)。