TensorFlow Model Export

In order to deploy trained machine learning models to various target platforms (e.g. servers, mobile, embedded devices, browsers, etc.), our first step is often to export (serialize) the entire trained model into a series of files with standard format. TensorFlow provides a unified model export format, SavedModel, which allows us to deploy our trained models on a variety of platforms using this format as an intermediary. It is the main export format we use in TensorFlow 2. Also, for historical reasons, Keras’s Sequential and Functional models have their own model export formats, which we will also introduce later.

Export models by SavedModel

In the previous section we introduced Checkpoint, which helps us save and recover the weights in the model. SavedModel, as a model export format, goes one step further and contains complete information about a TensorFlow program: not only the weights of the model, but also the computation process (i.e., the dataflow graph). When the model is exported as a SavedModel file, the model can be run again without source code, which makes SavedModel especially suitable for model sharing and deployment. This format is used later in TensorFlow Serving (server-side deployment), TensorFlow Lite (mobile-side deployment), and TensorFlow.js.

All Keras models can be easily exported to SavedModel format. It should be noted, however, that since SavedModel is based on graph execution mode, any method (e.g. call) that needs to be exported to SavedModel format requires to be decorated by @tf.function (see :ref:` previous <tffunction>` for the usage of @tf.function. Models built with sequantial or function API is not required for the decoration). Then, assuming we have a Keras model named model, the model can be exported as SavedModel using the following code.

tf.saved_model.save(model, "target export folder")

When you need to load a SavedModel file, use

model = tf.saved_model.load("target export folder")

Hint

For the Keras model model inheriting tf.keras.Model class, loaded instance using the SavedModel will not allow direct inference using model()`, but will require the use of ``model.call().

Here is a simple example of exporting and importing the model of previous MNIST digit classification task.

Export the model to the saved/1 folder.

import tensorflow as tf
from zh.model.utils import MNISTLoader

num_epochs = 1
batch_size = 50
learning_rate = 0.001

model = tf.keras.models.Sequential([
    tf.keras.layers.Flatten(),
    tf.keras.layers.Dense(100, activation=tf.nn.relu),
    tf.keras.layers.Dense(10),
    tf.keras.layers.Softmax()
])

data_loader = MNISTLoader()
model.compile(
    optimizer=tf.keras.optimizers.Adam(learning_rate=0.001),
    loss=tf.keras.losses.sparse_categorical_crossentropy,
    metrics=[tf.keras.metrics.sparse_categorical_accuracy]
)
model.fit(data_loader.train_data, data_loader.train_label, epochs=num_epochs, batch_size=batch_size)
tf.saved_model.save(model, "saved/1")

Import and test the performance of the exported model in saved/1.

import tensorflow as tf
from zh.model.utils import MNISTLoader

batch_size = 50

model = tf.saved_model.load("saved/1")
data_loader = MNISTLoader()
sparse_categorical_accuracy = tf.keras.metrics.SparseCategoricalAccuracy()
num_batches = int(data_loader.num_test_data // batch_size)
for batch_index in range(num_batches):
    start_index, end_index = batch_index * batch_size, (batch_index + 1) * batch_size
    y_pred = model(data_loader.test_data[start_index: end_index])
    sparse_categorical_accuracy.update_state(y_true=data_loader.test_label[start_index: end_index], y_pred=y_pred)
print("test accuracy: %f" % sparse_categorical_accuracy.result())

Output:

test accuracy: 0.952000

Keras models inheriting tf.keras.Model class can also be exported in the same way, but note that the call method requires a @tf.function modification to translate into a SavedModel-supported dataflow graph. The following code is an example

class MLP(tf.keras.Model):
    def __init__(self):
        super().__init__()
        self.flatten = tf.keras.layers.Flatten()
        self.dense1 = tf.keras.layers.Dense(units=100, activation=tf.nn.relu)
        self.dense2 = tf.keras.layers.Dense(units=10)

    @tf.function
    def call(self, inputs):         # [batch_size, 28, 28, 1]
        x = self.flatten(inputs)    # [batch_size, 784]
        x = self.dense1(x)          # [batch_size, 100]
        x = self.dense2(x)          # [batch_size, 10]
        output = tf.nn.softmax(x)
        return output

model = MLP()
...

The process of importing the model is the same, except that model inference requires an explicit call to the call method, i.e. using.

y_pred = model.call(data_loader.test_data[start_index: end_index])