mirror of
https://github.com/lordmathis/CUDANet.git
synced 2025-12-22 14:24:22 +00:00
Fix some dense layer issues
This commit is contained in:
@@ -32,12 +32,12 @@ class Backend {
|
|||||||
) = 0;
|
) = 0;
|
||||||
|
|
||||||
virtual CUDANet::Tensor& dense(
|
virtual CUDANet::Tensor& dense(
|
||||||
CUDANet::Tensor& weights,
|
const CUDANet::Tensor& weights,
|
||||||
CUDANet::Tensor& biases,
|
const CUDANet::Tensor& biases,
|
||||||
CUDANet::Tensor& input,
|
const CUDANet::Tensor& input,
|
||||||
CUDANet::Tensor& output,
|
CUDANet::Tensor& output,
|
||||||
size_t input_size,
|
const size_t input_size,
|
||||||
size_t output_size
|
const size_t output_size
|
||||||
) = 0;
|
) = 0;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|||||||
@@ -29,12 +29,12 @@ class CUDA : public Backend {
|
|||||||
) override;
|
) override;
|
||||||
|
|
||||||
CUDANet::Tensor& dense(
|
CUDANet::Tensor& dense(
|
||||||
CUDANet::Tensor& weights,
|
const CUDANet::Tensor& weights,
|
||||||
CUDANet::Tensor& biases,
|
const CUDANet::Tensor& biases,
|
||||||
CUDANet::Tensor& input,
|
const CUDANet::Tensor& input,
|
||||||
CUDANet::Tensor& output,
|
CUDANet::Tensor& output,
|
||||||
size_t input_size,
|
const size_t input_size,
|
||||||
size_t output_size
|
const size_t output_size
|
||||||
) override;
|
) override;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|||||||
@@ -20,7 +20,7 @@ class Layer {
|
|||||||
|
|
||||||
virtual ~Layer(){};
|
virtual ~Layer(){};
|
||||||
|
|
||||||
virtual CUDANet::Tensor& forward(CUDANet::Tensor &input) = 0;
|
virtual CUDANet::Tensor& forward(const CUDANet::Tensor &input) = 0;
|
||||||
|
|
||||||
virtual CUDANet::Shape input_shape() = 0;
|
virtual CUDANet::Shape input_shape() = 0;
|
||||||
|
|
||||||
|
|||||||
@@ -18,7 +18,7 @@ class Dense : public Layer {
|
|||||||
|
|
||||||
~Dense();
|
~Dense();
|
||||||
|
|
||||||
CUDANet::Tensor& forward(CUDANet::Tensor &input) override;
|
CUDANet::Tensor& forward(const CUDANet::Tensor &input) override;
|
||||||
|
|
||||||
CUDANet::Shape input_shape() override;
|
CUDANet::Shape input_shape() override;
|
||||||
|
|
||||||
|
|||||||
@@ -22,6 +22,12 @@ public:
|
|||||||
|
|
||||||
Tensor() = default;
|
Tensor() = default;
|
||||||
Tensor(Shape shape, DType dtype, CUDANet::Backend* backend);
|
Tensor(Shape shape, DType dtype, CUDANet::Backend* backend);
|
||||||
|
|
||||||
|
Tensor(Tensor&& other) noexcept;
|
||||||
|
Tensor& operator=(Tensor&& other) noexcept;
|
||||||
|
Tensor(const Tensor&) = delete;
|
||||||
|
Tensor& operator=(const Tensor&) = delete;
|
||||||
|
|
||||||
~Tensor();
|
~Tensor();
|
||||||
|
|
||||||
size_t size() const;
|
size_t size() const;
|
||||||
|
|||||||
@@ -1,20 +1,24 @@
|
|||||||
#include "backend/cuda.cuh"
|
#include "backend/cuda.cuh"
|
||||||
#include "utils/cuda_helper.cuh"
|
|
||||||
#include "kernels/activation_functions.cuh"
|
#include "kernels/activation_functions.cuh"
|
||||||
#include "kernels/matmul.cuh"
|
#include "kernels/matmul.cuh"
|
||||||
|
#include "utils/cuda_helper.cuh"
|
||||||
|
|
||||||
using namespace CUDANet::Backend;
|
using namespace CUDANet::Backend;
|
||||||
|
|
||||||
void CUDA::relu(Tensor& tensor) {
|
void CUDA::relu(Tensor& tensor) {
|
||||||
int gridSize = (tensor.numel() + BLOCK_SIZE - 1) / BLOCK_SIZE;
|
int gridSize = (tensor.numel() + BLOCK_SIZE - 1) / BLOCK_SIZE;
|
||||||
Kernels::relu<<<gridSize, BLOCK_SIZE>>>(tensor.data<float>(), tensor.data<float>(), tensor.numel());
|
Kernels::relu<<<gridSize, BLOCK_SIZE>>>(
|
||||||
|
tensor.data<float>(), tensor.data<float>(), tensor.numel()
|
||||||
|
);
|
||||||
CUDA_CHECK(cudaGetLastError());
|
CUDA_CHECK(cudaGetLastError());
|
||||||
CUDA_CHECK(cudaDeviceSynchronize());
|
CUDA_CHECK(cudaDeviceSynchronize());
|
||||||
}
|
}
|
||||||
|
|
||||||
void CUDA::sigmoid(Tensor& tensor) {
|
void CUDA::sigmoid(Tensor& tensor) {
|
||||||
int gridSize = (tensor.numel() + BLOCK_SIZE - 1) / BLOCK_SIZE;
|
int gridSize = (tensor.numel() + BLOCK_SIZE - 1) / BLOCK_SIZE;
|
||||||
Kernels::sigmoid<<<gridSize, BLOCK_SIZE>>>(tensor.data<float>(), tensor.data<float>(), tensor.numel());
|
Kernels::sigmoid<<<gridSize, BLOCK_SIZE>>>(
|
||||||
|
tensor.data<float>(), tensor.data<float>(), tensor.numel()
|
||||||
|
);
|
||||||
CUDA_CHECK(cudaGetLastError());
|
CUDA_CHECK(cudaGetLastError());
|
||||||
CUDA_CHECK(cudaDeviceSynchronize());
|
CUDA_CHECK(cudaDeviceSynchronize());
|
||||||
}
|
}
|
||||||
@@ -27,7 +31,8 @@ void CUDA::softmax(Tensor &tensor, Tensor &temp_max, Tensor &temp_sum) {
|
|||||||
|
|
||||||
// Subtract max value to improve numerical stability
|
// Subtract max value to improve numerical stability
|
||||||
Kernels::vec_scalar_sub<<<gridSize, BLOCK_SIZE>>>(
|
Kernels::vec_scalar_sub<<<gridSize, BLOCK_SIZE>>>(
|
||||||
tensor.data<float>(), tensor.data<float>(), temp_max.data<float>(), tensor.numel()
|
tensor.data<float>(), tensor.data<float>(), temp_max.data<float>(),
|
||||||
|
tensor.numel()
|
||||||
);
|
);
|
||||||
CUDA_CHECK(cudaGetLastError());
|
CUDA_CHECK(cudaGetLastError());
|
||||||
|
|
||||||
@@ -41,25 +46,34 @@ void CUDA::softmax(Tensor &tensor, Tensor &temp_max, Tensor &temp_sum) {
|
|||||||
sum(tensor, temp_sum);
|
sum(tensor, temp_sum);
|
||||||
|
|
||||||
Kernels::vec_scalar_div<<<gridSize, BLOCK_SIZE>>>(
|
Kernels::vec_scalar_div<<<gridSize, BLOCK_SIZE>>>(
|
||||||
tensor.data<float>(), tensor.data<float>(), temp_sum.data<float>(), tensor.numel()
|
tensor.data<float>(), tensor.data<float>(), temp_sum.data<float>(),
|
||||||
|
tensor.numel()
|
||||||
);
|
);
|
||||||
CUDA_CHECK(cudaGetLastError());
|
CUDA_CHECK(cudaGetLastError());
|
||||||
CUDA_CHECK(cudaDeviceSynchronize());
|
CUDA_CHECK(cudaDeviceSynchronize());
|
||||||
}
|
}
|
||||||
|
|
||||||
CUDANet::Tensor& CUDA::dense(CUDANet::Tensor &weights, CUDANet::Tensor &biases, CUDANet::Tensor &input, CUDANet::Tensor &output, size_t input_size, size_t output_size) {
|
CUDANet::Tensor& CUDA::dense(
|
||||||
|
const CUDANet::Tensor& weights,
|
||||||
|
const CUDANet::Tensor& biases,
|
||||||
|
const CUDANet::Tensor& input,
|
||||||
|
CUDANet::Tensor& output,
|
||||||
|
const size_t input_size,
|
||||||
|
const size_t output_size
|
||||||
|
) {
|
||||||
auto forwardGridSize =
|
auto forwardGridSize =
|
||||||
(std::max(input_size, output_size) + BLOCK_SIZE - 1) / BLOCK_SIZE;
|
(std::max(input_size, output_size) + BLOCK_SIZE - 1) / BLOCK_SIZE;
|
||||||
auto biasGridSize = (output_size + BLOCK_SIZE - 1) / BLOCK_SIZE;
|
auto biasGridSize = (output_size + BLOCK_SIZE - 1) / BLOCK_SIZE;
|
||||||
|
|
||||||
Kernels::mat_vec_mul<<<forwardGridSize, BLOCK_SIZE>>>(
|
Kernels::mat_vec_mul<<<forwardGridSize, BLOCK_SIZE>>>(
|
||||||
weights.data<float>(), input.data<float>(), output.data<float>(), input_size, output_size
|
weights.data<float>(), input.data<float>(), output.data<float>(),
|
||||||
|
input_size, output_size
|
||||||
);
|
);
|
||||||
CUDA_CHECK(cudaGetLastError());
|
CUDA_CHECK(cudaGetLastError());
|
||||||
|
|
||||||
Kernels::vec_vec_add<<<biasGridSize, BLOCK_SIZE>>>(
|
Kernels::vec_vec_add<<<biasGridSize, BLOCK_SIZE>>>(
|
||||||
biases.data<float>(), output.data<float>(), output.data<float>(), output_size
|
biases.data<float>(), output.data<float>(), output.data<float>(),
|
||||||
|
output_size
|
||||||
);
|
);
|
||||||
CUDA_CHECK(cudaGetLastError());
|
CUDA_CHECK(cudaGetLastError());
|
||||||
CUDA_CHECK(cudaDeviceSynchronize());
|
CUDA_CHECK(cudaDeviceSynchronize());
|
||||||
|
|||||||
@@ -1,34 +1,41 @@
|
|||||||
|
#include "dense.hpp"
|
||||||
|
|
||||||
#include <format>
|
#include <format>
|
||||||
#include <stdexcept>
|
#include <stdexcept>
|
||||||
|
|
||||||
#include "dense.hpp"
|
|
||||||
|
|
||||||
using namespace CUDANet::Layers;
|
using namespace CUDANet::Layers;
|
||||||
|
|
||||||
Dense::Dense(CUDANet::Backend *backend, CUDANet::Shape input_shape, CUDANet::Shape output_shape)
|
Dense::Dense(CUDANet::Backend* backend, CUDANet::Shape in, CUDANet::Shape out)
|
||||||
: backend(backend), in_shape(input_shape), out_shape(output_shape) {
|
: backend(backend),
|
||||||
|
in_shape(in),
|
||||||
|
out_shape(out),
|
||||||
|
weights(
|
||||||
|
CUDANet::Tensor{{in[0] * out[0]}, CUDANet::DType::FLOAT32, backend}
|
||||||
|
),
|
||||||
|
biases(CUDANet::Tensor({out[0]}, CUDANet::DType::FLOAT32, backend)),
|
||||||
|
output(CUDANet::Tensor({out[0]}, CUDANet::DType::FLOAT32, backend)) {
|
||||||
// Allocate memory for weights and biases
|
// Allocate memory for weights and biases
|
||||||
|
|
||||||
if (input_shape.size() != 1) {
|
if (in.size() != 1) {
|
||||||
throw std::runtime_error(std::format("Invalid shape. Expected [1], got {}", input_shape));
|
throw std::runtime_error(
|
||||||
|
std::format("Invalid shape. Expected [1], got {}", in)
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (output_shape.size() != 1) {
|
if (out.size() != 1) {
|
||||||
throw std::runtime_error(std::format("Invalid shape. Expected [1], got {}", output_shape));
|
throw std::runtime_error(
|
||||||
|
std::format("Invalid shape. Expected [1], got {}", out)
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
auto input_len = input_shape[0];
|
auto input_len = in[0];
|
||||||
auto output_len = output_shape[0];
|
auto output_len = out[0];
|
||||||
|
|
||||||
auto weights = CUDANet::Tensor{Shape(input_len * output_len), CUDANet::DType::FLOAT32, backend};
|
|
||||||
auto biases = CUDANet::Tensor(Shape(output_len), CUDANet::DType::FLOAT32, backend);
|
|
||||||
auto output = CUDANet::Tensor(Shape(output_len), CUDANet::DType::FLOAT32, backend);
|
|
||||||
|
|
||||||
weights.zero();
|
weights.zero();
|
||||||
biases.zero();
|
biases.zero();
|
||||||
}
|
}
|
||||||
|
|
||||||
CUDANet::Tensor& Dense::forward(CUDANet::Tensor &input) {
|
CUDANet::Tensor& Dense::forward(const CUDANet::Tensor& input) {
|
||||||
backend->dense(weights, biases, input, output, in_shape[0], out_shape[0]);
|
backend->dense(weights, biases, input, output, in_shape[0], out_shape[0]);
|
||||||
return output;
|
return output;
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -6,6 +6,11 @@ using namespace CUDANet;
|
|||||||
|
|
||||||
Tensor::Tensor(Shape shape, DType dtype, Backend* backend)
|
Tensor::Tensor(Shape shape, DType dtype, Backend* backend)
|
||||||
: shape(shape), dtype(dtype), backend(backend), d_ptr(nullptr) {
|
: shape(shape), dtype(dtype), backend(backend), d_ptr(nullptr) {
|
||||||
|
|
||||||
|
if (shape.empty()) {
|
||||||
|
throw std::runtime_error("Tensor shape cannot be empty");
|
||||||
|
}
|
||||||
|
|
||||||
// Count total elements
|
// Count total elements
|
||||||
size_t count = 1;
|
size_t count = 1;
|
||||||
for (const auto& dim : shape) {
|
for (const auto& dim : shape) {
|
||||||
@@ -28,6 +33,40 @@ Tensor::Tensor(Shape shape, DType dtype, Backend* backend)
|
|||||||
d_ptr = backend->allocate(total_size);
|
d_ptr = backend->allocate(total_size);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Tensor::Tensor(Tensor&& other) noexcept
|
||||||
|
: shape(std::move(other.shape)),
|
||||||
|
dtype(other.dtype),
|
||||||
|
total_elms(other.total_elms),
|
||||||
|
total_size(other.total_size),
|
||||||
|
backend(other.backend),
|
||||||
|
d_ptr(other.d_ptr)
|
||||||
|
{
|
||||||
|
other.d_ptr = nullptr;
|
||||||
|
other.backend = nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
Tensor& Tensor::operator=(Tensor&& other) noexcept {
|
||||||
|
if (this != &other) {
|
||||||
|
// Clean up our current resources
|
||||||
|
if (d_ptr != nullptr && backend != nullptr) {
|
||||||
|
backend->deallocate(d_ptr);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Steal other's resources
|
||||||
|
shape = std::move(other.shape);
|
||||||
|
dtype = other.dtype;
|
||||||
|
total_elms = other.total_elms;
|
||||||
|
total_size = other.total_size;
|
||||||
|
backend = other.backend;
|
||||||
|
d_ptr = other.d_ptr;
|
||||||
|
|
||||||
|
// Leave other in valid but empty state
|
||||||
|
other.d_ptr = nullptr;
|
||||||
|
other.backend = nullptr;
|
||||||
|
}
|
||||||
|
return *this;
|
||||||
|
}
|
||||||
|
|
||||||
Tensor::~Tensor() {
|
Tensor::~Tensor() {
|
||||||
backend->deallocate(d_ptr);
|
backend->deallocate(d_ptr);
|
||||||
d_ptr = nullptr;
|
d_ptr = nullptr;
|
||||||
@@ -57,5 +96,5 @@ void Tensor::zero() {
|
|||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
void Tensor::set_data(T *data) {
|
void Tensor::set_data(T *data) {
|
||||||
backend->copy_to_device(*this, data, total_size)
|
backend->copy_to_device(*this, data, total_size);
|
||||||
}
|
}
|
||||||
Reference in New Issue
Block a user