Cleanup and refactor

This commit is contained in:
2024-04-11 22:52:41 +02:00
parent 4b9d123e94
commit 18522c2dea
9 changed files with 81 additions and 72 deletions

View File

@@ -28,7 +28,7 @@ class Activation {
* @param activation Type of activation
* @param length Length of the input
*/
Activation(ActivationType activation, const unsigned int length);
Activation(ActivationType activation, const int length);
/**
* @brief Destroy the Activation object
@@ -46,8 +46,8 @@ class Activation {
private:
ActivationType activationType;
unsigned int length;
unsigned int gridSize;
int length;
int gridSize;
float* d_softmax_sum;
float* d_max;

View File

@@ -1,23 +1,22 @@
#include "activation.cuh"
#include "cuda_helper.cuh"
#include "activation_functions.cuh"
#include "matmul.cuh"
#include <iostream>
#include <vector>
#include "activation.cuh"
#include "activation_functions.cuh"
#include "cuda_helper.cuh"
#include "matmul.cuh"
#include "vector.cuh"
using namespace CUDANet::Layers;
Activation::Activation(ActivationType activation, const unsigned int length)
Activation::Activation(ActivationType activation, const int length)
: activationType(activation), length(length) {
if (activationType == SOFTMAX) {
d_softmax_sum = nullptr;
CUDA_CHECK(cudaMalloc((void**)&d_softmax_sum, sizeof(float) * length));
d_max = nullptr;
CUDA_CHECK(cudaMalloc((void**)&d_max, sizeof(float) * length));
d_softmax_sum = nullptr;
CUDA_CHECK(cudaMalloc((void**)&d_softmax_sum, sizeof(float) * length));
}
gridSize = (length + BLOCK_SIZE - 1) / BLOCK_SIZE;
@@ -26,10 +25,13 @@ Activation::Activation(ActivationType activation, const unsigned int length)
Activation::~Activation() {
if (activationType == SOFTMAX) {
cudaFree(d_softmax_sum);
cudaFree(d_max);
}
}
void Activation::activate(float* __restrict__ d_input) {
void Activation::activate(float* d_input) {
// float sum = 0.0f;
switch (activationType) {
case SIGMOID:
@@ -39,44 +41,36 @@ void Activation::activate(float* __restrict__ d_input) {
break;
case RELU:
Kernels::relu<<<gridSize, BLOCK_SIZE>>>(
d_input, d_input, length
);
Kernels::relu<<<gridSize, BLOCK_SIZE>>>(d_input, d_input, length);
break;
case SOFTMAX:
// Find max value
Kernels::max_reduce<<<gridSize, BLOCK_SIZE>>>(
d_input, d_max
);
Kernels::max_reduce<<<1, BLOCK_SIZE>>>(
d_max, d_max
);
Utils::max(d_input, d_max, length);
// Subtract max value to improve numerical stability
Kernels::vec_scalar_sub<<<gridSize, BLOCK_SIZE>>>(
d_input, d_max, d_input, length
d_input, d_input, d_max, length
);
// Compute softmax
Kernels::softmax_exp<<<gridSize, BLOCK_SIZE>>>(
// Compute exponentials
Kernels::vec_exp<<<gridSize, BLOCK_SIZE>>>(
d_input, d_input, length
);
Kernels::softmax_sum<<<gridSize, BLOCK_SIZE>>>(
d_input, d_softmax_sum
);
// Find sum
Utils::sum(d_input, d_softmax_sum, length);
Kernels::softmax_sum<<<1, BLOCK_SIZE>>>(
d_softmax_sum, d_softmax_sum
);
Kernels::softmax_div<<<gridSize, BLOCK_SIZE>>>(
Kernels::vec_scalar_div<<<gridSize, BLOCK_SIZE>>>(
d_input, d_input, d_softmax_sum, length
);
break;
default:
break;
}
cudaDeviceSynchronize();
}

View File

@@ -23,8 +23,6 @@ Dense::Dense(
weights.resize(outputSize * inputSize);
biases.resize(outputSize);
activation = Activation(activationType, outputSize);
initializeWeights();
initializeBiases();
@@ -46,6 +44,8 @@ Dense::Dense(
forwardGridSize =
(std::max(inputSize, outputSize) + BLOCK_SIZE - 1) / BLOCK_SIZE;
biasGridSize = (outputSize + BLOCK_SIZE - 1) / BLOCK_SIZE;
activation = Activation(activationType, outputSize);
}
Dense::~Dense() {

View File

@@ -50,7 +50,6 @@ void Utils::sum(float* d_vec, float* d_sum, const unsigned int length) {
int remaining = gridSize;
while (remaining > 1) {
std::cout << remaining << std::endl;
int blocks_needed = (remaining + BLOCK_SIZE - 1) / BLOCK_SIZE;
CUDANet::Kernels::sum_reduce<<<blocks_needed, BLOCK_SIZE>>>(d_sum, d_sum, remaining);
remaining = blocks_needed;

View File

@@ -47,21 +47,3 @@ TEST(ActivationFunctionsTest, SigmoidSanityCheck) {
cudaDeviceReset();
}
// void print_vec(float* d_vec, int length) {
// std::vector<float> h_vec(length);
// CUDA_CHECK(cudaMemcpy(
// h_vec.data(), d_vec, sizeof(float) * length, cudaMemcpyDeviceToHost
// ));
// float sum = 0.0f;
// for (int i = 0; i < length; ++i) {
// std::cout << h_vec[i] << ", ";
// sum += h_vec[i];
// }
// std::cout << std::endl;
// }

View File

@@ -61,7 +61,7 @@ TEST(MatMulTest, MatVecMulTest) {
for (int j = 0; j < w; j++) {
sum += matrix[i * w + j] * vector[j];
}
EXPECT_NEAR(sum, output_gpu[i], 1e-5);
EXPECT_NEAR(sum, output_gpu[i], 1e-5f);
}
cudaFree(d_matrix);
@@ -151,7 +151,7 @@ TEST(MatMulTest, VecExpTest) {
EXPECT_EQ(cudaStatus, cudaSuccess);
for (int i = 0; i < 6; i++) {
EXPECT_NEAR(expected[i], output[i], 1e7);
EXPECT_NEAR(expected[i], output[i], 1e7f);
}
cudaFree(d_input);
@@ -193,7 +193,6 @@ TEST(MatMulTest, SumReduceTest) {
int remaining = gridSize;
while (remaining > 1) {
std::cout << remaining << std::endl;
int blocks_needed = (remaining + BLOCK_SIZE - 1) / BLOCK_SIZE;
CUDANet::Kernels::sum_reduce<<<blocks_needed, BLOCK_SIZE>>>(d_sum, d_sum, remaining);
remaining = blocks_needed;

View File

@@ -242,20 +242,13 @@ TEST_F(DenseLayerTest, ForwardRandomWeightMatrixSoftmax) {
EXPECT_EQ(cudaStatus, cudaSuccess);
std::vector<float> expected = {0.17124f, 0.28516f, 0.22208f, 0.32152f};
// std::vector<float> expected = {0.46f, 0.97f, 0.72f, 1.09f};
float sum = 0.0f;
for (int i = 0; i < outputSize; ++i) {
std::cout << output[i] << ", ";
}
std::cout << std::endl;
for (int i = 0; i < outputSize; ++i) {
sum += output[i];
EXPECT_NEAR(output[i], expected[i], 1e-5);
EXPECT_NEAR(output[i], expected[i], 1e-5f);
}
std::cout << std::endl;
EXPECT_NEAR(sum, 1.0f, 1e-5f);

View File

@@ -105,11 +105,9 @@ TEST(Model, TestModelPredict) {
// float sum = 0.0f;
for (int i = 0; i < outputSize; ++i) {
sum += output[i];
std::cout << output[i] << " ";
}
std::cout << std::endl;
EXPECT_NEAR(sum, 1.0f, 1e-2f);
EXPECT_NEAR(sum, 1.0f, 1e-5f);
cudaDeviceReset();
}

44
tools/dense_test.py Normal file
View File

@@ -0,0 +1,44 @@
import torch
from utils import print_cpp_vector
def gen_dense_softmax_test():
input = torch.tensor([
0.1, 0.2, 0.3, 0.4, 0.5
])
weights = torch.tensor([
0.5, 0.1, 0.1, 0.4, 0.2,
0.4, 0.3, 0.9, 0.0, 0.8,
0.8, 0.4, 0.6, 0.2, 0.0,
0.1, 0.7, 0.3, 1.0, 0.1
]).reshape(4, 5)
biases = torch.tensor([
0.1, 0.2, 0.3, 0.4
])
dense = torch.nn.Linear(5, 4)
dense.weight = torch.nn.Parameter(weights)
dense.bias = torch.nn.Parameter(biases)
output = dense(input)
print_cpp_vector(output)
# Manual softmax
softmax_exp = torch.exp(output)
print(softmax_exp)
softmax_sum = torch.sum(softmax_exp, dim=0)
print(softmax_sum)
souftmax_out = softmax_exp / softmax_sum
print(souftmax_out)
softmax = torch.nn.Softmax(dim=0)(output)
print_cpp_vector(softmax)
if __name__ == "__main__":
gen_dense_softmax_test()