MiniDNN
Output.h
1 #ifndef OUTPUT_H_
2 #define OUTPUT_H_
3 
4 #include <Eigen/Core>
5 #include <stdexcept>
6 #include "Config.h"
7 
8 namespace MiniDNN {
9 
10 
14 
22 class Output
23 {
24 protected:
25  typedef Eigen::Matrix<Scalar, Eigen::Dynamic, Eigen::Dynamic> Matrix;
26  typedef Eigen::Matrix<Scalar, Eigen::Dynamic, 1> Vector;
27  typedef Eigen::RowVectorXi IntegerVector;
28 
29 public:
30  virtual ~Output() {}
31 
32  // Check the format of target data, e.g. in classification problems the
33  // target data should be binary (either 0 or 1)
34  virtual void check_target_data(const Matrix& target) {}
35 
36  // Another type of target data where each element is a class label
37  // This version may not be sensible for regression tasks, so by default
38  // we raise an exception
39  virtual void check_target_data(const IntegerVector& target)
40  {
41  throw std::invalid_argument("[class Output]: This output type cannot take class labels as target data");
42  }
43 
44  // A combination of the forward stage and the back-propagation stage for the output layer
45  // The computed derivative of the input should be stored in this layer, and can be retrieved by
46  // the backprop_data() function
47  virtual void evaluate(const Matrix& prev_layer_data, const Matrix& target) = 0;
48 
49  // Another type of target data where each element is a class label
50  // This version may not be sensible for regression tasks, so by default
51  // we raise an exception
52  virtual void evaluate(const Matrix& prev_layer_data, const IntegerVector& target)
53  {
54  throw std::invalid_argument("[class Output]: This output type cannot take class labels as target data");
55  }
56 
57  // The derivative of the input of this layer, which is also the derivative
58  // of the output of previous layer
59  virtual const Matrix& backprop_data() const = 0;
60 
61  // Return the loss function value after the evaluation
62  // This function can be assumed to be called after evaluate(), so that it can make use of the
63  // intermediate result to save some computation
64  virtual Scalar loss() const = 0;
65 };
66 
67 
68 } // namespace MiniDNN
69 
70 
71 #endif /* OUTPUT_H_ */