Skip to content

Commit

Permalink
Use command line options to set up & run the NN.
Browse files Browse the repository at this point in the history
Configurations can be passed using command line options.
--help list all the possible combinations of acceptable commands.
verification of passed arguments.
  • Loading branch information
hiraditya committed Oct 13, 2013
1 parent 10f8bad commit d9f9600
Show file tree
Hide file tree
Showing 9 changed files with 546 additions and 299 deletions.
22 changes: 7 additions & 15 deletions AI/ANN/Activation.h
Original file line number Diff line number Diff line change
@@ -1,33 +1,25 @@
#ifndef ANN_ACTIVATION_FUNCTION_H
#define ANN_ACTIVATION_FUNCTION_H
#include<Debug.h>
#include<cmath>
#include <Debug.h>
#include <cmath>

namespace ANN {
// Activation functions designed in the form of CRTP.
template<typename T>
template<typename WeightType>
class Activation {
public:
// @TODO: Try this: T operator()(T t) {
// FAQ: Even if I do not write this function as
// template of WeightType and put T instead of WeightType, it works.
// I don't quite understand why?
//
template<typename WeightType>
WeightType Act(WeightType t) const {
const T& ref = static_cast<const T&>(*this);
return ref.Act(t);
}
virtual WeightType Act(WeightType t) const = 0;
// Derivative of linear activation function w.r.t. weight.
template<typename WeightType>
WeightType Deriv(WeightType t) const {
const T& ref = static_cast<const T&>(*this);
return ref.Deriv(t);
}
virtual WeightType Deriv(WeightType t) const = 0;
};

template<typename WeightType>
class LinearAct : public Activation<LinearAct<WeightType> > {
class LinearAct : public Activation<WeightType> {
public:
WeightType Act(WeightType w) const {
return w;
Expand All @@ -39,7 +31,7 @@ namespace ANN {
};

template<typename WeightType>
class SigmoidAct : public Activation<SigmoidAct<WeightType> > {
class SigmoidAct : public Activation<WeightType> {
public:
WeightType Act(WeightType w) const {
DEBUG2(dbgs() << "\nSigmoid function Input: " << w
Expand Down
6 changes: 6 additions & 0 deletions AI/ANN/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,15 @@ cmake_minimum_required(VERSION 2.8)
set(NeuralNetwork
Activation.h
TrainingAlgorithms.h
CostFunction.h
NetworkConfiguration.h
NeuralNetwork.h
NeuralNetwork.cpp
)

set(Main
Main.cpp
)

add_library(NeuralNetwork ${NeuralNetwork})
add_executable(Main ${Main})
43 changes: 43 additions & 0 deletions AI/ANN/CostFunction.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
#ifndef ANN_COST_FUNCTION_H
#define ANN_COST_FUNCTION_H

namespace ANN {
template<typename T>
struct Evaluate {
typedef T init_type;
const init_type init_value;
Evaluate(init_type i)
: init_value(i)
{ }
virtual T operator()(T t1, T t2) const = 0;
};

struct BoolAnd : public Evaluate<bool> {
BoolAnd()
: Evaluate(true)
{ }
bool operator()(bool i1, bool i2) const {
return i1 && i2;
}
};

struct BoolOr : public Evaluate<bool> {
BoolOr()
: Evaluate(false)
{ }
bool operator()(bool i1, bool i2) const {
return i1 || i2;
}
};

struct BoolXor : public Evaluate<bool> {
BoolXor()
: Evaluate(false)
{ }
bool operator()(bool i1, bool i2) const {
return i1 ^ i2;
}
};
} // namespace ANN

#endif // ANN_COST_FUNCTION_H
89 changes: 89 additions & 0 deletions AI/ANN/Main.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
#include "NetworkConfiguration.h"
#include "Debug.h"
#include <getopt.h>
#include <map>
#include <string>

void PrintCmdline() {
std::cerr <<"\n\noptions: --activation Act --converge-method conv --cost-function cf"
<< "--network net --training-algo ta";
std::cerr << "\n\n--activation [LinearAct|SigmoidAct]"
<< "\n--converge-method [GradientDescent|SimpleDelta]"
<< "\n--cost-function [BoolAnd|BoolOr|BoolXor]"
<< "\n--help"
<< "\n--network [SLFFN|TLFFN]"
<< "\n--training-algo [BackProp|FeedForward]"
<< "\n";
std::cerr << "\nUsing GNU opt-parser, "
<< "either space or equals(=) works with long options.\n";
}

void PrintOptMap(OPT::OptmapType m) {
auto& os = dbgs();
std::for_each(m.begin(), m.end(),
[&os](OPT::OptmapType::value_type v) {
std::cerr<< "\n" <<v.first << ":" << v.second;
});
}

int main(int argc, char* argv[]) {
using namespace OPT;
int opt= 0;
static struct option lopt[] = {
{activation.c_str(), required_argument, 0, 'a' },
{converge_method.c_str(), required_argument, 0, 'c' },
{cost_function.c_str(), required_argument, 0, 'C' },
{help.c_str(), no_argument, 0, 'h' },
{network.c_str(), required_argument, 0, 'n' },
{training_algo.c_str(), required_argument, 0, 't' },
{ 0, 0, 0, 0 }
};
OptmapType optmap;
int lidx =0;
while ((opt = getopt_long_only(argc, argv,"",
lopt, &lidx )) != -1) {
switch (opt) {
case 'a':
DEBUG2(dbgs() << "\nActivation Function:" << optarg);
optmap[activation] = optarg;
break;
case 'c':
DEBUG2(dbgs() << "\nConvergence Method:" << optarg);
optmap[converge_method] = optarg;
break;
case 'C':
DEBUG2(dbgs() << "\nCost Function:" << optarg);
optmap[cost_function] = optarg;
break;
case 'h':
PrintCmdline();
return 0;
case 'n':
DEBUG2(dbgs() << "\nNetwork:" << optarg);
optmap[network] = optarg;
break;
case 't':
DEBUG2(dbgs() << "\nTraining Algorithm:" << optarg);
optmap[training_algo] = optarg;
break;
default:
std::cerr << "\nUnknown arguments. exiting...";
PrintCmdline();
return -1;
}
}

PrintOptMap(optmap);

ANN::NetworkConfiguration nc;
if(nc.ValidateOptmap(optmap)) {
nc.setup(optmap);
while (!nc.VerifyTraining())
nc.run();
} else {
DEBUG0(dbgs() << "\nInvalid/Insufficient options. exiting...");
PrintCmdline();
return -1;
}
return 0;
}
164 changes: 164 additions & 0 deletions AI/ANN/NetworkConfiguration.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,164 @@
#ifndef NETWORKCONFIGURATION_H
#define NETWORKCONFIGURATION_H

#include "TrainingAlgorithms.h"

#include <map>
#include <string>

namespace OPT {
typedef std::map<std::string, std::string> OptmapType;
std::string activation = "activation";
std::string converge_method = "converge-method";
std::string cost_function = "cost-function";
std::string help = "help";
std::string network = "network";
std::string training_algo = "training-algo";
} // namespace OPT

namespace ANN {
class NetworkConfiguration {
typedef ValidateOutput<std::vector<bool>, bool> ValidatorType;
typedef std::map<std::string, Activation<NeuronWeightType>* >
ActivationFunctionsType;
typedef std::map<std::string, ConvergenceMethod*> ConvergenceMethodsType;
typedef std::map<std::string, Evaluate<bool>* > EvaluatorsType;
Evaluate<bool>* CostFunction;

bool validated;
Trainer* T;
ValidatorType* Validator;
float alpha;
const int ip_size;
unsigned times_trained;
ActivationFunctionsType ActivationFunctions;
//std::map<std::string, TrainingAlgorithm> TrainingAlgorithms;
ConvergenceMethodsType ConvergenceMethods;
EvaluatorsType Evaluators;
NeuralNetwork NN;
public:
typedef std::map<std::string, std::string> OptMapType;
NetworkConfiguration()
: validated(false), T(nullptr),
Validator(nullptr), alpha(0.01),
ip_size(3), times_trained(0) {
/// @todo Rather than generating by default,
/// make the construction on-demand.
ActivationFunctions["LinearAct"] = new LinearAct<NeuronWeightType>;
ActivationFunctions["SigmoidAct"] = new SigmoidAct<NeuronWeightType>;
ConvergenceMethods["SimpleDelta"] = new SimpleDelta;
ConvergenceMethods["GradientDescent"] = new GradientDescent;
Evaluators["BoolAnd"] = new BoolAnd;
Evaluators["BoolOr"] = new BoolOr;
Evaluators["BoolXor"] = new BoolXor;
}
~NetworkConfiguration() {
if (T)
delete T;
if (Validator)
delete Validator;
std::for_each(ActivationFunctions.begin(), ActivationFunctions.end(),
[](ActivationFunctionsType::value_type v){
delete v.second;
});
std::for_each(ConvergenceMethods.begin(), ConvergenceMethods.end(),
[](ConvergenceMethodsType::value_type v){
delete v.second;
});
}

virtual void setup(OptMapType& optmap) {
// Create a single layer feed forward neural network.
NN = CreateTLFFN(ip_size, ActivationFunctions[optmap[OPT::activation]]);
NN.PrintNNDigraph(*NN.GetRoot(), std::cout);
// Choose the training algorithm.
ConvergenceMethod* CM = ConvergenceMethods[optmap[OPT::converge_method]];
assert(CM);
T = new Trainer(NN, CM, alpha);
// Validation of the output.
Validator = new ValidatorType(NN);
CostFunction = Evaluators[optmap[OPT::cost_function]];
assert(CostFunction);
}

virtual void run() {
using namespace utilities;
T->SetAlpha(alpha);
DEBUG0(dbgs() << "\nTraining with alpha:" << alpha);
for (unsigned i = 0; i < 10;) {
std::vector<bool> RS = GetRandomizedSet(BooleanSampleSpace, ip_size-1);
std::vector<float> RSF = BoolsToFloats(RS);
// The last input is the bias.
RSF.insert(RSF.begin(), -1);
DEBUG0(dbgs() << "\nSample Inputs:"; PrintElements(dbgs(), RSF));
//NN.PrintNNDigraph(*NN.GetRoot(), std::cout);
auto op = NN.GetOutput(RSF);
auto bool_op = FloatToBool(op);
auto desired_op = Validator->GetDesiredOutput(CostFunction, RS);
// Is the output same as desired output?
if (!Validator->Validate(CostFunction, RS, bool_op)) {
DEBUG0(dbgs() << "\nLearning (" << op << ", "
<< bool_op << ", "
<< desired_op << ")");
//NN.PrintNNDigraph(*NN.GetRoot(), std::cout);
// No => Train
T->TrainNetworkBackProp(RSF, desired_op);
++times_trained;
i = 0;
//NN.PrintNNDigraph(*NN.GetRoot(), std::cout);
} else {
++i; // Increment trained counter.
DEBUG0(dbgs() << "\tTrained (" << op << ", " << bool_op << ")");
}
}
}

virtual bool VerifyTraining() {
using namespace utilities;
bool trained = true;
DEBUG0(dbgs() << "\nPrinting after training");
for (unsigned i = 0; i < 20; ++i) {
std::vector<bool> RS = GetRandomizedSet(BooleanSampleSpace, ip_size-1);
std::vector<float> RSF = BoolsToFloats(RS);
// The last input is the bias.
RSF.insert(RSF.begin(), -1);
auto op = NN.GetOutput(RSF);
DEBUG0(dbgs() << "\nSample Inputs:"; PrintElements(dbgs(), RSF));
if (Validator->Validate(CostFunction, RS, FloatToBool(op)))
DEBUG0(dbgs() << "\tTrained (" << op << ", " << FloatToBool(op) << ")");
else {
// double the training rate.
alpha = alpha < 0.4 ? 2*alpha : alpha;
trained = false;
DEBUG0(dbgs() << "\tUnTrained: " << op);
break;
}
}
DEBUG0(dbgs() << "\nTrained for " << times_trained << " cycles.");
return trained;
}

bool ValidateOptmap(OptMapType& optmap) {
using namespace OPT;
if (optmap[activation] != "LinearAct" ||
optmap[activation] != "SigmoidAct")
return false;
if (optmap[converge_method] != "GradientDescent" ||
optmap[converge_method] != "SimpleDelta")
return false;
if (optmap[cost_function] != "BoolAnd" ||
optmap[cost_function] != "BoolOr" ||
optmap[cost_function] != "BoolXor")
return false;
if (optmap[network] != "SLFFN" ||
optmap[network] != "TLFFN")
return false;
if (optmap[training_algo] != "BackProp" ||
optmap[training_algo] != "FeedForward")
return false;
validated = true;
return true;
}
};
}
#endif // NETWORKCONFIGURATION_H
Loading

0 comments on commit d9f9600

Please sign in to comment.