mirror of
https://github.com/lordmathis/CUDANet.git
synced 2025-12-24 07:14:22 +00:00
Migrate batch norm layer
This commit is contained in:
@@ -2,7 +2,7 @@
|
||||
#include "kernels/activation_functions.cuh"
|
||||
#include "kernels/convolution.cuh"
|
||||
#include "kernels/matmul.cuh"
|
||||
#include "kernels/pooling.cuh"
|
||||
#include "kernels/pool.cuh"
|
||||
#include "utils/cuda_helper.cuh"
|
||||
|
||||
using namespace CUDANet::Backend;
|
||||
@@ -112,7 +112,7 @@ CUDANet::Tensor& CUDA::conv2d(
|
||||
return output;
|
||||
}
|
||||
|
||||
CUDANet::Tensor& CUDA::maxPool2d(
|
||||
CUDANet::Tensor& CUDA::max_pool2d(
|
||||
const CUDANet::Tensor& input,
|
||||
CUDANet::Tensor& output,
|
||||
CUDANet::Shape input_shape,
|
||||
@@ -138,7 +138,7 @@ CUDANet::Tensor& CUDA::maxPool2d(
|
||||
return output;
|
||||
}
|
||||
|
||||
CUDANet::Tensor& CUDA::avgPool2d(
|
||||
CUDANet::Tensor& CUDA::avg_pool2d(
|
||||
const CUDANet::Tensor& input,
|
||||
CUDANet::Tensor& output,
|
||||
CUDANet::Shape input_shape,
|
||||
@@ -162,4 +162,53 @@ CUDANet::Tensor& CUDA::avgPool2d(
|
||||
CUDA_CHECK(cudaDeviceSynchronize());
|
||||
|
||||
return output;
|
||||
}
|
||||
|
||||
CUDANet::Tensor& CUDA::batch_norm(
|
||||
const CUDANet::Tensor& input,
|
||||
CUDANet::Tensor& output,
|
||||
CUDANet::Shape input_shape,
|
||||
CUDANet::Tensor& weights,
|
||||
CUDANet::Tensor& biases,
|
||||
CUDANet::Tensor& running_mean,
|
||||
CUDANet::Tensor& running_var,
|
||||
CUDANet::Tensor& epsilon
|
||||
) {
|
||||
auto gridSize =
|
||||
(input_shape[0] * input_shape[1] + BLOCK_SIZE - 1) / BLOCK_SIZE;
|
||||
|
||||
|
||||
for (int i = 0; i < input_shape[2]; i++) {
|
||||
// Subtract mean from input
|
||||
Kernels::vec_scalar_sub<<<gridSize, BLOCK_SIZE>>>(
|
||||
input.data<float>() + i * input_shape[0] * input_shape[1],
|
||||
output.data<float>() + i * input_shape[0] * input_shape[1],
|
||||
&running_mean.data<float>()[i], input_shape[0] * input_shape[1]
|
||||
);
|
||||
CUDA_CHECK(cudaGetLastError());
|
||||
|
||||
// Divide by sqrt(running_var + epsilon)
|
||||
Kernels::vec_scale<<<gridSize, BLOCK_SIZE>>>(
|
||||
output.data<float>() + i * input_shape[0] * input_shape[1],
|
||||
output.data<float>() + i * input_shape[0] * input_shape[1],
|
||||
&running_var.data<float>()[i], epsilon.data<float>(), input_shape[0] * input_shape[1]
|
||||
);
|
||||
CUDA_CHECK(cudaGetLastError());
|
||||
|
||||
// Multiply by weights
|
||||
Kernels::vec_scalar_mul<<<gridSize, BLOCK_SIZE>>>(
|
||||
output.data<float>() + i * input_shape[0] * input_shape[1],
|
||||
output.data<float>() + i * input_shape[0] * input_shape[1], &weights.data<float>()[i],
|
||||
input_shape[0] * input_shape[1]
|
||||
);
|
||||
CUDA_CHECK(cudaGetLastError());
|
||||
|
||||
// Add biases
|
||||
Kernels::vec_scalar_add<<<gridSize, BLOCK_SIZE>>>(
|
||||
output.data<float>() + i * input_shape[0] * input_shape[1],
|
||||
output.data<float>() + i * input_shape[0] * input_shape[1], &biases.data<float>()[i],
|
||||
input_shape[0] * input_shape[1]
|
||||
);
|
||||
CUDA_CHECK(cudaGetLastError());
|
||||
}
|
||||
}
|
||||
@@ -1,120 +0,0 @@
|
||||
#include <vector>
|
||||
|
||||
#include "activation.hpp"
|
||||
#include "batch_norm.hpp"
|
||||
#include "cuda_helper.cuh"
|
||||
#include "layer.hpp"
|
||||
#include "matmul.cuh"
|
||||
#include "vector.cuh"
|
||||
|
||||
using namespace CUDANet::Layers;
|
||||
|
||||
void BatchNorm2d::initCUDA() {
|
||||
d_output = nullptr;
|
||||
CUDA_CHECK(cudaMalloc(
|
||||
(void **)&d_output,
|
||||
sizeof(float) * inputSize.first * inputSize.second * inputChannels
|
||||
));
|
||||
|
||||
d_running_mean = nullptr;
|
||||
CUDA_CHECK(
|
||||
cudaMalloc((void **)&d_running_mean, sizeof(float) * inputChannels)
|
||||
);
|
||||
|
||||
d_running_var = nullptr;
|
||||
CUDA_CHECK(
|
||||
cudaMalloc((void **)&d_running_var, sizeof(float) * inputChannels)
|
||||
);
|
||||
|
||||
d_weights = nullptr;
|
||||
CUDA_CHECK(cudaMalloc((void **)&d_weights, sizeof(float) * inputChannels));
|
||||
|
||||
d_biases = nullptr;
|
||||
CUDA_CHECK(cudaMalloc((void **)&d_biases, sizeof(float) * inputChannels));
|
||||
|
||||
d_length = nullptr;
|
||||
float length = (float)inputSize.first * inputSize.second;
|
||||
CUDA_CHECK(cudaMalloc((void **)&d_length, sizeof(float)));
|
||||
CUDA_CHECK(
|
||||
cudaMemcpy(d_length, &length, sizeof(float), cudaMemcpyHostToDevice)
|
||||
);
|
||||
|
||||
d_epsilon = nullptr;
|
||||
CUDA_CHECK(cudaMalloc((void **)&d_epsilon, sizeof(float)));
|
||||
CUDA_CHECK(
|
||||
cudaMemcpy(d_epsilon, &epsilon, sizeof(float), cudaMemcpyHostToDevice)
|
||||
);
|
||||
|
||||
gridSize =
|
||||
(inputSize.first * inputSize.second + BLOCK_SIZE - 1) / BLOCK_SIZE;
|
||||
}
|
||||
|
||||
void BatchNorm2d::delCUDA() {
|
||||
cudaFree(d_output);
|
||||
cudaFree(d_running_mean);
|
||||
cudaFree(d_running_var);
|
||||
cudaFree(d_weights);
|
||||
cudaFree(d_biases);
|
||||
cudaFree(d_length);
|
||||
cudaFree(d_epsilon);
|
||||
}
|
||||
|
||||
void BatchNorm2d::toCuda() {
|
||||
CUDA_CHECK(cudaMemcpy(
|
||||
d_weights, weights.data(), sizeof(float) * inputChannels,
|
||||
cudaMemcpyHostToDevice
|
||||
));
|
||||
CUDA_CHECK(cudaMemcpy(
|
||||
d_biases, biases.data(), sizeof(float) * inputChannels,
|
||||
cudaMemcpyHostToDevice
|
||||
));
|
||||
CUDA_CHECK(cudaMemcpy(
|
||||
d_running_mean, running_mean.data(), sizeof(float) * inputChannels,
|
||||
cudaMemcpyHostToDevice
|
||||
));
|
||||
CUDA_CHECK(cudaMemcpy(
|
||||
d_running_var, running_var.data(), sizeof(float) * inputChannels,
|
||||
cudaMemcpyHostToDevice
|
||||
));
|
||||
}
|
||||
|
||||
float *BatchNorm2d::forwardCUDA(const float *d_input) {
|
||||
// Compute per-channel batch normalization
|
||||
for (int i = 0; i < inputChannels; i++) {
|
||||
// Subtract mean from input
|
||||
Kernels::vec_scalar_sub<<<gridSize, BLOCK_SIZE>>>(
|
||||
d_input + i * inputSize.first * inputSize.second,
|
||||
d_output + i * inputSize.first * inputSize.second,
|
||||
&d_running_mean[i], inputSize.first * inputSize.second
|
||||
);
|
||||
CUDA_CHECK(cudaGetLastError());
|
||||
|
||||
// Divide by sqrt(running_var + epsilon)
|
||||
Kernels::vec_scale<<<gridSize, BLOCK_SIZE>>>(
|
||||
d_output + i * inputSize.first * inputSize.second,
|
||||
d_output + i * inputSize.first * inputSize.second,
|
||||
&d_running_var[i], d_epsilon, inputSize.first * inputSize.second
|
||||
);
|
||||
CUDA_CHECK(cudaGetLastError());
|
||||
|
||||
// Multiply by weights
|
||||
Kernels::vec_scalar_mul<<<gridSize, BLOCK_SIZE>>>(
|
||||
d_output + i * inputSize.first * inputSize.second,
|
||||
d_output + i * inputSize.first * inputSize.second, &d_weights[i],
|
||||
inputSize.first * inputSize.second
|
||||
);
|
||||
CUDA_CHECK(cudaGetLastError());
|
||||
|
||||
// Add biases
|
||||
Kernels::vec_scalar_add<<<gridSize, BLOCK_SIZE>>>(
|
||||
d_output + i * inputSize.first * inputSize.second,
|
||||
d_output + i * inputSize.first * inputSize.second, &d_biases[i],
|
||||
inputSize.first * inputSize.second
|
||||
);
|
||||
CUDA_CHECK(cudaGetLastError());
|
||||
}
|
||||
|
||||
activation->activate(d_output);
|
||||
|
||||
return d_output;
|
||||
}
|
||||
@@ -23,7 +23,12 @@ void CUDA::print(const CUDANet::Tensor &input) {
|
||||
}
|
||||
|
||||
void CUDA::zero(CUDANet::Tensor &input) {
|
||||
CUDA_CHECK(cudaMemset(input.data<float>(), 0, sizeof(float) * input.numel()));
|
||||
fill(input, 0);
|
||||
}
|
||||
|
||||
void CUDA::fill(CUDANet::Tensor &input, int value) {
|
||||
CUDA_CHECK(cudaMemset(input.data<float>(), value, sizeof(float) * input.numel()));
|
||||
|
||||
}
|
||||
|
||||
void CUDA::copy_to_device(CUDANet::Tensor &tensor, void *data, size_t size) {
|
||||
|
||||
Reference in New Issue
Block a user