diff --git a/src/MaxText/elastic_train.py b/src/MaxText/elastic_train.py index e2ee2ec958..f8d6facd11 100644 --- a/src/MaxText/elastic_train.py +++ b/src/MaxText/elastic_train.py @@ -115,19 +115,15 @@ def elastic_handler( checkpoint_manager.close() with jax.default_device(elastic_manager.default_device): - ( - init_rng, - checkpoint_manager, - state_mesh_shardings, - model, - mesh, - learning_rate_schedule, - data_iterator, - _, - _, - _, - state, - ) = setup_train_loop(config, recorder, elastic_manager.good_devices) + ctx = setup_train_loop(config, recorder, elastic_manager.good_devices) + init_rng = ctx.init_rng + checkpoint_manager = ctx.checkpoint_manager + state_mesh_shardings = ctx.state_mesh_shardings + model = ctx.model + mesh = ctx.mesh + learning_rate_schedule = ctx.learning_rate_schedule + data_iterator = ctx.data_iterator + state = ctx.state p_train_step, _ = train_utils.jit_train_and_eval_step(config, model, mesh, state, state_mesh_shardings, train_step) @@ -171,19 +167,15 @@ def elastic_handler( def train_loop(config, elastic_manager, recorder, state=None): """Main Training loop.""" - ( - init_rng, - checkpoint_manager, - state_mesh_shardings, - model, - mesh, - learning_rate_schedule, - data_iterator, - _, - _, - _, - state, - ) = setup_train_loop(config, recorder) + ctx = setup_train_loop(config, recorder) + init_rng = ctx.init_rng + checkpoint_manager = ctx.checkpoint_manager + state_mesh_shardings = ctx.state_mesh_shardings + model = ctx.model + mesh = ctx.mesh + learning_rate_schedule = ctx.learning_rate_schedule + data_iterator = ctx.data_iterator + state = ctx.state p_train_step, _ = train_utils.jit_train_and_eval_step(config, model, mesh, state, state_mesh_shardings, train_step) with jax.set_mesh(mesh), nn_partitioning.axis_rules(config.logical_axis_rules): diff --git a/src/MaxText/sft_trainer.py b/src/MaxText/sft_trainer.py index 272d95d2dc..09f7ec3fea 100644 --- a/src/MaxText/sft_trainer.py +++ b/src/MaxText/sft_trainer.py @@ -58,19 +58,16 @@ def train_loop(config, recorder, state=None): if not config.use_sft: raise TypeError("Set use_sft to True to run Supervised Fine Tuning.") - ( - init_rng, - checkpoint_manager, - state_mesh_shardings, - model, - mesh, - learning_rate_schedule, - data_iterator, - _, - _, - eval_data_iterator, - state, - ) = setup_train_loop(config, recorder) + ctx = setup_train_loop(config, recorder) + init_rng = ctx.init_rng + checkpoint_manager = ctx.checkpoint_manager + state_mesh_shardings = ctx.state_mesh_shardings + model = ctx.model + mesh = ctx.mesh + learning_rate_schedule = ctx.learning_rate_schedule + data_iterator = ctx.data_iterator + eval_data_iterator = ctx.eval_data_iterator + state = ctx.state params_shardings, state_mesh_shardings = sharding.maybe_update_params_sharding_with_opt(config, state_mesh_shardings) diff --git a/src/MaxText/train.py b/src/MaxText/train.py index 3fae3e056a..f7f067dca7 100644 --- a/src/MaxText/train.py +++ b/src/MaxText/train.py @@ -389,52 +389,43 @@ def eval_step(model, config, state, data, dropout_rng): def train_loop(config, recorder, state=None): """Main Training loop.""" - ( - init_rng, - checkpoint_manager, - state_mesh_shardings, - model, - mesh, - learning_rate_schedule, - data_iterator, - data_loader, - rampup_manager, - eval_data_iterator, - state, - ) = train_utils.setup_train_loop(config, recorder) + ctx = train_utils.setup_train_loop(config, recorder) + state = ctx.state if config.use_dpo: if "reference_params" not in state.params: reference_params = jax.tree.map(jnp.copy, state.params["params"]) state = _merge_dpo_state(state, reference_params) - state_mesh_shardings = _merge_dpo_state(state_mesh_shardings, state_mesh_shardings.params["params"]) + state_mesh_shardings = _merge_dpo_state(ctx.state_mesh_shardings, ctx.state_mesh_shardings.params["params"]) + else: + state_mesh_shardings = ctx.state_mesh_shardings params_shardings, state_mesh_shardings = sharding.maybe_update_params_sharding_with_opt(config, state_mesh_shardings) p_train_step, p_eval_step = train_utils.jit_train_and_eval_step( config, - model, - mesh, + ctx.model, + ctx.mesh, state, state_mesh_shardings, train_step, eval_step, - eval_data_iterator, + ctx.eval_data_iterator, params_shardings, ) - with jax.set_mesh(mesh), nn_partitioning.axis_rules(config.logical_axis_rules): + with jax.set_mesh(ctx.mesh), nn_partitioning.axis_rules(config.logical_axis_rules): shaped_batch = maxtext_utils.get_shaped_batch(config) if config.shard_optimizer_over_data: state = sharding.maybe_shard_with_name(state, state_mesh_shardings, config.shard_mode) if config.compiled_trainstep_file == "": # compile only when there is no pre-compiled file loaded - compiled = p_train_step.lower(state, shaped_batch, init_rng).compile() + compiled = p_train_step.lower(state, shaped_batch, ctx.init_rng).compile() compiled_stats = compiled.memory_analysis() max_utils.print_compiled_memory_stats(compiled_stats) start_step = get_first_step(state) # this is the start_step for training prof = profiler.Profiler(config, offset_step=start_step) - metric_logger = MetricLogger(config=config, learning_rate_schedule=learning_rate_schedule) + metric_logger = MetricLogger(config=config, learning_rate_schedule=ctx.learning_rate_schedule) # Write train config params, num model params, and XLA flags to tensorboard metric_logger.write_setup_info_to_tensorboard(state.params) @@ -445,11 +436,11 @@ def train_loop(config, recorder, state=None): prof.maybe_activate_profiler(step, state) with jax.profiler.StepTraceAnnotation("train", step_num=step): - example_batch = data_loader.load_next_batch(rampup_manager=rampup_manager) + example_batch = ctx.data_loader.load_next_batch(rampup_manager=ctx.rampup_manager) # pylint: disable=not-callable - nextrng = jax.jit(jax.random.fold_in)(init_rng, step) + nextrng = jax.jit(jax.random.fold_in)(ctx.init_rng, step) with maybe_record_goodput(recorder, GoodputEvent.STEP, step): - with jax.set_mesh(mesh), nn_partitioning.axis_rules(config.logical_axis_rules): + with jax.set_mesh(ctx.mesh), nn_partitioning.axis_rules(config.logical_axis_rules): if config.shard_optimizer_over_data: state = sharding.maybe_shard_with_name(state, state_mesh_shardings, config.shard_mode) state, metrics = p_train_step(state, example_batch, nextrng) @@ -458,7 +449,7 @@ def train_loop(config, recorder, state=None): last_step_completion = datetime.datetime.now() state_to_save = state if not config.use_dpo else _split_dpo_state(state)[0] - checkpointing.maybe_save_checkpoint(checkpoint_manager, state_to_save, config, data_iterator, step) + checkpointing.maybe_save_checkpoint(ctx.checkpoint_manager, state_to_save, config, ctx.data_iterator, step) if config.dump_hlo and step == (config.dump_step if config.dump_step >= 0 else start_step): jax.block_until_ready(state) # Ensure compilation has finished. @@ -471,17 +462,17 @@ def train_loop(config, recorder, state=None): ) if config.eval_interval > 0 and step > start_step and (step + 1) % config.eval_interval == 0: - assert eval_data_iterator + assert ctx.eval_data_iterator # Explicitly reset the eval iterator and counters before starting the eval loop - eval_data_iterator.reset() + ctx.eval_data_iterator.reset() metric_logger.reset_eval_metrics() eval_step_count = 0 # pylint: disable=not-callable - for eval_batch in eval_data_iterator: + for eval_batch in ctx.eval_data_iterator: if config.eval_steps > 0 and eval_step_count >= config.eval_steps: break - with jax.set_mesh(mesh), nn_partitioning.axis_rules(config.logical_axis_rules): + with jax.set_mesh(ctx.mesh), nn_partitioning.axis_rules(config.logical_axis_rules): eval_metrics = p_eval_step(state, eval_batch, nextrng) metric_logger.record_eval_metrics(step, metrics=eval_metrics) max_logging.log(f"Completed eval step {eval_step_count}") @@ -500,10 +491,10 @@ def train_loop(config, recorder, state=None): if config.save_checkpoint_on_completion: state_to_save = state if not config.use_dpo else _split_dpo_state(state)[0] - checkpointing.maybe_save_checkpoint(checkpoint_manager, state_to_save, config, data_iterator) - if checkpoint_manager is not None: + checkpointing.maybe_save_checkpoint(ctx.checkpoint_manager, state_to_save, config, ctx.data_iterator) + if ctx.checkpoint_manager is not None: # in case the last checkpoint_period checkpoint is still in progress - checkpoint_manager.wait_until_finished() + ctx.checkpoint_manager.wait_until_finished() except exceptions.StopTraining as e: max_logging.log(f"Training stopped: {str(e)}") finally: diff --git a/src/MaxText/train_utils.py b/src/MaxText/train_utils.py index 992091b9c4..80c9b96565 100644 --- a/src/MaxText/train_utils.py +++ b/src/MaxText/train_utils.py @@ -16,20 +16,54 @@ """ Utils that are only interesting for training in MaxText. """ import os +from typing import Any, Iterator + import jax + +import optax + +from pydantic import BaseModel, Field, ConfigDict + from MaxText import checkpointing from MaxText import max_logging from MaxText import max_utils from MaxText import maxtext_utils -from MaxText import sharding +from MaxText import model_creation_utils from MaxText import optimizers +from MaxText import sharding +from MaxText.data_loader import create_dataloader, DataLoader from MaxText.dpo_utils import _merge_dpo_state -from MaxText.data_loader import create_dataloader -from MaxText.rampup_batch import create_rampup_manager from MaxText.input_pipeline.input_pipeline_interface import create_data_iterator +from MaxText.rampup_batch import create_rampup_manager, RampupBatchManager from MaxText.utils.goodput_utils import GoodputEvent from MaxText.utils.goodput_utils import maybe_record_goodput -from MaxText import model_creation_utils + +# Fix for Pydantic resolving TrainState annotations +ArrayTree = Any + + +class TrainContext(BaseModel): + """Training context.""" + + model_config = ConfigDict(arbitrary_types_allowed=True) + + init_rng: jax.Array = Field(description="PRNG key initialized for the training loop.") + checkpoint_manager: checkpointing.CheckpointManager | None = Field( + description="Orbax CheckpointManager for saving/restoring checkpoints." + ) + state_mesh_shardings: Any = Field(description="TrainState containing sharding specifications for the model state.") + model: Any = Field(description="The initialized Flax (Linen or NNX) model instance.") + mesh: jax.sharding.Mesh = Field(description="JAX Mesh object defining the device topology.") + learning_rate_schedule: optax.Schedule | None = Field(description="Optax schedule function for learning rate.") + data_iterator: Iterator[Any] = Field(description="Iterator for training data.") + data_loader: DataLoader = Field(description="DataLoader instance handling sharding and batching.") + rampup_manager: RampupBatchManager | None = Field(description="Manager class for handling batch size rampup.") + eval_data_iterator: Iterator[Any] | None = Field(description="Iterator for evaluation data.") + state: Any = Field(description="Current TrainState containing parameters and optimizer state.") + + +# Explicitly rebuild the model to resolve possible ForwardRefs +TrainContext.model_rebuild() def create_training_tools(config, model, mesh): @@ -157,25 +191,19 @@ def jit_train_and_eval_step( return p_train_step, p_eval_step -def setup_train_loop(config, recorder, devices=None): - """Set up prerequisites for the training loop - +def setup_train_loop(config, recorder, devices=None) -> TrainContext: + """Sets up prerequisites for the training loop. - checkpoint_manager, PRNG keys, Mesh, Model and optimizer. - Set up data iterator and tokenizer, initialize the model. + Sets up checkpoint_manager, PRNG keys, Mesh, Model and optimizer. + Sets up data iterator and tokenizer, initializes the model. - Args: config recorder + Args: + config: pyconfig.HyperParameters + recorder: GoodputRecorder + devices: List of devices to use. Returns: - init_rng: - checkpoint_manager: Orbax checkpointer - state_mesh_annotations: the mesh annotations for the train state - model: - mesh: - learning_rate_schedule: - data_iterator: - data_loader: - rampup_manager: the class managing rampup batch sizes - state: the initialized train state + TrainContext: A dataclass containing the training context. """ with maybe_record_goodput(recorder, GoodputEvent.TPU_INIT): @@ -252,18 +280,18 @@ def setup_train_loop(config, recorder, devices=None): "Could not restore reference parameters for DPO from" f" '{os.path.join(str(config.checkpoint_dir), str(0))}'" ) - return ( - init_rng, - checkpoint_manager, - state_mesh_shardings, - model, - mesh, - learning_rate_schedule, - data_iterator, - data_loader, - rampup_manager, - eval_data_iterator, - state, + return TrainContext( + init_rng=init_rng, + checkpoint_manager=checkpoint_manager, + state_mesh_shardings=state_mesh_shardings, + model=model, + mesh=mesh, + learning_rate_schedule=learning_rate_schedule, + data_iterator=data_iterator, + data_loader=data_loader, + rampup_manager=rampup_manager, + eval_data_iterator=eval_data_iterator, + state=state, ) diff --git a/tests/max_utils_test.py b/tests/max_utils_test.py index 3e9b1dac6d..e0d7e6c13c 100644 --- a/tests/max_utils_test.py +++ b/tests/max_utils_test.py @@ -130,7 +130,8 @@ def test_unscan_train_state_params(self): # Initialize a configuration for an 8B model. config = self.init_pyconfig() - _, _, sharding, _, mesh, *_, state = setup_train_loop(config, None) + ctx = setup_train_loop(config, None) + sharding, mesh, state = ctx.state_mesh_shardings, ctx.mesh, ctx.state scan_axis = config.param_scan_axis num_layers = config.base_num_decoder_layers diff --git a/tools/gcs_benchmarks/standalone_dataloader.py b/tools/gcs_benchmarks/standalone_dataloader.py index dd789e2e3a..3686887bff 100644 --- a/tools/gcs_benchmarks/standalone_dataloader.py +++ b/tools/gcs_benchmarks/standalone_dataloader.py @@ -38,8 +38,9 @@ def data_load_loop(config, state=None): """Main data loader loop. Loads batches of data for each training step. """ - _, _, _, _, mesh, _, data_iterator, _, _, _, state = setup_train_loop(config, recorder=None) - data_loader = DataLoader(config, mesh, data_iterator, None) + ctx = setup_train_loop(config, recorder=None) + state = ctx.state + data_loader = DataLoader(config, ctx.mesh, ctx.data_iterator, None) example_batch = None