Migrate output layer

This commit is contained in:
2024-09-10 19:20:00 +02:00
parent f7b525e494
commit 74f49d6a00
6 changed files with 34 additions and 12 deletions

View File

@@ -19,7 +19,7 @@
#include "input.hpp" #include "input.hpp"
#include "layer.hpp" #include "layer.hpp"
#include "max_pooling.hpp" #include "max_pooling.hpp"
#include "output.cuh" #include "output.hpp"
// Models // Models
#include "model.hpp" #include "model.hpp"

View File

@@ -46,6 +46,12 @@ class Output : public SequentialLayer {
private: private:
int inputSize; int inputSize;
float* h_output; float* h_output;
float* forwardCPU(const float* input);
#ifdef USE_CUDA
float* forwardCUDA(const float* input);
#endif
}; };
} // namespace CUDANet::Layers } // namespace CUDANet::Layers

View File

@@ -8,7 +8,7 @@
#include "input.hpp" #include "input.hpp"
#include "layer.hpp" #include "layer.hpp"
#include "module.hpp" #include "module.hpp"
#include "output.cuh" #include "output.hpp"
namespace CUDANet { namespace CUDANet {

View File

@@ -0,0 +1,14 @@
#include "output.hpp"
#include "cuda_helper.cuh"
using namespace CUDANet::Layers;
float* Output::forwardCUDA(const float* input) {
CUDA_CHECK(cudaMemcpy(
h_output, input, sizeof(float) * inputSize, cudaMemcpyDeviceToHost
));
CUDA_CHECK(cudaDeviceSynchronize());
return h_output;
}

View File

@@ -1,6 +1,5 @@
#include "output.cuh" #include "output.hpp"
#include <stdexcept>
#include "cuda_helper.cuh"
using namespace CUDANet::Layers; using namespace CUDANet::Layers;
@@ -13,13 +12,16 @@ Output::~Output() {
free(h_output); free(h_output);
} }
float* Output::forward(const float* input) { float* Output::forwardCPU(const float* input) {
CUDA_CHECK(cudaMemcpy( throw std::logic_error("Not implemented");
h_output, input, sizeof(float) * inputSize, cudaMemcpyDeviceToHost }
));
CUDA_CHECK(cudaDeviceSynchronize());
return h_output; float* Output::forward(const float* input) {
#ifdef USE_CUDA
return forwardCUDA(input);
#else
return forwardCPU(input);
#endif
} }
int Output::getOutputSize() { int Output::getOutputSize() {

View File

@@ -1,7 +1,7 @@
#include <cuda_runtime_api.h> #include <cuda_runtime_api.h>
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include "output.cuh" #include "output.hpp"
TEST(OutputLayerTest, OutputForward) { TEST(OutputLayerTest, OutputForward) {
cudaError_t cudaStatus; cudaError_t cudaStatus;