tsn阅读
tsn 论文阅读笔记
1 论文链接
链接:https://pan.baidu.com/s/1jh0r5M9XMqB14aRAi6ZNVw
提取码:l8vx
2 论文讲了什么
这是一个视频级的预测,将一段video分成k个segment(论文中k=3),然后从每个segment里面分别随机取一帧即Snippets,将RGB和flow格式的图分别送入spatial ConvNet和Temporal ConvNet学习空间特征和时序特征,然后将不同segment的特征融合后输出分类结果。
3 tsn 函数调用关系
4 tsn函数关系调用解析
(强烈建议参照源码观看,否则跳过本节)
接下来主要介绍TSN源码的函数调用。
首先我们来看一下main.py,这是训练模型的入口。Main函数主要包括了导入模型、数据准备和训练三个部分。所以这里相应的调用了model,dataset以及transforms。
(一)
首先介绍导入模型的部分,那这里就主要涉及到了models.py。它主要就是进行了模型的设计,对之后的训练模型进行了准备。这里默认使用了resnet101作为基础模型,针对不同的输入参数,对最后的一层全连接层进行修改,得到我们所需的TSN网络模型。
然后让我们详细看一下models.py这个文件,首先调用了TSN类的init函数,进行初始化TSN模型,并且设置了一些参数和参数默认值。Init函数调用了prepare_basemodel和prepare_tsn函数来进行TSN网络模型结构的修改。这里详细的讲一下这两个函数
其中prepare_base_model函数主要是选择不同的网络结构模型时,对输入的数据集进行预处理。假设这里对resnet101进行修改,则先通过内置函数getattr获取模型的属性值,然后设置resnet最后一层为全连接层。这里往resnet输入的大小为224,之后对数据集进行处理,比如三个输入维度减去input_mean数组的对应值,然后除以input_std数组的对应值,完成数据标准化操作。
prepare_tsn函数的功能在于对已知的basemodel,也就是这里默认的resnet101网络结构进行修改,微调最后一层(全连接层)的结构,成为适合该数据集输出的形式。这里就是获取网络最后一层的输入feature_map的通道数,存储于feature_dim中,然后判断dropout是否为0,如果不为0,则添加一个dropout层后再添加一个全连接层,否则直接连接全连接层。这里用到了内置函数setattr,它是用来为输入的某个属性赋值,一般可以用来修改网络结构。全连接层的输入为feature_dim([batch,2048]),输出为数据集的类别num_class,也就是分类的标签数。最后对全连接层的网络权重,进行0均值且指定标准差的初始化操作,之后对偏差bias初始化为0。接着调用了forward函数,它的主要功能是将输入的数据通过模型得到输出,再通过聚合函数得到所需的结果。到这里,模型设计阶段就已经结束了,后来只需要在main函数里面导入这个被修改的模型即可。
(二)
接下来介绍数据准备的部分,这里主要用到了dataset.py。它的主要功能就是对数据集进行读取,并且对其稀疏采样,返回稀疏采样后得到的数据集。它主要调用了dataset里面的TSNDataSet类,首先是类里面的init函数,它的功能在于初始化TSNDataSet,并设置一些参数和参数默认值。这里同时也调用了parse_list函数用来读取list文件。我们可以看到parse_list函数同样调用了VideoRecord类,这里面有三个属性,分别是帧的相对路径,帧数以及帧标签。
这里也定义了sample_indices用来实现对训练集的稀疏随机采样,get_val_indices和get_test_indices函数对验证集和测试集进行稀疏固定采样,最后返回的均是稀疏采样的帧数列表。__getitem__函数,它主要通过parse_list函数获取帧信息,再根据两个判断标志,分别调用sample_indices用来实现对训练集的稀疏随机采样。
最后调用get函数,它的功能主要读取提取的帧图片,并且对帧图片进行并行操作,返回变形后的数据集和对应的类型标签。Get函数同时也调用了load_image函数用来读取图像数据。回归到main函数,对dataset返回的数据进行装载,这里调用了get_augmentation函数,它的功能主要是来对装载的240320的图片数据裁剪transform之后,转变为9224*224的tensor数据。接下来调用transform的GroupNormalize类来进行标准化处理。 到这里,数据准备阶段就基本完成了。
(三)
最后一部分是训练部分,这里只需看main.py即可。首先调用了adjust_learning_rate函数,它主要的功能是随着epoch的增大,自适应减小学习率。接着调用train函数,它是模型训练的部分,不过被独立封装了,所以直接调用即可。Train函数里面调用了这些函数,首先是AverageMeter类,它的主要功能是用来管理一些变量的更新,例如loss损失或者准确率。之后通过判断看是否调用partialBN函数来看是否需要部分BN。最后调用accuracy函数来计算准确率。验证部分调用和训练部分基本相同,这里也就不讲了。
5 数据流图
图一:论文数据流图
图二:resnet101模型训练图