mirror of
https://github.com/lordmathis/CUDANet.git
synced 2025-12-24 23:34:24 +00:00
Implement InvalidShapeException
This commit is contained in:
@@ -18,35 +18,19 @@ AvgPool2d::AvgPool2d(
|
||||
padding_shape(padding_shape),
|
||||
backend(backend) {
|
||||
if (in_shape.size() != 3) {
|
||||
throw std::runtime_error(
|
||||
std::format(
|
||||
"Invalid input shape. Expected 3 dims, got {}", input_shape.size()
|
||||
)
|
||||
);
|
||||
throw InvalidShapeException("input", 3, in_shape.size());
|
||||
}
|
||||
|
||||
if (pool_shape.size() != 2) {
|
||||
throw std::runtime_error(
|
||||
std::format(
|
||||
"Invalid pool shape. Expected 2 dims, got {}", pool_shape.size()
|
||||
)
|
||||
);
|
||||
throw InvalidShapeException("pool", 2, pool_shape.size());
|
||||
}
|
||||
|
||||
if (stride_shape.size() != 2) {
|
||||
throw std::runtime_error(
|
||||
std::format(
|
||||
"Invalid stride shape. Expected 2 dims, got {}", stride_shape.size()
|
||||
)
|
||||
);
|
||||
throw InvalidShapeException("stride", 2, stride_shape.size());
|
||||
}
|
||||
|
||||
if (padding_shape.size() != 2) {
|
||||
throw std::runtime_error(
|
||||
std::format(
|
||||
"Invalid padding shape. Expected 2 dims, got {}", padding_shape.size()
|
||||
)
|
||||
);
|
||||
throw InvalidShapeException("padding", 2, padding_shape.size());
|
||||
}
|
||||
|
||||
out_shape = {
|
||||
|
||||
Reference in New Issue
Block a user