
[Memo]Tensorflow: how to save/restore a model?

In(and after) Tensorflow version 0.11:

Save the model:

import tensorflow as tf
#Prepare to feed input, i.e. feed_dict and placeholders
w1 = tf.placeholder("float", name="w1")
w2 = tf.placeholder("float", name="w2")
b1= tf.Variable(2.0,name="bias")
feed_dict ={w1:4,w2:8}
#Define a test operation that we will restore
w3 = tf.add(w1,w2)
w4 = tf.multiply(w3,b1,name="op_to_restore")
sess = tf.Session()
#Create a saver object which will save all the variables
saver = tf.train.Saver()
#Run the operation by feeding input
print sess.run(w4,feed_dict)
#Prints 24 which is sum of (w1+w2)*b1 
#Now, save the graph
saver.save(sess, 'my_test_model',global_step=1000)

Restore the model:

import tensorflow as tf
#First let's load meta graph and restore weights
saver = tf.train.import_meta_graph('my_test_model-1000.meta')
# Access saved Variables directly
# This will print 2, which is the value of bias that we saved
# Now, let's access and create placeholders variables and
# create feed-dict to feed new data
graph = tf.get_default_graph()
w1 = graph.get_tensor_by_name("w1:0")
w2 = graph.get_tensor_by_name("w2:0")
feed_dict ={w1:13.0,w2:17.0}
#Now, access the op that you want to run. 
op_to_restore = graph.get_tensor_by_name("op_to_restore:0")
print sess.run(op_to_restore,feed_dict)
#This will print 60 which is calculated 

This and some more advanced use-cases have been explained very well here.




For TensorFlow version < 0.11.0RC1:

The checkpoints that are saved contain values for the Variables in your model, not the model/graph itself, which means that the graph should be the same when you restore the checkpoint.

Here's an example for a linear regression where there's a training loop that saves variable checkpoints and an evaluation section that will restore variables saved in a prior run and compute predictions. Of course, you can also restore variables and continue training if you'd like.

x = tf.placeholder(tf.float32)
y = tf.placeholder(tf.float32)
w = tf.Variable(tf.zeros([1, 1], dtype=tf.float32))
b = tf.Variable(tf.ones([1, 1], dtype=tf.float32))
y_hat = tf.add(b, tf.matmul(x, w))
...more setup for optimization and what not...
saver = tf.train.Saver()  # defaults to saving all variables - in this case w and b
with tf.Session() as sess:
    if FLAGS.train:
        for i in xrange(FLAGS.training_steps):
            ...training loop...
            if (i + 1) % FLAGS.checkpoint_steps == 0:
                saver.save(sess, FLAGS.checkpoint_dir + 'model.ckpt',
        # Here's where you're restoring the variables w and b.
        # Note that the graph is exactly as it was when the variables were
        # saved in a prior training run.
        ckpt = tf.train.get_checkpoint_state(FLAGS.checkpoint_dir)
        if ckpt and ckpt.model_checkpoint_path:
            saver.restore(sess, ckpt.model_checkpoint_path)
            ...no checkpoint found...
        # Now you can run the model to get predictions
        batch_x = ...load some data...
        predictions = sess.run(y_hat, feed_dict={x: batch_x})

Here are the docs for Variables, which cover saving and restoring. And here are the docs for the Saver.



You can also take this easier way.

Step.1 - Initialize all your variables

W1 = tf.Variable(tf.truncated_normal([6, 6, 1, K], stddev=0.1), name="W1")

B1 = tf.Variable(tf.constant(0.1, tf.float32, [K]), name="B1")


Similarly, W2, B2, W3, .....

Step.2 - Save the list inside Model Saver and Save it

model_saver = tf.train.Saver()


# Train the model and save it in the end

model_saver.save(session, "saved_models/CNN_New.ckpt")

Step. 3 - Restore the model

with tf.Session(graph=graph_cnn) as session:

    model_saver.restore(session, "saved_models/CNN_New.ckpt")

    print("Model restored.")


Step. 4 - Check Variable

W1 = session.run(W1)


While running in different python instance, use

with tf.Session() as sess:

    # Restore latest checkpoint

    saver.restore(sess, tf.train.latest_checkpoint('saved_model/.'))


    # Initalize the variables



    # Get default graph (supply your custom graph if you have one)

    graph = tf.get_default_graph()


    # It will give tensor object

    W1 = graph.get_tensor_by_name('W1:0')


    # To get the value (numpy array)

    W1_value = session.run(W1)



In most cases, saving and restoring from disk using a tf.train.Saver is your best option:

... # build your model
saver = tf.train.Saver()
with tf.Session() as sess:
    ... # train the model
    saver.save(sess, "/tmp/my_great_model")
with tf.Session() as sess:
    saver.restore(sess, "/tmp/my_great_model")
    ... # use the model

You can also save/restore the graph structure itself (see the MetaGraph documentation for details). By default, the Saver saves the graph structure into a .meta file. You can call import_meta_graph() to restore it. It restores the graph structure and returns a Saver that you can use to restore the model's state:

saver = tf.train.import_meta_graph("/tmp/my_great_model.meta")
with tf.Session() as sess:
    saver.restore(sess, "/tmp/my_great_model")
    ... # use the model

However, there are cases where you need something much faster. For example, if you implement early stopping, you want to save checkpoints every time the model improves during training (as measured on the validation set), then if there is no progress for some time, you want to roll back to the best model. If you save the model to disk every time it improves, it will tremendously slow down training. The trick is to save the variable states to memory, then just restore them later:

... # build your model
# get a handle on the graph nodes we need to save/restore the model
graph = tf.get_default_graph()
gvars = graph.get_collection(tf.GraphKeys.GLOBAL_VARIABLES)
assign_ops = [graph.get_operation_by_name(v.op.name + "/Assign") for v in gvars]
init_values = [assign_op.inputs[1] for assign_op in assign_ops]
with tf.Session() as sess:
    ... # train the model
    # when needed, save the model state to memory
    gvars_state = sess.run(gvars)
    # when needed, restore the model state
    feed_dict = {init_value: val
                 for init_value, val in zip(init_values, gvars_state)}
    sess.run(assign_ops, feed_dict=feed_dict)

A quick explanation: when you create a variable X, TensorFlow automatically creates an assignment operation X/Assign to set the variable's initial value. Instead of creating placeholders and extra assignment ops (which would just make the graph messy), we just use these existing assignment ops. The first input of each assignment op is a reference to the variable it is supposed to initialize, and the second input (assign_op.inputs[1]) is the initial value. So in order to set any value we want (instead of the initial value), we need to use a feed_dict and replace the initial value. Yes, TensorFlow lets you feed a value for any op, not just for placeholders, so this works fine.


