Fix some dense layer issues

This commit is contained in:
2025-11-18 22:17:08 +01:00
parent 7f203b8947
commit 4c26efe826
8 changed files with 110 additions and 44 deletions

View File

@@ -32,12 +32,12 @@ class Backend {
) = 0; ) = 0;
virtual CUDANet::Tensor& dense( virtual CUDANet::Tensor& dense(
CUDANet::Tensor& weights, const CUDANet::Tensor& weights,
CUDANet::Tensor& biases, const CUDANet::Tensor& biases,
CUDANet::Tensor& input, const CUDANet::Tensor& input,
CUDANet::Tensor& output, CUDANet::Tensor& output,
size_t input_size, const size_t input_size,
size_t output_size const size_t output_size
) = 0; ) = 0;
}; };

View File

@@ -29,12 +29,12 @@ class CUDA : public Backend {
) override; ) override;
CUDANet::Tensor& dense( CUDANet::Tensor& dense(
CUDANet::Tensor& weights, const CUDANet::Tensor& weights,
CUDANet::Tensor& biases, const CUDANet::Tensor& biases,
CUDANet::Tensor& input, const CUDANet::Tensor& input,
CUDANet::Tensor& output, CUDANet::Tensor& output,
size_t input_size, const size_t input_size,
size_t output_size const size_t output_size
) override; ) override;
}; };

View File

@@ -20,7 +20,7 @@ class Layer {
virtual ~Layer(){}; virtual ~Layer(){};
virtual CUDANet::Tensor& forward(CUDANet::Tensor &input) = 0; virtual CUDANet::Tensor& forward(const CUDANet::Tensor &input) = 0;
virtual CUDANet::Shape input_shape() = 0; virtual CUDANet::Shape input_shape() = 0;

View File

@@ -18,7 +18,7 @@ class Dense : public Layer {
~Dense(); ~Dense();
CUDANet::Tensor& forward(CUDANet::Tensor &input) override; CUDANet::Tensor& forward(const CUDANet::Tensor &input) override;
CUDANet::Shape input_shape() override; CUDANet::Shape input_shape() override;

View File

@@ -22,6 +22,12 @@ public:
Tensor() = default; Tensor() = default;
Tensor(Shape shape, DType dtype, CUDANet::Backend* backend); Tensor(Shape shape, DType dtype, CUDANet::Backend* backend);
Tensor(Tensor&& other) noexcept;
Tensor& operator=(Tensor&& other) noexcept;
Tensor(const Tensor&) = delete;
Tensor& operator=(const Tensor&) = delete;
~Tensor(); ~Tensor();
size_t size() const; size_t size() const;

View File

@@ -1,20 +1,24 @@
#include "backend/cuda.cuh" #include "backend/cuda.cuh"
#include "utils/cuda_helper.cuh"
#include "kernels/activation_functions.cuh" #include "kernels/activation_functions.cuh"
#include "kernels/matmul.cuh" #include "kernels/matmul.cuh"
#include "utils/cuda_helper.cuh"
using namespace CUDANet::Backend; using namespace CUDANet::Backend;
void CUDA::relu(Tensor& tensor) { void CUDA::relu(Tensor& tensor) {
int gridSize = (tensor.numel() + BLOCK_SIZE - 1) / BLOCK_SIZE; int gridSize = (tensor.numel() + BLOCK_SIZE - 1) / BLOCK_SIZE;
Kernels::relu<<<gridSize, BLOCK_SIZE>>>(tensor.data<float>(), tensor.data<float>(), tensor.numel()); Kernels::relu<<<gridSize, BLOCK_SIZE>>>(
tensor.data<float>(), tensor.data<float>(), tensor.numel()
);
CUDA_CHECK(cudaGetLastError()); CUDA_CHECK(cudaGetLastError());
CUDA_CHECK(cudaDeviceSynchronize()); CUDA_CHECK(cudaDeviceSynchronize());
} }
void CUDA::sigmoid(Tensor& tensor) { void CUDA::sigmoid(Tensor& tensor) {
int gridSize = (tensor.numel() + BLOCK_SIZE - 1) / BLOCK_SIZE; int gridSize = (tensor.numel() + BLOCK_SIZE - 1) / BLOCK_SIZE;
Kernels::sigmoid<<<gridSize, BLOCK_SIZE>>>(tensor.data<float>(), tensor.data<float>(), tensor.numel()); Kernels::sigmoid<<<gridSize, BLOCK_SIZE>>>(
tensor.data<float>(), tensor.data<float>(), tensor.numel()
);
CUDA_CHECK(cudaGetLastError()); CUDA_CHECK(cudaGetLastError());
CUDA_CHECK(cudaDeviceSynchronize()); CUDA_CHECK(cudaDeviceSynchronize());
} }
@@ -27,7 +31,8 @@ void CUDA::softmax(Tensor &tensor, Tensor &temp_max, Tensor &temp_sum) {
// Subtract max value to improve numerical stability // Subtract max value to improve numerical stability
Kernels::vec_scalar_sub<<<gridSize, BLOCK_SIZE>>>( Kernels::vec_scalar_sub<<<gridSize, BLOCK_SIZE>>>(
tensor.data<float>(), tensor.data<float>(), temp_max.data<float>(), tensor.numel() tensor.data<float>(), tensor.data<float>(), temp_max.data<float>(),
tensor.numel()
); );
CUDA_CHECK(cudaGetLastError()); CUDA_CHECK(cudaGetLastError());
@@ -41,25 +46,34 @@ void CUDA::softmax(Tensor &tensor, Tensor &temp_max, Tensor &temp_sum) {
sum(tensor, temp_sum); sum(tensor, temp_sum);
Kernels::vec_scalar_div<<<gridSize, BLOCK_SIZE>>>( Kernels::vec_scalar_div<<<gridSize, BLOCK_SIZE>>>(
tensor.data<float>(), tensor.data<float>(), temp_sum.data<float>(), tensor.numel() tensor.data<float>(), tensor.data<float>(), temp_sum.data<float>(),
tensor.numel()
); );
CUDA_CHECK(cudaGetLastError()); CUDA_CHECK(cudaGetLastError());
CUDA_CHECK(cudaDeviceSynchronize()); CUDA_CHECK(cudaDeviceSynchronize());
} }
CUDANet::Tensor& CUDA::dense(CUDANet::Tensor &weights, CUDANet::Tensor &biases, CUDANet::Tensor &input, CUDANet::Tensor &output, size_t input_size, size_t output_size) { CUDANet::Tensor& CUDA::dense(
const CUDANet::Tensor& weights,
const CUDANet::Tensor& biases,
const CUDANet::Tensor& input,
CUDANet::Tensor& output,
const size_t input_size,
const size_t output_size
) {
auto forwardGridSize = auto forwardGridSize =
(std::max(input_size, output_size) + BLOCK_SIZE - 1) / BLOCK_SIZE; (std::max(input_size, output_size) + BLOCK_SIZE - 1) / BLOCK_SIZE;
auto biasGridSize = (output_size + BLOCK_SIZE - 1) / BLOCK_SIZE; auto biasGridSize = (output_size + BLOCK_SIZE - 1) / BLOCK_SIZE;
Kernels::mat_vec_mul<<<forwardGridSize, BLOCK_SIZE>>>( Kernels::mat_vec_mul<<<forwardGridSize, BLOCK_SIZE>>>(
weights.data<float>(), input.data<float>(), output.data<float>(), input_size, output_size weights.data<float>(), input.data<float>(), output.data<float>(),
input_size, output_size
); );
CUDA_CHECK(cudaGetLastError()); CUDA_CHECK(cudaGetLastError());
Kernels::vec_vec_add<<<biasGridSize, BLOCK_SIZE>>>( Kernels::vec_vec_add<<<biasGridSize, BLOCK_SIZE>>>(
biases.data<float>(), output.data<float>(), output.data<float>(), output_size biases.data<float>(), output.data<float>(), output.data<float>(),
output_size
); );
CUDA_CHECK(cudaGetLastError()); CUDA_CHECK(cudaGetLastError());
CUDA_CHECK(cudaDeviceSynchronize()); CUDA_CHECK(cudaDeviceSynchronize());

View File

@@ -1,34 +1,41 @@
#include "dense.hpp"
#include <format> #include <format>
#include <stdexcept> #include <stdexcept>
#include "dense.hpp"
using namespace CUDANet::Layers; using namespace CUDANet::Layers;
Dense::Dense(CUDANet::Backend *backend, CUDANet::Shape input_shape, CUDANet::Shape output_shape) Dense::Dense(CUDANet::Backend* backend, CUDANet::Shape in, CUDANet::Shape out)
: backend(backend), in_shape(input_shape), out_shape(output_shape) { : backend(backend),
in_shape(in),
out_shape(out),
weights(
CUDANet::Tensor{{in[0] * out[0]}, CUDANet::DType::FLOAT32, backend}
),
biases(CUDANet::Tensor({out[0]}, CUDANet::DType::FLOAT32, backend)),
output(CUDANet::Tensor({out[0]}, CUDANet::DType::FLOAT32, backend)) {
// Allocate memory for weights and biases // Allocate memory for weights and biases
if (input_shape.size() != 1) { if (in.size() != 1) {
throw std::runtime_error(std::format("Invalid shape. Expected [1], got {}", input_shape)); throw std::runtime_error(
std::format("Invalid shape. Expected [1], got {}", in)
);
} }
if (output_shape.size() != 1) { if (out.size() != 1) {
throw std::runtime_error(std::format("Invalid shape. Expected [1], got {}", output_shape)); throw std::runtime_error(
std::format("Invalid shape. Expected [1], got {}", out)
);
} }
auto input_len = input_shape[0]; auto input_len = in[0];
auto output_len = output_shape[0]; auto output_len = out[0];
auto weights = CUDANet::Tensor{Shape(input_len * output_len), CUDANet::DType::FLOAT32, backend};
auto biases = CUDANet::Tensor(Shape(output_len), CUDANet::DType::FLOAT32, backend);
auto output = CUDANet::Tensor(Shape(output_len), CUDANet::DType::FLOAT32, backend);
weights.zero(); weights.zero();
biases.zero(); biases.zero();
} }
CUDANet::Tensor& Dense::forward(CUDANet::Tensor &input) { CUDANet::Tensor& Dense::forward(const CUDANet::Tensor& input) {
backend->dense(weights, biases, input, output, in_shape[0], out_shape[0]); backend->dense(weights, biases, input, output, in_shape[0], out_shape[0]);
return output; return output;
} }

View File

@@ -6,6 +6,11 @@ using namespace CUDANet;
Tensor::Tensor(Shape shape, DType dtype, Backend* backend) Tensor::Tensor(Shape shape, DType dtype, Backend* backend)
: shape(shape), dtype(dtype), backend(backend), d_ptr(nullptr) { : shape(shape), dtype(dtype), backend(backend), d_ptr(nullptr) {
if (shape.empty()) {
throw std::runtime_error("Tensor shape cannot be empty");
}
// Count total elements // Count total elements
size_t count = 1; size_t count = 1;
for (const auto& dim : shape) { for (const auto& dim : shape) {
@@ -28,6 +33,40 @@ Tensor::Tensor(Shape shape, DType dtype, Backend* backend)
d_ptr = backend->allocate(total_size); d_ptr = backend->allocate(total_size);
} }
Tensor::Tensor(Tensor&& other) noexcept
: shape(std::move(other.shape)),
dtype(other.dtype),
total_elms(other.total_elms),
total_size(other.total_size),
backend(other.backend),
d_ptr(other.d_ptr)
{
other.d_ptr = nullptr;
other.backend = nullptr;
}
Tensor& Tensor::operator=(Tensor&& other) noexcept {
if (this != &other) {
// Clean up our current resources
if (d_ptr != nullptr && backend != nullptr) {
backend->deallocate(d_ptr);
}
// Steal other's resources
shape = std::move(other.shape);
dtype = other.dtype;
total_elms = other.total_elms;
total_size = other.total_size;
backend = other.backend;
d_ptr = other.d_ptr;
// Leave other in valid but empty state
other.d_ptr = nullptr;
other.backend = nullptr;
}
return *this;
}
Tensor::~Tensor() { Tensor::~Tensor() {
backend->deallocate(d_ptr); backend->deallocate(d_ptr);
d_ptr = nullptr; d_ptr = nullptr;
@@ -57,5 +96,5 @@ void Tensor::zero() {
template <typename T> template <typename T>
void Tensor::set_data(T *data) { void Tensor::set_data(T *data) {
backend->copy_to_device(*this, data, total_size) backend->copy_to_device(*this, data, total_size);
} }