mirror of
https://github.com/lordmathis/CUDANet.git
synced 2025-11-06 01:34:22 +00:00
Implement simple input layer
This commit is contained in:
@@ -7,7 +7,6 @@ project(CUDANet
|
|||||||
find_package(CUDAToolkit REQUIRED)
|
find_package(CUDAToolkit REQUIRED)
|
||||||
include_directories(${CUDAToolkit_INCLUDE_DIRS})
|
include_directories(${CUDAToolkit_INCLUDE_DIRS})
|
||||||
|
|
||||||
# Add project source files for the library
|
|
||||||
set(LIBRARY_SOURCES
|
set(LIBRARY_SOURCES
|
||||||
src/utils/cuda_helper.cu
|
src/utils/cuda_helper.cu
|
||||||
src/kernels/activations.cu
|
src/kernels/activations.cu
|
||||||
@@ -15,6 +14,7 @@ set(LIBRARY_SOURCES
|
|||||||
src/kernels/matmul.cu
|
src/kernels/matmul.cu
|
||||||
src/layers/dense.cu
|
src/layers/dense.cu
|
||||||
src/layers/conv2d.cu
|
src/layers/conv2d.cu
|
||||||
|
src/layers/input.cu
|
||||||
)
|
)
|
||||||
|
|
||||||
set(CMAKE_CUDA_ARCHITECTURES 75)
|
set(CMAKE_CUDA_ARCHITECTURES 75)
|
||||||
|
|||||||
30
include/layers/input.cuh
Normal file
30
include/layers/input.cuh
Normal file
@@ -0,0 +1,30 @@
|
|||||||
|
#ifndef INPUT_LAYER_H
|
||||||
|
#define INPUT_LAYER_H
|
||||||
|
|
||||||
|
#include <ilayer.cuh>
|
||||||
|
|
||||||
|
namespace Layers {
|
||||||
|
|
||||||
|
class Input : public ILayer {
|
||||||
|
public:
|
||||||
|
Input(int inputSize);
|
||||||
|
~Input();
|
||||||
|
|
||||||
|
float* forward(const float* input);
|
||||||
|
|
||||||
|
void setWeights(const float* weights);
|
||||||
|
void setBiases(const float* biases);
|
||||||
|
|
||||||
|
private:
|
||||||
|
void initializeWeights();
|
||||||
|
void initializeBiases();
|
||||||
|
|
||||||
|
void toCuda();
|
||||||
|
|
||||||
|
int inputSize;
|
||||||
|
float* d_output;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace Layers
|
||||||
|
|
||||||
|
#endif // INPUT_LAYER_H
|
||||||
34
src/layers/input.cu
Normal file
34
src/layers/input.cu
Normal file
@@ -0,0 +1,34 @@
|
|||||||
|
#include "cuda_helper.cuh"
|
||||||
|
#include "input.cuh"
|
||||||
|
|
||||||
|
Layers::Input::Input(int inputSize) : inputSize(inputSize) {
|
||||||
|
d_output = nullptr;
|
||||||
|
CUDA_CHECK(cudaMalloc((void**)&d_output, sizeof(float) * inputSize));
|
||||||
|
}
|
||||||
|
|
||||||
|
Layers::Input::~Input() {
|
||||||
|
cudaFree(d_output);
|
||||||
|
}
|
||||||
|
|
||||||
|
/*
|
||||||
|
Copies host input to device d_output
|
||||||
|
|
||||||
|
Args
|
||||||
|
const float* input Host pointer to input data
|
||||||
|
float* d_output Device pointer to input data copied to device
|
||||||
|
*/
|
||||||
|
float* Layers::Input::forward(const float* input) {
|
||||||
|
CUDA_CHECK(cudaMemcpy(
|
||||||
|
d_output, input, sizeof(float) * inputSize, cudaMemcpyHostToDevice
|
||||||
|
));
|
||||||
|
|
||||||
|
return d_output;
|
||||||
|
}
|
||||||
|
|
||||||
|
void Layers::Input::setWeights(const float* weights) {}
|
||||||
|
void Layers::Input::setBiases(const float* biases) {}
|
||||||
|
|
||||||
|
void Layers::Input::initializeWeights() {}
|
||||||
|
void Layers::Input::initializeBiases() {}
|
||||||
|
|
||||||
|
void Layers::Input::toCuda() {}
|
||||||
@@ -4,6 +4,7 @@ include_directories(${GTEST_INCLUDE_DIRS})
|
|||||||
add_executable(test_main
|
add_executable(test_main
|
||||||
layers/test_dense.cu
|
layers/test_dense.cu
|
||||||
layers/test_conv2d.cu
|
layers/test_conv2d.cu
|
||||||
|
layers/test_input.cu
|
||||||
kernels/test_activations.cu
|
kernels/test_activations.cu
|
||||||
kernels/test_padding.cu
|
kernels/test_padding.cu
|
||||||
)
|
)
|
||||||
|
|||||||
16
test/layers/test_input.cu
Normal file
16
test/layers/test_input.cu
Normal file
@@ -0,0 +1,16 @@
|
|||||||
|
#include <gtest/gtest.h>
|
||||||
|
|
||||||
|
#include "input.cuh"
|
||||||
|
#include "cuda_helper.cuh"
|
||||||
|
|
||||||
|
|
||||||
|
TEST(InputLayerTest, Init) {
|
||||||
|
std::vector<float> input = {0.573f, 0.619f, 0.732f, 0.055f, 0.243f, 0.316f};
|
||||||
|
Layers::Input inputLayer(6);
|
||||||
|
float* d_output = inputLayer.forward(input.data());
|
||||||
|
|
||||||
|
std::vector<float> output(6);
|
||||||
|
CUDA_CHECK(cudaMemcpy(output.data(), d_output, sizeof(float) * 6, cudaMemcpyDeviceToHost));
|
||||||
|
EXPECT_EQ(input, output);
|
||||||
|
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user