mirror of
https://github.com/lordmathis/CUDANet.git
synced 2025-12-23 14:54:28 +00:00
Migrate Activation layer to Tensor
This commit is contained in:
@@ -30,7 +30,7 @@ class Activation {
|
|||||||
* @param activation Type of activation
|
* @param activation Type of activation
|
||||||
* @param length Length of the input
|
* @param length Length of the input
|
||||||
*/
|
*/
|
||||||
Activation(ActivationType activation, const int length);
|
Activation(CUDANet::Backend::IBackend* backend, ActivationType activation, const int length);
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief Destroy the Activation object
|
* @brief Destroy the Activation object
|
||||||
|
|||||||
@@ -6,13 +6,13 @@
|
|||||||
|
|
||||||
using namespace CUDANet::Layers;
|
using namespace CUDANet::Layers;
|
||||||
|
|
||||||
Activation::Activation(ActivationType activation, const int length)
|
Activation::Activation(CUDANet::Backend::IBackend* backend, ActivationType activation, const int length)
|
||||||
: activationType(activation), length(length) {
|
: backend(backend), activationType(activation), length(length) {
|
||||||
|
|
||||||
|
|
||||||
if (activationType == SOFTMAX) {
|
if (activationType == SOFTMAX) {
|
||||||
softmax_sum = CUDANet::Backend::Tensor({static_cast<size_t>(length)}, CUDANet::Backend::DType::FLOAT32, nullptr);
|
softmax_sum = CUDANet::Backend::Tensor({static_cast<size_t>(length)}, CUDANet::Backend::DType::FLOAT32, backend);
|
||||||
tensor_max = CUDANet::Backend::Tensor({static_cast<size_t>(length)}, CUDANet::Backend::DType::FLOAT32, nullptr);
|
tensor_max = CUDANet::Backend::Tensor({static_cast<size_t>(length)}, CUDANet::Backend::DType::FLOAT32, backend);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -23,10 +23,10 @@ void Activation::activate(CUDANet::Backend::Tensor input) {
|
|||||||
backend->sigmoid(input);
|
backend->sigmoid(input);
|
||||||
break;
|
break;
|
||||||
case ActivationType::RELU:
|
case ActivationType::RELU:
|
||||||
/* code */
|
backend->relu(input);
|
||||||
break;
|
break;
|
||||||
case ActivationType::SOFTMAX:
|
case ActivationType::SOFTMAX:
|
||||||
/* code */
|
backend->softmax(input, tensor_max, softmax_sum);
|
||||||
break;
|
break;
|
||||||
default:
|
default:
|
||||||
break;
|
break;
|
||||||
|
|||||||
Reference in New Issue
Block a user