Migrate Activation layer to Tensor

This commit is contained in:
2025-11-17 22:51:54 +01:00
parent d231e515b1
commit 25670f90c4
2 changed files with 7 additions and 7 deletions

View File

@@ -30,7 +30,7 @@ class Activation {
* @param activation Type of activation * @param activation Type of activation
* @param length Length of the input * @param length Length of the input
*/ */
Activation(ActivationType activation, const int length); Activation(CUDANet::Backend::IBackend* backend, ActivationType activation, const int length);
/** /**
* @brief Destroy the Activation object * @brief Destroy the Activation object

View File

@@ -6,13 +6,13 @@
using namespace CUDANet::Layers; using namespace CUDANet::Layers;
Activation::Activation(ActivationType activation, const int length) Activation::Activation(CUDANet::Backend::IBackend* backend, ActivationType activation, const int length)
: activationType(activation), length(length) { : backend(backend), activationType(activation), length(length) {
if (activationType == SOFTMAX) { if (activationType == SOFTMAX) {
softmax_sum = CUDANet::Backend::Tensor({static_cast<size_t>(length)}, CUDANet::Backend::DType::FLOAT32, nullptr); softmax_sum = CUDANet::Backend::Tensor({static_cast<size_t>(length)}, CUDANet::Backend::DType::FLOAT32, backend);
tensor_max = CUDANet::Backend::Tensor({static_cast<size_t>(length)}, CUDANet::Backend::DType::FLOAT32, nullptr); tensor_max = CUDANet::Backend::Tensor({static_cast<size_t>(length)}, CUDANet::Backend::DType::FLOAT32, backend);
} }
} }
@@ -23,10 +23,10 @@ void Activation::activate(CUDANet::Backend::Tensor input) {
backend->sigmoid(input); backend->sigmoid(input);
break; break;
case ActivationType::RELU: case ActivationType::RELU:
/* code */ backend->relu(input);
break; break;
case ActivationType::SOFTMAX: case ActivationType::SOFTMAX:
/* code */ backend->softmax(input, tensor_max, softmax_sum);
break; break;
default: default:
break; break;