Rename ILayer to WeightedLayer

This commit is contained in:
2024-03-18 20:36:52 +01:00
parent 6cf604423a
commit d9c6c663c8
5 changed files with 10 additions and 12 deletions

View File

@@ -1,8 +1,6 @@
#ifndef CUDANET_CONCAT_LAYER_H
#define CUDANET_CONCAT_LAYER_H
#include <ilayer.cuh>
namespace CUDANet::Layers {
/**
@@ -14,7 +12,8 @@ class Concat {
/**
* @brief Create a new Concat layer
*
* @param layers Layers to concatenate
* @param inputASize Size of the first input
* @param inputBSize Size of the second input
*/
Concat(const unsigned int inputASize, const unsigned int inputBSize);
@@ -25,10 +24,11 @@ class Concat {
~Concat();
/**
* @brief Forward pass of the concat layer
* @brief Concatenates the two inputs
*
* @param d_input_A Device pointer to the first input
* @param d_input_B Device pointer to the second input
*
* @return Device pointer to the output
*/
float* forward(const float* d_input_A, const float* d_input_B);

View File

@@ -6,7 +6,7 @@
#include "activation.cuh"
#include "convolution.cuh"
#include "ilayer.cuh"
#include "weighted_layer.cuh"
namespace CUDANet::Layers {
@@ -14,7 +14,7 @@ namespace CUDANet::Layers {
* @brief 2D convolutional layer
*
*/
class Conv2d : public ILayer {
class Conv2d : public WeightedLayer {
public:
/**
* @brief Construct a new Conv 2d layer

View File

@@ -5,7 +5,7 @@
#include <string>
#include <vector>
#include "ilayer.cuh"
#include "weighted_layer.cuh"
namespace CUDANet::Layers {
@@ -13,7 +13,7 @@ namespace CUDANet::Layers {
* @brief Dense (fully connected) layer
*
*/
class Dense : public ILayer {
class Dense : public WeightedLayer {
public:
/**
* @brief Construct a new Dense layer

View File

@@ -1,8 +1,6 @@
#ifndef CUDANET_INPUT_LAYER_H
#define CUDANET_INPUT_LAYER_H
#include <ilayer.cuh>
namespace CUDANet::Layers {
/**

View File

@@ -18,13 +18,13 @@ enum Padding { SAME, VALID };
/**
* @brief Base class for all layers
*/
class ILayer {
class WeightedLayer {
public:
/**
* @brief Destroy the ILayer object
*
*/
virtual ~ILayer() {}
virtual ~WeightedLayer() {}
/**
* @brief Virtual function for forward pass