MiniDNN
Callback.h
1
#ifndef CALLBACK_H_
2
#define CALLBACK_H_
3
4
#include <Eigen/Core>
5
#include "Config.h"
6
7
namespace
MiniDNN
{
8
9
10
class
Network;
11
15
28
class
Callback
29
{
30
protected
:
31
typedef
Eigen::Matrix<Scalar, Eigen::Dynamic, Eigen::Dynamic> Matrix;
32
typedef
Eigen::RowVectorXi IntegerVector;
33
34
public
:
35
// Public members that will be set by the network during the training process
36
int
m_nbatch;
// Number of total batches
37
int
m_batch_id;
// The index for the current mini-batch (0, 1, ..., m_nbatch-1)
38
int
m_nepoch;
// Total number of epochs (one run on the whole data set) in the training process
39
int
m_epoch_id;
// The index for the current epoch (0, 1, ..., m_nepoch-1)
40
41
Callback
() :
42
m_nbatch(0), m_batch_id(0), m_nepoch(0), m_epoch_id(0)
43
{}
44
45
virtual
~
Callback
() {}
46
47
// Before training a mini-batch
48
virtual
void
pre_training_batch(
const
Network
* net,
const
Matrix& x,
const
Matrix& y) {}
49
virtual
void
pre_training_batch(
const
Network
* net,
const
Matrix& x,
const
IntegerVector& y) {}
50
51
// After a mini-batch is trained
52
virtual
void
post_training_batch(
const
Network
* net,
const
Matrix& x,
const
Matrix& y) {}
53
virtual
void
post_training_batch(
const
Network
* net,
const
Matrix& x,
const
IntegerVector& y) {}
54
};
55
56
57
}
// namespace MiniDNN
58
59
60
#endif
/* CALLBACK_H_ */
MiniDNN::Callback
Definition:
Callback.h:28
MiniDNN::Network
Definition:
Network.h:28
MiniDNN
Definition:
Callback.h:7
Callback.h
Generated by
1.8.13