Tensorflow权重迁移至Pytorch

本篇文章介绍将Tensorflow的Conv2D层、Dense层和BatchNorm层的权重迁移至Pytorch。关于将Pytorch迁移至Tensorflow可参考下面的博文:

Pytorch与Tensorflow权重互转_太阳花的小绿豆的博客-CSDN博客_tensorflow权重转pytorch

基本思路

Tensorflow的模型中的每一层一般都会有个name来指定该层的名称。获取某一层时可以使用model.get_layer(name)方法得到,最后使用layer.get_weights()获得权重。而Pytorch的模块里面并没有相关变量指定该层名称,我们可以重新封装这些模块,并指定一个变量来存放名字,这样可以按照Tensorflow模型的结构搭建Pytorch模型,并逐层迁移权重。

文章目录

Tensorflow权重迁移至Pytorch基本思路Conv2D层Dense层BatchNorm层逐层迁移权重Tensorflow模型对应的Pytorch 模型测试实验代码

Conv2D层

Tensorflow的数据维度为(B,H,W,C), 而Pytorch的数据维度为(B,C,H,W), 因此二者卷积层的权重矩阵也是不一样的。Pytorch的为(out_channels,in_channels,H,W), Tensorflow的为(H,W,in_channels,out_channels), 因此权重迁移时需要转置权重矩阵。

此外,如果卷积带有bias,layer.get_weights()返回长度为2的列表,第一个元素为权重矩阵,第二个元素为bias.

class Conv2dWithName(nn.Module):

def __init__(self,in_planes, out_planes, kernel_size=3, stride=1,padding=0, groups=1, use_bias=True, dilation=1,name=None):

super(Conv2dWithName, self).__init__()

self.conv2d=nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride,

padding=padding, groups=groups, bias=use_bias, dilation=dilation)

self.name=name #存储模块名称

self.use_bias=use_bias

def forward(self,x):

return self.conv2d(x)

def set_weight(self,layer):

with torch.no_grad():

print('INFO: init layer %s with tf weights'%self.name)

weights=layer.get_weights()

weight=weights[0]

weight=torch.from_numpy(weight)

weight=weight.permute((3,2,0,1))

self.conv2d.weight.copy_(weight)

if self.use_bias:

bias=weights[1]

bias = torch.from_numpy(bias)

self.conv2d.bias.copy_(bias)

Dense层

类似的,dense层包含weight,bias两个权重参数。需要注意的是Pytorch的weight维度为(out_dims,in_dims),而Tensorflow正好相反为(in_dims,out_dims)。

class DenseWithName(nn.Module):

def __init__(self,in_dim,out_dim,name=None):

super(DenseWithName, self).__init__()

self.dense=nn.Linear(in_dim,out_dim)

self.name=name

def set_weight(self,layer):

print('INFO: init layer %s with tf weights' % self.name)

with torch.no_grad():

weights = layer.get_weights()

weight = torch.from_numpy(weights[0]).transpose(0, 1)

self.dense.weight.copy_(weight)

bias = weights[1]

bias = torch.from_numpy(bias)

self.dense.bias.copy_(bias)

def forward(self,x):

return self.dense(x)

BatchNorm层

BatchNorm需要迁移weight、bias、running_mean、running_var四个参数。

class BatchNorm2dWithName(nn.Module):

def __init__(self,n_chaanels,name=None):

super(BatchNorm2dWithName, self).__init__()

self.bn=nn.BatchNorm2d(n_chaanels)

self.name=name

def forward(self,x):

return self.bn(x)

def set_weight(self,layer):

with torch.no_grad():

print('INFO: init layer %s with tf weights' % self.name)

weights=layer.get_weights()

gamma=torch.from_numpy(weights[0])

beta=torch.from_numpy(weights[1])

run_mean=torch.from_numpy(weights[2])

run_var= torch.from_numpy(weights[3])

self.bn.bias.copy_(beta)

self.bn.running_mean.copy_(run_mean)

self.bn.running_var.copy_(run_var)

self.bn.weight.copy_(gamma)

逐层迁移权重

我们可以参照已有的Tensorflow模型结构,利用上述封装好的层来搭建深度模型。迁移权重时可以遍历模型的所有层,逐层迁移权重。

for m in self.modules():#遍历模型的所有模块

if isinstance(m, (Conv2dWithName,BatchNorm2dWithName,DenseWithName)):

layer=tf_model.get_layer(m.name)

m.set_weight(layer)

下面以ResNet50为例测试权重迁移:

Tensorflow模型

def block1(x, filters, kernel_size=3, stride=1,

conv_shortcut=True, name=None):

"""A residual block.

# Arguments

x: input tensor.

filters: integer, filters of the bottleneck layer.

kernel_size: default 3, kernel size of the bottleneck layer.

stride: default 1, stride of the first layer.

conv_shortcut: default True, use convolution shortcut if True,

otherwise identity shortcut.

name: string, block label.

# Returns

Output tensor for the residual block.

"""

bn_axis = 3

if conv_shortcut is True:

shortcut = layers.Conv2D(4 * filters, 1, strides=stride,

name=name + '_0_conv')(x)

shortcut = layers.BatchNormalization(axis=bn_axis, epsilon=1.001e-5,

name=name + '_0_bn')(shortcut)

else:

shortcut = x

x = layers.Conv2D(filters, 1, strides=stride, name=name + '_1_conv')(x)

x = layers.BatchNormalization(axis=bn_axis, epsilon=1.001e-5,

name=name + '_1_bn')(x)

x = layers.Activation('relu', name=name + '_1_relu')(x)

x = layers.Conv2D(filters, kernel_size, padding='SAME',

name=name + '_2_conv')(x)

x = layers.BatchNormalization(axis=bn_axis, epsilon=1.001e-5,

name=name + '_2_bn')(x)

x = layers.Activation('relu', name=name + '_2_relu')(x)

x = layers.Conv2D(4 * filters, 1, name=name + '_3_conv')(x)

x = layers.BatchNormalization(axis=bn_axis, epsilon=1.001e-5,

name=name + '_3_bn')(x)

x = layers.Add(name=name + '_add')([shortcut, x])

x = layers.Activation('relu', name=name + '_out')(x)

return x

def stack1(x, filters, blocks, stride1=2, name=None):

"""A set of stacked residual blocks.

# Arguments

x: input tensor.

filters: integer, filters of the bottleneck layer in a block.

blocks: integer, blocks in the stacked blocks.

stride1: default 2, stride of the first layer in the first block.

name: string, stack label.

# Returns

Output tensor for the stacked blocks.

"""

x = block1(x, filters, stride=stride1, name=name + '_block1')

for i in range(2, blocks + 1):

x = block1(x, filters, conv_shortcut=False, name=name + '_block' + str(i))

return x

def ResNet50_TF(inputs,

preact=False,

use_bias=True,

model_name='resnet50'):

bn_axis = 3

x = layers.ZeroPadding2D(padding=((3, 3), (3, 3)), name='conv1_pad')(inputs)

x = layers.Conv2D(64, 7, strides=2, use_bias=use_bias, name='conv1_conv')(x)

if preact is False:

x = layers.BatchNormalization(axis=bn_axis, epsilon=1.001e-5,

name='conv1_bn')(x)

x = layers.Activation('relu', name='conv1_relu')(x)

x = layers.ZeroPadding2D(padding=((1, 1), (1, 1)), name='pool1_pad')(x)

x = layers.MaxPooling2D(3, strides=2, name='pool1_pool')(x)

outputs = []

x = stack1(x, 64, 3, stride1=1, name='conv2')

x = stack1(x, 128, 4, name='conv3')

x = stack1(x, 256, 6, name='conv4')

x = stack1(x, 512, 3, name='conv5')

x = layers.GlobalAveragePooling2D(name='avg_pool')(x)

x = layers.Dense(1, activation='linear', name='final_fc')(x)

# Create model.

model = models.Model(inputs, x, name=model_name)

return model

注意上述ResNet50模型并不是一个原始的Resnet,它的输出维度为1。

对应的Pytorch 模型

def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1,name=None):

"""3x3 convolution with padding"""

return Conv2dWithName(in_planes, out_planes, kernel_size=3, stride=stride,

padding=dilation, groups=groups, bias=True, dilation=dilation,name=name)

def conv1x1(in_planes, out_planes, stride=1,name=None):

"""1x1 convolution"""

return Conv2dWithName(in_planes, out_planes, kernel_size=1, stride=stride, bias=True,name=name)

class BasicBlock(nn.Module):

expansion = 1

def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,

base_width=64, dilation=1, norm_layer=None):

super(BasicBlock, self).__init__()

if norm_layer is None:

norm_layer = nn.BatchNorm2d

if groups != 1 or base_width != 64:

raise ValueError('BasicBlock only supports groups=1 and base_width=64')

if dilation > 1:

raise NotImplementedError("Dilation > 1 not supported in BasicBlock")

# Both self.conv1 and self.downsample layers downsample the input when stride != 1

self.conv1 = conv3x3(inplanes, planes, stride)

self.bn1 = norm_layer(planes)

self.relu = nn.ReLU(inplace=True)

self.conv2 = conv3x3(planes, planes)

self.bn2 = norm_layer(planes)

self.downsample = downsample

self.stride = stride

def forward(self, x):

identity = x

out = self.conv1(x)

out = self.bn1(out)

out = self.relu(out)

out = self.conv2(out)

out = self.bn2(out)

if self.downsample is not None:

identity = self.downsample(x)

out += identity

out = self.relu(out)

return out

class Bottleneck(nn.Module):

expansion = 4

def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,

base_width=64, dilation=1, norm_layer=None,name=None):

super(Bottleneck, self).__init__()

if norm_layer is None:

norm_layer = BatchNorm2dWithName

width = int(planes * (base_width / 64.)) * groups

# Both self.conv2 and self.downsample layers downsample the input when stride != 1

self.conv1 = conv1x1(inplanes, width,stride=stride,name=name+'_1_conv')

self.bn1 = norm_layer(width,name=name+'_1_bn')

self.conv2 = conv3x3(width, width, name=name+'_2_conv')

self.bn2 = norm_layer(width,name=name+'_2_bn')

self.conv3 = conv1x1(width, planes * self.expansion,name=name+'_3_conv')

self.bn3 = norm_layer(planes * self.expansion,name=name+'_3_bn')

self.relu = nn.ReLU(inplace=True)

self.downsample = downsample

if not self.downsample is None:

self.downsample[0].name=name+'_0_conv'

self.downsample[1].name = name + '_0_bn'

self.stride = stride

self.name=name

def forward(self, x):

identity = x

out = checkpoint(self.conv1,x)

out = checkpoint(self.bn1,out)

out = self.relu(out)

out = checkpoint(self.conv2,out)

out = checkpoint(self.bn2,out)

out = self.relu(out)

out = checkpoint(self.conv3,out)

out = checkpoint(self.bn3,out)

if self.downsample is not None:

identity = checkpoint(self.downsample,x)

out += identity

out = self.relu(out)

return out

class ResNet(nn.Module):

def __init__(self, block, layers, width_per_group=64):

super(ResNet, self).__init__()

norm_layer = BatchNorm2dWithName

self._norm_layer = norm_layer

self.inplanes = 64

self.dilation = 1

self.base_width = width_per_group

self.conv1 = Conv2dWithName(3, self.inplanes, kernel_size=7, stride=2, padding=3,

bias=True,name='conv1_conv')

self.bn1 = norm_layer(self.inplanes,name='conv1_bn')

self.relu = nn.ReLU(inplace=True)

self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)

self.layer1 = self._make_layer(block, 64, layers[0],name='conv2')

self.layer2 = self._make_layer(block, 128, layers[1], stride=2,

name='conv3')

self.layer3 = self._make_layer(block, 256, layers[2], stride=2,

name='conv4')

self.layer4 = self._make_layer(block, 512, layers[3], stride=2,

name='conv5')

self.avgpool = nn.AdaptiveAvgPool2d((1, 1))

self.final_fc=DenseWithName(2048,1,name='final_fc')

for m in self.modules():

if isinstance(m, nn.Conv2d):

nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')

elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):

nn.init.constant_(m.weight, 1)

nn.init.constant_(m.bias, 0)

def _make_layer(self, block, planes, blocks, stride=1, name=None):

norm_layer = self._norm_layer

downsample = nn.Sequential(

conv1x1(self.inplanes, planes * block.expansion, stride=stride),

norm_layer(planes * block.expansion),

)

layers = []

layers.append(block(inplanes=self.inplanes, planes=planes, stride=stride, downsample=downsample,

name=name+'_block1'))

self.inplanes = planes * block.expansion

for lyer in range(1, blocks):

layers.append(block(self.inplanes, planes, base_width=self.base_width, dilation=self.dilation,

name=name+'_block%d'%(lyer+1)))

return nn.Sequential(*layers)

def init_from_tf(self,tf_model):

for m in self.modules():

if isinstance(m, (Conv2dWithName,BatchNorm2dWithName,DenseWithName)):

layer=tf_model.get_layer(m.name)

m.set_weight(layer)

def _forward_impl(self, x):

# See note [TorchScript super()]

x = checkpoint(self.conv1,x)

x = checkpoint(self.bn1,x)

x = self.relu(x)

x = F.max_pool2d(x,kernel_size=3, stride=2, padding=1)

x = self.layer1(x)

x = self.layer2(x)

x = self.layer3(x)

x = self.layer4(x)

x=self.avgpool(x).squeeze(-1).squeeze(-1)

x=self.final_fc(x)

return x

def forward(self, x):

return self._forward_impl(x)

def _resnet(arch, block, layers, **kwargs):

model = ResNet(block, layers, **kwargs)

return model

def resnet50_torch(**kwargs):

return _resnet('resnet50', Bottleneck, [3, 4, 6, 3], **kwargs)

测试实验代码

input_shape = (None, None, 3)

inputs = Input(shape=input_shape)

res50_tf=ResNet50_TF(inputs)

res50_tf.load_weights('./src/Resnet——weights.h5',by_name=True)

res50_torch=resnet50_torch().float()

res50_torch.init_from_tf(res50_tf)

res50_torch.eval()

img=np.random.rand(1,224,224,3)

img2=torch.from_numpy(img).permute([0,3,1,2]).float()

p_tf=res50_tf.predict(img)

p_torch=res50_torch(img2).data.numpy()

print('tensorflow predict: %f '%p_tf[0])

print('pytorch predict: %f '%p_torch[0])

输出结果

文章来源

评论可见,请评论后查看内容,谢谢!!!评论后请刷新页面。