mirror of
https://github.com/lordmathis/CUDANet.git
synced 2025-12-22 22:34:22 +00:00
Migrate model class to Tensor
This commit is contained in:
55
include/model.hpp
Normal file
55
include/model.hpp
Normal file
@@ -0,0 +1,55 @@
|
||||
#pragma once
|
||||
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
#include <vector>
|
||||
|
||||
#include "layer.hpp"
|
||||
#include "module.hpp"
|
||||
|
||||
namespace CUDANet {
|
||||
|
||||
enum TensorType {
|
||||
WEIGHT,
|
||||
BIAS,
|
||||
RUNNING_MEAN,
|
||||
RUNNING_VAR
|
||||
};
|
||||
|
||||
struct TensorInfo {
|
||||
std::string name;
|
||||
TensorType type;
|
||||
int size;
|
||||
int offset;
|
||||
};
|
||||
|
||||
class Model {
|
||||
public:
|
||||
Model(const CUDANet::Shape input_shape, const CUDANet::Shape output_shape);
|
||||
~Model();
|
||||
|
||||
virtual CUDANet::Tensor& predict(CUDANet::Tensor& input);
|
||||
|
||||
CUDANet::Layer* get_layer(const std::string& name);
|
||||
|
||||
void register_layer(const std::string& name, Layer* layer);
|
||||
|
||||
void register_module(Module& module);
|
||||
|
||||
void load_weights(const std::string& path);
|
||||
|
||||
bool validate();
|
||||
|
||||
void print_summary();
|
||||
|
||||
protected:
|
||||
CUDANet::Shape in_shape;
|
||||
CUDANet::Shape out_shape;
|
||||
|
||||
CUDANet::Tensor output;
|
||||
|
||||
std::vector<std::pair<std::string, Layer*>> layers;
|
||||
std::unordered_map<std::string, Layer*> layer_map;
|
||||
};
|
||||
|
||||
} // namespace CUDANet
|
||||
Reference in New Issue
Block a user