Train a miniGPT language model with JAX#

Open in Colab

This tutorial demonstrates how to use JAX, Flax NNX and Optax for language model (pre)training using data and tensor parallelism for Single-Program Multi-Data). It was originally inspired by the Keras miniGPT tutorial.

Here, you will learn how to:

  • Define the miniGPT model with Flax and JAX automatic parallelism

  • Load and preprocess the dataset

  • Create the loss and training step functions

  • Train the model on Google Colab’s Cloud TPU v2

  • Profile for hyperparameter tuning

If you are new to JAX for AI, check out the introductory tutorial, which covers neural network building with Flax NNX.

Setup#

JAX installation is covered in this guide on the JAX documentation site. We will use Tiktoken for tokenization and Grain for data loading.

!pip install -Uq tiktoken grain matplotlib
   ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 1.2/1.2 MB 21.1 MB/s eta 0:00:00
   ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 478.8/478.8 kB 20.8 MB/s eta 0:00:00
   ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 8.6/8.6 MB 110.0 MB/s eta 0:00:00
   ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 65.3/65.3 kB 3.9 MB/s eta 0:00:00
?25h

Note: If you are using Google Colab, select the free Google Cloud TPU v2 as the hardware accelerator.

Check the available JAX devices, or jax.Device, with jax.devices(). The output of the cell below will show a list of 8 (eight) devices.

import jax
jax.devices()
[TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0),
 TpuDevice(id=1, process_index=0, coords=(0,0,0), core_on_chip=1),
 TpuDevice(id=2, process_index=0, coords=(1,0,0), core_on_chip=0),
 TpuDevice(id=3, process_index=0, coords=(1,0,0), core_on_chip=1),
 TpuDevice(id=4, process_index=0, coords=(0,1,0), core_on_chip=0),
 TpuDevice(id=5, process_index=0, coords=(0,1,0), core_on_chip=1),
 TpuDevice(id=6, process_index=0, coords=(1,1,0), core_on_chip=0),
 TpuDevice(id=7, process_index=0, coords=(1,1,0), core_on_chip=1)]

Get the TinyStories dataset from Hugging Face. We only use the training split.

!wget https://huggingface.co/datasets/roneneldan/TinyStories/resolve/main/TinyStories-train.txt?download=true -O TinyStories-train.txt
--2025-04-25 01:37:13--  https://huggingface.co/datasets/roneneldan/TinyStories/resolve/main/TinyStories-train.txt?download=true
Resolving huggingface.co (huggingface.co)... 18.172.134.124, 18.172.134.4, 18.172.134.24, ...
Connecting to huggingface.co (huggingface.co)|18.172.134.124|:443... connected.
HTTP request sent, awaiting response... 302 Found
Location: https://cdn-lfs.hf.co/repos/42/7f/427f7497b6c6596c18b46d5a72e61364fcad12aa433c60a0dbd4d344477b9d81/c5cf5e22ff13614e830afbe61a99fbcbe8bcb7dd72252b989fa1117a368d401f?response-content-disposition=attachment%3B+filename*%3DUTF-8%27%27TinyStories-train.txt%3B+filename%3D%22TinyStories-train.txt%22%3B&response-content-type=text%2Fplain&Expires=1745548633&Policy=eyJTdGF0ZW1lbnQiOlt7IkNvbmRpdGlvbiI6eyJEYXRlTGVzc1RoYW4iOnsiQVdTOkVwb2NoVGltZSI6MTc0NTU0ODYzM319LCJSZXNvdXJjZSI6Imh0dHBzOi8vY2RuLWxmcy5oZi5jby9yZXBvcy80Mi83Zi80MjdmNzQ5N2I2YzY1OTZjMThiNDZkNWE3MmU2MTM2NGZjYWQxMmFhNDMzYzYwYTBkYmQ0ZDM0NDQ3N2I5ZDgxL2M1Y2Y1ZTIyZmYxMzYxNGU4MzBhZmJlNjFhOTlmYmNiZThiY2I3ZGQ3MjI1MmI5ODlmYTExMTdhMzY4ZDQwMWY%7EcmVzcG9uc2UtY29udGVudC1kaXNwb3NpdGlvbj0qJnJlc3BvbnNlLWNvbnRlbnQtdHlwZT0qIn1dfQ__&Signature=o3wAd1AJqh8HJqOoCSsN2UhG0CD18l4DvfzfbXQYuGu1ULd41OH67qYah6Gqa8UoiOR-vY2mL68PRKmR5xzN86-u0A-ONmHjeXfKuQj3JtAD2jJQ9Y1IvItldi8bW6yfpKMqtz8VKQ5iU%7EWsjkopTP%7EuLX%7EqbJleJ%7E2QoIeMfHQGSA-5ijnrTlJdjcsrlP-owPmmZ0xS8cWPFYLIFrL4Wi3JuddcN1AZDY9XraKobFVUrzzCgJBF5xgmRfGBejZbWtmm6VBhViB1m1CSoPSji5mrRlr1LclBFBmAcrkQ2QuPWjgYGBT7ONwmVSI0kjTO09z0%7E8mneFC3vQYJIaLItw__&Key-Pair-Id=K3RPWS32NSSJCE [following]
--2025-04-25 01:37:13--  https://cdn-lfs.hf.co/repos/42/7f/427f7497b6c6596c18b46d5a72e61364fcad12aa433c60a0dbd4d344477b9d81/c5cf5e22ff13614e830afbe61a99fbcbe8bcb7dd72252b989fa1117a368d401f?response-content-disposition=attachment%3B+filename*%3DUTF-8%27%27TinyStories-train.txt%3B+filename%3D%22TinyStories-train.txt%22%3B&response-content-type=text%2Fplain&Expires=1745548633&Policy=eyJTdGF0ZW1lbnQiOlt7IkNvbmRpdGlvbiI6eyJEYXRlTGVzc1RoYW4iOnsiQVdTOkVwb2NoVGltZSI6MTc0NTU0ODYzM319LCJSZXNvdXJjZSI6Imh0dHBzOi8vY2RuLWxmcy5oZi5jby9yZXBvcy80Mi83Zi80MjdmNzQ5N2I2YzY1OTZjMThiNDZkNWE3MmU2MTM2NGZjYWQxMmFhNDMzYzYwYTBkYmQ0ZDM0NDQ3N2I5ZDgxL2M1Y2Y1ZTIyZmYxMzYxNGU4MzBhZmJlNjFhOTlmYmNiZThiY2I3ZGQ3MjI1MmI5ODlmYTExMTdhMzY4ZDQwMWY%7EcmVzcG9uc2UtY29udGVudC1kaXNwb3NpdGlvbj0qJnJlc3BvbnNlLWNvbnRlbnQtdHlwZT0qIn1dfQ__&Signature=o3wAd1AJqh8HJqOoCSsN2UhG0CD18l4DvfzfbXQYuGu1ULd41OH67qYah6Gqa8UoiOR-vY2mL68PRKmR5xzN86-u0A-ONmHjeXfKuQj3JtAD2jJQ9Y1IvItldi8bW6yfpKMqtz8VKQ5iU%7EWsjkopTP%7EuLX%7EqbJleJ%7E2QoIeMfHQGSA-5ijnrTlJdjcsrlP-owPmmZ0xS8cWPFYLIFrL4Wi3JuddcN1AZDY9XraKobFVUrzzCgJBF5xgmRfGBejZbWtmm6VBhViB1m1CSoPSji5mrRlr1LclBFBmAcrkQ2QuPWjgYGBT7ONwmVSI0kjTO09z0%7E8mneFC3vQYJIaLItw__&Key-Pair-Id=K3RPWS32NSSJCE
Resolving cdn-lfs.hf.co (cdn-lfs.hf.co)... 3.167.152.106, 3.167.152.37, 3.167.152.12, ...
Connecting to cdn-lfs.hf.co (cdn-lfs.hf.co)|3.167.152.106|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 1924281556 (1.8G) [text/plain]
Saving to: ‘TinyStories-train.txt’

TinyStories-train.t 100%[===================>]   1.79G   259MB/s    in 7.3s    

2025-04-25 01:37:20 (251 MB/s) - ‘TinyStories-train.txt’ saved [1924281556/1924281556]

Import the necessary modules, including JAX NumPy, Flax NNX, Optax, Grain, pandas, and Tiktoken:

import jax
import jax.numpy as jnp

from jax.sharding import Mesh, PartitionSpec as P, NamedSharding # For data and model parallelism (explained in more detail later)
from jax.experimental import mesh_utils

import flax.nnx as nnx
import optax

from dataclasses import dataclass
import grain.python as pygrain
import pandas as pd
import tiktoken
import time

Define the miniGPT model with Flax and JAX automatic parallelism#

Leveraging JAX’s data and tensor parallelism#

One of the most powerful features of JAX is device parallelism for SPMD.

  • The data parallelism technique enables, for example, the training data to run via multiple parts (this is called sharding) - batches - in parallel and simultaneously across different devices, such as GPUs and Google TPUs. This allows to use larger batch sizes to speed up training.

  • Tensor parallelism allows us to split the model parameter tensors across several devices (sharding model tensors).

  • You can learn more about the basics of JAX parallelism in more detail in the Introduction to parallel programming on the JAX documentation site.

In this example, we’ll utilize a 4-way data parallel and 2-way tensor parallel setup. The free Google Cloud TPU v2 on Google Colab offers 4 chips, each with 2 TPU cores. The TPU v2 architeture aligns with the proposed setup.

jax.sharding.Mesh#

Earlier, we imported jax.sharding.Mesh - is a multidimensional NumPy array of JAX devices, where each axis of the mesh has a name, such as 'x' or 'y'. This will help encapsulate the information about the TPU resource organization for distributing computations across the devices.

Our Mesh will have two arguments:

  • devices: This will take the value of jax.experimental.mesh_utils((4, 2)), enabling us to build a device mesh. It is a NumPy ndarray with JAX devices (a list of devices from the JAX backend as obtained from jax.devices())..

  • axis_names, where:

    • batch: 4 devices along the first axis - i.e. sharded into 4 - for data parallelism; and

    • model: 2 devices along the second axis - i.e. sharded into 2 - for tensor paralleism, mapping to the TPU v2 cores.

This matches the (4, 2) structure in the Colab’s TPU v2 setup.

Let’s instantiate Mesh as mesh and declare the TPU configuration to define how data and model parameters are distributed across the devices:

# Create a `Mesh` object representing TPU device arrangement.
mesh = Mesh(mesh_utils.create_device_mesh((4, 2)), ('batch', 'model'))

### Alternatively, we could use the 8-way data parallelism with only one line of code change.
### JAX enables quick experimentation with different partitioning strategies
### like this. We will come back to this point at the end of this tutorial.
# mesh = Mesh(mesh_utils.create_device_mesh((8, 1)), ('batch', 'model'))

We will use the GPT-2 tokenizer from the Tiktoken library:

tokenizer = tiktoken.get_encoding("gpt2")

To leverage model parallelism, we need to instruct the JAX compiler how to shard the model tensors across the TPU devices. Earlier, we also imported jax.sharding.PartitionSpec and jax.sharding.NamedSharding:

  • PartitionSpec (using alias P) defines how tensors are sharded across the devices in our Mesh. Its elements describe how an input dimension is partitioned across mesh dimensions. For example, in PartitionSpec('x', 'y') the first dimension of data is sharded across x axis of the mesh, and the second one - across the y axis.

    • We’ll use PartitionSpec to describe how to shard a tensor across, for example, the model axis or be replicated on other dimensions (which is denoted by None).

  • NamedSharding is a (Mesh, PartitionSpec) pair that describes how to shard a model tensor across our mesh.

  • We combine Mesh (the TPU resources) with PartitionSpec and create a NamedSharding, which instructs how to shard each model tensor across the TPU devices.

Additionally, we’ll use Flax NNX’s flax.nnx.with_partitioning to let each model layer know that the model weights or tensors need to be sharded according to our specification. We need to do this for every tensor/layer in the model.

# Define a triangular mask for causal attention with `jax.numpy.tril` and `jax.numpy.ones`.
def causal_attention_mask(seq_len):
    return jnp.tril(jnp.ones((seq_len, seq_len)))

class TransformerBlock(nnx.Module):
    """ A single Transformer block.

    Each Transformer block processes input sequences via self-attention and feed-forward networks.

    Args:
        embed_dim (int): Embedding dimensionality.
        num_heads (int): Number of attention heads.
        ff_dim (int): Dimensionality of the feed-forward network.
        rngs (flax.nnx.Rngs): A Flax NNX stream of JAX PRNG keys.
        rate (float): Dropout rate. Defaults to 0.1.
    """
    def __init__(self, embed_dim: int, num_heads: int, ff_dim: int, *, rngs: nnx.Rngs, rate: float = 0.1):
        # Multi-Head Attention (MHA) with `flax.nnx.MultiHeadAttention`.
        # Specifies tensor sharding (depending on the mesh configuration)
        # where we shard the weights across devices for parallel computation.
        self.mha = nnx.MultiHeadAttention(num_heads=num_heads,
                                          in_features=embed_dim,
                                          kernel_init=nnx.with_partitioning(nnx.initializers.xavier_uniform(), NamedSharding(mesh, P(None, 'model'))),
                                          bias_init=nnx.with_partitioning(nnx.initializers.zeros_init(), NamedSharding(mesh, P('model'))),
                                          rngs=rngs)
        # The first dropout with `flax.nnx.Dropout`.
        self.dropout1 = nnx.Dropout(rate=rate)
        # First layer normalization with `flax.nnx.LayerNorm`.
        self.layer_norm1 = nnx.LayerNorm(epsilon=1e-6,
                                         num_features=embed_dim,
                                         scale_init=nnx.with_partitioning(nnx.initializers.ones_init(), NamedSharding(mesh, P('model'))),
                                         bias_init=nnx.with_partitioning(nnx.initializers.zeros_init(), NamedSharding(mesh, P('model'))),
                                         rngs=rngs)
        # The first linear transformation for the feed-forward network with `flax.nnx.Linear`.
        self.linear1 = nnx.Linear(in_features=embed_dim,
                                  out_features=ff_dim,
                                  kernel_init=nnx.with_partitioning(nnx.initializers.xavier_uniform(), NamedSharding(mesh, P(None, 'model'))),
                                  bias_init=nnx.with_partitioning(nnx.initializers.zeros_init(), NamedSharding(mesh, P('model'))),
                                  rngs=rngs)
        # The second linear transformation for the feed-forward network with `flax.nnx.Linear`.
        self.linear2 = nnx.Linear(in_features=ff_dim,
                                  out_features=embed_dim,
                                  kernel_init=nnx.with_partitioning(nnx.initializers.xavier_uniform(), NamedSharding(mesh, P(None, 'model'))),
                                  bias_init=nnx.with_partitioning(nnx.initializers.zeros_init(), NamedSharding(mesh, P('model'))),
                                  rngs=rngs)
        # The second dropout with `flax.nnx.Dropout`.
        self.dropout2 = nnx.Dropout(rate=rate)
        # Second layer normalization with `flax.nnx.LayerNorm`.
        self.layer_norm2 = nnx.LayerNorm(epsilon=1e-6,
                                         num_features=embed_dim,
                                         scale_init=nnx.with_partitioning(nnx.initializers.ones_init(), NamedSharding(mesh, P(None, 'model'))),
                                         bias_init=nnx.with_partitioning(nnx.initializers.zeros_init(), NamedSharding(mesh, P(None, 'model'))),
                                         rngs=rngs)


    # Apply the Transformer block to the input sequence.
    def __call__(self, inputs, training: bool = False):
        input_shape = inputs.shape
        _, seq_len, _ = input_shape

        # Instantiate the causal attention mask.
        mask = causal_attention_mask(seq_len)

        # Apply Multi-Head Attention with the causal attention mask.
        attention_output = self.mha(
            inputs_q=inputs,
            mask=mask,
            decode=False
        )
        # Apply the first dropout.
        attention_output = self.dropout1(attention_output, deterministic=not training)
        # Apply the first layer normalization.
        out1 = self.layer_norm1(inputs + attention_output)

        # The feed-forward network.
        # Apply the first linear transformation.
        ffn_output = self.linear1(out1)
        # Apply the ReLU activation with `flax.nnx.relu`.
        ffn_output = nnx.relu(ffn_output)
        # Apply the second linear transformation.
        ffn_output = self.linear2(ffn_output)
        # Apply the second dropout.
        ffn_output = self.dropout2(ffn_output, deterministic=not training)
        # Apply the second layer normalization and return the output of the Transformer block.
        return self.layer_norm2(out1 + ffn_output)

class TokenAndPositionEmbedding(nnx.Module):
    """ Combines token embeddings (words in an input sentence) with
    positional embeddings (the position of each word in a sentence).

    Args:
        maxlen (int): Matimum sequence length.
        vocal_size (int): Vocabulary size.
        embed_dim (int): Embedding dimensionality.
        rngs (flax.nnx.Rngs): A Flax NNX stream of JAX PRNG keys.
    """
    def __init__(self, maxlen: int, vocab_size: int, embed_dim: int, *, rngs: nnx.Rngs):
        # Initialize token embeddings (using `flax.nnx.Embed`).
        # Each unique word has an embedding vector.
        self.token_emb = nnx.Embed(num_embeddings=vocab_size, features=embed_dim, rngs=rngs)
        # Initialize positional embeddings (using `flax.nnx.Embed`).
        self.pos_emb = nnx.Embed(num_embeddings=maxlen, features=embed_dim, rngs=rngs)

    # Takes a token sequence (integers) and returns the combined token and positional embeddings.
    def __call__(self, x):
        # Generate a sequence of positions for the input tokens.
        positions = jnp.arange(0, x.shape[1])[None, :]
        # Look up the positional embeddings for each position in the input sequence.
        position_embedding = self.pos_emb(positions)
        # Look up the token embeddings for each token in the input sequence.
        token_embedding = self.token_emb(x)
        # Combine token and positional embeddings.
        return token_embedding + position_embedding

class MiniGPT(nnx.Module):
    """ A miniGPT transformer model, inherits from `flax.nnx.Module`.

    Args:
        maxlen (int): Maximum sequence length.
        vocab_size (int): Vocabulary size.
        embed_dim (int): Embedding dimensionality.
        num_heads (int): Number of attention heads.
        feed_forward_dim (int): Dimensionality of the feed-forward network.
        num_transformer_blocks (int): Number of transformer blocks. Each block contains attention and feed-forward networks.
        rngs (nnx.Rngs): A Flax NNX stream of JAX PRNG keys.
    """
    # Initialize miniGPT model components.
    def __init__(self, maxlen: int, vocab_size: int, embed_dim: int, num_heads: int, feed_forward_dim: int, num_transformer_blocks: int, rngs: nnx.Rngs):
        # Initiliaze the `TokenAndPositionEmbedding` that combines token and positional embeddings.
        self.embedding_layer = TokenAndPositionEmbedding(
                    maxlen, vocab_size, embed_dim, rngs=rngs
                )
        # Create a list of `TransformerBlock` instances.
        # Each block processes input sequences using attention and feed-forward networks.
        self.transformer_blocks = [TransformerBlock(
            embed_dim, num_heads, feed_forward_dim, rngs=rngs
        ) for _ in range(num_transformer_blocks)]
        # Initialize the output `flax.nnx.Linear` layer producing logits over the vocabulary for next-token prediction.
        self.output_layer = nnx.Linear(in_features=embed_dim,
                                       out_features=vocab_size,
                                       kernel_init=nnx.with_partitioning(nnx.initializers.xavier_uniform(), NamedSharding(mesh, P(None, 'model'))),
                                       bias_init=nnx.with_partitioning(nnx.initializers.zeros_init(), NamedSharding(mesh, P(None, 'model'))),
                                       rngs=rngs)

    def __call__(self, inputs, training: bool = False):
        # Pass the input tokens through the `embedding_layer` to get token embeddings.
        # Apply each transformer block sequentially to the embedded input, use the `training` flag for the behavior of `flax.nnx.Dropout`.
        x = self.embedding_layer(inputs)
        for transformer_block in self.transformer_blocks:
            x = transformer_block(x, training=training)
        # Pass the output of the transformer blocks through the output layer,
        # and obtain logits for each token in the vocabulary (for next token prediction).
        outputs = self.output_layer(x)
        return outputs

    @nnx.jit
    def sample_from(self, logits):
        logits, indices = jax.lax.top_k(logits, k=top_k)
        logits = nnx.softmax(logits)
        return jax.random.choice(jax.random.PRNGKey(0), indices, p=logits)

    @nnx.jit
    def generate_step(self, padded_tokens, sample_index):
        logits = self(padded_tokens)
        next_token = self.sample_from(logits[0][sample_index])
        return next_token

    def generate_text(self, max_tokens, start_tokens):
        generated = []
        print(tokenizer.decode(start_tokens), flush=True, end='')
        for i in range(max_tokens):
            sample_index = len(start_tokens) + len(generated) - 1

            padded_tokens = jnp.array((start_tokens + generated + [0] * (maxlen - len(start_tokens) - len(generated))))[None, :]
            next_token = int(self.generate_step(padded_tokens, sample_index))
            if next_token == tokenizer.encode('<|endoftext|>', allowed_special={'<|endoftext|>'})[0]:
              break
            generated.append(next_token)
            # decode and print next_token
            print(tokenizer.decode([next_token]), flush=True, end='')
        return tokenizer.decode(start_tokens + generated)

# Creates the miniGPT model with 4 transformer blocks.
def create_model(rngs):
    return MiniGPT(maxlen, vocab_size, embed_dim, num_heads, feed_forward_dim, num_transformer_blocks=4, rngs=rngs)

Set some hyperparameters.

vocab_size = tokenizer.n_vocab
num_transformer_blocks = 8
maxlen = 256
embed_dim = 256
num_heads = 8
feed_forward_dim = 256
batch_size = 256 # You can set a bigger batch size if you use Kaggle's Cloud TPU.
num_epochs = 1
top_k = 10

Loading and preprocessing the data#

Data loading and preprocessing with Grain.

@dataclass
class TextDataset:
    data: list
    maxlen: int

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx: int):
        # Use Tiktoken for tokenization
        encoding = tokenizer.encode(self.data[idx], allowed_special={'<|endoftext|>'})[:self.maxlen]  # Tokenize and truncate
        return encoding + [0] * (self.maxlen - len(encoding))  # Pad to maxlen

def load_and_preprocess_data(file_path, batch_size, maxlen):

    with open(file_path, 'r') as f:
      text = f.read()

    stories = text.split('<|endoftext|>')
    stories = [story+'<|endoftext|>' for story in stories if story.strip()]
    df = pd.DataFrame({'text': stories})
    data = df['text'].dropna().tolist()
    dataset = TextDataset(data, maxlen)

    sampler = pygrain.IndexSampler(
        len(dataset),
        shuffle=False,
        seed=42,
        shard_options=pygrain.NoSharding(),
        num_epochs=num_epochs,
    )

    dl = pygrain.DataLoader(
        data_source=dataset,
        sampler=sampler,
        operations=[pygrain.Batch(batch_size=batch_size, drop_remainder=True)],
    )

    return dl

text_dl = load_and_preprocess_data('TinyStories-train.txt', batch_size, maxlen)

Defining the loss function and training step function#

# Defines the loss function using `optax.softmax_cross_entropy_with_integer_labels`.
def loss_fn(model, batch):
    logits = model(batch[0])
    loss = optax.softmax_cross_entropy_with_integer_labels(logits=logits, labels=batch[1]).mean()
    return loss, logits

# Define the training step with the `flax.nnx.jit` transformation decorator.
@nnx.jit
def train_step(model: MiniGPT, optimizer: nnx.Optimizer, metrics: nnx.MultiMetric, batch):
    grad_fn = nnx.value_and_grad(loss_fn, has_aux=True)
    (loss, logits), grads = grad_fn(model, batch)
    metrics.update(loss=loss, logits=logits, lables=batch[1])
    optimizer.update(grads)

Training the model#

Start training. It takes ~50 minutes on Colab.

Note that for data parallel, we are sharding the training data along the batch axis using jax.device_put with NamedeSharding.

We are also using the jax.vmap transformation to produce the target sequences faster.

model = create_model(rngs=nnx.Rngs(0))
optimizer = nnx.Optimizer(model, optax.adam(1e-3))
metrics = nnx.MultiMetric(
  loss=nnx.metrics.Average('loss'),
)
rng = jax.random.PRNGKey(0)

start_prompt = "Once upon a time"
start_tokens = tokenizer.encode(start_prompt)[:maxlen]
print(f"Initial generated text:")
generated_text = model.generate_text(
    maxlen, start_tokens
)

metrics_history = {
  'train_loss': [],
}

prep_target_batch = jax.vmap(lambda tokens: jnp.concatenate((tokens[1:], jnp.array([0]))))

step = 0
for epoch in range(num_epochs):
    start_time = time.time()
    for batch in text_dl:
        if len(batch) % len(jax.devices()) != 0:
          continue  # skip the remaining elements
        input_batch = jnp.array(jnp.array(batch).T)
        target_batch = prep_target_batch(input_batch)
        train_step(model, optimizer, metrics, jax.device_put((input_batch, target_batch), NamedSharding(mesh, P('batch', None))))

        if (step + 1) % 200 == 0:
          for metric, value in metrics.compute().items():
              metrics_history[f'train_{metric}'].append(value)
          metrics.reset()

          elapsed_time = time.time() - start_time
          print(f"\n\nStep {step + 1}, Loss: {metrics_history['train_loss'][-1]}, Elapsed Time: {elapsed_time:.2f} seconds")
          start_time = time.time()

          print(f"Generated text:")
          generated_text = model.generate_text(
              maxlen, start_tokens
          )

        step += 1

# Final text generation
print(f"Final generated text:")
generated_text = model.generate_text(
    maxlen, start_tokens
)
Initial generated text:
Once upon a timeaciaGender gearuser Analysisval {} Bruce Lauren helic Lauren Bruce againstliterally SQU retire Path {}valascript northwest {} Bruceuit Pathascript northwestdrops freelyvic996 curated hysteria survivor {}sclaxteradvert Sitting qualifiers snack {} scenariovalameron {} Path {}Nick VeganExcept peasantascript Whites retire {} retire {} Analysisrest {} Mine psychedelic flankForgeModLoader Path Bravo {} inflic {} strutConnector psychedelic beyond Beforeocker interesting Dani {}sclaxter retire {}Nick sorrow Typesrest interestingUV FSyrus resorts {} Dani {} perished {} retire interesting sorrow reversibleurned {} Womanlast 118 reass gentlestudyManager {} retire {} verb Captain forbid Bruce {} Analysis ox {} inexplicable tumor psychedelic {} serverpelrest Sky {} cropDisclaimeruti Nortonocated twins Path {} psychedeliccre motionsundrum {} northwestroid variable {} Whites {} dancers iPod {} {} verb retire {} Fred Noble {}ampionscre lineman servesShould decision1024� serveraez {} retire interesting Tangrest Carly juice,. allowsmodulerest Antarsumameroncre Flesh --> northwestroidENN {} Gustav rolledMuch challengundrum {}val retire {}scl less {} perished Brigham Analysis developersSomething hiding {}scl Houthval {} northwest appease miles {} escalationManager {} northwest {} {} Cube psychedelic {} inflic {} retire {} Whites dancers {}scl FS lore appease Din {} Whites abnorm[] {} {}scl FS appease dangling Bruce abnormcre97 psychedeliccre!!!

Step 200, Loss: 4.653054714202881, Elapsed Time: 100.71 seconds
Generated text:
Once upon a time, there a little little little little little little little little little little little little little little little little little little little little little little little little little little little little little little little little little little little little little little little little little little little little little little little little little little little little little little little little little little little little little little little little little little little little little little little little little little little little little little little little little little little little little little little little little little little little little little little little little little little little little little little little little little little little little little little little little little little little little little little little little little little little little little little little little little little little little little little little little little little little little little little little little little little little little little little little little little little little little little little little little little little little little little little little little little little little little little little little little little little little little little little little little little little little little little little little little little little little little little little little little little little little little little little little little little little little little little little little little little little little little little little little little little little little little little little little little little little little little little little little little little little little little little little little little!!!!

Step 400, Loss: 3.0780816078186035, Elapsed Time: 59.71 seconds
Generated text:
Once upon a time there was a little girl named Lily. She loved to play outside and play with her mommy was very much fun. She loved to play with her mommy and her mommy's mommy's mommy said, but she went to her mommy's mommy's mommy's mommy said, "I'm going to play with her mommy.
Lily said, "I'm going to the park. She said, "I'm going to play with her mommy said, "I'm going to play with you want to play with you can't want to play with her mommy and said, "I'm so happy to play with her mommy and said, "I'm sorry, "I'm sorry, "I'm sorry, I can't worry, "I'm sorry, I can't want to be a good."


Step 600, Loss: 2.4993953704833984, Elapsed Time: 31.43 seconds
Generated text:
Once upon a time, there was a little girl named Lily. She loved to play with her toys and play with her toys. One day, she went to the park with her mommy and daddy.
"Let's go to the park!" Lily asked.
"Let's go to the park," her mommy said.
"No, we can't have to go to the park," said.
"Sure, Lily.
"Sure, I want to play with you!" Lily said.
"Yes, I'm going to play with you!" Lily said.
"Yes, I want to play with you!" Lily said.
Lily and Lily went to the park. She saw a big dog. She was so happy and had a lot of fun.
"Wow, Lily, Lily, Lily, Lily, Lily!" she said.
"I'm sorry, Lily. "I'm sorry, Lily. I'm sorry, Lily. I can't have a good friend."
Lily and Lily said, "Yes, Lily. I can't like to play with you."
Lily and Lily and Lily and Lily went to play with her mommy. They played together and Lily played together. They played together and played!!!!

Step 800, Loss: 2.1457183361053467, Elapsed Time: 32.42 seconds
Generated text:
Once upon a time, there was a little girl named Lily. She loved to play with her toys and her toys. One day, she went to the park with her mommy and dad. She saw a big dog named Max. Lily was scared and ran to Max.
"Hi, Lily! What's wrong?" asked Lily.
"I'm sorry, but I saw a big dog. I found a big dog. I found a big dog. I found a big dog and Lily said, "Don't worry, I will help you."
Lily was happy and said, "Thank you, Max. You are a good friend."
Lily felt happy and said, "Thank you, Lily. You are a good friend."


Step 1000, Loss: 1.9495660066604614, Elapsed Time: 31.16 seconds
Generated text:
Once upon a time, there was a little girl named Lily. She loved to play with her toys and her toys. One day, she found a big box in the box. She was very happy and wanted to show her mom her mom.
Lily was very happy. She said, "Mom, can I have a letter?"
Her mom said, "Sure, but you can't find a letter. It's a letter. It's a letter." Lily was so happy and thanked her mom.
After they finished, Lily's mom said, "Thank you, Lily. It's a letter." Lily was so happy to see her mom. She said, "Thank you, Lily. I love you, my letter."
Lily was happy to have a new letter. She said, "Thank you, Lily. I love you, my letter!"


Step 1200, Loss: 1.8355252742767334, Elapsed Time: 31.50 seconds
Generated text:
Once upon a time, there was a little girl named Lily. She loved to play with her toys and play with her toys. One day, she found a big box in her room. She was so happy and wanted to play with it.
Lily's mom said, "Lily, you can play with your toys. It's a toy car." Lily was so happy and said, "Thank you, Lily. I am so happy to have her toy car."
Lily was happy and said, "Thank you, Lily. I love you, my car is my car."
Lily was happy and said, "Thank you, Lily. I love you, Lily. I love you too!"


Step 1400, Loss: 1.7667107582092285, Elapsed Time: 31.08 seconds
Generated text:
Once upon a time, there was a little girl named Lily. She loved to play outside in the sunshine. One day, she saw a big red ball on the ground. She wanted to play with it, but her mom said no.
Lily went to the park and saw a big tree. She wanted to climb it. She climbed up and down the tree. She climbed up and down. She climbed up the tree and climbed up. She felt the wind on her face. She felt so happy and free.
But then, she heard a loud noise. It was coming from the tree. She was scared and ran away. She ran back to her mom and told her to go back to the tree. Her mom said she had to go to the tree and play on the swings. She was so happy that she had found a new friend. She was so happy to have found her new friend.


Step 1600, Loss: 1.7037931680679321, Elapsed Time: 31.64 seconds
Generated text:
Once upon a time, there was a little girl named Lily. She loved to play with her toys and play with her toys. One day, she went to the park with her mommy and daddy. She saw a big dog and wanted to pet it. She asked her mommy, "Can I pet the dog, please?" Her mommy said, "No, you can't. I want to pet the dog."
Lily was sad and didn't want to be sad. She wanted to pet the dog, but she was not happy. She said, "No, you have to be naughty and not to play with the dog. You have to be kind and gentle. You have to be kind and gentle and gentle."
The dog was happy and said, "Thank you, Lily. You are a good dog. I'm sorry for you. I'm sorry for the dog."
Lily felt bad and said, "I'm sorry, but I didn't mean to hurt you. I'm sorry for the dog. I didn't mean to hurt you. I'm sorry for the dog. I was mean to be rude and not to hurt you. I'm sorry for you. I'm sorry for you. I'm sorry to forgive you!!!!

Step 1800, Loss: 1.661496639251709, Elapsed Time: 32.36 seconds
Generated text:
Once upon a time, there was a little girl named Lucy. She was very excited to go on a tour with her mommy and daddy.
When they arrived at the park, Lucy saw a big tree with lots of leaves. She wanted to climb it and see what was on the tree.
Lucy climbed the tree and climbed the tree. She climbed up the tree and climbed the tree. She climbed higher and higher until she reached the top.
When Lucy reached the top, she saw a big, scary dog. The dog was barking loudly and Lucy was scared. She ran away, but the dog was too fast.
Lucy was scared and ran away. She tried to get back, but she couldn't. She was stuck in the tree, but she was too scared. She tried to get back, but the dog was too fast.
The dog was too fast and Lucy was scared. She tried to run away, but the dog was too fast. The dog was too fast and Lucy was too scared to get back. She was safe and happy.


Step 2000, Loss: 1.635001301765442, Elapsed Time: 32.21 seconds
Generated text:
Once upon a time, there was a little girl named Lucy. She was three years old and loved to play with her toys. One day, Lucy went to the park with her mom. She saw a big tree with lots of leaves. Lucy wanted to climb it.
Lucy asked her mom, "Mommy, can I climb the tree?" Her mom said, "Yes, but you can climb the tree. It's too high up." Lucy was so excited. She climbed up the tree and climbed up. She climbed higher and higher until she reached the top.
When she got to the top, she saw a big tree with a big tree. She climbed up the tree and climbed up. She climbed up the tree and climbed up. She climbed up the tree and climbed up. She climbed higher and higher, higher and higher until she was almost up.
When she was done, Lucy was so happy. She had a great adventure and was happy to have helped her mom.


Step 2200, Loss: 1.5896366834640503, Elapsed Time: 32.09 seconds
Generated text:
Once upon a time there was a little girl named Lucy. She was very happy and loved to play with her toys. One day, she found a big box in the box. She was so excited to open it.
Lucy opened the box and found a big box. Inside the box was a toy car. She was so happy and played with it all day long.
But then, Lucy accidentally knocked over the box. She was very sad. She wanted to see what was inside the box.
Lucy's mom saw her and said, "Lucy, you can't open the box. It's a toy car." Lucy was so happy and said, "Thank you, Mom. I'm so glad I could help."


Step 2400, Loss: 1.5428193807601929, Elapsed Time: 31.43 seconds
Generated text:
Once upon a time, there was a little girl named Sarah. She was three years old and loved to play outside. One day, Sarah went to the park to play. She saw a big slide and wanted to go on it. She climbed up and slid down the slide. She was so happy to see the slide.
But then, Sarah saw a big dog. She was scared and ran away. She ran to her mom and said, "Mommy, look! It's so big!" Her mom smiled and said, "Yes, it's okay. We can go on the slide again."
Sarah was so happy and said, "Thank you, Mommy! I'm so glad I could go to the slide again!" Her mom smiled and said, "Me too! Let's go!"


Step 2600, Loss: 1.5633305311203003, Elapsed Time: 31.28 seconds
Generated text:
Once upon a time, there was a little girl named Lucy. She was very curious and loved to explore. One day, she found a big box in the attic. She was curious and wanted to see what was inside.
Lucy opened the box and saw many things inside. She was curious and wanted to see what was inside. She opened the box and saw many things inside. She saw a big, scary monster.
Lucy was scared and ran away. She hid behind the box and waited for the monster to come back. The monster was hiding behind the box and the box was gone. Lucy was sad and scared. She wanted to go back to the box and play with the monster.
The monster was very angry and scared. It ran away from the box and Lucy never came back. The end.


Step 2800, Loss: 1.541506052017212, Elapsed Time: 31.95 seconds
Generated text:
Once upon a time, there was a little girl named Lucy. She was very happy and loved to play with her friends. One day, Lucy's mom said, "Lucy, you have to go to the park today. It's time for lunch." Lucy was sad and didn't want to go to the park.
But then, a big dog came and started to play. Lucy's mom saw the dog and said, "Don't worry, Lucy. We can go to the park and play." Lucy was so happy and ran to the dog. She ran and ran, but the dog was too fast. Lucy was scared and ran away.
The next day, Lucy and her mom went back to the park. They saw a big slide and ran to it. Lucy was so happy and said, "Thank you, mom! You're the best!" The dog smiled and said, "You're welcome, Lucy. I'm glad you're safe." Lucy smiled and said, "I'm glad you're safe."


Step 3000, Loss: 1.538221001625061, Elapsed Time: 31.71 seconds
Generated text:
Once upon a time, there was a little girl named Lucy. She was three years old and she loved to play with her toys. One day, Lucy's mommy said, "Lucy, I want to play with your toys." Lucy was very excited and said, "Yes, please!"
So, Lucy and her mommy went to the store to buy some candy. Lucy was so happy and said, "Thank you, Mommy!" Her mommy said, "You're welcome, Lucy. I'm glad you like candy." Lucy smiled and said, "I'm glad you like it."
Lucy's mommy said, "That's a great idea, Lucy. Let's go home and have some fun together." Lucy was so happy and said, "Yay! I love playing with my toys!"


Step 3200, Loss: 1.4795030355453491, Elapsed Time: 31.69 seconds
Generated text:
Once upon a time, there was a little girl named Lucy. She was three years old and loved to play with her toys. One day, Lucy's mom said, "Lucy, let's go for a walk!" Lucy was so excited. She ran to the park and saw a big, shiny rock. She picked it up and showed it to her mom.
"Look, Mommy! I found a rock!" said Lucy. "It's so pretty!"
Her mom smiled and said, "That's a great rock, Lucy. It's a pretty rock. It's a pretty rock."
Lucy put the rock on the rock and rocked back and forth. She felt happy and proud. She had found a new rock and showed it to her mom.


Step 3400, Loss: 1.510359764099121, Elapsed Time: 31.32 seconds
Generated text:
Once upon a time there was a little girl named Lucy. She was very excited because she was going to the park. She was going to the park and she saw a big slide. She wanted to go on the slide, but she was too scared to go down.
Suddenly, Lucy saw a big, scary monster. It was scary and Lucy was scared. She ran to her mom and said, "Mom, what is that scary monster in the park?"
Her mom smiled and said, "That's a monster, Lucy. It's just a big, scary monster."
The monster said, "I'm scared of the monster. It's just a monster."
Lucy was scared, but she was brave. She said, "Don't worry, I'll help you."
The monster was friendly and Lucy was happy to hear the monster. She said, "Thank you, monster!"
The monster smiled and said, "You're welcome, little one. I'm glad you're safe."
The monster smiled and said, "You're welcome, Lucy. I'm glad you're safe."


Step 3600, Loss: 1.4839898347854614, Elapsed Time: 32.21 seconds
Generated text:
Once upon a time, there was a little girl named Lucy. She was three years old and loved to play outside. One day, Lucy's mommy said she had to go to the store to buy some groceries. Lucy was so excited! She ran to the store and bought some groceries.
When Lucy got there, she saw a big, red apple. She was so happy! She picked it up and started to eat it. It was so yummy! Lucy was so happy! She ate the apple and felt so good.
After she finished eating, Lucy's mommy said she had to go home. Lucy was sad because she couldn't play with her toys anymore. She was sad because she couldn't play with her toys anymore.


Step 3800, Loss: 1.4641729593276978, Elapsed Time: 31.44 seconds
Generated text:
Once upon a time, there was a little girl called Lucy. She was three years old and loved to play outside. One day, she was playing in the garden when she saw a big, shiny rock. She picked it up and showed it to her mom.
"Look, Mommy, I found a rock!" she said.
Her mom smiled and said, "That's a great rock, Lucy. It's very pretty."
Lucy was so happy to have a new friend. She showed it to her mom and said, "Look, Mommy, I found a rock!"
Her mom smiled and said, "That's great, Lucy. Let's keep it safe."
So, they went inside and found a nice spot to keep it safe. They were so happy to have found a rock and they played together all day.


Step 4000, Loss: 1.46944260597229, Elapsed Time: 31.43 seconds
Generated text:
Once upon a time, there was a little girl named Lucy. She was three years old and she loved to play outside. One day, she went to the park with her mom. She saw a big slide and wanted to go on it.
Lucy went up to the slide and climbed up. She was very excited to go on the slide. She climbed up the ladder and slid down the slide. Whe!
When she got to the top, she saw a big slide. She was so happy! She climbed up the ladder and slid down the slide. She was so happy!
When she got to the top, she saw a big slide. She climbed up the ladder and slid down the slide. She was so happy! She laughed and slid down the slide.
At the end of the day, Lucy was tired but happy. She had a fun day at the park.


Step 4200, Loss: 1.4679889678955078, Elapsed Time: 31.67 seconds
Generated text:
Once upon a time there was a little girl named Lucy. She was very happy and loved to play with her friends. One day, she saw a big, red ball in the park. She wanted to play with it, but she was too scared to move.
Lucy tried to move the ball, but it was too fast. She tried to move it, but it was too fast. She tried to move it, but it was too fast. She tried to move it, but it was too fast.
Then, a kind lady saw Lucy and asked her if she could help her. She said yes, and Lucy was so happy. She gave Lucy a big hug and a big hug.
The lady was so happy that she hugged Lucy and thanked her. She said she was a good girl and she was very happy.


Step 4400, Loss: 1.4406870603561401, Elapsed Time: 31.44 seconds
Generated text:
Once upon a time, there was a little girl named Lucy. She was very happy and she loved to play with her toys. One day, she was playing with her toys when she saw a big, scary monster. She screamed and ran away.
Lucy was very scared and ran away. She ran and hid behind a tree. Suddenly, a big dog came running towards her. Lucy was very scared and ran away.
The big dog saw her and ran away. Lucy was very sad and scared. She ran back to her mom and dad and told them what happened. Her mom hugged her and said, "It's okay, Lucy. We will always be careful and not hurt you."
The big dog stopped and looked at Lucy. He said, "I'm sorry, Lucy. I was just scared of you."
Lucy was happy and hugged her mom. She said, "It's okay, I'm here. I'm glad I was safe."


Step 4600, Loss: 1.4200118780136108, Elapsed Time: 32.36 seconds
Generated text:
Once upon a time, there was a little girl named Jane. She was very happy and loved to play outside. One day, she saw a big, scary monster. It was so scary that she started to scream.
The monster was so scared that it started to run away. Jane was so scared that she ran away. She ran and ran until she was safe.
The monster was so scared that it ran away. Jane was so scared that she ran away. She never got to be scared.
The monster chased Jane and ran away. But she was too fast. She was too scared to run away.
The monster chased Jane and ran until she was safe. She was safe and happy.


Step 4800, Loss: 1.4398597478866577, Elapsed Time: 31.73 seconds
Generated text:
Once upon a time there was a little girl named Lucy. She was very happy and loved to play with her friends. One day, she was playing in the park when she saw a big, red ball. She wanted to play with it, so she ran to her mom.
"Mom, what is this?" asked Lucy.
"It's a ball, sweetheart," her mom replied.
"It's a ball, but it's not a toy," Lucy said.
Her mom smiled and said, "Yes, it's a good idea. Let's go to the park and play with the ball."
So, Lucy and her friends played with the ball all day long. They had so much fun and laughed together.


Step 5000, Loss: 1.4402837753295898, Elapsed Time: 31.41 seconds
Generated text:
Once upon a time, there was a little girl named Lucy. She was three years old and loved to play in the park. One day, she saw a big, shiny thing in the grass. She wanted to take it home, but it was too heavy for her. She tried to lift it, but it was too heavy. She tried and tried, but it was too heavy. She tried and tried, but it was too heavy. She tried and tried, but it was too heavy. She tried and tried, but it was too heavy. She tried and tried, but it still couldn't. She felt sad and frustrated. She wished she had a friend to help her.


Step 5200, Loss: 1.3984873294830322, Elapsed Time: 31.73 seconds
Generated text:
Once upon a time, there was a little girl named Lucy. She was three years old and loved to play with her toys. One day, she was playing with her toys when she accidentally broke one of the pieces. She was very sad and didn't know what to do.
Her mom saw her crying and asked her what was wrong. Lucy told her that she had broken pieces. Her mom said, "Don't worry, we can fix the pieces."
So Lucy and her mom went to the kitchen and fixed the pieces. When they got to the kitchen, Lucy was so happy and thanked her mom. She said, "I'm sorry I broke my piece."
Lucy was so sad and cried. She said, "I'm sorry, Mom. I didn't mean to break my piece."
Her mom hugged her and said, "It's okay, Lucy. Accidents happen. Let's clean it up again and we can fix it again."
Lucy was so happy and hugged her mom. She hugged her and said, "I'm sorry, Mom. I didn't mean to break the pieces. I'm sorry I didn't mean to break."
Her mom hugged her and said, "It's okay, Lucy. I forgive!!!!

Step 5400, Loss: 1.4043035507202148, Elapsed Time: 32.24 seconds
Generated text:
Once upon a time, there was a little girl named Lucy. She was three years old and loved to play with her toys. One day, she was playing with her toys when she heard a loud noise. She looked around and saw a big, scary monster. The monster was roaring and roaring. Lucy was scared and ran away.
The monster chased her and Lucy ran away. She ran and hid behind a tree. She never saw the monster again. The monster was gone and Lucy was safe.
The monster was very strong and brave. He chased Lucy and ran away. Lucy was safe and sound. She never saw the monster again.


Step 5600, Loss: 1.3977570533752441, Elapsed Time: 31.21 seconds
Generated text:
Once upon a time, there was a little girl named Lucy. She was very excited because she was going to the park. She was going to the park and she was going to the park.
When she arrived at the park, she saw a big slide. She wanted to slide down the slide, but she was scared. She started to cry.
Suddenly, a voice came from the park. It was a little girl who was crying. She had lost her toy. She looked around and saw a man. He was crying and crying.
The little girl asked the man if he had seen her toy. The man said, "I'm sorry, but I can't find my toy."
The little girl was so happy and thanked the man. She said, "I'm sorry I lost my toy. I can't find it."
The man smiled and said, "It's okay. I'm here. I'll help you find your toy."
The little girl was so happy and thanked the man. She went back home and told her mom about the park. Her mom was very happy and said, "I'm glad you're here. I'm glad you're here."


Step 5800, Loss: 1.38675856590271, Elapsed Time: 32.10 seconds
Generated text:
Once upon a time, there was a little girl named Lucy. She was three years old and loved to play outside. One day, Lucy went to the park with her mom. She saw a big, red ball and wanted to play with it.
Lucy ran to her mom and said, "Mommy, can I play with the ball?"
Her mom said, "No, Lucy. It's not safe. It's not safe to play with."
Lucy was sad, but she understood. She said, "Okay, Mommy. I will play with the ball."
Her mom smiled and said, "Okay, Lucy. Let's play with the ball."
So Lucy and her mom played with the ball and had lots of fun. They had a lot of fun and Lucy was happy.


Step 6000, Loss: 1.4106464385986328, Elapsed Time: 31.60 seconds
Generated text:
Once upon a time, there was a little girl named Lucy. She loved to play with her toys and her friends. One day, Lucy's mom asked her to clean up her toys. Lucy didn't want to clean up, so she said, "No, Mommy! I don't want to clean up."
Her mom said, "But Lucy, you need to clean up your toys. You can do it if you don't clean up."
Lucy didn't want to clean up, so she said, "But Mommy, I want to clean up my toys. I want to clean up my toys."
Her mom said, "Okay, but you have to clean up your toys first. You can do it again tomorrow."
Lucy nodded and said, "Okay, Mommy. I will clean up."
So, Lucy and her mom cleaned up and cleaned up all the toys. They were happy and clean.


Step 6200, Loss: 1.3815696239471436, Elapsed Time: 31.65 seconds
Generated text:
Once upon a time, there was a little girl named Lily. She loved to play outside in the park. One day, she saw a big, scary dog. The dog was very fast and ran away.
Lily was scared and ran after the dog. She ran as fast as she could. The dog was fast and ran away.
Lily was safe and happy. She learned that it's important to be careful when she played in the park. She also learned that it's important to be careful when playing in the park.


Step 6400, Loss: 1.3922070264816284, Elapsed Time: 31.11 seconds
Generated text:
Once upon a time, there was a little girl named Lucy. She was three years old and loved to play with her toys. One day, she was playing with her toys when she heard a loud noise. She looked up and saw a big, scary monster!
The monster was very scary and it started to shake. Lucy was scared and didn't know what to do. She wanted to run away, but the monster was too fast.
Suddenly, the monster started to shake and growl. Lucy was scared and ran away. She never saw the monster again.
The monster was gone forever. Lucy was very sad and wished she had never gone to play with her toys. She never forgot the monster and the monster was never seen again.


Step 6600, Loss: 1.3984136581420898, Elapsed Time: 31.23 seconds
Generated text:
Once upon a time, there was a little girl named Lucy. She was three years old and loved to play with her toys. One day, she was playing with her toys when she heard a loud noise. She looked up and saw a big, scary monster. She was scared and didn't know what to do.
Suddenly, a big, scary monster appeared. It was the monster's owner, who was watching her. She was scared and ran away. Lucy was scared and ran home.
The monster was very angry and started to chase her. Lucy was scared and ran as fast as she could. She ran as fast as she could, but the monster was too fast and caught her.
The monster was caught and took her away. Lucy was very sad and cried. She never saw her toys again.


Step 6800, Loss: 1.4159687757492065, Elapsed Time: 31.67 seconds
Generated text:
Once upon a time, there was a little girl named Lucy. She loved to play with her toys, but one day she found a big box in the attic. She was so excited to open it and see what was inside.
Inside the box, she found a shiny new toy. It was a toy car. She wanted to play with it, but it was too expensive. She asked her mom if she could have it, but her mom said no.
Lucy was sad and didn't know what to do. She wanted to play with the car, but her mom said no. She said it was too expensive and she should not have it.
Lucy was sad and didn't know what to do. She wanted to play with the car, but her mom said no. Lucy was sad and didn't understand why her mom was so sad.
Then, her mom came to her room and saw the car. She was so happy and hugged Lucy. She hugged her mom and said, "I'm sorry, Mom. I didn't mean to spoil you."
Lucy smiled and hugged her mom. She was so happy and grateful. She learned that sometimes things can be expensive, but it's important to be kind and share with others.
!!!!

Step 7000, Loss: 1.406973123550415, Elapsed Time: 32.53 seconds
Generated text:
Once upon a time, there was a little girl named Lucy. She was three years old and loved to play outside. One day, Lucy was playing in the park when she saw a big, scary dog. She was scared and ran away.
Lucy was very scared. She didn't know what to do. She looked around and saw a big, scary dog. The dog was friendly and wagged its tail. Lucy was so scared that she ran away.
The dog was very friendly and wagged its tail. Lucy was so happy that she had a friend. She hugged the dog and said, "I'm here to play with you!"
The dog was so happy that Lucy had a friend. They played together all day long. They laughed and laughed until it was time to go home.


Step 7200, Loss: 1.3830863237380981, Elapsed Time: 31.75 seconds
Generated text:
Once upon a time, there was a little girl named Lucy. She was three years old and loved to play outside.
One day, Lucy was playing in the park when she saw a big, scary dog. She was scared and ran away.
Lucy's mom saw her and said, "Don't worry, Lucy. The dog is just a little bit. He is just playing."
Lucy was so happy that she started to laugh and play. She ran around the park, laughing and having fun.
Suddenly, Lucy heard a loud noise. She looked up and saw a big, scary dog. Lucy was scared and ran back home.
The next day, Lucy was playing in the park again. She was so scared that she ran back home.


Step 7400, Loss: 1.390000581741333, Elapsed Time: 31.51 seconds
Generated text:
Once upon a time, there was a little girl named Lucy. She was very happy and loved to play outside. One day, she went to the park with her mommy and daddy. She saw a big tree and wanted to climb it. She asked her mommy if she could go. Her mommy said yes and Lucy was so excited. She ran up the tree and started to climb. She was so high that she could see the whole park. She was so happy to be outside and she wanted to go home. She ran and ran until she reached the top of the tree. She looked around and saw a big, beautiful butterfly. She was so happy to see it and she wanted to touch it. She ran back to her mommy and daddy and told them about the butterfly. They said they were so lucky to see it and they all hugged. Lucy was so happy to be home with her new friend.


Step 7600, Loss: 1.3659923076629639, Elapsed Time: 31.61 seconds
Generated text:
Once upon a time, there was a little girl named Amy. She was three years old and loved to play. One day, she was playing in the park when she saw a big, shiny ball. She wanted to play with it, so she ran to get it.
When she got close, she saw a big, shiny ball. She was so excited and wanted to play with it. She ran to the ball and picked it up. She threw the ball and it went very fast.
When she got back to the park, she saw a big, shiny ball. She was so happy and ran to play with it. She kicked the ball and it flew high in the sky. She was so happy and she had so much fun with the big, shiny ball.


Step 7800, Loss: 1.361992359161377, Elapsed Time: 31.49 seconds
Generated text:
Once upon a time, there was a little girl named Lucy. She was very excited because today was a special day. She had a big box of makeup that she wanted to try and play with.
So, she put on her makeup and went outside to play. She saw a big, scary monster! The monster was very scary and it was so scary.
Lucy was scared and started to cry. She wanted to get her makeup, but she was too scared to go inside. She tried to run away, but the monster was too fast.
Then, the monster came closer and closer. It bit Lucy's finger and she was so scared. She screamed and cried until the monster was gone.
The monster was so scared that it ran away. Lucy was so sad and scared. She wished she had been more careful.
The next day, the monster came back to Lucy and she was so happy. She had a new friend and she was so happy to be safe.


Step 8000, Loss: 1.3433030843734741, Elapsed Time: 31.83 seconds
Generated text:
Once upon a time, there was a little girl named Lucy. She was very excited because she was going to the park. She ran to the park and saw a big tree. She wanted to climb it.
She asked her mom if she could climb the tree. Her mom said yes, so she climbed the tree. She climbed higher and higher until she reached the top.
When she got to the top, she saw a big tree. She was so happy! She climbed up the tree and climbed up. She felt so good on the top.
When she got to the top, she saw a big, beautiful view. She was so happy! She ran to the top and looked down. She felt so lucky to be able to climb the tree.
Lucy was so glad she got to climb the tree. She was so glad she got to climb the tree. She was so glad she got to climb the tree.


Step 8200, Loss: 1.3881804943084717, Elapsed Time: 31.76 seconds
Generated text:
Once upon a time, there was a little girl named Lucy. She was very happy and loved to play with her friends. One day, Lucy's mom asked her to help her. Lucy was very excited and said yes.
Lucy's mom said, "Let's go to the store and buy some food." Lucy was very happy and said, "Yay! I want to buy some food!"
So, Lucy and her mom went to the store to buy food. When they got to the store, Lucy saw a big, yummy food. She asked her mom if she could buy it. Her mom said, "No, Lucy. We can't buy it."
Lucy was sad and said, "But I want to buy something else. I want to buy something else."
Her mom said, "Okay, but you have to wait until you get to buy something else. It's a nice, but you have to wait until it's time to get something else."
Lucy was so excited and she waited patiently. When it was time to go, she was so happy and thanked her mom for the yummy food.
Final generated text:
Once upon a time, there was a little girl named Lucy. She was very excited to go to the park. She put on her shoes and ran to the park.
When she got to the park, she saw a big slide. She wanted to go on it. She ran to the slide and started to slide down. She was so fast!
Suddenly, she heard a loud noise. It was a big, scary dog. The dog was barking loudly. Lucy was scared. She ran back to the park and started to run.
The dog was so fast that it ran away. Lucy was safe. She was so happy she had gone on the slide.

Visualize the training loss.

import matplotlib.pyplot as plt
plt.plot(metrics_history['train_loss'])
plt.title('Training Loss')
plt.xlabel('Step')
plt.ylabel('Loss')
plt.show()
_images/493ed649fb3633a266fb5e318e8c2d943fe76e2be467d87c65fca3a663d5aae2.png

As you can see, the model goes from generating completely random words at the beginning to generating sensible tiny stories at the end of the training. So essentially we have pretrained a small LLM to write tiny stories for us.

Saving the checkpoint#

Save the model checkpoint.

import orbax.checkpoint as orbax

state = nnx.state(model)

checkpointer = orbax.PyTreeCheckpointer()
checkpointer.save('/content/save', state)

# Make sure the files are there
!ls /content/save/
array_metadatas       d		      _METADATA        _sharding
_CHECKPOINT_METADATA  manifest.ocdbt  ocdbt.process_0
WARNING:absl:[process=0][thread=MainThread] Skipped cross-host ArrayMetadata validation because only one process is found: process_index=0.

Profiling for hyperparameter tuning#

!pip install -Uq tensorboard-plugin-profile tensorflow tensorboard
   ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 12.8/12.8 MB 63.6 MB/s eta 0:00:00
   ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 644.9/644.9 MB 1.7 MB/s eta 0:00:00
   ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 5.5/5.5 MB 75.9 MB/s eta 0:00:00
   ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 104.8/104.8 kB 7.0 MB/s eta 0:00:00
   ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 57.5/57.5 kB 4.8 MB/s eta 0:00:00
   ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 24.5/24.5 MB 79.8 MB/s eta 0:00:00
   ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 294.6/294.6 kB 19.6 MB/s eta 0:00:00
   ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 6.6/6.6 MB 111.1 MB/s eta 0:00:00
   ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 5.1/5.1 MB 88.5 MB/s eta 0:00:00
   ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 224.5/224.5 kB 18.1 MB/s eta 0:00:00
   ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 72.5/72.5 kB 6.1 MB/s eta 0:00:00
?25h

Load the tensorboard colab extension.

%load_ext tensorboard

As we’re going to be running this model a number of times, we need some scaffolding to more easily compare our work. For a baseline, we’ll need to perform some warmup to guarantee that our code is JIT’d and that our TPUs are warm. For improved comparability, we’ll only start tracing after we’ve finished warmup.

trace_dir = "/tmp/jax-trace/"

def loop_step(batch, step):
    input_batch = jnp.array(jnp.array(batch).T)
    target_batch = prep_target_batch(input_batch)
    train_step(model, optimizer, metrics, jax.device_put((input_batch, target_batch), NamedSharding(mesh, P('batch', None))))

def generate_trace():
    tracing_steps = 30
    warmup_steps = 5
    for current_step in range(warmup_steps + tracing_steps):
        if current_step == warmup_steps:
            jax.profiler.start_trace(trace_dir)
        with jax.profiler.StepTraceAnnotation("train", step_num=current_step):
            batch = next(text_dl)
            loop_step(batch, current_step)

    jax.profiler.stop_trace()

Now we’ll perform some traces to compare results of different batch sizes. This will take several minutes as we need to reprocess our input data to prepare new batches each time.

trace_dir = "/tmp/jax-trace-batch-comparison/"

batch_size = 64
text_dl = iter(load_and_preprocess_data('TinyStories-train.txt', batch_size, maxlen))
generate_trace()

batch_size = 256
text_dl = iter(load_and_preprocess_data('TinyStories-train.txt', batch_size, maxlen))
generate_trace()

Run Tensorboard with the Profiler Plugin to compare our runs. Runs are listed in order from newest to oldest, so the top run in the list will be have batch_size = 256.

The key metrics to focus on here for this hyperparameter are FLOPS Utilization and Average Step Time.

In general, we want to maximize FLOPS Utilization while minimizing the step time per training example. In this case, we can see that increasing the batch size from 64 -> 256 achieves both of those. FLOPS increases from 16% to 27%. Average Step Time increase from 100ms to 260ms, however we increased our batch size by 300%. This means we move from 1.5ms per training example to 1.02ms per training example.

%tensorboard --logdir=$trace_dir

Next, we can explore alternative parallelism methods. In cell #4, we used 4-way data parallel and 2-way tensor parallel. 8-way data parallel is another popular way. Let’s compare results between them. To switch to 8-way data parallel, we’ll replace the Mesh definition with:

mesh = Mesh(mesh_utils.create_device_mesh((8, 1)), ('batch', 'model'))

JAX will automatically figure out how to shard the model and data to use the new partition strategy and nothing else need to be done. Re-connect the TPU runtime and run it again to see how it runs.

How simple and powerful is this! And that’s the beauty of JAX automatic parallelism.

trace_dir = "/tmp/jax-trace-parallelism-comparison/"

mesh = Mesh(mesh_utils.create_device_mesh((4, 2)), ('batch', 'model'))
generate_trace()

mesh = Mesh(mesh_utils.create_device_mesh((8, 1)), ('batch', 'model'))
generate_trace()

Once again we’ll run tensorboard.

Looking at the results, we see that the step times are nearly the same, however the FLOPS Utilization is at 13% for 8-way data parallelism compared to 27% or 4-way data parallelism.

By looking at the Trace Viewer tool and looking under each TPU’s ops, we can see that the TPUs spend a large amount of time idle while waiting for the host, as well as spending a good amount of time in reduce_sum operations.

%tensorboard --logdir=$trace_dir

By changing hyperparameters and comparing profiles, we’re able to gain significant insights into our bottlenecks and limitations. These are just two examples of hyperparameters to tune, but plenty more of them will have significant effects on training speed and resource utilization.