新闻中心

【ICLR 2025】RegionViT:从区域到局部的ViT

2025-08-01
浏览次数:
返回列表
RegionViT提出从区域到局部的视觉Transformer结构,以金字塔结构引入区域到局部注意替代全局自注意。先生成不同贴片大小的区域和局部令牌,经区域自注意提取全局信息,再通过局部自注意传递给局部令牌,结合相对位置编码。在多视觉任务上表现优异,实现高效且兼具全局感受野与局部性。

☞☞☞AI 智能聊天, 问答助手, AI 智能搜索, 免费无限量使用 DeepSeek R1 模型☜☜☜

【iclr 2022】regionvit:从区域到局部的vit -

RegionViT:从区域到局部的ViT

摘要

        近年来,视觉Transformer(VIT)在图像分类方面显示出了与卷积神经网络(CNNs)相当的强大能力。 然而,原始ViT只是直接从自然语言处理中继承了相同的体系结构,而自然语言处理通常没有针对视觉应用进行优化。 基于此,本文提出了一种新的视觉Transformer结构,该结构采用金字塔结构,在视觉Transformer中引入了新的区域到局部的注意而不是全局的自注意。 更具体地说,我们的模型首先从具有不同贴片大小的图像中生成区域令牌和局部令牌,其中每个区域令牌与基于空间位置的一组局部令牌相关联。 区域到局部注意包括两个步骤:首先,区域自注意在所有区域令牌之间提取全局信息,然后局部自注意通过自注意在一个区域令牌和相关的局部令牌之间交换信息。 因此,即使局部自我注意的范围局限于局部区域,但它仍然可以接收到全局信息。 在图像分类、目标和关键点检测、语义分割和动作识别等四个视觉任务上的大量实验表明,我们的方法优于或与包括许多并行工作在内的现有ViT变体相当。

1. RegionViT

        由于全局自注意力计算太贵,很多工作提出使用局部自注意力,即在一个小区域内进行全局自注意力,但是局部自注意力又会带来另外一个问题,即感受野过小。为此,本文提出了一种新的从粗到细的Transformer——RegionViT。通过区域令牌进行全局交互,并将区域令牌包含的全局信息通过局部自注意力传递给对应的局部Token。本文方法的整体架构如图2所示:

【ICLR 2022】RegionViT:从区域到局部的ViT -        

        本文的核心模块是区域到局部的Transformer编码器,主要思想就是通过区域令牌进行全局交互,并将区域令牌包含的全局信息通过局部自注意力传递给对应的局部Token,具体操作如下公式所示:

yrd=xrd1+RSA(LN(xrd1)),yi,jd=[yri,jd{xli,j,m,nd1}m,nM]zi,jd=yi,jd+LSA(LN(yi,jd)),xi,jd=zi,jd+FFN(LN(zi,jd))yrd=xrd−1+RSA(LN(xrd−1)),yi,jd=[yri,jd∥{xli,j,m,nd−1}m,n∈M]zi,jd=yi,jd+LSA(LN(yi,jd)),xi,jd=zi,jd+FFN(LN(zi,jd))

Motiff妙多 Motiff妙多

Motiff妙多是一款AI驱动的界面设计工具,定位为“AI时代设计工具”

Motiff妙多 334 查看详情 Motiff妙多

        局部性是理解视觉内容的重要线索。因此,本文提出使用相对位置编码,值得注意的是,该位置编码只添加到局部Token中,不添加区域Token到局部Token的位置编码。具体公式如下:

a(xm,ym),(xn,yn)=softmax(q(xm,ym)k(xn,yn)T+b(xmxn,ymyn)),a(xm,ym),(xn,yn)=softmax(q(xm,ym)k(xn,yn)T+b(xm−xn,ym−yn)),

【ICLR 2022】RegionViT:从区域到局部的ViT -        

2. 代码复现

2.1 下载并导入所需的库

In [ ]
%matplotlib inlineimport paddleimport numpy as npimport matplotlib.pyplot as pltfrom paddle.vision.datasets import Cifar10from paddle.vision.transforms import Transposefrom paddle.io import Dataset, DataLoaderfrom paddle import nnimport paddle.nn.functional as Fimport paddle.vision.transforms as transformsimport osimport matplotlib.pyplot as pltfrom matplotlib.pyplot import figureimport itertoolsfrom functools import partial
   

2.2 创建数据集

In [3]
train_tfm = transforms.Compose([
    transforms.RandomResizedCrop(224, scale=(0.6, 1.0)),
    transforms.ColorJitter(brightness=0.2,contrast=0.2, saturation=0.2),
    transforms.RandomHorizontalFlip(0.5),
    transforms.RandomRotation(20),
    transforms.ToTensor(),
    transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
])

test_tfm = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
])
    In [4]
paddle.vision.set_image_backend('cv2')# 使用Cifar10数据集train_dataset = Cifar10(data_file='data/data152754/cifar-10-python.tar.gz', mode='train', transform = train_tfm, )
val_dataset = Cifar10(data_file='data/data152754/cifar-10-python.tar.gz', mode='test',transform = test_tfm)print("train_dataset: %d" % len(train_dataset))print("val_dataset: %d" % len(val_dataset))
       
train_dataset: 50000
val_dataset: 10000
        In [5]
batch_size=64
    In [6]
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, drop_last=True, num_workers=4)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, drop_last=False, num_workers=4)
   

2.3 模型的创建

2.3.1 标签平滑

In [7]
class LabelSmoothingCrossEntropy(nn.Layer):
    def __init__(self, smoothing=0.1):
        super().__init__()
        self.smoothing = smoothing    def forward(self, pred, target):

        confidence = 1. - self.smoothing
        log_probs = F.log_softmax(pred, axis=-1)
        idx = paddle.stack([paddle.arange(log_probs.shape[0]), target], axis=1)
        nll_loss = paddle.gather_nd(-log_probs, index=idx)
        smooth_loss = paddle.mean(-log_probs, axis=-1)
        loss = confidence * nll_loss + self.smoothing * smooth_loss        return loss.mean()
   

2.3.2 DropPath

In [8]
def drop_path(x, drop_prob=0.0, training=False):
    """
    Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
    the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
    See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ...
    """
    if drop_prob == 0.0 or not training:        return x
    keep_prob = paddle.to_tensor(1 - drop_prob)
    shape = (paddle.shape(x)[0],) + (1,) * (x.ndim - 1)
    random_tensor = keep_prob + paddle.rand(shape, dtype=x.dtype)
    random_tensor = paddle.floor(random_tensor)  # binarize
    output = x.divide(keep_prob) * random_tensor    return outputclass DropPath(nn.Layer):
    def __init__(self, drop_prob=None):
        super(DropPath, self).__init__()
        self.drop_prob = drop_prob    def forward(self, x):
        return drop_path(x, self.drop_prob, self.training)
   

2.3.3 RegionViT模型的创建

In [9]
class LayerNorm2D(nn.Layer):
    def __init__(self, channels, eps=1e-5, elementwise_affine=True):
        super().__init__()

        self.channels = channels
        self.eps = paddle.to_tensor(eps)
        self.elementwise_affine = elementwise_affine        if self.elementwise_affine:
            self.weight = self.create_parameter(shape=(1, channels, 1, 1), default_initializer=nn.initializer.Constant(1.0))
            self.bias = self.create_parameter(shape=(1, channels, 1, 1), default_initializer=nn.initializer.Constant(0.0))        else:
            self.register_buffer('weight', None)
            self.register_buffer('bias', None)    def forward(self, input):
        mean = input.mean(1, keepdim=True)
        std = paddle.sqrt(input.var(1, unbiased=False, keepdim=True) + self.eps)
        out = (input - mean) / std        if self.elementwise_affine:
            out = out * self.weight + self.bias        return out
    In [10]
class AttentionWithRelPos(nn.Layer):
    def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.,
                 attn_map_dim=None, num_cls_tokens=1):
        super().__init__()
        self.num_heads = num_heads
        head_dim = dim // num_heads

        self.scale = qk_scale or head_dim ** -0.5

        self.qkv = nn.Linear(dim, dim * 3, bias_attr=qkv_bias)
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)

        self.num_cls_tokens = num_cls_tokens        if attn_map_dim is not None:
            one_dim = attn_map_dim[0]
            rel_pos_dim = (2 * one_dim - 1)
            self.rel_pos = self.create_parameter(shape=(num_heads, rel_pos_dim ** 2), default_initializer=nn.initializer.Constant(0.0))
            tmp = paddle.arange(rel_pos_dim ** 2).reshape((rel_pos_dim, rel_pos_dim))
            out = []
            offset_x = offset_y = one_dim // 2
            for y in range(one_dim):                for x in range(one_dim):                    for dy in range(one_dim):                        for dx in range(one_dim):
                            out.append(tmp[dy - y + offset_y, dx - x + offset_x])
            self.rel_pos_index = paddle.to_tensor(out, dtype=paddle.int32)
            tn = nn.initializer.TruncatedNormal(std=.02)
            tn(self.rel_pos)        else:
            self.rel_pos = None

    def forward(self, x, patch_attn=False, mask=None):
        B, N, C = x.shape
        qkv = self.qkv(x).reshape((B, N, 3, self.num_heads, C // self.num_heads)).transpose([2, 0, 3, 1, 4])
        q, k, v = qkv[0], qkv[1], qkv[2]

        attn = (q @ k.transpose([0, 1, 3, 2])) * self.scale        if self.rel_pos is not None and patch_attn:            # use for the indicating patch + cls:
            rel_pos = self.rel_pos[:, self.rel_pos_index].reshape((self.num_heads, N - self.num_cls_tokens, N - self.num_cls_tokens))
            attn[:, :, self.num_cls_tokens:, self.num_cls_tokens:] = attn[:, :, self.num_cls_tokens:, self.num_cls_tokens:] + rel_pos        if mask is not None:            ## mask is only (BH_sW_s)(ksks)(ksks), need to expand it
            mask = mask.unsqueeze(1).expand((-1, self.num_heads, -1, -1))
            attn = attn.masked_fill(mask == 0, paddle.finfo(attn.dtype).min)

        attn = F.softmax(attn, axis=-1)
        attn = self.attn_drop(attn)

        x = (attn @ v).transpose([0, 2, 1, 3]).reshape((B, N, C))
        x = self.proj(x)
        x = self.proj_drop(x)        return x
    In [11]
def to_2tuple(x):
    return (x, x)class PatchEmbed(nn.Layer):
    """ Image to Patch Embedding
    """
    def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, patch_conv_type='linear'):
        super().__init__()
        img_size = to_2tuple(img_size)
        patch_size = to_2tuple(patch_size)
        num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0])
        self.img_size = img_size
        self.patch_size = patch_size
        self.num_patches = num_patches        if patch_conv_type == '3conv':            if patch_size[0] == 4:
                tmp = [
                    nn.Conv2D(in_chans, embed_dim // 4, kernel_size=3, stride=2, padding=1),
                    LayerNorm2D(embed_dim // 4),
                    nn.GELU(),
                    nn.Conv2D(embed_dim // 4, embed_dim // 2, kernel_size=3, stride=2, padding=1),
                    LayerNorm2D(embed_dim // 2),
                    nn.GELU(),
                    nn.Conv2D(embed_dim // 2, embed_dim, kernel_size=3, stride=1, padding=1),
                ]            else:                raise ValueError(f"Unknown patch size {patch_size[0]}")
            self.proj = nn.Sequential(*tmp)        else:            if patch_conv_type == '1conv':
                kernel_size = (2 * patch_size[0], 2 * patch_size[1])
                stride = (patch_size[0], patch_size[1])
                padding = (patch_size[0] - 1, patch_size[1] - 1)            else:
                kernel_size = patch_size
                stride = patch_size
                padding = 0

            self.proj = nn.Conv2D(in_chans, embed_dim, kernel_size=kernel_size,
                                  stride=stride, padding=padding)    def forward(self, x, extra_padding=False):
        B, C, H, W = x.shape        # FIXME look at relaxing size constraints
        # assert H == self.img_size[0] and W == self.img_size[1], \
        #     f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
        if extra_padding and (H % self.patch_size[0] != 0 or W % self.patch_size[1] != 0):
            p_l = (self.patch_size[1] - W % self.patch_size[1]) // 2
            p_r = (self.patch_size[1] - W % self.patch_size[1]) - p_l
            p_t = (self.patch_size[0] - H % self.patch_size[0]) // 2
            p_b = (self.patch_size[0] - H % self.patch_size[0]) - p_t
            x = F.pad(x, (p_l, p_r, p_t, p_b))
        x = self.proj(x)        return x
    In [12]
class Mlp(nn.Layer):
    """ MLP as used in Vision Transformer, MLP-Mixer and related networks
    """
    def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, bias=True, drop=0.):
        super().__init__()
        out_features = out_features or in_features
        hidden_features = hidden_features or in_features
        bias = to_2tuple(bias)
        drop_probs = to_2tuple(drop)

        self.fc1 = nn.Linear(in_features, hidden_features, bias_attr=bias[0])
        self.act = act_layer()
        self.drop1 = nn.Dropout(drop_probs[0])
        self.fc2 = nn.Linear(hidden_features, out_features, bias_attr=bias[1])
        self.drop2 = nn.Dropout(drop_probs[1])    def forward(self, x):
        x = self.fc1(x)
        x = self.act(x)
        x = self.drop1(x)
        x = self.fc2(x)
        x = self.drop2(x)        return x
    In [13]
class R2LAttentionPlusFFN(nn.Layer):

    def __init__(self, input_channels, output_channels, kernel_size, num_heads, mlp_ratio=1., qkv_bias=False, qk_scale=None,
                 act_layer=nn.GELU, norm_layer=nn.LayerNorm, drop_path=0., attn_drop=0., drop=0.,
                 cls_attn=True):
        super().__init__()        if not isinstance(kernel_size, (tuple, list)):
            kernel_size = [(kernel_size, kernel_size), (kernel_size, kernel_size), 0]
        self.kernel_size = kernel_size        if cls_attn:
            self.norm0 = norm_layer(input_channels)        else:
            self.norm0 = None

        self.norm1 = norm_layer(input_channels)
        self.attn = AttentionWithRelPos(
            input_channels, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop,
            attn_map_dim=(kernel_size[0][0], kernel_size[0][1]), num_cls_tokens=1)        # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
        self.norm2 = norm_layer(input_channels)
        self.mlp = Mlp(in_features=input_channels, hidden_features=int(output_channels * mlp_ratio), out_features=output_channels, act_layer=act_layer, drop=drop)

        self.expand = nn.Sequential(
            norm_layer(input_channels),
            act_layer(),
            nn.Linear(input_channels, output_channels)
        ) if input_channels != output_channels else None

        self.output_channels = output_channels
        self.input_channels = input_channels    def forward(self, xs):
        out, B, H, W, mask = xs
        cls_tokens = out[:, 0:1, ...]

        C = cls_tokens.shape[-1]
        cls_tokens = cls_tokens.reshape((B, -1, C))  # (N)x(H/sxW/s)xC

        if self.norm0 is not None:
            cls_tokens = cls_tokens + self.drop_path(self.attn(self.norm0(cls_tokens)))  # (N)x(H/sxK/s)xC

        # ks, stride, padding = self.kernel_size
        cls_tokens = cls_tokens.reshape((-1, 1, C))  # (NxH/sxK/s)x1xC

        out = paddle.concat((cls_tokens, out[:, 1:, ...]), axis=1)
        tmp = out

        tmp = tmp + self.drop_path(self.attn(self.norm1(tmp), patch_attn=True, mask=mask))
        identity = self.expand(tmp) if self.expand is not None else tmp
        tmp = identity + self.drop_path(self.mlp(self.norm2(tmp)))        return tmp
    In [14]
class Projection(nn.Layer):
    def __init__(self, input_channels, output_channels, act_layer, mode='sc'):
        super().__init__()
        tmp = []        if 'c' in mode:
            ks = 2 if 's' in mode else 1
            if ks == 2:
                stride = ks
                ks = ks + 1
                padding = ks // 2
            else:
                stride = ks
                padding = 0

            if input_channels == output_channels and ks == 1:
                tmp.append(nn.Identity())            else:
                tmp.extend([
                    LayerNorm2D(input_channels),
                    act_layer(),
                ])
                tmp.append(nn.Conv2D(in_channels=input_channels, out_channels=output_channels, kernel_size=ks, stride=stride, padding=padding, groups=input_channels))

        self.proj = nn.Sequential(*tmp)
        self.proj_cls = self.proj    def forward(self, xs):
        cls_tokens, patch_tokens = xs        # x: BxCxHxW
        cls_tokens = self.proj_cls(cls_tokens)
        patch_tokens = self.proj(patch_tokens)        return cls_tokens, patch_tokens
    In [15]
def convert_to_flatten_layout(cls_tokens, patch_tokens, ws):
    """
    Convert the token layer in a flatten form, it will speed up the model.

    Furthermore, it also handle the case that if the size between regional tokens and local tokens are not consistent.
    """
    # padding if needed, and all paddings are happened at bottom and right.
    B, C, H, W = patch_tokens.shape
    _, _, H_ks, W_ks = cls_tokens.shape
    need_mask = False
    p_l, p_r, p_t, p_b = 0, 0, 0, 0
    if H % (H_ks * ws) != 0 or W % (W_ks * ws) != 0:
        p_l, p_r = 0, W_ks * ws - W
        p_t, p_b = 0, H_ks * ws - H
        patch_tokens = F.pad(patch_tokens, (p_l, p_r, p_t, p_b))
        need_mask = True

    B, C, H, W = patch_tokens.shape
    kernel_size = [H // H_ks, W // W_ks]
    tmp = F.unfold(patch_tokens, kernel_sizes=kernel_size, strides=kernel_size, paddings=[0, 0])  # Nx(Cxksxks)x(H/sxK/s)
    patch_tokens = tmp.transpose([0, 2, 1]).reshape((-1, C, kernel_size[0] * kernel_size[1])).transpose([0, 2, 1])  # (NxH/sxK/s)x(ksxks)xC

    if need_mask:
        BH_sK_s, ksks, C = patch_tokens.shape
        H_s, W_s = H // ws, W // ws
        mask = paddle.ones(BH_sK_s // B, 1 + ksks, 1 + ksks, dtype='float32')
        right = paddle.zeros(1 + ksks, 1 + ksks, dtype='float32')
        tmp = paddle.zeros(ws, ws, dtype='float32')
        tmp[0:(ws - p_r), 0:(ws - p_r)] = 1.
        tmp = tmp.repeat(ws, ws)
        right[1:, 1:] = tmp
        right[0, 0] = 1
        right[0, 1:] = paddle.to_tensor([1.] * (ws - p_r) + [0.] * p_r).repeat(ws)
        right[1:, 0] = paddle.to_tensor([1.] * (ws - p_r) + [0.] * p_r).repeat(ws)
        bottom = paddle.zeros_like(right)
        bottom[0:ws * (ws - p_b) + 1, 0:ws * (ws - p_b) + 1] = 1.
        bottom_right = copy.deepcopy(right)
        bottom_right[0:ws * (ws - p_b) + 1, 0:ws * (ws - p_b) + 1] = 1.

        mask[W_s - 1:(H_s - 1) * W_s:W_s, ...] = right
        mask[(H_s - 1) * W_s:, ...] = bottom
        mask[-1, ...] = bottom_right
        mask = mask.repeat(B, 1, 1)    else:
        mask = None

    cls_tokens = cls_tokens.flatten(2).transpose([0, 2, 1])  # (N)x(H/sxK/s)xC
    cls_tokens = cls_tokens.reshape((-1, 1, cls_tokens.shape[-1]))  # (NxH/sxK/s)x1xC

    out = paddle.concat((cls_tokens, patch_tokens), axis=1)    return out, mask, p_l, p_r, p_t, p_b, B, C, H, Wdef convert_to_spatial_layout(out, output_channels, B, H, W, kernel_size, mask, p_l, p_r, p_t, p_b):
    """
    Convert the token layer from flatten into 2-D, will be used to downsample the spatial dimension.
    """
    cls_tokens = out[:, 0:1, ...]
    patch_tokens = out[:, 1:, ...]    # cls_tokens: (BxH/sxW/s)x(1)xC, patch_tokens: (BxH/sxW/s)x(ksxks)xC
    C = output_channels
    kernel_size = kernel_size[0]
    H_ks = H // kernel_size[0]
    W_ks = W // kernel_size[1]    # reorganize data, need to convert back to cls_tokens: BxCxH/sxW/s, patch_tokens: BxCxHxW
    cls_tokens = cls_tokens.reshape((B, -1, C)).transpose([0, 2, 1]).reshape((B, C, H_ks, W_ks))
    patch_tokens = patch_tokens.transpose([0, 2, 1]).reshape((B, -1, kernel_size[0] * kernel_size[1] * C)).transpose([0, 2, 1])
    patch_tokens = F.fold(patch_tokens, [H, W], kernel_sizes=kernel_size, strides=kernel_size, paddings=[0, 0])    if mask is not None:        if p_b > 0:
            patch_tokens = patch_tokens[:, :, :-p_b, :]        if p_r > 0:
            patch_tokens = patch_tokens[:, :, :, :-p_r]    return cls_tokens, patch_tokens
    In [16]
class ConvAttBlock(nn.Layer):
    def __init__(self, input_channels, output_channels, kernel_size, num_blocks, num_heads, mlp_ratio=1., qkv_bias=False, qk_scale=None, pool='sc',
                 act_layer=nn.GELU, norm_layer=nn.LayerNorm, drop_path_rate=(0.,), attn_drop_rate=0., drop_rate=0.,
                 cls_attn=True, peg=False):
        super().__init__()
        tmp = []        if pool:
            tmp.append(Projection(input_channels, output_channels, act_layer=act_layer, mode=pool))        for i in range(num_blocks):
            kernel_size_ = kernel_size
            tmp.append(R2LAttentionPlusFFN(output_channels, output_channels, kernel_size_, num_heads, mlp_ratio, qkv_bias, qk_scale,
                                           act_layer=act_layer, norm_layer=norm_layer, drop_path=drop_path_rate[i], attn_drop=attn_drop_rate, drop=drop_rate,
                                           cls_attn=cls_attn))

        self.block = nn.LayerList(tmp)
        self.output_channels = output_channels
        self.ws = kernel_size        if not isinstance(kernel_size, (tuple, list)):
            kernel_size = [[kernel_size, kernel_size], [kernel_size, kernel_size], 0]
        self.kernel_size = kernel_size

        self.peg = nn.Conv2D(output_channels, output_channels, kernel_size=3, padding=1, groups=output_channels, bias=False) if peg else None

    def forward(self, xs):
        cls_tokens, patch_tokens = xs
        cls_tokens, patch_tokens = self.block[0]((cls_tokens, patch_tokens))
        out, mask, p_l, p_r, p_t, p_b, B, C, H, W = convert_to_flatten_layout(cls_tokens, patch_tokens, self.ws)        for i in range(1, len(self.block)):
            blk = self.block[i]

            out = blk((out, B, H, W, mask))            if self.peg is not None and i == 1:
                cls_tokens, patch_tokens = convert_to_spatial_layout(out, self.output_channels, B, H, W, self.kernel_size, mask, p_l, p_r, p_t, p_b)
                cls_tokens = cls_tokens + self.peg(cls_tokens)
                patch_tokens = patch_tokens + self.peg(patch_tokens)
                out, mask, p_l, p_r, p_t, p_b, B, C, H, W = convert_to_flatten_layout(cls_tokens, patch_tokens, self.ws)

        cls_tokens, patch_tokens = convert_to_spatial_layout(out, self.output_channels, B, H, W, self.kernel_size, mask, p_l, p_r, p_t, p_b)        return cls_tokens, patch_tokens
    In [17]
class RegionViT(nn.Layer):
    """
    Note:
        The variable naming mapping between codes and papers:
        - cls_tokens -> regional tokens
        - patch_tokens -> local tokens
    """
    def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=(768,), depth=(12,),
                 num_heads=(12,), mlp_ratio=4., qkv_bias=True, qk_scale=None, drop_rate=0., attn_drop_rate=0.,
                 drop_path_rate=0., norm_layer=partial(nn.LayerNorm, epsilon=1e-6),                 # regionvit parameters
                 kernel_sizes=None, downsampling=None,
                 patch_conv_type='3conv',
                 computed_cls_token=True, peg=False,
                 det_norm=False):

        super().__init__()
        self.num_classes = num_classes
        self.kernel_sizes = kernel_sizes
        self.num_features = embed_dim[-1]  # num_features for consistency with other models
        self.embed_dim = embed_dim
        self.patch_size = patch_size
        self.img_size = img_size
        self.patch_embed = PatchEmbed(
            img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim[0],
            patch_conv_type=patch_conv_type)        if not isinstance(mlp_ratio, (list, tuple)):
            mlp_ratio = [mlp_ratio] * len(depth)

        self.computed_cls_token = computed_cls_token
        self.cls_token = PatchEmbed(
            img_size=img_size, patch_size=patch_size * kernel_sizes[0], in_chans=in_chans, embed_dim=embed_dim[0],
            patch_conv_type='linear'
        )
        self.pos_drop = nn.Dropout(p=drop_rate)
        total_depth = sum(depth)
        dpr = [x.item() for x in paddle.linspace(0, drop_path_rate, total_depth)]  # stochastic depth decay rule
        dpr_ptr = 0
        self.layers = nn.LayerList()        for i in range(len(embed_dim) - 1):
            curr_depth = depth[i]
            dpr_ = dpr[dpr_ptr: dpr_ptr + curr_depth]

            self.layers.append(
                ConvAttBlock(embed_dim[i], embed_dim[i + 1], kernel_size=kernel_sizes[i], num_blocks=depth[i], drop_path_rate=dpr_,
                             num_heads=num_heads[i], mlp_ratio=mlp_ratio[i], qkv_bias=qkv_bias, qk_scale=qk_scale,
                             pool=downsampling[i], norm_layer=norm_layer, attn_drop_rate=attn_drop_rate, drop_rate=drop_rate,
                             cls_attn=True, peg=peg)
            )
            dpr_ptr += curr_depth
        self.norm = norm_layer(embed_dim[-1])        # Classifier head
        self.head = nn.Linear(embed_dim[-1], num_classes) if num_classes > 0 else nn.Identity()        if not computed_cls_token:
            tn = nn.initializer.TruncatedNormal(std=.02)
            tn(self.cls_token)

        self.det_norm = det_norm        if self.det_norm:            # add a norm layer for the outputs at each stage, for detection
            for i in range(4):
                layer = LayerNorm2D(embed_dim[1 + i])
                layer_name = f'norm{i}'
                self.add_module(layer_name, layer)

        self.apply(self._init_weights)    def _init_weights(self, m):
        tn = nn.initializer.TruncatedNormal(std=.02)
        ones = nn.initializer.Constant(1.0)
        zeros = nn.initializer.Constant(0.0)        if isinstance(m, nn.Linear):
            tn(m.weight)            if isinstance(m, nn.Linear) and m.bias is not None:
                zeros(m.bias)        elif isinstance(m, nn.LayerNorm):
            zeros(m.bias)
            ones(m.weight)    def forward_features(self, x, detection=False):
        o_x = x
        x = self.patch_embed(x)        # B x branches x classes
        cls_tokens = self.cls_token(o_x, extra_padding=True)
        x = self.pos_drop(x)  # N C H W
        tmp_out = []        for idx, layer in enumerate(self.layers):
            cls_tokens, x = layer((cls_tokens, x))            if self.det_norm:
                norm_layer = getattr(self, f'norm{idx}')
                x = norm_layer(x)
            tmp_out.append(x)        if detection:            return tmp_out

        N, C, H, W = cls_tokens.shape
        cls_tokens = cls_tokens.reshape((N, C, -1)).transpose([0, 2, 1])
        cls_tokens = self.norm(cls_tokens)
        out = paddle.mean(cls_tokens, axis=1)        return out    def forward(self, x):
        x = self.forward_features(x)
        x = self.head(x)        return x
    In [18]
_model_cfg = {    'tiny': {        'img_size': 224,        'patch_conv_type': '3conv',        'patch_size': 4,        'embed_dim': [64, 64, 128, 256, 512],        'num_heads': [2, 4, 8, 16],        'mlp_ratio': 4.,        'depth': [2, 2, 8, 2],        'kernel_sizes': [7, 7, 7, 7],  # 8x8, 4x4, 2x2, 1x1,
        'downsampling': ['c', 'sc', 'sc', 'sc'],
    },    'small': {        'img_size': 224,        'patch_conv_type': '3conv',        'patch_size': 4,        'embed_dim': [96, 96, 192, 384, 768],        'num_heads': [3, 6, 12, 24],        'mlp_ratio': 4.,        'depth': [2, 2, 8, 2],        'kernel_sizes': [7, 7, 7, 7],  # 8x8, 4x4, 2x2, 1x1,
        'downsampling': ['c', 'sc', 'sc', 'sc'],
    },    'medium': {        'img_size': 224,        'patch_conv_type': '1conv',        'patch_size': 4,        'embed_dim': [96] + [96 * (2 ** i) for i in range(4)],        'num_heads': [3, 6, 12, 24],        'mlp_ratio': 4.,        'depth': [2, 2, 14, 2],        'kernel_sizes': [7, 7, 7, 7],  # 8x8, 4x4, 2x2, 1x1,
        'downsampling': ['c', 'sc', 'sc', 'sc'],
    },    'base': {        'img_size': 224,        'patch_conv_type': '1conv',        'patch_size': 4,        'embed_dim': [128, 128, 256, 512, 1024],        'num_heads': [4, 8, 16, 32],        'mlp_ratio': 4.,        'depth': [2, 2, 14, 2],        'kernel_sizes': [7, 7, 7, 7],  # 8x8, 4x4, 2x2, 1x1,
        'downsampling': ['c', 'sc', 'sc', 'sc'],
    },    'small_w14': {        'img_size': 224,        'patch_conv_type': '3conv',        'patch_size': 4,        'embed_dim': [96, 96, 192, 384, 768],        'num_heads': [3, 6, 12, 24],        'mlp_ratio': 4.,        'depth': [2, 2, 8, 2],        'kernel_sizes': [14, 14, 14, 14],  # 8x8, 4x4, 2x2, 1x1,
        'downsampling': ['c', 'sc', 'sc', 'sc'],
    },    'small_w14_peg': {        'img_size': 224,        'patch_conv_type': '3conv',        'patch_size': 4,        'embed_dim': [96, 96, 192, 384, 768],        'num_heads': [3, 6, 12, 24],        'mlp_ratio': 4.,        'depth': [2, 2, 8, 2],        'kernel_sizes': [14, 14, 14, 14],  # 8x8, 4x4, 2x2, 1x1,
        'downsampling': ['c', 'sc', 'sc', 'sc'],        'peg': True
    },    'base_w14': {        'img_size': 224,        'patch_conv_type': '1conv',        'patch_size': 4,        'embed_dim': [128, 128, 256, 512, 1024],        'num_heads': [4, 8, 16, 32],        'mlp_ratio': 4.,        'depth': [2, 2, 14, 2],        'kernel_sizes': [14, 14, 14, 14],  # 8x8, 4x4, 2x2, 1x1,
        'downsampling': ['c', 'sc', 'sc', 'sc'],
    },    'base_w14_peg': {        'img_size': 224,        'patch_conv_type': '1conv',        'patch_size': 4,        'embed_dim': [128, 128, 256, 512, 1024],        'num_heads': [4, 8, 16, 32],        'mlp_ratio': 4.,        'depth': [2, 2, 14, 2],        'kernel_sizes': [14, 14, 14, 14],  # 8x8, 4x4, 2x2, 1x1,
        'downsampling': ['c', 'sc', 'sc', 'sc'],        'peg': True
    },

}
    In [19]
num_classes = 10def regionvit_tiny_224():
    model_cfg = _model_cfg['tiny']
    model = RegionViT(**model_cfg, num_classes=num_classes)    return modeldef regionvit_small_224():
    model_cfg = _model_cfg['small']
    model = RegionViT(**model_cfg, num_classes=num_classes)    return modeldef regionvit_small_w14_224():
    model_cfg = _model_cfg['small_w14']
    model = RegionViT(**model_cfg, num_classes=num_classes)    return modeldef regionvit_small_w14_peg_224():
    model_cfg = _model_cfg['small_w14_peg']
    model = RegionViT(**model_cfg, num_classes=num_classes)    return modeldef regionvit_medium_224():
    model_cfg = _model_cfg['medium']
    model = RegionViT(**model_cfg, num_classes=num_classes)    return modeldef regionvit_base_224():
    model_cfg = _model_cfg['base']
    model = RegionViT(**model_cfg, num_classes=num_classes)    return modeldef regionvit_base_w14_224():
    model_cfg = _model_cfg['base_w14']
    model = RegionViT(**model_cfg, num_classes=num_classes)    return modeldef regionvit_base_w14_peg_224():
    model_cfg = _model_cfg['base_w14_peg']
    model = RegionViT(**model_cfg, num_classes=num_classes)    return model
   

2.3.4 模型的参数

In [ ]
model = regionvit_tiny_224()
paddle.summary(model, (1, 3, 224, 224))
   

【ICLR 2022】RegionViT:从区域到局部的ViT -        

2.4 训练

In [22]
learning_rate = 0.0001n_epochs = 100paddle.seed(42)
np.random.seed(42)
    In [ ]
work_path = 'work/model'# RegionViT-Tinymodel = regionvit_tiny_224()

criterion = LabelSmoothingCrossEntropy()

scheduler = paddle.optimizer.lr.CosineAnnealingDecay(learning_rate=learning_rate, T_max=50000 // batch_size * n_epochs, verbose=False)
optimizer = paddle.optimizer.Adam(parameters=model.parameters(), learning_rate=scheduler, weight_decay=1e-5)

gate = 0.0threshold = 0.0best_acc = 0.0val_acc = 0.0loss_record = {'train': {'loss': [], 'iter': []}, 'val': {'loss': [], 'iter': []}}   # for recording lossacc_record = {'train': {'acc': [], 'iter': []}, 'val': {'acc': [], 'iter': []}}      # for recording accuracyloss_iter = 0acc_iter = 0for epoch in range(n_epochs):    # ---------- Training ----------
    model.train()
    train_num = 0.0
    train_loss = 0.0

    val_num = 0.0
    val_loss = 0.0
    accuracy_manager = paddle.metric.Accuracy()
    val_accuracy_manager = paddle.metric.Accuracy()    print("#===epoch: {}, lr={:.10f}===#".format(epoch, optimizer.get_lr()))    for batch_id, data in enumerate(train_loader):
        x_data, y_data = data
        labels = paddle.unsqueeze(y_data, axis=1)

        logits = model(x_data)

        loss = criterion(logits, y_data)

        acc = accuracy_manager.compute(logits, labels)
        accuracy_manager.update(acc)        if batch_id % 10 == 0:
            loss_record['train']['loss'].append(loss.numpy())
            loss_record['train']['iter'].append(loss_iter)
            loss_iter += 1

        loss.backward()

        optimizer.step()
        scheduler.step()
        optimizer.clear_grad()

        train_loss += loss
        train_num += len(y_data)

    total_train_loss = (train_loss / train_num) * batch_size
    train_acc = accuracy_manager.accumulate()
    acc_record['train']['acc'].append(train_acc)
    acc_record['train']['iter'].append(acc_iter)
    acc_iter += 1
    # Print the information.
    print("#===epoch: {}, train loss is: {}, train acc is: {:2.2f}%===#".format(epoch, total_train_loss.numpy(), train_acc*100))    # ---------- Validation ----------
    model.eval()    for batch_id, data in enumerate(val_loader):

        x_data, y_data = data
        labels = paddle.unsqueeze(y_data, axis=1)        with paddle.no_grad():
          logits = model(x_data)

        loss = criterion(logits, y_data)

        acc = val_accuracy_manager.compute(logits, labels)
        val_accuracy_manager.update(acc)

        val_loss += loss
        val_num += len(y_data)

    total_val_loss = (val_loss / val_num) * batch_size
    loss_record['val']['loss'].append(total_val_loss.numpy())
    loss_record['val']['iter'].append(loss_iter)
    val_acc = val_accuracy_manager.accumulate()
    acc_record['val']['acc'].append(val_acc)
    acc_record['val']['iter'].append(acc_iter)    print("#===epoch: {}, val loss is: {}, val acc is: {:2.2f}%===#".format(epoch, total_val_loss.numpy(), val_acc*100))    # ===================s*e====================
    if val_acc > best_acc:
        best_acc = val_acc
        paddle.s*e(model.state_dict(), os.path.join(work_path, 'best_model.pdparams'))
        paddle.s*e(optimizer.state_dict(), os.path.join(work_path, 'best_optimizer.pdopt'))print(best_acc)
paddle.s*e(model.state_dict(), os.path.join(work_path, 'final_model.pdparams'))
paddle.s*e(optimizer.state_dict(), os.path.join(work_path, 'final_optimizer.pdopt'))
   

【ICLR 2022】RegionViT:从区域到局部的ViT -        

2.5 结果分析

In [24]
def plot_learning_curve(record, title='loss', ylabel='CE Loss'):
    ''' Plot learning curve of your CNN '''
    maxtrain = max(map(float, record['train'][title]))
    maxval = max(map(float, record['val'][title]))
    ymax = max(maxtrain, maxval) * 1.1
    mintrain = min(map(float, record['train'][title]))
    minval = min(map(float, record['val'][title]))
    ymin = min(mintrain, minval) * 0.9

    total_steps = len(record['train'][title])
    x_1 = list(map(int, record['train']['iter']))
    x_2 = list(map(int, record['val']['iter']))
    figure(figsize=(10, 6))
    plt.plot(x_1, record['train'][title], c='tab:red', label='train')
    plt.plot(x_2, record['val'][title], c='tab:cyan', label='val')
    plt.ylim(ymin, ymax)
    plt.xlabel('Training steps')
    plt.ylabel(ylabel)
    plt.title('Learning curve of {}'.format(title))
    plt.legend()
    plt.show()
    In [25]
plot_learning_curve(loss_record, title='loss', ylabel='CE Loss')
       
<Figure size 1000x600 with 1 Axes>
                In [26]
plot_learning_curve(acc_record, title='acc', ylabel='Accuracy')
       
<Figure size 1000x600 with 1 Axes>
                In [27]
import time
work_path = 'work/model'model = regionvit_tiny_224()
model_state_dict = paddle.load(os.path.join(work_path, 'best_model.pdparams'))
model.set_state_dict(model_state_dict)
model.eval()
aa = time.time()for batch_id, data in enumerate(val_loader):

    x_data, y_data = data
    labels = paddle.unsqueeze(y_data, axis=1)    with paddle.no_grad():
        logits = model(x_data)
bb = time.time()print("Throughout:{}".format(int(len(val_dataset)//(bb - aa))))
       
Throughout:678
        In [28]
def get_cifar10_labels(labels):
    """返回CIFAR10数据集的文本标签。"""
    text_labels = [        'airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog',        'horse', 'ship', 'truck']    return [text_labels[int(i)] for i in labels]
    In [29]
def show_images(imgs, num_rows, num_cols, pred=None, gt=None, scale=1.5):
    """Plot a list of images."""
    figsize = (num_cols * scale, num_rows * scale)
    _, axes = plt.subplots(num_rows, num_cols, figsize=figsize)
    axes = axes.flatten()    for i, (ax, img) in enumerate(zip(axes, imgs)):        if paddle.is_tensor(img):
            ax.imshow(img.numpy())        else:
            ax.imshow(img)
        ax.axes.get_xaxis().set_visible(False)
        ax.axes.get_yaxis().set_visible(False)        if pred or gt:
            ax.set_title("pt: " + pred[i] + "\ngt: " + gt[i])    return axes
    In [30]
work_path = 'work/model'X, y = next(iter(DataLoader(val_dataset, batch_size=18)))
model = regionvit_tiny_224()
model_state_dict = paddle.load(os.path.join(work_path, 'best_model.pdparams'))
model.set_state_dict(model_state_dict)
model.eval()
logits = model(X)
y_pred = paddle.argmax(logits, -1)
X = paddle.transpose(X, [0, 2, 3, 1])
axes = show_images(X.reshape((18, 224, 224, 3)), 1, 18, pred=get_cifar10_labels(y_pred), gt=get_cifar10_labels(y))
plt.show()
       
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
       
<Figure size 2700x150 with 18 Axes>
               

总结

        本文提出了一种从区域到局部的一种从粗到细的Transformer,既具有全局的感受野,又具有局部性,实现简单高效。

以上就是【ICLR 2025】RegionViT:从区域到局部的ViT的详细内容,更多请关注其它相关文章!


# 并将  # 淮北偃师网站建设  # 东营网站建设哪家优惠  # 搭建网站简历优化方法  # 天津广告推广招聘网站  # 网站搜索权重优化工具  # 左家庄网站优化  # 做网站建设详细教程视频  # 深刻seo优化总结  # 南通营销推广策划方案  # 曲阜线上seo策划培训  # 的是  # 贴片  # 官网  # 所示  # python  # 自然语言  # 提出了  # 一言  # 中文网  # 令牌  # type  # fig  # latte  # igs  # red  # cos  # ai  # git 


相关栏目: 【 行业资讯67740 】 【 技术百科0 】 【 网络运营39195


相关推荐: 夸克文字口令是什么意思  如何辨别固态硬盘坏块  如何打开命令提示符  夸克是什么用途  debian和ubuntu的区别是什么  春运抢票可以抢几次啊  春运抢票软件哪个最好用  自己如何加装固态硬盘  固态硬盘如何查看盘符  企业征信不好如何恢复 企业征信不好怎么恢复步骤  a股等权平均市盈率是什么意思  cron表达式在线工具有哪些  固态硬盘如何区分好坏  typescript掌握哪些可以做项目  破太岁是什么意思  苹果16自带配件有哪些  mac如何使用vi命令  typescript怎么用  360手机壁纸怎么改  calm是什么意思  如何使用ping命令  win7如何打开命令行窗口  夸克为什么老是投屏失败  如何通过命令系统还原  为什么夸克无法注销账户  苹果16系统有哪些缺陷  ftp$如何执行宏命令  商誉是什么意思  三菱变频器POWER是什么意思  台机如何安装固态硬盘  学typescript需要多久  固态硬盘损坏如何修复  typescript如何做项目  市盈率百分位roe是什么意思  电脑显示器上power是什么意思  春运抢票如何快速抢到票  什么网址不能域名解析  新的固态硬盘如何分区  空调power灯一直闪是什么意思  命令行如何启动应用程序  苹果16promax有哪些颜色  小屏折叠屏手机有哪些  sausage是什么意思  怎么更新typescript  折叠屏手机为什么这么小  汽车上power是什么意思  春运车票啥时候可以抢票  基金市盈率是什么意思  typescript接口有什么用  tft单片机怎么写彩屏 

搜索