-
Notifications
You must be signed in to change notification settings - Fork 3
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Use command line options to set up & run the NN.
Configurations can be passed using command line options. --help list all the possible combinations of acceptable commands. verification of passed arguments.
- Loading branch information
Showing
9 changed files
with
546 additions
and
299 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,43 @@ | ||
#ifndef ANN_COST_FUNCTION_H | ||
#define ANN_COST_FUNCTION_H | ||
|
||
namespace ANN { | ||
template<typename T> | ||
struct Evaluate { | ||
typedef T init_type; | ||
const init_type init_value; | ||
Evaluate(init_type i) | ||
: init_value(i) | ||
{ } | ||
virtual T operator()(T t1, T t2) const = 0; | ||
}; | ||
|
||
struct BoolAnd : public Evaluate<bool> { | ||
BoolAnd() | ||
: Evaluate(true) | ||
{ } | ||
bool operator()(bool i1, bool i2) const { | ||
return i1 && i2; | ||
} | ||
}; | ||
|
||
struct BoolOr : public Evaluate<bool> { | ||
BoolOr() | ||
: Evaluate(false) | ||
{ } | ||
bool operator()(bool i1, bool i2) const { | ||
return i1 || i2; | ||
} | ||
}; | ||
|
||
struct BoolXor : public Evaluate<bool> { | ||
BoolXor() | ||
: Evaluate(false) | ||
{ } | ||
bool operator()(bool i1, bool i2) const { | ||
return i1 ^ i2; | ||
} | ||
}; | ||
} // namespace ANN | ||
|
||
#endif // ANN_COST_FUNCTION_H |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,89 @@ | ||
#include "NetworkConfiguration.h" | ||
#include "Debug.h" | ||
#include <getopt.h> | ||
#include <map> | ||
#include <string> | ||
|
||
void PrintCmdline() { | ||
std::cerr <<"\n\noptions: --activation Act --converge-method conv --cost-function cf" | ||
<< "--network net --training-algo ta"; | ||
std::cerr << "\n\n--activation [LinearAct|SigmoidAct]" | ||
<< "\n--converge-method [GradientDescent|SimpleDelta]" | ||
<< "\n--cost-function [BoolAnd|BoolOr|BoolXor]" | ||
<< "\n--help" | ||
<< "\n--network [SLFFN|TLFFN]" | ||
<< "\n--training-algo [BackProp|FeedForward]" | ||
<< "\n"; | ||
std::cerr << "\nUsing GNU opt-parser, " | ||
<< "either space or equals(=) works with long options.\n"; | ||
} | ||
|
||
void PrintOptMap(OPT::OptmapType m) { | ||
auto& os = dbgs(); | ||
std::for_each(m.begin(), m.end(), | ||
[&os](OPT::OptmapType::value_type v) { | ||
std::cerr<< "\n" <<v.first << ":" << v.second; | ||
}); | ||
} | ||
|
||
int main(int argc, char* argv[]) { | ||
using namespace OPT; | ||
int opt= 0; | ||
static struct option lopt[] = { | ||
{activation.c_str(), required_argument, 0, 'a' }, | ||
{converge_method.c_str(), required_argument, 0, 'c' }, | ||
{cost_function.c_str(), required_argument, 0, 'C' }, | ||
{help.c_str(), no_argument, 0, 'h' }, | ||
{network.c_str(), required_argument, 0, 'n' }, | ||
{training_algo.c_str(), required_argument, 0, 't' }, | ||
{ 0, 0, 0, 0 } | ||
}; | ||
OptmapType optmap; | ||
int lidx =0; | ||
while ((opt = getopt_long_only(argc, argv,"", | ||
lopt, &lidx )) != -1) { | ||
switch (opt) { | ||
case 'a': | ||
DEBUG2(dbgs() << "\nActivation Function:" << optarg); | ||
optmap[activation] = optarg; | ||
break; | ||
case 'c': | ||
DEBUG2(dbgs() << "\nConvergence Method:" << optarg); | ||
optmap[converge_method] = optarg; | ||
break; | ||
case 'C': | ||
DEBUG2(dbgs() << "\nCost Function:" << optarg); | ||
optmap[cost_function] = optarg; | ||
break; | ||
case 'h': | ||
PrintCmdline(); | ||
return 0; | ||
case 'n': | ||
DEBUG2(dbgs() << "\nNetwork:" << optarg); | ||
optmap[network] = optarg; | ||
break; | ||
case 't': | ||
DEBUG2(dbgs() << "\nTraining Algorithm:" << optarg); | ||
optmap[training_algo] = optarg; | ||
break; | ||
default: | ||
std::cerr << "\nUnknown arguments. exiting..."; | ||
PrintCmdline(); | ||
return -1; | ||
} | ||
} | ||
|
||
PrintOptMap(optmap); | ||
|
||
ANN::NetworkConfiguration nc; | ||
if(nc.ValidateOptmap(optmap)) { | ||
nc.setup(optmap); | ||
while (!nc.VerifyTraining()) | ||
nc.run(); | ||
} else { | ||
DEBUG0(dbgs() << "\nInvalid/Insufficient options. exiting..."); | ||
PrintCmdline(); | ||
return -1; | ||
} | ||
return 0; | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,164 @@ | ||
#ifndef NETWORKCONFIGURATION_H | ||
#define NETWORKCONFIGURATION_H | ||
|
||
#include "TrainingAlgorithms.h" | ||
|
||
#include <map> | ||
#include <string> | ||
|
||
namespace OPT { | ||
typedef std::map<std::string, std::string> OptmapType; | ||
std::string activation = "activation"; | ||
std::string converge_method = "converge-method"; | ||
std::string cost_function = "cost-function"; | ||
std::string help = "help"; | ||
std::string network = "network"; | ||
std::string training_algo = "training-algo"; | ||
} // namespace OPT | ||
|
||
namespace ANN { | ||
class NetworkConfiguration { | ||
typedef ValidateOutput<std::vector<bool>, bool> ValidatorType; | ||
typedef std::map<std::string, Activation<NeuronWeightType>* > | ||
ActivationFunctionsType; | ||
typedef std::map<std::string, ConvergenceMethod*> ConvergenceMethodsType; | ||
typedef std::map<std::string, Evaluate<bool>* > EvaluatorsType; | ||
Evaluate<bool>* CostFunction; | ||
|
||
bool validated; | ||
Trainer* T; | ||
ValidatorType* Validator; | ||
float alpha; | ||
const int ip_size; | ||
unsigned times_trained; | ||
ActivationFunctionsType ActivationFunctions; | ||
//std::map<std::string, TrainingAlgorithm> TrainingAlgorithms; | ||
ConvergenceMethodsType ConvergenceMethods; | ||
EvaluatorsType Evaluators; | ||
NeuralNetwork NN; | ||
public: | ||
typedef std::map<std::string, std::string> OptMapType; | ||
NetworkConfiguration() | ||
: validated(false), T(nullptr), | ||
Validator(nullptr), alpha(0.01), | ||
ip_size(3), times_trained(0) { | ||
/// @todo Rather than generating by default, | ||
/// make the construction on-demand. | ||
ActivationFunctions["LinearAct"] = new LinearAct<NeuronWeightType>; | ||
ActivationFunctions["SigmoidAct"] = new SigmoidAct<NeuronWeightType>; | ||
ConvergenceMethods["SimpleDelta"] = new SimpleDelta; | ||
ConvergenceMethods["GradientDescent"] = new GradientDescent; | ||
Evaluators["BoolAnd"] = new BoolAnd; | ||
Evaluators["BoolOr"] = new BoolOr; | ||
Evaluators["BoolXor"] = new BoolXor; | ||
} | ||
~NetworkConfiguration() { | ||
if (T) | ||
delete T; | ||
if (Validator) | ||
delete Validator; | ||
std::for_each(ActivationFunctions.begin(), ActivationFunctions.end(), | ||
[](ActivationFunctionsType::value_type v){ | ||
delete v.second; | ||
}); | ||
std::for_each(ConvergenceMethods.begin(), ConvergenceMethods.end(), | ||
[](ConvergenceMethodsType::value_type v){ | ||
delete v.second; | ||
}); | ||
} | ||
|
||
virtual void setup(OptMapType& optmap) { | ||
// Create a single layer feed forward neural network. | ||
NN = CreateTLFFN(ip_size, ActivationFunctions[optmap[OPT::activation]]); | ||
NN.PrintNNDigraph(*NN.GetRoot(), std::cout); | ||
// Choose the training algorithm. | ||
ConvergenceMethod* CM = ConvergenceMethods[optmap[OPT::converge_method]]; | ||
assert(CM); | ||
T = new Trainer(NN, CM, alpha); | ||
// Validation of the output. | ||
Validator = new ValidatorType(NN); | ||
CostFunction = Evaluators[optmap[OPT::cost_function]]; | ||
assert(CostFunction); | ||
} | ||
|
||
virtual void run() { | ||
using namespace utilities; | ||
T->SetAlpha(alpha); | ||
DEBUG0(dbgs() << "\nTraining with alpha:" << alpha); | ||
for (unsigned i = 0; i < 10;) { | ||
std::vector<bool> RS = GetRandomizedSet(BooleanSampleSpace, ip_size-1); | ||
std::vector<float> RSF = BoolsToFloats(RS); | ||
// The last input is the bias. | ||
RSF.insert(RSF.begin(), -1); | ||
DEBUG0(dbgs() << "\nSample Inputs:"; PrintElements(dbgs(), RSF)); | ||
//NN.PrintNNDigraph(*NN.GetRoot(), std::cout); | ||
auto op = NN.GetOutput(RSF); | ||
auto bool_op = FloatToBool(op); | ||
auto desired_op = Validator->GetDesiredOutput(CostFunction, RS); | ||
// Is the output same as desired output? | ||
if (!Validator->Validate(CostFunction, RS, bool_op)) { | ||
DEBUG0(dbgs() << "\nLearning (" << op << ", " | ||
<< bool_op << ", " | ||
<< desired_op << ")"); | ||
//NN.PrintNNDigraph(*NN.GetRoot(), std::cout); | ||
// No => Train | ||
T->TrainNetworkBackProp(RSF, desired_op); | ||
++times_trained; | ||
i = 0; | ||
//NN.PrintNNDigraph(*NN.GetRoot(), std::cout); | ||
} else { | ||
++i; // Increment trained counter. | ||
DEBUG0(dbgs() << "\tTrained (" << op << ", " << bool_op << ")"); | ||
} | ||
} | ||
} | ||
|
||
virtual bool VerifyTraining() { | ||
using namespace utilities; | ||
bool trained = true; | ||
DEBUG0(dbgs() << "\nPrinting after training"); | ||
for (unsigned i = 0; i < 20; ++i) { | ||
std::vector<bool> RS = GetRandomizedSet(BooleanSampleSpace, ip_size-1); | ||
std::vector<float> RSF = BoolsToFloats(RS); | ||
// The last input is the bias. | ||
RSF.insert(RSF.begin(), -1); | ||
auto op = NN.GetOutput(RSF); | ||
DEBUG0(dbgs() << "\nSample Inputs:"; PrintElements(dbgs(), RSF)); | ||
if (Validator->Validate(CostFunction, RS, FloatToBool(op))) | ||
DEBUG0(dbgs() << "\tTrained (" << op << ", " << FloatToBool(op) << ")"); | ||
else { | ||
// double the training rate. | ||
alpha = alpha < 0.4 ? 2*alpha : alpha; | ||
trained = false; | ||
DEBUG0(dbgs() << "\tUnTrained: " << op); | ||
break; | ||
} | ||
} | ||
DEBUG0(dbgs() << "\nTrained for " << times_trained << " cycles."); | ||
return trained; | ||
} | ||
|
||
bool ValidateOptmap(OptMapType& optmap) { | ||
using namespace OPT; | ||
if (optmap[activation] != "LinearAct" || | ||
optmap[activation] != "SigmoidAct") | ||
return false; | ||
if (optmap[converge_method] != "GradientDescent" || | ||
optmap[converge_method] != "SimpleDelta") | ||
return false; | ||
if (optmap[cost_function] != "BoolAnd" || | ||
optmap[cost_function] != "BoolOr" || | ||
optmap[cost_function] != "BoolXor") | ||
return false; | ||
if (optmap[network] != "SLFFN" || | ||
optmap[network] != "TLFFN") | ||
return false; | ||
if (optmap[training_algo] != "BackProp" || | ||
optmap[training_algo] != "FeedForward") | ||
return false; | ||
validated = true; | ||
return true; | ||
} | ||
}; | ||
} | ||
#endif // NETWORKCONFIGURATION_H |
Oops, something went wrong.