Compare commits

..

3 Commits

12 changed files with 211 additions and 297 deletions

View File

@@ -1,6 +1,7 @@
#pragma once #pragma once
#include <cstddef> #include <cstddef>
#include "backend/tensor.hpp"
namespace CUDANet::Backend namespace CUDANet::Backend
{ {
@@ -13,7 +14,13 @@ public:
virtual void* allocate(size_t bytes) = 0; virtual void* allocate(size_t bytes) = 0;
virtual void deallocate(void* ptr) = 0; virtual void deallocate(void* ptr) = 0;
// Layer operations // Tensor ops
virtual void print(const CUDANet::Backend::Tensor &input) = 0;
virtual void clear(CUDANet::Backend::Tensor &input) = 0;
virtual void sum(const CUDANet::Backend::Tensor &input, CUDANet::Backend::Tensor &sum) = 0;
virtual void max(const CUDANet::Backend::Tensor &input, CUDANet::Backend::Tensor &max) = 0;
// Layer ops
virtual void relu(CUDANet::Backend::Tensor &tensor) = 0; virtual void relu(CUDANet::Backend::Tensor &tensor) = 0;
virtual void sigmoid(CUDANet::Backend::Tensor &tensor) = 0; virtual void sigmoid(CUDANet::Backend::Tensor &tensor) = 0;
virtual void softmax(CUDANet::Backend::Tensor &tensor, CUDANet::Backend::Tensor &temp_max, CUDANet::Backend::Tensor &temp_sum) = 0; virtual void softmax(CUDANet::Backend::Tensor &tensor, CUDANet::Backend::Tensor &temp_max, CUDANet::Backend::Tensor &temp_sum) = 0;

View File

@@ -11,7 +11,13 @@ public:
void* allocate(size_t bytes) override; void* allocate(size_t bytes) override;
void deallocate(void* ptr) override; void deallocate(void* ptr) override;
// Layer operations // Tensor ops
void print(const CUDANet::Backend::Tensor &input) override;
void clear(CUDANet::Backend::Tensor &input) override;
void sum(const CUDANet::Backend::Tensor &input, CUDANet::Backend::Tensor &sum) override;
void max(const CUDANet::Backend::Tensor &input, CUDANet::Backend::Tensor &max) override;
// Layer ops
void relu(CUDANet::Backend::Tensor &tensor) override; void relu(CUDANet::Backend::Tensor &tensor) override;
void sigmoid(CUDANet::Backend::Tensor &tensor) override; void sigmoid(CUDANet::Backend::Tensor &tensor) override;
void softmax(CUDANet::Backend::Tensor &tensor, CUDANet::Backend::Tensor &temp_max, CUDANet::Backend::Tensor &temp_sum) override; void softmax(CUDANet::Backend::Tensor &tensor, CUDANet::Backend::Tensor &temp_max, CUDANet::Backend::Tensor &temp_sum) override;

View File

@@ -23,22 +23,24 @@ public:
Tensor(Shape shape, DType dtype, IBackend* backend); Tensor(Shape shape, DType dtype, IBackend* backend);
~Tensor(); ~Tensor();
void* allocate();
void deallocate();
void toDevice(const void* hostPtr);
void toHost(void* hostPtr);
size_t size() const; size_t size() const;
size_t numel() const; size_t numel() const;
void* data() const;
template <typename T>
const T* data() const;
template <typename T>
T* data();
private: private:
Shape shape; Shape shape;
DType dtype; DType dtype;
size_t total_elms;
size_t total_size;
IBackend* backend; IBackend* backend;
void* devicePtr; void* d_ptr;
void* hostPtr;
}; };
} // namespace CUDANet::Backend } // namespace CUDANet::Backend

View File

@@ -30,7 +30,7 @@ class Activation {
* @param activation Type of activation * @param activation Type of activation
* @param length Length of the input * @param length Length of the input
*/ */
Activation(ActivationType activation, const int length); Activation(CUDANet::Backend::IBackend* backend, ActivationType activation, const int length);
/** /**
* @brief Destroy the Activation object * @brief Destroy the Activation object

View File

@@ -1,63 +0,0 @@
#ifndef CUDANET_VECTOR_H
#define CUDANET_VECTOR_H
namespace CUDANet::Utils {
/**
* @brief Utility function that prints a vector
*
* @param d_vec Pointer to the vector on device
* @param length Length of the vector
*/
void print_vec(const float *d_vec, const unsigned int length);
/**
* @brief Utility function that clears a vector
*
* @param d_vector Pointer to the vector on device
* @param len Length of the vector
*/
void clear(float *d_vector, const unsigned int len);
/**
* @brief Utility function that returns the sum of a vector
*
* @param d_vec Pointer to the vector
* @param length Length of the vector
*/
void sum(const float *d_vec, float *d_sum, const unsigned int length);
/**
* @brief Get the max of a vector
*
* @param d_vec Pointer to the vector
* @param length Length of the vector
*/
void max(const float *d_vec, float *d_max, const unsigned int length);
/**
* @brief Compute the mean of the vector
*
* @param d_vec Device pointer to the vector
* @param d_mean Device pointer to the mean
* @param d_length Device pointer to the length
* @param length Length of the vector
*/
void mean(const float *d_vec, float *d_mean, float *d_length, int length);
/**
* @brief Compute the variance of a vector
*
* @param d_vec
* @param d_var
* @param length
*/
void var(float *d_vec, float *d_var, float *d_length, const unsigned int length);
} // namespace CUDANet::Utils
#endif // CUDANET_VECTOR_H

View File

@@ -1,69 +1,39 @@
#include "backend/cuda_backend.cuh" #include <cuda_runtime.h>
#include "utils/cuda_helper.cuh"
#include "kernels/activation_functions.cuh" #include <cstdio>
#include "kernels/matmul.cuh" #include <cstdlib>
#include "utils/vector.cuh" #include <cuda_helper.cuh>
#include "backend/cuda.cuh"
cudaDeviceProp initializeCUDA() {
int deviceCount;
CUDA_CHECK(cudaGetDeviceCount(&deviceCount));
if (deviceCount == 0) {
std::fprintf(stderr, "No CUDA devices found. Exiting.\n");
std::exit(EXIT_FAILURE);
}
int device = 0;
CUDA_CHECK(cudaSetDevice(device));
cudaDeviceProp deviceProp;
CUDA_CHECK(cudaGetDeviceProperties(&deviceProp, device));
std::printf("Using CUDA device %d: %s\n", device, deviceProp.name);
return deviceProp;
}
using namespace CUDANet::Backend; using namespace CUDANet::Backend;
void* CUDABackend::allocate(size_t bytes) { void* CUDABackend::allocate(size_t bytes) {
void* devicePtr = nullptr; void* d_ptr = nullptr;
CUDA_CHECK(cudaMalloc(&devicePtr, bytes)); CUDA_CHECK(cudaMalloc(&d_ptr, bytes));
return devicePtr; return d_ptr;
} }
void CUDABackend::deallocate(void* ptr) { void CUDABackend::deallocate(void* ptr) {
CUDA_CHECK(cudaFree(ptr)); CUDA_CHECK(cudaFree(ptr));
} }
// void CUDABackend::copyToDevice(void* devicePtr, const void* hostPtr, size_t bytes) {
// CUDA_CHECK(cudaMemcpy(devicePtr, hostPtr, bytes, cudaMemcpyHostToDevice));
// CUDA_CHECK(cudaDeviceSynchronize());
// }
// void CUDABackend::copyToHost(void* hostPtr, const void* devicePtr, size_t bytes) {
// CUDA_CHECK(cudaMemcpy(hostPtr, devicePtr, bytes, cudaMemcpyDeviceToHost));
// CUDA_CHECK(cudaDeviceSynchronize());
// }
void CUDABackend::relu(Tensor &tensor) {
int gridSize = (tensor.numel() + BLOCK_SIZE - 1) / BLOCK_SIZE;
Kernels::relu<<<gridSize, BLOCK_SIZE>>>((float*)tensor.data(), (float*)tensor.data(), tensor.numel());
CUDA_CHECK(cudaGetLastError());
CUDA_CHECK(cudaDeviceSynchronize());
}
void CUDABackend::sigmoid(Tensor &tensor) {
int gridSize = (tensor.numel() + BLOCK_SIZE - 1) / BLOCK_SIZE;
Kernels::sigmoid<<<gridSize, BLOCK_SIZE>>>((float*)tensor.data(), (float*)tensor.data(), tensor.numel());
CUDA_CHECK(cudaGetLastError());
CUDA_CHECK(cudaDeviceSynchronize());
}
void CUDABackend::softmax(Tensor &tensor, Tensor &temp_max, Tensor &temp_sum) {
int gridSize = (tensor.numel() + BLOCK_SIZE - 1) / BLOCK_SIZE;
// Find max value
Utils::max(tensor, temp_max, tensor.numel());
// Subtract max value to improve numerical stability
Kernels::vec_scalar_sub<<<gridSize, BLOCK_SIZE>>>(
(float*)tensor.data(), (float*)tensor.data(), (float*)temp_max.data(), tensor.numel()
);
CUDA_CHECK(cudaGetLastError());
// Compute exponentials
Kernels::vec_exp<<<gridSize, BLOCK_SIZE>>>(
(float*)tensor.data(), (float*)tensor.data(), tensor.numel()
);
CUDA_CHECK(cudaGetLastError());
// Find sum
Utils::sum(tensor, temp_sum, tensor.numel());
Kernels::vec_scalar_div<<<gridSize, BLOCK_SIZE>>>(
(float*)tensor.data(), (float*)tensor.data(), (float*)temp_sum.data(), tensor.numel()
);
CUDA_CHECK(cudaGetLastError());
CUDA_CHECK(cudaDeviceSynchronize());
}

View File

@@ -0,0 +1,48 @@
#include "backend/cuda.cuh"
#include "utils/cuda_helper.cuh"
#include "kernels/activation_functions.cuh"
#include "kernels/matmul.cuh"
using namespace CUDANet::Backend;
void CUDABackend::relu(Tensor &tensor) {
int gridSize = (tensor.numel() + BLOCK_SIZE - 1) / BLOCK_SIZE;
Kernels::relu<<<gridSize, BLOCK_SIZE>>>(tensor.data<float>(), tensor.data<float>(), tensor.numel());
CUDA_CHECK(cudaGetLastError());
CUDA_CHECK(cudaDeviceSynchronize());
}
void CUDABackend::sigmoid(Tensor &tensor) {
int gridSize = (tensor.numel() + BLOCK_SIZE - 1) / BLOCK_SIZE;
Kernels::sigmoid<<<gridSize, BLOCK_SIZE>>>(tensor.data<float>(), tensor.data<float>(), tensor.numel());
CUDA_CHECK(cudaGetLastError());
CUDA_CHECK(cudaDeviceSynchronize());
}
void CUDABackend::softmax(Tensor &tensor, Tensor &temp_max, Tensor &temp_sum) {
int gridSize = (tensor.numel() + BLOCK_SIZE - 1) / BLOCK_SIZE;
// Find max value
max(tensor, temp_max);
// Subtract max value to improve numerical stability
Kernels::vec_scalar_sub<<<gridSize, BLOCK_SIZE>>>(
tensor.data<float>(), tensor.data<float>(), temp_max.data<float>(), tensor.numel()
);
CUDA_CHECK(cudaGetLastError());
// Compute exponentials
Kernels::vec_exp<<<gridSize, BLOCK_SIZE>>>(
tensor.data<float>(), tensor.data<float>(), tensor.numel()
);
CUDA_CHECK(cudaGetLastError());
// Find sum
sum(tensor, temp_sum);
Kernels::vec_scalar_div<<<gridSize, BLOCK_SIZE>>>(
tensor.data<float>(), tensor.data<float>(), temp_sum.data<float>(), tensor.numel()
);
CUDA_CHECK(cudaGetLastError());
CUDA_CHECK(cudaDeviceSynchronize());
}

View File

@@ -0,0 +1,64 @@
#include <iostream>
#include "backend/backend.hpp"
#include "backend/cuda.cuh"
#include "utils/cuda_helper.cuh"
#include "kernels/matmul.cuh"
using namespace CUDANet::Backend;
void CUDABackend::print(const CUDANet::Backend::Tensor &input) {
auto length = input.numel();
std::vector<float> h_vec(input.numel());
CUDA_CHECK(cudaMemcpy(
h_vec.data(), input.data<float>(), sizeof(float) * length, cudaMemcpyDeviceToHost
));
for (int i = 0; i < length; ++i) {
std::cout << h_vec[i] << ", ";
}
std::cout << std::endl;
}
void CUDABackend::clear(CUDANet::Backend::Tensor &input) {
CUDA_CHECK(cudaMemset(input.data<float>(), 0, sizeof(float) * input.numel()));
}
void CUDABackend::sum(const CUDANet::Backend::Tensor &input, CUDANet::Backend::Tensor &sum) {
auto length = input.numel();
const int gridSize = ( + BLOCK_SIZE - 1) / BLOCK_SIZE;
CUDANet::Kernels::sum_reduce<<<gridSize, BLOCK_SIZE>>>(
input.data<float>(), sum.data<float>(), length
);
CUDA_CHECK(cudaGetLastError());
int remaining = gridSize;
while (remaining > 1) {
int blocks_needed = (remaining + BLOCK_SIZE - 1) / BLOCK_SIZE;
CUDANet::Kernels::sum_reduce<<<blocks_needed, BLOCK_SIZE>>>(sum.data<float>(), sum.data<float>(), remaining);
CUDA_CHECK(cudaGetLastError());
remaining = blocks_needed;
}
}
void CUDABackend::max(const CUDANet::Backend::Tensor &input, CUDANet::Backend::Tensor &max) {
auto length = input.numel();
const int grid_size = (length + BLOCK_SIZE - 1) / BLOCK_SIZE;
Kernels::max_reduce<<<grid_size, BLOCK_SIZE>>>(input.data<float>(), max.data<float>(), length);
CUDA_CHECK(cudaGetLastError());
int remaining = grid_size;
while (remaining > 1) {
int blocks_needed = (remaining + BLOCK_SIZE - 1) / BLOCK_SIZE;
CUDANet::Kernels::max_reduce<<<blocks_needed, BLOCK_SIZE>>>(max.data<float>(), max.data<float>(), remaining);
CUDA_CHECK(cudaGetLastError());
remaining = blocks_needed;
}
}

View File

@@ -1,26 +0,0 @@
#include <cuda_runtime.h>
#include <cstdio>
#include <cstdlib>
#include "cuda_helper.cuh"
cudaDeviceProp initializeCUDA() {
int deviceCount;
CUDA_CHECK(cudaGetDeviceCount(&deviceCount));
if (deviceCount == 0) {
std::fprintf(stderr, "No CUDA devices found. Exiting.\n");
std::exit(EXIT_FAILURE);
}
int device = 0;
CUDA_CHECK(cudaSetDevice(device));
cudaDeviceProp deviceProp;
CUDA_CHECK(cudaGetDeviceProperties(&deviceProp, device));
std::printf("Using CUDA device %d: %s\n", device, deviceProp.name);
return deviceProp;
}

View File

@@ -1,107 +0,0 @@
#include <iostream>
#include <vector>
#include "vector.cuh"
#include "matmul.cuh"
#include "cuda_helper.cuh"
using namespace CUDANet;
void Utils::print_vec(const float* d_vec, const unsigned int length) {
std::vector<float> h_vec(length);
CUDA_CHECK(cudaMemcpy(
h_vec.data(), d_vec, sizeof(float) * length, cudaMemcpyDeviceToHost
));
for (int i = 0; i < length; ++i) {
std::cout << h_vec[i] << ", ";
}
std::cout << std::endl;
}
void Utils::clear(float* d_vec, const unsigned int length) {
CUDA_CHECK(cudaMemset(d_vec, 0, sizeof(float) * length));
}
void Utils::max(const float* d_vec, float* d_max, const unsigned int length) {
const int grid_size = (length + BLOCK_SIZE - 1) / BLOCK_SIZE;
Kernels::max_reduce<<<grid_size, BLOCK_SIZE>>>(d_vec, d_max, length);
CUDA_CHECK(cudaGetLastError());
int remaining = grid_size;
while (remaining > 1) {
int blocks_needed = (remaining + BLOCK_SIZE - 1) / BLOCK_SIZE;
CUDANet::Kernels::max_reduce<<<blocks_needed, BLOCK_SIZE>>>(d_max, d_max, remaining);
CUDA_CHECK(cudaGetLastError());
remaining = blocks_needed;
}
}
void Utils::sum(const float* d_vec, float* d_sum, const unsigned int length) {
const int gridSize = (length + BLOCK_SIZE - 1) / BLOCK_SIZE;
CUDANet::Kernels::sum_reduce<<<gridSize, BLOCK_SIZE>>>(
d_vec, d_sum, length
);
CUDA_CHECK(cudaGetLastError());
int remaining = gridSize;
while (remaining > 1) {
int blocks_needed = (remaining + BLOCK_SIZE - 1) / BLOCK_SIZE;
CUDANet::Kernels::sum_reduce<<<blocks_needed, BLOCK_SIZE>>>(d_sum, d_sum, remaining);
CUDA_CHECK(cudaGetLastError());
remaining = blocks_needed;
}
}
void Utils::mean(const float* d_vec, float* d_mean, float *d_length, int length) {
Utils::sum(d_vec, d_mean, length);
const int gridSize = (length + BLOCK_SIZE - 1) / BLOCK_SIZE;
Kernels::vec_scalar_div<<<gridSize, BLOCK_SIZE>>>(
d_mean,
d_mean,
d_length,
length
);
CUDA_CHECK(cudaGetLastError());
}
void Utils::var(float* d_vec, float* d_var, float *d_length, const unsigned int length) {
const int gridSize = (length + BLOCK_SIZE - 1) / BLOCK_SIZE;
Kernels::vec_vec_mul<<<gridSize, BLOCK_SIZE>>>(
d_vec,
d_vec,
d_var,
length
);
CUDA_CHECK(cudaGetLastError());
// Sum over all differences
Utils::sum(
d_var,
d_var,
length
);
// Divide by difference sum / length -> variance
Kernels::vec_scalar_div<<<gridSize, BLOCK_SIZE>>>(
d_var,
d_var,
d_length,
length
);
CUDA_CHECK(cudaGetLastError());
}

View File

@@ -1,39 +1,52 @@
#include <stdexcept>
#include "backend/tensor.hpp" #include "backend/tensor.hpp"
#include <stdexcept>
using namespace CUDANet::Backend; using namespace CUDANet::Backend;
Tensor::Tensor(Shape shape, DType dtype, IBackend* backend) Tensor::Tensor(Shape shape, DType dtype, IBackend* backend)
: shape(shape), dtype(dtype), backend(backend), devicePtr(nullptr), hostPtr(nullptr) {} : shape(shape), dtype(dtype), backend(backend), d_ptr(nullptr) {
// Count total elements
Tensor::~Tensor() { size_t count = 1;
deallocate();
}
size_t Tensor::numel() const {
size_t totalElements = 1;
for (const auto& dim : shape) { for (const auto& dim : shape) {
totalElements *= dim; count *= dim;
}
return totalElements;
} }
total_elms = count;
size_t Tensor::size() const { // Compute total size (bytes)
size_t totalSize = numel(); size_t type_size = 0;
size_t typeSize = 0;
switch (dtype) { switch (dtype) {
case DType::FLOAT32: case DType::FLOAT32:
typeSize = 4; type_size = 4;
break; break;
default: default:
throw std::runtime_error("Unsupported data type"); throw std::runtime_error("Unsupported data type");
} }
total_size = total_elms * type_size;
return totalSize * typeSize; // Allocate memory on backend
d_ptr = backend->allocate(total_size);
} }
void* Tensor::data() const { Tensor::~Tensor() {
return devicePtr; backend->deallocate(d_ptr);
d_ptr = nullptr;
}
size_t Tensor::numel() const {
return total_elms;
}
size_t Tensor::size() const {
return total_size;
}
template <typename T>
const T* Tensor::data() const {
return static_cast<T*>(d_ptr);
}
template <typename T>
T* Tensor::data() {
return static_cast<T*>(d_ptr);
} }

View File

@@ -6,13 +6,13 @@
using namespace CUDANet::Layers; using namespace CUDANet::Layers;
Activation::Activation(ActivationType activation, const int length) Activation::Activation(CUDANet::Backend::IBackend* backend, ActivationType activation, const int length)
: activationType(activation), length(length) { : backend(backend), activationType(activation), length(length) {
if (activationType == SOFTMAX) { if (activationType == SOFTMAX) {
softmax_sum = CUDANet::Backend::Tensor({static_cast<size_t>(length)}, CUDANet::Backend::DType::FLOAT32, nullptr); softmax_sum = CUDANet::Backend::Tensor({static_cast<size_t>(length)}, CUDANet::Backend::DType::FLOAT32, backend);
tensor_max = CUDANet::Backend::Tensor({static_cast<size_t>(length)}, CUDANet::Backend::DType::FLOAT32, nullptr); tensor_max = CUDANet::Backend::Tensor({static_cast<size_t>(length)}, CUDANet::Backend::DType::FLOAT32, backend);
} }
} }
@@ -23,10 +23,10 @@ void Activation::activate(CUDANet::Backend::Tensor input) {
backend->sigmoid(input); backend->sigmoid(input);
break; break;
case ActivationType::RELU: case ActivationType::RELU:
/* code */ backend->relu(input);
break; break;
case ActivationType::SOFTMAX: case ActivationType::SOFTMAX:
/* code */ backend->softmax(input, tensor_max, softmax_sum);
break; break;
default: default:
break; break;