mirror of
https://github.com/lordmathis/CUDANet.git
synced 2025-11-06 09:44:28 +00:00
Cleanup and refactor
This commit is contained in:
@@ -6,10 +6,11 @@
|
|||||||
|
|
||||||
#include "activations.cuh"
|
#include "activations.cuh"
|
||||||
#include "padding.cuh"
|
#include "padding.cuh"
|
||||||
|
#include "ilayer.cuh"
|
||||||
|
|
||||||
namespace Layers {
|
namespace Layers {
|
||||||
|
|
||||||
class Conv2d {
|
class Conv2d : public ILayer {
|
||||||
public:
|
public:
|
||||||
Conv2d(
|
Conv2d(
|
||||||
int inputSize,
|
int inputSize,
|
||||||
@@ -26,8 +27,8 @@ class Conv2d {
|
|||||||
int outputSize;
|
int outputSize;
|
||||||
|
|
||||||
void forward(const float* d_input, float* d_output);
|
void forward(const float* d_input, float* d_output);
|
||||||
void setKernels(const std::vector<float>& kernels_input);
|
void setWeights(const float* weights_input);
|
||||||
|
void setBiases(const float* biases_input);
|
||||||
void host_conv(const float* input, float* output);
|
void host_conv(const float* input, float* output);
|
||||||
|
|
||||||
private:
|
private:
|
||||||
@@ -42,18 +43,18 @@ class Conv2d {
|
|||||||
int numFilters;
|
int numFilters;
|
||||||
|
|
||||||
// Kernels
|
// Kernels
|
||||||
std::vector<float> kernels;
|
std::vector<float> weights;
|
||||||
std::vector<float> biases;
|
std::vector<float> biases;
|
||||||
|
|
||||||
// Cuda
|
// Cuda
|
||||||
float* d_kernels;
|
float* d_weights;
|
||||||
float* d_biases;
|
float* d_biases;
|
||||||
float* d_padded;
|
float* d_padded;
|
||||||
|
|
||||||
// Kernels
|
// Kernels
|
||||||
Activation activation;
|
Activation activation;
|
||||||
|
|
||||||
void initializeKernels();
|
void initializeWeights();
|
||||||
void initializeBiases();
|
void initializeBiases();
|
||||||
void toCuda();
|
void toCuda();
|
||||||
};
|
};
|
||||||
|
|||||||
@@ -19,8 +19,8 @@ class Dense : public ILayer {
|
|||||||
~Dense();
|
~Dense();
|
||||||
|
|
||||||
void forward(const float* d_input, float* d_output);
|
void forward(const float* d_input, float* d_output);
|
||||||
void setWeights(const std::vector<std::vector<float>>& weights);
|
void setWeights(const float* weights);
|
||||||
void setBiases(const std::vector<float>& biases);
|
void setBiases(const float* biases);
|
||||||
|
|
||||||
private:
|
private:
|
||||||
int inputSize;
|
int inputSize;
|
||||||
|
|||||||
@@ -11,8 +11,25 @@ class ILayer {
|
|||||||
virtual ~ILayer() {}
|
virtual ~ILayer() {}
|
||||||
|
|
||||||
virtual void forward(const float* input, float* output) = 0;
|
virtual void forward(const float* input, float* output) = 0;
|
||||||
virtual void setWeights(const std::vector<std::vector<float>>& weights) = 0;
|
virtual void setWeights(const float* weights) = 0;
|
||||||
virtual void setBiases(const std::vector<float>& biases) = 0;
|
virtual void setBiases(const float* biases) = 0;
|
||||||
|
|
||||||
|
private:
|
||||||
|
virtual void initializeWeights() = 0;
|
||||||
|
virtual void initializeBiases() = 0;
|
||||||
|
|
||||||
|
virtual void toCuda() = 0;
|
||||||
|
|
||||||
|
int inputSize;
|
||||||
|
int outputSize;
|
||||||
|
|
||||||
|
float* d_weights;
|
||||||
|
float* d_biases;
|
||||||
|
|
||||||
|
std::vector<float> weights;
|
||||||
|
std::vector<float> biases;
|
||||||
|
|
||||||
|
Activation activation;
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace Layers
|
} // namespace Layers
|
||||||
|
|||||||
@@ -23,7 +23,6 @@ Layers::Conv2d::Conv2d(
|
|||||||
stride(stride),
|
stride(stride),
|
||||||
numFilters(numFilters),
|
numFilters(numFilters),
|
||||||
activation(activation) {
|
activation(activation) {
|
||||||
// Allocate memory for kernels
|
|
||||||
|
|
||||||
switch (padding)
|
switch (padding)
|
||||||
{
|
{
|
||||||
@@ -41,12 +40,12 @@ Layers::Conv2d::Conv2d(
|
|||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
|
||||||
kernels.resize(kernelSize * kernelSize * inputChannels * numFilters);
|
weights.resize(kernelSize * kernelSize * inputChannels * numFilters);
|
||||||
initializeKernels();
|
initializeWeights();
|
||||||
|
|
||||||
d_kernels = nullptr;
|
d_weights = nullptr;
|
||||||
CUDA_CHECK(cudaMalloc(
|
CUDA_CHECK(cudaMalloc(
|
||||||
(void**)&d_kernels,
|
(void**)&d_weights,
|
||||||
sizeof(float) * kernelSize * kernelSize * inputChannels * numFilters
|
sizeof(float) * kernelSize * kernelSize * inputChannels * numFilters
|
||||||
));
|
));
|
||||||
|
|
||||||
@@ -68,27 +67,32 @@ Layers::Conv2d::Conv2d(
|
|||||||
}
|
}
|
||||||
|
|
||||||
Layers::Conv2d::~Conv2d() {
|
Layers::Conv2d::~Conv2d() {
|
||||||
cudaFree(d_kernels);
|
cudaFree(d_weights);
|
||||||
cudaFree(d_biases);
|
cudaFree(d_biases);
|
||||||
cudaFree(d_padded);
|
cudaFree(d_padded);
|
||||||
}
|
}
|
||||||
|
|
||||||
void Layers::Conv2d::initializeKernels() {
|
void Layers::Conv2d::initializeWeights() {
|
||||||
std::fill(kernels.begin(), kernels.end(), 0.0f);
|
std::fill(weights.begin(), weights.end(), 0.0f);
|
||||||
}
|
}
|
||||||
|
|
||||||
void Layers::Conv2d::initializeBiases() {
|
void Layers::Conv2d::initializeBiases() {
|
||||||
std::fill(biases.begin(), biases.end(), 0.0f);
|
std::fill(biases.begin(), biases.end(), 0.0f);
|
||||||
}
|
}
|
||||||
|
|
||||||
void Layers::Conv2d::setKernels(const std::vector<float>& kernels_input) {
|
void Layers::Conv2d::setWeights(const float* weights_input) {
|
||||||
std::copy(kernels_input.begin(), kernels_input.end(), kernels.begin());
|
std::copy(weights_input, weights_input + weights.size(), weights.begin());
|
||||||
|
toCuda();
|
||||||
|
}
|
||||||
|
|
||||||
|
void Layers::Conv2d::setBiases(const float* biases_input) {
|
||||||
|
std::copy(biases_input, biases_input + biases.size(), biases.begin());
|
||||||
toCuda();
|
toCuda();
|
||||||
}
|
}
|
||||||
|
|
||||||
void Layers::Conv2d::toCuda() {
|
void Layers::Conv2d::toCuda() {
|
||||||
CUDA_CHECK(cudaMemcpy(
|
CUDA_CHECK(cudaMemcpy(
|
||||||
d_kernels, kernels.data(),
|
d_weights, weights.data(),
|
||||||
sizeof(float) * kernelSize * kernelSize * inputChannels * numFilters,
|
sizeof(float) * kernelSize * kernelSize * inputChannels * numFilters,
|
||||||
cudaMemcpyHostToDevice
|
cudaMemcpyHostToDevice
|
||||||
));
|
));
|
||||||
@@ -112,7 +116,7 @@ void Layers::Conv2d::forward(const float* d_input, float* d_output) {
|
|||||||
// Convolve
|
// Convolve
|
||||||
THREADS_PER_BLOCK = outputSize * outputSize * numFilters;
|
THREADS_PER_BLOCK = outputSize * outputSize * numFilters;
|
||||||
convolution_kernel<<<1, THREADS_PER_BLOCK>>>(
|
convolution_kernel<<<1, THREADS_PER_BLOCK>>>(
|
||||||
d_padded, d_kernels, d_output, inputSize + (2 * paddingSize),
|
d_padded, d_weights, d_output, inputSize + (2 * paddingSize),
|
||||||
inputChannels, kernelSize, stride, numFilters, outputSize
|
inputChannels, kernelSize, stride, numFilters, outputSize
|
||||||
);
|
);
|
||||||
|
|
||||||
@@ -155,7 +159,7 @@ void Layers::Conv2d::host_conv(const float* input, float* output) {
|
|||||||
(i * stride + k) * inputSize +
|
(i * stride + k) * inputSize +
|
||||||
(j * stride + l);
|
(j * stride + l);
|
||||||
|
|
||||||
sum += kernels[kernelIndex] * input[inputIndex];
|
sum += weights[kernelIndex] * input[inputIndex];
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -10,14 +10,8 @@
|
|||||||
#include "dense.cuh"
|
#include "dense.cuh"
|
||||||
#include "matrix_math.cuh"
|
#include "matrix_math.cuh"
|
||||||
|
|
||||||
Layers::Dense::Dense(
|
Layers::Dense::Dense(int inputSize, int outputSize, Activation activation)
|
||||||
int inputSize,
|
: inputSize(inputSize), outputSize(outputSize), activation(activation) {
|
||||||
int outputSize,
|
|
||||||
Activation activation
|
|
||||||
)
|
|
||||||
: inputSize(inputSize),
|
|
||||||
outputSize(outputSize),
|
|
||||||
activation(activation) {
|
|
||||||
// Allocate memory for weights and biases
|
// Allocate memory for weights and biases
|
||||||
weights.resize(outputSize * inputSize);
|
weights.resize(outputSize * inputSize);
|
||||||
biases.resize(outputSize);
|
biases.resize(outputSize);
|
||||||
@@ -52,7 +46,6 @@ void Layers::Dense::initializeBiases() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
void Layers::Dense::forward(const float* d_input, float* d_output) {
|
void Layers::Dense::forward(const float* d_input, float* d_output) {
|
||||||
|
|
||||||
mat_vec_mul_kernel<<<1, outputSize>>>(
|
mat_vec_mul_kernel<<<1, outputSize>>>(
|
||||||
d_weights, d_input, d_output, inputSize, outputSize
|
d_weights, d_input, d_output, inputSize, outputSize
|
||||||
);
|
);
|
||||||
@@ -63,15 +56,11 @@ void Layers::Dense::forward(const float* d_input, float* d_output) {
|
|||||||
|
|
||||||
switch (activation) {
|
switch (activation) {
|
||||||
case SIGMOID:
|
case SIGMOID:
|
||||||
sigmoid_kernel<<<1, outputSize>>>(
|
sigmoid_kernel<<<1, outputSize>>>(d_output, d_output, outputSize);
|
||||||
d_output, d_output, outputSize
|
|
||||||
);
|
|
||||||
break;
|
break;
|
||||||
|
|
||||||
case RELU:
|
case RELU:
|
||||||
relu_kernel<<<1, outputSize>>>(
|
relu_kernel<<<1, outputSize>>>(d_output, d_output, outputSize);
|
||||||
d_output, d_output, outputSize
|
|
||||||
);
|
|
||||||
break;
|
break;
|
||||||
|
|
||||||
default:
|
default:
|
||||||
@@ -92,26 +81,12 @@ void Layers::Dense::toCuda() {
|
|||||||
));
|
));
|
||||||
}
|
}
|
||||||
|
|
||||||
void Layers::Dense::setWeights(
|
void Layers::Dense::setWeights(const float* weights_input) {
|
||||||
const std::vector<std::vector<float>>& weights_input
|
std::copy(weights_input, weights_input + weights.size(), weights.begin());
|
||||||
) {
|
|
||||||
int numWeights = inputSize * outputSize;
|
|
||||||
|
|
||||||
if (weights.size() != numWeights) {
|
|
||||||
std::cerr << "Invalid number of weights" << std::endl;
|
|
||||||
exit(EXIT_FAILURE);
|
|
||||||
}
|
|
||||||
|
|
||||||
for (int i = 0; i < outputSize; ++i) {
|
|
||||||
for (int j = 0; j < inputSize; ++j) {
|
|
||||||
weights[i * inputSize + j] = weights_input[i][j];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
toCuda();
|
toCuda();
|
||||||
}
|
}
|
||||||
|
|
||||||
void Layers::Dense::setBiases(const std::vector<float>& biases_input) {
|
void Layers::Dense::setBiases(const float* biases_input) {
|
||||||
std::copy(biases_input.begin(), biases_input.end(), biases.begin());
|
std::copy(biases_input, biases_input + biases.size(), biases.begin());
|
||||||
toCuda();
|
toCuda();
|
||||||
}
|
}
|
||||||
@@ -16,7 +16,7 @@ class Conv2dTest : public ::testing::Test {
|
|||||||
int numFilters,
|
int numFilters,
|
||||||
Activation activation,
|
Activation activation,
|
||||||
std::vector<float>& input,
|
std::vector<float>& input,
|
||||||
std::vector<float>& kernels,
|
float* kernels,
|
||||||
float*& d_input,
|
float*& d_input,
|
||||||
float*& d_output
|
float*& d_output
|
||||||
) {
|
) {
|
||||||
@@ -26,7 +26,7 @@ class Conv2dTest : public ::testing::Test {
|
|||||||
activation
|
activation
|
||||||
);
|
);
|
||||||
|
|
||||||
conv2d.setKernels(kernels);
|
conv2d.setWeights(kernels);
|
||||||
|
|
||||||
// Allocate device memory
|
// Allocate device memory
|
||||||
cudaStatus = cudaMalloc(
|
cudaStatus = cudaMalloc(
|
||||||
@@ -84,7 +84,7 @@ TEST_F(Conv2dTest, SimpleTest) {
|
|||||||
|
|
||||||
Layers::Conv2d conv2d = commonTestSetup(
|
Layers::Conv2d conv2d = commonTestSetup(
|
||||||
inputSize, inputChannels, kernelSize, stride, padding, numFilters,
|
inputSize, inputChannels, kernelSize, stride, padding, numFilters,
|
||||||
activation, input, kernels, d_input, d_output
|
activation, input, kernels.data(), d_input, d_output
|
||||||
);
|
);
|
||||||
|
|
||||||
int outputSize = (inputSize - kernelSize) / stride + 1;
|
int outputSize = (inputSize - kernelSize) / stride + 1;
|
||||||
@@ -173,7 +173,7 @@ TEST_F(Conv2dTest, ComplexTest) {
|
|||||||
|
|
||||||
Layers::Conv2d conv2d = commonTestSetup(
|
Layers::Conv2d conv2d = commonTestSetup(
|
||||||
inputSize, inputChannels, kernelSize, stride, padding, numFilters,
|
inputSize, inputChannels, kernelSize, stride, padding, numFilters,
|
||||||
activation, input, kernels, d_input, d_output
|
activation, input, kernels.data(), d_input, d_output
|
||||||
);
|
);
|
||||||
|
|
||||||
EXPECT_EQ(inputSize, conv2d.outputSize);
|
EXPECT_EQ(inputSize, conv2d.outputSize);
|
||||||
|
|||||||
@@ -6,23 +6,20 @@
|
|||||||
#include "activations.cuh"
|
#include "activations.cuh"
|
||||||
#include "dense.cuh"
|
#include "dense.cuh"
|
||||||
|
|
||||||
|
class DenseLayerTest : public ::testing::Test {
|
||||||
class DenseLayerTest : public::testing::Test {
|
|
||||||
protected:
|
protected:
|
||||||
Layers::Dense commonTestSetup(
|
Layers::Dense commonTestSetup(
|
||||||
int inputSize,
|
int inputSize,
|
||||||
int outputSize,
|
int outputSize,
|
||||||
std::vector<float>& input,
|
std::vector<float>& input,
|
||||||
std::vector<std::vector<float>>& weights,
|
float* weights,
|
||||||
std::vector<float>& biases,
|
float* biases,
|
||||||
float*& d_input,
|
float*& d_input,
|
||||||
float*& d_output,
|
float*& d_output,
|
||||||
Activation activation
|
Activation activation
|
||||||
) {
|
) {
|
||||||
// Create Dense layer
|
// Create Dense layer
|
||||||
Layers::Dense denseLayer(
|
Layers::Dense denseLayer(inputSize, outputSize, activation);
|
||||||
inputSize, outputSize, activation
|
|
||||||
);
|
|
||||||
|
|
||||||
// Set weights and biases
|
// Set weights and biases
|
||||||
denseLayer.setWeights(weights);
|
denseLayer.setWeights(weights);
|
||||||
@@ -37,11 +34,11 @@ class DenseLayerTest : public::testing::Test {
|
|||||||
|
|
||||||
// Copy input to device
|
// Copy input to device
|
||||||
cudaStatus = cudaMemcpy(
|
cudaStatus = cudaMemcpy(
|
||||||
d_input, input.data(), sizeof(float) * input.size(), cudaMemcpyHostToDevice
|
d_input, input.data(), sizeof(float) * input.size(),
|
||||||
|
cudaMemcpyHostToDevice
|
||||||
);
|
);
|
||||||
EXPECT_EQ(cudaStatus, cudaSuccess);
|
EXPECT_EQ(cudaStatus, cudaSuccess);
|
||||||
|
|
||||||
|
|
||||||
return denseLayer;
|
return denseLayer;
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -60,9 +57,7 @@ TEST_F(DenseLayerTest, Init) {
|
|||||||
int inputSize = i;
|
int inputSize = i;
|
||||||
int outputSize = j;
|
int outputSize = j;
|
||||||
|
|
||||||
Layers::Dense denseLayer(
|
Layers::Dense denseLayer(inputSize, outputSize, SIGMOID);
|
||||||
inputSize, outputSize, SIGMOID
|
|
||||||
);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -71,17 +66,19 @@ TEST_F(DenseLayerTest, setWeights) {
|
|||||||
int inputSize = 4;
|
int inputSize = 4;
|
||||||
int outputSize = 5;
|
int outputSize = 5;
|
||||||
|
|
||||||
std::vector<std::vector<float>> weights = {
|
// clang-format off
|
||||||
{0.5f, 1.0f, 0.2f, 0.8f},
|
std::vector<float> weights = {
|
||||||
{1.2f, 0.3f, 1.5f, 0.4f},
|
0.5f, 1.0f, 0.2f, 0.8f,
|
||||||
{0.7f, 1.8f, 0.9f, 0.1f},
|
1.2f, 0.3f, 1.5f, 0.4f,
|
||||||
{0.4f, 2.0f, 0.6f, 1.1f},
|
0.7f, 1.8f, 0.9f, 0.1f,
|
||||||
{1.3f, 0.5f, 0.0f, 1.7f}
|
0.4f, 2.0f, 0.6f, 1.1f,
|
||||||
|
1.3f, 0.5f, 0.0f, 1.7f
|
||||||
};
|
};
|
||||||
|
// clang-format on
|
||||||
|
|
||||||
Layers::Dense denseLayer(inputSize, outputSize, SIGMOID);
|
Layers::Dense denseLayer(inputSize, outputSize, SIGMOID);
|
||||||
|
|
||||||
denseLayer.setWeights(weights);
|
denseLayer.setWeights(weights.data());
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(DenseLayerTest, ForwardUnitWeightMatrixLinear) {
|
TEST_F(DenseLayerTest, ForwardUnitWeightMatrixLinear) {
|
||||||
@@ -90,13 +87,11 @@ TEST_F(DenseLayerTest, ForwardUnitWeightMatrixLinear) {
|
|||||||
|
|
||||||
std::vector<float> input = {1.0f, 2.0f, 3.0f};
|
std::vector<float> input = {1.0f, 2.0f, 3.0f};
|
||||||
|
|
||||||
std::vector<std::vector<float>> weights(
|
std::vector<float> weights(outputSize * inputSize, 0.0f);
|
||||||
inputSize, std::vector<float>(outputSize, 0.0f)
|
|
||||||
);
|
|
||||||
for (int i = 0; i < inputSize; ++i) {
|
for (int i = 0; i < inputSize; ++i) {
|
||||||
for (int j = 0; j < outputSize; ++j) {
|
for (int j = 0; j < outputSize; ++j) {
|
||||||
if (i == j) {
|
if (i == j) {
|
||||||
weights[i][j] = 1.0f;
|
weights[i * outputSize + j] = 1.0f;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -106,13 +101,15 @@ TEST_F(DenseLayerTest, ForwardUnitWeightMatrixLinear) {
|
|||||||
float* d_output;
|
float* d_output;
|
||||||
|
|
||||||
Layers::Dense denseLayer = commonTestSetup(
|
Layers::Dense denseLayer = commonTestSetup(
|
||||||
inputSize, outputSize, input, weights, biases, d_input, d_output, LINEAR
|
inputSize, outputSize, input, weights.data(), biases.data(), d_input,
|
||||||
|
d_output, LINEAR
|
||||||
);
|
);
|
||||||
denseLayer.forward(d_input, d_output);
|
denseLayer.forward(d_input, d_output);
|
||||||
|
|
||||||
std::vector<float> output(outputSize);
|
std::vector<float> output(outputSize);
|
||||||
cudaStatus = cudaMemcpy(
|
cudaStatus = cudaMemcpy(
|
||||||
output.data(), d_output, sizeof(float) * outputSize, cudaMemcpyDeviceToHost
|
output.data(), d_output, sizeof(float) * outputSize,
|
||||||
|
cudaMemcpyDeviceToHost
|
||||||
);
|
);
|
||||||
EXPECT_EQ(cudaStatus, cudaSuccess);
|
EXPECT_EQ(cudaStatus, cudaSuccess);
|
||||||
|
|
||||||
@@ -130,26 +127,30 @@ TEST_F(DenseLayerTest, ForwardRandomWeightMatrixRelu) {
|
|||||||
|
|
||||||
std::vector<float> input = {1.0f, 2.0f, 3.0f, 4.0f, -5.0f};
|
std::vector<float> input = {1.0f, 2.0f, 3.0f, 4.0f, -5.0f};
|
||||||
|
|
||||||
std::vector<std::vector<float>> weights = {
|
// clang-format off
|
||||||
{0.5f, 1.2f, 0.7f, 0.4f, 1.3f},
|
std::vector<float> weights = {
|
||||||
{1.0f, 0.3f, 1.8f, 2.0f, 0.5f},
|
0.5f, 1.2f, 0.7f, 0.4f,
|
||||||
{0.2f, 1.5f, 0.9f, 0.6f, 0.0f},
|
1.3f, 1.0f, 0.3f, 1.8f,
|
||||||
{0.8f, 0.4f, 0.1f, 1.1f, 1.7f}
|
2.0f, 0.5f, 0.2f, 1.5f,
|
||||||
|
0.9f, 0.6f, 0.0f, 0.8f,
|
||||||
|
0.4f, 0.1f, 1.1f, 1.7f
|
||||||
};
|
};
|
||||||
std::vector<float> biases = {0.2f, 0.5f, 0.7f, -1.1f};
|
std::vector<float> biases = {0.2f, 0.5f, 0.7f, -1.1f};
|
||||||
|
// clang-format on
|
||||||
|
|
||||||
float* d_input;
|
float* d_input;
|
||||||
float* d_output;
|
float* d_output;
|
||||||
|
|
||||||
Layers::Dense denseLayer = commonTestSetup(
|
Layers::Dense denseLayer = commonTestSetup(
|
||||||
inputSize, outputSize, input, weights, biases, d_input, d_output, RELU
|
inputSize, outputSize, input, weights.data(), biases.data(), d_input, d_output, RELU
|
||||||
);
|
);
|
||||||
|
|
||||||
denseLayer.forward(d_input, d_output);
|
denseLayer.forward(d_input, d_output);
|
||||||
|
|
||||||
std::vector<float> output(outputSize);
|
std::vector<float> output(outputSize);
|
||||||
cudaStatus = cudaMemcpy(
|
cudaStatus = cudaMemcpy(
|
||||||
output.data(), d_output, sizeof(float) * outputSize, cudaMemcpyDeviceToHost
|
output.data(), d_output, sizeof(float) * outputSize,
|
||||||
|
cudaMemcpyDeviceToHost
|
||||||
);
|
);
|
||||||
EXPECT_EQ(cudaStatus, cudaSuccess);
|
EXPECT_EQ(cudaStatus, cudaSuccess);
|
||||||
|
|
||||||
@@ -170,21 +171,22 @@ TEST_F(DenseLayerTest, ForwardRandomWeightMatrixSigmoid) {
|
|||||||
int inputSize = 5;
|
int inputSize = 5;
|
||||||
int outputSize = 4;
|
int outputSize = 4;
|
||||||
|
|
||||||
|
// clang-format off
|
||||||
std::vector<float> input = {0.1f, 0.2f, 0.3f, 0.4f, 0.5f};
|
std::vector<float> input = {0.1f, 0.2f, 0.3f, 0.4f, 0.5f};
|
||||||
|
std::vector<float> weights = {
|
||||||
std::vector<std::vector<float>> weights = {
|
0.8f, 0.7f, 0.7f, 0.3f, 0.8f,
|
||||||
{0.8f, 0.7f, 0.7f, 0.3f, 0.8f},
|
0.1f, 0.4f, 0.8f, 0.0f, 0.2f,
|
||||||
{0.1f, 0.4f, 0.8f, 0.0f, 0.2f},
|
0.2f, 0.5f, 0.7f, 0.3f, 0.0f,
|
||||||
{0.2f, 0.5f, 0.7f, 0.3f, 0.0f},
|
0.1f, 0.7f, 0.6f, 1.0f, 0.4f
|
||||||
{0.1f, 0.7f, 0.6f, 1.0f, 0.4f}
|
|
||||||
};
|
};
|
||||||
std::vector<float> biases = {0.1f, 0.2f, 0.3f, 0.4f};
|
std::vector<float> biases = {0.1f, 0.2f, 0.3f, 0.4f};
|
||||||
|
// clang-format on
|
||||||
|
|
||||||
float* d_input;
|
float* d_input;
|
||||||
float* d_output;
|
float* d_output;
|
||||||
|
|
||||||
Layers::Dense denseLayer = commonTestSetup(
|
Layers::Dense denseLayer = commonTestSetup(
|
||||||
inputSize, outputSize, input, weights, biases, d_input, d_output,
|
inputSize, outputSize, input, weights.data(), biases.data(), d_input, d_output,
|
||||||
SIGMOID
|
SIGMOID
|
||||||
);
|
);
|
||||||
|
|
||||||
@@ -192,7 +194,8 @@ TEST_F(DenseLayerTest, ForwardRandomWeightMatrixSigmoid) {
|
|||||||
|
|
||||||
std::vector<float> output(outputSize);
|
std::vector<float> output(outputSize);
|
||||||
cudaStatus = cudaMemcpy(
|
cudaStatus = cudaMemcpy(
|
||||||
output.data(), d_output, sizeof(float) * outputSize, cudaMemcpyDeviceToHost
|
output.data(), d_output, sizeof(float) * outputSize,
|
||||||
|
cudaMemcpyDeviceToHost
|
||||||
);
|
);
|
||||||
EXPECT_EQ(cudaStatus, cudaSuccess);
|
EXPECT_EQ(cudaStatus, cudaSuccess);
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user