the Test code for pytorch conv2d
# pytorch
import torch
import torch.nn as nn
x = torch.tensor(range(50),dtype=torch.float)
x = torch.reshape(x, (1,2,5,5)) #N,C,H,W
x.requires_grad = True
print(x.shape)
print(x)
conv = nn.Conv2d(2, 1, 3)
y = conv(x)
print(y)
print(y.shape)
o = torch.rand(1,1,3, 3)
y.backward(o)
print(x.grad)
print(conv.weight.grad)
print(conv.bias.grad)
the output is :
torch.Size([1, 2, 5, 5])
tensor([[[[ 0., 1., 2., 3., 4.],
[ 5., 6., 7., 8., 9.],
[10., 11., 12., 13., 14.],
[15., 16., 17., 18., 19.],
[20., 21., 22., 23., 24.]],
[[25., 26., 27., 28., 29.],
[30., 31., 32., 33., 34.],
[35., 36., 37., 38., 39.],
[40., 41., 42., 43., 44.],
[45., 46., 47., 48., 49.]]]], requires_grad=True)
tensor([[[[17.1784, 17.8333, 18.4882],
[20.4529, 21.1078, 21.7627],
[23.7273, 24.3822, 25.0371]]]], grad_fn=<ThnnConv2DBackward>)
torch.Size([1, 1, 3, 3])
tensor([[[[-1.8901e-02, -5.3357e-02, 2.5044e-02, 1.3101e-01, 3.9638e-02],
[ 4.0727e-02, 2.3310e-02, -1.3061e-01, -8.1674e-02, -3.5341e-02],
[ 1.6612e-02, 4.8658e-03, 5.5961e-02, 1.7897e-01, 1.2505e-01],
[ 1.1835e-01, 6.1245e-02, -8.6392e-02, -1.7587e-01, -9.4656e-02],
[ 8.3893e-05, 8.6943e-02, 2.1793e-01, 2.0295e-01, 8.1738e-02]],
[[-2.6953e-02, 1.3573e-02, 1.5071e-01, 6.0323e-02, 4.0366e-03],
[ 2.5628e-02, 2.0234e-01, 3.0339e-01, 7.0121e-02, 4.1564e-03],
[-7.2647e-02, 1.1603e-01, 3.6021e-01, 2.5557e-01, 3.5036e-02],
[ 4.1276e-02, 1.9398e-01, 3.5394e-01, 1.8174e-01, 1.2092e-02],
[-1.3326e-01, -8.9466e-02, 8.0237e-02, 1.6254e-01, 6.0368e-02]]]])
tensor([[[[ 32.0852, 36.8167, 41.5482],
[ 55.7426, 60.4741, 65.2056],
[ 79.4001, 84.1316, 88.8631]],
[[150.3724, 155.1039, 159.8354],
[174.0299, 178.7614, 183.4928],
[197.6873, 202.4188, 207.1503]]]])
tensor([4.7315])
the Test code for my implementaiton of convolution operation
from Layers import *
x = np.array(range(50),dtype=np.float)
x = np.reshape(x, (1,2,5,5)) #N,C,H,W
print(x)
convf = Conv_fast(2, 1, 3)
convf.K = conv.weight.detach().numpy()
convf.b = conv.bias.detach().numpy()
print(convf(x))
do = o.detach().numpy()
dx = convf.backward(do)
print(dx)
print(convf.grads[0])
print(convf.grads[1])
the output is :
[[[[ 0. 1. 2. 3. 4.]
[ 5. 6. 7. 8. 9.]
[10. 11. 12. 13. 14.]
[15. 16. 17. 18. 19.]
[20. 21. 22. 23. 24.]]
[[25. 26. 27. 28. 29.]
[30. 31. 32. 33. 34.]
[35. 36. 37. 38. 39.]
[40. 41. 42. 43. 44.]
[45. 46. 47. 48. 49.]]]]
[[[[17.17845082 17.83333966 18.4882285 ]
[20.45289502 21.10778385 21.76267269]
[23.72733921 24.38222805 25.03711689]]]]
[[[[-1.89012289e-02 -5.33568785e-02 2.50437073e-02 1.31009877e-01
3.96380052e-02]
[ 4.07270789e-02 2.33095456e-02 -1.30606547e-01 -8.16744715e-02
-3.53411548e-02]
[ 1.66122355e-02 4.86576185e-03 5.59609681e-02 1.78969383e-01
1.25054792e-01]
[ 1.18346602e-01 6.12446666e-02 -8.63919258e-02 -1.75870165e-01
-9.46561024e-02]
[ 8.38925480e-05 8.69431645e-02 2.17928350e-01 2.02951357e-01
8.17382038e-02]]
[[-2.69525535e-02 1.35725178e-02 1.50713146e-01 6.03230894e-02
4.03656950e-03]
[ 2.56275106e-02 2.02344760e-01 3.03391725e-01 7.01210722e-02
4.15636040e-03]
[-7.26472586e-02 1.16026394e-01 3.60213071e-01 2.55567610e-01
3.50362472e-02]
[ 4.12757248e-02 1.93978712e-01 3.53935063e-01 1.81737542e-01
1.20918062e-02]
[-1.33261606e-01 -8.94659385e-02 8.02367032e-02 1.62543684e-01
6.03683963e-02]]]]
[[[[ 32.08518082 36.81667 41.54815918]
[ 55.74262673 60.47411591 65.20560509]
[ 79.40007263 84.13156182 88.863051 ]]
[[150.37241036 155.10389954 159.83538872]
[174.02985626 178.76134545 183.49283463]
[197.68730217 202.41879135 207.15028054]]]]
[[4.73148918]]
您的打赏是对我最大的鼓励!