Start conv test implementation

This commit is contained in:
2024-03-07 22:03:05 +01:00
parent 7e75943a6b
commit 69ccba2dad
3 changed files with 48 additions and 12 deletions

View File

@@ -11,16 +11,19 @@ namespace Layers {
class Conv2d { class Conv2d {
public: public:
Conv2d( Conv2d(
int inputSize, int inputSize,
int inputChannels, int inputChannels,
int kernelSize, int kernelSize,
int stride, int stride,
std::string padding, std::string padding,
int numFilters, int numFilters,
Activation activation Activation activation
); );
~Conv2d(); ~Conv2d();
// Outputs
int outputSize;
void forward(const float* d_input, float* d_output); void forward(const float* d_input, float* d_output);
private: private:
@@ -34,15 +37,12 @@ class Conv2d {
int paddingSize; int paddingSize;
int numFilters; int numFilters;
// Outputs
int outputSize;
// Kernels // Kernels
std::vector<float> kernels; std::vector<float> kernels;
// Cuda // Cuda
float* d_kernels; float* d_kernels;
float* d_padded; float* d_padded;
// Kernels // Kernels
Activation activation; Activation activation;

View File

@@ -3,6 +3,7 @@ include_directories(${GTEST_INCLUDE_DIRS})
add_executable(test_main add_executable(test_main
layers/test_dense.cu layers/test_dense.cu
layers/test_conv2d.cu
kernels/test_activations.cu kernels/test_activations.cu
kernels/test_padding.cu kernels/test_padding.cu
) )

View File

@@ -0,0 +1,35 @@
#include <cuda_runtime_api.h>
#include <gtest/gtest.h>
#include <iostream>
#include "conv2d.cuh"
TEST(Conv2dTest, ValidPadding) {
int inputSize = 3;
int inputChannels = 1;
int kernelSize = 3;
int stride = 1;
std::string padding = "VALID";
int numFilters = 1;
Activation activation = LINEAR;
Layers::Conv2d conv2d(
inputSize,
inputChannels,
kernelSize,
stride,
padding,
numFilters,
activation
);
int outputSize = (inputSize - kernelSize) / stride + 1;
EXPECT_EQ(outputSize, conv2d.outputSize);
std::vector<float> input(inputSize * inputSize * inputChannels);
std::vector<float> output(outputSize * outputSize * numFilters);
std::vector<float> kernels(kernelSize * kernelSize * inputChannels * numFilters);
}