TensorFlow Model Export¶
In order to deploy trained machine learning models to various target platforms (e.g. servers, mobile, embedded devices, browsers, etc.), our first step is often to export (serialize) the entire trained model into a series of files with standard format. TensorFlow provides a unified model export format, SavedModel, which allows us to deploy our trained models on a variety of platforms using this format as an intermediary. It is the main export format we use in TensorFlow 2. Also, for historical reasons, Keras’s Sequential and Functional models have their own model export formats, which we will also introduce later.
Export models by SavedModel¶
In the previous section we introduced Checkpoint, which helps us save and recover the weights in the model. SavedModel, as a model export format, goes one step further and contains complete information about a TensorFlow program: not only the weights of the model, but also the computation process (i.e., the dataflow graph). When the model is exported as a SavedModel file, the model can be run again without source code, which makes SavedModel especially suitable for model sharing and deployment. This format is used later in TensorFlow Serving (server-side deployment), TensorFlow Lite (mobile-side deployment), and TensorFlow.js.
All Keras models can be easily exported to SavedModel format. It should be noted, however, that since SavedModel is based on graph execution mode, any method (e.g.
call) that needs to be exported to SavedModel format requires to be decorated by
@tf.function (see :ref:` previous <tffunction>` for the usage of
@tf.function. Models built with sequantial or function API is not required for the decoration). Then, assuming we have a Keras model named
model, the model can be exported as SavedModel using the following code.
tf.saved_model.save(model, "target export folder")
When you need to load a SavedModel file, use
model = tf.saved_model.load("target export folder")
For the Keras model
tf.keras.Model class, loaded instance using the SavedModel will not allow direct inference using
model()`, but will require the use of ``model.call().
Here is a simple example of exporting and importing the model of previous MNIST digit classification task.
Export the model to the
import tensorflow as tf from zh.model.utils import MNISTLoader num_epochs = 1 batch_size = 50 learning_rate = 0.001 model = tf.keras.models.Sequential([ tf.keras.layers.Flatten(), tf.keras.layers.Dense(100, activation=tf.nn.relu), tf.keras.layers.Dense(10), tf.keras.layers.Softmax() ]) data_loader = MNISTLoader() model.compile( optimizer=tf.keras.optimizers.Adam(learning_rate=0.001), loss=tf.keras.losses.sparse_categorical_crossentropy, metrics=[tf.keras.metrics.sparse_categorical_accuracy] ) model.fit(data_loader.train_data, data_loader.train_label, epochs=num_epochs, batch_size=batch_size) tf.saved_model.save(model, "saved/1")
Import and test the performance of the exported model in
import tensorflow as tf from zh.model.utils import MNISTLoader batch_size = 50 model = tf.saved_model.load("saved/1") data_loader = MNISTLoader() sparse_categorical_accuracy = tf.keras.metrics.SparseCategoricalAccuracy() num_batches = int(data_loader.num_test_data // batch_size) for batch_index in range(num_batches): start_index, end_index = batch_index * batch_size, (batch_index + 1) * batch_size y_pred = model(data_loader.test_data[start_index: end_index]) sparse_categorical_accuracy.update_state(y_true=data_loader.test_label[start_index: end_index], y_pred=y_pred) print("test accuracy: %f" % sparse_categorical_accuracy.result())
test accuracy: 0.952000
Keras models inheriting
tf.keras.Model class can also be exported in the same way, but note that the
call method requires a
@tf.function modification to translate into a SavedModel-supported dataflow graph. The following code is an example
class MLP(tf.keras.Model): def __init__(self): super().__init__() self.flatten = tf.keras.layers.Flatten() self.dense1 = tf.keras.layers.Dense(units=100, activation=tf.nn.relu) self.dense2 = tf.keras.layers.Dense(units=10) @tf.function def call(self, inputs): # [batch_size, 28, 28, 1] x = self.flatten(inputs) # [batch_size, 784] x = self.dense1(x) # [batch_size, 100] x = self.dense2(x) # [batch_size, 10] output = tf.nn.softmax(x) return output model = MLP() ...
The process of importing the model is the same, except that model inference requires an explicit call to the
call method, i.e. using.
y_pred = model.call(data_loader.test_data[start_index: end_index])