mirror of
https://github.com/lordmathis/CUDANet.git
synced 2025-11-06 09:44:28 +00:00
Add Kernels namespace
This commit is contained in:
@@ -2,7 +2,7 @@
|
||||
|
||||
#include "activations.cuh"
|
||||
|
||||
__global__ void sigmoid_kernel(
|
||||
__global__ void Kernels::sigmoid(
|
||||
const float* __restrict__ src,
|
||||
float* __restrict__ dst,
|
||||
int len
|
||||
@@ -16,7 +16,7 @@ __global__ void sigmoid_kernel(
|
||||
}
|
||||
|
||||
__global__ void
|
||||
relu_kernel(const float* __restrict__ src, float* __restrict__ dst, int len) {
|
||||
Kernels::relu(const float* __restrict__ src, float* __restrict__ dst, int len) {
|
||||
int stride = gridDim.x * blockDim.x;
|
||||
int tid = blockDim.x * blockIdx.x + threadIdx.x;
|
||||
|
||||
|
||||
@@ -1,7 +1,84 @@
|
||||
#include "convolution.cuh"
|
||||
#include <iostream>
|
||||
|
||||
__global__ void convolution_kernel(
|
||||
/*
|
||||
Pads matrix width x height x n_channels to width + 2 * padding x height + 2 *
|
||||
padding x n_channels Matrix is represented as a pointer to a vector
|
||||
|
||||
For example:
|
||||
|
||||
w = 2
|
||||
h = 3
|
||||
n = 2
|
||||
p = 1
|
||||
|
||||
Channel 0:
|
||||
0 1
|
||||
2 3
|
||||
4 5
|
||||
Channel 1:
|
||||
6 7
|
||||
8 9
|
||||
10 11
|
||||
|
||||
Is represented as:
|
||||
|
||||
0 1 2 3 4 5 6 7 8 9 10 11
|
||||
|
||||
Padded result (as a continuous vector):
|
||||
|
||||
0.0f, 0.0f, 0.0f, 0.0f,
|
||||
0.0f, 0.0f, 1.0f, 0.0f,
|
||||
0.0f, 2.0f, 3.0f, 0.0f,
|
||||
0.0f, 4.0f, 5.0f, 0.0f,
|
||||
0.0f, 0.0f, 0.0f, 0.0f,
|
||||
0.0f, 0.0f, 0.0f, 0.0f,
|
||||
0.0f, 6.0f, 7.0f, 0.0f,
|
||||
0.0f, 8.0f, 9.0f, 0.0f,
|
||||
9.0f, 10.0f, 11.0f, 0.0f,
|
||||
0.0f, 0.0f, 0.0f, 0.0f
|
||||
|
||||
Args:
|
||||
d_input: Pointer to input vector representing matrix
|
||||
d_padded: Pointer to output vector representing padded matrix (needs to be
|
||||
pre-allocated)
|
||||
w: Width of input matrix
|
||||
h: Height of input matrix
|
||||
n: Number of channels in input matrix
|
||||
p: Padding
|
||||
*/
|
||||
__global__ void Kernels::padding(
|
||||
const float* d_input,
|
||||
float* d_padded,
|
||||
int w,
|
||||
int h,
|
||||
int n,
|
||||
int p
|
||||
) {
|
||||
int tid = blockDim.x * blockIdx.x + threadIdx.x;
|
||||
|
||||
if (tid >= (w + 2 * p) * (h + 2 * p) * n) {
|
||||
return;
|
||||
}
|
||||
|
||||
int idx = tid;
|
||||
|
||||
// unravel index into padded matrix
|
||||
int i_n = idx / ((w + 2 * p) * (h + 2 * p));
|
||||
int i_h = idx % ((w + 2 * p) * (h + 2 * p)) / (w + 2 * p);
|
||||
int i_w = idx % (w + 2 * p);
|
||||
|
||||
// if i is in the padding region
|
||||
if (i_w < p || i_w >= (w + p) || i_h < p || i_h >= (h + p)) {
|
||||
d_padded[tid] = 0.0f;
|
||||
} else {
|
||||
// Get index into input vector
|
||||
int i_orig = i_n * w * h + (i_h - p) * w + (i_w - p);
|
||||
d_padded[tid] = d_input[i_orig];
|
||||
}
|
||||
}
|
||||
|
||||
__global__ void Kernels::convolution(
|
||||
const float* d_input,
|
||||
const float* d_kernel,
|
||||
float* d_output,
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
#include "matrix_math.cuh"
|
||||
#include "matmul.cuh"
|
||||
|
||||
__global__ void mat_vec_mul_kernel(
|
||||
__global__ void Kernels::mat_vec_mul(
|
||||
const float* d_matrix,
|
||||
const float* d_vector,
|
||||
float* d_output,
|
||||
@@ -22,7 +22,7 @@ __global__ void mat_vec_mul_kernel(
|
||||
|
||||
}
|
||||
|
||||
__global__ void vec_vec_add_kernel(
|
||||
__global__ void Kernels::vec_vec_add(
|
||||
const float* d_vector1,
|
||||
const float* d_vector2,
|
||||
float* d_output,
|
||||
@@ -1,78 +0,0 @@
|
||||
#include <vector>
|
||||
|
||||
/*
|
||||
Pads matrix width x height x n_channels to width + 2 * padding x height + 2 *
|
||||
padding x n_channels Matrix is represented as a pointer to a vector
|
||||
|
||||
For example:
|
||||
|
||||
w = 2
|
||||
h = 3
|
||||
n = 2
|
||||
p = 1
|
||||
|
||||
Channel 0:
|
||||
0 1
|
||||
2 3
|
||||
4 5
|
||||
Channel 1:
|
||||
6 7
|
||||
8 9
|
||||
10 11
|
||||
|
||||
Is represented as:
|
||||
|
||||
0 1 2 3 4 5 6 7 8 9 10 11
|
||||
|
||||
Padded result (as a continuous vector):
|
||||
|
||||
0.0f, 0.0f, 0.0f, 0.0f,
|
||||
0.0f, 0.0f, 1.0f, 0.0f,
|
||||
0.0f, 2.0f, 3.0f, 0.0f,
|
||||
0.0f, 4.0f, 5.0f, 0.0f,
|
||||
0.0f, 0.0f, 0.0f, 0.0f,
|
||||
0.0f, 0.0f, 0.0f, 0.0f,
|
||||
0.0f, 6.0f, 7.0f, 0.0f,
|
||||
0.0f, 8.0f, 9.0f, 0.0f,
|
||||
9.0f, 10.0f, 11.0f, 0.0f,
|
||||
0.0f, 0.0f, 0.0f, 0.0f
|
||||
|
||||
Args:
|
||||
d_input: Pointer to input vector representing matrix
|
||||
d_padded: Pointer to output vector representing padded matrix (needs to be
|
||||
pre-allocated)
|
||||
w: Width of input matrix
|
||||
h: Height of input matrix
|
||||
n: Number of channels in input matrix
|
||||
p: Padding
|
||||
*/
|
||||
__global__ void pad_matrix_kernel(
|
||||
const float* d_input,
|
||||
float* d_padded,
|
||||
int w,
|
||||
int h,
|
||||
int n,
|
||||
int p
|
||||
) {
|
||||
int tid = blockDim.x * blockIdx.x + threadIdx.x;
|
||||
|
||||
if (tid >= (w + 2 * p) * (h + 2 * p) * n) {
|
||||
return;
|
||||
}
|
||||
|
||||
int idx = tid;
|
||||
|
||||
// unravel index into padded matrix
|
||||
int i_n = idx / ((w + 2 * p) * (h + 2 * p));
|
||||
int i_h = idx % ((w + 2 * p) * (h + 2 * p)) / (w + 2 * p);
|
||||
int i_w = idx % (w + 2 * p);
|
||||
|
||||
// if i is in the padding region
|
||||
if (i_w < p || i_w >= (w + p) || i_h < p || i_h >= (h + p)) {
|
||||
d_padded[tid] = 0.0f;
|
||||
} else {
|
||||
// Get index into input vector
|
||||
int i_orig = i_n * w * h + (i_h - p) * w + (i_w - p);
|
||||
d_padded[tid] = d_input[i_orig];
|
||||
}
|
||||
}
|
||||
@@ -5,17 +5,16 @@
|
||||
#include "conv2d.cuh"
|
||||
#include "convolution.cuh"
|
||||
#include "cuda_helper.cuh"
|
||||
#include "matrix_math.cuh"
|
||||
#include "padding.cuh"
|
||||
#include "matmul.cuh"
|
||||
|
||||
Layers::Conv2d::Conv2d(
|
||||
int inputSize,
|
||||
int inputChannels,
|
||||
int kernelSize,
|
||||
int stride,
|
||||
Padding padding,
|
||||
int numFilters,
|
||||
Activation activation
|
||||
int inputSize,
|
||||
int inputChannels,
|
||||
int kernelSize,
|
||||
int stride,
|
||||
Layers::Padding padding,
|
||||
int numFilters,
|
||||
Layers::Activation activation
|
||||
)
|
||||
: inputSize(inputSize),
|
||||
inputChannels(inputChannels),
|
||||
@@ -23,21 +22,19 @@ Layers::Conv2d::Conv2d(
|
||||
stride(stride),
|
||||
numFilters(numFilters),
|
||||
activation(activation) {
|
||||
switch (padding) {
|
||||
case SAME:
|
||||
outputSize = inputSize;
|
||||
paddingSize = ((stride - 1) * inputSize - stride + kernelSize) / 2;
|
||||
break;
|
||||
|
||||
switch (padding)
|
||||
{
|
||||
case SAME:
|
||||
outputSize = inputSize;
|
||||
paddingSize = ((stride - 1) * inputSize - stride + kernelSize) / 2;
|
||||
break;
|
||||
case VALID:
|
||||
paddingSize = 0;
|
||||
outputSize = (inputSize - kernelSize) / stride + 1;
|
||||
break;
|
||||
|
||||
case VALID:
|
||||
paddingSize = 0;
|
||||
outputSize = (inputSize - kernelSize) / stride + 1;
|
||||
break;
|
||||
|
||||
default:
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
|
||||
weights.resize(kernelSize * kernelSize * inputChannels * numFilters);
|
||||
@@ -109,19 +106,19 @@ void Layers::Conv2d::forward(const float* d_input, float* d_output) {
|
||||
int THREADS_PER_BLOCK = (inputSize + 2 * paddingSize) *
|
||||
(inputSize + 2 * paddingSize) * inputChannels;
|
||||
|
||||
pad_matrix_kernel<<<1, THREADS_PER_BLOCK>>>(
|
||||
Kernels::padding<<<1, THREADS_PER_BLOCK>>>(
|
||||
d_input, d_padded, inputSize, inputSize, inputChannels, paddingSize
|
||||
);
|
||||
|
||||
// Convolve
|
||||
THREADS_PER_BLOCK = outputSize * outputSize * numFilters;
|
||||
convolution_kernel<<<1, THREADS_PER_BLOCK>>>(
|
||||
Kernels::convolution<<<1, THREADS_PER_BLOCK>>>(
|
||||
d_padded, d_weights, d_output, inputSize + (2 * paddingSize),
|
||||
inputChannels, kernelSize, stride, numFilters, outputSize
|
||||
);
|
||||
|
||||
// Add bias
|
||||
vec_vec_add_kernel<<<1, biases.size()>>>(
|
||||
Kernels::vec_vec_add<<<1, biases.size()>>>(
|
||||
d_biases, d_output, d_output, biases.size()
|
||||
);
|
||||
|
||||
@@ -138,8 +135,7 @@ outputSize x numFilters
|
||||
*/
|
||||
void Layers::Conv2d::host_conv(const float* input, float* output) {
|
||||
// Iterate over output matrix
|
||||
for (int tid = 0; tid < outputSize * outputSize * numFilters; tid++)
|
||||
{
|
||||
for (int tid = 0; tid < outputSize * outputSize * numFilters; tid++) {
|
||||
// Get output index
|
||||
int f = tid / (outputSize * outputSize);
|
||||
int i = tid % (outputSize * outputSize) / outputSize;
|
||||
@@ -153,19 +149,17 @@ void Layers::Conv2d::host_conv(const float* input, float* output) {
|
||||
for (int c = 0; c < inputChannels; c++) {
|
||||
int kernelIndex =
|
||||
f * kernelSize * kernelSize * inputChannels +
|
||||
c * kernelSize * kernelSize + k * kernelSize +
|
||||
l;
|
||||
c * kernelSize * kernelSize + k * kernelSize + l;
|
||||
int inputIndex = c * inputSize * inputSize +
|
||||
(i * stride + k) * inputSize +
|
||||
(j * stride + l);
|
||||
(i * stride + k) * inputSize +
|
||||
(j * stride + l);
|
||||
|
||||
sum += weights[kernelIndex] * input[inputIndex];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
int outputIndex =
|
||||
f * outputSize * outputSize + i * outputSize + j;
|
||||
int outputIndex = f * outputSize * outputSize + i * outputSize + j;
|
||||
|
||||
output[outputIndex] = sum;
|
||||
}
|
||||
|
||||
@@ -8,9 +8,9 @@
|
||||
#include "activations.cuh"
|
||||
#include "cuda_helper.cuh"
|
||||
#include "dense.cuh"
|
||||
#include "matrix_math.cuh"
|
||||
#include "matmul.cuh"
|
||||
|
||||
Layers::Dense::Dense(int inputSize, int outputSize, Activation activation)
|
||||
Layers::Dense::Dense(int inputSize, int outputSize, Layers::Activation activation)
|
||||
: inputSize(inputSize), outputSize(outputSize), activation(activation) {
|
||||
// Allocate memory for weights and biases
|
||||
weights.resize(outputSize * inputSize);
|
||||
@@ -46,21 +46,21 @@ void Layers::Dense::initializeBiases() {
|
||||
}
|
||||
|
||||
void Layers::Dense::forward(const float* d_input, float* d_output) {
|
||||
mat_vec_mul_kernel<<<1, outputSize>>>(
|
||||
Kernels::mat_vec_mul<<<1, outputSize>>>(
|
||||
d_weights, d_input, d_output, inputSize, outputSize
|
||||
);
|
||||
|
||||
vec_vec_add_kernel<<<1, outputSize>>>(
|
||||
Kernels::vec_vec_add<<<1, outputSize>>>(
|
||||
d_biases, d_output, d_output, outputSize
|
||||
);
|
||||
|
||||
switch (activation) {
|
||||
case SIGMOID:
|
||||
sigmoid_kernel<<<1, outputSize>>>(d_output, d_output, outputSize);
|
||||
Kernels::sigmoid<<<1, outputSize>>>(d_output, d_output, outputSize);
|
||||
break;
|
||||
|
||||
case RELU:
|
||||
relu_kernel<<<1, outputSize>>>(d_output, d_output, outputSize);
|
||||
Kernels::relu<<<1, outputSize>>>(d_output, d_output, outputSize);
|
||||
break;
|
||||
|
||||
default:
|
||||
|
||||
Reference in New Issue
Block a user