700字范文,内容丰富有趣,生活中的好帮手!
700字范文 > pytorch实现常用的一些即插即用模块(长期更新)

pytorch实现常用的一些即插即用模块(长期更新)

时间:2022-01-11 22:04:58

相关推荐

pytorch实现常用的一些即插即用模块(长期更新)

1.可分离卷积

#coding:utf-8import torch.nn as nnclass DWConv(nn.Module):def __init__(self, in_plane, out_plane):super(DWConv, self).__init__()self.depth_conv = nn.Conv2d(in_channels=in_plane,out_channels=in_plane,kernel_size=3,stride=1,padding=1,groups=in_plane)self.point_conv = nn.Conv2d(in_channels=in_plane,out_channels=out_plane,kernel_size=1,stride=1,padding=0,groups=1)def forward(self, x):x = self.depth_conv(x)x = self.point_conv(x)return xdef deubg_dw():import torchDW_model = DWConv(3, 32)x = torch.rand((32, 3, 320, 320))out = DW_model(x)print(out.shape)if __name__ == '__main__':deubg_dw()

2.DBnet论文中的DBhead

#coding:utf-8import torchfrom torch import nnclass DBHead(nn.Module):def __init__(self, in_channels, out_channels, k=50):super().__init__()self.k = kself.binarize = nn.Sequential(nn.Conv2d(in_channels, in_channels // 4, 3, padding=1),nn.BatchNorm2d(in_channels // 4),nn.ReLU(inplace=True),nn.ConvTranspose2d(in_channels // 4, in_channels // 4, 2, 2),nn.BatchNorm2d(in_channels // 4),nn.ReLU(inplace=True),nn.ConvTranspose2d(in_channels // 4, 1, 2, 2),nn.Sigmoid())self.binarize.apply(self.weights_init)self.thresh = self._init_thresh(in_channels)self.thresh.apply(self.weights_init)def forward(self, x):shrink_maps = self.binarize(x)threshold_maps = self.thresh(x)if self.training:#从父类继承的变量, train的时候默认是true, eval的时候会变为falsebinary_maps = self.step_function(shrink_maps, threshold_maps)y = torch.cat((shrink_maps, threshold_maps, binary_maps), dim=1)else:y = torch.cat((shrink_maps, threshold_maps), dim=1)return ydef weights_init(self, m):classname = m.__class__.__name__if classname.find('Conv') != -1:nn.init.kaiming_normal_(m.weight.data)elif classname.find('BatchNorm') != -1:m.weight.data.fill_(1.)m.bias.data.fill_(1e-4)def _init_thresh(self, inner_channels, serial=False, smooth=False, bias=False):in_channels = inner_channelsif serial:in_channels += 1self.thresh = nn.Sequential(nn.Conv2d(in_channels, inner_channels // 4, 3, padding=1, bias=bias),nn.BatchNorm2d(inner_channels // 4),nn.ReLU(inplace=True),self._init_upsample(inner_channels // 4, inner_channels // 4, smooth=smooth, bias=bias),nn.BatchNorm2d(inner_channels // 4),nn.ReLU(inplace=True),self._init_upsample(inner_channels // 4, 1, smooth=smooth, bias=bias),nn.Sigmoid())return self.threshdef _init_upsample(self, in_channels, out_channels, smooth=False, bias=False):if smooth:inter_out_channels = out_channelsif out_channels == 1:inter_out_channels = in_channelsmodule_list = [nn.Upsample(scale_factor=2, mode='nearest'),nn.Conv2d(in_channels, inter_out_channels, 3, 1, 1, bias=bias)]if out_channels == 1:module_list.append(nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=1, bias=True))return nn.Sequential(module_list)else:return nn.ConvTranspose2d(in_channels, out_channels, 2, 2)def step_function(self, x, y):return torch.reciprocal(1 + torch.exp(-self.k * (x - y)))def debug_main():x = torch.rand((8, 256, 160, 160))head_model = DBHead(in_channels=256, out_channels=2)head_model.train()y = head_model(x)print('==y.shape:', y.shape)head_model.eval()y = head_model(x)print('==y.shape:', y.shape)if __name__ == '__main__':debug_main()

3.sENet中的attention

目的对于不同通道进行加权,先squeeze将h*w*cglobal averge pooling成1*1*c特征,在经过两层线性层,通过sigmoid输出加权在不同通道。

import torchimport torch.nn as nnimport torch.nn.functional as Fclass SELayer(nn.Module):def __init__(self, channel, reduction=16):super(SELayer, self).__init__()self.avg_pool = nn.AdaptiveAvgPool2d(1) # 压缩空间self.fc = nn.Sequential(nn.Linear(channel, channel // reduction, bias=False),nn.ReLU(inplace=True),nn.Linear(channel // reduction, channel, bias=False),nn.Sigmoid())def forward(self, x):b, c, _, _ = x.size()y = self.avg_pool(x).view(b, c)y = self.fc(y).view(b, c, 1, 1)return x * ydef debug_attention():attention_module = SELayer(channel=128, reduction=16)# B,C,H,Wx = torch.rand((2, 128, 100, 100))out = attention_module(x)print('==out.shape:', out.shape)if __name__ == '__main__':debug_attention()

4.cv中的self-attention

(1).feature map通过1*1卷积获得,q,k,v三个向量,q与v转置相乘得到attention矩阵,进行softmax归一化到0到1,在作用于V,得到每个像素的加权.

(2).softmax

(3).加权求和

import torchimport torch.nn as nnimport torch.nn.functional as Fclass Self_Attn(nn.Module):""" Self attention Layer"""def __init__(self, in_dim):super(Self_Attn, self).__init__()self.chanel_in = in_dimself.query_conv = nn.Conv2d(in_channels=in_dim, out_channels=in_dim // 8, kernel_size=1)self.key_conv = nn.Conv2d(in_channels=in_dim, out_channels=in_dim // 8, kernel_size=1)self.value_conv = nn.Conv2d(in_channels=in_dim, out_channels=in_dim, kernel_size=1)self.gamma = nn.Parameter(torch.zeros(1))self.softmax = nn.Softmax(dim=-1)def forward(self, x):"""inputs :x : input feature maps( B * C * W * H)returns :out : self attention value + input featureattention: B * N * N (N is Width*Height)"""m_batchsize, C, width, height = x.size()proj_query = self.query_conv(x).view(m_batchsize, -1, width * height).permute(0, 2, 1) # B*N*Cproj_key = self.key_conv(x).view(m_batchsize, -1, width * height) # B*C*Nenergy = torch.bmm(proj_query, proj_key) # batch的matmul B*N*Nattention = self.softmax(energy) # B * (N) * (N)proj_value = self.value_conv(x).view(m_batchsize, -1, width * height) # B * C * Nout = torch.bmm(proj_value, attention.permute(0, 2, 1)) # B*C*Nout = out.view(m_batchsize, C, width, height) # B*C*H*Wout = self.gamma * out + xreturn out, attentiondef debug_attention():attention_module = Self_Attn(in_dim=128)#B,C,H,Wx = torch.rand((2, 128, 100, 100))attention_module(x)if __name__ == '__main__':debug_attention()

5.spp多窗口pooling

import torchimport torch.nn as nnimport torch.nn.functional as Fclass SPP(nn.Module):"""Spatial Pyramid Pooling"""def __init__(self):super(SPP, self).__init__()def forward(self, x):x_1 = F.max_pool2d(x, kernel_size=5, stride=1, padding=2)x_2 = F.max_pool2d(x, kernel_size=9, stride=1, padding=4)x_3 = F.max_pool2d(x, kernel_size=13, stride=1, padding=6)x = torch.cat([x, x_1, x_2, x_3], dim=1)return xdef debug_spp():x = torch.rand((8,3,256,256))spp = SPP()x = spp(x)print('==x.shape:', x.shape)if __name__ == '__main__':debug_spp()

6.RetinaFPN

# coding: utf-8import torchimport torch.nn as nnimport torch.nn.functional as Fclass RetinaFPN(nn.Module):def __init__(self,C3_inplanes,C4_inplanes,C5_inplanes,planes,use_p5=False):super(RetinaFPN, self).__init__()self.use_p5 = use_p5self.P3_1 = nn.Conv2d(C3_inplanes,planes,kernel_size=1,stride=1,padding=0)self.P3_2 = nn.Conv2d(planes,planes,kernel_size=3,stride=1,padding=1)self.P4_1 = nn.Conv2d(C4_inplanes,planes,kernel_size=1,stride=1,padding=0)self.P4_2 = nn.Conv2d(planes,planes,kernel_size=3,stride=1,padding=1)self.P5_1 = nn.Conv2d(C5_inplanes,planes,kernel_size=1,stride=1,padding=0)self.P5_2 = nn.Conv2d(planes,planes,kernel_size=3,stride=1,padding=1)if self.use_p5:self.P6 = nn.Conv2d(planes,planes,kernel_size=3,stride=2,padding=1)else:self.P6 = nn.Conv2d(C5_inplanes,planes,kernel_size=3,stride=2,padding=1)self.P7 = nn.Sequential(nn.ReLU(),nn.Conv2d(planes, planes, kernel_size=3, stride=2, padding=1))def forward(self, inputs):[C3, C4, C5] = inputsP5 = self.P5_1(C5)P4 = self.P4_1(C4)P4 = F.interpolate(P5, size=(P4.shape[2], P4.shape[3]),mode='nearest') + P4P3 = self.P3_1(C3)P3 = F.interpolate(P4, size=(P3.shape[2], P3.shape[3]),mode='nearest') + P3P5 = self.P5_2(P5)P4 = self.P4_2(P4)P3 = self.P3_2(P3)if self.use_p5:P6 = self.P6(P5)else:P6 = self.P6(C5)del C3, C4, C5P7 = self.P7(P6)return [P3, P4, P5, P6, P7]if __name__ == '__main__':image_h, image_w = 640, 640fpn = RetinaFPN(512, 1024, 2048, 256)C3, C4, C5 = torch.randn(3, 512, 80, 80), torch.randn(3, 1024, 40, 40), torch.randn(3, 2048, 20, 20)[P3, P4, P5, P6, P7] = fpn([C3, C4, C5])print("P3", P3.shape)print("P4", P4.shape)print("P5", P5.shape)print("P6", P6.shape)print("P7", P7.shape)

7.Focus

import torchimport torch.nn as nndef autopad(k, p=None): # kernel, padding# Pad to 'same'if p is None:p = k // 2 if isinstance(k, int) else [x // 2 for x in k] # auto-pad# print('==p:', p)return pclass Conv(nn.Module):# Standard convolutiondef __init__(self, c1, c2, k=1, s=1, p=None, g=1, act=True): # ch_in, ch_out, kernel, stride, padding, groupssuper(Conv, self).__init__()self.conv = nn.Conv2d(c1, c2, k, s, autopad(k, p), groups=g, bias=False)self.bn = nn.BatchNorm2d(c2)self.act = nn.Hardswish() if act else nn.Identity()def forward(self, x):return self.act(self.bn(self.conv(x)))def fuseforward(self, x):return self.act(self.conv(x))class Focus(nn.Module):# Focus wh information into c-spacedef __init__(self, c1, c2, k=1, s=1, p=None, g=1, act=True): # ch_in, ch_out, kernel, stride, padding, groupssuper(Focus, self).__init__()self.conv = Conv(c1 * 4, c2, k, s, p, g, act)def forward(self, x): # x(b,c,w,h) -> y(b,4c,w/2,h/2)return self.conv(torch.cat([x[..., ::2, ::2], x[..., 1::2, ::2], x[..., ::2, 1::2], x[..., 1::2, 1::2]], 1))def debug_focus():model = Focus(c1=3, c2=24)img = torch.rand((8, 3, 124, 124))print('==img.shape', img.shape)out = model(img)print('===out.shape', out.shape)debug_focus()

本内容不代表本网观点和政治立场,如有侵犯你的权益请联系我们处理。
网友评论
网友评论仅供其表达个人看法,并不表明网站立场。