大家好,又见面了,我是你们的朋友全栈君。
参考:
https://zhuanlan.zhihu.com/p/100672008
https://www.jianshu.com/p/2b94da24af3b
https://github.com/ptrblck/pytorch_misc
# python3.8
# -*- coding: utf-8 -*-
# ---
# @Software: PyCharm
# @File: test2.py
# @Author: ---
# @Institution: BeiJing, China
# @E-mail: lgdyangninghua@163.com
# @Site:
# @Time: 5月 19, 2021
# ---
import torch
import torch.nn as nn
import numpy as np
np.random.seed(10)
torch.manual_seed(10)
data = np.array([[1, 2, 7],
[1, 3, 9],
[1, 4, 6]]).astype(np.float32)
bn_torch = nn.BatchNorm1d(num_features=3)
data_torch = torch.from_numpy(data)
bn_output_torch = bn_torch(data_torch)
print("bn_output_torch:", bn_output_torch)
def fowardbn(x, gam, beta, ):
'''
x:(N,D)维数据
'''
momentum = 0.1
eps = 1e-05
running_mean = 0
running_var = 1
running_mean = (1 - momentum) * running_mean + momentum * x.mean(dim=0)
running_var = (1 - momentum) * running_var + momentum * x.var(dim=0)
mean = x.mean(dim=0)
var = x.var(dim=0,unbiased=False)
# bnmiddle_buffer = (input - mean) / ((var + eps) ** 0.5).data
x_hat = (x - mean) / torch.sqrt(var + eps)
out = gam * x_hat + beta
print("x_mean:", mean, "x_var:", var, "self._gamma:", gam, "self._beta:", beta)
cache = (x, gam, beta, x_hat, mean, var, eps)
return out, cache
class MyBN:
def __init__(self, momentum, eps, num_features):
"""
初始化参数值
:param momentum: 追踪样本整体均值和方差的动量
:param eps: 防止数值计算错误
:param num_features: 特征数量
"""
# 对每个batch的mean和var进行追踪统计
self._running_mean = 0
self._running_var = 1
# 更新self._running_xxx时的动量
self._momentum = momentum
# 防止分母计算为0
self._eps = eps
# 对应论文中需要更新的beta和gamma,采用pytorch文档中的初始化值
self._beta = np.zeros(shape=(num_features, ))
self._gamma = np.ones(shape=(num_features, ))
def batch_norm(self, x):
"""
BN向传播
:param x: 数据
:return: BN输出
"""
x_mean = x.mean(axis=0)
x_var = x.var(axis=0)
# 对应running_mean的更新公式
self._running_mean = (1-self._momentum)*x_mean + self._momentum*self._running_mean
self._running_var = (1-self._momentum)*x_var + self._momentum*self._running_var
# 对应论文中计算BN的公式
x_hat = (x-x_mean)/np.sqrt(x_var+self._eps)
y = self._gamma*x_hat + self._beta
print("x_mean:", x_mean, "x_var:", x_var, "self._gamma:", self._gamma, "self._beta:", self._beta)
return y
my_bn = MyBN(momentum=0.1, eps=1e-05, num_features=3)
my_bn._beta = bn_torch.bias.detach().numpy()
my_bn._gamma = bn_torch.weight.detach().numpy()
bn_output = my_bn.batch_norm(data, )
print("MyBN bn_output:", bn_output)
out, cache = fowardbn(data_torch.detach(), bn_torch.weight.detach(), bn_torch.bias.detach())
print("fowardbn out2: ", out)
# python3.8
# -*- coding: utf-8 -*-
# ---
# @Software: PyCharm
# @File: test.py
# @Author: ---
# @Institution: BeiJing, China
# @E-mail: lgdyangninghua@163.com
# @Site:
# @Time: 5月 19, 2021
# ---
import numpy as np
np.set_printoptions(suppress = True)
import torch
import torch.nn as nn
np.random.seed(10)
torch.manual_seed(10)
# import pprint
# np.random.seed(10)
# norm = np.random.normal(size=(5, 5))
# pprint.pprint(norm)
data = [
[0.1, 0.3, 0.4],
[0.5, 0.3, 0.2],
[0.4, 0.6, 0.1],
[0.5, 0.3, 0.2],
]
data_np = np.array(data, dtype=np.float32)*10; print("data_np.shape:", data_np.shape);
data_np = data_np.reshape((3,-1)); print("data_np.shape:", data_np.shape);
t_data = torch.from_numpy(data_np); t_data = torch.unsqueeze(t_data, dim=0)
print("t_data.shape:", t_data.shape); print(t_data)
class PointNet(nn.Module):
def __init__(self):
super(PointNet, self).__init__()
#pytorch之nn.Conv1d详解 https://blog.csdn.net/sunny_xsc1994/article/details/82969867
self.conv1 = torch.nn.Conv1d(3, 5, 1)
self.bn1 = nn.BatchNorm1d(5)
#Pytorch权值初始化及参数分组 https://blog.csdn.net/Bear_Kai/article/details/99302341
#Pytorch 实现权重初始化 https://www.jb51.net/article/177617.htm
for m in self.modules():
if isinstance(m, nn.Conv1d):
m.weight.data.normal_(0, 1)
if m.bias is not None:
m.bias.data.zero_()
self.weight = np.asarray(m.weight.data)
#print("nn.Conv1d:", m.weight.data)
elif isinstance(m, nn.BatchNorm1d):
m.weight.data.fill_(5) #1
m.bias.data.zero_()
def forward(self, x):
result1 = self.conv1(x)
result2 = self.bn1(result1)
return result1, result2, self.weight
pn = PointNet()
result1, result2, weight = pn(t_data); weight = torch.from_numpy(weight)
print("weight.shape:", weight.shape); print("weight:", weight)
print("result1.shape:", result1.shape); print(result1)
print("result2.shape:", result2.shape); print(result2)
#print("result1_end:", pn.bn1(result1))
#PointNet论文复现及代码详解 https://zhuanlan.zhihu.com/p/86331508
for n in range(t_data.shape[2]):
sum = []
for m in range(weight.shape[0]):
#Pytorch总结之乘法 https://zhuanlan.zhihu.com/p/212461087
#sum += (torch.mul(t_data[0,:,0], weight[m,:,0]))#对应位相乘
sum.append(torch.dot(t_data[0, :, n], weight[m, :, 0]))#点乘
print("sum:", sum)
#pytorch nn.BatchNorm1d 与手动python实现不一样--解决办法 https://www.jianshu.com/p/2b94da24af3b
#https://pytorch.org/docs/stable/generated/torch.nn.BatchNorm1d.html
def fowardbn(x, gam, beta, dim=0):
'''
x:(N,D)维数据
'''
momentum = 0.1
eps = 1e-05
running_mean = 0
running_var = 5 #1
running_mean = (1 - momentum) * running_mean + momentum * x.mean(dim=0)
running_var = (1 - momentum) * running_var + momentum * x.var(dim=0)
mean = x.mean(dim=dim)
var = x.var(dim=dim, unbiased=False)
# bnmiddle_buffer = (input - mean) / ((var + eps) ** 0.5).data
x_hat = (x - mean) / torch.sqrt(var + eps)
out = gam * x_hat + beta
print("x_mean:", mean, "x_var:", var, "self._gamma:", gam, "self._beta:", beta)
cache = (x, gam, beta, x_hat, mean, var, eps)
return out, cache
#如果是B*C*(H*W)
#1, 3_Iup, 4
#3_Iup, 5_Out, 1 卷积核
#1, 5_Out(channel), 4
bn_re = result1.permute(0, 2, 1)
out, cache = fowardbn(bn_re, pn.bn1.weight, pn.bn1.bias, dim=1)
out = out.permute(0, 2, 1)
print("out1", out)
bn_re = result1.squeeze()
bn_re = bn_re.permute(1, 0)
out, cache = fowardbn(bn_re, pn.bn1.weight, pn.bn1.bias, dim=0)
out = out.permute(1, 0)
print("out2", out)
x = np.array([[-1.2089, 6.8342, -0.3317, -5.2298],
[ 2.5075, 9.6109, 8.8057, 9.0995],
[ 4.2763, 1.2605, 6.7774, 11.4138],
[ 1.0103, 1.0549, 0.3408, 0.0656],
[-2.2381, 1.9428, -3.6522, -7.8491]])
x = x.mean(axis=1)
y = np.array([-2.2381, 1.9428, -3.6522, -7.8491])
y = y.mean(axis=0)
print(x)
print(y)
发布者:全栈程序员-用户IM,转载请注明出处:https://javaforall.cn/144784.html原文链接:https://javaforall.cn
【正版授权,激活自己账号】: Jetbrains全家桶Ide使用,1年售后保障,每天仅需1毛
【官方授权 正版激活】: 官方授权 正版激活 支持Jetbrains家族下所有IDE 使用个人JB账号...