mirror of
https://github.com/lordmathis/CUDANet.git
synced 2025-11-05 17:34:21 +00:00
Change forward function to return output pointer
This commit is contained in:
@@ -26,10 +26,10 @@ class Conv2d : public ILayer {
|
||||
// Outputs
|
||||
int outputSize;
|
||||
|
||||
void forward(const float* d_input, float* d_output);
|
||||
void setWeights(const float* weights_input);
|
||||
void setBiases(const float* biases_input);
|
||||
void host_conv(const float* input, float* output);
|
||||
float* forward(const float* d_input);
|
||||
void setWeights(const float* weights_input);
|
||||
void setBiases(const float* biases_input);
|
||||
void host_conv(const float* input, float* output);
|
||||
|
||||
private:
|
||||
// Inputs
|
||||
@@ -47,6 +47,7 @@ class Conv2d : public ILayer {
|
||||
std::vector<float> biases;
|
||||
|
||||
// Cuda
|
||||
float* d_output;
|
||||
float* d_weights;
|
||||
float* d_biases;
|
||||
float* d_padded;
|
||||
|
||||
@@ -18,7 +18,7 @@ class Dense : public ILayer {
|
||||
);
|
||||
~Dense();
|
||||
|
||||
void forward(const float* d_input, float* d_output);
|
||||
float* forward(const float* d_input);
|
||||
void setWeights(const float* weights);
|
||||
void setBiases(const float* biases);
|
||||
|
||||
@@ -26,6 +26,8 @@ class Dense : public ILayer {
|
||||
int inputSize;
|
||||
int outputSize;
|
||||
|
||||
float* d_output;
|
||||
|
||||
float* d_weights;
|
||||
float* d_biases;
|
||||
|
||||
|
||||
@@ -6,24 +6,17 @@
|
||||
|
||||
namespace Layers {
|
||||
|
||||
enum Activation {
|
||||
SIGMOID,
|
||||
RELU,
|
||||
NONE
|
||||
};
|
||||
enum Activation { SIGMOID, RELU, NONE };
|
||||
|
||||
enum Padding {
|
||||
SAME,
|
||||
VALID
|
||||
};
|
||||
enum Padding { SAME, VALID };
|
||||
|
||||
class ILayer {
|
||||
public:
|
||||
virtual ~ILayer() {}
|
||||
|
||||
virtual void forward(const float* input, float* output) = 0;
|
||||
virtual void setWeights(const float* weights) = 0;
|
||||
virtual void setBiases(const float* biases) = 0;
|
||||
virtual float* forward(const float* input) = 0;
|
||||
virtual void setWeights(const float* weights) = 0;
|
||||
virtual void setBiases(const float* biases) = 0;
|
||||
|
||||
private:
|
||||
virtual void initializeWeights() = 0;
|
||||
@@ -34,6 +27,8 @@ class ILayer {
|
||||
int inputSize;
|
||||
int outputSize;
|
||||
|
||||
float* d_output;
|
||||
|
||||
float* d_weights;
|
||||
float* d_biases;
|
||||
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
#define CUDA_HELPER_H
|
||||
|
||||
#include <cuda_runtime.h>
|
||||
#include <cstdio>
|
||||
|
||||
// CUDA error checking macro
|
||||
#define CUDA_CHECK(call) \
|
||||
|
||||
@@ -37,6 +37,12 @@ Layers::Conv2d::Conv2d(
|
||||
break;
|
||||
}
|
||||
|
||||
d_output = nullptr;
|
||||
CUDA_CHECK(cudaMalloc(
|
||||
(void**)&d_output,
|
||||
sizeof(float) * outputSize * outputSize * numFilters
|
||||
));
|
||||
|
||||
weights.resize(kernelSize * kernelSize * inputChannels * numFilters);
|
||||
initializeWeights();
|
||||
|
||||
@@ -64,6 +70,7 @@ Layers::Conv2d::Conv2d(
|
||||
}
|
||||
|
||||
Layers::Conv2d::~Conv2d() {
|
||||
cudaFree(d_output);
|
||||
cudaFree(d_weights);
|
||||
cudaFree(d_biases);
|
||||
cudaFree(d_padded);
|
||||
@@ -101,7 +108,7 @@ void Layers::Conv2d::toCuda() {
|
||||
));
|
||||
}
|
||||
|
||||
void Layers::Conv2d::forward(const float* d_input, float* d_output) {
|
||||
float* Layers::Conv2d::forward(const float* d_input) {
|
||||
// Pad input
|
||||
int THREADS_PER_BLOCK = (inputSize + 2 * paddingSize) *
|
||||
(inputSize + 2 * paddingSize) * inputChannels;
|
||||
@@ -136,44 +143,6 @@ void Layers::Conv2d::forward(const float* d_input, float* d_output) {
|
||||
}
|
||||
|
||||
CUDA_CHECK(cudaDeviceSynchronize());
|
||||
|
||||
return d_output;
|
||||
}
|
||||
|
||||
/*
|
||||
Convolves input vector with kernel and stores result in output
|
||||
|
||||
input: matrix (inputSize + paddingSize) x (inputSize + paddingSize) x
|
||||
inputChannels represented as a vector output: output matrix outputSize x
|
||||
outputSize x numFilters
|
||||
|
||||
*/
|
||||
void Layers::Conv2d::host_conv(const float* input, float* output) {
|
||||
// Iterate over output matrix
|
||||
for (int tid = 0; tid < outputSize * outputSize * numFilters; tid++) {
|
||||
// Get output index
|
||||
int f = tid / (outputSize * outputSize);
|
||||
int i = tid % (outputSize * outputSize) / outputSize;
|
||||
int j = tid % outputSize;
|
||||
|
||||
float sum = 0.0f;
|
||||
|
||||
// Iterate over kernel and input matrix
|
||||
for (int k = 0; k < kernelSize; k++) {
|
||||
for (int l = 0; l < kernelSize; l++) {
|
||||
for (int c = 0; c < inputChannels; c++) {
|
||||
int kernelIndex =
|
||||
f * kernelSize * kernelSize * inputChannels +
|
||||
c * kernelSize * kernelSize + k * kernelSize + l;
|
||||
int inputIndex = c * inputSize * inputSize +
|
||||
(i * stride + k) * inputSize +
|
||||
(j * stride + l);
|
||||
|
||||
sum += weights[kernelIndex] * input[inputIndex];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
int outputIndex = f * outputSize * outputSize + i * outputSize + j;
|
||||
|
||||
output[outputIndex] = sum;
|
||||
}
|
||||
}
|
||||
@@ -19,6 +19,10 @@ Layers::Dense::Dense(int inputSize, int outputSize, Layers::Activation activatio
|
||||
initializeWeights();
|
||||
initializeBiases();
|
||||
|
||||
d_output = nullptr;
|
||||
|
||||
CUDA_CHECK(cudaMalloc((void**)&d_output, sizeof(float) * outputSize));
|
||||
|
||||
d_weights = nullptr;
|
||||
d_biases = nullptr;
|
||||
|
||||
@@ -33,6 +37,7 @@ Layers::Dense::Dense(int inputSize, int outputSize, Layers::Activation activatio
|
||||
|
||||
Layers::Dense::~Dense() {
|
||||
// Free GPU memory
|
||||
cudaFree(d_output);
|
||||
cudaFree(d_weights);
|
||||
cudaFree(d_biases);
|
||||
}
|
||||
@@ -45,7 +50,7 @@ void Layers::Dense::initializeBiases() {
|
||||
std::fill(biases.begin(), biases.end(), 0.0f);
|
||||
}
|
||||
|
||||
void Layers::Dense::forward(const float* d_input, float* d_output) {
|
||||
float* Layers::Dense::forward(const float* d_input) {
|
||||
Kernels::mat_vec_mul<<<1, outputSize>>>(
|
||||
d_weights, d_input, d_output, inputSize, outputSize
|
||||
);
|
||||
@@ -68,6 +73,8 @@ void Layers::Dense::forward(const float* d_input, float* d_output) {
|
||||
}
|
||||
|
||||
CUDA_CHECK(cudaDeviceSynchronize());
|
||||
|
||||
return d_output;
|
||||
}
|
||||
|
||||
void Layers::Dense::toCuda() {
|
||||
|
||||
@@ -17,8 +17,7 @@ class Conv2dTest : public ::testing::Test {
|
||||
Layers::Activation activation,
|
||||
std::vector<float>& input,
|
||||
float* kernels,
|
||||
float*& d_input,
|
||||
float*& d_output
|
||||
float*& d_input
|
||||
) {
|
||||
// Create Conv2d layer
|
||||
Layers::Conv2d conv2d(
|
||||
@@ -35,12 +34,6 @@ class Conv2dTest : public ::testing::Test {
|
||||
);
|
||||
EXPECT_EQ(cudaStatus, cudaSuccess);
|
||||
|
||||
cudaStatus = cudaMalloc(
|
||||
(void**)&d_output,
|
||||
sizeof(float) * conv2d.outputSize * conv2d.outputSize * numFilters
|
||||
);
|
||||
EXPECT_EQ(cudaStatus, cudaSuccess);
|
||||
|
||||
// // Copy input to device
|
||||
cudaStatus = cudaMemcpy(
|
||||
d_input, input.data(), sizeof(float) * input.size(),
|
||||
@@ -51,10 +44,9 @@ class Conv2dTest : public ::testing::Test {
|
||||
return conv2d;
|
||||
}
|
||||
|
||||
void commonTestTeardown(float* d_input, float* d_output) {
|
||||
void commonTestTeardown(float* d_input) {
|
||||
// Free device memory
|
||||
cudaFree(d_input);
|
||||
cudaFree(d_output);
|
||||
}
|
||||
|
||||
cudaError_t cudaStatus;
|
||||
@@ -84,13 +76,13 @@ TEST_F(Conv2dTest, SimpleTest) {
|
||||
|
||||
Layers::Conv2d conv2d = commonTestSetup(
|
||||
inputSize, inputChannels, kernelSize, stride, padding, numFilters,
|
||||
activation, input, kernels.data(), d_input, d_output
|
||||
activation, input, kernels.data(), d_input
|
||||
);
|
||||
|
||||
int outputSize = (inputSize - kernelSize) / stride + 1;
|
||||
EXPECT_EQ(outputSize, conv2d.outputSize);
|
||||
|
||||
conv2d.forward(d_input, d_output);
|
||||
d_output = conv2d.forward(d_input);
|
||||
|
||||
std::vector<float> expected = {44.0f, 54.0f, 64.0f, 84.0f, 94.0f,
|
||||
104.0f, 124.0f, 134.0f, 144.0f};
|
||||
@@ -106,7 +98,7 @@ TEST_F(Conv2dTest, SimpleTest) {
|
||||
EXPECT_FLOAT_EQ(expected[i], output[i]);
|
||||
}
|
||||
|
||||
commonTestTeardown(d_input, d_output);
|
||||
commonTestTeardown(d_input);
|
||||
}
|
||||
|
||||
TEST_F(Conv2dTest, PaddedTest) {
|
||||
@@ -173,12 +165,12 @@ TEST_F(Conv2dTest, PaddedTest) {
|
||||
|
||||
Layers::Conv2d conv2d = commonTestSetup(
|
||||
inputSize, inputChannels, kernelSize, stride, padding, numFilters,
|
||||
activation, input, kernels.data(), d_input, d_output
|
||||
activation, input, kernels.data(), d_input
|
||||
);
|
||||
|
||||
EXPECT_EQ(inputSize, conv2d.outputSize);
|
||||
|
||||
conv2d.forward(d_input, d_output);
|
||||
d_output = conv2d.forward(d_input);
|
||||
|
||||
std::vector<float> output(
|
||||
conv2d.outputSize * conv2d.outputSize * numFilters
|
||||
@@ -192,23 +184,21 @@ TEST_F(Conv2dTest, PaddedTest) {
|
||||
// Generated by tools/generate_conv2d_test.py
|
||||
std::vector<float> expected = {
|
||||
// Channel 1
|
||||
2.29426f, 3.89173f, 4.17634f, 3.25501f, 2.07618f,
|
||||
5.41483f, 7.09971f, 6.39811f, 5.71432f, 3.10928f,
|
||||
5.12973f, 6.29638f, 5.26962f, 5.21997f, 3.05852f,
|
||||
6.17517f, 7.19311f, 6.69771f, 6.2142f, 4.03242f,
|
||||
3.3792f, 4.36444f, 4.396f, 4.69905f, 3.62061f,
|
||||
2.29426f, 3.89173f, 4.17634f, 3.25501f, 2.07618f, 5.41483f, 7.09971f,
|
||||
6.39811f, 5.71432f, 3.10928f, 5.12973f, 6.29638f, 5.26962f, 5.21997f,
|
||||
3.05852f, 6.17517f, 7.19311f, 6.69771f, 6.2142f, 4.03242f, 3.3792f,
|
||||
4.36444f, 4.396f, 4.69905f, 3.62061f,
|
||||
// Channel 2
|
||||
2.87914f, 3.71743f, 3.51854f, 2.98413f, 1.46579f,
|
||||
4.94951f, 6.18983f, 4.98187f, 4.38372f, 3.35386f,
|
||||
5.0364f, 5.3756f, 4.05993f, 4.89299f, 2.78625f,
|
||||
5.33763f, 5.80899f, 5.89785f, 5.51095f, 3.74287f,
|
||||
2.64053f, 4.05895f, 3.96482f, 4.30177f, 1.94269f
|
||||
2.87914f, 3.71743f, 3.51854f, 2.98413f, 1.46579f, 4.94951f, 6.18983f,
|
||||
4.98187f, 4.38372f, 3.35386f, 5.0364f, 5.3756f, 4.05993f, 4.89299f,
|
||||
2.78625f, 5.33763f, 5.80899f, 5.89785f, 5.51095f, 3.74287f, 2.64053f,
|
||||
4.05895f, 3.96482f, 4.30177f, 1.94269f
|
||||
};
|
||||
for (int i = 0; i < output.size(); i++) {
|
||||
EXPECT_NEAR(output[i], expected[i], 0.0001f);
|
||||
}
|
||||
|
||||
commonTestTeardown(d_input, d_output);
|
||||
commonTestTeardown(d_input);
|
||||
}
|
||||
|
||||
TEST_F(Conv2dTest, StridedPaddedConvolution) {
|
||||
@@ -260,12 +250,12 @@ TEST_F(Conv2dTest, StridedPaddedConvolution) {
|
||||
|
||||
Layers::Conv2d conv2d = commonTestSetup(
|
||||
inputSize, inputChannels, kernelSize, stride, padding, numFilters,
|
||||
activation, input, kernels.data(), d_input, d_output
|
||||
activation, input, kernels.data(), d_input
|
||||
);
|
||||
|
||||
EXPECT_EQ(inputSize, conv2d.outputSize);
|
||||
|
||||
conv2d.forward(d_input, d_output);
|
||||
d_output = conv2d.forward(d_input);
|
||||
|
||||
std::vector<float> output(
|
||||
conv2d.outputSize * conv2d.outputSize * numFilters
|
||||
@@ -279,22 +269,18 @@ TEST_F(Conv2dTest, StridedPaddedConvolution) {
|
||||
// Generated by tools/generate_conv2d_test.py
|
||||
std::vector<float> expected = {
|
||||
// Channel 1
|
||||
0.0f, 0.0f, 0.0f, 0.0f, 0.0f,
|
||||
0.0f, 1.59803f, 2.84444f, 1.6201f, 0.0f,
|
||||
0.0f, 2.38937f, 3.80762f, 3.39679f, 0.0f,
|
||||
0.0f, 1.13102f, 2.33335f, 1.98488f, 0.0f,
|
||||
0.0f, 0.0f, 0.0f, 0.0f, 0.0f,
|
||||
0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 1.59803f, 2.84444f, 1.6201f, 0.0f,
|
||||
0.0f, 2.38937f, 3.80762f, 3.39679f, 0.0f, 0.0f, 1.13102f, 2.33335f,
|
||||
1.98488f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f,
|
||||
// Channel 2
|
||||
0.0f, 0.0f, 0.0f, 0.0f, 0.0f,
|
||||
0.0f, 2.57732f, 3.55543f, 2.24675f, 0.0f,
|
||||
0.0f, 3.36842f, 3.41373f, 3.14804f, 0.0f,
|
||||
0.0f, 1.17963f, 2.55005f, 1.63218f, 0.0f,
|
||||
0.0f, 0.0f, 0.0f, 0.0f, 0.0f
|
||||
0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 2.57732f, 3.55543f, 2.24675f, 0.0f,
|
||||
0.0f, 3.36842f, 3.41373f, 3.14804f, 0.0f, 0.0f, 1.17963f, 2.55005f,
|
||||
1.63218f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f
|
||||
};
|
||||
|
||||
for (int i = 0; i < output.size(); i++) {
|
||||
EXPECT_NEAR(output[i], expected[i], 0.0001f);
|
||||
}
|
||||
|
||||
commonTestTeardown(d_input, d_output);
|
||||
commonTestTeardown(d_input);
|
||||
}
|
||||
|
||||
@@ -15,7 +15,6 @@ class DenseLayerTest : public ::testing::Test {
|
||||
float* weights,
|
||||
float* biases,
|
||||
float*& d_input,
|
||||
float*& d_output,
|
||||
Layers::Activation activation
|
||||
) {
|
||||
// Create Dense layer
|
||||
@@ -29,9 +28,6 @@ class DenseLayerTest : public ::testing::Test {
|
||||
cudaStatus = cudaMalloc((void**)&d_input, sizeof(float) * input.size());
|
||||
EXPECT_EQ(cudaStatus, cudaSuccess);
|
||||
|
||||
cudaStatus = cudaMalloc((void**)&d_output, sizeof(float) * outputSize);
|
||||
EXPECT_EQ(cudaStatus, cudaSuccess);
|
||||
|
||||
// Copy input to device
|
||||
cudaStatus = cudaMemcpy(
|
||||
d_input, input.data(), sizeof(float) * input.size(),
|
||||
@@ -42,10 +38,9 @@ class DenseLayerTest : public ::testing::Test {
|
||||
return denseLayer;
|
||||
}
|
||||
|
||||
void commonTestTeardown(float* d_input, float* d_output) {
|
||||
void commonTestTeardown(float* d_input) {
|
||||
// Free device memory
|
||||
cudaFree(d_input);
|
||||
cudaFree(d_output);
|
||||
}
|
||||
|
||||
cudaError_t cudaStatus;
|
||||
@@ -106,9 +101,9 @@ TEST_F(DenseLayerTest, ForwardUnitWeightMatrixLinear) {
|
||||
|
||||
Layers::Dense denseLayer = commonTestSetup(
|
||||
inputSize, outputSize, input, weights.data(), biases.data(), d_input,
|
||||
d_output, Layers::Activation::NONE
|
||||
Layers::Activation::NONE
|
||||
);
|
||||
denseLayer.forward(d_input, d_output);
|
||||
d_output = denseLayer.forward(d_input);
|
||||
|
||||
std::vector<float> output(outputSize);
|
||||
cudaStatus = cudaMemcpy(
|
||||
@@ -122,7 +117,7 @@ TEST_F(DenseLayerTest, ForwardUnitWeightMatrixLinear) {
|
||||
EXPECT_FLOAT_EQ(output[1], 3.0f);
|
||||
EXPECT_FLOAT_EQ(output[2], 4.0f);
|
||||
|
||||
commonTestTeardown(d_input, d_output);
|
||||
commonTestTeardown(d_input);
|
||||
}
|
||||
|
||||
TEST_F(DenseLayerTest, ForwardRandomWeightMatrixRelu) {
|
||||
@@ -147,10 +142,10 @@ TEST_F(DenseLayerTest, ForwardRandomWeightMatrixRelu) {
|
||||
|
||||
Layers::Dense denseLayer = commonTestSetup(
|
||||
inputSize, outputSize, input, weights.data(), biases.data(), d_input,
|
||||
d_output, Layers::Activation::RELU
|
||||
Layers::Activation::RELU
|
||||
);
|
||||
|
||||
denseLayer.forward(d_input, d_output);
|
||||
d_output = denseLayer.forward(d_input);
|
||||
|
||||
std::vector<float> output(outputSize);
|
||||
cudaStatus = cudaMemcpy(
|
||||
@@ -169,7 +164,7 @@ TEST_F(DenseLayerTest, ForwardRandomWeightMatrixRelu) {
|
||||
); // Allow small tolerance for floating-point comparison
|
||||
}
|
||||
|
||||
commonTestTeardown(d_input, d_output);
|
||||
commonTestTeardown(d_input);
|
||||
}
|
||||
|
||||
TEST_F(DenseLayerTest, ForwardRandomWeightMatrixSigmoid) {
|
||||
@@ -192,10 +187,10 @@ TEST_F(DenseLayerTest, ForwardRandomWeightMatrixSigmoid) {
|
||||
|
||||
Layers::Dense denseLayer = commonTestSetup(
|
||||
inputSize, outputSize, input, weights.data(), biases.data(), d_input,
|
||||
d_output, Layers::Activation::SIGMOID
|
||||
Layers::Activation::SIGMOID
|
||||
);
|
||||
|
||||
denseLayer.forward(d_input, d_output);
|
||||
d_output = denseLayer.forward(d_input);
|
||||
|
||||
std::vector<float> output(outputSize);
|
||||
cudaStatus = cudaMemcpy(
|
||||
@@ -216,5 +211,5 @@ TEST_F(DenseLayerTest, ForwardRandomWeightMatrixSigmoid) {
|
||||
EXPECT_NEAR(output[i], expectedOutput[i], 1e-5);
|
||||
}
|
||||
|
||||
commonTestTeardown(d_input, d_output);
|
||||
commonTestTeardown(d_input);
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user