mirror of
https://github.com/lordmathis/CUDANet.git
synced 2025-11-05 17:34:21 +00:00
Implement modules
This commit is contained in:
@@ -50,6 +50,18 @@ float* Model::predict(const float* input) {
|
||||
}
|
||||
|
||||
void Model::addLayer(const std::string& name, Layers::SequentialLayer* layer) {
|
||||
|
||||
Module* module = dynamic_cast<Module*>(layer);
|
||||
|
||||
if (module != nullptr) {
|
||||
layers.push_back({ name, module });
|
||||
for (const auto& moduleLayer : module->getLayers()) {
|
||||
layerMap[moduleLayer.first] = moduleLayer.second;
|
||||
}
|
||||
|
||||
return;
|
||||
}
|
||||
|
||||
layers.push_back({ name, layer });
|
||||
layerMap[name] = layer;
|
||||
}
|
||||
|
||||
49
src/model/module.cpp
Normal file
49
src/model/module.cpp
Normal file
@@ -0,0 +1,49 @@
|
||||
#include "module.hpp"
|
||||
|
||||
#include "cuda_helper.cuh"
|
||||
|
||||
using namespace CUDANet;
|
||||
|
||||
Module::Module(
|
||||
const int inputSize,
|
||||
const int inputChannels,
|
||||
const int outputSize,
|
||||
const int outputChannels
|
||||
)
|
||||
: inputSize(inputSize),
|
||||
inputChannels(inputChannels),
|
||||
outputSize(outputSize),
|
||||
outputChannels(outputChannels),
|
||||
layers(std::vector<std::pair<std::string, Layers::SequentialLayer*>>()),
|
||||
layerMap(std::unordered_map<std::string, Layers::SequentialLayer*>()) {
|
||||
d_output = nullptr;
|
||||
CUDA_CHECK(cudaMalloc(
|
||||
(void**)&d_output,
|
||||
sizeof(float) * outputSize * outputSize * outputChannels
|
||||
));
|
||||
}
|
||||
|
||||
Module::~Module() {
|
||||
cudaFree(d_output);
|
||||
}
|
||||
|
||||
void Module::addLayer(const std::string& name, Layers::SequentialLayer* layer) {
|
||||
layers.push_back({ name, layer });
|
||||
layerMap[name] = layer;
|
||||
}
|
||||
|
||||
Layers::SequentialLayer* Module::getLayer(const std::string& name) {
|
||||
return layerMap[name];
|
||||
}
|
||||
|
||||
const std::unordered_map<std::string, Layers::SequentialLayer*>& Module::getLayers() const {
|
||||
return layerMap;
|
||||
}
|
||||
|
||||
int Module::getInputSize() {
|
||||
return inputSize * inputSize * inputChannels;
|
||||
}
|
||||
|
||||
int Module::getOutputSize() {
|
||||
return outputSize * outputSize * outputChannels;
|
||||
}
|
||||
Reference in New Issue
Block a user