Add support for non square matrices

This commit is contained in:
2024-05-20 15:20:43 +02:00
parent 6f8b5f4081
commit 74098b24e3
21 changed files with 314 additions and 299 deletions

View File

@@ -10,7 +10,7 @@ namespace CUDANet::Layers {
class BatchNorm2D : public WeightedLayer {
public:
BatchNorm2D(int inputSize, int inputChannels, float epsilon, ActivationType activationType);
BatchNorm2D(dim2d inputSize, int inputChannels, float epsilon, ActivationType activationType);
~BatchNorm2D();
@@ -66,7 +66,7 @@ class BatchNorm2D : public WeightedLayer {
private:
int inputSize;
dim2d inputSize;
int inputChannels;
int gridSize;