Pytorch中torch.flatten()和torch.nn.Flatten()怎么用
导读:本文共1121字符,通常情况下阅读需要4分钟。同时您也可以点击右侧朗读,来听本文内容。按键盘←(左) →(右) 方向键可以翻页。
摘要: torch.flatten(x)等于torch.flatten(x,0)默认将张量拉成一维的向量,也就是说从第一维开始平坦化,torch.flatten(x,1)代表从第二维开始平坦化。importtorchx=torch.randn(2,4,2)print(x)z=torch.flatten(x)print(z)w=torch.flatten(x... ...
音频解说
目录
(为您整理了一些要点),点击可以直达。torch.flatten(x)等于torch.flatten(x,0)默认将张量拉成一维的向量,也就是说从第一维开始平坦化,torch.flatten(x,1)代表从第二维开始平坦化。
torch.flatten(x,0,1)代表在第一维和第二维之间平坦化
对于torch.nn.Flatten(),因为其被用在神经网络中,输入为一批数据,第一维为batch,通常要把一个数据拉成一维,而不是将一批数据拉为一维。所以torch.nn.Flatten()默认从第二维开始平坦化。
本文:
Pytorch中torch.flatten()和torch.nn.Flatten()怎么用的详细内容,希望对您有所帮助,信息来源于网络。