Make conv2d work again

This commit is contained in:
2024-03-10 19:13:22 +01:00
parent 6bbc036f62
commit f3112311da
6 changed files with 146 additions and 98 deletions

View File

@@ -10,4 +10,9 @@ __global__ void pad_matrix_kernel(
int p
);
enum Padding {
SAME,
VALID
};
#endif // PADDING_H

View File

@@ -5,19 +5,20 @@
#include <vector>
#include "activations.cuh"
#include "padding.cuh"
namespace Layers {
class Conv2d {
public:
Conv2d(
int inputSize,
int inputChannels,
int kernelSize,
int stride,
std::string padding,
int numFilters,
Activation activation
int inputSize,
int inputChannels,
int kernelSize,
int stride,
Padding padding,
int numFilters,
Activation activation
);
~Conv2d();