mirror of
https://github.com/lordmathis/CUDANet.git
synced 2025-12-22 14:24:22 +00:00
Fix compilation errors and warnings
This commit is contained in:
@@ -43,6 +43,11 @@ set(CMAKE_EXPORT_COMPILE_COMMANDS ON)
|
||||
add_library(${PROJECT_NAME} STATIC ${LIBRARY_SOURCES})
|
||||
|
||||
if(USE_CUDA)
|
||||
# Enable relocatable device code for proper template instantiation across translation units
|
||||
set_target_properties(${PROJECT_NAME} PROPERTIES
|
||||
CUDA_SEPARABLE_COMPILATION ON
|
||||
CUDA_RUNTIME_LIBRARY Shared
|
||||
)
|
||||
target_link_libraries(${PROJECT_NAME} CUDA::cudart)
|
||||
endif()
|
||||
|
||||
|
||||
@@ -8,9 +8,10 @@
|
||||
|
||||
namespace CUDANet {
|
||||
|
||||
// Forward declaration
|
||||
class Tensor;
|
||||
// Forward declarations
|
||||
class Backend;
|
||||
class Tensor;
|
||||
enum class DType;
|
||||
|
||||
enum BackendType { CUDA_BACKEND, CPU_BACKEND };
|
||||
|
||||
|
||||
@@ -29,6 +29,8 @@ size_t dtype_size(DType dtype) {
|
||||
}
|
||||
}
|
||||
|
||||
class Backend;
|
||||
|
||||
class Tensor
|
||||
{
|
||||
public:
|
||||
@@ -49,6 +51,7 @@ public:
|
||||
size_t size() const;
|
||||
size_t numel() const;
|
||||
|
||||
void* device_ptr() const;
|
||||
void* device_ptr();
|
||||
|
||||
void zero();
|
||||
|
||||
@@ -13,6 +13,7 @@ std::unique_ptr<Backend> BackendFactory::create(BackendType backend_type, const
|
||||
switch (backend_type)
|
||||
{
|
||||
case BackendType::CUDA_BACKEND:
|
||||
{
|
||||
#ifdef USE_CUDA
|
||||
|
||||
if (!CUDANet::Backends::CUDA::is_cuda_available()) {
|
||||
@@ -20,14 +21,12 @@ std::unique_ptr<Backend> BackendFactory::create(BackendType backend_type, const
|
||||
}
|
||||
|
||||
auto cuda = std::make_unique<CUDANet::Backends::CUDA>(config);
|
||||
cuda.initialize();
|
||||
|
||||
return cuda;
|
||||
|
||||
#else
|
||||
throw std::runtime_error("Library was compiled without CUDA support.");
|
||||
#endif
|
||||
|
||||
}
|
||||
break;
|
||||
|
||||
default:
|
||||
|
||||
@@ -213,7 +213,7 @@ CUDANet::Tensor& CUDA::conv2d_impl(
|
||||
);
|
||||
|
||||
Kernels::convolution<<<grid, block>>>(
|
||||
static_cast<T*>(input.device_ptr())(), static_cast<T*>(weights.device_ptr())(), static_cast<T*>(biases.device_ptr())(), static_cast<T*>(output.device_ptr())(),
|
||||
static_cast<T*>(input.device_ptr()), static_cast<T*>(weights.device_ptr()), static_cast<T*>(biases.device_ptr()), static_cast<T*>(output.device_ptr()),
|
||||
in_shape, padding_shape, kernel_shape, stride_shape, out_shape
|
||||
);
|
||||
CUDA_CHECK(cudaGetLastError());
|
||||
@@ -273,7 +273,7 @@ CUDANet::Tensor& CUDA::max_pool2d_impl(
|
||||
);
|
||||
|
||||
Kernels::max_pool<<<grid, block>>>(
|
||||
static_cast<T*>(input.device_ptr())(), static_cast<T*>(output.device_ptr())(), input_shape, output_shape,
|
||||
static_cast<T*>(input.device_ptr()), static_cast<T*>(output.device_ptr()), input_shape, output_shape,
|
||||
pool_shape, stride_shape, padding_shape
|
||||
);
|
||||
CUDA_CHECK(cudaGetLastError());
|
||||
@@ -333,7 +333,7 @@ CUDANet::Tensor& CUDA::avg_pool2d_impl(
|
||||
);
|
||||
|
||||
Kernels::avg_pool<<<grid, block>>>(
|
||||
static_cast<T*>(input.device_ptr())(), static_cast<T*>(output.device_ptr())(), input_shape, output_shape,
|
||||
static_cast<T*>(input.device_ptr()), static_cast<T*>(output.device_ptr()), input_shape, output_shape,
|
||||
pool_shape, stride_shape, padding_shape
|
||||
);
|
||||
CUDA_CHECK(cudaGetLastError());
|
||||
@@ -394,34 +394,34 @@ CUDANet::Tensor& CUDA::batch_norm_impl(
|
||||
for (int i = 0; i < input_shape[2]; i++) {
|
||||
// Subtract mean from input
|
||||
Kernels::vec_scalar_sub<<<gridSize, BLOCK_SIZE>>>(
|
||||
static_cast<T*>(input.device_ptr())() + i * input_shape[0] * input_shape[1],
|
||||
static_cast<T*>(output.device_ptr())() + i * input_shape[0] * input_shape[1],
|
||||
&static_cast<T*>(running_mean.device_ptr())()[i], input_shape[0] * input_shape[1]
|
||||
static_cast<T*>(input.device_ptr()) + i * input_shape[0] * input_shape[1],
|
||||
static_cast<T*>(output.device_ptr()) + i * input_shape[0] * input_shape[1],
|
||||
&static_cast<T*>(running_mean.device_ptr())[i], input_shape[0] * input_shape[1]
|
||||
);
|
||||
CUDA_CHECK(cudaGetLastError());
|
||||
|
||||
// Divide by sqrt(running_var + epsilon)
|
||||
Kernels::vec_scale<<<gridSize, BLOCK_SIZE>>>(
|
||||
static_cast<T*>(output.device_ptr())() + i * input_shape[0] * input_shape[1],
|
||||
static_cast<T*>(output.device_ptr())() + i * input_shape[0] * input_shape[1],
|
||||
&static_cast<T*>(running_var.device_ptr())()[i], static_cast<T*>(epsilon.device_ptr())(),
|
||||
static_cast<T*>(output.device_ptr()) + i * input_shape[0] * input_shape[1],
|
||||
static_cast<T*>(output.device_ptr()) + i * input_shape[0] * input_shape[1],
|
||||
&static_cast<T*>(running_var.device_ptr())[i], static_cast<T*>(epsilon.device_ptr()),
|
||||
input_shape[0] * input_shape[1]
|
||||
);
|
||||
CUDA_CHECK(cudaGetLastError());
|
||||
|
||||
// Multiply by weights
|
||||
Kernels::vec_scalar_mul<<<gridSize, BLOCK_SIZE>>>(
|
||||
static_cast<T*>(output.device_ptr())() + i * input_shape[0] * input_shape[1],
|
||||
static_cast<T*>(output.device_ptr())() + i * input_shape[0] * input_shape[1],
|
||||
&static_cast<T*>(weights.device_ptr())()[i], input_shape[0] * input_shape[1]
|
||||
static_cast<T*>(output.device_ptr()) + i * input_shape[0] * input_shape[1],
|
||||
static_cast<T*>(output.device_ptr()) + i * input_shape[0] * input_shape[1],
|
||||
&static_cast<T*>(weights.device_ptr())[i], input_shape[0] * input_shape[1]
|
||||
);
|
||||
CUDA_CHECK(cudaGetLastError());
|
||||
|
||||
// Add biases
|
||||
Kernels::vec_scalar_add<<<gridSize, BLOCK_SIZE>>>(
|
||||
static_cast<T*>(output.device_ptr())() + i * input_shape[0] * input_shape[1],
|
||||
static_cast<T*>(output.device_ptr())() + i * input_shape[0] * input_shape[1],
|
||||
&static_cast<T*>(biases.device_ptr())()[i], input_shape[0] * input_shape[1]
|
||||
static_cast<T*>(output.device_ptr()) + i * input_shape[0] * input_shape[1],
|
||||
static_cast<T*>(output.device_ptr()) + i * input_shape[0] * input_shape[1],
|
||||
&static_cast<T*>(biases.device_ptr())[i], input_shape[0] * input_shape[1]
|
||||
);
|
||||
CUDA_CHECK(cudaGetLastError());
|
||||
}
|
||||
@@ -460,12 +460,12 @@ CUDANet::Tensor& CUDA::concat_impl(
|
||||
CUDANet::Tensor& output
|
||||
) {
|
||||
CUDA_CHECK(cudaMemcpy(
|
||||
static_cast<T*>(output.device_ptr())(), static_cast<T*>(input_a.device_ptr())(), input_a.size(),
|
||||
static_cast<T*>(output.device_ptr()), static_cast<T*>(input_a.device_ptr()), input_a.size(),
|
||||
cudaMemcpyDeviceToDevice
|
||||
));
|
||||
|
||||
CUDA_CHECK(cudaMemcpy(
|
||||
static_cast<T*>(output.device_ptr())() + input_a.numel(), static_cast<T*>(input_b.device_ptr())(), input_b.size(),
|
||||
static_cast<T*>(output.device_ptr()) + input_a.numel(), static_cast<T*>(input_b.device_ptr()), input_b.size(),
|
||||
cudaMemcpyDeviceToDevice
|
||||
));
|
||||
|
||||
@@ -508,7 +508,7 @@ CUDANet::Tensor& CUDA::add_impl(
|
||||
auto gridSize = (input_a.numel() + BLOCK_SIZE - 1) / BLOCK_SIZE;
|
||||
|
||||
Kernels::vec_vec_add<<<gridSize, BLOCK_SIZE>>>(
|
||||
static_cast<T*>(input_a.device_ptr())(), static_cast<T*>(input_b.device_ptr())(), static_cast<T*>(output.device_ptr())(), input_a.numel()
|
||||
static_cast<T*>(input_a.device_ptr()), static_cast<T*>(input_b.device_ptr()), static_cast<T*>(output.device_ptr()), input_a.numel()
|
||||
);
|
||||
CUDA_CHECK(cudaGetLastError());
|
||||
CUDA_CHECK(cudaDeviceSynchronize());
|
||||
|
||||
@@ -26,7 +26,7 @@ void CUDA::print_impl(const CUDANet::Tensor &input) {
|
||||
std::vector<T> h_vec(input.numel());
|
||||
|
||||
CUDA_CHECK(cudaMemcpy(
|
||||
h_vec.data(), static_cast<T*>(input.device_ptr())(), sizeof(T) * length, cudaMemcpyDeviceToHost
|
||||
h_vec.data(), static_cast<T*>(input.device_ptr()), sizeof(T) * length, cudaMemcpyDeviceToHost
|
||||
));
|
||||
|
||||
for (int i = 0; i < length; ++i) {
|
||||
@@ -56,7 +56,7 @@ template void CUDA::fill_impl<float>(CUDANet::Tensor &input, int value);
|
||||
|
||||
template <typename T>
|
||||
void CUDA::fill_impl(CUDANet::Tensor &input, int value) {
|
||||
CUDA_CHECK(cudaMemset(static_cast<T*>(input.device_ptr())(), value, sizeof(T) * input.numel()));
|
||||
CUDA_CHECK(cudaMemset(static_cast<T*>(input.device_ptr()), value, sizeof(T) * input.numel()));
|
||||
}
|
||||
|
||||
void CUDA::copy_to_device(CUDANet::Tensor &tensor, void *data, size_t size) {
|
||||
@@ -75,7 +75,7 @@ template void CUDA::copy_to_device_impl<float>(CUDANet::Tensor &tensor, void *da
|
||||
|
||||
template <typename T>
|
||||
void CUDA::copy_to_device_impl(CUDANet::Tensor &tensor, void *data, size_t size) {
|
||||
CUDA_CHECK(cudaMemcpy(static_cast<T*>(tensor.device_ptr())(), data, size, cudaMemcpyHostToDevice));
|
||||
CUDA_CHECK(cudaMemcpy(static_cast<T*>(tensor.device_ptr()), data, size, cudaMemcpyHostToDevice));
|
||||
}
|
||||
|
||||
void CUDA::sum(const CUDANet::Tensor &input, CUDANet::Tensor &sum) {
|
||||
@@ -98,14 +98,14 @@ void CUDA::sum_impl(const CUDANet::Tensor &input, CUDANet::Tensor &sum) {
|
||||
const int gridSize = (length + BLOCK_SIZE - 1) / BLOCK_SIZE;
|
||||
|
||||
CUDANet::Kernels::sum_reduce<<<gridSize, BLOCK_SIZE>>>(
|
||||
static_cast<T*>(input.device_ptr())(), static_cast<T*>(sum.device_ptr())(), length
|
||||
static_cast<T*>(input.device_ptr()), static_cast<T*>(sum.device_ptr()), length
|
||||
);
|
||||
CUDA_CHECK(cudaGetLastError());
|
||||
|
||||
int remaining = gridSize;
|
||||
while (remaining > 1) {
|
||||
int blocks_needed = (remaining + BLOCK_SIZE - 1) / BLOCK_SIZE;
|
||||
CUDANet::Kernels::sum_reduce<<<blocks_needed, BLOCK_SIZE>>>(static_cast<T*>(sum.device_ptr())(), static_cast<T*>(sum.device_ptr())(), remaining);
|
||||
CUDANet::Kernels::sum_reduce<<<blocks_needed, BLOCK_SIZE>>>(static_cast<T*>(sum.device_ptr()), static_cast<T*>(sum.device_ptr()), remaining);
|
||||
CUDA_CHECK(cudaGetLastError());
|
||||
|
||||
remaining = blocks_needed;
|
||||
@@ -131,14 +131,14 @@ void CUDA::max_impl(const CUDANet::Tensor &input, CUDANet::Tensor &max) {
|
||||
auto length = input.numel();
|
||||
const int grid_size = (length + BLOCK_SIZE - 1) / BLOCK_SIZE;
|
||||
|
||||
Kernels::max_reduce<<<grid_size, BLOCK_SIZE>>>(static_cast<T*>(input.device_ptr())(), static_cast<T*>(max.device_ptr())(), length);
|
||||
Kernels::max_reduce<<<grid_size, BLOCK_SIZE>>>(static_cast<T*>(input.device_ptr()), static_cast<T*>(max.device_ptr()), length);
|
||||
CUDA_CHECK(cudaGetLastError());
|
||||
|
||||
int remaining = grid_size;
|
||||
|
||||
while (remaining > 1) {
|
||||
int blocks_needed = (remaining + BLOCK_SIZE - 1) / BLOCK_SIZE;
|
||||
CUDANet::Kernels::max_reduce<<<blocks_needed, BLOCK_SIZE>>>(static_cast<T*>(max.device_ptr())(), static_cast<T*>(max.device_ptr())(), remaining);
|
||||
CUDANet::Kernels::max_reduce<<<blocks_needed, BLOCK_SIZE>>>(static_cast<T*>(max.device_ptr()), static_cast<T*>(max.device_ptr()), remaining);
|
||||
CUDA_CHECK(cudaGetLastError());
|
||||
|
||||
remaining = blocks_needed;
|
||||
|
||||
@@ -1,11 +1,11 @@
|
||||
#include "activation.hpp"
|
||||
|
||||
#include <format>
|
||||
#include <stdexcept>
|
||||
#include <vector>
|
||||
|
||||
#include "layers/activation.hpp"
|
||||
#include "tensor.hpp"
|
||||
|
||||
|
||||
using namespace CUDANet::Layers;
|
||||
|
||||
Activation::Activation(
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
#include "add.hpp"
|
||||
#include "layers/add.hpp"
|
||||
|
||||
using namespace CUDANet::Layers;
|
||||
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
#include <format>
|
||||
#include <stdexcept>
|
||||
|
||||
#include "avg_pool.hpp"
|
||||
#include <format>
|
||||
#include "layers/avg_pool.hpp"
|
||||
|
||||
using namespace CUDANet::Layers;
|
||||
|
||||
|
||||
@@ -1,9 +1,7 @@
|
||||
#include "batch_norm.hpp"
|
||||
|
||||
#include <stdexcept>
|
||||
#include <vector>
|
||||
|
||||
#include "activation.hpp"
|
||||
#include "layers/batch_norm.hpp"
|
||||
#include "layer.hpp"
|
||||
|
||||
using namespace CUDANet::Layers;
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
#include "concat.hpp"
|
||||
#include "layers/concat.hpp"
|
||||
|
||||
using namespace CUDANet::Layers;
|
||||
|
||||
|
||||
@@ -1,8 +1,7 @@
|
||||
#include "conv2d.hpp"
|
||||
|
||||
#include <format>
|
||||
#include <stdexcept>
|
||||
|
||||
#include "layers/conv2d.hpp"
|
||||
#include "layer.hpp"
|
||||
#include "tensor.hpp"
|
||||
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
#include "dense.hpp"
|
||||
|
||||
#include <format>
|
||||
#include <stdexcept>
|
||||
|
||||
#include "layers/dense.hpp"
|
||||
|
||||
using namespace CUDANet::Layers;
|
||||
|
||||
Dense::Dense(CUDANet::Shape in_shape, CUDANet::Shape out_shape, CUDANet::Backend* backend)
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
#include "max_pool.hpp"
|
||||
|
||||
#include <stdexcept>
|
||||
|
||||
#include "layers/max_pool.hpp"
|
||||
|
||||
using namespace CUDANet::Layers;
|
||||
|
||||
MaxPool2d::MaxPool2d(
|
||||
|
||||
@@ -1,5 +1,3 @@
|
||||
#include "model.hpp"
|
||||
|
||||
#include <fstream>
|
||||
#include <iostream>
|
||||
#include <iomanip>
|
||||
@@ -8,7 +6,9 @@
|
||||
#include <vector>
|
||||
|
||||
#include "layer.hpp"
|
||||
#include "batch_norm.hpp"
|
||||
#include "layers/batch_norm.hpp"
|
||||
|
||||
#include "model.hpp"
|
||||
|
||||
using namespace CUDANet;
|
||||
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
#include "module.hpp"
|
||||
|
||||
#include <algorithm>
|
||||
|
||||
#include "module.hpp"
|
||||
|
||||
using namespace CUDANet;
|
||||
|
||||
CUDANet::Shape Module::input_shape() {
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
#include "tensor.hpp"
|
||||
|
||||
#include <stdexcept>
|
||||
|
||||
#include "tensor.hpp"
|
||||
|
||||
using namespace CUDANet;
|
||||
|
||||
Tensor::Tensor(Shape shape, CUDANet::Backend* backend)
|
||||
@@ -92,6 +92,10 @@ size_t Tensor::size() const {
|
||||
return total_size;
|
||||
}
|
||||
|
||||
void* Tensor::device_ptr() const {
|
||||
return d_ptr;
|
||||
}
|
||||
|
||||
void* Tensor::device_ptr() {
|
||||
return d_ptr;
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user