mirror of
https://github.com/lordmathis/CUDANet.git
synced 2025-11-06 01:34:22 +00:00
Implement concat layer
This commit is contained in:
@@ -14,6 +14,7 @@ set(LIBRARY_SOURCES
|
|||||||
src/kernels/matmul.cu
|
src/kernels/matmul.cu
|
||||||
src/layers/dense.cu
|
src/layers/dense.cu
|
||||||
src/layers/conv2d.cu
|
src/layers/conv2d.cu
|
||||||
|
src/layers/concat.cu
|
||||||
src/layers/input.cu
|
src/layers/input.cu
|
||||||
src/layers/activation.cu
|
src/layers/activation.cu
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -11,7 +11,7 @@ Convolutional Neural Network inference library running on CUDA.
|
|||||||
- [x] Conv2d layer
|
- [x] Conv2d layer
|
||||||
- [ ] Max pooling
|
- [ ] Max pooling
|
||||||
- [ ] Average pooling
|
- [ ] Average pooling
|
||||||
- [ ] Concat layer
|
- [x] Concat layer
|
||||||
- [x] Sigmoid activation
|
- [x] Sigmoid activation
|
||||||
- [x] ReLU activation
|
- [x] ReLU activation
|
||||||
- [x] Softmax activation
|
- [x] Softmax activation
|
||||||
|
|||||||
45
include/layers/concat.cuh
Normal file
45
include/layers/concat.cuh
Normal file
@@ -0,0 +1,45 @@
|
|||||||
|
#ifndef CUDANET_CONCAT_LAYER_H
|
||||||
|
#define CUDANET_CONCAT_LAYER_H
|
||||||
|
|
||||||
|
#include <ilayer.cuh>
|
||||||
|
|
||||||
|
namespace CUDANet::Layers {
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Concatenate layers
|
||||||
|
*
|
||||||
|
*/
|
||||||
|
class Concat {
|
||||||
|
public:
|
||||||
|
/**
|
||||||
|
* @brief Create a new Concat layer
|
||||||
|
*
|
||||||
|
* @param layers Layers to concatenate
|
||||||
|
*/
|
||||||
|
Concat(const unsigned int inputASize, const unsigned int inputBSize);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Destroy the Concat layer
|
||||||
|
*
|
||||||
|
*/
|
||||||
|
~Concat();
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Forward pass of the concat layer
|
||||||
|
*
|
||||||
|
* @param d_input_A Device pointer to the first input
|
||||||
|
* @param d_input_B Device pointer to the second input
|
||||||
|
* @return Device pointer to the output
|
||||||
|
*/
|
||||||
|
float* forward(const float* d_input_A, const float* d_input_B);
|
||||||
|
|
||||||
|
private:
|
||||||
|
unsigned int inputASize;
|
||||||
|
unsigned int inputBSize;
|
||||||
|
|
||||||
|
float* d_output;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace CUDANet::Layers
|
||||||
|
|
||||||
|
#endif // CUDANET_CONCAT_LAYER_H
|
||||||
@@ -64,18 +64,6 @@ class ILayer {
|
|||||||
* @brief Copy the weights and biases to the device
|
* @brief Copy the weights and biases to the device
|
||||||
*/
|
*/
|
||||||
virtual void toCuda() = 0;
|
virtual void toCuda() = 0;
|
||||||
|
|
||||||
int inputSize;
|
|
||||||
int outputSize;
|
|
||||||
|
|
||||||
float* d_output;
|
|
||||||
|
|
||||||
float* d_weights;
|
|
||||||
float* d_biases;
|
|
||||||
|
|
||||||
std::vector<float> weights;
|
|
||||||
std::vector<float> biases;
|
|
||||||
|
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace CUDANet::Layers
|
} // namespace CUDANet::Layers
|
||||||
|
|||||||
@@ -9,7 +9,7 @@ namespace CUDANet::Layers {
|
|||||||
* @brief Input layer, just copies the input to the device
|
* @brief Input layer, just copies the input to the device
|
||||||
*
|
*
|
||||||
*/
|
*/
|
||||||
class Input : public ILayer {
|
class Input {
|
||||||
public:
|
public:
|
||||||
/**
|
/**
|
||||||
* @brief Create a new Input layer
|
* @brief Create a new Input layer
|
||||||
@@ -32,15 +32,7 @@ class Input : public ILayer {
|
|||||||
*/
|
*/
|
||||||
float* forward(const float* input);
|
float* forward(const float* input);
|
||||||
|
|
||||||
void setWeights(const float* weights);
|
|
||||||
void setBiases(const float* biases);
|
|
||||||
|
|
||||||
private:
|
private:
|
||||||
void initializeWeights();
|
|
||||||
void initializeBiases();
|
|
||||||
|
|
||||||
void toCuda();
|
|
||||||
|
|
||||||
int inputSize;
|
int inputSize;
|
||||||
float* d_output;
|
float* d_output;
|
||||||
};
|
};
|
||||||
|
|||||||
32
src/layers/concat.cu
Normal file
32
src/layers/concat.cu
Normal file
@@ -0,0 +1,32 @@
|
|||||||
|
#include "concat.cuh"
|
||||||
|
#include "cuda_helper.cuh"
|
||||||
|
|
||||||
|
using namespace CUDANet;
|
||||||
|
|
||||||
|
|
||||||
|
Layers::Concat::Concat(const unsigned int inputASize, const unsigned int inputBSize)
|
||||||
|
: inputASize(inputASize), inputBSize(inputBSize) {
|
||||||
|
|
||||||
|
d_output = nullptr;
|
||||||
|
CUDA_CHECK(cudaMalloc(
|
||||||
|
(void**)&d_output, sizeof(float) * (inputASize + inputBSize)
|
||||||
|
));
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
Layers::Concat::~Concat() {
|
||||||
|
cudaFree(d_output);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
float* Layers::Concat::forward(const float* d_input_A, const float* d_input_B) {
|
||||||
|
CUDA_CHECK(cudaMemcpy(
|
||||||
|
d_output, d_input_A, sizeof(float) * inputASize, cudaMemcpyDeviceToDevice
|
||||||
|
));
|
||||||
|
CUDA_CHECK(cudaMemcpy(
|
||||||
|
d_output + inputASize, d_input_B,
|
||||||
|
sizeof(float) * inputBSize, cudaMemcpyDeviceToDevice
|
||||||
|
));
|
||||||
|
|
||||||
|
return d_output;
|
||||||
|
}
|
||||||
@@ -26,11 +26,3 @@ float* Layers::Input::forward(const float* input) {
|
|||||||
|
|
||||||
return d_output;
|
return d_output;
|
||||||
}
|
}
|
||||||
|
|
||||||
void Layers::Input::setWeights(const float* weights) {}
|
|
||||||
void Layers::Input::setBiases(const float* biases) {}
|
|
||||||
|
|
||||||
void Layers::Input::initializeWeights() {}
|
|
||||||
void Layers::Input::initializeBiases() {}
|
|
||||||
|
|
||||||
void Layers::Input::toCuda() {}
|
|
||||||
@@ -3,10 +3,11 @@ include_directories(${GTEST_INCLUDE_DIRS})
|
|||||||
|
|
||||||
add_executable(test_main
|
add_executable(test_main
|
||||||
EXCLUDE_FROM_ALL
|
EXCLUDE_FROM_ALL
|
||||||
layers/test_dense.cu
|
|
||||||
layers/test_conv2d.cu
|
|
||||||
layers/test_input.cu
|
|
||||||
layers/test_activation.cu
|
layers/test_activation.cu
|
||||||
|
layers/test_concat.cu
|
||||||
|
layers/test_conv2d.cu
|
||||||
|
layers/test_dense.cu
|
||||||
|
layers/test_input.cu
|
||||||
kernels/test_activation_functions.cu
|
kernels/test_activation_functions.cu
|
||||||
kernels/test_padding.cu
|
kernels/test_padding.cu
|
||||||
kernels/test_matmul.cu
|
kernels/test_matmul.cu
|
||||||
|
|||||||
37
test/layers/test_concat.cu
Normal file
37
test/layers/test_concat.cu
Normal file
@@ -0,0 +1,37 @@
|
|||||||
|
#include "concat.cuh"
|
||||||
|
#include <gtest/gtest.h>
|
||||||
|
#include <cuda_runtime.h>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
TEST(ConcatLayerTest, Init) {
|
||||||
|
std::vector<float> inputA = {0.573f, 0.619f, 0.732f, 0.055f, 0.243f};
|
||||||
|
std::vector<float> inputB = {0.123f, 0.321f, 0.456f, 0.789f, 0.654f, 0.123f};
|
||||||
|
|
||||||
|
CUDANet::Layers::Concat concat(5, 6);
|
||||||
|
|
||||||
|
float* d_inputA;
|
||||||
|
float* d_inputB;
|
||||||
|
cudaMalloc((void**)&d_inputA, sizeof(float) * 5);
|
||||||
|
cudaMalloc((void**)&d_inputB, sizeof(float) * 6);
|
||||||
|
cudaMemcpy(
|
||||||
|
d_inputA, inputA.data(), sizeof(float) * 5, cudaMemcpyHostToDevice
|
||||||
|
);
|
||||||
|
cudaMemcpy(
|
||||||
|
d_inputB, inputB.data(), sizeof(float) * 6, cudaMemcpyHostToDevice
|
||||||
|
);
|
||||||
|
|
||||||
|
float* d_output = concat.forward(d_inputA, d_inputB);
|
||||||
|
|
||||||
|
std::vector<float> output(11);
|
||||||
|
cudaMemcpy(
|
||||||
|
output.data(), d_output, sizeof(float) * 11, cudaMemcpyDeviceToHost
|
||||||
|
);
|
||||||
|
|
||||||
|
for (int i = 0; i < 5; ++i) {
|
||||||
|
EXPECT_EQ(output[i], inputA[i]);
|
||||||
|
}
|
||||||
|
for (int i = 0; i < 6; ++i) {
|
||||||
|
EXPECT_EQ(output[i + 5], inputB[i]);
|
||||||
|
}
|
||||||
|
cudaFree(d_output);
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user