mirror of
https://github.com/lordmathis/CUDANet.git
synced 2025-12-22 06:14:22 +00:00
Update BatchNorm2d to return sizes for running mean and var
This commit is contained in:
@@ -30,11 +30,11 @@ class BatchNorm2d : public Layer {
|
||||
|
||||
void set_running_mean(void* input);
|
||||
|
||||
CUDANet::Tensor& get_running_mean();
|
||||
size_t get_running_mean_size();
|
||||
|
||||
void set_running_var(void* input);
|
||||
|
||||
CUDANet::Tensor& get_running_var();
|
||||
size_t get_running_var_size();
|
||||
|
||||
private:
|
||||
CUDANet::Shape in_shape;
|
||||
|
||||
@@ -90,14 +90,14 @@ void BatchNorm2d::set_running_mean(void* input) {
|
||||
running_mean.set_data<float>(static_cast<float*>(input));
|
||||
}
|
||||
|
||||
CUDANet::Tensor& BatchNorm2d::get_running_mean() {
|
||||
return running_mean;
|
||||
size_t BatchNorm2d::get_running_mean_size() {
|
||||
return running_mean.size();
|
||||
}
|
||||
|
||||
void BatchNorm2d::set_running_var(void* input) {
|
||||
running_var.set_data<float>(static_cast<float*>(input));
|
||||
}
|
||||
|
||||
CUDANet::Tensor& BatchNorm2d::get_running_var() {
|
||||
return running_var;
|
||||
size_t BatchNorm2d::get_running_var_size() {
|
||||
return running_var.size();
|
||||
}
|
||||
@@ -155,16 +155,16 @@ void Model::load_weights(const std::string& path) {
|
||||
}
|
||||
|
||||
if (tensor_info.type == TensorType::RUNNING_MEAN) {
|
||||
if (bn_layer->get_running_mean().size() != values.size()) {
|
||||
if (bn_layer->get_running_mean_size() != values.size()) {
|
||||
std::cerr << "Layer: " << tensor_info.name << " has incorrect number of running mean values, expected "
|
||||
<< bn_layer->get_running_mean().size() << " but got " << values.size() << ", skipping" << std::endl;
|
||||
<< bn_layer->get_running_mean_size() << " but got " << values.size() << ", skipping" << std::endl;
|
||||
continue;
|
||||
}
|
||||
bn_layer->set_running_mean(values.data());
|
||||
} else if (tensor_info.type == TensorType::RUNNING_VAR) {
|
||||
if (bn_layer->get_running_var().size() != values.size()) {
|
||||
if (bn_layer->get_running_var_size() != values.size()) {
|
||||
std::cerr << "Layer: " << tensor_info.name << " has incorrect number of running var values, expected "
|
||||
<< bn_layer->get_running_var().size() << " but got " << values.size() << ", skipping" << std::endl;
|
||||
<< bn_layer->get_running_var_size() << " but got " << values.size() << ", skipping" << std::endl;
|
||||
continue;
|
||||
}
|
||||
bn_layer->set_running_var(values.data());
|
||||
|
||||
Reference in New Issue
Block a user