Remove cublas dependency

This commit is contained in:
2024-03-05 18:41:35 +01:00
parent 98ad84c659
commit f4257afd5a
16 changed files with 65 additions and 141 deletions

View File

@@ -12,6 +12,7 @@ set(LIBRARY_SOURCES
src/utils/cuda_helper.cu
src/kernels/activations.cu
src/kernels/padding.cu
src/kernels/matrix_math.cu
src/layers/dense.cu
src/layers/conv2d.cu
)
@@ -23,8 +24,7 @@ set(CUDA_NVCC_FLAGS ${CUDA_NVCC_FLAGS} -arch=sm_75)
# Build static library
add_library(${PROJECT_NAME} STATIC ${LIBRARY_SOURCES})
# Link cuBLAS library to the library
target_link_libraries(${PROJECT_NAME} CUDA::cublas CUDA::cudart)
target_link_libraries(${PROJECT_NAME} CUDA::cudart)
# Set include directories for the library
target_include_directories(${PROJECT_NAME} PUBLIC

View File

@@ -1,5 +1,5 @@
# CUDANet
requirements:
- CUDA, cuBLAS
- CUDA
- Google Test

View File

@@ -1,8 +1,6 @@
#ifndef CONV_LAYER_H
#define CONV_LAYER_H
#include <cublas_v2.h>
#include <string>
#include <vector>
@@ -19,8 +17,7 @@ class Conv2d {
int stride,
std::string padding,
int numFilters,
Activation activation,
cublasHandle_t cublasHandle
Activation activation
);
~Conv2d();
@@ -44,7 +41,6 @@ class Conv2d {
std::vector<float> kernels;
// Cuda
cublasHandle_t cublasHandle;
float* d_kernels;
float* d_padded;

View File

@@ -1,8 +1,6 @@
#ifndef DENSE_LAYER_H
#define DENSE_LAYER_H
#include <cublas_v2.h>
#include <functional>
#include <string>
#include <vector>
@@ -16,8 +14,7 @@ class Dense : public ILayer {
Dense(
int inputSize,
int outputSize,
Activation activation,
cublasHandle_t cublasHandle
Activation activation
);
~Dense();
@@ -29,8 +26,6 @@ class Dense : public ILayer {
int inputSize;
int outputSize;
cublasHandle_t cublasHandle;
float* d_weights;
float* d_biases;

View File

@@ -2,8 +2,6 @@
#ifndef I_LAYER_H
#define I_LAYER_H
#include <cublas_v2.h>
#include <vector>
namespace Layers {

View File

@@ -2,9 +2,6 @@
#define CUDA_HELPER_H
#include <cuda_runtime.h>
#include <cublas_v2.h>
#define IDX2C(i,j,ld) (((j)*(ld))+(i))
// CUDA error checking macro
#define CUDA_CHECK(call) \
@@ -18,15 +15,4 @@ do { \
} \
} while (0)
// cuBLAS error checking macro
#define CUBLAS_CHECK(call) \
do { \
cublasStatus_t result = call; \
if (result != CUBLAS_STATUS_SUCCESS) { \
fprintf(stderr, "cuBLAS error at %s:%d code=%d\n", \
__FILE__, __LINE__, static_cast<unsigned int>(result)); \
exit(EXIT_FAILURE); \
} \
} while (0)
#endif // CUDA_HELPER_H

View File

@@ -14,6 +14,8 @@ __global__ void mat_vec_mul_kernel(
return;
}
d_output[tid] = 0.0f;
for (int i = 0; i < w; i++) {
d_output[tid] += d_matrix[tid * w + i] * d_vector[i];
}

View File

@@ -1,5 +1,3 @@
#include <cublas_v2.h>
#include <string>
#include "activations.cuh"
@@ -14,15 +12,13 @@ Layers::Conv2d::Conv2d(
int stride,
std::string padding,
int numFilters,
Activation activation,
cublasHandle_t cublasHandle
Activation activation
)
: inputSize(inputSize),
inputChannels(inputChannels),
kernelSize(kernelSize),
stride(stride),
numFilters(numFilters),
cublasHandle(cublasHandle),
activation(activation) {
// Allocate memory for kernels

View File

@@ -1,4 +1,3 @@
#include <cublas_v2.h>
#include <cuda_runtime.h>
#include <cstdio>
@@ -9,16 +8,15 @@
#include "activations.cuh"
#include "cuda_helper.cuh"
#include "dense.cuh"
#include "matrix_math.cuh"
Layers::Dense::Dense(
int inputSize,
int outputSize,
Activation activation,
cublasHandle_t cublasHandle
Activation activation
)
: inputSize(inputSize),
outputSize(outputSize),
cublasHandle(cublasHandle),
activation(activation) {
// Allocate memory for weights and biases
weights.resize(outputSize * inputSize);
@@ -54,35 +52,30 @@ void Layers::Dense::initializeBiases() {
}
void Layers::Dense::forward(const float* d_input, float* d_output) {
const float alpha = 1.0f;
const float beta = 0.0f;
CUBLAS_CHECK(cublasSgemv(
cublasHandle, CUBLAS_OP_N, outputSize, inputSize, &alpha, d_weights,
outputSize, d_input, 1, &beta, d_output, 1
));
CUBLAS_CHECK(
cublasSaxpy(cublasHandle, outputSize, &alpha, d_biases, 1, d_output, 1)
mat_vec_mul_kernel<<<1, outputSize>>>(
d_weights, d_input, d_output, inputSize, outputSize
);
int threadsPerBlock = 256;
int blocksPerGrid = (outputSize + threadsPerBlock - 1) / threadsPerBlock;
vec_vec_add_kernel<<<1, outputSize>>>(
d_biases, d_output, d_output, outputSize
);
switch (activation) {
case SIGMOID:
sigmoid_kernel<<<blocksPerGrid, threadsPerBlock>>>(
sigmoid_kernel<<<1, outputSize>>>(
d_output, d_output, outputSize
);
break;
case RELU:
relu_kernel<<<blocksPerGrid, threadsPerBlock>>>(
relu_kernel<<<1, outputSize>>>(
d_output, d_output, outputSize
);
break;
default:
linear_kernel<<<blocksPerGrid, threadsPerBlock>>>(
linear_kernel<<<1, outputSize>>>(
d_output, d_output, outputSize
);
break;
@@ -92,12 +85,13 @@ void Layers::Dense::forward(const float* d_input, float* d_output) {
}
void Layers::Dense::toCuda() {
CUBLAS_CHECK(cublasSetMatrix(
outputSize, inputSize, sizeof(float), weights.data(), outputSize,
d_weights, outputSize
CUDA_CHECK(cudaMemcpy(
d_weights, weights.data(), sizeof(float) * inputSize * outputSize,
cudaMemcpyHostToDevice
));
CUBLAS_CHECK(cublasSetVector(
biases.size(), sizeof(float), biases.data(), 1, d_biases, 1
CUDA_CHECK(cudaMemcpy(
d_biases, biases.data(), sizeof(float) * outputSize,
cudaMemcpyHostToDevice
));
}
@@ -111,10 +105,9 @@ void Layers::Dense::setWeights(
exit(EXIT_FAILURE);
}
for (int j = 0; j < inputSize; ++j) {
for (int i = 0; i < outputSize; ++i) {
int idx = IDX2C(i, j, outputSize);
weights[idx] = weights_input[i][j];
for (int j = 0; j < inputSize; ++j) {
weights[i * inputSize + j] = weights_input[i][j];
}
}

View File

@@ -1,4 +1,3 @@
#include <cublas_v2.h>
#include <cuda_runtime.h>
#include <cstdio>
@@ -6,7 +5,7 @@
#include "cuda_helper.cuh"
cudaDeviceProp initializeCUDA(cublasHandle_t& cublasHandle) {
cudaDeviceProp initializeCUDA() {
int deviceCount;
CUDA_CHECK(cudaGetDeviceCount(&deviceCount));
@@ -23,8 +22,5 @@ cudaDeviceProp initializeCUDA(cublasHandle_t& cublasHandle) {
std::printf("Using CUDA device %d: %s\n", device, deviceProp.name);
// Initialize cuBLAS
CUBLAS_CHECK(cublasCreate(&cublasHandle));
return deviceProp;
}

View File

@@ -7,12 +7,6 @@ add_executable(test_main
kernels/test_padding.cu
)
add_library(test_utils
test_utils/test_cublas_fixture.cu
)
target_include_directories(test_utils PUBLIC test_utils)
target_link_libraries(test_main ${GTEST_BOTH_LIBRARIES} CUDANet test_utils)
target_link_libraries(test_main ${GTEST_BOTH_LIBRARIES} CUDANet)
add_test(NAME TestMain COMMAND test_main)

View File

@@ -4,15 +4,11 @@
#include <iostream>
#include "activations.cuh"
#include "test_cublas_fixture.cuh"
class ActivationsTest : public CublasTestFixture {
protected:
TEST(ActivationsTest, SigmoidSanityCheck) {
cudaError_t cudaStatus;
cublasStatus_t cublasStatus;
};
TEST_F(ActivationsTest, SigmoidSanityCheck) {
float input[3] = {-100.0f, 0.0f, 100.0f};
std::vector<float> expected_output = {0.0f, 0.5f, 1.0f};
@@ -26,8 +22,8 @@ TEST_F(ActivationsTest, SigmoidSanityCheck) {
cudaStatus = cudaMalloc((void**)&d_output, sizeof(float) * 3);
EXPECT_EQ(cudaStatus, cudaSuccess);
cublasStatus = cublasSetVector(3, sizeof(float), input, 1, d_input, 1);
EXPECT_EQ(cublasStatus, CUBLAS_STATUS_SUCCESS);
cudaStatus = cudaMemcpy(d_input, input, sizeof(float) * 3, cudaMemcpyHostToDevice);
EXPECT_EQ(cudaStatus, cudaSuccess);
sigmoid_kernel<<<1, 3>>>(d_input, d_output, 3);
cudaStatus = cudaDeviceSynchronize();
@@ -35,9 +31,8 @@ TEST_F(ActivationsTest, SigmoidSanityCheck) {
std::vector<float> output(3);
cublasStatus =
cublasGetVector(3, sizeof(float), d_output, 1, output.data(), 1);
EXPECT_EQ(cublasStatus, CUBLAS_STATUS_SUCCESS);
cudaStatus = cudaMemcpy(output.data(), d_output, sizeof(float) * 3, cudaMemcpyDeviceToHost);
EXPECT_EQ(cudaStatus, cudaSuccess);
for (int i = 0; i < 3; i++) {
EXPECT_NEAR(expected_output[i], output[i], 1e-5);

View File

@@ -4,15 +4,10 @@
#include <iostream>
#include "padding.cuh"
#include "test_cublas_fixture.cuh"
class PaddingTest : public CublasTestFixture {
protected:
TEST(PaddingTest, SimplePaddingTest) {
cudaError_t cudaStatus;
cublasStatus_t cublasStatus;
};
TEST_F(PaddingTest, SimplePaddingTest) {
int w = 2;
int h = 3;
int n = 2;
@@ -48,9 +43,10 @@ TEST_F(PaddingTest, SimplePaddingTest) {
std::vector<float> input = {0.0f, 2.0f, 4.0f, 1.0f, 3.0f, 5.0f,
6.0f, 8.0f, 10.0f, 7.0f, 9.0f, 11.0f};
cublasStatus =
cublasSetVector(inputSize, sizeof(float), input.data(), 1, d_input, 1);
EXPECT_EQ(cublasStatus, CUBLAS_STATUS_SUCCESS);
cudaStatus = cudaMemcpy(
d_input, input.data(), sizeof(float) * inputSize, cudaMemcpyHostToDevice
);
EXPECT_EQ(cudaStatus, cudaSuccess);
int THREADS_PER_BLOCK = 64;
int BLOCKS = paddedSize / THREADS_PER_BLOCK + 1;
@@ -69,9 +65,12 @@ TEST_F(PaddingTest, SimplePaddingTest) {
};
std::vector<float> output(paddedSize);
cublasStatus = cublasGetVector(
paddedSize, sizeof(float), d_padded, 1, output.data(), 1
cudaStatus = cudaMemcpy(
output.data(), d_padded, sizeof(float) * paddedSize,
cudaMemcpyDeviceToHost
);
EXPECT_EQ(cudaStatus, cudaSuccess);
for (int i = 0; i < paddedSize; i++) {
EXPECT_NEAR(expectedOutput[i], output[i], 1e-5);

View File

@@ -5,9 +5,9 @@
#include "activations.cuh"
#include "dense.cuh"
#include "test_cublas_fixture.cuh"
class DenseLayerTest : public CublasTestFixture {
class DenseLayerTest : public::testing::Test {
protected:
Layers::Dense commonTestSetup(
int inputSize,
@@ -21,7 +21,7 @@ class DenseLayerTest : public CublasTestFixture {
) {
// Create Dense layer
Layers::Dense denseLayer(
inputSize, outputSize, activation, cublasHandle
inputSize, outputSize, activation
);
// Set weights and biases
@@ -36,10 +36,11 @@ class DenseLayerTest : public CublasTestFixture {
EXPECT_EQ(cudaStatus, cudaSuccess);
// Copy input to device
cublasStatus = cublasSetVector(
input.size(), sizeof(float), input.data(), 1, d_input, 1
cudaStatus = cudaMemcpy(
d_input, input.data(), sizeof(float) * input.size(), cudaMemcpyHostToDevice
);
EXPECT_EQ(cublasStatus, CUBLAS_STATUS_SUCCESS);
EXPECT_EQ(cudaStatus, cudaSuccess);
return denseLayer;
}
@@ -51,7 +52,6 @@ class DenseLayerTest : public CublasTestFixture {
}
cudaError_t cudaStatus;
cublasStatus_t cublasStatus;
};
TEST_F(DenseLayerTest, Init) {
@@ -60,10 +60,8 @@ TEST_F(DenseLayerTest, Init) {
int inputSize = i;
int outputSize = j;
// std::cout << "Dense layer: input size = " << inputSize << ",
// output size = " << outputSize << std::endl;
Layers::Dense denseLayer(
inputSize, outputSize, SIGMOID, cublasHandle
inputSize, outputSize, SIGMOID
);
}
}
@@ -81,7 +79,7 @@ TEST_F(DenseLayerTest, setWeights) {
{1.3f, 0.5f, 0.0f, 1.7f}
};
Layers::Dense denseLayer(inputSize, outputSize, SIGMOID, cublasHandle);
Layers::Dense denseLayer(inputSize, outputSize, SIGMOID);
denseLayer.setWeights(weights);
}
@@ -113,10 +111,10 @@ TEST_F(DenseLayerTest, ForwardUnitWeightMatrixLinear) {
denseLayer.forward(d_input, d_output);
std::vector<float> output(outputSize);
cublasStatus = cublasGetVector(
outputSize, sizeof(float), d_output, 1, output.data(), 1
cudaStatus = cudaMemcpy(
output.data(), d_output, sizeof(float) * outputSize, cudaMemcpyDeviceToHost
);
EXPECT_EQ(cublasStatus, CUBLAS_STATUS_SUCCESS);
EXPECT_EQ(cudaStatus, cudaSuccess);
// Check if the output is a zero vector
EXPECT_FLOAT_EQ(output[0], 2.0f);
@@ -150,10 +148,10 @@ TEST_F(DenseLayerTest, ForwardRandomWeightMatrixRelu) {
denseLayer.forward(d_input, d_output);
std::vector<float> output(outputSize);
cublasStatus = cublasGetVector(
outputSize, sizeof(float), d_output, 1, output.data(), 1
cudaStatus = cudaMemcpy(
output.data(), d_output, sizeof(float) * outputSize, cudaMemcpyDeviceToHost
);
EXPECT_EQ(cublasStatus, CUBLAS_STATUS_SUCCESS);
EXPECT_EQ(cudaStatus, cudaSuccess);
// weights * inputs = 0.1, 12.5, 8.3, -2.2
// + biases = 0.3, 13, 9, -3.3
@@ -193,10 +191,10 @@ TEST_F(DenseLayerTest, ForwardRandomWeightMatrixSigmoid) {
denseLayer.forward(d_input, d_output);
std::vector<float> output(outputSize);
cublasStatus = cublasGetVector(
outputSize, sizeof(float), d_output, 1, output.data(), 1
cudaStatus = cudaMemcpy(
output.data(), d_output, sizeof(float) * outputSize, cudaMemcpyDeviceToHost
);
EXPECT_EQ(cublasStatus, CUBLAS_STATUS_SUCCESS);
EXPECT_EQ(cudaStatus, cudaSuccess);
// weights * input = 0.95, 0.43, 0.45, 0.93
// + biases = 1.05, 0.63, 0.75, 1.33

View File

@@ -1,14 +0,0 @@
#include <cublas_v2.h>
#include <gtest/gtest.h>
#include "test_cublas_fixture.cuh"
cublasHandle_t CublasTestFixture::cublasHandle;
void CublasTestFixture::SetUpTestSuite() {
cublasCreate(&cublasHandle);
}
void CublasTestFixture::TearDownTestSuite() {
cublasDestroy(cublasHandle);
}

View File

@@ -1,10 +0,0 @@
#include <cublas_v2.h>
#include <gtest/gtest.h>
class CublasTestFixture : public ::testing::Test {
protected:
static cublasHandle_t cublasHandle;
static void SetUpTestSuite();
static void TearDownTestSuite();
};