Implement simple test for host conv2d

This commit is contained in:
2024-03-08 23:12:04 +01:00
parent 69ccba2dad
commit 4b6fcbc191
3 changed files with 39 additions and 12 deletions

View File

@@ -25,6 +25,9 @@ class Conv2d {
int outputSize;
void forward(const float* d_input, float* d_output);
void setKernels(const std::vector<float>& kernels_input);
void host_conv(const float* input, float* output);
private:
// Inputs
@@ -49,10 +52,6 @@ class Conv2d {
void initializeKernels();
void toCuda();
void setKernels(const std::vector<float>& kernels_input);
void host_conv(const float* input, float* output);
};
} // namespace Layers