mirror of
https://github.com/lordmathis/CUDANet.git
synced 2025-12-22 14:24:22 +00:00
Migrate batch norm layer
This commit is contained in:
@@ -16,6 +16,7 @@ class Backend {
|
|||||||
// Tensor ops
|
// Tensor ops
|
||||||
virtual void print(const CUDANet::Tensor& input) = 0;
|
virtual void print(const CUDANet::Tensor& input) = 0;
|
||||||
virtual void zero(CUDANet::Tensor& input) = 0;
|
virtual void zero(CUDANet::Tensor& input) = 0;
|
||||||
|
virtual void fill(CUDANet::Tensor& input, int data) = 0;
|
||||||
|
|
||||||
virtual void
|
virtual void
|
||||||
copy_to_device(CUDANet::Tensor& tensor, void* data, size_t size) = 0;
|
copy_to_device(CUDANet::Tensor& tensor, void* data, size_t size) = 0;
|
||||||
@@ -53,7 +54,7 @@ class Backend {
|
|||||||
const CUDANet::Shape out_shape
|
const CUDANet::Shape out_shape
|
||||||
) = 0;
|
) = 0;
|
||||||
|
|
||||||
virtual CUDANet::Tensor& maxPool2d(
|
virtual CUDANet::Tensor& max_pool2d(
|
||||||
const CUDANet::Tensor& input,
|
const CUDANet::Tensor& input,
|
||||||
CUDANet::Tensor& output,
|
CUDANet::Tensor& output,
|
||||||
CUDANet::Shape input_shape,
|
CUDANet::Shape input_shape,
|
||||||
@@ -63,7 +64,7 @@ class Backend {
|
|||||||
CUDANet::Shape output_shape
|
CUDANet::Shape output_shape
|
||||||
) = 0;
|
) = 0;
|
||||||
|
|
||||||
virtual CUDANet::Tensor& avgPool2d(
|
virtual CUDANet::Tensor& avg_pool2d(
|
||||||
const CUDANet::Tensor& input,
|
const CUDANet::Tensor& input,
|
||||||
CUDANet::Tensor& output,
|
CUDANet::Tensor& output,
|
||||||
CUDANet::Shape input_shape,
|
CUDANet::Shape input_shape,
|
||||||
@@ -72,6 +73,17 @@ class Backend {
|
|||||||
CUDANet::Shape padding_shape,
|
CUDANet::Shape padding_shape,
|
||||||
CUDANet::Shape output_shape
|
CUDANet::Shape output_shape
|
||||||
) = 0;
|
) = 0;
|
||||||
|
|
||||||
|
virtual CUDANet::Tensor& batch_norm(
|
||||||
|
const CUDANet::Tensor& input,
|
||||||
|
CUDANet::Tensor& output,
|
||||||
|
CUDANet::Shape input_shape,
|
||||||
|
CUDANet::Tensor& weights,
|
||||||
|
CUDANet::Tensor& biases,
|
||||||
|
CUDANet::Tensor& running_mean,
|
||||||
|
CUDANet::Tensor& running_var,
|
||||||
|
CUDANet::Tensor& epsilon
|
||||||
|
) = 0;
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace CUDANet
|
} // namespace CUDANet
|
||||||
@@ -14,6 +14,7 @@ class CUDA : public Backend {
|
|||||||
// Tensor ops
|
// Tensor ops
|
||||||
void print(const CUDANet::Tensor& input) override;
|
void print(const CUDANet::Tensor& input) override;
|
||||||
void zero(CUDANet::Tensor& input) override;
|
void zero(CUDANet::Tensor& input) override;
|
||||||
|
void fill(CUDANet::Tensor &input, int value) override;
|
||||||
void
|
void
|
||||||
copy_to_device(CUDANet::Tensor& tensor, void* data, size_t size) override;
|
copy_to_device(CUDANet::Tensor& tensor, void* data, size_t size) override;
|
||||||
void sum(const CUDANet::Tensor& input, CUDANet::Tensor& sum) override;
|
void sum(const CUDANet::Tensor& input, CUDANet::Tensor& sum) override;
|
||||||
@@ -49,7 +50,7 @@ class CUDA : public Backend {
|
|||||||
const CUDANet::Shape out_shape
|
const CUDANet::Shape out_shape
|
||||||
) override;
|
) override;
|
||||||
|
|
||||||
CUDANet::Tensor& maxPool2d(
|
CUDANet::Tensor& max_pool2d(
|
||||||
const CUDANet::Tensor& input,
|
const CUDANet::Tensor& input,
|
||||||
CUDANet::Tensor& output,
|
CUDANet::Tensor& output,
|
||||||
CUDANet::Shape input_shape,
|
CUDANet::Shape input_shape,
|
||||||
@@ -59,7 +60,7 @@ class CUDA : public Backend {
|
|||||||
CUDANet::Shape output_shape
|
CUDANet::Shape output_shape
|
||||||
) override;
|
) override;
|
||||||
|
|
||||||
CUDANet::Tensor& avgPool2d(
|
CUDANet::Tensor& avg_pool2d(
|
||||||
const CUDANet::Tensor& input,
|
const CUDANet::Tensor& input,
|
||||||
CUDANet::Tensor& output,
|
CUDANet::Tensor& output,
|
||||||
CUDANet::Shape input_shape,
|
CUDANet::Shape input_shape,
|
||||||
@@ -67,7 +68,18 @@ class CUDA : public Backend {
|
|||||||
CUDANet::Shape stride_shape,
|
CUDANet::Shape stride_shape,
|
||||||
CUDANet::Shape padding_shape,
|
CUDANet::Shape padding_shape,
|
||||||
CUDANet::Shape output_shape
|
CUDANet::Shape output_shape
|
||||||
) = 0;
|
) override;
|
||||||
|
|
||||||
|
CUDANet::Tensor& batch_norm(
|
||||||
|
const CUDANet::Tensor& input,
|
||||||
|
CUDANet::Tensor& output,
|
||||||
|
CUDANet::Shape input_shape,
|
||||||
|
CUDANet::Tensor& weights,
|
||||||
|
CUDANet::Tensor& biases,
|
||||||
|
CUDANet::Tensor& running_mean,
|
||||||
|
CUDANet::Tensor& running_var,
|
||||||
|
CUDANet::Tensor& epsilon
|
||||||
|
) override;
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace CUDANet::Backend
|
} // namespace CUDANet::Backend
|
||||||
@@ -1,170 +1,54 @@
|
|||||||
#ifndef CUDANET_BATCH_NORM_H
|
#pragma once
|
||||||
#define CUDANET_BATCH_NORM_H
|
|
||||||
|
|
||||||
#include <vector>
|
|
||||||
|
|
||||||
#include "activation.hpp"
|
|
||||||
#include "layer.hpp"
|
#include "layer.hpp"
|
||||||
|
|
||||||
namespace CUDANet::Layers {
|
namespace CUDANet::Layers {
|
||||||
|
|
||||||
class BatchNorm2d : public WeightedLayer, public TwoDLayer {
|
class BatchNorm2d : public Layer {
|
||||||
public:
|
public:
|
||||||
BatchNorm2d(
|
BatchNorm2d(CUDANet::Shape input_shape, float epsilon, CUDANet::Backend *backend);
|
||||||
shape2d inputSize,
|
|
||||||
int inputChannels,
|
|
||||||
float epsilon,
|
|
||||||
ActivationType activationType
|
|
||||||
);
|
|
||||||
|
|
||||||
~BatchNorm2d();
|
~BatchNorm2d();
|
||||||
|
|
||||||
/**
|
CUDANet::Tensor& forward(CUDANet::Tensor& input) override;
|
||||||
* @brief Compute the forward pass of the batchnorm layer
|
|
||||||
*
|
|
||||||
* @param d_input Device pointer to the input
|
|
||||||
* @return float* Device pointer to the output
|
|
||||||
*/
|
|
||||||
float* forward(const float* d_input);
|
|
||||||
|
|
||||||
/**
|
CUDANet::Shape input_shape() override;
|
||||||
* @brief Set the weights of the batchnorm layer
|
|
||||||
*
|
|
||||||
* @param weights_input Pointer to the weights
|
|
||||||
*/
|
|
||||||
void setWeights(const float* weights_input);
|
|
||||||
|
|
||||||
/**
|
CUDANet::Shape output_shape() override;
|
||||||
* @brief Get the weights of the batchnorm layer
|
|
||||||
*
|
|
||||||
* @return std::vector<float>
|
|
||||||
*/
|
|
||||||
std::vector<float> getWeights();
|
|
||||||
|
|
||||||
/**
|
size_t input_size() override;
|
||||||
* @brief Set the biases of the batchnorm layer
|
|
||||||
*
|
|
||||||
* @param biases_input Pointer to the biases
|
|
||||||
*/
|
|
||||||
void setBiases(const float* biases_input);
|
|
||||||
|
|
||||||
/**
|
size_t output_size() override;
|
||||||
* @brief Get the biases of the batchnorm layer
|
|
||||||
*
|
|
||||||
* @return std::vector<float>
|
|
||||||
*/
|
|
||||||
std::vector<float> getBiases();
|
|
||||||
|
|
||||||
/**
|
void set_weights(void* input) override;
|
||||||
* @brief Set the Running Mean
|
|
||||||
*
|
|
||||||
* @param running_mean_input
|
|
||||||
*/
|
|
||||||
void setRunningMean(const float* running_mean_input);
|
|
||||||
|
|
||||||
/**
|
CUDANet::Tensor& get_weights() override;
|
||||||
* @brief Get the Running Mean
|
|
||||||
*
|
|
||||||
*/
|
|
||||||
std::vector<float> getRunningMean();
|
|
||||||
|
|
||||||
/**
|
void set_biases(void* input) override;
|
||||||
* @brief Set the Running Var
|
|
||||||
*
|
|
||||||
* @param running_mean_input
|
|
||||||
*/
|
|
||||||
void setRunningVar(const float* running_mean_input);
|
|
||||||
|
|
||||||
/**
|
CUDANet::Tensor& get_biases() override;
|
||||||
* @brief Get the Running Var
|
|
||||||
*
|
|
||||||
*/
|
|
||||||
std::vector<float> getRunningVar();
|
|
||||||
|
|
||||||
/**
|
void set_running_mean(void* input);
|
||||||
* @brief Get output size
|
|
||||||
*
|
|
||||||
* @return int output size
|
|
||||||
*/
|
|
||||||
int getOutputSize();
|
|
||||||
|
|
||||||
/**
|
CUDANet::Tensor& get_running_mean();
|
||||||
* @brief Get input size
|
|
||||||
*
|
|
||||||
* @return int input size
|
|
||||||
*/
|
|
||||||
int getInputSize();
|
|
||||||
|
|
||||||
shape2d getOutputDims();
|
void set_running_var(void* input);
|
||||||
|
|
||||||
|
CUDANet::Tensor& get_running_var();
|
||||||
|
|
||||||
private:
|
private:
|
||||||
shape2d inputSize;
|
CUDANet::Shape in_shape;
|
||||||
int inputChannels;
|
CUDANet::Tensor epsilon;
|
||||||
float epsilon;
|
|
||||||
|
|
||||||
int gridSize;
|
CUDANet::Tensor running_mean;
|
||||||
|
CUDANet::Tensor running_var;
|
||||||
|
|
||||||
#ifdef USE_CUDA
|
CUDANet::Tensor weights;
|
||||||
|
CUDANet::Tensor biases;
|
||||||
|
|
||||||
float* d_output;
|
CUDANet::Tensor output;
|
||||||
|
|
||||||
float* d_running_mean;
|
CUDANet::Backend *backend;
|
||||||
float* d_running_var;
|
|
||||||
|
|
||||||
float* d_length;
|
|
||||||
float* d_epsilon;
|
|
||||||
|
|
||||||
float* d_weights;
|
|
||||||
float* d_biases;
|
|
||||||
|
|
||||||
void initCUDA();
|
|
||||||
void delCUDA();
|
|
||||||
|
|
||||||
/**
|
|
||||||
* @brief Copy weights and biases to the device
|
|
||||||
*
|
|
||||||
*/
|
|
||||||
void toCuda();
|
|
||||||
|
|
||||||
float* forwardCUDA(const float* d_input);
|
|
||||||
|
|
||||||
#endif
|
|
||||||
|
|
||||||
std::vector<float> weights;
|
|
||||||
std::vector<float> biases;
|
|
||||||
|
|
||||||
std::vector<float> running_mean;
|
|
||||||
std::vector<float> running_var;
|
|
||||||
|
|
||||||
Activation* activation;
|
|
||||||
|
|
||||||
float* forwardCPU(const float* input);
|
|
||||||
|
|
||||||
/**
|
|
||||||
* @brief Initialize weights of the batchnorm layer with zeros
|
|
||||||
*
|
|
||||||
*/
|
|
||||||
void initializeWeights();
|
|
||||||
|
|
||||||
/**
|
|
||||||
* @brief Initialize biases of the batchnorm layer with zeros
|
|
||||||
*
|
|
||||||
*/
|
|
||||||
void initializeBiases();
|
|
||||||
|
|
||||||
/**
|
|
||||||
* @brief Initialize mean of the batchnorm layer with zeros
|
|
||||||
*
|
|
||||||
*/
|
|
||||||
void initializeRunningMean();
|
|
||||||
|
|
||||||
/**
|
|
||||||
* @brief Initialize sqrt of variance of the batchnorm layer with ones
|
|
||||||
*
|
|
||||||
*/
|
|
||||||
void initializeRunningVar();
|
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace CUDANet::Layers
|
} // namespace CUDANet::Layers
|
||||||
|
|
||||||
#endif // CUDANET_BATCH_NORM_H
|
|
||||||
@@ -45,6 +45,11 @@ public:
|
|||||||
|
|
||||||
void zero();
|
void zero();
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
void fill(T value) {
|
||||||
|
backend->fill(*this, value);
|
||||||
|
}
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
void set_data(T *data) {
|
void set_data(T *data) {
|
||||||
backend->copy_to_device(*this, data, total_size);
|
backend->copy_to_device(*this, data, total_size);
|
||||||
|
|||||||
@@ -2,7 +2,7 @@
|
|||||||
#include "kernels/activation_functions.cuh"
|
#include "kernels/activation_functions.cuh"
|
||||||
#include "kernels/convolution.cuh"
|
#include "kernels/convolution.cuh"
|
||||||
#include "kernels/matmul.cuh"
|
#include "kernels/matmul.cuh"
|
||||||
#include "kernels/pooling.cuh"
|
#include "kernels/pool.cuh"
|
||||||
#include "utils/cuda_helper.cuh"
|
#include "utils/cuda_helper.cuh"
|
||||||
|
|
||||||
using namespace CUDANet::Backend;
|
using namespace CUDANet::Backend;
|
||||||
@@ -112,7 +112,7 @@ CUDANet::Tensor& CUDA::conv2d(
|
|||||||
return output;
|
return output;
|
||||||
}
|
}
|
||||||
|
|
||||||
CUDANet::Tensor& CUDA::maxPool2d(
|
CUDANet::Tensor& CUDA::max_pool2d(
|
||||||
const CUDANet::Tensor& input,
|
const CUDANet::Tensor& input,
|
||||||
CUDANet::Tensor& output,
|
CUDANet::Tensor& output,
|
||||||
CUDANet::Shape input_shape,
|
CUDANet::Shape input_shape,
|
||||||
@@ -138,7 +138,7 @@ CUDANet::Tensor& CUDA::maxPool2d(
|
|||||||
return output;
|
return output;
|
||||||
}
|
}
|
||||||
|
|
||||||
CUDANet::Tensor& CUDA::avgPool2d(
|
CUDANet::Tensor& CUDA::avg_pool2d(
|
||||||
const CUDANet::Tensor& input,
|
const CUDANet::Tensor& input,
|
||||||
CUDANet::Tensor& output,
|
CUDANet::Tensor& output,
|
||||||
CUDANet::Shape input_shape,
|
CUDANet::Shape input_shape,
|
||||||
@@ -163,3 +163,52 @@ CUDANet::Tensor& CUDA::avgPool2d(
|
|||||||
|
|
||||||
return output;
|
return output;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
CUDANet::Tensor& CUDA::batch_norm(
|
||||||
|
const CUDANet::Tensor& input,
|
||||||
|
CUDANet::Tensor& output,
|
||||||
|
CUDANet::Shape input_shape,
|
||||||
|
CUDANet::Tensor& weights,
|
||||||
|
CUDANet::Tensor& biases,
|
||||||
|
CUDANet::Tensor& running_mean,
|
||||||
|
CUDANet::Tensor& running_var,
|
||||||
|
CUDANet::Tensor& epsilon
|
||||||
|
) {
|
||||||
|
auto gridSize =
|
||||||
|
(input_shape[0] * input_shape[1] + BLOCK_SIZE - 1) / BLOCK_SIZE;
|
||||||
|
|
||||||
|
|
||||||
|
for (int i = 0; i < input_shape[2]; i++) {
|
||||||
|
// Subtract mean from input
|
||||||
|
Kernels::vec_scalar_sub<<<gridSize, BLOCK_SIZE>>>(
|
||||||
|
input.data<float>() + i * input_shape[0] * input_shape[1],
|
||||||
|
output.data<float>() + i * input_shape[0] * input_shape[1],
|
||||||
|
&running_mean.data<float>()[i], input_shape[0] * input_shape[1]
|
||||||
|
);
|
||||||
|
CUDA_CHECK(cudaGetLastError());
|
||||||
|
|
||||||
|
// Divide by sqrt(running_var + epsilon)
|
||||||
|
Kernels::vec_scale<<<gridSize, BLOCK_SIZE>>>(
|
||||||
|
output.data<float>() + i * input_shape[0] * input_shape[1],
|
||||||
|
output.data<float>() + i * input_shape[0] * input_shape[1],
|
||||||
|
&running_var.data<float>()[i], epsilon.data<float>(), input_shape[0] * input_shape[1]
|
||||||
|
);
|
||||||
|
CUDA_CHECK(cudaGetLastError());
|
||||||
|
|
||||||
|
// Multiply by weights
|
||||||
|
Kernels::vec_scalar_mul<<<gridSize, BLOCK_SIZE>>>(
|
||||||
|
output.data<float>() + i * input_shape[0] * input_shape[1],
|
||||||
|
output.data<float>() + i * input_shape[0] * input_shape[1], &weights.data<float>()[i],
|
||||||
|
input_shape[0] * input_shape[1]
|
||||||
|
);
|
||||||
|
CUDA_CHECK(cudaGetLastError());
|
||||||
|
|
||||||
|
// Add biases
|
||||||
|
Kernels::vec_scalar_add<<<gridSize, BLOCK_SIZE>>>(
|
||||||
|
output.data<float>() + i * input_shape[0] * input_shape[1],
|
||||||
|
output.data<float>() + i * input_shape[0] * input_shape[1], &biases.data<float>()[i],
|
||||||
|
input_shape[0] * input_shape[1]
|
||||||
|
);
|
||||||
|
CUDA_CHECK(cudaGetLastError());
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -1,120 +0,0 @@
|
|||||||
#include <vector>
|
|
||||||
|
|
||||||
#include "activation.hpp"
|
|
||||||
#include "batch_norm.hpp"
|
|
||||||
#include "cuda_helper.cuh"
|
|
||||||
#include "layer.hpp"
|
|
||||||
#include "matmul.cuh"
|
|
||||||
#include "vector.cuh"
|
|
||||||
|
|
||||||
using namespace CUDANet::Layers;
|
|
||||||
|
|
||||||
void BatchNorm2d::initCUDA() {
|
|
||||||
d_output = nullptr;
|
|
||||||
CUDA_CHECK(cudaMalloc(
|
|
||||||
(void **)&d_output,
|
|
||||||
sizeof(float) * inputSize.first * inputSize.second * inputChannels
|
|
||||||
));
|
|
||||||
|
|
||||||
d_running_mean = nullptr;
|
|
||||||
CUDA_CHECK(
|
|
||||||
cudaMalloc((void **)&d_running_mean, sizeof(float) * inputChannels)
|
|
||||||
);
|
|
||||||
|
|
||||||
d_running_var = nullptr;
|
|
||||||
CUDA_CHECK(
|
|
||||||
cudaMalloc((void **)&d_running_var, sizeof(float) * inputChannels)
|
|
||||||
);
|
|
||||||
|
|
||||||
d_weights = nullptr;
|
|
||||||
CUDA_CHECK(cudaMalloc((void **)&d_weights, sizeof(float) * inputChannels));
|
|
||||||
|
|
||||||
d_biases = nullptr;
|
|
||||||
CUDA_CHECK(cudaMalloc((void **)&d_biases, sizeof(float) * inputChannels));
|
|
||||||
|
|
||||||
d_length = nullptr;
|
|
||||||
float length = (float)inputSize.first * inputSize.second;
|
|
||||||
CUDA_CHECK(cudaMalloc((void **)&d_length, sizeof(float)));
|
|
||||||
CUDA_CHECK(
|
|
||||||
cudaMemcpy(d_length, &length, sizeof(float), cudaMemcpyHostToDevice)
|
|
||||||
);
|
|
||||||
|
|
||||||
d_epsilon = nullptr;
|
|
||||||
CUDA_CHECK(cudaMalloc((void **)&d_epsilon, sizeof(float)));
|
|
||||||
CUDA_CHECK(
|
|
||||||
cudaMemcpy(d_epsilon, &epsilon, sizeof(float), cudaMemcpyHostToDevice)
|
|
||||||
);
|
|
||||||
|
|
||||||
gridSize =
|
|
||||||
(inputSize.first * inputSize.second + BLOCK_SIZE - 1) / BLOCK_SIZE;
|
|
||||||
}
|
|
||||||
|
|
||||||
void BatchNorm2d::delCUDA() {
|
|
||||||
cudaFree(d_output);
|
|
||||||
cudaFree(d_running_mean);
|
|
||||||
cudaFree(d_running_var);
|
|
||||||
cudaFree(d_weights);
|
|
||||||
cudaFree(d_biases);
|
|
||||||
cudaFree(d_length);
|
|
||||||
cudaFree(d_epsilon);
|
|
||||||
}
|
|
||||||
|
|
||||||
void BatchNorm2d::toCuda() {
|
|
||||||
CUDA_CHECK(cudaMemcpy(
|
|
||||||
d_weights, weights.data(), sizeof(float) * inputChannels,
|
|
||||||
cudaMemcpyHostToDevice
|
|
||||||
));
|
|
||||||
CUDA_CHECK(cudaMemcpy(
|
|
||||||
d_biases, biases.data(), sizeof(float) * inputChannels,
|
|
||||||
cudaMemcpyHostToDevice
|
|
||||||
));
|
|
||||||
CUDA_CHECK(cudaMemcpy(
|
|
||||||
d_running_mean, running_mean.data(), sizeof(float) * inputChannels,
|
|
||||||
cudaMemcpyHostToDevice
|
|
||||||
));
|
|
||||||
CUDA_CHECK(cudaMemcpy(
|
|
||||||
d_running_var, running_var.data(), sizeof(float) * inputChannels,
|
|
||||||
cudaMemcpyHostToDevice
|
|
||||||
));
|
|
||||||
}
|
|
||||||
|
|
||||||
float *BatchNorm2d::forwardCUDA(const float *d_input) {
|
|
||||||
// Compute per-channel batch normalization
|
|
||||||
for (int i = 0; i < inputChannels; i++) {
|
|
||||||
// Subtract mean from input
|
|
||||||
Kernels::vec_scalar_sub<<<gridSize, BLOCK_SIZE>>>(
|
|
||||||
d_input + i * inputSize.first * inputSize.second,
|
|
||||||
d_output + i * inputSize.first * inputSize.second,
|
|
||||||
&d_running_mean[i], inputSize.first * inputSize.second
|
|
||||||
);
|
|
||||||
CUDA_CHECK(cudaGetLastError());
|
|
||||||
|
|
||||||
// Divide by sqrt(running_var + epsilon)
|
|
||||||
Kernels::vec_scale<<<gridSize, BLOCK_SIZE>>>(
|
|
||||||
d_output + i * inputSize.first * inputSize.second,
|
|
||||||
d_output + i * inputSize.first * inputSize.second,
|
|
||||||
&d_running_var[i], d_epsilon, inputSize.first * inputSize.second
|
|
||||||
);
|
|
||||||
CUDA_CHECK(cudaGetLastError());
|
|
||||||
|
|
||||||
// Multiply by weights
|
|
||||||
Kernels::vec_scalar_mul<<<gridSize, BLOCK_SIZE>>>(
|
|
||||||
d_output + i * inputSize.first * inputSize.second,
|
|
||||||
d_output + i * inputSize.first * inputSize.second, &d_weights[i],
|
|
||||||
inputSize.first * inputSize.second
|
|
||||||
);
|
|
||||||
CUDA_CHECK(cudaGetLastError());
|
|
||||||
|
|
||||||
// Add biases
|
|
||||||
Kernels::vec_scalar_add<<<gridSize, BLOCK_SIZE>>>(
|
|
||||||
d_output + i * inputSize.first * inputSize.second,
|
|
||||||
d_output + i * inputSize.first * inputSize.second, &d_biases[i],
|
|
||||||
inputSize.first * inputSize.second
|
|
||||||
);
|
|
||||||
CUDA_CHECK(cudaGetLastError());
|
|
||||||
}
|
|
||||||
|
|
||||||
activation->activate(d_output);
|
|
||||||
|
|
||||||
return d_output;
|
|
||||||
}
|
|
||||||
@@ -23,7 +23,12 @@ void CUDA::print(const CUDANet::Tensor &input) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
void CUDA::zero(CUDANet::Tensor &input) {
|
void CUDA::zero(CUDANet::Tensor &input) {
|
||||||
CUDA_CHECK(cudaMemset(input.data<float>(), 0, sizeof(float) * input.numel()));
|
fill(input, 0);
|
||||||
|
}
|
||||||
|
|
||||||
|
void CUDA::fill(CUDANet::Tensor &input, int value) {
|
||||||
|
CUDA_CHECK(cudaMemset(input.data<float>(), value, sizeof(float) * input.numel()));
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void CUDA::copy_to_device(CUDANet::Tensor &tensor, void *data, size_t size) {
|
void CUDA::copy_to_device(CUDANet::Tensor &tensor, void *data, size_t size) {
|
||||||
|
|||||||
@@ -51,7 +51,7 @@ AvgPool2d::~AvgPool2d() {}
|
|||||||
|
|
||||||
CUDANet::Tensor& AvgPool2d::forward(CUDANet::Tensor& input) {
|
CUDANet::Tensor& AvgPool2d::forward(CUDANet::Tensor& input) {
|
||||||
output.zero();
|
output.zero();
|
||||||
backend->avgPool2d(
|
backend->avg_pool2d(
|
||||||
input,
|
input,
|
||||||
output,
|
output,
|
||||||
in_shape,
|
in_shape,
|
||||||
|
|||||||
@@ -9,125 +9,95 @@
|
|||||||
using namespace CUDANet::Layers;
|
using namespace CUDANet::Layers;
|
||||||
|
|
||||||
BatchNorm2d::BatchNorm2d(
|
BatchNorm2d::BatchNorm2d(
|
||||||
shape2d inputSize,
|
CUDANet::Shape input_shape,
|
||||||
int inputChannels,
|
float eps,
|
||||||
float epsilon,
|
CUDANet::Backend *backend
|
||||||
ActivationType activationType
|
|
||||||
)
|
)
|
||||||
: inputSize(inputSize), inputChannels(inputChannels), epsilon(epsilon) {
|
: in_shape(input_shape), backend(backend) {
|
||||||
activation = new Activation(
|
|
||||||
activationType, inputSize.first * inputSize.second * inputChannels
|
if (in_shape.size() != 3) {
|
||||||
|
throw InvalidShapeException("input", 3, in_shape.size());
|
||||||
|
}
|
||||||
|
|
||||||
|
epsilon = CUDANet::Tensor({1}, CUDANet::DType::FLOAT32, backend);
|
||||||
|
epsilon.set_data<float>(&eps);
|
||||||
|
|
||||||
|
running_mean = CUDANet::Tensor({in_shape[2]}, CUDANet::DType::FLOAT32, backend);
|
||||||
|
running_mean.zero();
|
||||||
|
|
||||||
|
running_var = CUDANet::Tensor({in_shape[2]}, CUDANet::DType::FLOAT32, backend);
|
||||||
|
running_var.fill(1);
|
||||||
|
|
||||||
|
weights = CUDANet::Tensor({in_shape[2]}, CUDANet::DType::FLOAT32, backend);
|
||||||
|
weights.fill(1);
|
||||||
|
|
||||||
|
biases = CUDANet::Tensor({in_shape[2]}, CUDANet::DType::FLOAT32, backend);
|
||||||
|
biases.zero();
|
||||||
|
|
||||||
|
output = CUDANet::Tensor(in_shape, CUDANet::DType::FLOAT32, backend);
|
||||||
|
}
|
||||||
|
|
||||||
|
BatchNorm2d::~BatchNorm2d() {}
|
||||||
|
|
||||||
|
CUDANet::Tensor& BatchNorm2d::forward(CUDANet::Tensor& input) {
|
||||||
|
output.zero();
|
||||||
|
backend->batch_norm(
|
||||||
|
input,
|
||||||
|
output,
|
||||||
|
in_shape,
|
||||||
|
weights,
|
||||||
|
biases,
|
||||||
|
running_mean,
|
||||||
|
running_var,
|
||||||
|
epsilon
|
||||||
);
|
);
|
||||||
|
return output;
|
||||||
weights.resize(inputChannels);
|
|
||||||
biases.resize(inputChannels);
|
|
||||||
|
|
||||||
running_mean.resize(inputChannels);
|
|
||||||
running_var.resize(inputChannels);
|
|
||||||
|
|
||||||
initializeWeights();
|
|
||||||
initializeBiases();
|
|
||||||
initializeRunningMean();
|
|
||||||
initializeRunningVar();
|
|
||||||
|
|
||||||
#ifdef USE_CUDA
|
|
||||||
initCUDA();
|
|
||||||
toCuda();
|
|
||||||
#endif
|
|
||||||
}
|
}
|
||||||
|
|
||||||
BatchNorm2d::~BatchNorm2d() {
|
CUDANet::Shape BatchNorm2d::input_shape() {
|
||||||
#ifdef USE_CUDA
|
return in_shape;
|
||||||
delCUDA();
|
|
||||||
#endif
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void BatchNorm2d::initializeWeights() {
|
CUDANet::Shape BatchNorm2d::output_shape() {
|
||||||
std::fill(weights.begin(), weights.end(), 1.0f);
|
return in_shape;
|
||||||
}
|
}
|
||||||
|
|
||||||
void BatchNorm2d::initializeBiases() {
|
size_t BatchNorm2d::input_size() {
|
||||||
std::fill(biases.begin(), biases.end(), 0.0f);
|
return sizeof(float) * in_shape[0] * in_shape[1] * in_shape[2];
|
||||||
}
|
}
|
||||||
|
|
||||||
void BatchNorm2d::initializeRunningMean() {
|
size_t BatchNorm2d::output_size() {
|
||||||
std::fill(running_mean.begin(), running_mean.end(), 0.0f);
|
return sizeof(float) * in_shape[0] * in_shape[1] * in_shape[2];
|
||||||
}
|
}
|
||||||
|
|
||||||
void BatchNorm2d::initializeRunningVar() {
|
void BatchNorm2d::set_weights(void* input) {
|
||||||
std::fill(running_var.begin(), running_var.end(), 1.0f);
|
weights.set_data<float>(static_cast<float*>(input));
|
||||||
}
|
}
|
||||||
|
|
||||||
void BatchNorm2d::setWeights(const float* weights_input) {
|
CUDANet::Tensor& BatchNorm2d::get_weights() {
|
||||||
std::copy(weights_input, weights_input + weights.size(), weights.begin());
|
|
||||||
#ifdef USE_CUDA
|
|
||||||
toCuda();
|
|
||||||
#endif
|
|
||||||
}
|
|
||||||
|
|
||||||
std::vector<float> BatchNorm2d::getWeights() {
|
|
||||||
return weights;
|
return weights;
|
||||||
}
|
}
|
||||||
|
|
||||||
void BatchNorm2d::setBiases(const float* biases_input) {
|
void BatchNorm2d::set_biases(void* input) {
|
||||||
std::copy(biases_input, biases_input + biases.size(), biases.begin());
|
biases.set_data<float>(static_cast<float*>(input));
|
||||||
#ifdef USE_CUDA
|
|
||||||
toCuda();
|
|
||||||
#endif
|
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<float> BatchNorm2d::getBiases() {
|
CUDANet::Tensor& BatchNorm2d::get_biases() {
|
||||||
return biases;
|
return biases;
|
||||||
}
|
}
|
||||||
|
|
||||||
void BatchNorm2d::setRunningMean(const float* running_mean_input) {
|
void BatchNorm2d::set_running_mean(void* input) {
|
||||||
std::copy(
|
running_mean.set_data<float>(static_cast<float*>(input));
|
||||||
running_mean_input, running_mean_input + inputChannels,
|
|
||||||
running_mean.begin()
|
|
||||||
);
|
|
||||||
#ifdef USE_CUDA
|
|
||||||
toCuda();
|
|
||||||
#endif
|
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<float> BatchNorm2d::getRunningMean() {
|
CUDANet::Tensor& BatchNorm2d::get_running_mean() {
|
||||||
return running_mean;
|
return running_mean;
|
||||||
}
|
}
|
||||||
|
|
||||||
void BatchNorm2d::setRunningVar(const float* running_var_input) {
|
void BatchNorm2d::set_running_var(void* input) {
|
||||||
std::copy(
|
running_var.set_data<float>(static_cast<float*>(input));
|
||||||
running_var_input, running_var_input + inputChannels,
|
|
||||||
running_var.begin()
|
|
||||||
);
|
|
||||||
#ifdef USE_CUDA
|
|
||||||
toCuda();
|
|
||||||
#endif
|
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<float> BatchNorm2d::getRunningVar() {
|
CUDANet::Tensor& BatchNorm2d::get_running_var() {
|
||||||
return running_var;
|
return running_var;
|
||||||
}
|
}
|
||||||
|
|
||||||
int BatchNorm2d::getInputSize() {
|
|
||||||
return inputSize.first * inputSize.second * inputChannels;
|
|
||||||
}
|
|
||||||
|
|
||||||
int BatchNorm2d::getOutputSize() {
|
|
||||||
return inputSize.first * inputSize.second * inputChannels;
|
|
||||||
}
|
|
||||||
|
|
||||||
shape2d BatchNorm2d::getOutputDims() {
|
|
||||||
return inputSize;
|
|
||||||
}
|
|
||||||
|
|
||||||
float* BatchNorm2d::forwardCPU(const float* input) {
|
|
||||||
throw std::logic_error("Not implemented");
|
|
||||||
}
|
|
||||||
|
|
||||||
float* BatchNorm2d::forward(const float* input) {
|
|
||||||
#ifdef USE_CUDA
|
|
||||||
return forwardCUDA(input);
|
|
||||||
#else
|
|
||||||
return forwardCPU(input);
|
|
||||||
#endif
|
|
||||||
}
|
|
||||||
@@ -47,7 +47,7 @@ Conv2d::Conv2d(
|
|||||||
};
|
};
|
||||||
|
|
||||||
output = CUDANet::Tensor(
|
output = CUDANet::Tensor(
|
||||||
Shape{out_shape[0] * out_shape[1] * out_shape[3]},
|
Shape{out_shape[0], out_shape[1], out_shape[3]},
|
||||||
CUDANet::DType::FLOAT32, backend
|
CUDANet::DType::FLOAT32, backend
|
||||||
);
|
);
|
||||||
|
|
||||||
|
|||||||
@@ -50,7 +50,7 @@ MaxPool2d::~MaxPool2d() {}
|
|||||||
|
|
||||||
CUDANet::Tensor& MaxPool2d::forward(CUDANet::Tensor& input) {
|
CUDANet::Tensor& MaxPool2d::forward(CUDANet::Tensor& input) {
|
||||||
output.zero();
|
output.zero();
|
||||||
backend->maxPool2d(
|
backend->max_pool2d(
|
||||||
input, output, in_shape, pool_shape, stride_shape, padding_shape,
|
input, output, in_shape, pool_shape, stride_shape, padding_shape,
|
||||||
out_shape
|
out_shape
|
||||||
);
|
);
|
||||||
|
|||||||
Reference in New Issue
Block a user