Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 18 additions & 26 deletions src/MaxText/elastic_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,19 +115,15 @@ def elastic_handler(
checkpoint_manager.close()

with jax.default_device(elastic_manager.default_device):
(
init_rng,
checkpoint_manager,
state_mesh_shardings,
model,
mesh,
learning_rate_schedule,
data_iterator,
_,
_,
_,
state,
) = setup_train_loop(config, recorder, elastic_manager.good_devices)
ctx = setup_train_loop(config, recorder, elastic_manager.good_devices)
init_rng = ctx.init_rng
checkpoint_manager = ctx.checkpoint_manager
state_mesh_shardings = ctx.state_mesh_shardings
model = ctx.model
mesh = ctx.mesh
learning_rate_schedule = ctx.learning_rate_schedule
data_iterator = ctx.data_iterator
state = ctx.state

p_train_step, _ = train_utils.jit_train_and_eval_step(config, model, mesh, state, state_mesh_shardings, train_step)

Expand Down Expand Up @@ -171,19 +167,15 @@ def elastic_handler(

def train_loop(config, elastic_manager, recorder, state=None):
"""Main Training loop."""
(
init_rng,
checkpoint_manager,
state_mesh_shardings,
model,
mesh,
learning_rate_schedule,
data_iterator,
_,
_,
_,
state,
) = setup_train_loop(config, recorder)
ctx = setup_train_loop(config, recorder)
init_rng = ctx.init_rng
checkpoint_manager = ctx.checkpoint_manager
state_mesh_shardings = ctx.state_mesh_shardings
model = ctx.model
mesh = ctx.mesh
learning_rate_schedule = ctx.learning_rate_schedule
data_iterator = ctx.data_iterator
state = ctx.state

p_train_step, _ = train_utils.jit_train_and_eval_step(config, model, mesh, state, state_mesh_shardings, train_step)
with jax.set_mesh(mesh), nn_partitioning.axis_rules(config.logical_axis_rules):
Expand Down
23 changes: 10 additions & 13 deletions src/MaxText/sft_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,19 +58,16 @@ def train_loop(config, recorder, state=None):
if not config.use_sft:
raise TypeError("Set use_sft to True to run Supervised Fine Tuning.")

(
init_rng,
checkpoint_manager,
state_mesh_shardings,
model,
mesh,
learning_rate_schedule,
data_iterator,
_,
_,
eval_data_iterator,
state,
) = setup_train_loop(config, recorder)
ctx = setup_train_loop(config, recorder)
init_rng = ctx.init_rng
checkpoint_manager = ctx.checkpoint_manager
state_mesh_shardings = ctx.state_mesh_shardings
model = ctx.model
mesh = ctx.mesh
learning_rate_schedule = ctx.learning_rate_schedule
data_iterator = ctx.data_iterator
eval_data_iterator = ctx.eval_data_iterator
state = ctx.state

params_shardings, state_mesh_shardings = sharding.maybe_update_params_sharding_with_opt(config, state_mesh_shardings)

Expand Down
53 changes: 22 additions & 31 deletions src/MaxText/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -389,52 +389,43 @@ def eval_step(model, config, state, data, dropout_rng):

def train_loop(config, recorder, state=None):
"""Main Training loop."""
(
init_rng,
checkpoint_manager,
state_mesh_shardings,
model,
mesh,
learning_rate_schedule,
data_iterator,
data_loader,
rampup_manager,
eval_data_iterator,
state,
) = train_utils.setup_train_loop(config, recorder)
ctx = train_utils.setup_train_loop(config, recorder)
state = ctx.state

if config.use_dpo:
if "reference_params" not in state.params:
reference_params = jax.tree.map(jnp.copy, state.params["params"])
state = _merge_dpo_state(state, reference_params)
state_mesh_shardings = _merge_dpo_state(state_mesh_shardings, state_mesh_shardings.params["params"])
state_mesh_shardings = _merge_dpo_state(ctx.state_mesh_shardings, ctx.state_mesh_shardings.params["params"])
else:
state_mesh_shardings = ctx.state_mesh_shardings

params_shardings, state_mesh_shardings = sharding.maybe_update_params_sharding_with_opt(config, state_mesh_shardings)

p_train_step, p_eval_step = train_utils.jit_train_and_eval_step(
config,
model,
mesh,
ctx.model,
ctx.mesh,
state,
state_mesh_shardings,
train_step,
eval_step,
eval_data_iterator,
ctx.eval_data_iterator,
params_shardings,
)

with jax.set_mesh(mesh), nn_partitioning.axis_rules(config.logical_axis_rules):
with jax.set_mesh(ctx.mesh), nn_partitioning.axis_rules(config.logical_axis_rules):
shaped_batch = maxtext_utils.get_shaped_batch(config)
if config.shard_optimizer_over_data:
state = sharding.maybe_shard_with_name(state, state_mesh_shardings, config.shard_mode)
if config.compiled_trainstep_file == "": # compile only when there is no pre-compiled file loaded
compiled = p_train_step.lower(state, shaped_batch, init_rng).compile()
compiled = p_train_step.lower(state, shaped_batch, ctx.init_rng).compile()
compiled_stats = compiled.memory_analysis()
max_utils.print_compiled_memory_stats(compiled_stats)

start_step = get_first_step(state) # this is the start_step for training
prof = profiler.Profiler(config, offset_step=start_step)
metric_logger = MetricLogger(config=config, learning_rate_schedule=learning_rate_schedule)
metric_logger = MetricLogger(config=config, learning_rate_schedule=ctx.learning_rate_schedule)

# Write train config params, num model params, and XLA flags to tensorboard
metric_logger.write_setup_info_to_tensorboard(state.params)
Expand All @@ -445,11 +436,11 @@ def train_loop(config, recorder, state=None):
prof.maybe_activate_profiler(step, state)

with jax.profiler.StepTraceAnnotation("train", step_num=step):
example_batch = data_loader.load_next_batch(rampup_manager=rampup_manager)
example_batch = ctx.data_loader.load_next_batch(rampup_manager=ctx.rampup_manager)
# pylint: disable=not-callable
nextrng = jax.jit(jax.random.fold_in)(init_rng, step)
nextrng = jax.jit(jax.random.fold_in)(ctx.init_rng, step)
with maybe_record_goodput(recorder, GoodputEvent.STEP, step):
with jax.set_mesh(mesh), nn_partitioning.axis_rules(config.logical_axis_rules):
with jax.set_mesh(ctx.mesh), nn_partitioning.axis_rules(config.logical_axis_rules):
if config.shard_optimizer_over_data:
state = sharding.maybe_shard_with_name(state, state_mesh_shardings, config.shard_mode)
state, metrics = p_train_step(state, example_batch, nextrng)
Expand All @@ -458,7 +449,7 @@ def train_loop(config, recorder, state=None):
last_step_completion = datetime.datetime.now()

state_to_save = state if not config.use_dpo else _split_dpo_state(state)[0]
checkpointing.maybe_save_checkpoint(checkpoint_manager, state_to_save, config, data_iterator, step)
checkpointing.maybe_save_checkpoint(ctx.checkpoint_manager, state_to_save, config, ctx.data_iterator, step)

if config.dump_hlo and step == (config.dump_step if config.dump_step >= 0 else start_step):
jax.block_until_ready(state) # Ensure compilation has finished.
Expand All @@ -471,17 +462,17 @@ def train_loop(config, recorder, state=None):
)

if config.eval_interval > 0 and step > start_step and (step + 1) % config.eval_interval == 0:
assert eval_data_iterator
assert ctx.eval_data_iterator
# Explicitly reset the eval iterator and counters before starting the eval loop
eval_data_iterator.reset()
ctx.eval_data_iterator.reset()
metric_logger.reset_eval_metrics()

eval_step_count = 0
# pylint: disable=not-callable
for eval_batch in eval_data_iterator:
for eval_batch in ctx.eval_data_iterator:
if config.eval_steps > 0 and eval_step_count >= config.eval_steps:
break
with jax.set_mesh(mesh), nn_partitioning.axis_rules(config.logical_axis_rules):
with jax.set_mesh(ctx.mesh), nn_partitioning.axis_rules(config.logical_axis_rules):
eval_metrics = p_eval_step(state, eval_batch, nextrng)
metric_logger.record_eval_metrics(step, metrics=eval_metrics)
max_logging.log(f"Completed eval step {eval_step_count}")
Expand All @@ -500,10 +491,10 @@ def train_loop(config, recorder, state=None):

if config.save_checkpoint_on_completion:
state_to_save = state if not config.use_dpo else _split_dpo_state(state)[0]
checkpointing.maybe_save_checkpoint(checkpoint_manager, state_to_save, config, data_iterator)
if checkpoint_manager is not None:
checkpointing.maybe_save_checkpoint(ctx.checkpoint_manager, state_to_save, config, ctx.data_iterator)
if ctx.checkpoint_manager is not None:
# in case the last checkpoint_period checkpoint is still in progress
checkpoint_manager.wait_until_finished()
ctx.checkpoint_manager.wait_until_finished()
except exceptions.StopTraining as e:
max_logging.log(f"Training stopped: {str(e)}")
finally:
Expand Down
90 changes: 59 additions & 31 deletions src/MaxText/train_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,20 +16,54 @@
""" Utils that are only interesting for training in MaxText. """

import os
from typing import Any, Iterator

import jax

import optax

from pydantic import BaseModel, Field, ConfigDict

from MaxText import checkpointing
from MaxText import max_logging
from MaxText import max_utils
from MaxText import maxtext_utils
from MaxText import sharding
from MaxText import model_creation_utils
from MaxText import optimizers
from MaxText import sharding
from MaxText.data_loader import create_dataloader, DataLoader
from MaxText.dpo_utils import _merge_dpo_state
from MaxText.data_loader import create_dataloader
from MaxText.rampup_batch import create_rampup_manager
from MaxText.input_pipeline.input_pipeline_interface import create_data_iterator
from MaxText.rampup_batch import create_rampup_manager, RampupBatchManager
from MaxText.utils.goodput_utils import GoodputEvent
from MaxText.utils.goodput_utils import maybe_record_goodput
from MaxText import model_creation_utils

# Fix for Pydantic resolving TrainState annotations
ArrayTree = Any


class TrainContext(BaseModel):
"""Training context."""

model_config = ConfigDict(arbitrary_types_allowed=True)

init_rng: jax.Array = Field(description="PRNG key initialized for the training loop.")
checkpoint_manager: checkpointing.CheckpointManager | None = Field(
description="Orbax CheckpointManager for saving/restoring checkpoints."
)
state_mesh_shardings: Any = Field(description="TrainState containing sharding specifications for the model state.")
model: Any = Field(description="The initialized Flax (Linen or NNX) model instance.")
mesh: jax.sharding.Mesh = Field(description="JAX Mesh object defining the device topology.")
learning_rate_schedule: optax.Schedule | None = Field(description="Optax schedule function for learning rate.")
data_iterator: Iterator[Any] = Field(description="Iterator for training data.")
data_loader: DataLoader = Field(description="DataLoader instance handling sharding and batching.")
rampup_manager: RampupBatchManager | None = Field(description="Manager class for handling batch size rampup.")
eval_data_iterator: Iterator[Any] | None = Field(description="Iterator for evaluation data.")
state: Any = Field(description="Current TrainState containing parameters and optimizer state.")


# Explicitly rebuild the model to resolve possible ForwardRefs
TrainContext.model_rebuild()


def create_training_tools(config, model, mesh):
Expand Down Expand Up @@ -157,25 +191,19 @@ def jit_train_and_eval_step(
return p_train_step, p_eval_step


def setup_train_loop(config, recorder, devices=None):
"""Set up prerequisites for the training loop -
def setup_train_loop(config, recorder, devices=None) -> TrainContext:
"""Sets up prerequisites for the training loop.

checkpoint_manager, PRNG keys, Mesh, Model and optimizer.
Set up data iterator and tokenizer, initialize the model.
Sets up checkpoint_manager, PRNG keys, Mesh, Model and optimizer.
Sets up data iterator and tokenizer, initializes the model.

Args: config recorder
Args:
config: pyconfig.HyperParameters
recorder: GoodputRecorder
devices: List of devices to use.

Returns:
init_rng:
checkpoint_manager: Orbax checkpointer
state_mesh_annotations: the mesh annotations for the train state
model:
mesh:
learning_rate_schedule:
data_iterator:
data_loader:
rampup_manager: the class managing rampup batch sizes
state: the initialized train state
TrainContext: A dataclass containing the training context.
"""

with maybe_record_goodput(recorder, GoodputEvent.TPU_INIT):
Expand Down Expand Up @@ -252,18 +280,18 @@ def setup_train_loop(config, recorder, devices=None):
"Could not restore reference parameters for DPO from" f" '{os.path.join(str(config.checkpoint_dir), str(0))}'"
)

return (
init_rng,
checkpoint_manager,
state_mesh_shardings,
model,
mesh,
learning_rate_schedule,
data_iterator,
data_loader,
rampup_manager,
eval_data_iterator,
state,
return TrainContext(
init_rng=init_rng,
checkpoint_manager=checkpoint_manager,
state_mesh_shardings=state_mesh_shardings,
model=model,
mesh=mesh,
learning_rate_schedule=learning_rate_schedule,
data_iterator=data_iterator,
data_loader=data_loader,
rampup_manager=rampup_manager,
eval_data_iterator=eval_data_iterator,
state=state,
)


Expand Down
3 changes: 2 additions & 1 deletion tests/max_utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,8 @@ def test_unscan_train_state_params(self):
# Initialize a configuration for an 8B model.
config = self.init_pyconfig()

_, _, sharding, _, mesh, *_, state = setup_train_loop(config, None)
ctx = setup_train_loop(config, None)
sharding, mesh, state = ctx.state_mesh_shardings, ctx.mesh, ctx.state

scan_axis = config.param_scan_axis
num_layers = config.base_num_decoder_layers
Expand Down
5 changes: 3 additions & 2 deletions tools/gcs_benchmarks/standalone_dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,9 @@ def data_load_loop(config, state=None):
"""Main data loader loop.
Loads batches of data for each training step.
"""
_, _, _, _, mesh, _, data_iterator, _, _, _, state = setup_train_loop(config, recorder=None)
data_loader = DataLoader(config, mesh, data_iterator, None)
ctx = setup_train_loop(config, recorder=None)
state = ctx.state
data_loader = DataLoader(config, ctx.mesh, ctx.data_iterator, None)

example_batch = None

Expand Down
Loading