From c7782082a21dcd7227cdba8f2eb28e33244afb03 Mon Sep 17 00:00:00 2001 From: Przemek Tredak Date: Mon, 1 Dec 2025 14:42:04 -0800 Subject: [PATCH 01/24] PoC of the changes Signed-off-by: Przemek Tredak --- tests/pytorch/attention/test_attention.py | 2 +- .../dot_product_attention.py | 2 +- transformer_engine/pytorch/module/base.py | 108 +++++---- .../pytorch/module/grouped_linear.py | 110 ++++----- .../pytorch/module/layernorm_linear.py | 156 ++++++------- .../pytorch/module/layernorm_mlp.py | 214 +++++++++--------- transformer_engine/pytorch/module/linear.py | 142 ++++++------ 7 files changed, 380 insertions(+), 354 deletions(-) diff --git a/tests/pytorch/attention/test_attention.py b/tests/pytorch/attention/test_attention.py index eb7905bcd5..b29d1289f8 100644 --- a/tests/pytorch/attention/test_attention.py +++ b/tests/pytorch/attention/test_attention.py @@ -2751,7 +2751,7 @@ def forward( cu_seqlens, max_s, ) -> torch.Tensor: - with self.prepare_forward(inp, num_gemms=3) as inp: + with self.prepare_forward_ctx(inp, num_gemms=3) as inp: out = _custom_mha_fp8.apply( inp, self.qkv_weight, diff --git a/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py b/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py index 6e5a12a103..384a163b6f 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py @@ -1000,7 +1000,7 @@ def forward( cases. It is ignored for other backends and when context parallelism is enabled. """ - with self.prepare_forward( + with self.prepare_forward_ctx( query_layer, num_gemms=3, allow_non_contiguous=True, diff --git a/transformer_engine/pytorch/module/base.py b/transformer_engine/pytorch/module/base.py index ad5cd04341..c1e24871ac 100644 --- a/transformer_engine/pytorch/module/base.py +++ b/transformer_engine/pytorch/module/base.py @@ -50,6 +50,7 @@ is_non_tn_fp8_gemm_supported, torch_get_autocast_gpu_dtype, get_nvtx_range_context, + _nvtx_enabled, ) from ..tensor.storage.float8_blockwise_tensor_storage import Float8BlockwiseQTensorStorage from ...common.recipe import DelayedScaling, Recipe @@ -641,12 +642,15 @@ def __init__(self) -> None: "fp8_parameters", } + def fast_set_attr(self, name: str, value: Any) -> None: + self.__dict__[name] = value + def __setattr__(self, name: str, value: Any) -> None: if name in TransformerEngineBaseModule._fast_setattr_names: # torch.nn.Module has a custom __setattr__ that handles # modules, parameters, and buffers. This is unnecessary # overhead when setting plain attrs. - self.__dict__[name] = value + self.fast_set_attr(name, value) else: # Default case super().__setattr__(name, value) @@ -927,7 +931,7 @@ def set_activation_dtype(self, inp: torch.Tensor) -> None: """Get activation data type for AMP.""" # Native AMP (`torch.autocast`) gets highest priority if torch.is_autocast_enabled(): - self.activation_dtype = torch_get_autocast_gpu_dtype() + self.fast_set_attr("activation_dtype", torch_get_autocast_gpu_dtype()) return # All checks after this have already been performed once, thus skip @@ -942,7 +946,7 @@ def set_activation_dtype(self, inp: torch.Tensor) -> None: "Data types for parameters must match when outside of autocasted region. " f" Found input dtype: {dtype} and {name!r} dtype: {param.dtype}" ) - self.activation_dtype = dtype + self.fast_set_attr("activation_dtype", dtype) def set_tensor_parallel_group(self, tp_group: Union[dist_group_type, None]) -> None: """ @@ -971,48 +975,51 @@ def _get_fp8_params(self) -> Union[List[torch.Tensor], None]: # assume FP8 execution. def init_fp8_metadata(self, num_gemms: int = 1) -> None: """Initialize fp8 related metadata and tensors during fprop.""" - _original_recipe = self.fp8_meta.get("recipe", None) - - self.fp8_parameters = FP8GlobalStateManager.with_fp8_parameters() - self.fp8 = FP8GlobalStateManager.is_fp8_enabled() - self.fp8_calibration = FP8GlobalStateManager.is_fp8_calibration() - fp8_enabled = self.fp8 or self.fp8_calibration - self.fp8_meta["fp8_checkpoint"] = self.fp8 or self.fp8_calibration - - if self.fp8_parameters or fp8_enabled: - if ( - self.fp8_initialized - and FP8GlobalStateManager.get_fp8_recipe() == self.fp8_meta["recipe"] - ): + meta = self.fp8_meta + + fp8 = FP8GlobalStateManager.is_fp8_enabled() + fp8_parameters = FP8GlobalStateManager.with_fp8_parameters() + fp8_calibration = FP8GlobalStateManager.is_fp8_calibration() + self.fast_set_attr("fp8_parameters", fp8_parameters) + self.fast_set_attr("fp8", fp8) + self.fast_set_attr("fp8_calibration", fp8_calibration) + fp8_enabled = fp8 or fp8_calibration + meta["fp8_checkpoint"] = fp8_enabled + + _original_recipe = None + + if fp8_parameters or fp8_enabled: + _original_recipe = meta.get("recipe", None) + if self.fp8_initialized and FP8GlobalStateManager.get_fp8_recipe() == _original_recipe: # FP8 init has already been run and recipe is the same, don't do anything. return - self.fp8_meta["recipe"] = FP8GlobalStateManager.get_fp8_recipe() + meta["recipe"] = FP8GlobalStateManager.get_fp8_recipe() else: # If fp8 isn't enabled, turn off and return. - self.fp8_initialized = False + self.fast_set_attr("fp8_initialized", False) return - if self.fp8_parameters and not self.fp8_initialized: - self.fp8_meta["num_gemms"] = num_gemms - self.init_fp8_meta_tensors(self.fp8_meta["recipe"]) + if fp8_parameters and not self.fp8_initialized: + meta["num_gemms"] = num_gemms + self.init_fp8_meta_tensors(meta["recipe"]) if fp8_enabled: # Set FP8 and other FP8 metadata - self.fp8_meta["num_gemms"] = num_gemms - self.fp8_meta["fp8_group"] = FP8GlobalStateManager.get_fp8_group() + meta["num_gemms"] = num_gemms + meta["fp8_group"] = FP8GlobalStateManager.get_fp8_group() # Set FP8_MAX per tensor according to recipe - if hasattr(self.fp8_meta["recipe"], "fp8_format"): - self.fp8_meta["fp8_max_fwd"] = self.fp8_meta["recipe"].fp8_format.value.max_fwd - self.fp8_meta["fp8_max_bwd"] = self.fp8_meta["recipe"].fp8_format.value.max_bwd + if hasattr(meta["recipe"], "fp8_format"): + meta["fp8_max_fwd"] = meta["recipe"].fp8_format.value.max_fwd + meta["fp8_max_bwd"] = meta["recipe"].fp8_format.value.max_bwd # Allocate scales and amaxes - self.init_fp8_meta_tensors(self.fp8_meta["recipe"]) + self.init_fp8_meta_tensors(meta["recipe"]) self.fp8_initialized = True - self.fp8_meta["recipe"] = FP8GlobalStateManager.get_fp8_recipe() + meta["recipe"] = FP8GlobalStateManager.get_fp8_recipe() - _current_recipe = self.fp8_meta["recipe"] + _current_recipe = meta["recipe"] if _original_recipe is not None and not ( issubclass(_current_recipe.__class__, _original_recipe.__class__) or issubclass(_original_recipe.__class__, _current_recipe.__class__) @@ -1025,22 +1032,18 @@ def init_fp8_metadata(self, num_gemms: int = 1) -> None: # Clear cached workspaces as they were created with the old recipe/quantizer type self._fp8_workspaces.clear() - @contextmanager def prepare_forward( self, inp: torch.Tensor, num_gemms: int = 1, allow_non_contiguous: bool = False, allow_different_data_and_param_types: bool = False, - ) -> Generator[torch.Tensor, None, None]: - """Checks and prep for FWD. - The context manager is needed because there isn't a way for a module to know - if it's the last FP8 module in the forward autocast. It is useful - to setup the forward aggregated amax reduction for every module - just in case. The autocast exit will pick up the most recent one. - """ - self.allow_different_data_and_param_types = allow_different_data_and_param_types - self.forwarded_at_least_once = True + ) -> torch.Tensor: + """Checks and prepare for FWD execution.""" + self.fast_set_attr( + "allow_different_data_and_param_types", allow_different_data_and_param_types + ) + self.fast_set_attr("forwarded_at_least_once", True) # Activation recomputation is used and this is the second forward phase. if self.fp8 and in_fp8_activation_recompute_phase(): @@ -1071,13 +1074,32 @@ def prepare_forward( if self.training and is_fp8_activation_recompute_enabled(): FP8GlobalStateManager.copy_forward_fp8_meta_tensors_for_recompute(self.fp8_meta) - with get_nvtx_range_context(self.__class__.__name__ + " forward"): - if not allow_non_contiguous and not inp.is_contiguous(): - inp = inp.contiguous() - yield inp + # with get_nvtx_range_context(self.__class__.__name__ + " forward"): + if _nvtx_enabled(): + torch.cuda.nvtx.range_push(self.__class__.__name__ + " forward") + if not allow_non_contiguous and not inp.is_contiguous(): + inp = inp.contiguous() + return inp + def end_forward(self): + delayed_scaling_recipe = self.fp8 and self.fp8_meta["recipe"].delayed() if delayed_scaling_recipe and self.fp8 and in_fp8_activation_recompute_phase(): FP8GlobalStateManager.restore_fp8_meta_tensors(self.fp8_meta) + if _nvtx_enabled(): + torch.cuda.nvtx.range_pop() + + @contextmanager + def prepare_forward_ctx( + self, + inp: torch.Tensor, + num_gemms: int = 1, + allow_non_contiguous: bool = False, + allow_different_data_and_param_types: bool = False, + ) -> Generator[torch.Tensor, None, None]: + yield self.prepare_forward( + inp, num_gemms, allow_non_contiguous, allow_different_data_and_param_types + ) + self.end_forward() def set_nccl_overlap_warning_if_tp(self) -> None: """When using TP, the NCCL communication needs to be scheduled diff --git a/transformer_engine/pytorch/module/grouped_linear.py b/transformer_engine/pytorch/module/grouped_linear.py index 1e6f0b00ab..9f7f01dbb1 100644 --- a/transformer_engine/pytorch/module/grouped_linear.py +++ b/transformer_engine/pytorch/module/grouped_linear.py @@ -793,60 +793,62 @@ def forward( is_grad_enabled = torch.is_grad_enabled() - with self.prepare_forward(inp, num_gemms=self.num_gemms) as inp: - weight_tensors = self._get_weight_tensors() - bias_tensors = [getattr(self, f"bias{i}") for i in range(self.num_gemms)] - - quantizers = self._get_quantizers() if not debug else self._get_debug_quantizers() - - if debug: - if self.no_debug_features_active(list(chain(*quantizers))): - debug = False - quantizers = self._get_quantizers() - - if isinstance(weight_tensors, QuantizedTensorStorage): - raise RuntimeError("FP8 weights are not supported in debug mode.") - - ( - input_quantizers, - weight_quantizers, - output_quantizers, - grad_input_quantizers, - grad_weight_quantizers, - grad_output_quantizers, - ) = quantizers - - if is_grad_enabled: - linear_fn = _GroupedLinear.apply - autograd_ctx = [] - else: - linear_fn = _GroupedLinear.forward - autograd_ctx = [None] - - non_tensor_args = ( - m_splits, - self.apply_bias, - is_first_microbatch, - self.fp8, - self.fp8_calibration, - self.wgrad_store, - input_quantizers, - weight_quantizers, - output_quantizers, - grad_input_quantizers, - grad_weight_quantizers, - grad_output_quantizers, - self.fuse_wgrad_accumulation, - is_cpu_offload_enabled(), - self.sequence_parallel, - self.activation_dtype, - is_grad_enabled, - self, - None, # skip_fp8_weight_update - self.save_original_input, - debug, - ) - out = linear_fn(*autograd_ctx, inp, non_tensor_args, *weight_tensors, *bias_tensors) + inp = self.prepare_forward(inp, num_gemms=self.num_gemms) + weight_tensors = self._get_weight_tensors() + bias_tensors = [getattr(self, f"bias{i}") for i in range(self.num_gemms)] + + quantizers = self._get_quantizers() if not debug else self._get_debug_quantizers() + + if debug: + if self.no_debug_features_active(list(chain(*quantizers))): + debug = False + quantizers = self._get_quantizers() + + if isinstance(weight_tensors, QuantizedTensorStorage): + raise RuntimeError("FP8 weights are not supported in debug mode.") + + ( + input_quantizers, + weight_quantizers, + output_quantizers, + grad_input_quantizers, + grad_weight_quantizers, + grad_output_quantizers, + ) = quantizers + + if is_grad_enabled: + linear_fn = _GroupedLinear.apply + autograd_ctx = [] + else: + linear_fn = _GroupedLinear.forward + autograd_ctx = [None] + + non_tensor_args = ( + m_splits, + self.apply_bias, + is_first_microbatch, + self.fp8, + self.fp8_calibration, + self.wgrad_store, + input_quantizers, + weight_quantizers, + output_quantizers, + grad_input_quantizers, + grad_weight_quantizers, + grad_output_quantizers, + self.fuse_wgrad_accumulation, + is_cpu_offload_enabled(), + self.sequence_parallel, + self.activation_dtype, + is_grad_enabled, + self, + None, # skip_fp8_weight_update + self.save_original_input, + debug, + ) + out = linear_fn(*autograd_ctx, inp, non_tensor_args, *weight_tensors, *bias_tensors) + + self.end_forward() if self.return_bias: return out, [cast_if_needed(b, self.activation_dtype) for b in bias_tensors] diff --git a/transformer_engine/pytorch/module/layernorm_linear.py b/transformer_engine/pytorch/module/layernorm_linear.py index 13b94f2327..2322fff935 100644 --- a/transformer_engine/pytorch/module/layernorm_linear.py +++ b/transformer_engine/pytorch/module/layernorm_linear.py @@ -1514,87 +1514,89 @@ def forward( ).is_fp8_ubuf(): fp8_grad = True - with self.prepare_forward( + inp = self.prepare_forward( inp, allow_non_contiguous=False # removed .contiguous from inside the layer - ) as inp: + ) - # Get concatenated weight and bias tensors - weight_tensor, bias_tensor = self._get_weight_and_bias_tensors() + # Get concatenated weight and bias tensors + weight_tensor, bias_tensor = self._get_weight_and_bias_tensors() - quantizers = ( - self._get_quantizers(fp8_output, fp8_grad, is_grad_enabled) - if not debug - else self._get_debug_quantizers(fp8_output, fp8_grad, is_grad_enabled) - ) - if debug: - if self.no_debug_features_active(quantizers): - debug = False - quantizers = self._get_quantizers(fp8_output, fp8_grad, is_grad_enabled) + quantizers = ( + self._get_quantizers(fp8_output, fp8_grad, is_grad_enabled) + if not debug + else self._get_debug_quantizers(fp8_output, fp8_grad, is_grad_enabled) + ) + if debug: + if self.no_debug_features_active(quantizers): + debug = False + quantizers = self._get_quantizers(fp8_output, fp8_grad, is_grad_enabled) - ( - input_quantizer, - weight_quantizer, - output_quantizer, - grad_input_quantizer, - grad_weight_quantizer, - grad_output_quantizer, - ) = quantizers - - if is_grad_enabled: - fwd_fn = _LayerNormLinear.apply - autograd_ctx = [] - else: - fwd_fn = _LayerNormLinear.forward - autograd_ctx = [None] - non_tensor_args = ( - self.eps, - is_first_microbatch, - self.fp8, - self.fp8_calibration, - self.wgrad_store, - self.fuse_wgrad_accumulation, - input_quantizer, - weight_quantizer, - output_quantizer, - grad_input_quantizer, - grad_weight_quantizer, - grad_output_quantizer, - is_cpu_offload_enabled(), - self.tp_group, - self.tp_size, - self.sequence_parallel, - self.tp_size > 1, - self.activation_dtype, - self.parallel_mode, - self.return_layernorm_output, - self.return_layernorm_output_gathered, - is_grad_enabled, - self.fwd_ln_sm_margin if is_grad_enabled else self.inf_ln_sm_margin, - self.bwd_ln_sm_margin, - self.zero_centered_gamma, - self.normalization, - self.ub_overlap_ag_fprop, - self.ub_overlap_rs_fprop, - self.ub_overlap_ag_dgrad, - self.ub_overlap_rs_dgrad, - self.ub_bulk_wgrad, - self.ub_bulk_dgrad, - self.ub_name, - self.fsdp_group, - self, - skip_fp8_weight_update, - self.symmetric_ar_type, - debug, - ) - out = fwd_fn( - *autograd_ctx, - inp, - self.layer_norm_weight, - self.layer_norm_bias, - weight_tensor, - bias_tensor if self.apply_bias and not self.gemm_bias_unfused_add else None, - non_tensor_args, - ) + ( + input_quantizer, + weight_quantizer, + output_quantizer, + grad_input_quantizer, + grad_weight_quantizer, + grad_output_quantizer, + ) = quantizers + + if is_grad_enabled: + fwd_fn = _LayerNormLinear.apply + autograd_ctx = [] + else: + fwd_fn = _LayerNormLinear.forward + autograd_ctx = [None] + non_tensor_args = ( + self.eps, + is_first_microbatch, + self.fp8, + self.fp8_calibration, + self.wgrad_store, + self.fuse_wgrad_accumulation, + input_quantizer, + weight_quantizer, + output_quantizer, + grad_input_quantizer, + grad_weight_quantizer, + grad_output_quantizer, + is_cpu_offload_enabled(), + self.tp_group, + self.tp_size, + self.sequence_parallel, + self.tp_size > 1, + self.activation_dtype, + self.parallel_mode, + self.return_layernorm_output, + self.return_layernorm_output_gathered, + is_grad_enabled, + self.fwd_ln_sm_margin if is_grad_enabled else self.inf_ln_sm_margin, + self.bwd_ln_sm_margin, + self.zero_centered_gamma, + self.normalization, + self.ub_overlap_ag_fprop, + self.ub_overlap_rs_fprop, + self.ub_overlap_ag_dgrad, + self.ub_overlap_rs_dgrad, + self.ub_bulk_wgrad, + self.ub_bulk_dgrad, + self.ub_name, + self.fsdp_group, + self, + skip_fp8_weight_update, + self.symmetric_ar_type, + debug, + ) + out = fwd_fn( + *autograd_ctx, + inp, + self.layer_norm_weight, + self.layer_norm_bias, + weight_tensor, + bias_tensor if self.apply_bias and not self.gemm_bias_unfused_add else None, + non_tensor_args, + ) + + self.end_forward() if self.return_layernorm_output: out, ln_out = out diff --git a/transformer_engine/pytorch/module/layernorm_mlp.py b/transformer_engine/pytorch/module/layernorm_mlp.py index 4256028c8b..b40f933b16 100644 --- a/transformer_engine/pytorch/module/layernorm_mlp.py +++ b/transformer_engine/pytorch/module/layernorm_mlp.py @@ -2052,115 +2052,117 @@ def forward( if get_ub("fc2_fprop", FP8GlobalStateManager.is_fp8_enabled()).is_fp8_ubuf(): fp8_output = True - with self.prepare_forward(inp, num_gemms=2) as inp: + inp = self.prepare_forward(inp, num_gemms=2) - quantizers = ( - self._get_quantizers(fp8_output, is_grad_enabled) - if not debug - else self._get_debug_quantizers(fp8_output, is_grad_enabled) - ) - if debug: - if self.no_debug_features_active(quantizers): - debug = False - quantizers = self._get_quantizers(fp8_output, is_grad_enabled) + quantizers = ( + self._get_quantizers(fp8_output, is_grad_enabled) + if not debug + else self._get_debug_quantizers(fp8_output, is_grad_enabled) + ) + if debug: + if self.no_debug_features_active(quantizers): + debug = False + quantizers = self._get_quantizers(fp8_output, is_grad_enabled) - # Get quantizers - ( - fc1_input_quantizer, - fc1_weight_quantizer, - fc1_output_quantizer, - fc1_grad_input_quantizer, - fc1_grad_weight_quantizer, - fc1_grad_output_quantizer, - fc2_input_quantizer, - fc2_weight_quantizer, - fc2_output_quantizer, - fc2_grad_input_quantizer, - fc2_grad_weight_quantizer, - fc2_grad_output_quantizer, - ) = quantizers - - # Get weight tensors - fc1_weight, fc2_weight = self._get_weight_tensors() - fc1_bias = self.fc1_bias if self.use_bias else None - fc2_bias = self.fc2_bias if self.use_bias else None - if not self.fp8: - if isinstance(fc1_weight, Float8Tensor): - fc1_weight = fc1_weight.dequantize() - if isinstance(fc2_weight, Float8Tensor): - fc2_weight = fc2_weight.dequantize() - - # Disable bias_gelu_nvfusion for determinism checkpointing in non-reentrant mode - if self.bias_gelu_nvfusion and not use_reentrant_activation_recompute(): - self.bias_gelu_nvfusion = False + # Get quantizers + ( + fc1_input_quantizer, + fc1_weight_quantizer, + fc1_output_quantizer, + fc1_grad_input_quantizer, + fc1_grad_weight_quantizer, + fc1_grad_output_quantizer, + fc2_input_quantizer, + fc2_weight_quantizer, + fc2_output_quantizer, + fc2_grad_input_quantizer, + fc2_grad_weight_quantizer, + fc2_grad_output_quantizer, + ) = quantizers - if is_grad_enabled: - fwd_fn = _LayerNormMLP.apply - autograd_ctx = [] - else: - fwd_fn = _LayerNormMLP.forward - autograd_ctx = [None] - - non_tensor_args = ( - self.eps, - is_first_microbatch, - self.fp8, - self.fp8_calibration, - self.wgrad_store, - self.fuse_wgrad_accumulation, - fc1_input_quantizer, - fc1_weight_quantizer, - fc1_output_quantizer, - fc1_grad_input_quantizer, - fc1_grad_weight_quantizer, - fc1_grad_output_quantizer, - fc2_input_quantizer, - fc2_weight_quantizer, - fc2_output_quantizer, - fc2_grad_input_quantizer, - fc2_grad_weight_quantizer, - fc2_grad_output_quantizer, - is_cpu_offload_enabled(), - self.tp_group, - self.tp_size, - self.sequence_parallel, - self.tp_size > 1, - self.activation_dtype, - self.return_layernorm_output, - self.return_layernorm_output_gathered, - self.bias_gelu_nvfusion and not self.fp8 and not debug, - self.set_parallel_mode, - is_grad_enabled, - self.fwd_ln_sm_margin if is_grad_enabled else self.inf_ln_sm_margin, - self.bwd_ln_sm_margin, - self.zero_centered_gamma, - self.activation, - self.activation_params, - self.normalization, - self.ub_overlap_ag, - self.ub_overlap_rs, - self.ub_overlap_rs_dgrad, - self.ub_bulk_dgrad, - self.ub_bulk_wgrad, - self.gemm_gelu_fusion and not debug, - self.fsdp_group, - self, - skip_fp8_weight_update, - self.symmetric_ar_type, - self.checkpoint, - debug, - ) - out = fwd_fn( - *autograd_ctx, - inp, - self.layer_norm_weight, - self.layer_norm_bias, - fc1_weight, - fc1_bias, - fc2_weight, - fc2_bias if self.apply_bias and not self.gemm_bias_unfused_add else None, - non_tensor_args, - ) + # Get weight tensors + fc1_weight, fc2_weight = self._get_weight_tensors() + fc1_bias = self.fc1_bias if self.use_bias else None + fc2_bias = self.fc2_bias if self.use_bias else None + if not self.fp8: + if isinstance(fc1_weight, Float8Tensor): + fc1_weight = fc1_weight.dequantize() + if isinstance(fc2_weight, Float8Tensor): + fc2_weight = fc2_weight.dequantize() + + # Disable bias_gelu_nvfusion for determinism checkpointing in non-reentrant mode + if self.bias_gelu_nvfusion and not use_reentrant_activation_recompute(): + self.bias_gelu_nvfusion = False + + if is_grad_enabled: + fwd_fn = _LayerNormMLP.apply + autograd_ctx = [] + else: + fwd_fn = _LayerNormMLP.forward + autograd_ctx = [None] + + non_tensor_args = ( + self.eps, + is_first_microbatch, + self.fp8, + self.fp8_calibration, + self.wgrad_store, + self.fuse_wgrad_accumulation, + fc1_input_quantizer, + fc1_weight_quantizer, + fc1_output_quantizer, + fc1_grad_input_quantizer, + fc1_grad_weight_quantizer, + fc1_grad_output_quantizer, + fc2_input_quantizer, + fc2_weight_quantizer, + fc2_output_quantizer, + fc2_grad_input_quantizer, + fc2_grad_weight_quantizer, + fc2_grad_output_quantizer, + is_cpu_offload_enabled(), + self.tp_group, + self.tp_size, + self.sequence_parallel, + self.tp_size > 1, + self.activation_dtype, + self.return_layernorm_output, + self.return_layernorm_output_gathered, + self.bias_gelu_nvfusion and not self.fp8 and not debug, + self.set_parallel_mode, + is_grad_enabled, + self.fwd_ln_sm_margin if is_grad_enabled else self.inf_ln_sm_margin, + self.bwd_ln_sm_margin, + self.zero_centered_gamma, + self.activation, + self.activation_params, + self.normalization, + self.ub_overlap_ag, + self.ub_overlap_rs, + self.ub_overlap_rs_dgrad, + self.ub_bulk_dgrad, + self.ub_bulk_wgrad, + self.gemm_gelu_fusion and not debug, + self.fsdp_group, + self, + skip_fp8_weight_update, + self.symmetric_ar_type, + self.checkpoint, + debug, + ) + out = fwd_fn( + *autograd_ctx, + inp, + self.layer_norm_weight, + self.layer_norm_bias, + fc1_weight, + fc1_bias, + fc2_weight, + fc2_bias if self.apply_bias and not self.gemm_bias_unfused_add else None, + non_tensor_args, + ) + + self.end_forward() if self.return_layernorm_output: out, ln_out = out diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index b8349f84a0..fd9129765f 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -1398,81 +1398,79 @@ def forward( ).is_fp8_ubuf(): fp8_grad = True - with self.prepare_forward( - inp, - allow_non_contiguous=isinstance(inp, QuantizedTensor), - ) as inp: + inp = self.prepare_forward(inp, allow_non_contiguous=isinstance(inp, QuantizedTensor)) - weight_tensor, bias_tensor = self._get_weight_and_bias_tensors() + weight_tensor, bias_tensor = self._get_weight_and_bias_tensors() - quantizers = ( - self._get_quantizers(fp8_output, fp8_grad, is_grad_enabled) - if not debug - else self._get_debug_quantizers(fp8_output, fp8_grad, is_grad_enabled) - ) - if debug: - if self.no_debug_features_active(quantizers): - debug = False - quantizers = self._get_quantizers(fp8_output, fp8_grad, is_grad_enabled) + quantizers = ( + self._get_quantizers(fp8_output, fp8_grad, is_grad_enabled) + if not debug + else self._get_debug_quantizers(fp8_output, fp8_grad, is_grad_enabled) + ) + if debug: + if self.no_debug_features_active(quantizers): + debug = False + quantizers = self._get_quantizers(fp8_output, fp8_grad, is_grad_enabled) - ( - input_quantizer, - weight_quantizer, - output_quantizer, - grad_input_quantizer, - grad_weight_quantizer, - grad_output_quantizer, - ) = quantizers - - if is_grad_enabled: - linear_fn = _Linear.apply - autograd_ctx = [] - else: - linear_fn = _Linear.forward - autograd_ctx = [None] - - non_tensor_args = ( - is_first_microbatch, - self.fp8, - self.fp8_calibration, - self.wgrad_store, - input_quantizer, - weight_quantizer, - output_quantizer, - grad_input_quantizer, - grad_weight_quantizer, - grad_output_quantizer, - self.fuse_wgrad_accumulation, - is_cpu_offload_enabled(), - self.tp_group, - self.tp_size, - self.sequence_parallel, - self.tp_size > 1, - self.activation_dtype, - self.parallel_mode, - is_grad_enabled, - self.ub_overlap_rs_fprop, - self.ub_overlap_ag_dgrad, - self.ub_overlap_ag_fprop, - self.ub_overlap_rs_dgrad, - self.ub_bulk_dgrad, - self.ub_bulk_wgrad, - self.ub_name, - fp8_output, - self.fsdp_group, - self, - skip_fp8_weight_update, - self.symmetric_ar_type, - self.save_original_input, - debug, - ) - out = linear_fn( - *autograd_ctx, - weight_tensor, - inp, - bias_tensor if (self.apply_bias and not self.gemm_bias_unfused_add) else None, - non_tensor_args, - ) + ( + input_quantizer, + weight_quantizer, + output_quantizer, + grad_input_quantizer, + grad_weight_quantizer, + grad_output_quantizer, + ) = quantizers + + if is_grad_enabled: + linear_fn = _Linear.apply + autograd_ctx = [] + else: + linear_fn = _Linear.forward + autograd_ctx = [None] + + non_tensor_args = ( + is_first_microbatch, + self.fp8, + self.fp8_calibration, + self.wgrad_store, + input_quantizer, + weight_quantizer, + output_quantizer, + grad_input_quantizer, + grad_weight_quantizer, + grad_output_quantizer, + self.fuse_wgrad_accumulation, + is_cpu_offload_enabled(), + self.tp_group, + self.tp_size, + self.sequence_parallel, + self.tp_size > 1, + self.activation_dtype, + self.parallel_mode, + is_grad_enabled, + self.ub_overlap_rs_fprop, + self.ub_overlap_ag_dgrad, + self.ub_overlap_ag_fprop, + self.ub_overlap_rs_dgrad, + self.ub_bulk_dgrad, + self.ub_bulk_wgrad, + self.ub_name, + fp8_output, + self.fsdp_group, + self, + skip_fp8_weight_update, + self.symmetric_ar_type, + self.save_original_input, + debug, + ) + out = linear_fn( + *autograd_ctx, + weight_tensor, + inp, + bias_tensor if (self.apply_bias and not self.gemm_bias_unfused_add) else None, + non_tensor_args, + ) + self.end_forward() if self.gemm_bias_unfused_add: out = out + cast_if_needed(bias_tensor, self.activation_dtype) From 1b1c574a2601ea2e0fed134b954745d14651a556 Mon Sep 17 00:00:00 2001 From: Przemek Tredak Date: Tue, 2 Dec 2025 14:45:07 -0800 Subject: [PATCH 02/24] Early exit from the Free function for the empty tensor Signed-off-by: Przemek Tredak --- transformer_engine/common/transformer_engine.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/transformer_engine/common/transformer_engine.cpp b/transformer_engine/common/transformer_engine.cpp index 370d9723cf..11c8af92f0 100644 --- a/transformer_engine/common/transformer_engine.cpp +++ b/transformer_engine/common/transformer_engine.cpp @@ -450,9 +450,9 @@ class TensorAllocator { } void Free(NVTETensor t) { - std::lock_guard lock(mutex); uintptr_t index = reinterpret_cast(t); if (index == 0) return; + std::lock_guard lock(mutex); NVTE_CHECK(index <= memory.size(), "Invalid tensor."); free_list.push_back(index); // Clean up @@ -560,9 +560,9 @@ class GroupedTensorAllocator { } void Free(NVTEGroupedTensor t) { - std::lock_guard lock(mutex); uintptr_t index = reinterpret_cast(t); if (index == 0) return; + std::lock_guard lock(mutex); NVTE_CHECK(index <= memory.size(), "Invalid grouped tensor."); free_list.push_back(index); // Clean up From ac14119d6cc723eb3d513de1c09171a7ca95412d Mon Sep 17 00:00:00 2001 From: Przemek Tredak Date: Tue, 2 Dec 2025 14:49:01 -0800 Subject: [PATCH 03/24] Use the proper function for nvtx range Signed-off-by: Przemek Tredak --- transformer_engine/pytorch/module/base.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/transformer_engine/pytorch/module/base.py b/transformer_engine/pytorch/module/base.py index c1e24871ac..deea09702c 100644 --- a/transformer_engine/pytorch/module/base.py +++ b/transformer_engine/pytorch/module/base.py @@ -50,7 +50,8 @@ is_non_tn_fp8_gemm_supported, torch_get_autocast_gpu_dtype, get_nvtx_range_context, - _nvtx_enabled, + nvtx_range_push, + nvtx_range_pop, ) from ..tensor.storage.float8_blockwise_tensor_storage import Float8BlockwiseQTensorStorage from ...common.recipe import DelayedScaling, Recipe @@ -1074,9 +1075,7 @@ def prepare_forward( if self.training and is_fp8_activation_recompute_enabled(): FP8GlobalStateManager.copy_forward_fp8_meta_tensors_for_recompute(self.fp8_meta) - # with get_nvtx_range_context(self.__class__.__name__ + " forward"): - if _nvtx_enabled(): - torch.cuda.nvtx.range_push(self.__class__.__name__ + " forward") + nvtx_range_push(self.__class__.__name__ + " forward") if not allow_non_contiguous and not inp.is_contiguous(): inp = inp.contiguous() return inp @@ -1085,8 +1084,7 @@ def end_forward(self): delayed_scaling_recipe = self.fp8 and self.fp8_meta["recipe"].delayed() if delayed_scaling_recipe and self.fp8 and in_fp8_activation_recompute_phase(): FP8GlobalStateManager.restore_fp8_meta_tensors(self.fp8_meta) - if _nvtx_enabled(): - torch.cuda.nvtx.range_pop() + nvtx_range_pop() @contextmanager def prepare_forward_ctx( From 616af19dc1bc4aa27b38c05f3f3088a54f6abbd8 Mon Sep 17 00:00:00 2001 From: Przemek Tredak Date: Fri, 12 Dec 2025 11:06:39 -0800 Subject: [PATCH 04/24] Only do mark_not_offload when the cpu_offloading is enabled Signed-off-by: Przemek Tredak --- transformer_engine/pytorch/module/linear.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index fd9129765f..6adf479d82 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -428,8 +428,8 @@ def forward( # weights if weights are externally touched outside this module ctx.weight_object = weight - if cpu_offloading: mark_not_offload(weight, weightmat, bias) + # TODO(ksivamani): Check memory usage tensors_to_save, tensor_objects = prepare_for_saving( saved_inputmat, From c0d2cccb916032eb2ac509832b1a9beb881660db Mon Sep 17 00:00:00 2001 From: Przemek Tredak Date: Mon, 15 Dec 2025 15:40:23 -0800 Subject: [PATCH 05/24] First pass on making the setattr issue not come back Signed-off-by: Przemek Tredak --- qa/L0_pytorch_debug_unittest/test.sh | 16 +++---- qa/L0_pytorch_unittest/test.sh | 52 +++++++++++----------- qa/L1_pytorch_distributed_unittest/test.sh | 24 +++++----- qa/L1_pytorch_onnx_unittest/test.sh | 2 +- qa/L1_pytorch_thunder_integration/test.sh | 2 +- transformer_engine/pytorch/module/base.py | 36 ++++++++------- 6 files changed, 67 insertions(+), 65 deletions(-) diff --git a/qa/L0_pytorch_debug_unittest/test.sh b/qa/L0_pytorch_debug_unittest/test.sh index ce65bc4305..6d40f44b08 100644 --- a/qa/L0_pytorch_debug_unittest/test.sh +++ b/qa/L0_pytorch_debug_unittest/test.sh @@ -28,16 +28,16 @@ mkdir -p "$XML_LOG_DIR" pip install pytest==8.2.1 || error_exit "Failed to install pytest" -pytest -v -s --junitxml=$XML_LOG_DIR/test_sanity.xml $TE_PATH/tests/pytorch/debug/test_sanity.py --feature_dirs=$NVTE_TEST_NVINSPECT_FEATURE_DIRS || test_fail "test_sanity.py" -pytest -v -s --junitxml=$XML_LOG_DIR/test_config.xml $TE_PATH/tests/pytorch/debug/test_config.py --feature_dirs=$NVTE_TEST_NVINSPECT_FEATURE_DIRS || test_fail "test_config.py" -pytest -v -s --junitxml=$XML_LOG_DIR/test_numerics.xml $TE_PATH/tests/pytorch/debug/test_numerics.py --feature_dirs=$NVTE_TEST_NVINSPECT_FEATURE_DIRS || test_fail "test_numerics.py" -pytest -v -s --junitxml=$XML_LOG_DIR/test_log.xml $TE_PATH/tests/pytorch/debug/test_log.py --feature_dirs=$NVTE_TEST_NVINSPECT_FEATURE_DIRS --configs_dir=$NVTE_TEST_NVINSPECT_CONFIGS_DIR || test_fail "test_log.py" -NVTE_TORCH_COMPILE=0 pytest -v -s --junitxml=$XML_LOG_DIR/test_api_features.xml $TE_PATH/tests/pytorch/debug/test_api_features.py --feature_dirs=$NVTE_TEST_NVINSPECT_FEATURE_DIRS --configs_dir=$NVTE_TEST_NVINSPECT_CONFIGS_DIR || test_fail "test_api_features.py" -pytest -v -s --junitxml=$XML_LOG_DIR/test_perf.xml $TE_PATH/tests/pytorch/debug/test_perf.py --feature_dirs=$NVTE_TEST_NVINSPECT_FEATURE_DIRS --configs_dir=$NVTE_TEST_NVINSPECT_CONFIGS_DIR || test_fail "test_perf.py" +pytest -c $TE_PATH/tests/pytorch/pytest.ini -v -s --junitxml=$XML_LOG_DIR/test_sanity.xml $TE_PATH/tests/pytorch/debug/test_sanity.py --feature_dirs=$NVTE_TEST_NVINSPECT_FEATURE_DIRS || test_fail "test_sanity.py" +pytest -c $TE_PATH/tests/pytorch/pytest.ini -v -s --junitxml=$XML_LOG_DIR/test_config.xml $TE_PATH/tests/pytorch/debug/test_config.py --feature_dirs=$NVTE_TEST_NVINSPECT_FEATURE_DIRS || test_fail "test_config.py" +pytest -c $TE_PATH/tests/pytorch/pytest.ini -v -s --junitxml=$XML_LOG_DIR/test_numerics.xml $TE_PATH/tests/pytorch/debug/test_numerics.py --feature_dirs=$NVTE_TEST_NVINSPECT_FEATURE_DIRS || test_fail "test_numerics.py" +pytest -c $TE_PATH/tests/pytorch/pytest.ini -v -s --junitxml=$XML_LOG_DIR/test_log.xml $TE_PATH/tests/pytorch/debug/test_log.py --feature_dirs=$NVTE_TEST_NVINSPECT_FEATURE_DIRS --configs_dir=$NVTE_TEST_NVINSPECT_CONFIGS_DIR || test_fail "test_log.py" +NVTE_TORCH_COMPILE=0 pytest -c $TE_PATH/tests/pytorch/pytest.ini -v -s --junitxml=$XML_LOG_DIR/test_api_features.xml $TE_PATH/tests/pytorch/debug/test_api_features.py --feature_dirs=$NVTE_TEST_NVINSPECT_FEATURE_DIRS --configs_dir=$NVTE_TEST_NVINSPECT_CONFIGS_DIR || test_fail "test_api_features.py" +pytest -c $TE_PATH/tests/pytorch/pytest.ini -v -s --junitxml=$XML_LOG_DIR/test_perf.xml $TE_PATH/tests/pytorch/debug/test_perf.py --feature_dirs=$NVTE_TEST_NVINSPECT_FEATURE_DIRS --configs_dir=$NVTE_TEST_NVINSPECT_CONFIGS_DIR || test_fail "test_perf.py" # standard sanity and numerics tests with initialized debug -NVTE_TEST_NVINSPECT_ENABLED=1 NVTE_TEST_NVINSPECT_CONFIG_FILE=$NVTE_TEST_NVINSPECT_DUMMY_CONFIG_FILE NVTE_TEST_NVINSPECT_FEATURE_DIRS=$NVTE_TEST_NVINSPECT_FEATURE_DIRS PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 pytest -v -s --junitxml=$XML_LOG_DIR/test_sanity_2.xml $TE_PATH/tests/pytorch/test_sanity.py || test_fail "debug test_sanity.py" -NVTE_TEST_NVINSPECT_ENABLED=1 NVTE_TEST_NVINSPECT_CONFIG_FILE=$NVTE_TEST_NVINSPECT_DUMMY_CONFIG_FILE NVTE_TEST_NVINSPECT_FEATURE_DIRS=$NVTE_TEST_NVINSPECT_FEATURE_DIRS PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 NVTE_FUSED_ATTN=0 pytest -v -s --junitxml=$XML_LOG_DIR/test_numerics_2.xml $TE_PATH/tests/pytorch/test_numerics.py || test_fail "debug test_numerics.py" +NVTE_TEST_NVINSPECT_ENABLED=1 NVTE_TEST_NVINSPECT_CONFIG_FILE=$NVTE_TEST_NVINSPECT_DUMMY_CONFIG_FILE NVTE_TEST_NVINSPECT_FEATURE_DIRS=$NVTE_TEST_NVINSPECT_FEATURE_DIRS PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 pytest -c $TE_PATH/tests/pytorch/pytest.ini -v -s --junitxml=$XML_LOG_DIR/test_sanity_2.xml $TE_PATH/tests/pytorch/test_sanity.py || test_fail "debug test_sanity.py" +NVTE_TEST_NVINSPECT_ENABLED=1 NVTE_TEST_NVINSPECT_CONFIG_FILE=$NVTE_TEST_NVINSPECT_DUMMY_CONFIG_FILE NVTE_TEST_NVINSPECT_FEATURE_DIRS=$NVTE_TEST_NVINSPECT_FEATURE_DIRS PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 NVTE_FUSED_ATTN=0 pytest -c $TE_PATH/tests/pytorch/pytest.ini -v -s --junitxml=$XML_LOG_DIR/test_numerics_2.xml $TE_PATH/tests/pytorch/test_numerics.py || test_fail "debug test_numerics.py" if [ "$RET" -ne 0 ]; then echo "Error in the following test cases:$FAILED_CASES" diff --git a/qa/L0_pytorch_unittest/test.sh b/qa/L0_pytorch_unittest/test.sh index 21eed28367..454536358c 100644 --- a/qa/L0_pytorch_unittest/test.sh +++ b/qa/L0_pytorch_unittest/test.sh @@ -24,32 +24,32 @@ mkdir -p "$XML_LOG_DIR" pip3 install pytest==8.2.1 || error_exit "Failed to install pytest" -python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_sanity.xml $TE_PATH/tests/pytorch/test_sanity.py || test_fail "test_sanity.py" -python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_recipe.xml $TE_PATH/tests/pytorch/test_recipe.py || test_fail "test_recipe.py" -python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_deferred_init.xml $TE_PATH/tests/pytorch/test_deferred_init.py || test_fail "test_deferred_init.py" -PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 NVTE_FUSED_ATTN=0 python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_numerics.xml $TE_PATH/tests/pytorch/test_numerics.py || test_fail "test_numerics.py" -PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 NVTE_FUSED_ATTN=0 python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_cuda_graphs.xml $TE_PATH/tests/pytorch/test_cuda_graphs.py || test_fail "test_cuda_graphs.py" -python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_jit.xml $TE_PATH/tests/pytorch/test_jit.py || test_fail "test_jit.py" -python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_fused_rope.xml $TE_PATH/tests/pytorch/test_fused_rope.py || test_fail "test_fused_rope.py" -python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_nvfp4.xml $TE_PATH/tests/pytorch/nvfp4 || test_fail "test_nvfp4" -python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_quantized_tensor.xml $TE_PATH/tests/pytorch/test_quantized_tensor.py || test_fail "test_quantized_tensor.py" -python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_float8blockwisetensor.xml $TE_PATH/tests/pytorch/test_float8blockwisetensor.py || test_fail "test_float8blockwisetensor.py" -python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_float8_blockwise_scaling_exact.xml $TE_PATH/tests/pytorch/test_float8_blockwise_scaling_exact.py || test_fail "test_float8_blockwise_scaling_exact.py" -python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_float8_blockwise_gemm_exact.xml $TE_PATH/tests/pytorch/test_float8_blockwise_gemm_exact.py || test_fail "test_float8_blockwise_gemm_exact.py" -python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_gqa.xml $TE_PATH/tests/pytorch/test_gqa.py || test_fail "test_gqa.py" -python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_fused_optimizer.xml $TE_PATH/tests/pytorch/test_fused_optimizer.py || test_fail "test_fused_optimizer.py" -python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_multi_tensor.xml $TE_PATH/tests/pytorch/test_multi_tensor.py || test_fail "test_multi_tensor.py" -python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_fusible_ops.xml $TE_PATH/tests/pytorch/test_fusible_ops.py || test_fail "test_fusible_ops.py" -python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_permutation.xml $TE_PATH/tests/pytorch/test_permutation.py || test_fail "test_permutation.py" -python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_parallel_cross_entropy.xml $TE_PATH/tests/pytorch/test_parallel_cross_entropy.py || test_fail "test_parallel_cross_entropy.py" -python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_cpu_offloading.xml $TE_PATH/tests/pytorch/test_cpu_offloading.py || test_fail "test_cpu_offloading.py" -NVTE_FLASH_ATTN=0 NVTE_CPU_OFFLOAD_V1=1 python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_cpu_offloading_v1.xml $TE_PATH/tests/pytorch/test_cpu_offloading_v1.py || test_fail "test_cpu_offloading_v1.py" -python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_attention.xml $TE_PATH/tests/pytorch/attention/test_attention.py || test_fail "test_attention.py" -python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_kv_cache.xml $TE_PATH/tests/pytorch/attention/test_kv_cache.py || test_fail "test_kv_cache.py" -python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_hf_integration.xml $TE_PATH/tests/pytorch/test_hf_integration.py || test_fail "test_hf_integration.py" -NVTE_TEST_CHECKPOINT_ARTIFACT_PATH=$TE_PATH/artifacts/tests/pytorch/test_checkpoint python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_checkpoint.xml $TE_PATH/tests/pytorch/test_checkpoint.py || test_fail "test_checkpoint.py" -python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_fused_router.xml $TE_PATH/tests/pytorch/test_fused_router.py || test_fail "test_fused_router.py" -python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_partial_cast.xml $TE_PATH/tests/pytorch/test_partial_cast.py || test_fail "test_partial_cast.py" +python3 -m pytest -c $TE_PATH/tests/pytorch/pytest.ini --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_sanity.xml $TE_PATH/tests/pytorch/test_sanity.py || test_fail "test_sanity.py" +python3 -m pytest -c $TE_PATH/tests/pytorch/pytest.ini --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_recipe.xml $TE_PATH/tests/pytorch/test_recipe.py || test_fail "test_recipe.py" +python3 -m pytest -c $TE_PATH/tests/pytorch/pytest.ini --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_deferred_init.xml $TE_PATH/tests/pytorch/test_deferred_init.py || test_fail "test_deferred_init.py" +PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 NVTE_FUSED_ATTN=0 python3 -m pytest -c $TE_PATH/tests/pytorch/pytest.ini --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_numerics.xml $TE_PATH/tests/pytorch/test_numerics.py || test_fail "test_numerics.py" +PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 NVTE_FUSED_ATTN=0 python3 -m pytest -c $TE_PATH/tests/pytorch/pytest.ini --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_cuda_graphs.xml $TE_PATH/tests/pytorch/test_cuda_graphs.py || test_fail "test_cuda_graphs.py" +python3 -m pytest -c $TE_PATH/tests/pytorch/pytest.ini --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_jit.xml $TE_PATH/tests/pytorch/test_jit.py || test_fail "test_jit.py" +python3 -m pytest -c $TE_PATH/tests/pytorch/pytest.ini --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_fused_rope.xml $TE_PATH/tests/pytorch/test_fused_rope.py || test_fail "test_fused_rope.py" +python3 -m pytest -c $TE_PATH/tests/pytorch/pytest.ini --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_nvfp4.xml $TE_PATH/tests/pytorch/nvfp4 || test_fail "test_nvfp4" +python3 -m pytest -c $TE_PATH/tests/pytorch/pytest.ini --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_quantized_tensor.xml $TE_PATH/tests/pytorch/test_quantized_tensor.py || test_fail "test_quantized_tensor.py" +python3 -m pytest -c $TE_PATH/tests/pytorch/pytest.ini --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_float8blockwisetensor.xml $TE_PATH/tests/pytorch/test_float8blockwisetensor.py || test_fail "test_float8blockwisetensor.py" +python3 -m pytest -c $TE_PATH/tests/pytorch/pytest.ini --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_float8_blockwise_scaling_exact.xml $TE_PATH/tests/pytorch/test_float8_blockwise_scaling_exact.py || test_fail "test_float8_blockwise_scaling_exact.py" +python3 -m pytest -c $TE_PATH/tests/pytorch/pytest.ini --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_float8_blockwise_gemm_exact.xml $TE_PATH/tests/pytorch/test_float8_blockwise_gemm_exact.py || test_fail "test_float8_blockwise_gemm_exact.py" +python3 -m pytest -c $TE_PATH/tests/pytorch/pytest.ini --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_gqa.xml $TE_PATH/tests/pytorch/test_gqa.py || test_fail "test_gqa.py" +python3 -m pytest -c $TE_PATH/tests/pytorch/pytest.ini --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_fused_optimizer.xml $TE_PATH/tests/pytorch/test_fused_optimizer.py || test_fail "test_fused_optimizer.py" +python3 -m pytest -c $TE_PATH/tests/pytorch/pytest.ini --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_multi_tensor.xml $TE_PATH/tests/pytorch/test_multi_tensor.py || test_fail "test_multi_tensor.py" +python3 -m pytest -c $TE_PATH/tests/pytorch/pytest.ini --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_fusible_ops.xml $TE_PATH/tests/pytorch/test_fusible_ops.py || test_fail "test_fusible_ops.py" +python3 -m pytest -c $TE_PATH/tests/pytorch/pytest.ini --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_permutation.xml $TE_PATH/tests/pytorch/test_permutation.py || test_fail "test_permutation.py" +python3 -m pytest -c $TE_PATH/tests/pytorch/pytest.ini --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_parallel_cross_entropy.xml $TE_PATH/tests/pytorch/test_parallel_cross_entropy.py || test_fail "test_parallel_cross_entropy.py" +python3 -m pytest -c $TE_PATH/tests/pytorch/pytest.ini --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_cpu_offloading.xml $TE_PATH/tests/pytorch/test_cpu_offloading.py || test_fail "test_cpu_offloading.py" +NVTE_FLASH_ATTN=0 NVTE_CPU_OFFLOAD_V1=1 python3 -m pytest -c $TE_PATH/tests/pytorch/pytest.ini --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_cpu_offloading_v1.xml $TE_PATH/tests/pytorch/test_cpu_offloading_v1.py || test_fail "test_cpu_offloading_v1.py" +python3 -m pytest -c $TE_PATH/tests/pytorch/pytest.ini --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_attention.xml $TE_PATH/tests/pytorch/attention/test_attention.py || test_fail "test_attention.py" +python3 -m pytest -c $TE_PATH/tests/pytorch/pytest.ini --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_kv_cache.xml $TE_PATH/tests/pytorch/attention/test_kv_cache.py || test_fail "test_kv_cache.py" +python3 -m pytest -c $TE_PATH/tests/pytorch/pytest.ini --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_hf_integration.xml $TE_PATH/tests/pytorch/test_hf_integration.py || test_fail "test_hf_integration.py" +NVTE_TEST_CHECKPOINT_ARTIFACT_PATH=$TE_PATH/artifacts/tests/pytorch/test_checkpoint python3 -m pytest -c $TE_PATH/tests/pytorch/pytest.ini --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_checkpoint.xml $TE_PATH/tests/pytorch/test_checkpoint.py || test_fail "test_checkpoint.py" +python3 -m pytest -c $TE_PATH/tests/pytorch/pytest.ini --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_fused_router.xml $TE_PATH/tests/pytorch/test_fused_router.py || test_fail "test_fused_router.py" +python3 -m pytest -c $TE_PATH/tests/pytorch/pytest.ini --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_partial_cast.xml $TE_PATH/tests/pytorch/test_partial_cast.py || test_fail "test_partial_cast.py" if [ "$RET" -ne 0 ]; then echo "Error in the following test cases:$FAILED_CASES" diff --git a/qa/L1_pytorch_distributed_unittest/test.sh b/qa/L1_pytorch_distributed_unittest/test.sh index 9d868d99cf..12439422c4 100644 --- a/qa/L1_pytorch_distributed_unittest/test.sh +++ b/qa/L1_pytorch_distributed_unittest/test.sh @@ -22,16 +22,16 @@ mkdir -p "$XML_LOG_DIR" pip3 install pytest==8.2.1 || error_exit "Failed to install pytest" -python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_sanity.xml $TE_PATH/tests/pytorch/distributed/test_sanity.py || test_fail "test_sanity.py" -python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_numerics.xml $TE_PATH/tests/pytorch/distributed/test_numerics.py || test_fail "test_numerics.py" -python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_numerics_exact.xml $TE_PATH/tests/pytorch/distributed/test_numerics_exact.py || test_fail "test_numerics_exact.py" -python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_fusible_ops.xml $TE_PATH/tests/pytorch/distributed/test_fusible_ops.py || test_fail "test_fusible_ops.py" -python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_torch_fsdp2.xml $TE_PATH/tests/pytorch/distributed/test_torch_fsdp2.py || test_fail "test_torch_fsdp2.py" -python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_comm_gemm_overlap.xml $TE_PATH/tests/pytorch/distributed/test_comm_gemm_overlap.py || test_fail "test_comm_gemm_overlap.py" -python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_fusible_ops_with_userbuffers.xml $TE_PATH/tests/pytorch/distributed/test_fusible_ops_with_userbuffers.py || test_fail "test_fusible_ops_with_userbuffers.py" -python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_attention_with_cp.xml $TE_PATH/tests/pytorch/attention/test_attention_with_cp.py || test_fail "test_attention_with_cp.py" -python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_cp_utils.xml $TE_PATH/tests/pytorch/attention/test_cp_utils.py || test_fail "test_cp_utils.py" -python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_cast_master_weights_to_fp8.xml $TE_PATH/tests/pytorch/distributed/test_cast_master_weights_to_fp8.py || test_fail "test_cast_master_weights_to_fp8.py" +python3 -m pytest -c $TE_PATH/tests/pytorch/pytest.ini -v -s --junitxml=$XML_LOG_DIR/pytest_test_sanity.xml $TE_PATH/tests/pytorch/distributed/test_sanity.py || test_fail "test_sanity.py" +python3 -m pytest -c $TE_PATH/tests/pytorch/pytest.ini -v -s --junitxml=$XML_LOG_DIR/pytest_test_numerics.xml $TE_PATH/tests/pytorch/distributed/test_numerics.py || test_fail "test_numerics.py" +python3 -m pytest -c $TE_PATH/tests/pytorch/pytest.ini -v -s --junitxml=$XML_LOG_DIR/pytest_test_numerics_exact.xml $TE_PATH/tests/pytorch/distributed/test_numerics_exact.py || test_fail "test_numerics_exact.py" +python3 -m pytest -c $TE_PATH/tests/pytorch/pytest.ini -v -s --junitxml=$XML_LOG_DIR/pytest_test_fusible_ops.xml $TE_PATH/tests/pytorch/distributed/test_fusible_ops.py || test_fail "test_fusible_ops.py" +python3 -m pytest -c $TE_PATH/tests/pytorch/pytest.ini -v -s --junitxml=$XML_LOG_DIR/pytest_test_torch_fsdp2.xml $TE_PATH/tests/pytorch/distributed/test_torch_fsdp2.py || test_fail "test_torch_fsdp2.py" +python3 -m pytest -c $TE_PATH/tests/pytorch/pytest.ini -v -s --junitxml=$XML_LOG_DIR/pytest_test_comm_gemm_overlap.xml $TE_PATH/tests/pytorch/distributed/test_comm_gemm_overlap.py || test_fail "test_comm_gemm_overlap.py" +python3 -m pytest -c $TE_PATH/tests/pytorch/pytest.ini -v -s --junitxml=$XML_LOG_DIR/pytest_test_fusible_ops_with_userbuffers.xml $TE_PATH/tests/pytorch/distributed/test_fusible_ops_with_userbuffers.py || test_fail "test_fusible_ops_with_userbuffers.py" +python3 -m pytest -c $TE_PATH/tests/pytorch/pytest.ini -v -s --junitxml=$XML_LOG_DIR/pytest_test_attention_with_cp.xml $TE_PATH/tests/pytorch/attention/test_attention_with_cp.py || test_fail "test_attention_with_cp.py" +python3 -m pytest -c $TE_PATH/tests/pytorch/pytest.ini -v -s --junitxml=$XML_LOG_DIR/pytest_test_cp_utils.xml $TE_PATH/tests/pytorch/attention/test_cp_utils.py || test_fail "test_cp_utils.py" +python3 -m pytest -c $TE_PATH/tests/pytorch/pytest.ini -v -s --junitxml=$XML_LOG_DIR/pytest_test_cast_master_weights_to_fp8.xml $TE_PATH/tests/pytorch/distributed/test_cast_master_weights_to_fp8.py || test_fail "test_cast_master_weights_to_fp8.py" # debug tests @@ -42,9 +42,9 @@ python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_cast_master_weights_ : ${NVTE_TEST_NVINSPECT_DUMMY_CONFIG_FILE:=$TE_PATH/tests/pytorch/debug/test_configs/dummy_feature.yaml} : ${NVTE_TEST_NVINSPECT_FEATURE_DIRS:=$TE_PATH/transformer_engine/debug/features} -pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_distributed.xml $TE_PATH/tests/pytorch/debug/test_distributed.py --feature_dirs=$NVTE_TEST_NVINSPECT_FEATURE_DIRS || test_fail "debug test_distributed.py" +pytest -c $TE_PATH/tests/pytorch/pytest.ini -v -s --junitxml=$XML_LOG_DIR/pytest_test_distributed.xml $TE_PATH/tests/pytorch/debug/test_distributed.py --feature_dirs=$NVTE_TEST_NVINSPECT_FEATURE_DIRS || test_fail "debug test_distributed.py" # standard numerics tests with initialized debug -NVTE_TEST_NVINSPECT_ENABLED=1 NVTE_TEST_NVINSPECT_CONFIG_FILE=$NVTE_TEST_NVINSPECT_DUMMY_CONFIG_FILE NVTE_TEST_NVINSPECT_FEATURE_DIRS=$NVTE_TEST_NVINSPECT_FEATURE_DIRS pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_numerics_2.xml $TE_PATH/tests/pytorch/distributed/test_numerics.py || test_fail "debug test_numerics.py" +NVTE_TEST_NVINSPECT_ENABLED=1 NVTE_TEST_NVINSPECT_CONFIG_FILE=$NVTE_TEST_NVINSPECT_DUMMY_CONFIG_FILE NVTE_TEST_NVINSPECT_FEATURE_DIRS=$NVTE_TEST_NVINSPECT_FEATURE_DIRS pytest -c $TE_PATH/tests/pytorch/pytest.ini -v -s --junitxml=$XML_LOG_DIR/pytest_test_numerics_2.xml $TE_PATH/tests/pytorch/distributed/test_numerics.py || test_fail "debug test_numerics.py" if [ "$RET" -ne 0 ]; then echo "Error in the following test cases:$FAILED_CASES" diff --git a/qa/L1_pytorch_onnx_unittest/test.sh b/qa/L1_pytorch_onnx_unittest/test.sh index b3a520e129..0e84a5ca5e 100644 --- a/qa/L1_pytorch_onnx_unittest/test.sh +++ b/qa/L1_pytorch_onnx_unittest/test.sh @@ -6,4 +6,4 @@ : ${XML_LOG_DIR:=/logs} mkdir -p "$XML_LOG_DIR" -python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/test_onnx_export.xml $TE_PATH/tests/pytorch/test_onnx_export.py +python3 -m pytest -c $TE_PATH/tests/pytorch/pytest.ini --tb=auto --junitxml=$XML_LOG_DIR/test_onnx_export.xml $TE_PATH/tests/pytorch/test_onnx_export.py diff --git a/qa/L1_pytorch_thunder_integration/test.sh b/qa/L1_pytorch_thunder_integration/test.sh index 8c3fdc8cdb..a5fd33cda9 100644 --- a/qa/L1_pytorch_thunder_integration/test.sh +++ b/qa/L1_pytorch_thunder_integration/test.sh @@ -9,7 +9,7 @@ set -x mkdir -p "$XML_LOG_DIR" pip3 install pytest==8.1.1 pytest-benchmark==5.1.0 -python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest.xml ${THUNDER_PATH}/thunder/tests/test_transformer_engine_executor.py +python3 -m pytest -c $TE_PATH/tests/pytorch/pytest.ini -v -s --junitxml=$XML_LOG_DIR/pytest.xml ${THUNDER_PATH}/thunder/tests/test_transformer_engine_executor.py # Check return code # Note: Return code 5 is fine. Lightning tests are skipped on systems diff --git a/transformer_engine/pytorch/module/base.py b/transformer_engine/pytorch/module/base.py index deea09702c..837b8cb959 100644 --- a/transformer_engine/pytorch/module/base.py +++ b/transformer_engine/pytorch/module/base.py @@ -643,18 +643,20 @@ def __init__(self) -> None: "fp8_parameters", } - def fast_set_attr(self, name: str, value: Any) -> None: + def fast_setattr(self, name: str, value: Any) -> None: self.__dict__[name] = value + def module_setattr(self, name: str, value: Any) -> None: + super().__setattr__(name, value) + def __setattr__(self, name: str, value: Any) -> None: - if name in TransformerEngineBaseModule._fast_setattr_names: - # torch.nn.Module has a custom __setattr__ that handles - # modules, parameters, and buffers. This is unnecessary - # overhead when setting plain attrs. - self.fast_set_attr(name, value) - else: - # Default case - super().__setattr__(name, value) + warnings.warn( + """The default implementation of torch.nn.Module introduces significant CPU overhead + when setting attributes and is therefore not recommended. Please use the explicit calls + (fast_setattr for setting regular values and module_setattr for setting parameters, + children modules and buffers).""", + RuntimeWarning) + self.module_setattr(name, value) def adjust_amax_history_length(self, length: int, fwd: Optional[bool] = None) -> None: """ @@ -932,7 +934,7 @@ def set_activation_dtype(self, inp: torch.Tensor) -> None: """Get activation data type for AMP.""" # Native AMP (`torch.autocast`) gets highest priority if torch.is_autocast_enabled(): - self.fast_set_attr("activation_dtype", torch_get_autocast_gpu_dtype()) + self.fast_setattr("activation_dtype", torch_get_autocast_gpu_dtype()) return # All checks after this have already been performed once, thus skip @@ -947,7 +949,7 @@ def set_activation_dtype(self, inp: torch.Tensor) -> None: "Data types for parameters must match when outside of autocasted region. " f" Found input dtype: {dtype} and {name!r} dtype: {param.dtype}" ) - self.fast_set_attr("activation_dtype", dtype) + self.fast_setattr("activation_dtype", dtype) def set_tensor_parallel_group(self, tp_group: Union[dist_group_type, None]) -> None: """ @@ -981,9 +983,9 @@ def init_fp8_metadata(self, num_gemms: int = 1) -> None: fp8 = FP8GlobalStateManager.is_fp8_enabled() fp8_parameters = FP8GlobalStateManager.with_fp8_parameters() fp8_calibration = FP8GlobalStateManager.is_fp8_calibration() - self.fast_set_attr("fp8_parameters", fp8_parameters) - self.fast_set_attr("fp8", fp8) - self.fast_set_attr("fp8_calibration", fp8_calibration) + self.fast_setattr("fp8_parameters", fp8_parameters) + self.fast_setattr("fp8", fp8) + self.fast_setattr("fp8_calibration", fp8_calibration) fp8_enabled = fp8 or fp8_calibration meta["fp8_checkpoint"] = fp8_enabled @@ -997,7 +999,7 @@ def init_fp8_metadata(self, num_gemms: int = 1) -> None: meta["recipe"] = FP8GlobalStateManager.get_fp8_recipe() else: # If fp8 isn't enabled, turn off and return. - self.fast_set_attr("fp8_initialized", False) + self.fast_setattr("fp8_initialized", False) return if fp8_parameters and not self.fp8_initialized: @@ -1041,10 +1043,10 @@ def prepare_forward( allow_different_data_and_param_types: bool = False, ) -> torch.Tensor: """Checks and prepare for FWD execution.""" - self.fast_set_attr( + self.fast_setattr( "allow_different_data_and_param_types", allow_different_data_and_param_types ) - self.fast_set_attr("forwarded_at_least_once", True) + self.fast_setattr("forwarded_at_least_once", True) # Activation recomputation is used and this is the second forward phase. if self.fp8 and in_fp8_activation_recompute_phase(): From 68b6f74d6fe065560cdacf684af2b2666897e585 Mon Sep 17 00:00:00 2001 From: Przemek Tredak Date: Mon, 15 Dec 2025 15:51:30 -0800 Subject: [PATCH 06/24] Actually add pytest.ini Signed-off-by: Przemek Tredak --- tests/pytorch/pytest.ini | 8 ++++++++ 1 file changed, 8 insertions(+) create mode 100644 tests/pytorch/pytest.ini diff --git a/tests/pytorch/pytest.ini b/tests/pytorch/pytest.ini new file mode 100644 index 0000000000..49111713b8 --- /dev/null +++ b/tests/pytorch/pytest.ini @@ -0,0 +1,8 @@ +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +[pytest] +filterWarnings= + error:RuntimeWarning + From 1c5434cb8bbf01d62e59a9fba23e9aa681be7bff Mon Sep 17 00:00:00 2001 From: Przemek Tredak Date: Mon, 15 Dec 2025 16:28:11 -0800 Subject: [PATCH 07/24] Changes to __init__ Signed-off-by: Przemek Tredak --- tests/pytorch/pytest.ini | 4 +- transformer_engine/pytorch/module/base.py | 50 +++++++++-------------- 2 files changed, 22 insertions(+), 32 deletions(-) diff --git a/tests/pytorch/pytest.ini b/tests/pytorch/pytest.ini index 49111713b8..e90989721b 100644 --- a/tests/pytorch/pytest.ini +++ b/tests/pytorch/pytest.ini @@ -3,6 +3,6 @@ # See LICENSE for license information. [pytest] -filterWarnings= - error:RuntimeWarning +filterwarnings= + error::RuntimeWarning diff --git a/transformer_engine/pytorch/module/base.py b/transformer_engine/pytorch/module/base.py index 837b8cb959..3e234fde45 100644 --- a/transformer_engine/pytorch/module/base.py +++ b/transformer_engine/pytorch/module/base.py @@ -607,42 +607,32 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC): def __init__(self) -> None: super().__init__() assert torch.cuda.is_available(), "TransformerEngine needs CUDA." - self.name = None - self.next_iter_when_debug_should_be_run = 0 - self.fp8_initialized = False - self.fp8 = False - self.fp8_calibration = False - self.fp8_meta = {} + self.fast_setattr("name", None) + self.fast_setattr("next_iter_when_debug_should_be_run", 0) + self.fast_setattr("fp8_initialized", False) + self.fast_setattr("fp8", False) + self.fast_setattr("fp8_calibration", False) + self.fast_setattr("fp8_meta", {}) self.fp8_meta["fp8_checkpoint"] = False self.fp8_meta["fp8_group"] = None - self.fp8_meta_tensors_initialized = False - self.quantizers = {"scaling_fwd": {}, "scaling_bwd": {}} - self.tp_group = None - self.tp_size = 1 - self.sequence_parallel = False - self.param_init_meta = {} - self.primary_weights_in_fp8 = FP8GlobalStateManager.with_fp8_parameters() - self.preserve_high_precision_init_val = FP8GlobalStateManager.with_high_precision_init_val() - self.fsdp_wrapped = False - self.fsdp_group = None - self._fp8_workspaces: Dict[str, QuantizedTensor] = {} - self.activation_dtype: Optional[torch.dtype] = None - self.wgrad_accumulation_and_reduce_hooks = [] - self.wgrad_store = None + self.fast_setattr("fp8_meta_tensors_initialized", False) + self.fast_setattr("quantizers", {"scaling_fwd": {}, "scaling_bwd": {}}) + self.fast_setattr("tp_group", None) + self.fast_setattr("tp_size", 1) + self.fast_setattr("sequence_parallel", False) + self.fast_setattr("param_init_meta", {}) + self.fast_setattr("primary_weights_in_fp8", FP8GlobalStateManager.with_fp8_parameters()) + self.fast_setattr("preserve_high_precision_init_val", FP8GlobalStateManager.with_high_precision_init_val()) + self.fast_setattr("fsdp_wrapped", False) + self.fast_setattr("fsdp_group", None) + self.fast_setattr("_fp8_workspaces", {}) + self.fast_setattr("activation_dtype", None) + self.fast_setattr("wgrad_accumulation_and_reduce_hooks", []) + self.fast_setattr("wgrad_store", None) if not TEDebugState.debug_enabled: TEDebugState.initialize() - # Names of attributes that can be set quickly (see __setattr__ - # method) - _fast_setattr_names: Set[str] = { - "activation_dtype", - "fp8", - "fp8_initialized", - "fp8_calibration", - "fp8_parameters", - } - def fast_setattr(self, name: str, value: Any) -> None: self.__dict__[name] = value From 778019db6fd216ef518cdff744fcdb883c2e8f1d Mon Sep 17 00:00:00 2001 From: Przemek Tredak Date: Mon, 15 Dec 2025 17:08:07 -0800 Subject: [PATCH 08/24] A different way Signed-off-by: Przemek Tredak --- .../dot_product_attention.py | 2 + transformer_engine/pytorch/module/base.py | 45 ++++++++++--------- .../pytorch/module/grouped_linear.py | 2 + .../pytorch/module/layernorm_linear.py | 2 + .../pytorch/module/layernorm_mlp.py | 2 + transformer_engine/pytorch/module/linear.py | 2 + 6 files changed, 33 insertions(+), 22 deletions(-) diff --git a/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py b/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py index 384a163b6f..9fac831961 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py @@ -482,6 +482,8 @@ def remove_extra_states_check(self, incompatible_keys): # pylint: disable=unuse self.register_load_state_dict_post_hook(remove_extra_states_check) + self.__setattr__ = self.default_setattr + def _load_from_state_dict( self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs ): diff --git a/transformer_engine/pytorch/module/base.py b/transformer_engine/pytorch/module/base.py index 3e234fde45..d5e352aa65 100644 --- a/transformer_engine/pytorch/module/base.py +++ b/transformer_engine/pytorch/module/base.py @@ -607,28 +607,28 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC): def __init__(self) -> None: super().__init__() assert torch.cuda.is_available(), "TransformerEngine needs CUDA." - self.fast_setattr("name", None) - self.fast_setattr("next_iter_when_debug_should_be_run", 0) - self.fast_setattr("fp8_initialized", False) - self.fast_setattr("fp8", False) - self.fast_setattr("fp8_calibration", False) - self.fast_setattr("fp8_meta", {}) + self.name = None + self.next_iter_when_debug_should_be_run = 0 + self.fp8_initialized = False + self.fp8 = False + self.fp8_calibration = False + self.fp8_meta = {} self.fp8_meta["fp8_checkpoint"] = False self.fp8_meta["fp8_group"] = None - self.fast_setattr("fp8_meta_tensors_initialized", False) - self.fast_setattr("quantizers", {"scaling_fwd": {}, "scaling_bwd": {}}) - self.fast_setattr("tp_group", None) - self.fast_setattr("tp_size", 1) - self.fast_setattr("sequence_parallel", False) - self.fast_setattr("param_init_meta", {}) - self.fast_setattr("primary_weights_in_fp8", FP8GlobalStateManager.with_fp8_parameters()) - self.fast_setattr("preserve_high_precision_init_val", FP8GlobalStateManager.with_high_precision_init_val()) - self.fast_setattr("fsdp_wrapped", False) - self.fast_setattr("fsdp_group", None) - self.fast_setattr("_fp8_workspaces", {}) - self.fast_setattr("activation_dtype", None) - self.fast_setattr("wgrad_accumulation_and_reduce_hooks", []) - self.fast_setattr("wgrad_store", None) + self.fp8_meta_tensors_initialized = False + self.quantizers = {"scaling_fwd": {}, "scaling_bwd": {}} + self.tp_group = None + self.tp_size = 1 + self.sequence_parallel = False + self.param_init_meta = {} + self.primary_weights_in_fp8 = FP8GlobalStateManager.with_fp8_parameters() + self.preserve_high_precision_init_val = FP8GlobalStateManager.with_high_precision_init_val() + self.fsdp_wrapped = False + self.fsdp_group = None + self._fp8_workspaces: Dict[str, QuantizedTensor] = {} + self.activation_dtype: Optional[torch.dtype] = None + self.wgrad_accumulation_and_reduce_hooks = [] + self.wgrad_store = None if not TEDebugState.debug_enabled: TEDebugState.initialize() @@ -639,13 +639,14 @@ def fast_setattr(self, name: str, value: Any) -> None: def module_setattr(self, name: str, value: Any) -> None: super().__setattr__(name, value) - def __setattr__(self, name: str, value: Any) -> None: + def default_setattr(self, name: str, value: Any) -> None: warnings.warn( """The default implementation of torch.nn.Module introduces significant CPU overhead when setting attributes and is therefore not recommended. Please use the explicit calls (fast_setattr for setting regular values and module_setattr for setting parameters, children modules and buffers).""", - RuntimeWarning) + RuntimeWarning, + ) self.module_setattr(name, value) def adjust_amax_history_length(self, length: int, fwd: Optional[bool] = None) -> None: diff --git a/transformer_engine/pytorch/module/grouped_linear.py b/transformer_engine/pytorch/module/grouped_linear.py index 9f7f01dbb1..f59772f956 100644 --- a/transformer_engine/pytorch/module/grouped_linear.py +++ b/transformer_engine/pytorch/module/grouped_linear.py @@ -716,6 +716,8 @@ def __init__( if name in (f"weight{i}", f"bias{i}"): param.skip_backward_post_hook = True + self.__setattr__ = self.default_setattr + def set_meta_tensor(self, fwd: bool, recipe: Recipe) -> None: """Init scales and amaxes for fwd | bwd.""" super().set_meta_tensor(fwd, recipe) diff --git a/transformer_engine/pytorch/module/layernorm_linear.py b/transformer_engine/pytorch/module/layernorm_linear.py index 2322fff935..242feb5ae5 100644 --- a/transformer_engine/pytorch/module/layernorm_linear.py +++ b/transformer_engine/pytorch/module/layernorm_linear.py @@ -1405,6 +1405,8 @@ def __init__( if name in self.weight_names or name in self.bias_names: param.skip_backward_post_hook = True + self.__setattr__ = self.default_setattr + def set_meta_tensor(self, fwd: bool, recipe: Recipe) -> None: """Init scales and amaxes for fwd | bwd.""" super().set_meta_tensor(fwd, recipe) diff --git a/transformer_engine/pytorch/module/layernorm_mlp.py b/transformer_engine/pytorch/module/layernorm_mlp.py index b40f933b16..9834a22cbd 100644 --- a/transformer_engine/pytorch/module/layernorm_mlp.py +++ b/transformer_engine/pytorch/module/layernorm_mlp.py @@ -1960,6 +1960,8 @@ def __init__( self.bwd_ln_sm_margin = int(os.getenv("NVTE_BWD_LAYERNORM_SM_MARGIN", "0")) self.inf_ln_sm_margin = int(os.getenv("NVTE_INF_LAYERNORM_SM_MARGIN", "0")) + self.__setattr__ = self.default_setattr + def set_meta_tensor(self, fwd: bool, recipe: Recipe) -> None: """Init scales and amaxes for fwd | bwd.""" super().set_meta_tensor(fwd, recipe) diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index 6adf479d82..92e37b9476 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -1309,6 +1309,8 @@ def __init__( if name in self.weight_names or name in self.bias_names: param.skip_backward_post_hook = True + self.__setattr__ = self.default_setattr + def set_meta_tensor(self, fwd: bool, recipe: Recipe) -> None: """Init scales and amaxes for fwd | bwd.""" super().set_meta_tensor(fwd, recipe) From 0fc2a62ce519dbd2ebf1e994f71c43272df44d4b Mon Sep 17 00:00:00 2001 From: Przemek Tredak Date: Tue, 16 Dec 2025 15:58:06 -0800 Subject: [PATCH 09/24] WAR the fact that it is not possible to set __setattr__ dynamically Signed-off-by: Przemek Tredak --- .../dot_product_attention/dot_product_attention.py | 2 +- transformer_engine/pytorch/module/base.py | 8 +++++++- transformer_engine/pytorch/module/grouped_linear.py | 2 +- transformer_engine/pytorch/module/layernorm_linear.py | 2 +- transformer_engine/pytorch/module/layernorm_mlp.py | 2 +- transformer_engine/pytorch/module/linear.py | 2 +- 6 files changed, 12 insertions(+), 6 deletions(-) diff --git a/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py b/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py index 9fac831961..fd32e83666 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py @@ -482,7 +482,7 @@ def remove_extra_states_check(self, incompatible_keys): # pylint: disable=unuse self.register_load_state_dict_post_hook(remove_extra_states_check) - self.__setattr__ = self.default_setattr + self._default_setattr = self._warning_setattr def _load_from_state_dict( self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs diff --git a/transformer_engine/pytorch/module/base.py b/transformer_engine/pytorch/module/base.py index d5e352aa65..185ec7a5c2 100644 --- a/transformer_engine/pytorch/module/base.py +++ b/transformer_engine/pytorch/module/base.py @@ -639,7 +639,7 @@ def fast_setattr(self, name: str, value: Any) -> None: def module_setattr(self, name: str, value: Any) -> None: super().__setattr__(name, value) - def default_setattr(self, name: str, value: Any) -> None: + def _warning_setattr(self, name: str, value: Any) -> None: warnings.warn( """The default implementation of torch.nn.Module introduces significant CPU overhead when setting attributes and is therefore not recommended. Please use the explicit calls @@ -649,6 +649,12 @@ def default_setattr(self, name: str, value: Any) -> None: ) self.module_setattr(name, value) + def _default_setattr(self, name: str, value: Any) -> None: + return self.module_setattr(name, value) + + def __setattr__(self, name: str, value: Any) -> None: + return self._default_setattr(name, value) + def adjust_amax_history_length(self, length: int, fwd: Optional[bool] = None) -> None: """ Delayed scaling only. diff --git a/transformer_engine/pytorch/module/grouped_linear.py b/transformer_engine/pytorch/module/grouped_linear.py index f59772f956..3fadce918d 100644 --- a/transformer_engine/pytorch/module/grouped_linear.py +++ b/transformer_engine/pytorch/module/grouped_linear.py @@ -716,7 +716,7 @@ def __init__( if name in (f"weight{i}", f"bias{i}"): param.skip_backward_post_hook = True - self.__setattr__ = self.default_setattr + self._default_setattr = self._warning_setattr def set_meta_tensor(self, fwd: bool, recipe: Recipe) -> None: """Init scales and amaxes for fwd | bwd.""" diff --git a/transformer_engine/pytorch/module/layernorm_linear.py b/transformer_engine/pytorch/module/layernorm_linear.py index 242feb5ae5..3f2037adce 100644 --- a/transformer_engine/pytorch/module/layernorm_linear.py +++ b/transformer_engine/pytorch/module/layernorm_linear.py @@ -1405,7 +1405,7 @@ def __init__( if name in self.weight_names or name in self.bias_names: param.skip_backward_post_hook = True - self.__setattr__ = self.default_setattr + self._default_setattr = self._warning_setattr def set_meta_tensor(self, fwd: bool, recipe: Recipe) -> None: """Init scales and amaxes for fwd | bwd.""" diff --git a/transformer_engine/pytorch/module/layernorm_mlp.py b/transformer_engine/pytorch/module/layernorm_mlp.py index 9834a22cbd..a637963893 100644 --- a/transformer_engine/pytorch/module/layernorm_mlp.py +++ b/transformer_engine/pytorch/module/layernorm_mlp.py @@ -1960,7 +1960,7 @@ def __init__( self.bwd_ln_sm_margin = int(os.getenv("NVTE_BWD_LAYERNORM_SM_MARGIN", "0")) self.inf_ln_sm_margin = int(os.getenv("NVTE_INF_LAYERNORM_SM_MARGIN", "0")) - self.__setattr__ = self.default_setattr + self._default_setattr = self._warning_setattr def set_meta_tensor(self, fwd: bool, recipe: Recipe) -> None: """Init scales and amaxes for fwd | bwd.""" diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index 92e37b9476..c5c3990bd1 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -1309,7 +1309,7 @@ def __init__( if name in self.weight_names or name in self.bias_names: param.skip_backward_post_hook = True - self.__setattr__ = self.default_setattr + self._default_setattr = self._warning_setattr def set_meta_tensor(self, fwd: bool, recipe: Recipe) -> None: """Init scales and amaxes for fwd | bwd.""" From 2fb6ee3590e475a9f8bec14a9e917ea54f50b434 Mon Sep 17 00:00:00 2001 From: Przemek Tredak Date: Fri, 9 Jan 2026 16:21:45 -0800 Subject: [PATCH 10/24] Simpler solution and fixes Signed-off-by: Przemek Tredak --- tests/pytorch/test_numerics.py | 124 ++++++++++-------- .../dot_product_attention.py | 12 +- transformer_engine/pytorch/distributed.py | 8 +- transformer_engine/pytorch/graph.py | 8 +- transformer_engine/pytorch/module/base.py | 54 +++++--- .../pytorch/module/grouped_linear.py | 2 +- .../pytorch/module/layernorm_linear.py | 2 +- .../pytorch/module/layernorm_mlp.py | 4 +- transformer_engine/pytorch/module/linear.py | 2 +- 9 files changed, 124 insertions(+), 92 deletions(-) diff --git a/tests/pytorch/test_numerics.py b/tests/pytorch/test_numerics.py index abe2806e66..9a1942f30d 100644 --- a/tests/pytorch/test_numerics.py +++ b/tests/pytorch/test_numerics.py @@ -5,6 +5,7 @@ import math import os from typing import Dict, List, Tuple, Optional +import warnings import pytest import random @@ -1296,14 +1297,15 @@ def test_linear_accuracy_delay_wgrad_compute(dtype, bs, model, bias, fuse_wgrad_ ).eval() # Share params - with torch.no_grad(): - te_linear_ref.weight = Parameter(te_linear.weight.clone()) - if bias: - te_linear_ref.bias = Parameter(te_linear.bias.clone()) - if fuse_wgrad_accumulation: - weight = getattr(te_linear, f"weight") - weight.main_grad = torch.rand_like(weight, dtype=torch.float32) - te_linear_ref.weight.main_grad = weight.main_grad.clone() + with warnings.catch_warnings(action="ignore", category=RuntimeWarning): + with torch.no_grad(): + te_linear_ref.weight = Parameter(te_linear.weight.clone()) + if bias: + te_linear_ref.bias = Parameter(te_linear.bias.clone()) + if fuse_wgrad_accumulation: + weight = getattr(te_linear, f"weight") + weight.main_grad = torch.rand_like(weight, dtype=torch.float32) + te_linear_ref.weight.main_grad = weight.main_grad.clone() te_outputs = _test_granular_accuracy(te_linear, bs, dtype, config, delay_wgrad_compute=True) te_outputs_ref = _test_granular_accuracy( @@ -1359,12 +1361,13 @@ def test_linear_accuracy_save_original_input(dtype, model, recipe): ).eval() # Share params - with torch.no_grad(): - te_linear_ref.weight = Parameter(te_linear.weight.clone()) - if fuse_wgrad_accumulation: - weight = getattr(te_linear, f"weight") - weight.main_grad = torch.rand_like(weight, dtype=torch.float32) - te_linear_ref.weight.main_grad = weight.main_grad.clone() + with warnings.catch_warnings(action="ignore", category=RuntimeWarning): + with torch.no_grad(): + te_linear_ref.weight = Parameter(te_linear.weight.clone()) + if fuse_wgrad_accumulation: + weight = getattr(te_linear, f"weight") + weight.main_grad = torch.rand_like(weight, dtype=torch.float32) + te_linear_ref.weight.main_grad = weight.main_grad.clone() te_outputs = _test_granular_accuracy(te_linear, bs, dtype, config, recipe=recipe) te_outputs_ref = _test_granular_accuracy(te_linear_ref, bs, dtype, config, recipe=recipe) @@ -1601,17 +1604,18 @@ def test_layernorm_linear_accuracy_delay_wgrad_compute( ).eval() # Share params - with torch.no_grad(): - ln_linear_ref.layer_norm_weight = Parameter(ln_linear.layer_norm_weight.clone()) - if normalization != "RMSNorm": - ln_linear_ref.layer_norm_bias = Parameter(ln_linear.layer_norm_bias.clone()) - ln_linear_ref.weight = Parameter(ln_linear.weight.clone()) - if bias: - ln_linear_ref.bias = Parameter(ln_linear.bias.clone()) - if fuse_wgrad_accumulation: - weight = getattr(ln_linear, f"weight") - weight.main_grad = torch.rand_like(weight, dtype=torch.float32) - ln_linear_ref.weight.main_grad = weight.main_grad.clone() + with warnings.catch_warnings(action="ignore", category=RuntimeWarning): + with torch.no_grad(): + ln_linear_ref.layer_norm_weight = Parameter(ln_linear.layer_norm_weight.clone()) + if normalization != "RMSNorm": + ln_linear_ref.layer_norm_bias = Parameter(ln_linear.layer_norm_bias.clone()) + ln_linear_ref.weight = Parameter(ln_linear.weight.clone()) + if bias: + ln_linear_ref.bias = Parameter(ln_linear.bias.clone()) + if fuse_wgrad_accumulation: + weight = getattr(ln_linear, f"weight") + weight.main_grad = torch.rand_like(weight, dtype=torch.float32) + ln_linear_ref.weight.main_grad = weight.main_grad.clone() te_outputs = _test_granular_accuracy(ln_linear, bs, dtype, config, delay_wgrad_compute=True) te_outputs_ref = _test_granular_accuracy( @@ -1739,19 +1743,24 @@ def test_layernorm_mlp_accuracy_delay_wgrad_compute( ).eval() # Share params - with torch.no_grad(): - ln_mlp_ref.layer_norm_weight = Parameter(ln_mlp.layer_norm_weight.clone()) - ln_mlp_ref.layer_norm_bias = Parameter(ln_mlp.layer_norm_bias.clone()) - ln_mlp_ref.fc1_weight = Parameter(ln_mlp.fc1_weight.clone()) - ln_mlp_ref.fc2_weight = Parameter(ln_mlp.fc2_weight.clone()) - if bias: - ln_mlp_ref.fc1_bias = Parameter(ln_mlp.fc1_bias.clone()) - ln_mlp_ref.fc2_bias = Parameter(ln_mlp.fc2_bias.clone()) - if fuse_wgrad_accumulation: - ln_mlp.fc1_weight.main_grad = torch.rand_like(ln_mlp.fc1_weight, dtype=torch.float32) - ln_mlp_ref.fc1_weight.main_grad = ln_mlp.fc1_weight.main_grad.clone() - ln_mlp.fc2_weight.main_grad = torch.rand_like(ln_mlp.fc2_weight, dtype=torch.float32) - ln_mlp_ref.fc2_weight.main_grad = ln_mlp.fc2_weight.main_grad.clone() + with warnings.catch_warnings(action="ignore", category=RuntimeWarning): + with torch.no_grad(): + ln_mlp_ref.layer_norm_weight = Parameter(ln_mlp.layer_norm_weight.clone()) + ln_mlp_ref.layer_norm_bias = Parameter(ln_mlp.layer_norm_bias.clone()) + ln_mlp_ref.fc1_weight = Parameter(ln_mlp.fc1_weight.clone()) + ln_mlp_ref.fc2_weight = Parameter(ln_mlp.fc2_weight.clone()) + if bias: + ln_mlp_ref.fc1_bias = Parameter(ln_mlp.fc1_bias.clone()) + ln_mlp_ref.fc2_bias = Parameter(ln_mlp.fc2_bias.clone()) + if fuse_wgrad_accumulation: + ln_mlp.fc1_weight.main_grad = torch.rand_like( + ln_mlp.fc1_weight, dtype=torch.float32 + ) + ln_mlp_ref.fc1_weight.main_grad = ln_mlp.fc1_weight.main_grad.clone() + ln_mlp.fc2_weight.main_grad = torch.rand_like( + ln_mlp.fc2_weight, dtype=torch.float32 + ) + ln_mlp_ref.fc2_weight.main_grad = ln_mlp.fc2_weight.main_grad.clone() te_outputs = _test_granular_accuracy(ln_mlp, bs, dtype, config, delay_wgrad_compute=True) te_outputs_ref = _test_granular_accuracy( @@ -1796,14 +1805,15 @@ def test_layernorm_mlp_accuracy_checkpoint( ).eval() # Share params - with torch.no_grad(): - ln_mlp_ref.layer_norm_weight = Parameter(ln_mlp.layer_norm_weight.clone()) - ln_mlp_ref.layer_norm_bias = Parameter(ln_mlp.layer_norm_bias.clone()) - ln_mlp_ref.fc1_weight = Parameter(ln_mlp.fc1_weight.clone()) - ln_mlp_ref.fc2_weight = Parameter(ln_mlp.fc2_weight.clone()) - if bias: - ln_mlp_ref.fc1_bias = Parameter(ln_mlp.fc1_bias.clone()) - ln_mlp_ref.fc2_bias = Parameter(ln_mlp.fc2_bias.clone()) + with warnings.catch_warnings(action="ignore", category=RuntimeWarning): + with torch.no_grad(): + ln_mlp_ref.layer_norm_weight = Parameter(ln_mlp.layer_norm_weight.clone()) + ln_mlp_ref.layer_norm_bias = Parameter(ln_mlp.layer_norm_bias.clone()) + ln_mlp_ref.fc1_weight = Parameter(ln_mlp.fc1_weight.clone()) + ln_mlp_ref.fc2_weight = Parameter(ln_mlp.fc2_weight.clone()) + if bias: + ln_mlp_ref.fc1_bias = Parameter(ln_mlp.fc1_bias.clone()) + ln_mlp_ref.fc2_bias = Parameter(ln_mlp.fc2_bias.clone()) te_outputs = _test_granular_accuracy(ln_mlp, bs, dtype, config, delay_wgrad_compute=False) te_outputs_ref = _test_granular_accuracy( @@ -1952,9 +1962,13 @@ def test_grouped_linear_accuracy( # Share params with torch.no_grad(): for i in range(num_gemms): - sequential_linear[i].weight = Parameter(getattr(grouped_linear, f"weight{i}").clone()) + sequential_linear[i].module_setattr( + "weight", Parameter(getattr(grouped_linear, f"weight{i}").clone()) + ) if bias: - sequential_linear[i].bias = Parameter(getattr(grouped_linear, f"bias{i}").clone()) + sequential_linear[i].module_setattr( + "bias", Parameter(getattr(grouped_linear, f"bias{i}").clone()) + ) if fuse_wgrad_accumulation: weight_i = getattr(grouped_linear, f"weight{i}") weight_i.main_grad = torch.rand_like(weight_i, dtype=torch.float32) @@ -2096,9 +2110,13 @@ def test_grouped_linear_accuracy_save_original_input( # Share params with torch.no_grad(): for i in range(num_gemms): - sequential_linear[i].weight = Parameter(getattr(grouped_linear, f"weight{i}").clone()) + sequential_linear[i].module_setattr( + "weight", Parameter(getattr(grouped_linear, f"weight{i}").clone()) + ) if bias: - sequential_linear[i].bias = Parameter(getattr(grouped_linear, f"bias{i}").clone()) + sequential_linear[i].module_setattr( + "bias", Parameter(getattr(grouped_linear, f"bias{i}").clone()) + ) if fuse_wgrad_accumulation: weight_i = getattr(grouped_linear, f"weight{i}") weight_i.main_grad = torch.rand_like(weight_i, dtype=torch.float32) @@ -2298,8 +2316,7 @@ def test_padding_grouped_linear_accuracy( with torch.no_grad(): inner_grouped_linear = grouped_linear.linear_fn for i in range(num_gemms): - setattr( - ref_grouped_linear, + ref_grouped_linear.module_setattr( f"weight{i}", Parameter(getattr(inner_grouped_linear, f"weight{i}").clone()), ) @@ -2375,8 +2392,7 @@ def test_padding_grouped_linear_accuracy_save_original_input( with torch.no_grad(): inner_grouped_linear = grouped_linear.linear_fn for i in range(num_gemms): - setattr( - ref_grouped_linear, + ref_grouped_linear.module_setattr( f"weight{i}", Parameter(getattr(inner_grouped_linear, f"weight{i}").clone()), ) diff --git a/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py b/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py index fd32e83666..f19e69a508 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py @@ -482,7 +482,7 @@ def remove_extra_states_check(self, incompatible_keys): # pylint: disable=unuse self.register_load_state_dict_post_hook(remove_extra_states_check) - self._default_setattr = self._warning_setattr + self._initialized = True def _load_from_state_dict( self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs @@ -678,9 +678,9 @@ def init_fp8_metadata(self, num_gemms: int = 1) -> None: # assume attention uses the same fp8_group as GEMMs fp8_group = FP8GlobalStateManager.get_fp8_group() - self.fp8_parameters = FP8GlobalStateManager.with_fp8_parameters() - self.fp8 = FP8GlobalStateManager.is_fp8_enabled() - self.fp8_calibration = FP8GlobalStateManager.is_fp8_calibration() + self.fast_setattr("fp8_parameters", FP8GlobalStateManager.with_fp8_parameters()) + self.fast_setattr("fp8", FP8GlobalStateManager.is_fp8_enabled()) + self.fast_setattr("fp8_calibration", FP8GlobalStateManager.is_fp8_calibration()) fp8_enabled = self.fp8 or self.fp8_calibration self.fp8_meta["fp8_checkpoint"] = self.fp8 or self.fp8_calibration if self.fp8_parameters or fp8_enabled: @@ -705,7 +705,7 @@ def init_fp8_metadata(self, num_gemms: int = 1) -> None: ) else: # If fp8 isn't enabled, turn off and return. - self.fp8_initialized = False + self.fast_setattr("fp8_initialized", False) return if self.fp8_parameters and not self.fp8_initialized: @@ -723,7 +723,7 @@ def init_fp8_metadata(self, num_gemms: int = 1) -> None: # Allocate scales and amaxes self.init_fp8_meta_tensors(fp8_recipes) - self.fp8_initialized = True + self.fast_setattr("fp8_initialized", True) self.fp8_meta["recipe"] = fp8_recipe_dpa if fp8_recipe != fp8_recipe_dpa: diff --git a/transformer_engine/pytorch/distributed.py b/transformer_engine/pytorch/distributed.py index 5497ee7967..b38725e8be 100644 --- a/transformer_engine/pytorch/distributed.py +++ b/transformer_engine/pytorch/distributed.py @@ -729,8 +729,8 @@ def checkpoint( if isinstance(function, TransformerEngineBaseModule): # If this TE module is FSDP-wrapped, clear its FSDP group information because there's no need # to scatter/gather activations that we will recompute anyway. - setattr(function, "fsdp_wrapped", False) - setattr(function, "fsdp_group", None) + function.fast_setattr("fsdp_wrapped", False) + function.fast_setattr("fsdp_group", None) # Otherwise discard unused te.utils.checkpoint.checkpoint() arguments # and execute TE's own checkpointing @@ -2046,7 +2046,7 @@ def prepare_te_modules_for_fsdp(fsdp_root: torch.nn.Module) -> None: ) root_state = _get_module_fsdp_state(fsdp_root) assert root_state is not None, "Root module does not have a valid _FSDPState." - setattr(fsdp_root.module, "fsdp_group", root_state.process_group) + fsdp_root.module.fast_setattr("fsdp_group", root_state.process_group) # Iterate through all FSDP-wrapped submodules and inject FSDP information into TE modules fsdp_states, fsdp_modules = _get_fsdp_states_with_modules(fsdp_root) @@ -2057,7 +2057,7 @@ def prepare_te_modules_for_fsdp(fsdp_root: torch.nn.Module) -> None: "TE modules with primary weights in FP8 cannot be FSDP-wrapped. " "Please initialize your model without the te.quantized_model_init(...) context." ) - setattr(fsdp_module.module, "fsdp_group", state.process_group) + fsdp_module.module.fast_setattr("fsdp_group", state.process_group) class FullyShardedDataParallel(FSDP): diff --git a/transformer_engine/pytorch/graph.py b/transformer_engine/pytorch/graph.py index f587ca9946..322ad52723 100644 --- a/transformer_engine/pytorch/graph.py +++ b/transformer_engine/pytorch/graph.py @@ -935,7 +935,8 @@ def new_fwd(*user_args, **user_kwargs): forward = make_graphed_forward(func, func.training, graphed, func.forward, te_modules) if _order is None: - func.forward = forward + with warnings.catch_warnings(action="ignore", category=RuntimeWarning): + func.forward = forward ret.append(func) else: ret.append(forward) @@ -943,8 +944,9 @@ def new_fwd(*user_args, **user_kwargs): ret.append(graphed) backward_dw_func, reset_func = make_graphed_attribute_functions(i) - setattr(ret[-1], "backward_dw", backward_dw_func) - setattr(ret[-1], "reset", reset_func) + with warnings.catch_warnings(action="ignore", category=RuntimeWarning): + setattr(ret[-1], "backward_dw", backward_dw_func) + setattr(ret[-1], "reset", reset_func) if just_one_callable: return ret[0] diff --git a/transformer_engine/pytorch/module/base.py b/transformer_engine/pytorch/module/base.py index 185ec7a5c2..937bd32772 100644 --- a/transformer_engine/pytorch/module/base.py +++ b/transformer_engine/pytorch/module/base.py @@ -10,7 +10,8 @@ import warnings from enum import Enum from abc import ABC, abstractmethod -from typing import Any, Dict, Generator, List, Optional, Set, Tuple, Union +from typing import Any, Dict, Generator, List, Optional, Tuple, Union +from typing_extensions import Self from contextlib import contextmanager import logging from types import MethodType @@ -634,26 +635,30 @@ def __init__(self) -> None: TEDebugState.initialize() def fast_setattr(self, name: str, value: Any) -> None: + """ + Fast version of the Module's set attribute function. + Should be used for regular attributes, but not properties nor parameters/buffers. + """ self.__dict__[name] = value def module_setattr(self, name: str, value: Any) -> None: + """ + Regular version of the Module's set attribute function. + Should be used only when the fast version cannot be used - for the properties, + parameters and buffers. + """ super().__setattr__(name, value) - def _warning_setattr(self, name: str, value: Any) -> None: - warnings.warn( - """The default implementation of torch.nn.Module introduces significant CPU overhead - when setting attributes and is therefore not recommended. Please use the explicit calls - (fast_setattr for setting regular values and module_setattr for setting parameters, - children modules and buffers).""", - RuntimeWarning, - ) - self.module_setattr(name, value) - - def _default_setattr(self, name: str, value: Any) -> None: - return self.module_setattr(name, value) - def __setattr__(self, name: str, value: Any) -> None: - return self._default_setattr(name, value) + if "_initialized" in self.__dict__ and self._initialized: + warnings.warn( + """The default implementation of torch.nn.Module introduces significant CPU overhead + when setting attributes and is therefore not recommended. Please use the explicit + calls (fast_setattr for setting regular values and module_setattr for setting + parameters, children modules and buffers).""", + RuntimeWarning, + ) + super().__setattr__(name, value) def adjust_amax_history_length(self, length: int, fwd: Optional[bool] = None) -> None: """ @@ -774,7 +779,7 @@ def init_fp8_meta_tensors(self, recipe: Recipe) -> None: self.set_meta_tensor(True, recipe) self.set_meta_tensor(False, recipe) - self.fp8_meta_tensors_initialized = True + self.fast_setattr("fp8_meta_tensors_initialized", True) def get_fp8_meta_tensors(self) -> None: """Get scales and amaxes.""" @@ -1015,7 +1020,7 @@ def init_fp8_metadata(self, num_gemms: int = 1) -> None: # Allocate scales and amaxes self.init_fp8_meta_tensors(meta["recipe"]) - self.fp8_initialized = True + self.fast_setattr("fp8_initialized", True) meta["recipe"] = FP8GlobalStateManager.get_fp8_recipe() @@ -1039,7 +1044,7 @@ def prepare_forward( allow_non_contiguous: bool = False, allow_different_data_and_param_types: bool = False, ) -> torch.Tensor: - """Checks and prepare for FWD execution.""" + """Checks and prepares for FWD execution.""" self.fast_setattr( "allow_different_data_and_param_types", allow_different_data_and_param_types ) @@ -1080,6 +1085,10 @@ def prepare_forward( return inp def end_forward(self): + """ + Required to be called at the end of the forward function to properly handle + DelayedScaling metadata handling and the NVTX ranges. + """ delayed_scaling_recipe = self.fp8 and self.fp8_meta["recipe"].delayed() if delayed_scaling_recipe and self.fp8 and in_fp8_activation_recompute_phase(): FP8GlobalStateManager.restore_fp8_meta_tensors(self.fp8_meta) @@ -1093,11 +1102,16 @@ def prepare_forward_ctx( allow_non_contiguous: bool = False, allow_different_data_and_param_types: bool = False, ) -> Generator[torch.Tensor, None, None]: + """Checks and prepares for FWD execution.""" yield self.prepare_forward( inp, num_gemms, allow_non_contiguous, allow_different_data_and_param_types ) self.end_forward() + def train(self, mode: bool = True) -> Self: + with warnings.catch_warnings(action="ignore", category=RuntimeWarning): + return super().train(mode) + def set_nccl_overlap_warning_if_tp(self) -> None: """When using TP, the NCCL communication needs to be scheduled before the GEMM for there to be a guaranteed overlap. From the @@ -1331,9 +1345,9 @@ def clear(self): # Update the parameter based on its type if not is_dtensor: - setattr(self, name, param) + self.module_setattr(name, param) else: - setattr(self, name, dtensor_param) + self.module_setattr(name, dtensor_param) @abstractmethod def forward(self): diff --git a/transformer_engine/pytorch/module/grouped_linear.py b/transformer_engine/pytorch/module/grouped_linear.py index 3fadce918d..b5fbb0b8fd 100644 --- a/transformer_engine/pytorch/module/grouped_linear.py +++ b/transformer_engine/pytorch/module/grouped_linear.py @@ -716,7 +716,7 @@ def __init__( if name in (f"weight{i}", f"bias{i}"): param.skip_backward_post_hook = True - self._default_setattr = self._warning_setattr + self._initialized = True def set_meta_tensor(self, fwd: bool, recipe: Recipe) -> None: """Init scales and amaxes for fwd | bwd.""" diff --git a/transformer_engine/pytorch/module/layernorm_linear.py b/transformer_engine/pytorch/module/layernorm_linear.py index 3f2037adce..dd0e1a4c61 100644 --- a/transformer_engine/pytorch/module/layernorm_linear.py +++ b/transformer_engine/pytorch/module/layernorm_linear.py @@ -1405,7 +1405,7 @@ def __init__( if name in self.weight_names or name in self.bias_names: param.skip_backward_post_hook = True - self._default_setattr = self._warning_setattr + self._initialized = True def set_meta_tensor(self, fwd: bool, recipe: Recipe) -> None: """Init scales and amaxes for fwd | bwd.""" diff --git a/transformer_engine/pytorch/module/layernorm_mlp.py b/transformer_engine/pytorch/module/layernorm_mlp.py index a637963893..a283640c7e 100644 --- a/transformer_engine/pytorch/module/layernorm_mlp.py +++ b/transformer_engine/pytorch/module/layernorm_mlp.py @@ -1960,7 +1960,7 @@ def __init__( self.bwd_ln_sm_margin = int(os.getenv("NVTE_BWD_LAYERNORM_SM_MARGIN", "0")) self.inf_ln_sm_margin = int(os.getenv("NVTE_INF_LAYERNORM_SM_MARGIN", "0")) - self._default_setattr = self._warning_setattr + self._initialized = True def set_meta_tensor(self, fwd: bool, recipe: Recipe) -> None: """Init scales and amaxes for fwd | bwd.""" @@ -2094,7 +2094,7 @@ def forward( # Disable bias_gelu_nvfusion for determinism checkpointing in non-reentrant mode if self.bias_gelu_nvfusion and not use_reentrant_activation_recompute(): - self.bias_gelu_nvfusion = False + self.fast_setattr("bias_gelu_nvfusion", False) if is_grad_enabled: fwd_fn = _LayerNormMLP.apply diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index c5c3990bd1..cf85a1d836 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -1309,7 +1309,7 @@ def __init__( if name in self.weight_names or name in self.bias_names: param.skip_backward_post_hook = True - self._default_setattr = self._warning_setattr + self._initialized = True def set_meta_tensor(self, fwd: bool, recipe: Recipe) -> None: """Init scales and amaxes for fwd | bwd.""" From 49407242cedfd7b20aff5b23e56053ef22eb8cd7 Mon Sep 17 00:00:00 2001 From: Przemek Tredak Date: Tue, 13 Jan 2026 16:30:07 -0800 Subject: [PATCH 11/24] Fix for the inference mode DPA Signed-off-by: Przemek Tredak --- .../dot_product_attention/dot_product_attention.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py b/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py index f19e69a508..ad7dac108a 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py @@ -1147,10 +1147,11 @@ def forward( if attn_mask_type == "padding_causal": attn_mask_type = attn_mask_type + "_bottom_right" - self.attention_type = "cross" - self.flash_attention.attention_type = self.attention_type - self.fused_attention.attention_type = self.attention_type - self.unfused_attention.attention_type = self.attention_type + if self.attention_type != "cross": + self.fast_setattr("attention_type", "cross") + self.flash_attention.attention_type = self.attention_type + self.fused_attention.attention_type = self.attention_type + self.unfused_attention.attention_type = self.attention_type query_layer, key_layer, value_layer = [ x.contiguous() if not x.is_contiguous() else x From 8704a59710ca1603ce859d62b01e6430254f0579 Mon Sep 17 00:00:00 2001 From: Przemek Tredak Date: Tue, 13 Jan 2026 16:39:46 -0800 Subject: [PATCH 12/24] Start of debugging debug tools Signed-off-by: Przemek Tredak --- transformer_engine/pytorch/module/base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/transformer_engine/pytorch/module/base.py b/transformer_engine/pytorch/module/base.py index 937bd32772..f42a8251d3 100644 --- a/transformer_engine/pytorch/module/base.py +++ b/transformer_engine/pytorch/module/base.py @@ -1610,7 +1610,7 @@ def _validate_name(self): " insight. ", level=logging.WARNING, ) - self.name = f"Layer_{TEDebugState.get_layer_count()}" + self.fast_setattr("name", f"Layer_{TEDebugState.get_layer_count()}") def _check_weight_tensor_recipe_correspondence(self) -> None: """ From aa783a7d56ee1dc5098e772cd919a0ab5c333e1d Mon Sep 17 00:00:00 2001 From: Przemek Tredak Date: Tue, 13 Jan 2026 16:42:20 -0800 Subject: [PATCH 13/24] More fixes in debug Signed-off-by: Przemek Tredak --- transformer_engine/pytorch/module/base.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/transformer_engine/pytorch/module/base.py b/transformer_engine/pytorch/module/base.py index f42a8251d3..3856bf95bf 100644 --- a/transformer_engine/pytorch/module/base.py +++ b/transformer_engine/pytorch/module/base.py @@ -1560,14 +1560,14 @@ def is_debug_iter(self) -> bool: debug = False else: debug = TEDebugState.get_iteration() >= self.next_iter_when_debug_should_be_run - self.debug_last_iteration = TEDebugState.get_iteration() - self.debug_enabled_in_this_iteration = debug + self.fast_setattr("debug_last_iteration", TEDebugState.get_iteration()) + self.fast_setattr("debug_enabled_in_this_iteration", debug) else: # If this is the same iteration as previous invocation of the module, # we use the debug value from the first invocation in the iteration. debug = self.debug_enabled_in_this_iteration - self.debug_last_iteration = TEDebugState.get_iteration() + self.fast_setattr("debug_last_iteration", TEDebugState.get_iteration()) if self.wgrad_store is not None: if debug and self.wgrad_store.delay_wgrad_compute(): @@ -1583,7 +1583,8 @@ def no_debug_features_active(self, quantizers): # Sometimes features inform that they will not be enabled for particular layer # for multiple next iterations. - self.next_iter_when_debug_should_be_run = next_iter_when_debug_should_be_run(quantizers) + self.fast_setattr("next_iter_when_debug_should_be_run", + next_iter_when_debug_should_be_run(quantizers)) if not run_current: return True From fd158f61f3f6a320d7754bec2e53e997aff15af5 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 14 Jan 2026 00:43:14 +0000 Subject: [PATCH 14/24] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- transformer_engine/pytorch/module/base.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/transformer_engine/pytorch/module/base.py b/transformer_engine/pytorch/module/base.py index 3856bf95bf..f1767cdeea 100644 --- a/transformer_engine/pytorch/module/base.py +++ b/transformer_engine/pytorch/module/base.py @@ -1583,8 +1583,9 @@ def no_debug_features_active(self, quantizers): # Sometimes features inform that they will not be enabled for particular layer # for multiple next iterations. - self.fast_setattr("next_iter_when_debug_should_be_run", - next_iter_when_debug_should_be_run(quantizers)) + self.fast_setattr( + "next_iter_when_debug_should_be_run", next_iter_when_debug_should_be_run(quantizers) + ) if not run_current: return True From 2601e5e8a9ca4fa4cbd8a4d770cc14a1d52c34a4 Mon Sep 17 00:00:00 2001 From: Przemek Tredak Date: Tue, 13 Jan 2026 16:54:35 -0800 Subject: [PATCH 15/24] Speculative moving the validate_name to the constructor Signed-off-by: Przemek Tredak --- .../pytorch/attention/multi_head_attention.py | 2 -- transformer_engine/pytorch/module/base.py | 15 ++++----------- transformer_engine/pytorch/transformer.py | 3 --- 3 files changed, 4 insertions(+), 16 deletions(-) diff --git a/transformer_engine/pytorch/attention/multi_head_attention.py b/transformer_engine/pytorch/attention/multi_head_attention.py index f875fd1e0a..d7bd0f1237 100644 --- a/transformer_engine/pytorch/attention/multi_head_attention.py +++ b/transformer_engine/pytorch/attention/multi_head_attention.py @@ -739,8 +739,6 @@ def forward( core_attention_bias_type in AttnBiasTypes ), f"core_attention_bias_type {core_attention_bias_type} is not supported!" - if TEDebugState.debug_enabled: - TransformerEngineBaseModule._validate_name(self) # ================================================= # Pre-allocate memory for key-value cache for inference diff --git a/transformer_engine/pytorch/module/base.py b/transformer_engine/pytorch/module/base.py index f1767cdeea..a9dfd4b6ec 100644 --- a/transformer_engine/pytorch/module/base.py +++ b/transformer_engine/pytorch/module/base.py @@ -633,6 +633,7 @@ def __init__(self) -> None: if not TEDebugState.debug_enabled: TEDebugState.initialize() + self._validate_name() def fast_setattr(self, name: str, value: Any) -> None: """ @@ -1546,7 +1547,6 @@ def is_debug_iter(self) -> bool: debug = TEDebugState.debug_enabled if not debug: return False - self._validate_name() # If layer is run first time in new iteration, # we need to check if the debug should be enabled for this layer - @@ -1597,22 +1597,15 @@ def no_debug_features_active(self, quantizers): def _validate_name(self): """ Validate name passed to the module. - This is invoked in the forward() method as module names are assigned after Model is initialized in Megatron-LM. - If no name is assigned, it creates a default name with layer count as the variable. + It creates a default name with layer count as the variable + which may be changed by the user of the module. """ if self.name is not None: return assert TEDebugState.debug_enabled import nvdlfw_inspect.api as debug_api - if self.name is None: - debug_api.log_message( - "Names are not provided to debug modules. ", - "Creating and using generic names. Pass names to debug modules for better" - " insight. ", - level=logging.WARNING, - ) - self.fast_setattr("name", f"Layer_{TEDebugState.get_layer_count()}") + self.name = f"Layer_{TEDebugState.get_layer_count()}" def _check_weight_tensor_recipe_correspondence(self) -> None: """ diff --git a/transformer_engine/pytorch/transformer.py b/transformer_engine/pytorch/transformer.py index 9b9ccc5185..a2eb68dcc8 100644 --- a/transformer_engine/pytorch/transformer.py +++ b/transformer_engine/pytorch/transformer.py @@ -768,9 +768,6 @@ def forward( enc_dec_attn_mask[i].dtype == torch.bool for i in range(len(enc_dec_attn_mask)) ), "Encoder-decoder attention mask must be boolean tensor(s)" - if TEDebugState.debug_enabled: - TransformerEngineBaseModule._validate_name(self) - # For AMP if torch.is_autocast_enabled(): hidden_states = cast_if_needed(hidden_states, torch_get_autocast_gpu_dtype()) From 09a3a7c11c41ae589d4d9cd7474f4258c96ba62b Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 14 Jan 2026 00:55:51 +0000 Subject: [PATCH 16/24] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- transformer_engine/pytorch/attention/multi_head_attention.py | 1 - 1 file changed, 1 deletion(-) diff --git a/transformer_engine/pytorch/attention/multi_head_attention.py b/transformer_engine/pytorch/attention/multi_head_attention.py index d7bd0f1237..5ee3316a09 100644 --- a/transformer_engine/pytorch/attention/multi_head_attention.py +++ b/transformer_engine/pytorch/attention/multi_head_attention.py @@ -739,7 +739,6 @@ def forward( core_attention_bias_type in AttnBiasTypes ), f"core_attention_bias_type {core_attention_bias_type} is not supported!" - # ================================================= # Pre-allocate memory for key-value cache for inference # ================================================= From 026b4b12f5d380fbf8597aa6fe747a54995e91b0 Mon Sep 17 00:00:00 2001 From: Przemek Tredak Date: Tue, 13 Jan 2026 17:02:25 -0800 Subject: [PATCH 17/24] Fix Signed-off-by: Przemek Tredak --- transformer_engine/pytorch/attention/multi_head_attention.py | 1 + transformer_engine/pytorch/transformer.py | 1 + 2 files changed, 2 insertions(+) diff --git a/transformer_engine/pytorch/attention/multi_head_attention.py b/transformer_engine/pytorch/attention/multi_head_attention.py index 5ee3316a09..143252640b 100644 --- a/transformer_engine/pytorch/attention/multi_head_attention.py +++ b/transformer_engine/pytorch/attention/multi_head_attention.py @@ -335,6 +335,7 @@ def __init__( self.hidden_size_kv = self.hidden_size_per_attention_head * self.num_gqa_groups self.name = name + TransformerEngineBaseModule._validate_name(self) common_gemm_kwargs = { "fuse_wgrad_accumulation": fuse_wgrad_accumulation, diff --git a/transformer_engine/pytorch/transformer.py b/transformer_engine/pytorch/transformer.py index a2eb68dcc8..fb2676d422 100644 --- a/transformer_engine/pytorch/transformer.py +++ b/transformer_engine/pytorch/transformer.py @@ -398,6 +398,7 @@ def __init__( self.softmax_type = softmax_type self.name = name + TransformerEngineBaseModule._validate_name(self) attention_args = ( hidden_size, From e83e8165df965e98e22e781518b49cab85205bbb Mon Sep 17 00:00:00 2001 From: Przemek Tredak Date: Tue, 13 Jan 2026 17:22:03 -0800 Subject: [PATCH 18/24] Making the debug tools names saner Signed-off-by: Przemek Tredak --- transformer_engine/pytorch/module/base.py | 6 ++---- transformer_engine/pytorch/module/grouped_linear.py | 3 +-- transformer_engine/pytorch/module/layernorm_linear.py | 5 ++--- transformer_engine/pytorch/module/layernorm_mlp.py | 5 ++--- transformer_engine/pytorch/module/linear.py | 3 +-- transformer_engine/pytorch/transformer.py | 6 +++--- 6 files changed, 11 insertions(+), 17 deletions(-) diff --git a/transformer_engine/pytorch/module/base.py b/transformer_engine/pytorch/module/base.py index a9dfd4b6ec..eb83722512 100644 --- a/transformer_engine/pytorch/module/base.py +++ b/transformer_engine/pytorch/module/base.py @@ -605,10 +605,10 @@ def fill_userbuffers_buffer_for_all_gather( class TransformerEngineBaseModule(torch.nn.Module, ABC): """Base TE module.""" - def __init__(self) -> None: + def __init__(self, name: Optional[str] = None) -> None: super().__init__() assert torch.cuda.is_available(), "TransformerEngine needs CUDA." - self.name = None + self.name = name self.next_iter_when_debug_should_be_run = 0 self.fp8_initialized = False self.fp8 = False @@ -1602,8 +1602,6 @@ def _validate_name(self): """ if self.name is not None: return - assert TEDebugState.debug_enabled - import nvdlfw_inspect.api as debug_api self.name = f"Layer_{TEDebugState.get_layer_count()}" diff --git a/transformer_engine/pytorch/module/grouped_linear.py b/transformer_engine/pytorch/module/grouped_linear.py index b5fbb0b8fd..ec1da1e02e 100644 --- a/transformer_engine/pytorch/module/grouped_linear.py +++ b/transformer_engine/pytorch/module/grouped_linear.py @@ -614,7 +614,7 @@ def __init__( save_original_input: bool = False, name: Optional[str] = None, ) -> None: - super().__init__() + super().__init__(name) params_dtype = torch.get_default_dtype() if params_dtype is None else params_dtype self.num_gemms = num_gemms @@ -633,7 +633,6 @@ def __init__( ), "GroupedLinear doesn't support Userbuffer overlap." self.get_rng_state_tracker = get_rng_state_tracker self.rng_tracker_name = rng_tracker_name - self.name = name self.wgrad_store = WeightGradStore(delay_wgrad_compute) diff --git a/transformer_engine/pytorch/module/layernorm_linear.py b/transformer_engine/pytorch/module/layernorm_linear.py index dd0e1a4c61..2b821e38a5 100644 --- a/transformer_engine/pytorch/module/layernorm_linear.py +++ b/transformer_engine/pytorch/module/layernorm_linear.py @@ -1161,9 +1161,9 @@ def __init__( ub_name: Optional[str] = None, delay_wgrad_compute: bool = False, symmetric_ar_type: Optional[str] = None, - name: str = None, + name: Optional[str] = None, ) -> None: - super().__init__() + super().__init__(name) params_dtype = torch.get_default_dtype() if params_dtype is None else params_dtype self.in_features = in_features @@ -1182,7 +1182,6 @@ def __init__( self.symmetric_ar_type = symmetric_ar_type self.wgrad_store = WeightGradStore(delay_wgrad_compute, ub_bulk_wgrad) - self.name = name if tp_group is None: self.tp_size = tp_size diff --git a/transformer_engine/pytorch/module/layernorm_mlp.py b/transformer_engine/pytorch/module/layernorm_mlp.py index a283640c7e..0311092449 100644 --- a/transformer_engine/pytorch/module/layernorm_mlp.py +++ b/transformer_engine/pytorch/module/layernorm_mlp.py @@ -1789,7 +1789,7 @@ def __init__( zero_centered_gamma: bool = False, device: Union[torch.device, str] = "cuda", ub_overlap_ag: bool = False, - name: str = None, + name: Optional[str] = None, ub_overlap_rs: bool = False, ub_overlap_rs_dgrad: bool = False, ub_bulk_dgrad: bool = False, @@ -1798,7 +1798,7 @@ def __init__( symmetric_ar_type: Optional[str] = None, checkpoint: bool = False, ) -> None: - super().__init__() + super().__init__(name) params_dtype = torch.get_default_dtype() if params_dtype is None else params_dtype self.fuse_wgrad_accumulation = fuse_wgrad_accumulation @@ -1829,7 +1829,6 @@ def __init__( for use_fp8 in [False, True] ) ) - self.name = name self.wgrad_store = WeightGradStore(delay_wgrad_compute, ub_bulk_wgrad) diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index cf85a1d836..68c9758bbc 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -1098,7 +1098,7 @@ def __init__( save_original_input: bool = False, name: Optional[str] = None, ) -> None: - super().__init__() + super().__init__(name) params_dtype = torch.get_default_dtype() if params_dtype is None else params_dtype self.in_features = in_features @@ -1111,7 +1111,6 @@ def __init__( self.rng_tracker_name = rng_tracker_name self.symmetric_ar_type = symmetric_ar_type self.save_original_input = save_original_input - self.name = name self.wgrad_store = WeightGradStore(delay_wgrad_compute, ub_bulk_wgrad) diff --git a/transformer_engine/pytorch/transformer.py b/transformer_engine/pytorch/transformer.py index fb2676d422..1f856d386f 100644 --- a/transformer_engine/pytorch/transformer.py +++ b/transformer_engine/pytorch/transformer.py @@ -447,7 +447,7 @@ def __init__( qk_norm_type=qk_norm_type, qk_norm_eps=qk_norm_eps, qk_norm_before_rope=qk_norm_before_rope, - name=name + ".self_attention" if name is not None else None, + name=self.name + ".self_attention" if self.name is not None else None, ) if layer_type == "decoder": @@ -464,7 +464,7 @@ def __init__( qk_norm_type=qk_norm_type, qk_norm_eps=qk_norm_eps, qk_norm_before_rope=qk_norm_before_rope, - name=name + ".inter_attention" if name is not None else None, + name=self.name + ".inter_attention" if self.name is not None else None, ) # LayerNorm -> activation(Linear + Bias) -> Linear @@ -500,7 +500,7 @@ def __init__( activation_params=activation_params, normalization=normalization, device=device, - name=name + ".layernorm_mlp" if name is not None else None, + name=self.name + ".layernorm_mlp" if self.name is not None else None, ) self.hidden_dropout = hidden_dropout From 8843a462b5fffddc8fcfb414585bc8a7e48fa58a Mon Sep 17 00:00:00 2001 From: Przemek Tredak Date: Wed, 14 Jan 2026 13:12:37 -0800 Subject: [PATCH 19/24] Change the setattr usage in the tensor parallel group setting Signed-off-by: Przemek Tredak --- transformer_engine/pytorch/module/base.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/transformer_engine/pytorch/module/base.py b/transformer_engine/pytorch/module/base.py index eb83722512..aa78ccaa2a 100644 --- a/transformer_engine/pytorch/module/base.py +++ b/transformer_engine/pytorch/module/base.py @@ -964,8 +964,8 @@ def set_tensor_parallel_group(self, tp_group: Union[dist_group_type, None]) -> N tp_group : ProcessGroup, default = None tensor parallel process group. """ - self.tp_group = tp_group - self.tp_group_initialized = True + self.fast_setattr("tp_group", tp_group) + self.fast_setattr("tp_group_initialized", True) def _get_fp8_params(self) -> Union[List[torch.Tensor], None]: """returns the FP8 weights.""" From 084847f3e0c1157391863509cae68ca7211a54ab Mon Sep 17 00:00:00 2001 From: Przemek Tredak Date: Wed, 14 Jan 2026 15:41:30 -0800 Subject: [PATCH 20/24] Adding try/finally - it does not seem to impact the time in observable way Signed-off-by: Przemek Tredak --- .../pytorch/module/grouped_linear.py | 110 ++++----- .../pytorch/module/layernorm_linear.py | 156 ++++++------- .../pytorch/module/layernorm_mlp.py | 214 +++++++++--------- transformer_engine/pytorch/module/linear.py | 141 ++++++------ 4 files changed, 314 insertions(+), 307 deletions(-) diff --git a/transformer_engine/pytorch/module/grouped_linear.py b/transformer_engine/pytorch/module/grouped_linear.py index ec1da1e02e..10a28e01f8 100644 --- a/transformer_engine/pytorch/module/grouped_linear.py +++ b/transformer_engine/pytorch/module/grouped_linear.py @@ -795,61 +795,63 @@ def forward( is_grad_enabled = torch.is_grad_enabled() inp = self.prepare_forward(inp, num_gemms=self.num_gemms) - weight_tensors = self._get_weight_tensors() - bias_tensors = [getattr(self, f"bias{i}") for i in range(self.num_gemms)] - - quantizers = self._get_quantizers() if not debug else self._get_debug_quantizers() - - if debug: - if self.no_debug_features_active(list(chain(*quantizers))): - debug = False - quantizers = self._get_quantizers() - - if isinstance(weight_tensors, QuantizedTensorStorage): - raise RuntimeError("FP8 weights are not supported in debug mode.") - - ( - input_quantizers, - weight_quantizers, - output_quantizers, - grad_input_quantizers, - grad_weight_quantizers, - grad_output_quantizers, - ) = quantizers - - if is_grad_enabled: - linear_fn = _GroupedLinear.apply - autograd_ctx = [] - else: - linear_fn = _GroupedLinear.forward - autograd_ctx = [None] - - non_tensor_args = ( - m_splits, - self.apply_bias, - is_first_microbatch, - self.fp8, - self.fp8_calibration, - self.wgrad_store, - input_quantizers, - weight_quantizers, - output_quantizers, - grad_input_quantizers, - grad_weight_quantizers, - grad_output_quantizers, - self.fuse_wgrad_accumulation, - is_cpu_offload_enabled(), - self.sequence_parallel, - self.activation_dtype, - is_grad_enabled, - self, - None, # skip_fp8_weight_update - self.save_original_input, - debug, - ) - out = linear_fn(*autograd_ctx, inp, non_tensor_args, *weight_tensors, *bias_tensors) + try: + weight_tensors = self._get_weight_tensors() + bias_tensors = [getattr(self, f"bias{i}") for i in range(self.num_gemms)] + + quantizers = self._get_quantizers() if not debug else self._get_debug_quantizers() + + if debug: + if self.no_debug_features_active(list(chain(*quantizers))): + debug = False + quantizers = self._get_quantizers() + + if isinstance(weight_tensors, QuantizedTensorStorage): + raise RuntimeError("FP8 weights are not supported in debug mode.") + + ( + input_quantizers, + weight_quantizers, + output_quantizers, + grad_input_quantizers, + grad_weight_quantizers, + grad_output_quantizers, + ) = quantizers + + if is_grad_enabled: + linear_fn = _GroupedLinear.apply + autograd_ctx = [] + else: + linear_fn = _GroupedLinear.forward + autograd_ctx = [None] + + non_tensor_args = ( + m_splits, + self.apply_bias, + is_first_microbatch, + self.fp8, + self.fp8_calibration, + self.wgrad_store, + input_quantizers, + weight_quantizers, + output_quantizers, + grad_input_quantizers, + grad_weight_quantizers, + grad_output_quantizers, + self.fuse_wgrad_accumulation, + is_cpu_offload_enabled(), + self.sequence_parallel, + self.activation_dtype, + is_grad_enabled, + self, + None, # skip_fp8_weight_update + self.save_original_input, + debug, + ) + out = linear_fn(*autograd_ctx, inp, non_tensor_args, *weight_tensors, *bias_tensors) - self.end_forward() + finally: + self.end_forward() if self.return_bias: return out, [cast_if_needed(b, self.activation_dtype) for b in bias_tensors] diff --git a/transformer_engine/pytorch/module/layernorm_linear.py b/transformer_engine/pytorch/module/layernorm_linear.py index 2b821e38a5..484ad8ffba 100644 --- a/transformer_engine/pytorch/module/layernorm_linear.py +++ b/transformer_engine/pytorch/module/layernorm_linear.py @@ -1519,85 +1519,87 @@ def forward( inp, allow_non_contiguous=False # removed .contiguous from inside the layer ) - # Get concatenated weight and bias tensors - weight_tensor, bias_tensor = self._get_weight_and_bias_tensors() - - quantizers = ( - self._get_quantizers(fp8_output, fp8_grad, is_grad_enabled) - if not debug - else self._get_debug_quantizers(fp8_output, fp8_grad, is_grad_enabled) - ) - if debug: - if self.no_debug_features_active(quantizers): - debug = False - quantizers = self._get_quantizers(fp8_output, fp8_grad, is_grad_enabled) - - ( - input_quantizer, - weight_quantizer, - output_quantizer, - grad_input_quantizer, - grad_weight_quantizer, - grad_output_quantizer, - ) = quantizers + try: + # Get concatenated weight and bias tensors + weight_tensor, bias_tensor = self._get_weight_and_bias_tensors() + + quantizers = ( + self._get_quantizers(fp8_output, fp8_grad, is_grad_enabled) + if not debug + else self._get_debug_quantizers(fp8_output, fp8_grad, is_grad_enabled) + ) + if debug: + if self.no_debug_features_active(quantizers): + debug = False + quantizers = self._get_quantizers(fp8_output, fp8_grad, is_grad_enabled) - if is_grad_enabled: - fwd_fn = _LayerNormLinear.apply - autograd_ctx = [] - else: - fwd_fn = _LayerNormLinear.forward - autograd_ctx = [None] - non_tensor_args = ( - self.eps, - is_first_microbatch, - self.fp8, - self.fp8_calibration, - self.wgrad_store, - self.fuse_wgrad_accumulation, - input_quantizer, - weight_quantizer, - output_quantizer, - grad_input_quantizer, - grad_weight_quantizer, - grad_output_quantizer, - is_cpu_offload_enabled(), - self.tp_group, - self.tp_size, - self.sequence_parallel, - self.tp_size > 1, - self.activation_dtype, - self.parallel_mode, - self.return_layernorm_output, - self.return_layernorm_output_gathered, - is_grad_enabled, - self.fwd_ln_sm_margin if is_grad_enabled else self.inf_ln_sm_margin, - self.bwd_ln_sm_margin, - self.zero_centered_gamma, - self.normalization, - self.ub_overlap_ag_fprop, - self.ub_overlap_rs_fprop, - self.ub_overlap_ag_dgrad, - self.ub_overlap_rs_dgrad, - self.ub_bulk_wgrad, - self.ub_bulk_dgrad, - self.ub_name, - self.fsdp_group, - self, - skip_fp8_weight_update, - self.symmetric_ar_type, - debug, - ) - out = fwd_fn( - *autograd_ctx, - inp, - self.layer_norm_weight, - self.layer_norm_bias, - weight_tensor, - bias_tensor if self.apply_bias and not self.gemm_bias_unfused_add else None, - non_tensor_args, - ) + ( + input_quantizer, + weight_quantizer, + output_quantizer, + grad_input_quantizer, + grad_weight_quantizer, + grad_output_quantizer, + ) = quantizers + + if is_grad_enabled: + fwd_fn = _LayerNormLinear.apply + autograd_ctx = [] + else: + fwd_fn = _LayerNormLinear.forward + autograd_ctx = [None] + non_tensor_args = ( + self.eps, + is_first_microbatch, + self.fp8, + self.fp8_calibration, + self.wgrad_store, + self.fuse_wgrad_accumulation, + input_quantizer, + weight_quantizer, + output_quantizer, + grad_input_quantizer, + grad_weight_quantizer, + grad_output_quantizer, + is_cpu_offload_enabled(), + self.tp_group, + self.tp_size, + self.sequence_parallel, + self.tp_size > 1, + self.activation_dtype, + self.parallel_mode, + self.return_layernorm_output, + self.return_layernorm_output_gathered, + is_grad_enabled, + self.fwd_ln_sm_margin if is_grad_enabled else self.inf_ln_sm_margin, + self.bwd_ln_sm_margin, + self.zero_centered_gamma, + self.normalization, + self.ub_overlap_ag_fprop, + self.ub_overlap_rs_fprop, + self.ub_overlap_ag_dgrad, + self.ub_overlap_rs_dgrad, + self.ub_bulk_wgrad, + self.ub_bulk_dgrad, + self.ub_name, + self.fsdp_group, + self, + skip_fp8_weight_update, + self.symmetric_ar_type, + debug, + ) + out = fwd_fn( + *autograd_ctx, + inp, + self.layer_norm_weight, + self.layer_norm_bias, + weight_tensor, + bias_tensor if self.apply_bias and not self.gemm_bias_unfused_add else None, + non_tensor_args, + ) - self.end_forward() + finally: + self.end_forward() if self.return_layernorm_output: out, ln_out = out diff --git a/transformer_engine/pytorch/module/layernorm_mlp.py b/transformer_engine/pytorch/module/layernorm_mlp.py index 0311092449..bf9f4bd08d 100644 --- a/transformer_engine/pytorch/module/layernorm_mlp.py +++ b/transformer_engine/pytorch/module/layernorm_mlp.py @@ -2055,115 +2055,117 @@ def forward( inp = self.prepare_forward(inp, num_gemms=2) - quantizers = ( - self._get_quantizers(fp8_output, is_grad_enabled) - if not debug - else self._get_debug_quantizers(fp8_output, is_grad_enabled) - ) - if debug: - if self.no_debug_features_active(quantizers): - debug = False - quantizers = self._get_quantizers(fp8_output, is_grad_enabled) - - # Get quantizers - ( - fc1_input_quantizer, - fc1_weight_quantizer, - fc1_output_quantizer, - fc1_grad_input_quantizer, - fc1_grad_weight_quantizer, - fc1_grad_output_quantizer, - fc2_input_quantizer, - fc2_weight_quantizer, - fc2_output_quantizer, - fc2_grad_input_quantizer, - fc2_grad_weight_quantizer, - fc2_grad_output_quantizer, - ) = quantizers + try: + quantizers = ( + self._get_quantizers(fp8_output, is_grad_enabled) + if not debug + else self._get_debug_quantizers(fp8_output, is_grad_enabled) + ) + if debug: + if self.no_debug_features_active(quantizers): + debug = False + quantizers = self._get_quantizers(fp8_output, is_grad_enabled) - # Get weight tensors - fc1_weight, fc2_weight = self._get_weight_tensors() - fc1_bias = self.fc1_bias if self.use_bias else None - fc2_bias = self.fc2_bias if self.use_bias else None - if not self.fp8: - if isinstance(fc1_weight, Float8Tensor): - fc1_weight = fc1_weight.dequantize() - if isinstance(fc2_weight, Float8Tensor): - fc2_weight = fc2_weight.dequantize() - - # Disable bias_gelu_nvfusion for determinism checkpointing in non-reentrant mode - if self.bias_gelu_nvfusion and not use_reentrant_activation_recompute(): - self.fast_setattr("bias_gelu_nvfusion", False) - - if is_grad_enabled: - fwd_fn = _LayerNormMLP.apply - autograd_ctx = [] - else: - fwd_fn = _LayerNormMLP.forward - autograd_ctx = [None] + # Get quantizers + ( + fc1_input_quantizer, + fc1_weight_quantizer, + fc1_output_quantizer, + fc1_grad_input_quantizer, + fc1_grad_weight_quantizer, + fc1_grad_output_quantizer, + fc2_input_quantizer, + fc2_weight_quantizer, + fc2_output_quantizer, + fc2_grad_input_quantizer, + fc2_grad_weight_quantizer, + fc2_grad_output_quantizer, + ) = quantizers + + # Get weight tensors + fc1_weight, fc2_weight = self._get_weight_tensors() + fc1_bias = self.fc1_bias if self.use_bias else None + fc2_bias = self.fc2_bias if self.use_bias else None + if not self.fp8: + if isinstance(fc1_weight, Float8Tensor): + fc1_weight = fc1_weight.dequantize() + if isinstance(fc2_weight, Float8Tensor): + fc2_weight = fc2_weight.dequantize() + + # Disable bias_gelu_nvfusion for determinism checkpointing in non-reentrant mode + if self.bias_gelu_nvfusion and not use_reentrant_activation_recompute(): + self.fast_setattr("bias_gelu_nvfusion", False) - non_tensor_args = ( - self.eps, - is_first_microbatch, - self.fp8, - self.fp8_calibration, - self.wgrad_store, - self.fuse_wgrad_accumulation, - fc1_input_quantizer, - fc1_weight_quantizer, - fc1_output_quantizer, - fc1_grad_input_quantizer, - fc1_grad_weight_quantizer, - fc1_grad_output_quantizer, - fc2_input_quantizer, - fc2_weight_quantizer, - fc2_output_quantizer, - fc2_grad_input_quantizer, - fc2_grad_weight_quantizer, - fc2_grad_output_quantizer, - is_cpu_offload_enabled(), - self.tp_group, - self.tp_size, - self.sequence_parallel, - self.tp_size > 1, - self.activation_dtype, - self.return_layernorm_output, - self.return_layernorm_output_gathered, - self.bias_gelu_nvfusion and not self.fp8 and not debug, - self.set_parallel_mode, - is_grad_enabled, - self.fwd_ln_sm_margin if is_grad_enabled else self.inf_ln_sm_margin, - self.bwd_ln_sm_margin, - self.zero_centered_gamma, - self.activation, - self.activation_params, - self.normalization, - self.ub_overlap_ag, - self.ub_overlap_rs, - self.ub_overlap_rs_dgrad, - self.ub_bulk_dgrad, - self.ub_bulk_wgrad, - self.gemm_gelu_fusion and not debug, - self.fsdp_group, - self, - skip_fp8_weight_update, - self.symmetric_ar_type, - self.checkpoint, - debug, - ) - out = fwd_fn( - *autograd_ctx, - inp, - self.layer_norm_weight, - self.layer_norm_bias, - fc1_weight, - fc1_bias, - fc2_weight, - fc2_bias if self.apply_bias and not self.gemm_bias_unfused_add else None, - non_tensor_args, - ) + if is_grad_enabled: + fwd_fn = _LayerNormMLP.apply + autograd_ctx = [] + else: + fwd_fn = _LayerNormMLP.forward + autograd_ctx = [None] + + non_tensor_args = ( + self.eps, + is_first_microbatch, + self.fp8, + self.fp8_calibration, + self.wgrad_store, + self.fuse_wgrad_accumulation, + fc1_input_quantizer, + fc1_weight_quantizer, + fc1_output_quantizer, + fc1_grad_input_quantizer, + fc1_grad_weight_quantizer, + fc1_grad_output_quantizer, + fc2_input_quantizer, + fc2_weight_quantizer, + fc2_output_quantizer, + fc2_grad_input_quantizer, + fc2_grad_weight_quantizer, + fc2_grad_output_quantizer, + is_cpu_offload_enabled(), + self.tp_group, + self.tp_size, + self.sequence_parallel, + self.tp_size > 1, + self.activation_dtype, + self.return_layernorm_output, + self.return_layernorm_output_gathered, + self.bias_gelu_nvfusion and not self.fp8 and not debug, + self.set_parallel_mode, + is_grad_enabled, + self.fwd_ln_sm_margin if is_grad_enabled else self.inf_ln_sm_margin, + self.bwd_ln_sm_margin, + self.zero_centered_gamma, + self.activation, + self.activation_params, + self.normalization, + self.ub_overlap_ag, + self.ub_overlap_rs, + self.ub_overlap_rs_dgrad, + self.ub_bulk_dgrad, + self.ub_bulk_wgrad, + self.gemm_gelu_fusion and not debug, + self.fsdp_group, + self, + skip_fp8_weight_update, + self.symmetric_ar_type, + self.checkpoint, + debug, + ) + out = fwd_fn( + *autograd_ctx, + inp, + self.layer_norm_weight, + self.layer_norm_bias, + fc1_weight, + fc1_bias, + fc2_weight, + fc2_bias if self.apply_bias and not self.gemm_bias_unfused_add else None, + non_tensor_args, + ) - self.end_forward() + finally: + self.end_forward() if self.return_layernorm_output: out, ln_out = out diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index 68c9758bbc..98d90aca20 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -1400,78 +1400,79 @@ def forward( fp8_grad = True inp = self.prepare_forward(inp, allow_non_contiguous=isinstance(inp, QuantizedTensor)) + try: + weight_tensor, bias_tensor = self._get_weight_and_bias_tensors() - weight_tensor, bias_tensor = self._get_weight_and_bias_tensors() - - quantizers = ( - self._get_quantizers(fp8_output, fp8_grad, is_grad_enabled) - if not debug - else self._get_debug_quantizers(fp8_output, fp8_grad, is_grad_enabled) - ) - if debug: - if self.no_debug_features_active(quantizers): - debug = False - quantizers = self._get_quantizers(fp8_output, fp8_grad, is_grad_enabled) - - ( - input_quantizer, - weight_quantizer, - output_quantizer, - grad_input_quantizer, - grad_weight_quantizer, - grad_output_quantizer, - ) = quantizers - - if is_grad_enabled: - linear_fn = _Linear.apply - autograd_ctx = [] - else: - linear_fn = _Linear.forward - autograd_ctx = [None] + quantizers = ( + self._get_quantizers(fp8_output, fp8_grad, is_grad_enabled) + if not debug + else self._get_debug_quantizers(fp8_output, fp8_grad, is_grad_enabled) + ) + if debug: + if self.no_debug_features_active(quantizers): + debug = False + quantizers = self._get_quantizers(fp8_output, fp8_grad, is_grad_enabled) - non_tensor_args = ( - is_first_microbatch, - self.fp8, - self.fp8_calibration, - self.wgrad_store, - input_quantizer, - weight_quantizer, - output_quantizer, - grad_input_quantizer, - grad_weight_quantizer, - grad_output_quantizer, - self.fuse_wgrad_accumulation, - is_cpu_offload_enabled(), - self.tp_group, - self.tp_size, - self.sequence_parallel, - self.tp_size > 1, - self.activation_dtype, - self.parallel_mode, - is_grad_enabled, - self.ub_overlap_rs_fprop, - self.ub_overlap_ag_dgrad, - self.ub_overlap_ag_fprop, - self.ub_overlap_rs_dgrad, - self.ub_bulk_dgrad, - self.ub_bulk_wgrad, - self.ub_name, - fp8_output, - self.fsdp_group, - self, - skip_fp8_weight_update, - self.symmetric_ar_type, - self.save_original_input, - debug, - ) - out = linear_fn( - *autograd_ctx, - weight_tensor, - inp, - bias_tensor if (self.apply_bias and not self.gemm_bias_unfused_add) else None, - non_tensor_args, - ) - self.end_forward() + ( + input_quantizer, + weight_quantizer, + output_quantizer, + grad_input_quantizer, + grad_weight_quantizer, + grad_output_quantizer, + ) = quantizers + + if is_grad_enabled: + linear_fn = _Linear.apply + autograd_ctx = [] + else: + linear_fn = _Linear.forward + autograd_ctx = [None] + + non_tensor_args = ( + is_first_microbatch, + self.fp8, + self.fp8_calibration, + self.wgrad_store, + input_quantizer, + weight_quantizer, + output_quantizer, + grad_input_quantizer, + grad_weight_quantizer, + grad_output_quantizer, + self.fuse_wgrad_accumulation, + is_cpu_offload_enabled(), + self.tp_group, + self.tp_size, + self.sequence_parallel, + self.tp_size > 1, + self.activation_dtype, + self.parallel_mode, + is_grad_enabled, + self.ub_overlap_rs_fprop, + self.ub_overlap_ag_dgrad, + self.ub_overlap_ag_fprop, + self.ub_overlap_rs_dgrad, + self.ub_bulk_dgrad, + self.ub_bulk_wgrad, + self.ub_name, + fp8_output, + self.fsdp_group, + self, + skip_fp8_weight_update, + self.symmetric_ar_type, + self.save_original_input, + debug, + ) + out = linear_fn( + *autograd_ctx, + weight_tensor, + inp, + bias_tensor if (self.apply_bias and not self.gemm_bias_unfused_add) else None, + non_tensor_args, + ) + finally: + self.end_forward() if self.gemm_bias_unfused_add: out = out + cast_if_needed(bias_tensor, self.activation_dtype) From d004747eedb4272075b1c7fd13f79d8540cd5d18 Mon Sep 17 00:00:00 2001 From: Przemek Tredak Date: Thu, 15 Jan 2026 10:25:57 -0800 Subject: [PATCH 21/24] Fixing lint issues and the thunder test Signed-off-by: Przemek Tredak --- qa/L1_pytorch_thunder_integration/test.sh | 2 +- transformer_engine/pytorch/attention/multi_head_attention.py | 1 - transformer_engine/pytorch/module/base.py | 3 +-- transformer_engine/pytorch/transformer.py | 1 - 4 files changed, 2 insertions(+), 5 deletions(-) diff --git a/qa/L1_pytorch_thunder_integration/test.sh b/qa/L1_pytorch_thunder_integration/test.sh index a5fd33cda9..8c3fdc8cdb 100644 --- a/qa/L1_pytorch_thunder_integration/test.sh +++ b/qa/L1_pytorch_thunder_integration/test.sh @@ -9,7 +9,7 @@ set -x mkdir -p "$XML_LOG_DIR" pip3 install pytest==8.1.1 pytest-benchmark==5.1.0 -python3 -m pytest -c $TE_PATH/tests/pytorch/pytest.ini -v -s --junitxml=$XML_LOG_DIR/pytest.xml ${THUNDER_PATH}/thunder/tests/test_transformer_engine_executor.py +python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest.xml ${THUNDER_PATH}/thunder/tests/test_transformer_engine_executor.py # Check return code # Note: Return code 5 is fine. Lightning tests are skipped on systems diff --git a/transformer_engine/pytorch/attention/multi_head_attention.py b/transformer_engine/pytorch/attention/multi_head_attention.py index 143252640b..d813e7c8f1 100644 --- a/transformer_engine/pytorch/attention/multi_head_attention.py +++ b/transformer_engine/pytorch/attention/multi_head_attention.py @@ -8,7 +8,6 @@ from typing import Callable, List, Optional, Tuple, Union import torch -from transformer_engine.debug.pytorch.debug_state import TEDebugState from transformer_engine.pytorch.quantization import FP8GlobalStateManager from transformer_engine.pytorch.tensor.float8_tensor import Float8Tensor from transformer_engine.pytorch.module.base import TransformerEngineBaseModule diff --git a/transformer_engine/pytorch/module/base.py b/transformer_engine/pytorch/module/base.py index aa78ccaa2a..acb730cf83 100644 --- a/transformer_engine/pytorch/module/base.py +++ b/transformer_engine/pytorch/module/base.py @@ -11,11 +11,10 @@ from enum import Enum from abc import ABC, abstractmethod from typing import Any, Dict, Generator, List, Optional, Tuple, Union -from typing_extensions import Self from contextlib import contextmanager -import logging from types import MethodType +from typing_extensions import Self import torch import torch.nn.functional as F from torch.distributed.tensor import DTensor diff --git a/transformer_engine/pytorch/transformer.py b/transformer_engine/pytorch/transformer.py index 1f856d386f..7c3125a165 100644 --- a/transformer_engine/pytorch/transformer.py +++ b/transformer_engine/pytorch/transformer.py @@ -12,7 +12,6 @@ from transformer_engine.pytorch.torch_version import torch_version from transformer_engine.pytorch.module import LayerNormMLP, LayerNorm, RMSNorm -from transformer_engine.debug.pytorch.debug_state import TEDebugState from transformer_engine.pytorch.attention.multi_head_attention import MultiheadAttention from transformer_engine.pytorch.attention.inference import InferenceParams from transformer_engine.pytorch.jit import ( From 523944a5c70bc1e5bbc92b0358612bf10f3b9165 Mon Sep 17 00:00:00 2001 From: Przemek Tredak Date: Thu, 15 Jan 2026 14:47:48 -0800 Subject: [PATCH 22/24] Fix 1 of the debug tests Signed-off-by: Przemek Tredak --- tests/pytorch/debug/test_sanity.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/pytorch/debug/test_sanity.py b/tests/pytorch/debug/test_sanity.py index aee5474e76..4d80882aea 100644 --- a/tests/pytorch/debug/test_sanity.py +++ b/tests/pytorch/debug/test_sanity.py @@ -33,7 +33,7 @@ LogFp8TensorStats: enabled: True tensors: [activation, gradient, weight] - stats: [underflows, overflows] + stats: [underflows%] start_step : 0 end_step: 1 """, From 9873235d0c3d6dfd5d0e06f9c6c6ed95b43d8a78 Mon Sep 17 00:00:00 2001 From: Przemek Tredak Date: Thu, 15 Jan 2026 15:08:18 -0800 Subject: [PATCH 23/24] Removed the warning and enforcement in the CI Signed-off-by: Przemek Tredak --- qa/L0_pytorch_debug_unittest/test.sh | 16 +-- qa/L0_pytorch_unittest/test.sh | 52 ++++---- qa/L1_pytorch_distributed_unittest/test.sh | 24 ++-- qa/L1_pytorch_onnx_unittest/test.sh | 2 +- tests/pytorch/pytest.ini | 8 -- tests/pytorch/test_numerics.py | 124 ++++++++---------- .../dot_product_attention.py | 2 - transformer_engine/pytorch/graph.py | 8 +- transformer_engine/pytorch/module/base.py | 16 --- .../pytorch/module/grouped_linear.py | 2 - .../pytorch/module/layernorm_linear.py | 2 - .../pytorch/module/layernorm_mlp.py | 2 - transformer_engine/pytorch/module/linear.py | 2 - 13 files changed, 104 insertions(+), 156 deletions(-) delete mode 100644 tests/pytorch/pytest.ini diff --git a/qa/L0_pytorch_debug_unittest/test.sh b/qa/L0_pytorch_debug_unittest/test.sh index 6d40f44b08..ce65bc4305 100644 --- a/qa/L0_pytorch_debug_unittest/test.sh +++ b/qa/L0_pytorch_debug_unittest/test.sh @@ -28,16 +28,16 @@ mkdir -p "$XML_LOG_DIR" pip install pytest==8.2.1 || error_exit "Failed to install pytest" -pytest -c $TE_PATH/tests/pytorch/pytest.ini -v -s --junitxml=$XML_LOG_DIR/test_sanity.xml $TE_PATH/tests/pytorch/debug/test_sanity.py --feature_dirs=$NVTE_TEST_NVINSPECT_FEATURE_DIRS || test_fail "test_sanity.py" -pytest -c $TE_PATH/tests/pytorch/pytest.ini -v -s --junitxml=$XML_LOG_DIR/test_config.xml $TE_PATH/tests/pytorch/debug/test_config.py --feature_dirs=$NVTE_TEST_NVINSPECT_FEATURE_DIRS || test_fail "test_config.py" -pytest -c $TE_PATH/tests/pytorch/pytest.ini -v -s --junitxml=$XML_LOG_DIR/test_numerics.xml $TE_PATH/tests/pytorch/debug/test_numerics.py --feature_dirs=$NVTE_TEST_NVINSPECT_FEATURE_DIRS || test_fail "test_numerics.py" -pytest -c $TE_PATH/tests/pytorch/pytest.ini -v -s --junitxml=$XML_LOG_DIR/test_log.xml $TE_PATH/tests/pytorch/debug/test_log.py --feature_dirs=$NVTE_TEST_NVINSPECT_FEATURE_DIRS --configs_dir=$NVTE_TEST_NVINSPECT_CONFIGS_DIR || test_fail "test_log.py" -NVTE_TORCH_COMPILE=0 pytest -c $TE_PATH/tests/pytorch/pytest.ini -v -s --junitxml=$XML_LOG_DIR/test_api_features.xml $TE_PATH/tests/pytorch/debug/test_api_features.py --feature_dirs=$NVTE_TEST_NVINSPECT_FEATURE_DIRS --configs_dir=$NVTE_TEST_NVINSPECT_CONFIGS_DIR || test_fail "test_api_features.py" -pytest -c $TE_PATH/tests/pytorch/pytest.ini -v -s --junitxml=$XML_LOG_DIR/test_perf.xml $TE_PATH/tests/pytorch/debug/test_perf.py --feature_dirs=$NVTE_TEST_NVINSPECT_FEATURE_DIRS --configs_dir=$NVTE_TEST_NVINSPECT_CONFIGS_DIR || test_fail "test_perf.py" +pytest -v -s --junitxml=$XML_LOG_DIR/test_sanity.xml $TE_PATH/tests/pytorch/debug/test_sanity.py --feature_dirs=$NVTE_TEST_NVINSPECT_FEATURE_DIRS || test_fail "test_sanity.py" +pytest -v -s --junitxml=$XML_LOG_DIR/test_config.xml $TE_PATH/tests/pytorch/debug/test_config.py --feature_dirs=$NVTE_TEST_NVINSPECT_FEATURE_DIRS || test_fail "test_config.py" +pytest -v -s --junitxml=$XML_LOG_DIR/test_numerics.xml $TE_PATH/tests/pytorch/debug/test_numerics.py --feature_dirs=$NVTE_TEST_NVINSPECT_FEATURE_DIRS || test_fail "test_numerics.py" +pytest -v -s --junitxml=$XML_LOG_DIR/test_log.xml $TE_PATH/tests/pytorch/debug/test_log.py --feature_dirs=$NVTE_TEST_NVINSPECT_FEATURE_DIRS --configs_dir=$NVTE_TEST_NVINSPECT_CONFIGS_DIR || test_fail "test_log.py" +NVTE_TORCH_COMPILE=0 pytest -v -s --junitxml=$XML_LOG_DIR/test_api_features.xml $TE_PATH/tests/pytorch/debug/test_api_features.py --feature_dirs=$NVTE_TEST_NVINSPECT_FEATURE_DIRS --configs_dir=$NVTE_TEST_NVINSPECT_CONFIGS_DIR || test_fail "test_api_features.py" +pytest -v -s --junitxml=$XML_LOG_DIR/test_perf.xml $TE_PATH/tests/pytorch/debug/test_perf.py --feature_dirs=$NVTE_TEST_NVINSPECT_FEATURE_DIRS --configs_dir=$NVTE_TEST_NVINSPECT_CONFIGS_DIR || test_fail "test_perf.py" # standard sanity and numerics tests with initialized debug -NVTE_TEST_NVINSPECT_ENABLED=1 NVTE_TEST_NVINSPECT_CONFIG_FILE=$NVTE_TEST_NVINSPECT_DUMMY_CONFIG_FILE NVTE_TEST_NVINSPECT_FEATURE_DIRS=$NVTE_TEST_NVINSPECT_FEATURE_DIRS PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 pytest -c $TE_PATH/tests/pytorch/pytest.ini -v -s --junitxml=$XML_LOG_DIR/test_sanity_2.xml $TE_PATH/tests/pytorch/test_sanity.py || test_fail "debug test_sanity.py" -NVTE_TEST_NVINSPECT_ENABLED=1 NVTE_TEST_NVINSPECT_CONFIG_FILE=$NVTE_TEST_NVINSPECT_DUMMY_CONFIG_FILE NVTE_TEST_NVINSPECT_FEATURE_DIRS=$NVTE_TEST_NVINSPECT_FEATURE_DIRS PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 NVTE_FUSED_ATTN=0 pytest -c $TE_PATH/tests/pytorch/pytest.ini -v -s --junitxml=$XML_LOG_DIR/test_numerics_2.xml $TE_PATH/tests/pytorch/test_numerics.py || test_fail "debug test_numerics.py" +NVTE_TEST_NVINSPECT_ENABLED=1 NVTE_TEST_NVINSPECT_CONFIG_FILE=$NVTE_TEST_NVINSPECT_DUMMY_CONFIG_FILE NVTE_TEST_NVINSPECT_FEATURE_DIRS=$NVTE_TEST_NVINSPECT_FEATURE_DIRS PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 pytest -v -s --junitxml=$XML_LOG_DIR/test_sanity_2.xml $TE_PATH/tests/pytorch/test_sanity.py || test_fail "debug test_sanity.py" +NVTE_TEST_NVINSPECT_ENABLED=1 NVTE_TEST_NVINSPECT_CONFIG_FILE=$NVTE_TEST_NVINSPECT_DUMMY_CONFIG_FILE NVTE_TEST_NVINSPECT_FEATURE_DIRS=$NVTE_TEST_NVINSPECT_FEATURE_DIRS PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 NVTE_FUSED_ATTN=0 pytest -v -s --junitxml=$XML_LOG_DIR/test_numerics_2.xml $TE_PATH/tests/pytorch/test_numerics.py || test_fail "debug test_numerics.py" if [ "$RET" -ne 0 ]; then echo "Error in the following test cases:$FAILED_CASES" diff --git a/qa/L0_pytorch_unittest/test.sh b/qa/L0_pytorch_unittest/test.sh index 454536358c..21eed28367 100644 --- a/qa/L0_pytorch_unittest/test.sh +++ b/qa/L0_pytorch_unittest/test.sh @@ -24,32 +24,32 @@ mkdir -p "$XML_LOG_DIR" pip3 install pytest==8.2.1 || error_exit "Failed to install pytest" -python3 -m pytest -c $TE_PATH/tests/pytorch/pytest.ini --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_sanity.xml $TE_PATH/tests/pytorch/test_sanity.py || test_fail "test_sanity.py" -python3 -m pytest -c $TE_PATH/tests/pytorch/pytest.ini --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_recipe.xml $TE_PATH/tests/pytorch/test_recipe.py || test_fail "test_recipe.py" -python3 -m pytest -c $TE_PATH/tests/pytorch/pytest.ini --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_deferred_init.xml $TE_PATH/tests/pytorch/test_deferred_init.py || test_fail "test_deferred_init.py" -PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 NVTE_FUSED_ATTN=0 python3 -m pytest -c $TE_PATH/tests/pytorch/pytest.ini --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_numerics.xml $TE_PATH/tests/pytorch/test_numerics.py || test_fail "test_numerics.py" -PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 NVTE_FUSED_ATTN=0 python3 -m pytest -c $TE_PATH/tests/pytorch/pytest.ini --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_cuda_graphs.xml $TE_PATH/tests/pytorch/test_cuda_graphs.py || test_fail "test_cuda_graphs.py" -python3 -m pytest -c $TE_PATH/tests/pytorch/pytest.ini --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_jit.xml $TE_PATH/tests/pytorch/test_jit.py || test_fail "test_jit.py" -python3 -m pytest -c $TE_PATH/tests/pytorch/pytest.ini --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_fused_rope.xml $TE_PATH/tests/pytorch/test_fused_rope.py || test_fail "test_fused_rope.py" -python3 -m pytest -c $TE_PATH/tests/pytorch/pytest.ini --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_nvfp4.xml $TE_PATH/tests/pytorch/nvfp4 || test_fail "test_nvfp4" -python3 -m pytest -c $TE_PATH/tests/pytorch/pytest.ini --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_quantized_tensor.xml $TE_PATH/tests/pytorch/test_quantized_tensor.py || test_fail "test_quantized_tensor.py" -python3 -m pytest -c $TE_PATH/tests/pytorch/pytest.ini --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_float8blockwisetensor.xml $TE_PATH/tests/pytorch/test_float8blockwisetensor.py || test_fail "test_float8blockwisetensor.py" -python3 -m pytest -c $TE_PATH/tests/pytorch/pytest.ini --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_float8_blockwise_scaling_exact.xml $TE_PATH/tests/pytorch/test_float8_blockwise_scaling_exact.py || test_fail "test_float8_blockwise_scaling_exact.py" -python3 -m pytest -c $TE_PATH/tests/pytorch/pytest.ini --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_float8_blockwise_gemm_exact.xml $TE_PATH/tests/pytorch/test_float8_blockwise_gemm_exact.py || test_fail "test_float8_blockwise_gemm_exact.py" -python3 -m pytest -c $TE_PATH/tests/pytorch/pytest.ini --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_gqa.xml $TE_PATH/tests/pytorch/test_gqa.py || test_fail "test_gqa.py" -python3 -m pytest -c $TE_PATH/tests/pytorch/pytest.ini --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_fused_optimizer.xml $TE_PATH/tests/pytorch/test_fused_optimizer.py || test_fail "test_fused_optimizer.py" -python3 -m pytest -c $TE_PATH/tests/pytorch/pytest.ini --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_multi_tensor.xml $TE_PATH/tests/pytorch/test_multi_tensor.py || test_fail "test_multi_tensor.py" -python3 -m pytest -c $TE_PATH/tests/pytorch/pytest.ini --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_fusible_ops.xml $TE_PATH/tests/pytorch/test_fusible_ops.py || test_fail "test_fusible_ops.py" -python3 -m pytest -c $TE_PATH/tests/pytorch/pytest.ini --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_permutation.xml $TE_PATH/tests/pytorch/test_permutation.py || test_fail "test_permutation.py" -python3 -m pytest -c $TE_PATH/tests/pytorch/pytest.ini --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_parallel_cross_entropy.xml $TE_PATH/tests/pytorch/test_parallel_cross_entropy.py || test_fail "test_parallel_cross_entropy.py" -python3 -m pytest -c $TE_PATH/tests/pytorch/pytest.ini --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_cpu_offloading.xml $TE_PATH/tests/pytorch/test_cpu_offloading.py || test_fail "test_cpu_offloading.py" -NVTE_FLASH_ATTN=0 NVTE_CPU_OFFLOAD_V1=1 python3 -m pytest -c $TE_PATH/tests/pytorch/pytest.ini --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_cpu_offloading_v1.xml $TE_PATH/tests/pytorch/test_cpu_offloading_v1.py || test_fail "test_cpu_offloading_v1.py" -python3 -m pytest -c $TE_PATH/tests/pytorch/pytest.ini --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_attention.xml $TE_PATH/tests/pytorch/attention/test_attention.py || test_fail "test_attention.py" -python3 -m pytest -c $TE_PATH/tests/pytorch/pytest.ini --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_kv_cache.xml $TE_PATH/tests/pytorch/attention/test_kv_cache.py || test_fail "test_kv_cache.py" -python3 -m pytest -c $TE_PATH/tests/pytorch/pytest.ini --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_hf_integration.xml $TE_PATH/tests/pytorch/test_hf_integration.py || test_fail "test_hf_integration.py" -NVTE_TEST_CHECKPOINT_ARTIFACT_PATH=$TE_PATH/artifacts/tests/pytorch/test_checkpoint python3 -m pytest -c $TE_PATH/tests/pytorch/pytest.ini --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_checkpoint.xml $TE_PATH/tests/pytorch/test_checkpoint.py || test_fail "test_checkpoint.py" -python3 -m pytest -c $TE_PATH/tests/pytorch/pytest.ini --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_fused_router.xml $TE_PATH/tests/pytorch/test_fused_router.py || test_fail "test_fused_router.py" -python3 -m pytest -c $TE_PATH/tests/pytorch/pytest.ini --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_partial_cast.xml $TE_PATH/tests/pytorch/test_partial_cast.py || test_fail "test_partial_cast.py" +python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_sanity.xml $TE_PATH/tests/pytorch/test_sanity.py || test_fail "test_sanity.py" +python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_recipe.xml $TE_PATH/tests/pytorch/test_recipe.py || test_fail "test_recipe.py" +python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_deferred_init.xml $TE_PATH/tests/pytorch/test_deferred_init.py || test_fail "test_deferred_init.py" +PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 NVTE_FUSED_ATTN=0 python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_numerics.xml $TE_PATH/tests/pytorch/test_numerics.py || test_fail "test_numerics.py" +PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 NVTE_FUSED_ATTN=0 python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_cuda_graphs.xml $TE_PATH/tests/pytorch/test_cuda_graphs.py || test_fail "test_cuda_graphs.py" +python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_jit.xml $TE_PATH/tests/pytorch/test_jit.py || test_fail "test_jit.py" +python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_fused_rope.xml $TE_PATH/tests/pytorch/test_fused_rope.py || test_fail "test_fused_rope.py" +python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_nvfp4.xml $TE_PATH/tests/pytorch/nvfp4 || test_fail "test_nvfp4" +python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_quantized_tensor.xml $TE_PATH/tests/pytorch/test_quantized_tensor.py || test_fail "test_quantized_tensor.py" +python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_float8blockwisetensor.xml $TE_PATH/tests/pytorch/test_float8blockwisetensor.py || test_fail "test_float8blockwisetensor.py" +python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_float8_blockwise_scaling_exact.xml $TE_PATH/tests/pytorch/test_float8_blockwise_scaling_exact.py || test_fail "test_float8_blockwise_scaling_exact.py" +python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_float8_blockwise_gemm_exact.xml $TE_PATH/tests/pytorch/test_float8_blockwise_gemm_exact.py || test_fail "test_float8_blockwise_gemm_exact.py" +python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_gqa.xml $TE_PATH/tests/pytorch/test_gqa.py || test_fail "test_gqa.py" +python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_fused_optimizer.xml $TE_PATH/tests/pytorch/test_fused_optimizer.py || test_fail "test_fused_optimizer.py" +python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_multi_tensor.xml $TE_PATH/tests/pytorch/test_multi_tensor.py || test_fail "test_multi_tensor.py" +python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_fusible_ops.xml $TE_PATH/tests/pytorch/test_fusible_ops.py || test_fail "test_fusible_ops.py" +python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_permutation.xml $TE_PATH/tests/pytorch/test_permutation.py || test_fail "test_permutation.py" +python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_parallel_cross_entropy.xml $TE_PATH/tests/pytorch/test_parallel_cross_entropy.py || test_fail "test_parallel_cross_entropy.py" +python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_cpu_offloading.xml $TE_PATH/tests/pytorch/test_cpu_offloading.py || test_fail "test_cpu_offloading.py" +NVTE_FLASH_ATTN=0 NVTE_CPU_OFFLOAD_V1=1 python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_cpu_offloading_v1.xml $TE_PATH/tests/pytorch/test_cpu_offloading_v1.py || test_fail "test_cpu_offloading_v1.py" +python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_attention.xml $TE_PATH/tests/pytorch/attention/test_attention.py || test_fail "test_attention.py" +python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_kv_cache.xml $TE_PATH/tests/pytorch/attention/test_kv_cache.py || test_fail "test_kv_cache.py" +python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_hf_integration.xml $TE_PATH/tests/pytorch/test_hf_integration.py || test_fail "test_hf_integration.py" +NVTE_TEST_CHECKPOINT_ARTIFACT_PATH=$TE_PATH/artifacts/tests/pytorch/test_checkpoint python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_checkpoint.xml $TE_PATH/tests/pytorch/test_checkpoint.py || test_fail "test_checkpoint.py" +python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_fused_router.xml $TE_PATH/tests/pytorch/test_fused_router.py || test_fail "test_fused_router.py" +python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_partial_cast.xml $TE_PATH/tests/pytorch/test_partial_cast.py || test_fail "test_partial_cast.py" if [ "$RET" -ne 0 ]; then echo "Error in the following test cases:$FAILED_CASES" diff --git a/qa/L1_pytorch_distributed_unittest/test.sh b/qa/L1_pytorch_distributed_unittest/test.sh index 12439422c4..9d868d99cf 100644 --- a/qa/L1_pytorch_distributed_unittest/test.sh +++ b/qa/L1_pytorch_distributed_unittest/test.sh @@ -22,16 +22,16 @@ mkdir -p "$XML_LOG_DIR" pip3 install pytest==8.2.1 || error_exit "Failed to install pytest" -python3 -m pytest -c $TE_PATH/tests/pytorch/pytest.ini -v -s --junitxml=$XML_LOG_DIR/pytest_test_sanity.xml $TE_PATH/tests/pytorch/distributed/test_sanity.py || test_fail "test_sanity.py" -python3 -m pytest -c $TE_PATH/tests/pytorch/pytest.ini -v -s --junitxml=$XML_LOG_DIR/pytest_test_numerics.xml $TE_PATH/tests/pytorch/distributed/test_numerics.py || test_fail "test_numerics.py" -python3 -m pytest -c $TE_PATH/tests/pytorch/pytest.ini -v -s --junitxml=$XML_LOG_DIR/pytest_test_numerics_exact.xml $TE_PATH/tests/pytorch/distributed/test_numerics_exact.py || test_fail "test_numerics_exact.py" -python3 -m pytest -c $TE_PATH/tests/pytorch/pytest.ini -v -s --junitxml=$XML_LOG_DIR/pytest_test_fusible_ops.xml $TE_PATH/tests/pytorch/distributed/test_fusible_ops.py || test_fail "test_fusible_ops.py" -python3 -m pytest -c $TE_PATH/tests/pytorch/pytest.ini -v -s --junitxml=$XML_LOG_DIR/pytest_test_torch_fsdp2.xml $TE_PATH/tests/pytorch/distributed/test_torch_fsdp2.py || test_fail "test_torch_fsdp2.py" -python3 -m pytest -c $TE_PATH/tests/pytorch/pytest.ini -v -s --junitxml=$XML_LOG_DIR/pytest_test_comm_gemm_overlap.xml $TE_PATH/tests/pytorch/distributed/test_comm_gemm_overlap.py || test_fail "test_comm_gemm_overlap.py" -python3 -m pytest -c $TE_PATH/tests/pytorch/pytest.ini -v -s --junitxml=$XML_LOG_DIR/pytest_test_fusible_ops_with_userbuffers.xml $TE_PATH/tests/pytorch/distributed/test_fusible_ops_with_userbuffers.py || test_fail "test_fusible_ops_with_userbuffers.py" -python3 -m pytest -c $TE_PATH/tests/pytorch/pytest.ini -v -s --junitxml=$XML_LOG_DIR/pytest_test_attention_with_cp.xml $TE_PATH/tests/pytorch/attention/test_attention_with_cp.py || test_fail "test_attention_with_cp.py" -python3 -m pytest -c $TE_PATH/tests/pytorch/pytest.ini -v -s --junitxml=$XML_LOG_DIR/pytest_test_cp_utils.xml $TE_PATH/tests/pytorch/attention/test_cp_utils.py || test_fail "test_cp_utils.py" -python3 -m pytest -c $TE_PATH/tests/pytorch/pytest.ini -v -s --junitxml=$XML_LOG_DIR/pytest_test_cast_master_weights_to_fp8.xml $TE_PATH/tests/pytorch/distributed/test_cast_master_weights_to_fp8.py || test_fail "test_cast_master_weights_to_fp8.py" +python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_sanity.xml $TE_PATH/tests/pytorch/distributed/test_sanity.py || test_fail "test_sanity.py" +python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_numerics.xml $TE_PATH/tests/pytorch/distributed/test_numerics.py || test_fail "test_numerics.py" +python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_numerics_exact.xml $TE_PATH/tests/pytorch/distributed/test_numerics_exact.py || test_fail "test_numerics_exact.py" +python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_fusible_ops.xml $TE_PATH/tests/pytorch/distributed/test_fusible_ops.py || test_fail "test_fusible_ops.py" +python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_torch_fsdp2.xml $TE_PATH/tests/pytorch/distributed/test_torch_fsdp2.py || test_fail "test_torch_fsdp2.py" +python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_comm_gemm_overlap.xml $TE_PATH/tests/pytorch/distributed/test_comm_gemm_overlap.py || test_fail "test_comm_gemm_overlap.py" +python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_fusible_ops_with_userbuffers.xml $TE_PATH/tests/pytorch/distributed/test_fusible_ops_with_userbuffers.py || test_fail "test_fusible_ops_with_userbuffers.py" +python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_attention_with_cp.xml $TE_PATH/tests/pytorch/attention/test_attention_with_cp.py || test_fail "test_attention_with_cp.py" +python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_cp_utils.xml $TE_PATH/tests/pytorch/attention/test_cp_utils.py || test_fail "test_cp_utils.py" +python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_cast_master_weights_to_fp8.xml $TE_PATH/tests/pytorch/distributed/test_cast_master_weights_to_fp8.py || test_fail "test_cast_master_weights_to_fp8.py" # debug tests @@ -42,9 +42,9 @@ python3 -m pytest -c $TE_PATH/tests/pytorch/pytest.ini -v -s --junitxml=$XML_LOG : ${NVTE_TEST_NVINSPECT_DUMMY_CONFIG_FILE:=$TE_PATH/tests/pytorch/debug/test_configs/dummy_feature.yaml} : ${NVTE_TEST_NVINSPECT_FEATURE_DIRS:=$TE_PATH/transformer_engine/debug/features} -pytest -c $TE_PATH/tests/pytorch/pytest.ini -v -s --junitxml=$XML_LOG_DIR/pytest_test_distributed.xml $TE_PATH/tests/pytorch/debug/test_distributed.py --feature_dirs=$NVTE_TEST_NVINSPECT_FEATURE_DIRS || test_fail "debug test_distributed.py" +pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_distributed.xml $TE_PATH/tests/pytorch/debug/test_distributed.py --feature_dirs=$NVTE_TEST_NVINSPECT_FEATURE_DIRS || test_fail "debug test_distributed.py" # standard numerics tests with initialized debug -NVTE_TEST_NVINSPECT_ENABLED=1 NVTE_TEST_NVINSPECT_CONFIG_FILE=$NVTE_TEST_NVINSPECT_DUMMY_CONFIG_FILE NVTE_TEST_NVINSPECT_FEATURE_DIRS=$NVTE_TEST_NVINSPECT_FEATURE_DIRS pytest -c $TE_PATH/tests/pytorch/pytest.ini -v -s --junitxml=$XML_LOG_DIR/pytest_test_numerics_2.xml $TE_PATH/tests/pytorch/distributed/test_numerics.py || test_fail "debug test_numerics.py" +NVTE_TEST_NVINSPECT_ENABLED=1 NVTE_TEST_NVINSPECT_CONFIG_FILE=$NVTE_TEST_NVINSPECT_DUMMY_CONFIG_FILE NVTE_TEST_NVINSPECT_FEATURE_DIRS=$NVTE_TEST_NVINSPECT_FEATURE_DIRS pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_numerics_2.xml $TE_PATH/tests/pytorch/distributed/test_numerics.py || test_fail "debug test_numerics.py" if [ "$RET" -ne 0 ]; then echo "Error in the following test cases:$FAILED_CASES" diff --git a/qa/L1_pytorch_onnx_unittest/test.sh b/qa/L1_pytorch_onnx_unittest/test.sh index 0e84a5ca5e..b3a520e129 100644 --- a/qa/L1_pytorch_onnx_unittest/test.sh +++ b/qa/L1_pytorch_onnx_unittest/test.sh @@ -6,4 +6,4 @@ : ${XML_LOG_DIR:=/logs} mkdir -p "$XML_LOG_DIR" -python3 -m pytest -c $TE_PATH/tests/pytorch/pytest.ini --tb=auto --junitxml=$XML_LOG_DIR/test_onnx_export.xml $TE_PATH/tests/pytorch/test_onnx_export.py +python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/test_onnx_export.xml $TE_PATH/tests/pytorch/test_onnx_export.py diff --git a/tests/pytorch/pytest.ini b/tests/pytorch/pytest.ini deleted file mode 100644 index e90989721b..0000000000 --- a/tests/pytorch/pytest.ini +++ /dev/null @@ -1,8 +0,0 @@ -# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# -# See LICENSE for license information. - -[pytest] -filterwarnings= - error::RuntimeWarning - diff --git a/tests/pytorch/test_numerics.py b/tests/pytorch/test_numerics.py index 9a1942f30d..abe2806e66 100644 --- a/tests/pytorch/test_numerics.py +++ b/tests/pytorch/test_numerics.py @@ -5,7 +5,6 @@ import math import os from typing import Dict, List, Tuple, Optional -import warnings import pytest import random @@ -1297,15 +1296,14 @@ def test_linear_accuracy_delay_wgrad_compute(dtype, bs, model, bias, fuse_wgrad_ ).eval() # Share params - with warnings.catch_warnings(action="ignore", category=RuntimeWarning): - with torch.no_grad(): - te_linear_ref.weight = Parameter(te_linear.weight.clone()) - if bias: - te_linear_ref.bias = Parameter(te_linear.bias.clone()) - if fuse_wgrad_accumulation: - weight = getattr(te_linear, f"weight") - weight.main_grad = torch.rand_like(weight, dtype=torch.float32) - te_linear_ref.weight.main_grad = weight.main_grad.clone() + with torch.no_grad(): + te_linear_ref.weight = Parameter(te_linear.weight.clone()) + if bias: + te_linear_ref.bias = Parameter(te_linear.bias.clone()) + if fuse_wgrad_accumulation: + weight = getattr(te_linear, f"weight") + weight.main_grad = torch.rand_like(weight, dtype=torch.float32) + te_linear_ref.weight.main_grad = weight.main_grad.clone() te_outputs = _test_granular_accuracy(te_linear, bs, dtype, config, delay_wgrad_compute=True) te_outputs_ref = _test_granular_accuracy( @@ -1361,13 +1359,12 @@ def test_linear_accuracy_save_original_input(dtype, model, recipe): ).eval() # Share params - with warnings.catch_warnings(action="ignore", category=RuntimeWarning): - with torch.no_grad(): - te_linear_ref.weight = Parameter(te_linear.weight.clone()) - if fuse_wgrad_accumulation: - weight = getattr(te_linear, f"weight") - weight.main_grad = torch.rand_like(weight, dtype=torch.float32) - te_linear_ref.weight.main_grad = weight.main_grad.clone() + with torch.no_grad(): + te_linear_ref.weight = Parameter(te_linear.weight.clone()) + if fuse_wgrad_accumulation: + weight = getattr(te_linear, f"weight") + weight.main_grad = torch.rand_like(weight, dtype=torch.float32) + te_linear_ref.weight.main_grad = weight.main_grad.clone() te_outputs = _test_granular_accuracy(te_linear, bs, dtype, config, recipe=recipe) te_outputs_ref = _test_granular_accuracy(te_linear_ref, bs, dtype, config, recipe=recipe) @@ -1604,18 +1601,17 @@ def test_layernorm_linear_accuracy_delay_wgrad_compute( ).eval() # Share params - with warnings.catch_warnings(action="ignore", category=RuntimeWarning): - with torch.no_grad(): - ln_linear_ref.layer_norm_weight = Parameter(ln_linear.layer_norm_weight.clone()) - if normalization != "RMSNorm": - ln_linear_ref.layer_norm_bias = Parameter(ln_linear.layer_norm_bias.clone()) - ln_linear_ref.weight = Parameter(ln_linear.weight.clone()) - if bias: - ln_linear_ref.bias = Parameter(ln_linear.bias.clone()) - if fuse_wgrad_accumulation: - weight = getattr(ln_linear, f"weight") - weight.main_grad = torch.rand_like(weight, dtype=torch.float32) - ln_linear_ref.weight.main_grad = weight.main_grad.clone() + with torch.no_grad(): + ln_linear_ref.layer_norm_weight = Parameter(ln_linear.layer_norm_weight.clone()) + if normalization != "RMSNorm": + ln_linear_ref.layer_norm_bias = Parameter(ln_linear.layer_norm_bias.clone()) + ln_linear_ref.weight = Parameter(ln_linear.weight.clone()) + if bias: + ln_linear_ref.bias = Parameter(ln_linear.bias.clone()) + if fuse_wgrad_accumulation: + weight = getattr(ln_linear, f"weight") + weight.main_grad = torch.rand_like(weight, dtype=torch.float32) + ln_linear_ref.weight.main_grad = weight.main_grad.clone() te_outputs = _test_granular_accuracy(ln_linear, bs, dtype, config, delay_wgrad_compute=True) te_outputs_ref = _test_granular_accuracy( @@ -1743,24 +1739,19 @@ def test_layernorm_mlp_accuracy_delay_wgrad_compute( ).eval() # Share params - with warnings.catch_warnings(action="ignore", category=RuntimeWarning): - with torch.no_grad(): - ln_mlp_ref.layer_norm_weight = Parameter(ln_mlp.layer_norm_weight.clone()) - ln_mlp_ref.layer_norm_bias = Parameter(ln_mlp.layer_norm_bias.clone()) - ln_mlp_ref.fc1_weight = Parameter(ln_mlp.fc1_weight.clone()) - ln_mlp_ref.fc2_weight = Parameter(ln_mlp.fc2_weight.clone()) - if bias: - ln_mlp_ref.fc1_bias = Parameter(ln_mlp.fc1_bias.clone()) - ln_mlp_ref.fc2_bias = Parameter(ln_mlp.fc2_bias.clone()) - if fuse_wgrad_accumulation: - ln_mlp.fc1_weight.main_grad = torch.rand_like( - ln_mlp.fc1_weight, dtype=torch.float32 - ) - ln_mlp_ref.fc1_weight.main_grad = ln_mlp.fc1_weight.main_grad.clone() - ln_mlp.fc2_weight.main_grad = torch.rand_like( - ln_mlp.fc2_weight, dtype=torch.float32 - ) - ln_mlp_ref.fc2_weight.main_grad = ln_mlp.fc2_weight.main_grad.clone() + with torch.no_grad(): + ln_mlp_ref.layer_norm_weight = Parameter(ln_mlp.layer_norm_weight.clone()) + ln_mlp_ref.layer_norm_bias = Parameter(ln_mlp.layer_norm_bias.clone()) + ln_mlp_ref.fc1_weight = Parameter(ln_mlp.fc1_weight.clone()) + ln_mlp_ref.fc2_weight = Parameter(ln_mlp.fc2_weight.clone()) + if bias: + ln_mlp_ref.fc1_bias = Parameter(ln_mlp.fc1_bias.clone()) + ln_mlp_ref.fc2_bias = Parameter(ln_mlp.fc2_bias.clone()) + if fuse_wgrad_accumulation: + ln_mlp.fc1_weight.main_grad = torch.rand_like(ln_mlp.fc1_weight, dtype=torch.float32) + ln_mlp_ref.fc1_weight.main_grad = ln_mlp.fc1_weight.main_grad.clone() + ln_mlp.fc2_weight.main_grad = torch.rand_like(ln_mlp.fc2_weight, dtype=torch.float32) + ln_mlp_ref.fc2_weight.main_grad = ln_mlp.fc2_weight.main_grad.clone() te_outputs = _test_granular_accuracy(ln_mlp, bs, dtype, config, delay_wgrad_compute=True) te_outputs_ref = _test_granular_accuracy( @@ -1805,15 +1796,14 @@ def test_layernorm_mlp_accuracy_checkpoint( ).eval() # Share params - with warnings.catch_warnings(action="ignore", category=RuntimeWarning): - with torch.no_grad(): - ln_mlp_ref.layer_norm_weight = Parameter(ln_mlp.layer_norm_weight.clone()) - ln_mlp_ref.layer_norm_bias = Parameter(ln_mlp.layer_norm_bias.clone()) - ln_mlp_ref.fc1_weight = Parameter(ln_mlp.fc1_weight.clone()) - ln_mlp_ref.fc2_weight = Parameter(ln_mlp.fc2_weight.clone()) - if bias: - ln_mlp_ref.fc1_bias = Parameter(ln_mlp.fc1_bias.clone()) - ln_mlp_ref.fc2_bias = Parameter(ln_mlp.fc2_bias.clone()) + with torch.no_grad(): + ln_mlp_ref.layer_norm_weight = Parameter(ln_mlp.layer_norm_weight.clone()) + ln_mlp_ref.layer_norm_bias = Parameter(ln_mlp.layer_norm_bias.clone()) + ln_mlp_ref.fc1_weight = Parameter(ln_mlp.fc1_weight.clone()) + ln_mlp_ref.fc2_weight = Parameter(ln_mlp.fc2_weight.clone()) + if bias: + ln_mlp_ref.fc1_bias = Parameter(ln_mlp.fc1_bias.clone()) + ln_mlp_ref.fc2_bias = Parameter(ln_mlp.fc2_bias.clone()) te_outputs = _test_granular_accuracy(ln_mlp, bs, dtype, config, delay_wgrad_compute=False) te_outputs_ref = _test_granular_accuracy( @@ -1962,13 +1952,9 @@ def test_grouped_linear_accuracy( # Share params with torch.no_grad(): for i in range(num_gemms): - sequential_linear[i].module_setattr( - "weight", Parameter(getattr(grouped_linear, f"weight{i}").clone()) - ) + sequential_linear[i].weight = Parameter(getattr(grouped_linear, f"weight{i}").clone()) if bias: - sequential_linear[i].module_setattr( - "bias", Parameter(getattr(grouped_linear, f"bias{i}").clone()) - ) + sequential_linear[i].bias = Parameter(getattr(grouped_linear, f"bias{i}").clone()) if fuse_wgrad_accumulation: weight_i = getattr(grouped_linear, f"weight{i}") weight_i.main_grad = torch.rand_like(weight_i, dtype=torch.float32) @@ -2110,13 +2096,9 @@ def test_grouped_linear_accuracy_save_original_input( # Share params with torch.no_grad(): for i in range(num_gemms): - sequential_linear[i].module_setattr( - "weight", Parameter(getattr(grouped_linear, f"weight{i}").clone()) - ) + sequential_linear[i].weight = Parameter(getattr(grouped_linear, f"weight{i}").clone()) if bias: - sequential_linear[i].module_setattr( - "bias", Parameter(getattr(grouped_linear, f"bias{i}").clone()) - ) + sequential_linear[i].bias = Parameter(getattr(grouped_linear, f"bias{i}").clone()) if fuse_wgrad_accumulation: weight_i = getattr(grouped_linear, f"weight{i}") weight_i.main_grad = torch.rand_like(weight_i, dtype=torch.float32) @@ -2316,7 +2298,8 @@ def test_padding_grouped_linear_accuracy( with torch.no_grad(): inner_grouped_linear = grouped_linear.linear_fn for i in range(num_gemms): - ref_grouped_linear.module_setattr( + setattr( + ref_grouped_linear, f"weight{i}", Parameter(getattr(inner_grouped_linear, f"weight{i}").clone()), ) @@ -2392,7 +2375,8 @@ def test_padding_grouped_linear_accuracy_save_original_input( with torch.no_grad(): inner_grouped_linear = grouped_linear.linear_fn for i in range(num_gemms): - ref_grouped_linear.module_setattr( + setattr( + ref_grouped_linear, f"weight{i}", Parameter(getattr(inner_grouped_linear, f"weight{i}").clone()), ) diff --git a/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py b/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py index ad7dac108a..51ffbc2e48 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py @@ -482,8 +482,6 @@ def remove_extra_states_check(self, incompatible_keys): # pylint: disable=unuse self.register_load_state_dict_post_hook(remove_extra_states_check) - self._initialized = True - def _load_from_state_dict( self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs ): diff --git a/transformer_engine/pytorch/graph.py b/transformer_engine/pytorch/graph.py index 322ad52723..f587ca9946 100644 --- a/transformer_engine/pytorch/graph.py +++ b/transformer_engine/pytorch/graph.py @@ -935,8 +935,7 @@ def new_fwd(*user_args, **user_kwargs): forward = make_graphed_forward(func, func.training, graphed, func.forward, te_modules) if _order is None: - with warnings.catch_warnings(action="ignore", category=RuntimeWarning): - func.forward = forward + func.forward = forward ret.append(func) else: ret.append(forward) @@ -944,9 +943,8 @@ def new_fwd(*user_args, **user_kwargs): ret.append(graphed) backward_dw_func, reset_func = make_graphed_attribute_functions(i) - with warnings.catch_warnings(action="ignore", category=RuntimeWarning): - setattr(ret[-1], "backward_dw", backward_dw_func) - setattr(ret[-1], "reset", reset_func) + setattr(ret[-1], "backward_dw", backward_dw_func) + setattr(ret[-1], "reset", reset_func) if just_one_callable: return ret[0] diff --git a/transformer_engine/pytorch/module/base.py b/transformer_engine/pytorch/module/base.py index acb730cf83..2770d9bb7b 100644 --- a/transformer_engine/pytorch/module/base.py +++ b/transformer_engine/pytorch/module/base.py @@ -14,7 +14,6 @@ from contextlib import contextmanager from types import MethodType -from typing_extensions import Self import torch import torch.nn.functional as F from torch.distributed.tensor import DTensor @@ -649,17 +648,6 @@ def module_setattr(self, name: str, value: Any) -> None: """ super().__setattr__(name, value) - def __setattr__(self, name: str, value: Any) -> None: - if "_initialized" in self.__dict__ and self._initialized: - warnings.warn( - """The default implementation of torch.nn.Module introduces significant CPU overhead - when setting attributes and is therefore not recommended. Please use the explicit - calls (fast_setattr for setting regular values and module_setattr for setting - parameters, children modules and buffers).""", - RuntimeWarning, - ) - super().__setattr__(name, value) - def adjust_amax_history_length(self, length: int, fwd: Optional[bool] = None) -> None: """ Delayed scaling only. @@ -1108,10 +1096,6 @@ def prepare_forward_ctx( ) self.end_forward() - def train(self, mode: bool = True) -> Self: - with warnings.catch_warnings(action="ignore", category=RuntimeWarning): - return super().train(mode) - def set_nccl_overlap_warning_if_tp(self) -> None: """When using TP, the NCCL communication needs to be scheduled before the GEMM for there to be a guaranteed overlap. From the diff --git a/transformer_engine/pytorch/module/grouped_linear.py b/transformer_engine/pytorch/module/grouped_linear.py index 10a28e01f8..e095019330 100644 --- a/transformer_engine/pytorch/module/grouped_linear.py +++ b/transformer_engine/pytorch/module/grouped_linear.py @@ -715,8 +715,6 @@ def __init__( if name in (f"weight{i}", f"bias{i}"): param.skip_backward_post_hook = True - self._initialized = True - def set_meta_tensor(self, fwd: bool, recipe: Recipe) -> None: """Init scales and amaxes for fwd | bwd.""" super().set_meta_tensor(fwd, recipe) diff --git a/transformer_engine/pytorch/module/layernorm_linear.py b/transformer_engine/pytorch/module/layernorm_linear.py index 484ad8ffba..90e5a95c13 100644 --- a/transformer_engine/pytorch/module/layernorm_linear.py +++ b/transformer_engine/pytorch/module/layernorm_linear.py @@ -1404,8 +1404,6 @@ def __init__( if name in self.weight_names or name in self.bias_names: param.skip_backward_post_hook = True - self._initialized = True - def set_meta_tensor(self, fwd: bool, recipe: Recipe) -> None: """Init scales and amaxes for fwd | bwd.""" super().set_meta_tensor(fwd, recipe) diff --git a/transformer_engine/pytorch/module/layernorm_mlp.py b/transformer_engine/pytorch/module/layernorm_mlp.py index bf9f4bd08d..28f8558d34 100644 --- a/transformer_engine/pytorch/module/layernorm_mlp.py +++ b/transformer_engine/pytorch/module/layernorm_mlp.py @@ -1959,8 +1959,6 @@ def __init__( self.bwd_ln_sm_margin = int(os.getenv("NVTE_BWD_LAYERNORM_SM_MARGIN", "0")) self.inf_ln_sm_margin = int(os.getenv("NVTE_INF_LAYERNORM_SM_MARGIN", "0")) - self._initialized = True - def set_meta_tensor(self, fwd: bool, recipe: Recipe) -> None: """Init scales and amaxes for fwd | bwd.""" super().set_meta_tensor(fwd, recipe) diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index 98d90aca20..1099de79ec 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -1308,8 +1308,6 @@ def __init__( if name in self.weight_names or name in self.bias_names: param.skip_backward_post_hook = True - self._initialized = True - def set_meta_tensor(self, fwd: bool, recipe: Recipe) -> None: """Init scales and amaxes for fwd | bwd.""" super().set_meta_tensor(fwd, recipe) From 7b55639f2dd31ae735e6f2b179586b642e233e5e Mon Sep 17 00:00:00 2001 From: Przemek Tredak Date: Fri, 16 Jan 2026 12:27:51 -0800 Subject: [PATCH 24/24] try-finally in the context manager Signed-off-by: Przemek Tredak --- transformer_engine/pytorch/module/base.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/transformer_engine/pytorch/module/base.py b/transformer_engine/pytorch/module/base.py index 2770d9bb7b..3b8cc5db27 100644 --- a/transformer_engine/pytorch/module/base.py +++ b/transformer_engine/pytorch/module/base.py @@ -1091,10 +1091,13 @@ def prepare_forward_ctx( allow_different_data_and_param_types: bool = False, ) -> Generator[torch.Tensor, None, None]: """Checks and prepares for FWD execution.""" - yield self.prepare_forward( + inp = self.prepare_forward( inp, num_gemms, allow_non_contiguous, allow_different_data_and_param_types ) - self.end_forward() + try: + yield inp + finally: + self.end_forward() def set_nccl_overlap_warning_if_tp(self) -> None: """When using TP, the NCCL communication needs to be scheduled