MiniDNN
Callback.h
1 #ifndef CALLBACK_H_
2 #define CALLBACK_H_
3 
4 #include <Eigen/Core>
5 #include "Config.h"
6 
7 namespace MiniDNN {
8 
9 
10 class Network;
11 
15 
28 class Callback
29 {
30 protected:
31  typedef Eigen::Matrix<Scalar, Eigen::Dynamic, Eigen::Dynamic> Matrix;
32  typedef Eigen::RowVectorXi IntegerVector;
33 
34 public:
35  // Public members that will be set by the network during the training process
36  int m_nbatch; // Number of total batches
37  int m_batch_id; // The index for the current mini-batch (0, 1, ..., m_nbatch-1)
38  int m_nepoch; // Total number of epochs (one run on the whole data set) in the training process
39  int m_epoch_id; // The index for the current epoch (0, 1, ..., m_nepoch-1)
40 
41  Callback() :
42  m_nbatch(0), m_batch_id(0), m_nepoch(0), m_epoch_id(0)
43  {}
44 
45  virtual ~Callback() {}
46 
47  // Before training a mini-batch
48  virtual void pre_training_batch(const Network* net, const Matrix& x, const Matrix& y) {}
49  virtual void pre_training_batch(const Network* net, const Matrix& x, const IntegerVector& y) {}
50 
51  // After a mini-batch is trained
52  virtual void post_training_batch(const Network* net, const Matrix& x, const Matrix& y) {}
53  virtual void post_training_batch(const Network* net, const Matrix& x, const IntegerVector& y) {}
54 };
55 
56 
57 } // namespace MiniDNN
58 
59 
60 #endif /* CALLBACK_H_ */