Keras + JAX smoke test

Keras + JAX smoke test#

import os

os.environ["KERAS_BACKEND"] = "jax"

import jax
import keras
import numpy as np

import kinetic


@kinetic.run(accelerator="cpu")
def train_keras_jax_model():
  print(f"Keras version: {keras.__version__}")
  print(f"Keras backend: {keras.config.backend()}")
  print(f"JAX version: {jax.__version__}")
  print(f"JAX devices: {jax.devices()}")

  num_classes = 10
  input_shape = (28, 28, 1)

  model = keras.Sequential(
    [
      keras.layers.Input(shape=input_shape),
      keras.layers.Conv2D(64, kernel_size=(3, 3), activation="relu"),
      keras.layers.Conv2D(64, kernel_size=(3, 3), activation="relu"),
      keras.layers.MaxPooling2D(pool_size=(2, 2)),
      keras.layers.Conv2D(128, kernel_size=(3, 3), activation="relu"),
      keras.layers.Conv2D(128, kernel_size=(3, 3), activation="relu"),
      keras.layers.GlobalAveragePooling2D(),
      keras.layers.Dropout(0.5),
      keras.layers.Dense(num_classes, activation="softmax"),
    ]
  )
  print("Model defined.")

  model.compile(
    loss=keras.losses.SparseCategoricalCrossentropy(),
    optimizer=keras.optimizers.Adam(learning_rate=1e-3),
    metrics=[
      keras.metrics.SparseCategoricalAccuracy(name="acc"),
    ],
  )
  print("Model compiled.")

  # Dummy data
  num_samples = 1024
  x_train = np.random.rand(num_samples, *input_shape).astype(np.float32)
  y_train = np.random.randint(0, num_classes, size=(num_samples,)).astype(
    np.int32
  )

  print("Starting model.fit...")
  history = model.fit(x_train, y_train, epochs=5, batch_size=32)
  print("Model.fit finished.")

  return history.history["loss"][-1]


if __name__ == "__main__":
  print("Starting Keras JAX demo...")
  loss = train_keras_jax_model()
  print(f"Final training loss: {loss}")