Checkpointing Deep Dive: Current Limitations And Future Plans

by ADMIN 62 views

Hey guys! Let's dive into the nitty-gritty of checkpointing in our systems, specifically focusing on why we can't easily save everything just yet and what we're cooking up for the future. This is a critical topic for ensuring the resilience and reproducibility of our machine learning models, so let's get into it.

TL;DR: The Current State of Checkpointing

We can't cleanly checkpoint everything, like the dataloader, replay buffer, random number generator (RNG), and other crucial components, because of how things are currently structured. The checkpointer in the Titan system lives inside the trainer.engine and is solely responsible for managing the step-<N> folders. The other components I mentioned exist in separate actors. There's currently no safe, centralized method to write all the data into the same step folder. So, we're putting off full multi-component checkpointing until after PTC (I’ll explain this later). The good news is, thanks to this pull request, we can already save and resume model weights, optimizer states, and learning rate schedulers using Titan's checkpointer.


Context: How Checkpointing Works Today

Let’s start with how things work currently. Our system spins up various components in the main script:

(
    dataloader,       # DatasetActor (actor)
    policy,           # Policy (service)
    trainer,          # RLTrainer (actor)
    replay_buffer,    # ReplayBuffer (actor)
    compute_advantages,  # ComputeAdvantages (actor)
    ref_model,        # ReferenceModel (service)
    reward_actor,     # RewardActor (service)
) = await asyncio.gather(...)

The model checkpointing is handled by the TorchTitan library, specifically through the trainer's engine. The trainer is responsible for creating the engine and loading any existing checkpoints:

self.engine = ForgeEngine(ForgeJobConfig(**engine_config))
self.engine.checkpointer.load(step=self.step)

During each training step (inside the train_step function), the trainer instructs Titan to save the necessary information:

self.engine.checkpointer.save(
    curr_step=self.step,
    last_step=self.step == self.num_training_steps,
)

Titan's checkpointer then writes the model weights, optimizer state, and learning rate schedulers into a folder structure that looks like this:

<folder>/step-<N>/__0_0.distcp

We also want to save and load other essential states, like the data step, replay buffer data, and RNG states. These additions will allow us to fully restore the training process from any point, improving the robustness of our models. This is where the core challenge lies: the current architecture doesn't naturally support saving all these different components together.


The Problem: Obstacles to Comprehensive Checkpointing

So, what's the deal? Why can't we just save everything? The problem boils down to a couple of key issues.

Problem 1: Step-Folder Ownership

Currently, we have one Titan-owned directory per step (e.g., step-200), which is created internally within the trainer.engine.checkpointer. Other actors, like the dataloader and replay buffer, don't have access to the trainer's engine or the internal folder naming conventions of Titan. This setup leads to two not-so-great choices.

  1. Two folders per step:

    checkpoint/
      step-200/          # Titan
        __0_0.distcp
      step-200-other/    # Ours
        dataloader.json
        replay_buffer.bin
        rng.json
    

    This option is clunky. It creates a messy user experience, makes it difficult to atomically purge or retain checkpoints, and is prone to errors. You can easily get the components out of sync.

  2. Single folder per step (the preferred option):

    To save everything in the same step-200/ folder, we'd have to jump through some hoops:

    • Call Titan's private _create_checkpoint_id to get the folder name. However, other components (like the dataloader) don't have access to the engine. It is a big NO-NO.
    • Reimplement a similar function and hope it never deviates from the original. This is also not a good approach.
    • (Preferred Solution) Add a path parameter to the checkpointer.save function. This will allow us to specify the save location, giving us more control. This is the most promising option.

Problem 2: Lack of a Unified Saving Mechanism

Currently, the different states are managed by separate actors or services (e.g., the dataloader), and the trainer isn't aware of them. Moreover, Titan's checkpointer resides within trainer.engine.checkpointer and is only responsible for the model, optimizer, and learning rate scheduler. There's no central place to coordinate the saving of all these components into the same step-<N> directory. This is a critical problem.


Proposed Solutions: How We Can Move Forward

Now, let's explore some solutions to address these problems.

Option 1: Trainer as the Central Owner

class RLTrainer:
    self.dataloader = ...
    self.replay_buffer = ...
    self.rng = ...

This would mean that the trainer would own all the other components. It's a quick fix in terms of implementation, but it causes tight coupling, violates the actor/service separation, and ultimately hurts scalability and reusability. It's not a great long-term solution.

Option 2: Reimplement Checkpointing

We could create our own model, optimizer, and learning rate scheduler checkpointing from scratch. This would give us full control over the layout and atomicity of the process. However, this is a risky, high-effort task, and we're guaranteed to diverge from Titan over time. It's not the best approach.

Option 3: Checkpoint Coordinator

This solution involves introducing a Checkpoint Coordinator. The coordinator would sit above the existing actors and handle the checkpointing process. Here’s how it would work:

  • The coordinator calls Titan to save the model, learning rate scheduler, and optimizers to a specified path (this will require a small API update to Titan's save function to accept a path parameter).
  • The coordinator then asks each actor to create a state_dict() and writes it to the same folder (e.g., step-200/dataloader.json, etc.).
  • On loading, after Titan resolves the step to load, the coordinator attempts to load each component's states by calling their load_state() function.
class CheckpointCoordinator:
    def __init__(self):
        self._trainer: RLTrainer = None
        self._components: Dict[str, ForgeActor | SeverceInterface] = {}

    def set_trainer(self, trainer:RLTrainer):
        self._trainer = trainer

    def register(self, name, comp: ForgeActor | SeverceInterface):
        self._components[name] = comp

    async def save(self, step: int, path: str):
        path = get_path(folder, step)
        if self._trainer:
            self._trainer.engine.checkpointer.save(path = path)
        for name, comp in self._components.items():
            states = comp.state_dict()
            save(states, f"{path}/{name}.json")

    async def load(self, step: int, path: str):
        ...

The changes needed in grpo/main would look like this:

coord = CheckpointCoordinator()
coord.set_trainer(trainer)

coord.register("dataloader", dataloader)
coord.register("replay_buffer", replay_buffer)
...
await coord.load(step, path=step_dir)
await coord.save(step, path=step_dir)

This option is relatively easy to implement and leverages much of the existing infrastructure. However, it still has some drawbacks:

  • It has a nested structure: coord.save calls self._trainer.engine.checkpointer.save.
  • It is slightly specific to our current grpo script.

Option 4: Standalone ForgeCheckpointManager

In the long run, we could create a standalone manager ForgeCheckpointManager that inherits from Titan's CheckpointManager. This manager would orchestrate both Titan and any additional components in a single save()/load() call. Actors would register their export_state/import_state with this manager, and main would only need to call this single manager. This is a more elegant and scalable solution.

  • Open Question: Where would the ForgeCheckpointManager reside if the engine remains within the trainer? Also, how can it read/write the model, optimizer, and learning rate scheduler states without re-nesting the trainer or breaking the decoupling of actors?

That's the plan, folks! We're actively working on improving our checkpointing capabilities to make our models even more robust and reliable. Stay tuned for updates as we continue to refine this process.