Add adaptive avg pooling

This commit is contained in:
2024-05-30 17:17:31 +02:00
parent 9faf20876a
commit 8168f02f58
5 changed files with 143 additions and 8 deletions

View File

@@ -126,6 +126,15 @@ def gen_avg_pool_non_square_padding_test_result():
print_cpp_vector(output)
def gen_adaptive_avg_pool_test_result():
input = _get_pool_input()
output = torch.nn.AdaptiveAvgPool2d((2, 2))(input)
output = torch.flatten(output)
print_cpp_vector(output)
if __name__ == "__main__":
print("Generating test results...")
@@ -152,3 +161,8 @@ if __name__ == "__main__":
gen_avg_pool_non_square_stride_test_result()
print("Avg pool non square padding test:")
gen_avg_pool_non_square_padding_test_result()
print("--------------")
print("Adaptive avg pool test:")
gen_adaptive_avg_pool_test_result()