diff --git a/include/layers/batch_norm.hpp b/include/layers/batch_norm.hpp index 601e443..1b605e9 100644 --- a/include/layers/batch_norm.hpp +++ b/include/layers/batch_norm.hpp @@ -30,11 +30,11 @@ class BatchNorm2d : public Layer { void set_running_mean(void* input); - CUDANet::Tensor& get_running_mean(); + size_t get_running_mean_size(); void set_running_var(void* input); - CUDANet::Tensor& get_running_var(); + size_t get_running_var_size(); private: CUDANet::Shape in_shape; diff --git a/src/layers/batch_norm.cpp b/src/layers/batch_norm.cpp index cd2293d..213c16b 100644 --- a/src/layers/batch_norm.cpp +++ b/src/layers/batch_norm.cpp @@ -90,14 +90,14 @@ void BatchNorm2d::set_running_mean(void* input) { running_mean.set_data(static_cast(input)); } -CUDANet::Tensor& BatchNorm2d::get_running_mean() { - return running_mean; +size_t BatchNorm2d::get_running_mean_size() { + return running_mean.size(); } void BatchNorm2d::set_running_var(void* input) { running_var.set_data(static_cast(input)); } -CUDANet::Tensor& BatchNorm2d::get_running_var() { - return running_var; +size_t BatchNorm2d::get_running_var_size() { + return running_var.size(); } \ No newline at end of file diff --git a/src/model.cpp b/src/model.cpp index ace690e..d04adc0 100644 --- a/src/model.cpp +++ b/src/model.cpp @@ -155,16 +155,16 @@ void Model::load_weights(const std::string& path) { } if (tensor_info.type == TensorType::RUNNING_MEAN) { - if (bn_layer->get_running_mean().size() != values.size()) { + if (bn_layer->get_running_mean_size() != values.size()) { std::cerr << "Layer: " << tensor_info.name << " has incorrect number of running mean values, expected " - << bn_layer->get_running_mean().size() << " but got " << values.size() << ", skipping" << std::endl; + << bn_layer->get_running_mean_size() << " but got " << values.size() << ", skipping" << std::endl; continue; } bn_layer->set_running_mean(values.data()); } else if (tensor_info.type == TensorType::RUNNING_VAR) { - if (bn_layer->get_running_var().size() != values.size()) { + if (bn_layer->get_running_var_size() != values.size()) { std::cerr << "Layer: " << tensor_info.name << " has incorrect number of running var values, expected " - << bn_layer->get_running_var().size() << " but got " << values.size() << ", skipping" << std::endl; + << bn_layer->get_running_var_size() << " but got " << values.size() << ", skipping" << std::endl; continue; } bn_layer->set_running_var(values.data());