Update BatchNorm2d to return sizes for running mean and var

This commit is contained in:
2025-11-23 20:48:41 +01:00
parent 9f1a56c699
commit 4161caf3e1
3 changed files with 10 additions and 10 deletions

View File

@@ -30,11 +30,11 @@ class BatchNorm2d : public Layer {
void set_running_mean(void* input);
CUDANet::Tensor& get_running_mean();
size_t get_running_mean_size();
void set_running_var(void* input);
CUDANet::Tensor& get_running_var();
size_t get_running_var_size();
private:
CUDANet::Shape in_shape;