yolov5中head怎么修改为decouple head(head,yolov5,开发技术)

时间:2024-05-05 00:35:07 作者 : 石家庄SEO 分类 : 开发技术
  • TAG :

yolox的decoupled head结构

yolov5中head怎么修改为decouple head

本来想将yolov5的head修改为decoupled head,与yolox的decouple head对齐,但是没注意,该成了如下结构:

yolov5中head怎么修改为decouple head

感谢少年肩上杨柳依依的指出,如还有问题欢迎指出

yolov5中head怎么修改为decouple head

1.修改models下的yolo.py文件中的Detect

classDetect(nn.Module):stride=None#stridescomputedduringbuildonnx_dynamic=False#ONNXexportparameterdef__init__(self,nc=80,anchors=(),ch=(),inplace=True):#detectionlayersuper().__init__()self.nc=nc#numberofclassesself.no=nc+5#numberofoutputsperanchorself.nl=len(anchors)#numberofdetectionlayersself.na=len(anchors[0])//2#numberofanchorsself.grid=[torch.zeros(1)]*self.nl#initgridself.anchor_grid=[torch.zeros(1)]*self.nl#initanchorgridself.register_buffer('anchors',torch.tensor(anchors).float().view(self.nl,-1,2))#shape(nl,na,2)#self.m=nn.ModuleList(nn.Conv2d(x,self.no*self.na,1)forxinch)#outputconvself.m_box=nn.ModuleList(nn.Conv2d(256,4*self.na,1)forxinch)#outputconvself.m_conf=nn.ModuleList(nn.Conv2d(256,1*self.na,1)forxinch)#outputconvself.m_labels=nn.ModuleList(nn.Conv2d(256,self.nc*self.na,1)forxinch)#outputconvself.base_conv=nn.ModuleList(BaseConv(in_channels=x,out_channels=256,ksize=1,stride=1)forxinch)self.cls_convs=nn.ModuleList(BaseConv(in_channels=256,out_channels=256,ksize=3,stride=1)forxinch)self.reg_convs=nn.ModuleList(BaseConv(in_channels=256,out_channels=256,ksize=3,stride=1)forxinch)#self.m=nn.ModuleList(nn.Conv2d(x,4*self.na,1)forxinch,nn.Conv2d(x,1*self.na,1)forxinch,nn.Conv2d(x,self.nc*self.na,1)forxinch)self.inplace=inplace#usein-placeops(e.g.sliceassignment)self.ch=chdefforward(self,x):z=[]#inferenceoutputforiinrange(self.nl):##x[i]=self.m[i](x[i])#convs#print("&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&",i)#print(x[i].shape)#print(self.base_conv[i])#print("%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%")x_feature=self.base_conv[i](x[i])#x_feature=x[i]cls_feature=self.cls_convs[i](x_feature)reg_feature=self.reg_convs[i](x_feature)#reg_feature=x_featurem_box=self.m_box[i](reg_feature)m_conf=self.m_conf[i](reg_feature)m_labels=self.m_labels[i](cls_feature)x[i]=torch.cat((m_box,m_conf,m_labels),1)bs,_,ny,nx=x[i].shape#x(bs,255,20,20)tox(bs,3,20,20,85)x[i]=x[i].view(bs,self.na,self.no,ny,nx).permute(0,1,3,4,2).contiguous()ifnotself.training:#inferenceifself.onnx_dynamicorself.grid[i].shape[2:4]!=x[i].shape[2:4]:self.grid[i],self.anchor_grid[i]=self._make_grid(nx,ny,i)y=x[i].sigmoid()ifself.inplace:y[...,0:2]=(y[...,0:2]*2-0.5+self.grid[i])*self.stride[i]#xyy[...,2:4]=(y[...,2:4]*2)**2*self.anchor_grid[i]#whelse:#forYOLOv5onAWSInferentiahttps://github.com/ultralytics/yolov5/pull/2953xy=(y[...,0:2]*2-0.5+self.grid[i])*self.stride[i]#xywh=(y[...,2:4]*2)**2*self.anchor_grid[i]#why=torch.cat((xy,wh,y[...,4:]),-1)z.append(y.view(bs,-1,self.no))returnxifself.trainingelse(torch.cat(z,1),x)

2.在yolo.py中添加

defget_activation(name="silu",inplace=True):ifname=="silu":module=nn.SiLU(inplace=inplace)elifname=="relu":module=nn.ReLU(inplace=inplace)elifname=="lrelu":module=nn.LeakyReLU(0.1,inplace=inplace)else:raiseAttributeError("Unsupportedacttype:{}".format(name))returnmoduleclassBaseConv(nn.Module):"""AConv2d->Batchnorm->silu/leakyrelublock"""def__init__(self,in_channels,out_channels,ksize,stride,groups=1,bias=False,act="silu"):super().__init__()#samepaddingpad=(ksize-1)//2self.conv=nn.Conv2d(in_channels,out_channels,kernel_size=ksize,stride=stride,padding=pad,groups=groups,bias=bias,)self.bn=nn.BatchNorm2d(out_channels)self.act=get_activation(act,inplace=True)defforward(self,x):#print(self.bn(self.conv(x)).shape)returnself.act(self.bn(self.conv(x)))#returnself.bn(self.conv(x))deffuseforward(self,x):returnself.act(self.conv(x))

decouple head的特点:

由于训练模型时,应该是channels = 256的地方改成了channels = x(失误),所以在decoupled head的部分参数量比yolox要大一些,以下的结果是在channels= x的情况下得出

比yolov5s参数多,计算量大,在我自己的2.5万的数据量下map提升了3%多

1.模型给出的目标cls较高,需要将conf的阈值设置较大(0.5),不然准确率较低

parser.add_argument('--conf-thres',type=float,default=0.5,help='confidencethreshold')

2.对于少样本的检测效果较好,召回率的提升比准确率多

3.在conf设置为0.25时,召回率比yolov5s高,但是准确率低;在conf设置为0.5时,召回率与准确率比yolov5s高

4.比yolov5s参数多,计算量大,在2.5万的数据量下map提升了3%多

对于decouple head的改进

yolov5中head怎么修改为decouple head

改进:

1.将红色框中的conv去掉,缩小参数量和计算量;

2.channels =256 ,512 ,1024是考虑不增加参数,不进行featuremap的信息压缩

classDetect(nn.Module):stride=None#stridescomputedduringbuildonnx_dynamic=False#ONNXexportparameterdef__init__(self,nc=80,anchors=(),ch=(),inplace=True):#detectionlayersuper().__init__()self.nc=nc#numberofclassesself.no=nc+5#numberofoutputsperanchorself.nl=len(anchors)#numberofdetectionlayersself.na=len(anchors[0])//2#numberofanchorsself.grid=[torch.zeros(1)]*self.nl#initgridself.anchor_grid=[torch.zeros(1)]*self.nl#initanchorgridself.register_buffer('anchors',torch.tensor(anchors).float().view(self.nl,-1,2))#shape(nl,na,2)self.m=nn.ModuleList(nn.Conv2d(x,self.no*self.na,1)forxinch)#outputconvself.inplace=inplace#usein-placeops(e.g.sliceassignment)defforward(self,x):z=[]#inferenceoutputforiinrange(self.nl):x[i]=self.m[i](x[i])#convbs,_,ny,nx=x[i].shape#x(bs,255,20,20)tox(bs,3,20,20,85)x[i]=x[i].view(bs,self.na,self.no,ny,nx).permute(0,1,3,4,2).contiguous()ifnotself.training:#inferenceifself.onnx_dynamicorself.grid[i].shape[2:4]!=x[i].shape[2:4]:self.grid[i],self.anchor_grid[i]=self._make_grid(nx,ny,i)y=x[i].sigmoid()ifself.inplace:y[...,0:2]=(y[...,0:2]*2-0.5+self.grid[i])*self.stride[i]#xyy[...,2:4]=(y[...,2:4]*2)**2*self.anchor_grid[i]#whelse:#forYOLOv5onAWSInferentiahttps://github.com/ultralytics/yolov5/pull/2953xy=(y[...,0:2]*2-0.5+self.grid[i])*self.stride[i]#xywh=(y[...,2:4]*2)**2*self.anchor_grid[i]#why=torch.cat((xy,wh,y[...,4:]),-1)z.append(y.view(bs,-1,self.no))returnxifself.trainingelse(torch.cat(z,1),x)

特点

1.模型给出的目标cls较高,需要将conf的阈值设置较大(0.4),不然准确率较低

2.对于少样本的检测效果较好,准确率的提升比召回率多

3. 准确率的提升比召回率多,

该改进不如上面的模型提升多,但是参数量小,计算量小少9Gflop,占用显存少

decoupled head指标提升的原因:由于yolov5s原本的head不能完全的提取featuremap中的信息,decoupled head能够较为充分的提取featuremap的信息;

 </div> <div class="zixun-tj-product adv-bottom"></div> </div> </div> <div class="prve-next-news">
本文:yolov5中head怎么修改为decouple head的详细内容,希望对您有所帮助,信息来源于网络。
上一篇:Spring如何加载properties文件下一篇:

21 人围观 / 0 条评论 ↓快速评论↓

(必须)

(必须,保密)

阿狸1 阿狸2 阿狸3 阿狸4 阿狸5 阿狸6 阿狸7 阿狸8 阿狸9 阿狸10 阿狸11 阿狸12 阿狸13 阿狸14 阿狸15 阿狸16 阿狸17 阿狸18