Update BatchNorm2d to return sizes for running mean and var

This commit is contained in:
2025-11-23 20:48:41 +01:00
parent 9f1a56c699
commit 4161caf3e1
3 changed files with 10 additions and 10 deletions

View File

@@ -90,14 +90,14 @@ void BatchNorm2d::set_running_mean(void* input) {
running_mean.set_data<float>(static_cast<float*>(input));
}
CUDANet::Tensor& BatchNorm2d::get_running_mean() {
return running_mean;
size_t BatchNorm2d::get_running_mean_size() {
return running_mean.size();
}
void BatchNorm2d::set_running_var(void* input) {
running_var.set_data<float>(static_cast<float*>(input));
}
CUDANet::Tensor& BatchNorm2d::get_running_var() {
return running_var;
size_t BatchNorm2d::get_running_var_size() {
return running_var.size();
}

View File

@@ -155,16 +155,16 @@ void Model::load_weights(const std::string& path) {
}
if (tensor_info.type == TensorType::RUNNING_MEAN) {
if (bn_layer->get_running_mean().size() != values.size()) {
if (bn_layer->get_running_mean_size() != values.size()) {
std::cerr << "Layer: " << tensor_info.name << " has incorrect number of running mean values, expected "
<< bn_layer->get_running_mean().size() << " but got " << values.size() << ", skipping" << std::endl;
<< bn_layer->get_running_mean_size() << " but got " << values.size() << ", skipping" << std::endl;
continue;
}
bn_layer->set_running_mean(values.data());
} else if (tensor_info.type == TensorType::RUNNING_VAR) {
if (bn_layer->get_running_var().size() != values.size()) {
if (bn_layer->get_running_var_size() != values.size()) {
std::cerr << "Layer: " << tensor_info.name << " has incorrect number of running var values, expected "
<< bn_layer->get_running_var().size() << " but got " << values.size() << ", skipping" << std::endl;
<< bn_layer->get_running_var_size() << " but got " << values.size() << ", skipping" << std::endl;
continue;
}
bn_layer->set_running_var(values.data());