I have an tensorflow .pb file which I would like to load into python DNN, restore the graph and get the predictions. I am doing this to test out whether the .pb file created can make the predictions similar to the normal Saver.save() model.
My basic problem is am getting a very different value of predictions when I make them on Android using the above mentioned .pb file
My .pb file creation code:
frozen_graph = tf.graph_util.convert_variables_to_constants( session, session.graph_def, ['outputLayer/Softmax'] )with open('frozen_model.pb', 'wb') as f: f.write(frozen_graph.SerializeToString())
So I have two major concerns:
- How can I load the above mentioned .pb file to python Tensorflow model ?
- Why am I getting completely different values of prediction in python and android ?
2 Answers2
The following code will read the model and print out the names of the nodes in the graph.
import tensorflow as tffrom tensorflow.python.platform import gfileGRAPH_PB_PATH = './frozen_model.pb'with tf.Session() as sess: print("load graph") with gfile.FastGFile(GRAPH_PB_PATH,'rb') as f: graph_def = tf.GraphDef() graph_def.ParseFromString(f.read()) sess.graph.as_default() tf.import_graph_def(graph_def, name='') graph_nodes=[n for n in graph_def.node] names = [] for t in graph_nodes: names.append(t.name) print(names)
You are freezing the graph properly that is why you are getting different results basically weights are not getting stored in your model. You can use thefreeze_graph.py (link) for getting a correctly stored graph.
6 Comments
sess.graph.as_default()
doing?graph_def.ParseFromString(f.read())
DecodeError: Error parsing messagetf.gfile.GFile
instead ofgfile.FastGFile
in 2019ValueError: Input 0 of node InceptionResnetV1/Conv2d_1a_3x3/BatchNorm/cond/Switch was passed float from phase_train:0 incompatible with expected bool.
Do you have any idea why is it happening ? ThanksHere is the updated code for tensorflow 2.
import tensorflow as tfGRAPH_PB_PATH = './frozen_model.pb'with tf.compat.v1.Session() as sess: print("load graph") with tf.io.gfile.GFile(GRAPH_PB_PATH,'rb') as f: graph_def = tf.compat.v1.GraphDef() graph_def.ParseFromString(f.read()) sess.graph.as_default() tf.import_graph_def(graph_def, name='') graph_nodes=[n for n in graph_def.node] names = [] for t in graph_nodes: names.append(t.name) print(names)
1 Comment
DecodeError: Error parsing message with type 'tensorflow.GraphDef'
.Explore related questions
See similar questions with these tags.