Update BatchNorm2d to return sizes for running mean and var

This commit is contained in:
2025-11-23 20:48:41 +01:00
parent 9f1a56c699
commit 4161caf3e1
3 changed files with 10 additions and 10 deletions

View File

@@ -30,11 +30,11 @@ class BatchNorm2d : public Layer {
void set_running_mean(void* input); void set_running_mean(void* input);
CUDANet::Tensor& get_running_mean(); size_t get_running_mean_size();
void set_running_var(void* input); void set_running_var(void* input);
CUDANet::Tensor& get_running_var(); size_t get_running_var_size();
private: private:
CUDANet::Shape in_shape; CUDANet::Shape in_shape;

View File

@@ -90,14 +90,14 @@ void BatchNorm2d::set_running_mean(void* input) {
running_mean.set_data<float>(static_cast<float*>(input)); running_mean.set_data<float>(static_cast<float*>(input));
} }
CUDANet::Tensor& BatchNorm2d::get_running_mean() { size_t BatchNorm2d::get_running_mean_size() {
return running_mean; return running_mean.size();
} }
void BatchNorm2d::set_running_var(void* input) { void BatchNorm2d::set_running_var(void* input) {
running_var.set_data<float>(static_cast<float*>(input)); running_var.set_data<float>(static_cast<float*>(input));
} }
CUDANet::Tensor& BatchNorm2d::get_running_var() { size_t BatchNorm2d::get_running_var_size() {
return running_var; return running_var.size();
} }

View File

@@ -155,16 +155,16 @@ void Model::load_weights(const std::string& path) {
} }
if (tensor_info.type == TensorType::RUNNING_MEAN) { if (tensor_info.type == TensorType::RUNNING_MEAN) {
if (bn_layer->get_running_mean().size() != values.size()) { if (bn_layer->get_running_mean_size() != values.size()) {
std::cerr << "Layer: " << tensor_info.name << " has incorrect number of running mean values, expected " std::cerr << "Layer: " << tensor_info.name << " has incorrect number of running mean values, expected "
<< bn_layer->get_running_mean().size() << " but got " << values.size() << ", skipping" << std::endl; << bn_layer->get_running_mean_size() << " but got " << values.size() << ", skipping" << std::endl;
continue; continue;
} }
bn_layer->set_running_mean(values.data()); bn_layer->set_running_mean(values.data());
} else if (tensor_info.type == TensorType::RUNNING_VAR) { } else if (tensor_info.type == TensorType::RUNNING_VAR) {
if (bn_layer->get_running_var().size() != values.size()) { if (bn_layer->get_running_var_size() != values.size()) {
std::cerr << "Layer: " << tensor_info.name << " has incorrect number of running var values, expected " std::cerr << "Layer: " << tensor_info.name << " has incorrect number of running var values, expected "
<< bn_layer->get_running_var().size() << " but got " << values.size() << ", skipping" << std::endl; << bn_layer->get_running_var_size() << " but got " << values.size() << ", skipping" << std::endl;
continue; continue;
} }
bn_layer->set_running_var(values.data()); bn_layer->set_running_var(values.data());