Point track项目源代码分析

我们这篇文章来分析分析Point track的源代码。这个项目的工程量可以说非常非常大。

(最终生成的视频截图,这里这个图是倒着的)

image.png

从官网拿到代码之后:Point Track ,按照要求配置好环境:

配置环境的方法:

git clone --recursive https://github.com/Large-Trajectory-Model/ATM.git

cd ATM/
conda env create -f environment.yml
conda activate atm

pip install -e third_party/robosuite/
pip install -e third_party/robomimic/

然后,处理数据。

mkdir data
python -m scripts.download_libero_datasets

当然,它这里的CoTracker是在现使用的(就是直接从hub上拿下来去使用的)。如果你的网络不好,你也可以将其换成离线的。

换成离线的方法:

1、需要自己手动下载CoTracker v2的代码。

2、把代码中的这个地方换掉(这里是在script/ preprocess_libero下):

    # setup cotracker

    # 获得cotracker,这是在线获取cotracker的方式
    cotracker = torch.hub.load(os.path.join(os.path.expanduser("~"), ".cache/torch/hub/facebookresearch_co-tracker_main/"), "cotracker2", source="local")

    # 切换到评估模式,dropout层不会再丢弃神经元,bn也使用预定义好的方差和均值
    cotracker = cotracker.eval().cuda()

替换的方式其实很简单,就是CoTracker里面有一个CoTrackerPredict类,用那个构建出一个CoTracker,然后把这个给换掉就可以了。具体的可以参考我之前的一些文章。

紧接着,数据处理好之后。需要再分割一下数据:

python -m scripts.split_libero_dataset

做完以后应该是这个样子的:

image.png

然后就到了跑第一个TrackTransformer了。这一步也比较简单,直接这样子就可以:

python -m scripts.train_libero_track_transformer --suite $SUITE_NAME

这里的 SUITE_NAME 可以是 libero_spatiallibero_objectlibero_goal, 或者 libero_100

这里是预训练那个大的Track Transformer模型。

接着这里做完之后,就进行第二阶段的训练,可以理解为用该Track Transformer来去指导下游demo的学习。

训练方法:

python -m scripts.train_libero_policy_atm --suite $SUITE_NAME -tt $PATH_TO_TT

这里的--suite后面跟的参数仍然是 libero_spatiallibero_objectlibero_goal, 或者 libero_100, 然后-tt就是track transformer的意思,它后面跟着的就是你第一阶段训练的Track Transformer模型的路径。

例如:

python -m scripts.train_libero_policy_atm --suite libero_spatial -tt results/track_transformer/libero_track_transformer_libero-spatial/

作者这里也同时提供了传统克隆(BC)的做法,也就是没有前面的Track Transformer生成的Track作为指导来去训练学习,直接通过模仿学习来完成的方式:

python -m scripts.train_libero_policy_bc --suite $SUITE_NAME

这里就不再需要后面的-tt参数,也就是不需要Track Transformer的参数了。

那么,这个就是一整个的训练过程。训练完你应该有这些东西:

image.png

评估过程,通过 scripts/eval_libero_policy来去实现。那具体的语法如下:

python -m scripts.eval_libero_policy --suite $SUITE_NAME --exp-dir $PATH_TO_EXP

例如:

python -m scripts.eval_libero_policy --suite libero_spatial --exp-dir results/policy/atm-policy_libero-spatial_demo10

OK。这样以后,就会在结果中生成在Libero仿真中的视频。

以上是关于代码运行的介绍。这是我运行了libero_goal的五个任务来去训练,然后在十个任务上进行测试的结果:

image.png

下面我们对代码本身来去进行简单的分析:

首先,模型最核心的部分(TrackTransformer),在目录atm/model/track_transformer.py中:

image.png

class TrackTransformer(nn.Module):
    """
    flow video model using a BERT transformer

    dim: int, dimension of the model
    depth: int, number of layers
    heads: int, number of heads
    dim_head: int, dimension of each head
    attn_dropout: float, dropout for attention layers
    ff_dropout: float, dropout for feedforward layers
    """

    def __init__(self,
                 transformer_cfg,
                 track_cfg,
                 vid_cfg,
                 language_encoder_cfg,
                 load_path=None):
        super().__init__()
        self.dim = dim = transformer_cfg.dim
        self.transformer = self._init_transformer(**transformer_cfg)
        self.track_proj_encoder, self.track_decoder = self._init_track_modules(**track_cfg, dim=dim)
        self.img_proj_encoder, self.img_decoder = self._init_video_modules(**vid_cfg, dim=dim)
        self.language_encoder = self._init_language_encoder(output_size=dim, **language_encoder_cfg)
        self._init_weights(self.dim, self.num_img_patches)
    
    def forward(self, vid, track, task_emb, p_img):
        """
        track: (b, tl, n, 2), which means current time step t0 -> t0 + tl
        vid: (b, t, c, h, w), which means the past time step t0 - t -> t0
        task_emb, (b, emb_size)
        """
        assert torch.max(vid) <=1.
        B, T, _, _ = track.shape
        patches = self._encode_video(vid, p_img)  # (b, n_image, d)
        enc_track = self._encode_track(track)

        text_encoded = self.language_encoder(task_emb)  # (b, c)
        text_encoded = rearrange(text_encoded, 'b c -> b 1 c')

        x = torch.cat([enc_track, patches, text_encoded], dim=1)
        x = self.transformer(x)

        rec_track, rec_patches = x[:, :self.num_track_patches], x[:, self.num_track_patches:-1]
        rec_patches = self.img_decoder(rec_patches)  # (b, n_image, 3 * t * patch_size ** 2)
        rec_track = self.track_decoder(rec_track)  # (b, (t n), 2 * patch_size)
        num_track_h = self.num_track_ts // self.track_patch_size
        rec_track = rearrange(rec_track, 'b (t n) (p c) -> b (t p) n c', p=self.track_patch_size, t=num_track_h)

        return rec_track, rec_patches

那么,主要看它的self函数和forward函数,可以从这两个函数作为入口来去看。

那么关于维度的变化,都在上面的注释中解释过了。

通过这段代码可以简单地看到它是怎么样去做的。

首先是经过了编码,分别是_encode_video_encode_track,然后,对文本也进行了编码。编码之后将三者cat在一起,丢进Transformer模型中。最后是取得rec_track和rec_patches两部分。即这里不仅计算了track,还计算了下一帧的frame。

这一点可以在下文的forward_loss中看到:

点击上海AI实验室Point Track项目代码实战:从代码配置到核心模型架构全攻略查看全文。

Logo

欢迎加入DeepSeek 技术社区。在这里,你可以找到志同道合的朋友,共同探索AI技术的奥秘。

更多推荐