BigGAN代码解读(gpt3.5的帮助)——批标准化(BN)部分
这个版本的BigGAN开发人员编写的BN层代码内部耦合度非常高,需要一步步的理解。
作者学习记录
代码来源:GitHub关于BigGAN点赞最多的代码
这个版本的BigGAN开发人员编写的BN层代码内部耦合度非常高,需要一步步的理解。
首先在BigGAN.py中,有着关于生成器的代码,其中关于选择BN层:
self.which_bn = functools.partial(layers.ccbn,
which_linear=bn_linear,
cross_replica=self.cross_replica,
mybn=self.mybn,
input_size=(self.shared_dim + self.z_chunk_size if self.G_shared
else self.n_classes),
norm_style=self.norm_style,
eps=self.BN_eps)
这个部分为选择BN层的代码,可以发现这段代码的核心参数为layers.ccbn,即self.which_bn的构建是基于layers.ccbn的,之后寻找layers.ccbn这段代码:
# Class-conditional bn
# output size is the number of channels, input size is for the linear layers
# Andy's Note: this class feels messy but I'm not really sure how to clean it up
# Suggestions welcome! (By which I mean, refactor this and make a pull request
# if you want to make this more readable/usable).
class ccbn(nn.Module):
# ccbn用途就是将x经过某种bn操作的结果,再一次与基于类别信息得到的gain与bias进行计算
def __init__(self, output_size, input_size, which_linear, eps=1e-5, momentum=0.1,
cross_replica=False, mybn=False, norm_style='bn',):
super(ccbn, self).__init__()
self.output_size, self.input_size = output_size, input_size
# Prepare gain and bias layers
self.gain = which_linear(input_size, output_size)
self.bias = which_linear(input_size, output_size)
# epsilon to avoid dividing by 0
self.eps = eps
# Momentum
self.momentum = momentum
# Use cross-replica batchnorm?
# 多个GPU进行批归一化的技术,有助于提高性能
self.cross_replica = cross_replica
# Use my batchnorm?
self.mybn = mybn
# Norm style?
self.norm_style = norm_style
if self.cross_replica:
self.bn = SyncBN2d(output_size, eps=self.eps, momentum=self.momentum, affine=False)
elif self.mybn:
self.bn = myBN(output_size, self.eps, self.momentum)
elif self.norm_style in ['bn', 'in']:
self.register_buffer('stored_mean', torch.zeros(output_size))
self.register_buffer('stored_var', torch.ones(output_size))
def forward(self, x, y):
# Calculate class-conditional gains and biases
# 类别信息y是经过某种线性变换,从而提供了与类别有关的BN层的gain与bias
gain = (1 + self.gain(y)).view(y.size(0), -1, 1, 1)
bias = self.bias(y).view(y.size(0), -1, 1, 1)
# 几种用于规范化x的BN操作
# If using my batchnorm
if self.mybn or self.cross_replica:
# 这里的gain和bias是类别信息
return self.bn(x, gain=gain, bias=bias)
# else:
else:
if self.norm_style == 'bn':
out = F.batch_norm(x, self.stored_mean, self.stored_var, None, None,
self.training, 0.1, self.eps)
elif self.norm_style == 'in':
out = F.instance_norm(x, self.stored_mean, self.stored_var, None, None,
self.training, 0.1, self.eps)
elif self.norm_style == 'gn':
out = groupnorm(x, self.normstyle)
elif self.norm_style == 'nonorm':
out = x
# 标准化后的x在与基于类别信息得到的gain与bias进行计算
return out * gain + bias
def extra_repr(self):
s = 'out: {output_size}, in: {input_size},'
s +=' cross_replica={cross_replica}'
return s.format(**self.__dict__)
在forward函数中,参数x为BN层的输入,有可能为上一层的输出,y为希望生成样本的类别,可以看到self.gain与self.bias为ccbn类初始化时定义的某种线性层,不难发现在进行有类别条件的批标准化时,ccbn首先对类别信息进行变换,得到与类别有关的两个参数gain与bias;紧接着后续进行选择批标准化的种类,其中有5类批标准化操作,以及一个恒等映射,在经过批标准化的处理之后,选择将处理之后得到的结果与类别信息进行组合,从而为生成的图像赋予类别信息,BigGAN是以SAGAN作为BaseLine的,其类别信息是通过BN层赋予的。
之后尝试对mybn类进行理解,进而更好的理解类别参与生成样本的方式。想要理解mybn类,首先要了解fused_bn函数,fused_bn代码如下:
# Fused batchnorm op
# 其实就是缩放和平移的过程
# 这里的gain和bias是类别信息 或者在普通的无条件bn中,就为可学习的参数
def fused_bn(x, mean, var, gain=None, bias=None, eps=1e-5):
# Apply scale and shift--if gain and bias are provided, fuse them here
# Prepare scale
scale = torch.rsqrt(var + eps)
# If a gain is provided, use it
#
if gain is not None:
scale = scale * gain
# Prepare shift
shift = mean * scale
# If bias is provided, use it
if bias is not None:
shift = shift - bias
return x * scale - shift
#return ((x - mean) / ((var + eps) ** 0.5)) * gain + bias # The unfused way.
在这里fused_bn的输入参数mean,var为输入量x的均值与方差,已经提前计算好,torch.rsqrt函数能够计算张量的平方根倒数,在没有gain与bias的参与下,相当于令x减去均值再除以标准差,即正常的批标准化步骤;在有类别信息的参数下,这里我们提前知道gain与bias为在ccbn中给予类别学习到的信息,在fused_bn中,分别以缩放与平移的方式参加了批标准化的操作。
这样就了解了类别信息如何参与批标准化的操作,但是ccbn类中并没有fused_bn函数,所以接下来开始了解下一个函数——manual_bn:
# Manual BN
# Calculate means and variances using mean-of-squares minus mean-squared
# 这里的gain和bias是类别信息 或者在普通的无条件bn中,就为可学习的参数
def manual_bn(x, gain=None, bias=None, return_mean_var=False, eps=1e-5):
# 首先计算均值与方差
# Cast x to float32 if necessary
float_x = x.float()
# Calculate expected value of x (m) and expected value of x**2 (m2)
# Mean of x
m = torch.mean(float_x, [0, 2, 3], keepdim=True)
# Mean of x squared
m2 = torch.mean(float_x ** 2, [0, 2, 3], keepdim=True)
# Calculate variance as mean of squared minus mean squared.
var = (m2 - m **2)
# Cast back to float 16 if necessary
var = var.type(x.type())
m = m.type(x.type())
# Return mean and variance for updating stored mean/var if requested
if return_mean_var:
# 训练的时候使用
return fused_bn(x, m, var, gain, bias, eps), m.squeeze(), var.squeeze()
else:
return fused_bn(x, m, var, gain, bias, eps)
很明显,manual_bn将输入x的均值与方差计算出来,赋值到了fused_bn进行运算,需要注意的一点为manual_bn分为两种模式,在训练情况下,需要返回均值与方差参与训练,而在测试情况下,均值与方差便无需返回。
之后便迎来最后一个中间类——myBN:
# My batchnorm, supports standing stats
# "My batchnorm, supports standing stats"这句注释表明MyBatchNorm支持在测试时使用先前计算的统计信息,从而避免在测试时重新计算统计信息。
# 指该 BN 实现支持固定的均值和方差,也就是说,如果在训练过程中已经计算好了某个 mini-batch 的均值和方差,并将其保存下来,
# 那么在之后的推理过程中,这个 BN 层就可以直接使用这个固定的均值和方差,而不需要重新计算。这种做法有助于提高模型的推理速度。
class myBN(nn.Module):
def __init__(self, num_channels, eps=1e-5, momentum=0.1):
super(myBN, self).__init__()
# momentum for updating running stats
self.momentum = momentum
# epsilon to avoid dividing by 0
self.eps = eps
# Momentum
self.momentum = momentum
# Register buffers
self.register_buffer('stored_mean', torch.zeros(num_channels))
self.register_buffer('stored_var', torch.ones(num_channels))
self.register_buffer('accumulation_counter', torch.zeros(1))
# Accumulate running means and vars
self.accumulate_standing = False
# reset standing stats
def reset_stats(self):
self.stored_mean[:] = 0
self.stored_var[:] = 0
self.accumulation_counter[:] = 0
def forward(self, x, gain, bias):
# 这里的gain和bias是类别信息 或者在普通的无条件bn中,就为可学习的参数
if self.training:
# 如果在训练过程中,就进行参数的更新(return_mean_var=True),测试的时候不用
out, mean, var = manual_bn(x, gain, bias, return_mean_var=True, eps=self.eps)
# If accumulating standing stats, increment them
if self.accumulate_standing:
self.stored_mean[:] = self.stored_mean + mean.data
self.stored_var[:] = self.stored_var + var.data
self.accumulation_counter += 1.0
# If not accumulating standing stats, take running averages
else:
self.stored_mean[:] = self.stored_mean * (1 - self.momentum) + mean * self.momentum
self.stored_var[:] = self.stored_var * (1 - self.momentum) + var * self.momentum
return out
# If not in training mode, use the stored statistics
else:
mean = self.stored_mean.view(1, -1, 1, 1)
var = self.stored_var.view(1, -1, 1, 1)
# If using standing stats, divide them by the accumulation counter
if self.accumulate_standing:
mean = mean / self.accumulation_counter
var = var / self.accumulation_counter
return fused_bn(x, mean, var, gain, bias, self.eps)
在这个myBN类中,首先定义了三个缓冲区用于训练时的参数更新或测试时的参数调用,分为两种情况,在self.accumulate_standing设置为True时,会将每个mini_batch得到的均值与方差累计起来,在测试时会基于mini_batch的数量对均值与方差求平均值,进而进行bn计算;如果设置为False,那么用于测试的均值与方差将在训练过程中,基于对每个mini_batch的训练进行即时更新,即每收到一个mini_batch便会更新一次均值与方差。这两种方式各有优劣,在实际应用中需要进行调试和验证,基于mini_batch的大小比较影响生成样本的精度。
最后再通过ccbn类调用myBN类,这样便结束了BigGAN中有类别条件的批标准化的主体部分的理解。另外还有无条件bn,与普通的bn层差距不大;还有给予参数self.cross_replica实施的SyncBN2d操作,这是基于多个GPU的BN层操作,在参数更新上有些不同;还有几个不同风格批标准化,理解难度没有主体部分大,有机会便记录一下。
最后感谢GPT3.5的帮助
更多推荐
所有评论(0)