Import a JAX model using JAX2TF

View on TensorFlow.org Run in Google Colab View source on GitHub Download notebook

This notebook provides a complete, runnable example of creating a model using JAX and bringing it into TensorFlow to continue training. This is made possible by JAX2TF, a lightweight API that provides a pathway from the JAX ecosystem to the TensorFlow ecosystem.

JAX is a high-performance array computing library. To create the model, this notebook uses Flax, a neural network library for JAX. To train it, it uses Optax, an optimization library for JAX.

If you're a researcher using JAX, JAX2TF gives you a path to production using TensorFlow's proven tools.

There are many ways this can be useful, here are just a few:

  • Inference: Taking a model written for JAX and deploying it either on a server using TF Serving, on-device using TFLite, or on the web using TensorFlow.js.

  • Fine-tuning: Taking a model that was trained using JAX, you can bring its components to TF using JAX2TF, and continue training it in TensorFlow with your existing training data and setup.

  • Fusion: Combining parts of models that were trained using JAX with those trained using TensorFlow, for maximum flexibility.

The key to enabling this kind of interoperation between JAX and TensorFlow is jax2tf.convert, which takes in model components created on top of JAX (your loss function, prediction function, etc) and creates equivalent representations of them as TensorFlow functions, which can then be exported as a TensorFlow SavedModel.

Setup

import tensorflow as tf
import numpy as np
import jax
import jax.numpy as jnp
import flax
import optax
import os
from matplotlib import pyplot as plt
from jax.experimental import jax2tf
from threading import Lock # Only used in the visualization utility.
from functools import partial
# Needed for TensorFlow and JAX to coexist in GPU memory.
os.environ['XLA_PYTHON_CLIENT_PREALLOCATE'] = "false"
gpus = tf.config.list_physical_devices('GPU')
if gpus:
  try:
    for gpu in gpus:
      tf.config.experimental.set_memory_growth(gpu, True)
  except RuntimeError as e:
    # Memory growth must be set before GPUs have been initialized.
    print(e)

Visualization utilities

plt.rcParams["figure.figsize"] = (20,8)

# The utility for displaying training and validation curves.
def display_train_curves(loss, avg_loss, eval_loss, eval_accuracy, epochs, steps_per_epochs, ignore_first_n=10):

  ignore_first_n_epochs = int(ignore_first_n/steps_per_epochs)

  # The losses.
  ax = plt.subplot(121)
  if loss is not None:
    x = np.arange(len(loss)) / steps_per_epochs #* epochs
    ax.plot(x, loss)
  ax.plot(range(1, epochs+1), avg_loss, "-o", linewidth=3)
  ax.plot(range(1, epochs+1), eval_loss, "-o", linewidth=3)
  ax.set_title('Loss')
  ax.set_ylabel('loss')
  ax.set_xlabel('epoch')
  if loss is not None:
    ax.set_ylim(0, np.max(loss[ignore_first_n:]))
    ax.legend(['train', 'avg train', 'eval'])
  else:
    ymin = np.min(avg_loss[ignore_first_n_epochs:])
    ymax = np.max(avg_loss[ignore_first_n_epochs:])
    ax.set_ylim(ymin-(ymax-ymin)/10, ymax+(ymax-ymin)/10)
    ax.legend(['avg train', 'eval'])

  # The accuracy.
  ax = plt.subplot(122)
  ax.set_title('Eval Accuracy')
  ax.set_ylabel('accuracy')
  ax.set_xlabel('epoch')
  ymin = np.min(eval_accuracy[ignore_first_n_epochs:])
  ymax = np.max(eval_accuracy[ignore_first_n_epochs:])
  ax.set_ylim(ymin-(ymax-ymin)/10, ymax+(ymax-ymin)/10)
  ax.plot(range(1, epochs+1), eval_accuracy, "-o", linewidth=3)

class Progress:
    """Text mode progress bar.
    Usage:
            p = Progress(30)
            p.step()
            p.step()
            p.step(reset=True) # to restart form 0%
    The progress bar displays a new header at each restart."""
    def __init__(self, maxi, size=100, msg=""):
        """
        :param maxi: the number of steps required to reach 100%
        :param size: the number of characters taken on the screen by the progress bar
        :param msg: the message displayed in the header of the progress bar
        """
        self.maxi = maxi
        self.p = self.__start_progress(maxi)()  # `()`: to get the iterator from the generator.
        self.header_printed = False
        self.msg = msg
        self.size = size
        self.lock = Lock()

    def step(self, reset=False):
        with self.lock:
            if reset:
                self.__init__(self.maxi, self.size, self.msg)
            if not self.header_printed:
                self.__print_header()
            next(self.p)

    def __print_header(self):
        print()
        format_string = "0%{: ^" + str(self.size - 6) + "}100%"
        print(format_string.format(self.msg))
        self.header_printed = True

    def __start_progress(self, maxi):
        def print_progress():
            # Bresenham's algorithm. Yields the number of dots printed.
            # This will always print 100 dots in max invocations.
            dx = maxi
            dy = self.size
            d = dy - dx
            for x in range(maxi):
                k = 0
                while d >= 0:
                    print('=', end="", flush=True)
                    k += 1
                    d -= dx
                d += dy
                yield k
            # Keep yielding the last result if there are too many steps.
            while True:
              yield k

        return print_progress