mirror of
https://github.com/lordmathis/CUDANet.git
synced 2025-11-06 01:34:22 +00:00
Use tiling shmem for mat vec mul kernel
This commit is contained in:
@@ -66,6 +66,10 @@ class Dense : public ILayer {
|
|||||||
|
|
||||||
Layers::Activation activation;
|
Layers::Activation activation;
|
||||||
|
|
||||||
|
// Precompute kernel launch parameters
|
||||||
|
int forwardGridSize;
|
||||||
|
int biasGridSize;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief Initialize the weights to zeros
|
* @brief Initialize the weights to zeros
|
||||||
*
|
*
|
||||||
|
|||||||
@@ -4,6 +4,10 @@
|
|||||||
#include <cuda_runtime.h>
|
#include <cuda_runtime.h>
|
||||||
#include <cstdio>
|
#include <cstdio>
|
||||||
|
|
||||||
|
#ifndef BLOCK_SIZE
|
||||||
|
#define BLOCK_SIZE 128
|
||||||
|
#endif // BLOCK_SIZE
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief CUDA error checking macro
|
* @brief CUDA error checking macro
|
||||||
*
|
*
|
||||||
|
|||||||
@@ -1,31 +1,41 @@
|
|||||||
|
#include "cuda_helper.cuh"
|
||||||
#include "matmul.cuh"
|
#include "matmul.cuh"
|
||||||
|
|
||||||
__global__ void Kernels::mat_vec_mul(
|
#define SHARED_SIZE 128 * 4
|
||||||
const float* d_matrix,
|
|
||||||
const float* d_vector,
|
|
||||||
float* d_output,
|
|
||||||
int w,
|
|
||||||
int h
|
|
||||||
) {
|
|
||||||
|
|
||||||
|
__global__ void Kernels::mat_vec_mul(
|
||||||
|
const float* __restrict__ d_matrix,
|
||||||
|
const float* __restrict__ d_vector,
|
||||||
|
float* __restrict__ d_output,
|
||||||
|
int w,
|
||||||
|
int h
|
||||||
|
) {
|
||||||
int tid = blockDim.x * blockIdx.x + threadIdx.x;
|
int tid = blockDim.x * blockIdx.x + threadIdx.x;
|
||||||
|
|
||||||
extern __shared__ float shared[];
|
__shared__ float shared[BLOCK_SIZE];
|
||||||
|
|
||||||
if (tid < w) {
|
float temp = 0.0f;
|
||||||
shared[tid] = d_vector[tid];
|
|
||||||
}
|
|
||||||
|
|
||||||
__syncthreads();
|
#pragma unroll
|
||||||
|
for (unsigned int i = 0; i < (w + BLOCK_SIZE - 1) / BLOCK_SIZE; i++)
|
||||||
if (tid < h) {
|
{
|
||||||
d_output[tid] = 0.0f;
|
if (i * BLOCK_SIZE + threadIdx.x < w) {
|
||||||
|
shared[threadIdx.x] = d_vector[i * BLOCK_SIZE + threadIdx.x];
|
||||||
#pragma unroll
|
} else {
|
||||||
for (int i = 0; i < w; i++) {
|
shared[threadIdx.x] = 0.0f;
|
||||||
d_output[tid] += d_matrix[tid * w + i] * shared[i];
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
|
for (unsigned int j = 0; j < BLOCK_SIZE; j++)
|
||||||
|
{
|
||||||
|
temp += d_matrix[tid * w + i * BLOCK_SIZE + j] * shared[j];
|
||||||
|
}
|
||||||
|
|
||||||
|
__syncthreads();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
d_output[tid] = temp;
|
||||||
}
|
}
|
||||||
|
|
||||||
__global__ void Kernels::vec_vec_add(
|
__global__ void Kernels::vec_vec_add(
|
||||||
|
|||||||
@@ -10,7 +10,11 @@
|
|||||||
#include "dense.cuh"
|
#include "dense.cuh"
|
||||||
#include "matmul.cuh"
|
#include "matmul.cuh"
|
||||||
|
|
||||||
Layers::Dense::Dense(int inputSize, int outputSize, Layers::Activation activation)
|
Layers::Dense::Dense(
|
||||||
|
int inputSize,
|
||||||
|
int outputSize,
|
||||||
|
Layers::Activation activation
|
||||||
|
)
|
||||||
: inputSize(inputSize), outputSize(outputSize), activation(activation) {
|
: inputSize(inputSize), outputSize(outputSize), activation(activation) {
|
||||||
// Allocate memory for weights and biases
|
// Allocate memory for weights and biases
|
||||||
weights.resize(outputSize * inputSize);
|
weights.resize(outputSize * inputSize);
|
||||||
@@ -31,8 +35,12 @@ Layers::Dense::Dense(int inputSize, int outputSize, Layers::Activation activatio
|
|||||||
cudaMalloc((void**)&d_weights, sizeof(float) * inputSize * outputSize)
|
cudaMalloc((void**)&d_weights, sizeof(float) * inputSize * outputSize)
|
||||||
);
|
);
|
||||||
CUDA_CHECK(cudaMalloc((void**)&d_biases, sizeof(float) * outputSize));
|
CUDA_CHECK(cudaMalloc((void**)&d_biases, sizeof(float) * outputSize));
|
||||||
|
|
||||||
toCuda();
|
toCuda();
|
||||||
|
|
||||||
|
// Calculate block and grid sizes
|
||||||
|
forwardGridSize =
|
||||||
|
(std::max(inputSize, outputSize) + BLOCK_SIZE - 1) / BLOCK_SIZE;
|
||||||
|
biasGridSize = (outputSize + BLOCK_SIZE - 1) / BLOCK_SIZE;
|
||||||
}
|
}
|
||||||
|
|
||||||
Layers::Dense::~Dense() {
|
Layers::Dense::~Dense() {
|
||||||
@@ -51,21 +59,25 @@ void Layers::Dense::initializeBiases() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
float* Layers::Dense::forward(const float* d_input) {
|
float* Layers::Dense::forward(const float* d_input) {
|
||||||
Kernels::mat_vec_mul<<<1, std::max(inputSize, outputSize), sizeof(float) * inputSize>>>(
|
Kernels::mat_vec_mul<<<forwardGridSize, BLOCK_SIZE>>>(
|
||||||
d_weights, d_input, d_output, inputSize, outputSize
|
d_weights, d_input, d_output, inputSize, outputSize
|
||||||
);
|
);
|
||||||
|
|
||||||
Kernels::vec_vec_add<<<1, outputSize>>>(
|
Kernels::vec_vec_add<<<biasGridSize, BLOCK_SIZE>>>(
|
||||||
d_biases, d_output, d_output, outputSize
|
d_biases, d_output, d_output, outputSize
|
||||||
);
|
);
|
||||||
|
|
||||||
switch (activation) {
|
switch (activation) {
|
||||||
case SIGMOID:
|
case SIGMOID:
|
||||||
Kernels::sigmoid<<<1, outputSize>>>(d_output, d_output, outputSize);
|
Kernels::sigmoid<<<biasGridSize, BLOCK_SIZE>>>(
|
||||||
|
d_output, d_output, outputSize
|
||||||
|
);
|
||||||
break;
|
break;
|
||||||
|
|
||||||
case RELU:
|
case RELU:
|
||||||
Kernels::relu<<<1, outputSize>>>(d_output, d_output, outputSize);
|
Kernels::relu<<<biasGridSize, BLOCK_SIZE>>>(
|
||||||
|
d_output, d_output, outputSize
|
||||||
|
);
|
||||||
break;
|
break;
|
||||||
|
|
||||||
default:
|
default:
|
||||||
|
|||||||
Reference in New Issue
Block a user