新闻中心
轻量级Vision-Transformer:EdgeViTs复现
本文聚焦轻量级Vision-Transformer模型EdgeViTs的复现。EdgeViTs为适配移动设备,采用分层金字塔结构,设计Local-Global-Local(LGL)瓶颈,通过局部聚合、全局稀疏注意力和局部传播操作,在减少计算量的同时保留全局与局部上下文信息。文中给出模型各组件及整体架构的Paddle实现代码,并基于Flowers数据集进行训练验证。
☞☞☞AI 智能聊天, 问答助手, AI 智能搜索, 免费无限量使用 DeepSeek R1 模型☜☜☜

轻量级Vision-Transformer:EdgeViTs复现
摘要
在计算机视觉领域,基于Self-attention的模型(如(ViTs))已经成为CNN之外的一种极具竞争力的架构。尽管越来越强的变种具有越来越高的识别精度,但由于Self-attention的二次复杂度,现有的ViT在计算和模型大小方面都有较高的要求。 虽然之前的CNN的一些成功的设计选择(例如,卷积和分层结构)已经被引入到最近的ViT中,但它们仍然不足以满足移动设备有限的计算资源需求。这促使人们最近尝试开发基于最先进的MobileNet-v2的轻型MobileViT,但MobileViT与MobileNet-v2仍然存在性能差距。 在这项工作中,作者进一步推进这一研究方向,引入了EdgeViTs,一个新的轻量级ViTs家族,也是首次使基于Self-attention的视觉模型在准确性和设备效率之间的权衡中达到最佳轻量级CNN的性能。
1 EdgeViTs
1.1 总体架构
为了设计适用于移动/边缘设备的轻量级ViT,作者采用了最近ViT变体中使用的分层金字塔结构(图2(a))。Pyramid Transformer模型通常在不同阶段降低了空间分辨率同时也扩展了通道维度。每个阶段由多个基于Transformer Block处理相同形状的张量,类似ResNet的层次设计结构。
在这项工作中,作者深入到Transformer Block,并引入了一个比较划算的Bottlneck,Local-Global-Local(LGL)(图2(b))。LGL通过一个稀疏注意力模块进一步减少了Self-attention的开销(图2(c)),实现了更好的准确性-延迟平衡。
1.2 Local-Global-Local bottleneck(LGL)
与以前在每个空间位置执行Self-attention的Transformer Block相比,LGL Bottleneck只对输入Token的子集计算Self-attention,但支持完整的空间交互,如在标准的Multi-Head Self-attention(MHSA)中。既会减少Token的作用域,同时也保留建模全局和局部上下文的底层信息流。
为了实现这一点,作者将Self-attention分解为连续的模块,处理不同范围内的空间Token(图2(b))。
这里引入了3种有效的操作:
轻量级J*ript复选框动画插件Checkbix
轻量级J*ript复选框动画插件Checkbix
23
查看详情
- Local aggregation:仅集成来自局部近似Token信号的局部聚合
- Global sparse attention:建模一组代表性Token之间的长期关系,其中每个Token都被视为一个局部窗口的代表;
- Local propagation:将委托学习到的全局上下文信息扩散到具有相同窗口的非代表Token。
- Local aggregation
对于每个Token,利用Depth-wise和Point-wise卷积在大小为k×k的局部窗口中聚合信息(图3(a))。
- Global sparse attention
对均匀分布在空间中的稀疏代表性Token集进行采样,每个r×r窗口有一个代表性Token。这里,r表示子样本率。然后,只对这些被选择的Token应用Self-attention(图3(b))。这与所有现有的ViTs不同,在那里,所有的空间Token都作为Self-attention计算中的query被涉及到。
- Local propagation
通过转置卷积将代表性 Token 中编码的全局上下文信息传播到它们的相邻的 Token 中(图 3(c))。
2 代码复现
In [1]import paddleimport paddle.nn as nnfrom paddle.nn import Conv2D as Conv2dfrom paddle.nn import BatchNorm2D as BatchNorm2dfrom paddle.nn import Linearfrom paddle.nn import AvgPool2D as AvgPool2dfrom paddle.nn import Conv2DTranspose as ConvTranspose2dfrom paddle.nn import LayerNorm, GELUIn [2]
class Residual(nn.Layer):
def __init__(self, module):
super().__init__()
self.module = module
def forward(self, x):
return x + self.module(x)class LocalAgg(nn.Layer):
def __init__(self, dim):
super().__init__()
self.conv1 = Conv2d(dim, dim, 1)
self.conv2 = Conv2d(dim, dim, 3, padding=1, groups=dim)
self.conv3 = Conv2d(dim, dim, 1)
self.norm1 = BatchNorm2d(dim)
self.norm2 = BatchNorm2d(dim)
def forward(self, x):
"""
[B, C, H, W] = x.shape
"""
x = self.conv1(self.norm1(x))
x = self.conv2(x)
x = self.conv3(self.norm2(x))
return x
class GlobalSparseAttn(nn.Layer):
def __init__(self, dim, sample_rate = 4, scale = 1):
super().__init__()
self.head_dim = int(48)//int(1)
self.num_heads = int(1)
self.scale = scale
self.qkv = Linear(dim, dim * 3)
self.sampler = AvgPool2d(1, stride=sample_rate)
self.LocalProp = ConvTranspose2d(dim, dim, kernel_size=sample_rate, stride=sample_rate, groups=dim
)
self.proj = Linear(dim, dim)
def forward(self, x):
"""
[B, C, H, W] = x.shape
"""
x = self.sampler(x)
[B, C, H, W] = x.shape
x = x.flatten(2)
x = x.transpose([0,2,1])
x = self.qkv(x)
x = x.transpose([0, 2, 1])
x = x.reshape([1, 144, 14, 14])
q, k, v = x.reshape([B, self.num_heads, -1, H*W]).split([self.head_dim, self.head_dim, self.head_dim], axis=2)
attn = (q.transpose([0, 1, 3, 2]) @ k)
attn = nn.functional.softmax(attn)
x = v @ attn.transpose([0, 1, 3, 2])
x = x.reshape([B, -1, H, W])
x = self.LocalProp(x)
x = paddle.nn.functional.layer_norm(x, x.shape[1:])
x = x.flatten(2)
x = x.transpose([0,2,1])
x = self.proj(x)
x = x.transpose([0,2,1])
x = x.reshape([1, 48, 56, 56]) return x
class DownSampleLayer(nn.Layer):
def __init__(self, dim_in=3, dim_out=48, downsample_rate=4):
super().__init__()
self.downsample = Conv2d(dim_in, dim_out, kernel_size=downsample_rate, stride=
downsample_rate)
def forward(self, x):
x = self.downsample(x)
x = paddle.nn.functional.layer_norm(x, x.shape[1:]) return x
class PatchEmbed(nn.Layer):
def __init__(self, dim):
super().__init__()
self.embed = Conv2d(dim, dim, 3, padding=1, groups=dim)
def forward(self, x):
return x + self.embed(x)
class FFN(nn.Layer):
def __init__(self, dim=3156):
super().__init__()
self.fc1 = nn.Linear(dim, dim*4)
self.fc2 = nn.Linear(dim*4, dim)
def forward(self, x):
x = x.flatten(2)
x = x.transpose([0,2,1])
x = self.fc1(x)
x = nn.functional.gelu(x)
x = self.fc2(x)
x = x.transpose([0,2,1])
x = x.reshape([1, 48, 56, 56]) return x
In [ ]
class EdgeViT(nn.Layer):
def __init__(self, dim_in=3, dim_out=48, downsample_rate=4, dim=48):
super().__init__()
self.downsample1 = DownSampleLayer(dim_in=3, dim_out=48, downsample_rate=4)
self.patchembeding1 = PatchEmbed(dim=48)
self.residual_add1 = Residual(LocalAgg(dim=48))
self.residual_add1_1 = Residual(FFN(dim=48))
self.patchembeding2 = PatchEmbed(dim=48)
self.residual_add2 = Residual(GlobalSparseAttn(dim=48))
self.fc = nn.Linear(150528,103) def forward(self, x):
x = self.downsample1(x)
x = self.patchembeding1(x)
x = self.residual_add1(x)
x = self.residual_add1_1(x)
x = self.patchembeding2(x)
x = self.residual_add2(x)
x = paddle.reshape(x,shape=[-1,48*56*56]) # x = x.transpose([0,2,1])
# print(x.shape)
x = self.fc(x) return x
In [4]
cnn = EdgeViT() paddle.summary(cnn,(1,3,224,224))
[1, 150528]
------------------------------------------------------------------------------
Layer (type) Input Shape Output Shape Param #
==============================================================================
Conv2D-1 [[1, 3, 224, 224]] [1, 48, 56, 56] 2,352
DownSampleLayer-1 [[1, 3, 224, 224]] [1, 48, 56, 56] 0
Conv2D-2 [[1, 48, 56, 56]] [1, 48, 56, 56] 480
PatchEmbed-1 [[1, 48, 56, 56]] [1, 48, 56, 56] 0
BatchNorm2D-1 [[1, 48, 56, 56]] [1, 48, 56, 56] 192
Conv2D-3 [[1, 48, 56, 56]] [1, 48, 56, 56] 2,352
Conv2D-4 [[1, 48, 56, 56]] [1, 48, 56, 56] 480
BatchNorm2D-2 [[1, 48, 56, 56]] [1, 48, 56, 56] 192
Conv2D-5 [[1, 48, 56, 56]] [1, 48, 56, 56] 2,352
LocalAgg-1 [[1, 48, 56, 56]] [1, 48, 56, 56] 0
Residual-1 [[1, 48, 56, 56]] [1, 48, 56, 56] 0
Linear-1 [[1, 3136, 48]] [1, 3136, 192] 9,408
Linear-2 [[1, 3136, 192]] [1, 3136, 48] 9,264
FFN-1 [[1, 48, 56, 56]] [1, 48, 56, 56] 0
Residual-2 [[1, 48, 56, 56]] [1, 48, 56, 56] 0
Conv2D-6 [[1, 48, 56, 56]] [1, 48, 56, 56] 480
PatchEmbed-2 [[1, 48, 56, 56]] [1, 48, 56, 56] 0
AvgPool2D-1 [[1, 48, 56, 56]] [1, 48, 14, 14] 0
Linear-3 [[1, 196, 48]] [1, 196, 144] 7,056
Conv2DTranspose-1 [[1, 48, 14, 14]] [1, 48, 56, 56] 816
Linear-4 [[1, 3136, 48]] [1, 3136, 48] 2,352
GlobalSparseAttn-1 [[1, 48, 56, 56]] [1, 48, 56, 56] 0
Residual-3 [[1, 48, 56, 56]] [1, 48, 56, 56] 0
Linear-5 [[1, 150528]] [1, 103] 15,504,487
==============================================================================
Total params: 15,542,263
Trainable params: 15,541,879
Non-trainable params: 384
------------------------------------------------------------------------------
Input size (MB): 0.57
Forward/backward pass size (MB): 27.85
Params size (MB): 59.29
Estimated Total Size (MB): 87.71
------------------------------------------------------------------------------
{'total_params': 15542263, 'trainable_params': 15541879}
3 模型训练
论文的实验是基于ImageNet数据集进行的,但是目前平台不具备拉取该数据集的能力,故这里采用了Cifar10作为模型验证数据集,仅做调通,不设置对比实验,因为在小数据集上无对比性。
In [5]import paddlefrom paddle.vision.datasets import Flowersfrom paddle.vision.transforms import Compose, Normalize, Resize, Transpose, ToTensor
normalize = Normalize(mean=[0.5, 0.5, 0.5],
std=[0.5, 0.5, 0.5],
data_format='HWC')
transform = Compose([ToTensor(), Normalize(), Resize(size=(224,224))])
cifar10_train = paddle.vision.datasets.Flowers(mode='train',
transform=transform)
cifar10_test = paddle.vision.datasets.Flowers(mode='test',
transform=transform)# 构建训练集数据加载器train_loader = paddle.io.DataLoader(cifar10_train, batch_size=1, shuffle=True)# 构建测试集数据加载器test_loader = paddle.io.DataLoader(cifar10_test, batch_size=1, shuffle=True)print('=============train dataset=============')for image, label in cifar10_train: print('image shape: {}, label: {}'.format(image.shape, label)) break
=============train dataset============= image shape: [3, 224, 224], label: [1]In [ ]
from paddle.metric import Accuracy
model = paddle.Model(EdgeViT())
optim = paddle.optimizer.Adam(learning_rate=0.001, parameters=model.parameters())
model.prepare(
optim,
paddle.nn.CrossEntropyLoss(),
Accuracy()
)
model.fit(train_data=train_loader,
eval_data=test_loader,
epochs=2,
verbose=1
)
以上就是轻量级Vision-Transformer:EdgeViTs复现的详细内容,更多请关注其它相关文章!
# 戛纳
# 网站建设在线咨询
# 丽江网站推广营销
# 太原seo 公司
# seo开源
# 如何培养网站seo优化思路
# 自考本科推广营销方案
# 金融营销软件推广
# 免费网站建设可信吗
# 微博推广自己的网站
# 营销短视频如何突围推广
# 引入了
# 科大
# ai
# 开源
# 只对
# 首款
# 系列产品
# 采用了
# 复选框
# 中文网
# type
# latte
# lsp
# 作用域
相关栏目:
【
行业资讯67740 】
【
技术百科0 】
【
网络运营39195 】
相关推荐:
闪光灯power闪烁是什么意思
为什么夸克下载不到
cmd如何定时执行命令
.asm如何在命令行运行
新装固态硬盘如何安装
春运抢票极速版怎么抢票
如何弄坏固态硬盘
系统如何装在固态硬盘
j*a数组怎么比较abc
智能锁type-c接口是什么
电脑5G怎么上传手机
税负是什么意思
如何管理员打开cmd命令行窗口
excel中datediff函数怎么用
雅迪电动车上的power是什么意思
课程伴侣电脑怎么登录
学typescript需要什么基础么
阿里云盘扩容是什么_扩容阿里云盘方法是什么教程
j*a怎么用json数组
夸克内测有什么好处
ai如何重复使用上一命令
如何创建sql命令
更换固态硬盘如何检查
哪些编程软件需要typescript
电脑命令如何删除账号
如何用命令下载服务器网站
苹果16有哪些亮点功能
折叠屏手机哪个牌子性价比高
市盈率静是什么意思
如何在命令行执行一个jar
linux如何调出命令行
固态硬盘电脑如何设置
折叠手机屏易坏吗为什么
云笔记本电脑有什么用
typescript接口有什么用
如何寻找和修复无法在 AI 中找到文件的问题
如何用命令查看数据库日志文件
如何以管理员身份打开命令提示符
单片机面包板怎么插
春运抢票需要抢几天
网络光刻机是干什么用的
如何安装固态硬盘win10
vue中datediff函数怎么用
typescript怎么拼接
破太岁是什么意思
怎么把手机里爱奇艺的视频下载到u盘里
夸克网盘是什么都有吗
折叠屏手机哪个卖得最好
md5解密是什么意思
html怎么使用typescript


2025-07-29
浏览次数:次
返回列表