Noltibus
Noltibus

Reputation: 1360

How to get only the first n layers of a network loaded from a .pb file

What I want: A protobuf file which contains all the layers of a pretrained AlexNet up until the pool5 layer.

What I have: I downloaded the file of the weights of AlexNet here and converted it to a protobuf file of the model and a frozen protobuf file with this code. I loaded the resulting protobuf file with this code:

import tensorflow as tf
from tensorflow.python.platform import gfile
from tensorflow.contrib import graph_editor as editor

GRAPH_PB_PATH = 'alexnet.pb'
with tf.Session() as sess:
   print("load graph")
   with gfile.FastGFile(GRAPH_PB_PATH,'rb') as f:
       graph_def = tf.GraphDef()
   graph_def.ParseFromString(f.read())
   sess.graph.as_default()
   tf.import_graph_def(graph_def, name='')
   writer = tf.summary.FileWriter('logs', sess.graph)
   writer.close()
   graph_nodes=[n for n in graph_def.node]
   names = []
   for t in graph_nodes:
      names.append(t.name)
   print(names)

Now I want to throw away all layers subsequent to the pool5 layer, such that the input of the network is an image and the output is whatever pool5 return (i.e. some vector). I would like to save the resulting, and now way smaller network into a protobuf file again. So how do I delete the unnecessary layers? Thanks in advance!

Upvotes: 0

Views: 1197

Answers (1)

schil
schil

Reputation: 322

graph_def = tf.GraphDef()
with open('alexnet.pb', 'rb') as f:
    graph_def.ParseFromString(f.read())

with tf.Graph().as_default() as graph:
    importer.import_graph_def(graph_def, name='')

new_model = tf.GraphDef()

with tf.Session(graph=graph) as sess:    
    for n in sess.graph_def.node:            
        nn = new_model.node.add()
        nn.CopyFrom(n)
        if n.op.name == 'pool5':
            break;

tf.train.write_graph(new_model, '.', 'cut_model.pb', as_text=False)

Upvotes: 1

Related Questions