From 7e27c876733b6a518d94c23ff046d848f87ac7de Mon Sep 17 00:00:00 2001 From: LordMathis Date: Thu, 27 Nov 2025 22:41:49 +0100 Subject: [PATCH] Fix compilation errors and warnings --- CMakeLists.txt | 5 +++++ include/backend.hpp | 5 +++-- include/tensor.hpp | 3 +++ src/backend_factory.cpp | 5 ++--- src/backends/cuda/layer_ops.cu | 36 ++++++++++++++++----------------- src/backends/cuda/tensor_ops.cu | 14 ++++++------- src/layers/activation.cpp | 4 ++-- src/layers/add.cpp | 2 +- src/layers/avg_pooling.cpp | 4 ++-- src/layers/batch_norm.cpp | 4 +--- src/layers/concat.cpp | 2 +- src/layers/conv2d.cpp | 3 +-- src/layers/dense.cpp | 4 ++-- src/layers/max_pool.cpp | 4 ++-- src/model.cpp | 6 +++--- src/module.cpp | 4 ++-- src/tensor.cpp | 8 ++++++-- 17 files changed, 61 insertions(+), 52 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index bf8f7a8..b39144b 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -43,6 +43,11 @@ set(CMAKE_EXPORT_COMPILE_COMMANDS ON) add_library(${PROJECT_NAME} STATIC ${LIBRARY_SOURCES}) if(USE_CUDA) + # Enable relocatable device code for proper template instantiation across translation units + set_target_properties(${PROJECT_NAME} PROPERTIES + CUDA_SEPARABLE_COMPILATION ON + CUDA_RUNTIME_LIBRARY Shared + ) target_link_libraries(${PROJECT_NAME} CUDA::cudart) endif() diff --git a/include/backend.hpp b/include/backend.hpp index 938be54..b1d97a2 100644 --- a/include/backend.hpp +++ b/include/backend.hpp @@ -8,9 +8,10 @@ namespace CUDANet { -// Forward declaration -class Tensor; +// Forward declarations class Backend; +class Tensor; +enum class DType; enum BackendType { CUDA_BACKEND, CPU_BACKEND }; diff --git a/include/tensor.hpp b/include/tensor.hpp index 691c2aa..3b329c0 100644 --- a/include/tensor.hpp +++ b/include/tensor.hpp @@ -29,6 +29,8 @@ size_t dtype_size(DType dtype) { } } +class Backend; + class Tensor { public: @@ -49,6 +51,7 @@ public: size_t size() const; size_t numel() const; + void* device_ptr() const; void* device_ptr(); void zero(); diff --git a/src/backend_factory.cpp b/src/backend_factory.cpp index 705468b..d2301d4 100644 --- a/src/backend_factory.cpp +++ b/src/backend_factory.cpp @@ -13,6 +13,7 @@ std::unique_ptr BackendFactory::create(BackendType backend_type, const switch (backend_type) { case BackendType::CUDA_BACKEND: + { #ifdef USE_CUDA if (!CUDANet::Backends::CUDA::is_cuda_available()) { @@ -20,14 +21,12 @@ std::unique_ptr BackendFactory::create(BackendType backend_type, const } auto cuda = std::make_unique(config); - cuda.initialize(); - return cuda; #else throw std::runtime_error("Library was compiled without CUDA support."); #endif - + } break; default: diff --git a/src/backends/cuda/layer_ops.cu b/src/backends/cuda/layer_ops.cu index 3d07fc9..4d927f3 100644 --- a/src/backends/cuda/layer_ops.cu +++ b/src/backends/cuda/layer_ops.cu @@ -213,7 +213,7 @@ CUDANet::Tensor& CUDA::conv2d_impl( ); Kernels::convolution<<>>( - static_cast(input.device_ptr())(), static_cast(weights.device_ptr())(), static_cast(biases.device_ptr())(), static_cast(output.device_ptr())(), + static_cast(input.device_ptr()), static_cast(weights.device_ptr()), static_cast(biases.device_ptr()), static_cast(output.device_ptr()), in_shape, padding_shape, kernel_shape, stride_shape, out_shape ); CUDA_CHECK(cudaGetLastError()); @@ -273,7 +273,7 @@ CUDANet::Tensor& CUDA::max_pool2d_impl( ); Kernels::max_pool<<>>( - static_cast(input.device_ptr())(), static_cast(output.device_ptr())(), input_shape, output_shape, + static_cast(input.device_ptr()), static_cast(output.device_ptr()), input_shape, output_shape, pool_shape, stride_shape, padding_shape ); CUDA_CHECK(cudaGetLastError()); @@ -333,7 +333,7 @@ CUDANet::Tensor& CUDA::avg_pool2d_impl( ); Kernels::avg_pool<<>>( - static_cast(input.device_ptr())(), static_cast(output.device_ptr())(), input_shape, output_shape, + static_cast(input.device_ptr()), static_cast(output.device_ptr()), input_shape, output_shape, pool_shape, stride_shape, padding_shape ); CUDA_CHECK(cudaGetLastError()); @@ -394,34 +394,34 @@ CUDANet::Tensor& CUDA::batch_norm_impl( for (int i = 0; i < input_shape[2]; i++) { // Subtract mean from input Kernels::vec_scalar_sub<<>>( - static_cast(input.device_ptr())() + i * input_shape[0] * input_shape[1], - static_cast(output.device_ptr())() + i * input_shape[0] * input_shape[1], - &static_cast(running_mean.device_ptr())()[i], input_shape[0] * input_shape[1] + static_cast(input.device_ptr()) + i * input_shape[0] * input_shape[1], + static_cast(output.device_ptr()) + i * input_shape[0] * input_shape[1], + &static_cast(running_mean.device_ptr())[i], input_shape[0] * input_shape[1] ); CUDA_CHECK(cudaGetLastError()); // Divide by sqrt(running_var + epsilon) Kernels::vec_scale<<>>( - static_cast(output.device_ptr())() + i * input_shape[0] * input_shape[1], - static_cast(output.device_ptr())() + i * input_shape[0] * input_shape[1], - &static_cast(running_var.device_ptr())()[i], static_cast(epsilon.device_ptr())(), + static_cast(output.device_ptr()) + i * input_shape[0] * input_shape[1], + static_cast(output.device_ptr()) + i * input_shape[0] * input_shape[1], + &static_cast(running_var.device_ptr())[i], static_cast(epsilon.device_ptr()), input_shape[0] * input_shape[1] ); CUDA_CHECK(cudaGetLastError()); // Multiply by weights Kernels::vec_scalar_mul<<>>( - static_cast(output.device_ptr())() + i * input_shape[0] * input_shape[1], - static_cast(output.device_ptr())() + i * input_shape[0] * input_shape[1], - &static_cast(weights.device_ptr())()[i], input_shape[0] * input_shape[1] + static_cast(output.device_ptr()) + i * input_shape[0] * input_shape[1], + static_cast(output.device_ptr()) + i * input_shape[0] * input_shape[1], + &static_cast(weights.device_ptr())[i], input_shape[0] * input_shape[1] ); CUDA_CHECK(cudaGetLastError()); // Add biases Kernels::vec_scalar_add<<>>( - static_cast(output.device_ptr())() + i * input_shape[0] * input_shape[1], - static_cast(output.device_ptr())() + i * input_shape[0] * input_shape[1], - &static_cast(biases.device_ptr())()[i], input_shape[0] * input_shape[1] + static_cast(output.device_ptr()) + i * input_shape[0] * input_shape[1], + static_cast(output.device_ptr()) + i * input_shape[0] * input_shape[1], + &static_cast(biases.device_ptr())[i], input_shape[0] * input_shape[1] ); CUDA_CHECK(cudaGetLastError()); } @@ -460,12 +460,12 @@ CUDANet::Tensor& CUDA::concat_impl( CUDANet::Tensor& output ) { CUDA_CHECK(cudaMemcpy( - static_cast(output.device_ptr())(), static_cast(input_a.device_ptr())(), input_a.size(), + static_cast(output.device_ptr()), static_cast(input_a.device_ptr()), input_a.size(), cudaMemcpyDeviceToDevice )); CUDA_CHECK(cudaMemcpy( - static_cast(output.device_ptr())() + input_a.numel(), static_cast(input_b.device_ptr())(), input_b.size(), + static_cast(output.device_ptr()) + input_a.numel(), static_cast(input_b.device_ptr()), input_b.size(), cudaMemcpyDeviceToDevice )); @@ -508,7 +508,7 @@ CUDANet::Tensor& CUDA::add_impl( auto gridSize = (input_a.numel() + BLOCK_SIZE - 1) / BLOCK_SIZE; Kernels::vec_vec_add<<>>( - static_cast(input_a.device_ptr())(), static_cast(input_b.device_ptr())(), static_cast(output.device_ptr())(), input_a.numel() + static_cast(input_a.device_ptr()), static_cast(input_b.device_ptr()), static_cast(output.device_ptr()), input_a.numel() ); CUDA_CHECK(cudaGetLastError()); CUDA_CHECK(cudaDeviceSynchronize()); diff --git a/src/backends/cuda/tensor_ops.cu b/src/backends/cuda/tensor_ops.cu index eeaf77e..901a52a 100644 --- a/src/backends/cuda/tensor_ops.cu +++ b/src/backends/cuda/tensor_ops.cu @@ -26,7 +26,7 @@ void CUDA::print_impl(const CUDANet::Tensor &input) { std::vector h_vec(input.numel()); CUDA_CHECK(cudaMemcpy( - h_vec.data(), static_cast(input.device_ptr())(), sizeof(T) * length, cudaMemcpyDeviceToHost + h_vec.data(), static_cast(input.device_ptr()), sizeof(T) * length, cudaMemcpyDeviceToHost )); for (int i = 0; i < length; ++i) { @@ -56,7 +56,7 @@ template void CUDA::fill_impl(CUDANet::Tensor &input, int value); template void CUDA::fill_impl(CUDANet::Tensor &input, int value) { - CUDA_CHECK(cudaMemset(static_cast(input.device_ptr())(), value, sizeof(T) * input.numel())); + CUDA_CHECK(cudaMemset(static_cast(input.device_ptr()), value, sizeof(T) * input.numel())); } void CUDA::copy_to_device(CUDANet::Tensor &tensor, void *data, size_t size) { @@ -75,7 +75,7 @@ template void CUDA::copy_to_device_impl(CUDANet::Tensor &tensor, void *da template void CUDA::copy_to_device_impl(CUDANet::Tensor &tensor, void *data, size_t size) { - CUDA_CHECK(cudaMemcpy(static_cast(tensor.device_ptr())(), data, size, cudaMemcpyHostToDevice)); + CUDA_CHECK(cudaMemcpy(static_cast(tensor.device_ptr()), data, size, cudaMemcpyHostToDevice)); } void CUDA::sum(const CUDANet::Tensor &input, CUDANet::Tensor &sum) { @@ -98,14 +98,14 @@ void CUDA::sum_impl(const CUDANet::Tensor &input, CUDANet::Tensor &sum) { const int gridSize = (length + BLOCK_SIZE - 1) / BLOCK_SIZE; CUDANet::Kernels::sum_reduce<<>>( - static_cast(input.device_ptr())(), static_cast(sum.device_ptr())(), length + static_cast(input.device_ptr()), static_cast(sum.device_ptr()), length ); CUDA_CHECK(cudaGetLastError()); int remaining = gridSize; while (remaining > 1) { int blocks_needed = (remaining + BLOCK_SIZE - 1) / BLOCK_SIZE; - CUDANet::Kernels::sum_reduce<<>>(static_cast(sum.device_ptr())(), static_cast(sum.device_ptr())(), remaining); + CUDANet::Kernels::sum_reduce<<>>(static_cast(sum.device_ptr()), static_cast(sum.device_ptr()), remaining); CUDA_CHECK(cudaGetLastError()); remaining = blocks_needed; @@ -131,14 +131,14 @@ void CUDA::max_impl(const CUDANet::Tensor &input, CUDANet::Tensor &max) { auto length = input.numel(); const int grid_size = (length + BLOCK_SIZE - 1) / BLOCK_SIZE; - Kernels::max_reduce<<>>(static_cast(input.device_ptr())(), static_cast(max.device_ptr())(), length); + Kernels::max_reduce<<>>(static_cast(input.device_ptr()), static_cast(max.device_ptr()), length); CUDA_CHECK(cudaGetLastError()); int remaining = grid_size; while (remaining > 1) { int blocks_needed = (remaining + BLOCK_SIZE - 1) / BLOCK_SIZE; - CUDANet::Kernels::max_reduce<<>>(static_cast(max.device_ptr())(), static_cast(max.device_ptr())(), remaining); + CUDANet::Kernels::max_reduce<<>>(static_cast(max.device_ptr()), static_cast(max.device_ptr()), remaining); CUDA_CHECK(cudaGetLastError()); remaining = blocks_needed; diff --git a/src/layers/activation.cpp b/src/layers/activation.cpp index 8a4caa2..c447c0c 100644 --- a/src/layers/activation.cpp +++ b/src/layers/activation.cpp @@ -1,11 +1,11 @@ -#include "activation.hpp" - #include #include #include +#include "layers/activation.hpp" #include "tensor.hpp" + using namespace CUDANet::Layers; Activation::Activation( diff --git a/src/layers/add.cpp b/src/layers/add.cpp index be989bd..8f71660 100644 --- a/src/layers/add.cpp +++ b/src/layers/add.cpp @@ -1,4 +1,4 @@ -#include "add.hpp" +#include "layers/add.hpp" using namespace CUDANet::Layers; diff --git a/src/layers/avg_pooling.cpp b/src/layers/avg_pooling.cpp index 90719d4..0855abc 100644 --- a/src/layers/avg_pooling.cpp +++ b/src/layers/avg_pooling.cpp @@ -1,7 +1,7 @@ +#include #include -#include "avg_pool.hpp" -#include +#include "layers/avg_pool.hpp" using namespace CUDANet::Layers; diff --git a/src/layers/batch_norm.cpp b/src/layers/batch_norm.cpp index d73d40b..5351984 100644 --- a/src/layers/batch_norm.cpp +++ b/src/layers/batch_norm.cpp @@ -1,9 +1,7 @@ -#include "batch_norm.hpp" - #include #include -#include "activation.hpp" +#include "layers/batch_norm.hpp" #include "layer.hpp" using namespace CUDANet::Layers; diff --git a/src/layers/concat.cpp b/src/layers/concat.cpp index db1383b..899977e 100644 --- a/src/layers/concat.cpp +++ b/src/layers/concat.cpp @@ -1,4 +1,4 @@ -#include "concat.hpp" +#include "layers/concat.hpp" using namespace CUDANet::Layers; diff --git a/src/layers/conv2d.cpp b/src/layers/conv2d.cpp index 628b2b5..6712d2a 100644 --- a/src/layers/conv2d.cpp +++ b/src/layers/conv2d.cpp @@ -1,8 +1,7 @@ -#include "conv2d.hpp" - #include #include +#include "layers/conv2d.hpp" #include "layer.hpp" #include "tensor.hpp" diff --git a/src/layers/dense.cpp b/src/layers/dense.cpp index db160c9..855606c 100644 --- a/src/layers/dense.cpp +++ b/src/layers/dense.cpp @@ -1,8 +1,8 @@ -#include "dense.hpp" - #include #include +#include "layers/dense.hpp" + using namespace CUDANet::Layers; Dense::Dense(CUDANet::Shape in_shape, CUDANet::Shape out_shape, CUDANet::Backend* backend) diff --git a/src/layers/max_pool.cpp b/src/layers/max_pool.cpp index ddf2935..9bc4888 100644 --- a/src/layers/max_pool.cpp +++ b/src/layers/max_pool.cpp @@ -1,7 +1,7 @@ -#include "max_pool.hpp" - #include +#include "layers/max_pool.hpp" + using namespace CUDANet::Layers; MaxPool2d::MaxPool2d( diff --git a/src/model.cpp b/src/model.cpp index d04adc0..3c083d2 100644 --- a/src/model.cpp +++ b/src/model.cpp @@ -1,5 +1,3 @@ -#include "model.hpp" - #include #include #include @@ -8,7 +6,9 @@ #include #include "layer.hpp" -#include "batch_norm.hpp" +#include "layers/batch_norm.hpp" + +#include "model.hpp" using namespace CUDANet; diff --git a/src/module.cpp b/src/module.cpp index dda1b78..1d7ddd6 100644 --- a/src/module.cpp +++ b/src/module.cpp @@ -1,7 +1,7 @@ -#include "module.hpp" - #include +#include "module.hpp" + using namespace CUDANet; CUDANet::Shape Module::input_shape() { diff --git a/src/tensor.cpp b/src/tensor.cpp index 70f5d9d..fbfd2b4 100644 --- a/src/tensor.cpp +++ b/src/tensor.cpp @@ -1,7 +1,7 @@ -#include "tensor.hpp" - #include +#include "tensor.hpp" + using namespace CUDANet; Tensor::Tensor(Shape shape, CUDANet::Backend* backend) @@ -92,6 +92,10 @@ size_t Tensor::size() const { return total_size; } +void* Tensor::device_ptr() const { + return d_ptr; +} + void* Tensor::device_ptr() { return d_ptr; }