springboot实现查询数据以图片形式导出+图片下载

zblog批量添加文章标签,sql语句根据文章标题添加标签

  返回  

BN代码实现

2021/8/20 11:08:35 浏览:

BN代码实现

import torch
from torch import nn

def batch_norm(X,gamma,beta,moving_mean,moving_var,eps,momentum):
  # 推理
  if not torch.is_grad_enabled():
    X_hat = (X - moving_mean) / torch.sqrt(moving_var + eps)
  # 训练
  else:
    assert len(X.shape) in (2,4):
    # 全连接
    if len(X.shape) == 2:
      mean = X.mean(dim=0)
      var = ((X - mean)**2).mean(dim=0)
    # 卷积
    else:
      mean = X.mean(dim=(0,2,3),keepdim=True)
      var = ((X - mean)**2).mean(dim=(0,2,3),keepdim=True)

    X_hat = (X - mean) / torch.sqrt(var + eps)
    # 滑动平均记录训练过程中BN的均值和方差
    moving_mean = momentum * moving_mean + (1.0 - momentum) * mean
    moving_var = momentum * moving_var + (1.0 - momentum) * var
  Y = gamma * X_hat + beta
  return Y, moving_mean.data, moving_var.data
class BatchNorm(nn.module):
  def __init__(self,num_features,num_dims):
    super().__init__()
    if num_dim == 2:
      shape = (1,num_features)
    else:
      shape = (1,num_features,1,1)
    # 参数初始化
    self.gamma = nn.Parameter(torch.ones(shape))
    self.beta = nn.Parameter(torch.ones(shape))
    self.moving_mean = torch.zeros(shape)
    self.moving_var = torch.zeros(shape)

  def forward(self,X):
    if self.moving_mean.device != X.device:
      self.moving_mean = self.moving_mean.to(X.device)
      self.moving_var = self.moving_var.to(X.device)
    Y, self.gamma, self.beta, self.moving_mean, self.moving_var = batch_norm(X,self.gamma,self.beta,self.moving_mean,self.moving_var,eps=1e-5,momentum=0.9)

    return Y

联系我们

如果您对我们的服务有兴趣,请及时和我们联系!

服务热线:18288888888
座机:18288888888
传真:
邮箱:888888@qq.com
地址:郑州市文化路红专路93号