Python tensorflow与pytorch的浮点运算数怎么计算(python,pytorch,tensorflow,开发技术)

时间:2024-05-06 00:20:36 作者 : 石家庄SEO 分类 : 开发技术
  • TAG :

    1. 引言

    FLOPs 是 floating point operations 的缩写,指浮点运算数,可以用来衡量模型/算法的计算复杂度。本文主要讨论如何在 tensorflow 1.x, tensorflow 2.x 以及 pytorch 中利用相关工具计算对应模型的 FLOPs。

    2. 模型结构

    为了说明方便,先搭建一个简单的神经网络模型,其模型结构以及主要参数如表1 所示。

    表 1 模型结构及主要参数

    LayerschannelsKernelsStridesUnitsActivationConv2D32(4,4)(1,2)\reluGRU\\96\Dense\\256sigmoid

    用 tensorflow(实际使用 tensorflow 中的 keras 模块)实现该模型的代码为:

    fromtensorflow.keras.layersimport*
    fromtensorflow.keras.modelsimportload_model,Model
    deftest_model_tf(Input_shape):

    shape:[B,C,T,F]

    main_input=Input(batch_shape=Input_shape,name='main_inputs')
    conv=Conv2D(32,kernel_size=(4,4),strides=(1,2),activation='relu',data_format='channels_first',name='conv')(main_input)

    shape:[B,T,FC]

    gru=Reshape((conv.shape[2],conv.shape[1]*conv.shape[3]))(conv)
    gru=GRU(units=96,reset_after=True,return_sequences=True,name='gru')(gru)
    output=Dense(256,activation='sigmoid',name='output')(gru)
    model=Model(inputs=[main_input],outputs=[output])
    returnmodel

    用 pytorch 实现该模型的代码为:

    importtorch
    importtorch.nnasnn
    classtest_model_torch(nn.Module):
    definit(self):
    super(test_model_torch,self).init()
    self.conv2d=nn.Conv2d(in_channels=1,out_channels=32,kernel_size=(4,4),stride=(1,2))
    self.relu=nn.ReLU()
    self.gru=nn.GRU(input_size=4064,hidden_size=96)
    self.fc=nn.Linear(96,256)
    self.sigmoid=nn.Sigmoid()
    defforward(self,inputs):

    shape:[B,C,T,F]

    out=self.conv2d(inputs)
    out=self.relu(out)

    shape:[B,T,FC]

    batch,channel,frame,freq=out.size()
    out=torch.reshape(out,(batch,frame,freqchannel))
    out,_=self.gru(out)
    out=self.fc(out)
    out=self.sigmoid(out)
    returnout

    3. 计算模型的 FLOPs

    本节讨论的版本具体为:tensorflow 1.12.0, tensorflow 2.3.1 以及 pytorch 1.10.1+cu102。

    3.1. tensorflow 1.12.0

    在 tensorflow 1.12.0 环境中,可以使用以下代码计算模型的 FLOPs:

    importtensorflowastf
    importtensorflow.keras.backendasK
    defget_flops(model):
    run_meta=tf.RunMetadata()
    opts=tf.profiler.ProfileOptionBuilder.float_operation()
    flops=tf.profiler.profile(graph=K.get_session().graph,
    run_meta=run_meta,cmd='op',options=opts)
    returnflops.total_float_ops
    ifname=="main":
    x=K.random_normal(shape=(1,1,100,256))
    model=test_model_tf(x.shape)
    print('FLOPsoftensorflow1.12.0:',get_flops(model))

    3.2. tensorflow 2.3.1

    在 tensorflow 2.3.1 环境中,可以使用以下代码计算模型的 FLOPs :

    importtensorflow.compat.v1astf
    importtensorflow.compat.v1.keras.backendasK
    tf.disable_eager_execution()
    defget_flops(model):
    run_meta=tf.RunMetadata()
    opts=tf.profiler.ProfileOptionBuilder.float_operation()
    flops=tf.profiler.profile(graph=K.get_session().graph,
    run_meta=run_meta,cmd='op',options=opts)
    returnflops.total_float_ops
    ifname=="main":
    x=K.random_normal(shape=(1,1,100,256))
    model=test_model_tf(x.shape)
    print('FLOPsoftensorflow2.3.1:',get_flops(model))

    3.3. pytorch 1.10.1+cu102

    在 pytorch 1.10.1+cu102 环境中,可以使用以下代码计算模型的 FLOPs(需要安装 thop):

    importthop
    x=torch.randn(1,1,100,256)
    model=test_modeltorch()
    flops,
    =thop.profile(model,inputs=(x,))
    print('FLOPsofpytorch1.10.1:',flops
    2)

    需要注意的是,thop 返回的是 MACs (Multiply–Accumulate Operations),其等于 2 2 2 倍的 FLOPs,所以上述代码有乘 2 2 2 操作。

    3.4. 结果对比

    三者计算出的 FLOPs 分别为:

    tensorflow 1.12.0:

    Python tensorflow与pytorch的浮点运算数怎么计算

    tensorflow 2.3.1:

    Python tensorflow与pytorch的浮点运算数怎么计算

    pytorch 1.10.1:

    Python tensorflow与pytorch的浮点运算数怎么计算

    可以看到 tensorflow 1.12.0 和 tensorflow 2.3.1 的结果基本在同一个量级,而与 pytorch 1.10.1 计算出来的相差甚远。但如果将上述模型结构改为只包含第一层 Conv2D,三者计算出来的 FLOPs 却又是一致的。所以推断差异主要来自于 GRU 的 FLOPs。

    本文:Python tensorflow与pytorch的浮点运算数怎么计算的详细内容,希望对您有所帮助,信息来源于网络。
    上一篇:Android Java try catch失效问题如何解决下一篇:

    18 人围观 / 0 条评论 ↓快速评论↓

    (必须)

    (必须,保密)

    阿狸1 阿狸2 阿狸3 阿狸4 阿狸5 阿狸6 阿狸7 阿狸8 阿狸9 阿狸10 阿狸11 阿狸12 阿狸13 阿狸14 阿狸15 阿狸16 阿狸17 阿狸18