From 6723e976d0eced0d1d694641f7c682ccf036eca9 Mon Sep 17 00:00:00 2001 From: Amrit Krishnan Date: Thu, 3 Oct 2024 08:50:33 -0400 Subject: [PATCH 1/5] Add skeleton for data2vec class --- mmlearn/tasks/data2vec.py | 381 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 381 insertions(+) create mode 100644 mmlearn/tasks/data2vec.py diff --git a/mmlearn/tasks/data2vec.py b/mmlearn/tasks/data2vec.py new file mode 100644 index 0000000..d3c758b --- /dev/null +++ b/mmlearn/tasks/data2vec.py @@ -0,0 +1,381 @@ +"""Data2Vec task.""" + +import inspect +from dataclasses import dataclass +from functools import partial +from typing import Any, Dict, Literal, Optional, Union + +import lightning as L # noqa: N812 +import torch +from hydra_zen import store +from lightning.pytorch.utilities.types import OptimizerLRScheduler +from lightning_utilities.core.rank_zero import rank_zero_warn +from torch import nn + +from mmlearn.datasets.processors.masking import apply_masks +from mmlearn.modules.ema import ExponentialMovingAverage +from mmlearn.modules.losses.data2vec import Data2VecLoss + + +@dataclass +class ModuleKeySpec: + """Module key specification for mapping modules to modalities.""" + + encoder_key: Optional[str] = None + head_key: Optional[str] = None + postprocessor_key: Optional[str] = None + + +@dataclass +class EvaluationSpec: + """Specification for an evaluation task.""" + + task: Any # `EvaluationHooks` expected + run_on_validation: bool = True + run_on_test: bool = True + + +@store(group="task", provider="mmlearn") +class Data2VecTask(L.LightningModule): + """Data2Vec task. + + This class implements the Data2Vec self-supervised learning approach for a single + modality. It can be used as an auxiliary task in multi-modal learning setups. + + Parameters + ---------- + encoder : nn.Module + The encoder for the modality. + optimizer : partial[torch.optim.Optimizer], optional + The optimizer to use for training. + lr_scheduler : Union[ + Dict[str, Union[partial[torch.optim.lr_scheduler.LRScheduler], Any]], + partial[torch.optim.lr_scheduler.LRScheduler] + ], optional + The learning rate scheduler to use for training. + loss : Data2VecLoss, optional + The loss function to use. + ema_decay : float + The initial decay value for EMA. + ema_end_decay : float + The final decay value for EMA. + ema_anneal_end_step : int + The number of steps to anneal the decay from `ema_decay` to `ema_end_decay`. + mask_generator : Any + The mask generator to use for creating masked inputs. + compute_validation_loss : bool + Whether to compute the validation loss. + compute_test_loss : bool + Whether to compute the test loss. + evaluation_tasks : Dict[str, EvaluationSpec], optional + Evaluation tasks to run during validation and testing. + """ + + def __init__( + self, + encoder: nn.Module, + optimizer: Optional[partial[torch.optim.Optimizer]] = None, + lr_scheduler: Optional[ + Union[ + Dict[str, Union[partial[torch.optim.lr_scheduler.LRScheduler], Any]], + partial[torch.optim.lr_scheduler.LRScheduler], + ] + ] = None, + loss: Optional[Data2VecLoss] = None, + ema_decay: float = 0.999, + ema_end_decay: float = 0.9999, + ema_anneal_end_step: int = 300000, + mask_generator: Any = None, + compute_validation_loss: bool = True, + compute_test_loss: bool = True, + evaluation_tasks: Optional[Dict[str, EvaluationSpec]] = None, + ) -> None: + super().__init__() + + self.encoder = encoder + self.optimizer = optimizer + self.lr_scheduler = lr_scheduler + self.loss_fn = loss or Data2VecLoss() + + self.ema = ExponentialMovingAverage( + self, + ema_decay, + ema_end_decay, + ema_anneal_end_step, + ) + + self.mask_generator = mask_generator + self.compute_validation_loss = compute_validation_loss + self.compute_test_loss = compute_test_loss + self.evaluation_tasks = evaluation_tasks + + def encode(self, inputs: torch.Tensor) -> torch.Tensor: + """Encode the input values. + + Parameters + ---------- + inputs : torch.Tensor + The input values to encode. + + Returns + ------- + torch.Tensor + The encoded values. + """ + return self.encoder(inputs) + + def forward(self, inputs: torch.Tensor) -> torch.Tensor: + """Run the forward pass. + + Parameters + ---------- + inputs : torch.Tensor + The input values to forward pass. + + Returns + ------- + torch.Tensor + The forward pass output. + """ + return self.encode(inputs) + + def _compute_loss(self, batch: torch.Tensor) -> torch.Tensor: + """Compute the loss for the batch. + + Parameters + ---------- + batch : torch.Tensor + The batch of data to compute the loss for. + + Returns + ------- + torch.Tensor + The loss for the batch. + """ + masked_input = apply_masks(batch, self.mask_generator()) + student_output = self.encode(masked_input) + with torch.no_grad(): + teacher_output = self.ema.model.encode(batch) + return self.loss_fn(student_output, teacher_output) + + def training_step(self, batch: torch.Tensor, batch_idx: int) -> torch.Tensor: + """Compute the loss for the batch. + + Parameters + ---------- + batch : torch.Tensor + The batch of data to compute the loss for. + + Returns + ------- + torch.Tensor + The loss for the batch. + """ + loss = self._compute_loss(batch) + + self.log("train/loss", loss, prog_bar=True, sync_dist=True) + self.ema.step(self) + + return loss + + def on_validation_epoch_start(self) -> None: + """Prepare for the validation epoch.""" + self._on_eval_epoch_start("val") + + def validation_step( + self, batch: torch.Tensor, batch_idx: int + ) -> Optional[torch.Tensor]: + """Run a single validation step. + + Parameters + ---------- + batch : torch.Tensor + The batch of data to process. + batch_idx : int + The index of the batch. + + Returns + ------- + torch.Tensor or None + The loss for the batch or None if the loss function is not provided. + """ + return self._shared_eval_step(batch, batch_idx, "val") + + def on_validation_epoch_end(self) -> None: + """Compute and log epoch-level metrics at the end of the validation epoch.""" + self._on_eval_epoch_end("val") + + def on_test_epoch_start(self) -> None: + """Prepare for the test epoch.""" + self._on_eval_epoch_start("test") + + def test_step(self, batch: torch.Tensor, batch_idx: int) -> Optional[torch.Tensor]: + """Run a single test step. + + Parameters + ---------- + batch : torch.Tensor + The batch of data to process. + batch_idx : int + The index of the batch. + + Returns + ------- + torch.Tensor or None + The loss for the batch or None if the loss function is not provided. + """ + return self._shared_eval_step(batch, batch_idx, "test") + + def on_test_epoch_end(self) -> None: + """Compute and log epoch-level metrics at the end of the test epoch.""" + self._on_eval_epoch_end("test") + + def _shared_eval_step( + self, + batch: torch.Tensor, + batch_idx: int, + eval_type: Literal["val", "test"], + ) -> Optional[torch.Tensor]: + """Run a single evaluation step. + + Parameters + ---------- + batch : torch.Tensor + The batch of data to process. + batch_idx : int + The index of the batch. + eval_type : Literal["val", "test"] + The type of evaluation to run. + + Returns + ------- + torch.Tensor or None + The loss for the batch or None if the loss function is not provided. + """ + loss = None + + if (eval_type == "val" and self.compute_validation_loss) or ( + eval_type == "test" and self.compute_test_loss + ): + loss = self._compute_loss(batch) + self.log(f"{eval_type}/loss", loss, prog_bar=True, sync_dist=True) + + if self.evaluation_tasks: + for task_spec in self.evaluation_tasks.values(): + if (eval_type == "val" and task_spec.run_on_validation) or ( + eval_type == "test" and task_spec.run_on_test + ): + batch_result = task_spec.task.evaluation_step( + self.trainer, self, batch, batch_idx + ) + if batch_result: + for key, value in batch_result.items(): + self.log( + f"{eval_type}/{key}_step", + value, + on_step=True, + on_epoch=False, + sync_dist=True, + ) + + return loss + + def _on_eval_epoch_start(self, eval_type: Literal["val", "test"]) -> None: + """Prepare for the evaluation epoch. + + Parameters + ---------- + eval_type : Literal["val", "test"] + The type of evaluation to run. + """ + if self.evaluation_tasks: + for task_spec in self.evaluation_tasks.values(): + if (eval_type == "val" and task_spec.run_on_validation) or ( + eval_type == "test" and task_spec.run_on_test + ): + task_spec.task.on_evaluation_epoch_start(self) + + def _on_eval_epoch_end(self, eval_type: Literal["val", "test"]) -> None: + """Compute and log epoch-level metrics at the end of the evaluation epoch. + + Parameters + ---------- + eval_type : Literal["val", "test"] + The type of evaluation to run. + """ + if self.evaluation_tasks: + for task_spec in self.evaluation_tasks.values(): + if (eval_type == "val" and task_spec.run_on_validation) or ( + eval_type == "test" and task_spec.run_on_test + ): + results = task_spec.task.on_evaluation_epoch_end(self) + if results: + for key, value in results.items(): + self.log(f"{eval_type}/{key}", value) + + def configure_optimizers(self) -> OptimizerLRScheduler: + """Configure the optimizer and learning rate scheduler. + + Returns + ------- + OptimizerLRScheduler + The optimizer and learning rate scheduler. + """ + if self.optimizer is None: + rank_zero_warn( + "Optimizer not provided. Training will continue without an optimizer. " + "LR scheduler will not be used." + ) + return None + + weight_decay: Optional[float] = self.optimizer.keywords.get( + "weight_decay", None + ) + if weight_decay is None: + kw_param = inspect.signature(self.optimizer.func).parameters.get( + "weight_decay" + ) + if kw_param is not None and kw_param.default != inspect.Parameter.empty: + weight_decay = kw_param.default + + parameters = [param for param in self.parameters() if param.requires_grad] + + if weight_decay is not None: + decay_params = [] + no_decay_params = [] + + for param in self.parameters(): + if not param.requires_grad: + continue + + if param.ndim < 2: + no_decay_params.append(param) + else: + decay_params.append(param) + + parameters = [ + { + "params": decay_params, + "weight_decay": weight_decay, + "name": "weight_decay_params", + }, + { + "params": no_decay_params, + "weight_decay": 0.0, + "name": "no_weight_decay_params", + }, + ] + + optimizer = self.optimizer(parameters) + + if self.lr_scheduler is None: + return optimizer + + if isinstance(self.lr_scheduler, dict): + scheduler = self.lr_scheduler["scheduler"](optimizer) + lr_scheduler_dict = {"scheduler": scheduler} + lr_scheduler_dict.update(self.lr_scheduler.get("extras", {})) + return {"optimizer": optimizer, "lr_scheduler": lr_scheduler_dict} + + scheduler = self.lr_scheduler(optimizer) + return [optimizer], [scheduler] From 6645be6565bb0080fb17325977baff3f8ed23c40 Mon Sep 17 00:00:00 2001 From: Amrit Krishnan Date: Thu, 3 Oct 2024 09:26:33 -0400 Subject: [PATCH 2/5] Fix the code to be compatible with EMA, vision encoder modules --- mmlearn/tasks/data2vec.py | 99 ++++++++++++++++++++------------------- 1 file changed, 51 insertions(+), 48 deletions(-) diff --git a/mmlearn/tasks/data2vec.py b/mmlearn/tasks/data2vec.py index d3c758b..ebe1e48 100644 --- a/mmlearn/tasks/data2vec.py +++ b/mmlearn/tasks/data2vec.py @@ -11,21 +11,14 @@ from lightning.pytorch.utilities.types import OptimizerLRScheduler from lightning_utilities.core.rank_zero import rank_zero_warn from torch import nn +from transformers.modeling_outputs import BaseModelOutput +from mmlearn.datasets.core.modalities import Modalities, Modality from mmlearn.datasets.processors.masking import apply_masks from mmlearn.modules.ema import ExponentialMovingAverage from mmlearn.modules.losses.data2vec import Data2VecLoss -@dataclass -class ModuleKeySpec: - """Module key specification for mapping modules to modalities.""" - - encoder_key: Optional[str] = None - head_key: Optional[str] = None - postprocessor_key: Optional[str] = None - - @dataclass class EvaluationSpec: """Specification for an evaluation task.""" @@ -98,7 +91,7 @@ def __init__( self.loss_fn = loss or Data2VecLoss() self.ema = ExponentialMovingAverage( - self, + self.encoder, ema_decay, ema_end_decay, ema_anneal_end_step, @@ -109,42 +102,27 @@ def __init__( self.compute_test_loss = compute_test_loss self.evaluation_tasks = evaluation_tasks - def encode(self, inputs: torch.Tensor) -> torch.Tensor: - """Encode the input values. - - Parameters - ---------- - inputs : torch.Tensor - The input values to encode. - - Returns - ------- - torch.Tensor - The encoded values. - """ - return self.encoder(inputs) - - def forward(self, inputs: torch.Tensor) -> torch.Tensor: + def forward(self, inputs: Dict[Union[str, Modality], Any]) -> BaseModelOutput: """Run the forward pass. Parameters ---------- - inputs : torch.Tensor - The input values to forward pass. + inputs : Dict[Union[str, Modality], Any] + The input tensors to encode. Returns ------- - torch.Tensor - The forward pass output. + BaseModelOutput + The output of the encoder. """ - return self.encode(inputs) + return self.encoder(inputs) - def _compute_loss(self, batch: torch.Tensor) -> torch.Tensor: + def _compute_loss(self, batch: Dict[Union[str, Modality], Any]) -> torch.Tensor: """Compute the loss for the batch. Parameters ---------- - batch : torch.Tensor + batch : Dict[Union[str, Modality], Any] The batch of data to compute the loss for. Returns @@ -152,19 +130,44 @@ def _compute_loss(self, batch: torch.Tensor) -> torch.Tensor: torch.Tensor The loss for the batch. """ - masked_input = apply_masks(batch, self.mask_generator()) - student_output = self.encode(masked_input) + # Generate mask + mask = self.mask_generator() + + # Apply mask to input (only supported for RGB modality) + masked_input = { + k: apply_masks(v, mask) if k == Modalities.RGB else v + for k, v in batch.items() + } + + # Get student output + student_output = self(masked_input) + student_hidden_states = student_output.hidden_states + + # Get teacher output (without gradients) with torch.no_grad(): - teacher_output = self.ema.model.encode(batch) - return self.loss_fn(student_output, teacher_output) + teacher_output = self.ema.model(batch) + teacher_hidden_states = teacher_output.hidden_states + + # Compute loss + loss = 0 + for student_layer, teacher_layer in zip( + student_hidden_states, teacher_hidden_states + ): + loss += self.loss_fn(student_layer, teacher_layer) - def training_step(self, batch: torch.Tensor, batch_idx: int) -> torch.Tensor: + return loss / len(student_hidden_states) + + def training_step( + self, batch: Dict[Union[str, Modality], Any], batch_idx: int + ) -> torch.Tensor: """Compute the loss for the batch. Parameters ---------- - batch : torch.Tensor + batch : Dict[Union[str, Modality], Any] The batch of data to compute the loss for. + batch_idx : int + The index of the batch. Returns ------- @@ -172,10 +175,8 @@ def training_step(self, batch: torch.Tensor, batch_idx: int) -> torch.Tensor: The loss for the batch. """ loss = self._compute_loss(batch) - self.log("train/loss", loss, prog_bar=True, sync_dist=True) - self.ema.step(self) - + self.ema.step(self.encoder) return loss def on_validation_epoch_start(self) -> None: @@ -183,13 +184,13 @@ def on_validation_epoch_start(self) -> None: self._on_eval_epoch_start("val") def validation_step( - self, batch: torch.Tensor, batch_idx: int + self, batch: Dict[Union[str, Modality], Any], batch_idx: int ) -> Optional[torch.Tensor]: """Run a single validation step. Parameters ---------- - batch : torch.Tensor + batch : Dict[Union[str, Modality], Any] The batch of data to process. batch_idx : int The index of the batch. @@ -209,12 +210,14 @@ def on_test_epoch_start(self) -> None: """Prepare for the test epoch.""" self._on_eval_epoch_start("test") - def test_step(self, batch: torch.Tensor, batch_idx: int) -> Optional[torch.Tensor]: + def test_step( + self, batch: Dict[Union[str, Modality], Any], batch_idx: int + ) -> Optional[torch.Tensor]: """Run a single test step. Parameters ---------- - batch : torch.Tensor + batch : Dict[Union[str, Modality], Any] The batch of data to process. batch_idx : int The index of the batch. @@ -232,7 +235,7 @@ def on_test_epoch_end(self) -> None: def _shared_eval_step( self, - batch: torch.Tensor, + batch: Dict[Union[str, Modality], Any], batch_idx: int, eval_type: Literal["val", "test"], ) -> Optional[torch.Tensor]: @@ -240,7 +243,7 @@ def _shared_eval_step( Parameters ---------- - batch : torch.Tensor + batch : Dict[Union[str, Modality], Any] The batch of data to process. batch_idx : int The index of the batch. From 4e61d0122fe754afc3fc90a99ad4d134b7bb1ff4 Mon Sep 17 00:00:00 2001 From: Amrit Krishnan Date: Thu, 31 Oct 2024 10:09:44 -0400 Subject: [PATCH 3/5] Fixes to data2vec task class --- mmlearn/tasks/data2vec.py | 135 ++++++++++++++++++++++---------------- 1 file changed, 79 insertions(+), 56 deletions(-) diff --git a/mmlearn/tasks/data2vec.py b/mmlearn/tasks/data2vec.py index ebe1e48..8a48acc 100644 --- a/mmlearn/tasks/data2vec.py +++ b/mmlearn/tasks/data2vec.py @@ -3,18 +3,21 @@ import inspect from dataclasses import dataclass from functools import partial -from typing import Any, Dict, Literal, Optional, Union +from typing import Any, Dict, List, Literal, Optional, Union import lightning as L # noqa: N812 import torch +import torch.nn.functional as F from hydra_zen import store from lightning.pytorch.utilities.types import OptimizerLRScheduler from lightning_utilities.core.rank_zero import rank_zero_warn from torch import nn -from transformers.modeling_outputs import BaseModelOutput from mmlearn.datasets.core.modalities import Modalities, Modality -from mmlearn.datasets.processors.masking import apply_masks +from mmlearn.datasets.processors.masking import ( + BlockwiseImagePatchMaskGenerator, + apply_masks, +) from mmlearn.modules.ema import ExponentialMovingAverage from mmlearn.modules.losses.data2vec import Data2VecLoss @@ -30,39 +33,7 @@ class EvaluationSpec: @store(group="task", provider="mmlearn") class Data2VecTask(L.LightningModule): - """Data2Vec task. - - This class implements the Data2Vec self-supervised learning approach for a single - modality. It can be used as an auxiliary task in multi-modal learning setups. - - Parameters - ---------- - encoder : nn.Module - The encoder for the modality. - optimizer : partial[torch.optim.Optimizer], optional - The optimizer to use for training. - lr_scheduler : Union[ - Dict[str, Union[partial[torch.optim.lr_scheduler.LRScheduler], Any]], - partial[torch.optim.lr_scheduler.LRScheduler] - ], optional - The learning rate scheduler to use for training. - loss : Data2VecLoss, optional - The loss function to use. - ema_decay : float - The initial decay value for EMA. - ema_end_decay : float - The final decay value for EMA. - ema_anneal_end_step : int - The number of steps to anneal the decay from `ema_decay` to `ema_end_decay`. - mask_generator : Any - The mask generator to use for creating masked inputs. - compute_validation_loss : bool - Whether to compute the validation loss. - compute_test_loss : bool - Whether to compute the test loss. - evaluation_tasks : Dict[str, EvaluationSpec], optional - Evaluation tasks to run during validation and testing. - """ + """Data2Vec task implementation.""" def __init__( self, @@ -78,10 +49,16 @@ def __init__( ema_decay: float = 0.999, ema_end_decay: float = 0.9999, ema_anneal_end_step: int = 300000, - mask_generator: Any = None, + mask_generator: Optional[BlockwiseImagePatchMaskGenerator] = None, compute_validation_loss: bool = True, compute_test_loss: bool = True, evaluation_tasks: Optional[Dict[str, EvaluationSpec]] = None, + average_top_k_layers: int = 6, + target_instance_norm: bool = False, + target_batch_norm: bool = False, + target_layer_norm_last: bool = False, + post_target_instance_norm: bool = False, + post_target_layer_norm: bool = False, ) -> None: super().__init__() @@ -102,20 +79,68 @@ def __init__( self.compute_test_loss = compute_test_loss self.evaluation_tasks = evaluation_tasks - def forward(self, inputs: Dict[Union[str, Modality], Any]) -> BaseModelOutput: - """Run the forward pass. + # Data2Vec specific parameters + self.average_top_k_layers = average_top_k_layers + self.target_instance_norm = target_instance_norm + self.target_batch_norm = target_batch_norm + self.target_layer_norm_last = target_layer_norm_last + self.post_target_instance_norm = post_target_instance_norm + self.post_target_layer_norm = post_target_layer_norm + + def _get_teacher_targets(self, hidden_states: List[torch.Tensor]) -> torch.Tensor: + """Get teacher targets by averaging top k layers with normalization. Parameters ---------- - inputs : Dict[Union[str, Modality], Any] - The input tensors to encode. + hidden_states : List[torch.Tensor] + The hidden states to average. Returns ------- - BaseModelOutput - The output of the encoder. + torch.Tensor + The averaged hidden states. """ - return self.encoder(inputs) + top_k_hidden_states = hidden_states[-self.average_top_k_layers :] + + if self.target_instance_norm or self.target_batch_norm: + top_k_hidden_states = [ + val.permute(0, 2, 1) for val in top_k_hidden_states + ] # btc => bct + + if self.target_batch_norm: + top_k_hidden_states = [ + F.batch_norm( + val.float(), running_mean=None, running_var=None, training=True + ) + for val in top_k_hidden_states + ] + + if self.target_instance_norm: + top_k_hidden_states = [ + F.instance_norm(val.float()) for val in top_k_hidden_states + ] + + if self.target_instance_norm or self.target_batch_norm: + top_k_hidden_states = [ + val.permute(0, 2, 1) for val in top_k_hidden_states + ] # bct => btc + + if self.target_layer_norm_last: + top_k_hidden_states = [ + F.layer_norm(val.float(), val.shape[-1:]) for val in top_k_hidden_states + ] + + targets = sum(top_k_hidden_states) / len(top_k_hidden_states) + + if self.post_target_instance_norm: + targets = targets.permute(0, 2, 1) + targets = F.instance_norm(targets.float()) + targets = targets.permute(0, 2, 1) + + if self.post_target_layer_norm: + targets = F.layer_norm(targets.float(), targets.shape[-1:]) + + return targets def _compute_loss(self, batch: Dict[Union[str, Modality], Any]) -> torch.Tensor: """Compute the loss for the batch. @@ -139,23 +164,21 @@ def _compute_loss(self, batch: Dict[Union[str, Modality], Any]) -> torch.Tensor: for k, v in batch.items() } - # Get student output + # Get student output with masked input student_output = self(masked_input) - student_hidden_states = student_output.hidden_states + student_hidden = student_output.last_hidden_state - # Get teacher output (without gradients) + # Get teacher output with original input with torch.no_grad(): - teacher_output = self.ema.model(batch) - teacher_hidden_states = teacher_output.hidden_states + teacher_output = self.ema.model(batch, output_hidden_states=True) + teacher_targets = self._get_teacher_targets(teacher_output.hidden_states) - # Compute loss - loss = 0 - for student_layer, teacher_layer in zip( - student_hidden_states, teacher_hidden_states - ): - loss += self.loss_fn(student_layer, teacher_layer) + # Get masked indices + if isinstance(mask, torch.Tensor): + mask = mask.bool() - return loss / len(student_hidden_states) + # Compute loss only on masked positions + return self.loss_fn(student_hidden[mask], teacher_targets[mask]) def training_step( self, batch: Dict[Union[str, Modality], Any], batch_idx: int From 3fdd5dd08a4e6bd863634cdef0fab7c555fe8ae4 Mon Sep 17 00:00:00 2001 From: Amrit Krishnan Date: Tue, 19 Nov 2024 08:57:13 -0500 Subject: [PATCH 4/5] Move regression head and get_hidden_states in data2vec task --- mmlearn/tasks/data2vec.py | 127 +++++++++++++++++++++++++------------- 1 file changed, 84 insertions(+), 43 deletions(-) diff --git a/mmlearn/tasks/data2vec.py b/mmlearn/tasks/data2vec.py index 8a48acc..ce3b3b8 100644 --- a/mmlearn/tasks/data2vec.py +++ b/mmlearn/tasks/data2vec.py @@ -7,13 +7,13 @@ import lightning as L # noqa: N812 import torch -import torch.nn.functional as F +import torch.nn.functional as F # noqa: N812 from hydra_zen import store from lightning.pytorch.utilities.types import OptimizerLRScheduler from lightning_utilities.core.rank_zero import rank_zero_warn from torch import nn -from mmlearn.datasets.core.modalities import Modalities, Modality +from mmlearn.datasets.core.modalities import Modality from mmlearn.datasets.processors.masking import ( BlockwiseImagePatchMaskGenerator, apply_masks, @@ -31,6 +31,42 @@ class EvaluationSpec: run_on_test: bool = True +class RegressionHead(nn.Module): + """Regression head for Data2Vec.""" + + def __init__(self, embed_dim: int, num_layers: int = 1) -> None: + """Initialize the regression head.""" + super().__init__() + if num_layers < 1: + raise ValueError("num_layers must be >= 1") + + layers = [] + curr_dim = embed_dim + + for i in range(num_layers - 1): + next_dim = embed_dim * 2 if i == 0 else curr_dim + layers.extend([nn.Linear(curr_dim, next_dim), nn.GELU()]) + curr_dim = next_dim + + layers.append(nn.Linear(curr_dim, embed_dim)) + self.layers = nn.Sequential(*layers) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Forward pass. + + Parameters + ---------- + x : torch.Tensor + The input tensor. + + Returns + ------- + torch.Tensor + The output tensor. + """ + return self.layers(x) + + @store(group="task", provider="mmlearn") class Data2VecTask(L.LightningModule): """Data2Vec task implementation.""" @@ -38,6 +74,7 @@ class Data2VecTask(L.LightningModule): def __init__( self, encoder: nn.Module, + head_layers: int = 1, optimizer: Optional[partial[torch.optim.Optimizer]] = None, lr_scheduler: Optional[ Union[ @@ -63,6 +100,12 @@ def __init__( super().__init__() self.encoder = encoder + # Build regression head + self.regression_head = RegressionHead( + self.encoder.model.config.hidden_size, + num_layers=head_layers, + ) + self.optimizer = optimizer self.lr_scheduler = lr_scheduler self.loss_fn = loss or Data2VecLoss() @@ -87,51 +130,52 @@ def __init__( self.post_target_instance_norm = post_target_instance_norm self.post_target_layer_norm = post_target_layer_norm - def _get_teacher_targets(self, hidden_states: List[torch.Tensor]) -> torch.Tensor: - """Get teacher targets by averaging top k layers with normalization. - - Parameters - ---------- - hidden_states : List[torch.Tensor] - The hidden states to average. - - Returns - ------- - torch.Tensor - The averaged hidden states. - """ - top_k_hidden_states = hidden_states[-self.average_top_k_layers :] + def _process_hidden_states( + self, hidden_states: List[torch.Tensor], remove_cls_token: bool = False + ) -> List[torch.Tensor]: + """Process hidden states with normalization.""" + if remove_cls_token: + hidden_states = [h[:, 1:] for h in hidden_states] if self.target_instance_norm or self.target_batch_norm: - top_k_hidden_states = [ - val.permute(0, 2, 1) for val in top_k_hidden_states - ] # btc => bct - - if self.target_batch_norm: - top_k_hidden_states = [ - F.batch_norm( - val.float(), running_mean=None, running_var=None, training=True - ) - for val in top_k_hidden_states - ] + hidden_states = [h.permute(0, 2, 1) for h in hidden_states] - if self.target_instance_norm: - top_k_hidden_states = [ - F.instance_norm(val.float()) for val in top_k_hidden_states - ] + if self.target_batch_norm: + hidden_states = [ + F.batch_norm(h.float(), None, None, training=True) + for h in hidden_states + ] - if self.target_instance_norm or self.target_batch_norm: - top_k_hidden_states = [ - val.permute(0, 2, 1) for val in top_k_hidden_states - ] # bct => btc + if self.target_instance_norm: + hidden_states = [F.instance_norm(h.float()) for h in hidden_states] + + hidden_states = [h.permute(0, 2, 1) for h in hidden_states] if self.target_layer_norm_last: - top_k_hidden_states = [ - F.layer_norm(val.float(), val.shape[-1:]) for val in top_k_hidden_states + hidden_states = [ + F.layer_norm(h.float(), h.shape[-1:]) for h in hidden_states ] + return hidden_states + + def _get_teacher_targets( + self, + hidden_states: List[torch.Tensor], + ) -> torch.Tensor: + """Get teacher targets following reference implementation.""" + # Remove final layer as per reference + hidden_states = hidden_states[:-1] + + # Get top k layers + top_k_hidden_states = self._process_hidden_states( + hidden_states[-self.average_top_k_layers :], + remove_cls_token=True, + ) + + # Average the layers targets = sum(top_k_hidden_states) / len(top_k_hidden_states) + # Apply post-processing if self.post_target_instance_norm: targets = targets.permute(0, 2, 1) targets = F.instance_norm(targets.float()) @@ -156,13 +200,10 @@ def _compute_loss(self, batch: Dict[Union[str, Modality], Any]) -> torch.Tensor: The loss for the batch. """ # Generate mask - mask = self.mask_generator() + mask = self.mask_generator() # type: ignore - # Apply mask to input (only supported for RGB modality) - masked_input = { - k: apply_masks(v, mask) if k == Modalities.RGB else v - for k, v in batch.items() - } + # Apply mask to input + masked_input = {k: apply_masks(v, mask) for k, v in batch.items()} # Get student output with masked input student_output = self(masked_input) From 27411c81d05e2a26357884f86896658eb1ede802 Mon Sep 17 00:00:00 2001 From: Amrit Krishnan Date: Tue, 19 Nov 2024 09:07:07 -0500 Subject: [PATCH 5/5] Re-use the MLP class --- mmlearn/tasks/data2vec.py | 42 +++++++++++++++++++++++++-------------- 1 file changed, 27 insertions(+), 15 deletions(-) diff --git a/mmlearn/tasks/data2vec.py b/mmlearn/tasks/data2vec.py index ce3b3b8..77d499e 100644 --- a/mmlearn/tasks/data2vec.py +++ b/mmlearn/tasks/data2vec.py @@ -19,6 +19,7 @@ apply_masks, ) from mmlearn.modules.ema import ExponentialMovingAverage +from mmlearn.modules.layers.mlp import MLP from mmlearn.modules.losses.data2vec import Data2VecLoss @@ -32,37 +33,48 @@ class EvaluationSpec: class RegressionHead(nn.Module): - """Regression head for Data2Vec.""" + """Regression head for Data2Vec text encoder.""" def __init__(self, embed_dim: int, num_layers: int = 1) -> None: - """Initialize the regression head.""" + """Initialize the regression head. + + Parameters + ---------- + embed_dim : int + Dimension of the input embeddings + num_layers : int, optional + Number of layers in the regression head, by default 1 + """ super().__init__() if num_layers < 1: raise ValueError("num_layers must be >= 1") - layers = [] - curr_dim = embed_dim - - for i in range(num_layers - 1): - next_dim = embed_dim * 2 if i == 0 else curr_dim - layers.extend([nn.Linear(curr_dim, next_dim), nn.GELU()]) - curr_dim = next_dim - - layers.append(nn.Linear(curr_dim, embed_dim)) - self.layers = nn.Sequential(*layers) + if num_layers == 1: + hidden_dims = [] + else: + hidden_dims = [embed_dim * 2] + [embed_dim * 2] * (num_layers - 2) + + self.layers = MLP( + in_dim=embed_dim, + out_dim=embed_dim, + hidden_dims=hidden_dims, + activation_layer=nn.GELU, + norm_layer=None, + dropout=0.0, + ) def forward(self, x: torch.Tensor) -> torch.Tensor: - """Forward pass. + """Run the forward pass. Parameters ---------- x : torch.Tensor - The input tensor. + Input tensor. Returns ------- torch.Tensor - The output tensor. + Output tensor. """ return self.layers(x)