diff --git a/internnav/dataset/navdp_dataset_lerobot.py b/internnav/dataset/navdp_dataset_lerobot.py index 4437e7c..95753d9 100644 --- a/internnav/dataset/navdp_dataset_lerobot.py +++ b/internnav/dataset/navdp_dataset_lerobot.py @@ -41,6 +41,7 @@ def __init__( image_size=224, scene_data_scale=1.0, trajectory_data_scale=1.0, + pixel_channel=7, debug=False, preload=False, random_digit=False, @@ -61,6 +62,7 @@ def __init__( self.trajectory_afford_path = [] self.random_digit = random_digit self.prior_sample = prior_sample + self.pixel_channel = pixel_channel self.item_cnt = 0 self.batch_size = batch_size self.batch_time_sum = 0.0 @@ -509,7 +511,12 @@ def __getitem__(self, index): camera_intrinsic, trajectory_base_extrinsic, ) - pixel_goal = np.concatenate((pixel_goal, memory_images[-1]), axis=-1) + # pixel channel == 7 represents the navdp works pixel navigation under asynchronous pace, + # pixel_mask (1), the history image with the assigned pixel goal (3), current image (3) + # if pixel_channel == 4, pixel goal is assigned at current frame, therefore, + # only pixel_mask (1) and current image (3) are needed + if self.pixel_channel == 7: + pixel_goal = np.concatenate((pixel_goal, memory_images[-1]), axis=-1) pred_actions = (pred_actions[1:] - pred_actions[:-1]) * 4.0 augment_actions = (augment_actions[1:] - augment_actions[:-1]) * 4.0 diff --git a/internnav/model/basemodel/navdp/navdp_policy.py b/internnav/model/basemodel/navdp/navdp_policy.py index 6a8da42..e235352 100644 --- a/internnav/model/basemodel/navdp/navdp_policy.py +++ b/internnav/model/basemodel/navdp/navdp_policy.py @@ -7,7 +7,13 @@ from internnav.configs.model.base_encoders import ModelCfg from internnav.configs.trainer.exp import ExpCfg -from internnav.model.encoder.navdp_backbone import * +from internnav.model.encoder.navdp_backbone import ( + ImageGoalBackbone, + LearnablePositionalEncoding, + PixelGoalBackbone, + RGBDBackbone, + SinusoidalPosEmb, +) class NavDPModelConfig(PretrainedConfig): @@ -51,9 +57,7 @@ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): elif pretrained_model_name_or_path is None or len(pretrained_model_name_or_path) == 0: pass else: - incompatible_keys, _ = model.load_state_dict( - torch.load(pretrained_model_name_or_path)['state_dict'], strict=False - ) + incompatible_keys, _ = model.load_state_dict(torch.load(pretrained_model_name_or_path), strict=False) if len(incompatible_keys) > 0: print(f'Incompatible keys: {incompatible_keys}') @@ -66,13 +70,12 @@ def __init__(self, config: NavDPModelConfig): self.model_config = ModelCfg(**config.model_cfg['model']) else: self.model_config = config - self.config.model_cfg['il'] - self._device = torch.device(f"cuda:{config.model_cfg['local_rank']}") self.image_size = self.config.model_cfg['il']['image_size'] self.memory_size = self.config.model_cfg['il']['memory_size'] self.predict_size = self.config.model_cfg['il']['predict_size'] + self.pixel_channel = self.config.model_cfg['il']['pixel_channel'] self.temporal_depth = self.config.model_cfg['il']['temporal_depth'] self.attention_heads = self.config.model_cfg['il']['heads'] self.input_channels = self.config.model_cfg['il']['channels'] @@ -80,17 +83,19 @@ def __init__(self, config: NavDPModelConfig): self.token_dim = self.config.model_cfg['il']['token_dim'] self.scratch = self.config.model_cfg['il']['scratch'] self.finetune = self.config.model_cfg['il']['finetune'] - self.rgbd_encoder = NavDP_RGBD_Backbone( + self.rgbd_encoder = RGBDBackbone( self.image_size, self.token_dim, memory_size=self.memory_size, finetune=self.finetune, device=self._device ) - self.pixel_encoder = NavDP_PixelGoal_Backbone(self.image_size, self.token_dim, device=self._device) - self.image_encoder = NavDP_ImageGoal_Backbone(self.image_size, self.token_dim, device=self._device) + self.pixel_encoder = PixelGoalBackbone( + self.image_size, self.token_dim, pixel_channel=self.pixel_channel, device=self._device + ) + self.image_encoder = ImageGoalBackbone(self.image_size, self.token_dim, device=self._device) self.point_encoder = nn.Linear(3, self.token_dim) if not self.finetune: - for p in self.rgbd_encoder.parameters(): + for p in self.rgbd_encoder.rgb_model.parameters(): p.requires_grad = False - self.rgbd_encoder.eval() + self.rgbd_encoder.rgb_model.eval() decoder_layer = nn.TransformerDecoderLayer( d_model=self.token_dim, @@ -180,23 +185,6 @@ def predict_critic(self, predict_trajectory, rgbd_embed): return critic_output def forward(self, goal_point, goal_image, goal_pixel, input_images, input_depths, output_actions, augment_actions): - # """get device safely""" - # # get device safely - # try: - # # try to get device through model parameters - # device = next(self.parameters()).device - # except StopIteration: - # # model has no parameters, use the default device - # device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - # # move all inputs to model device - # goal_point = goal_point.to(device) - # goal_image = goal_image.to(device) - # input_images = input_images.to(device) - # input_depths = input_depths.to(device) - # output_actions = output_actions.to(device) - # augment_actions = augment_actions.to(device) - # device = self._device - # print(f"self.parameters() is:{self.parameters()}") device = next(self.parameters()).device assert input_images.shape[1] == self.memory_size @@ -325,7 +313,6 @@ def predict_pointgoal_batch_action_vel(self, goal_point, input_images, input_dep naction = self.noise_scheduler.step(model_output=noise_pred, timestep=k, sample=naction).prev_sample critic_values = self.predict_critic(naction, rgbd_embed) - all_trajectory = torch.cumsum(naction / 4.0, dim=1) negative_trajectory = torch.cumsum(naction / 4.0, dim=1)[(critic_values).argsort()[0:8]] positive_trajectory = torch.cumsum(naction / 4.0, dim=1)[(-critic_values).argsort()[0:8]] @@ -344,7 +331,6 @@ def predict_nogoal_batch_action_vel(self, input_images, input_depths, sample_num naction = self.noise_scheduler.step(model_output=noise_pred, timestep=k, sample=naction).prev_sample critic_values = self.predict_critic(naction, rgbd_embed) - all_trajectory = torch.cumsum(naction / 4.0, dim=1) negative_trajectory = torch.cumsum(naction / 4.0, dim=1)[(critic_values).argsort()[0:8]] positive_trajectory = torch.cumsum(naction / 4.0, dim=1)[(-critic_values).argsort()[0:8]] diff --git a/internnav/model/encoder/navdp_backbone.py b/internnav/model/encoder/navdp_backbone.py index cd2e879..34cb2c9 100644 --- a/internnav/model/encoder/navdp_backbone.py +++ b/internnav/model/encoder/navdp_backbone.py @@ -202,7 +202,7 @@ def forward(self, images, depths): return memory_token -class NavDP_RGBD_Backbone(nn.Module): +class RGBDBackbone(nn.Module): def __init__( self, image_size=224, @@ -313,7 +313,7 @@ def _get_device(self): return torch.device("cuda" if torch.cuda.is_available() else "cpu") -class NavDP_ImageGoal_Backbone(nn.Module): +class ImageGoalBackbone(nn.Module): def __init__(self, image_size=224, embed_size=512, device='cuda:0'): super().__init__() if device is None: @@ -376,8 +376,8 @@ def _get_device(self): return torch.device("cuda" if torch.cuda.is_available() else "cpu") -class NavDP_PixelGoal_Backbone(nn.Module): - def __init__(self, image_size=224, embed_size=512, device='cuda:0'): +class PixelGoalBackbone(nn.Module): + def __init__(self, image_size=224, embed_size=512, pixel_channel=7, device='cuda:0'): super().__init__() if device is None: device = torch.device("cuda" if torch.cuda.is_available() else "cpu") @@ -392,7 +392,7 @@ def __init__(self, image_size=224, embed_size=512, device='cuda:0'): self.pixelgoal_encoder = DepthAnythingV2(**model_configs['vits']) self.pixelgoal_encoder = self.pixelgoal_encoder.pretrained.float() self.pixelgoal_encoder.patch_embed.proj = nn.Conv2d( - in_channels=7, + in_channels=pixel_channel, out_channels=self.pixelgoal_encoder.patch_embed.proj.out_channels, kernel_size=self.pixelgoal_encoder.patch_embed.proj.kernel_size, stride=self.pixelgoal_encoder.patch_embed.proj.stride, diff --git a/scripts/train/configs/navdp.py b/scripts/train/configs/navdp.py index 4085f8a..4329da2 100644 --- a/scripts/train/configs/navdp.py +++ b/scripts/train/configs/navdp.py @@ -29,7 +29,7 @@ ), il=IlCfg( epochs=1000, - batch_size=16, + batch_size=32, lr=1e-4, num_workers=8, weight_decay=1e-4, # TODO @@ -57,6 +57,7 @@ prior_sample=False, memory_size=8, predict_size=24, + pixel_channel=4, temporal_depth=16, heads=8, token_dim=384, diff --git a/scripts/train/train.py b/scripts/train/train.py index becab27..d2cc965 100755 --- a/scripts/train/train.py +++ b/scripts/train/train.py @@ -77,12 +77,12 @@ def main(config, model_class, model_config_class): """Main training function.""" _make_dir(config) - print(f"=== Start training ===") + print("=== Start training ===") print(f"Current time: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}") print(f"PyTorch version: {torch.__version__}") print(f"CUDA available: {torch.cuda.is_available()}") print(f"CUDA device count: {torch.cuda.device_count()}") - print(f"Environment variables:") + print("Environment variables:") print(f" RANK: {os.getenv('RANK', 'Not set')}") print(f" LOCAL_RANK: {os.getenv('LOCAL_RANK', 'Not set')}") print(f" WORLD_SIZE: {os.getenv('WORLD_SIZE', 'Not set')}") @@ -172,6 +172,7 @@ def main(config, model_class, model_config_class): config.il.batch_size, config.il.image_size, config.il.scene_scale, + pixel_channel=config.il.pixel_channel, preload=config.il.preload, random_digit=config.il.random_digit, prior_sample=config.il.prior_sample,