Migrate concat layer

This commit is contained in:
2024-09-09 22:16:22 +02:00
parent a0665fb05c
commit fe7c16ac36
6 changed files with 80 additions and 42 deletions

View File

@@ -0,0 +1,31 @@
#include "concat.hpp"
#include "cuda_helper.cuh"
using namespace CUDANet::Layers;
void Concat::initCUDA() {
d_output = nullptr;
CUDA_CHECK(
cudaMalloc((void**)&d_output, sizeof(float) * (inputASize + inputBSize))
);
}
void Concat::delCUDA() {
cudaFree(d_output);
}
float* Concat::forwardCUDA(const float* d_input_A, const float* d_input_B) {
CUDA_CHECK(cudaMemcpy(
d_output, d_input_A, sizeof(float) * inputASize,
cudaMemcpyDeviceToDevice
));
CUDA_CHECK(cudaMemcpy(
d_output + inputASize, d_input_B, sizeof(float) * inputBSize,
cudaMemcpyDeviceToDevice
));
CUDA_CHECK(cudaDeviceSynchronize());
return d_output;
}

34
src/layers/concat.cpp Normal file
View File

@@ -0,0 +1,34 @@
#include <stdexcept>
#include "concat.hpp"
using namespace CUDANet::Layers;
Concat::Concat(const int inputASize, const int inputBSize)
: inputASize(inputASize), inputBSize(inputBSize) {
#ifdef USE_CUDA
initCUDA();
#endif
}
Concat::~Concat() {
#ifdef USE_CUDA
delCUDA();
#endif
}
float* Concat::forwardCPU(const float* input_A, const float* input_B) {
throw std::logic_error("Not implemented");
}
float* Concat::forward(const float* input_A, const float* input_B) {
#ifdef USE_CUDA
return forwardCUDA(input_A, input_B);
#else
return forwardCPU(input_A, input_B);
#endif
}
int Concat::getOutputSize() {
return inputASize + inputBSize;
};

View File

@@ -1,37 +0,0 @@
#include "concat.cuh"
#include "cuda_helper.cuh"
using namespace CUDANet::Layers;
Concat::Concat(const int inputASize, const int inputBSize)
: inputASize(inputASize), inputBSize(inputBSize) {
d_output = nullptr;
CUDA_CHECK(cudaMalloc(
(void**)&d_output, sizeof(float) * (inputASize + inputBSize)
));
}
Concat::~Concat() {
cudaFree(d_output);
}
float* Concat::forward(const float* d_input_A, const float* d_input_B) {
CUDA_CHECK(cudaMemcpy(
d_output, d_input_A, sizeof(float) * inputASize, cudaMemcpyDeviceToDevice
));
CUDA_CHECK(cudaMemcpy(
d_output + inputASize, d_input_B,
sizeof(float) * inputBSize, cudaMemcpyDeviceToDevice
));
CUDA_CHECK(cudaDeviceSynchronize());
return d_output;
}
int Concat::getOutputSize() {
return inputASize + inputBSize;
};