Skip to content

Commit

Permalink
Add WrapperAware trait
Browse files Browse the repository at this point in the history
  • Loading branch information
ElGigi committed Jan 16, 2024
1 parent 7e262ff commit 241abc4
Show file tree
Hide file tree
Showing 5 changed files with 99 additions and 135 deletions.
44 changes: 2 additions & 42 deletions src/GridSearch.php
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
use Rubix\ML\Specifications\EstimatorIsCompatibleWithMetric;
use Rubix\ML\Specifications\SamplesAreCompatibleWithEstimator;
use Rubix\ML\Exceptions\InvalidArgumentException;
use Rubix\ML\Traits\WrapperAware;

/**
* Grid Search
Expand All @@ -41,7 +42,7 @@
*/
class GridSearch implements Wrapper, Learner, Parallel, Verbose, Persistable
{
use AutotrackRevisions, Multiprocessing, LoggerAware;
use AutotrackRevisions, Multiprocessing, LoggerAware, WrapperAware;

/**
* The class name of the base estimator.
Expand Down Expand Up @@ -71,13 +72,6 @@ class GridSearch implements Wrapper, Learner, Parallel, Verbose, Persistable
*/
protected \Rubix\ML\CrossValidation\Validator $validator;

/**
* The base estimator instance.
*
* @var Learner
*/
protected \Rubix\ML\Learner $base;

/**
* The validation scores obtained from the last search.
*
Expand Down Expand Up @@ -179,18 +173,6 @@ public function __construct(
$this->backend = new Serial();
}

/**
* Return the estimator type.
*
* @internal
*
* @return EstimatorType
*/
public function type() : EstimatorType
{
return $this->base->type();
}

/**
* Return the data types that the estimator is compatible with.
*
Expand Down Expand Up @@ -232,16 +214,6 @@ public function trained() : bool
return $this->base->trained();
}

/**
* Return the base learner instance.
*
* @return Estimator
*/
public function base() : Estimator
{
return $this->base;
}

/**
* Train one estimator per combination of parameters given by the grid and
* assign the best one as the base estimator of this instance.
Expand Down Expand Up @@ -304,18 +276,6 @@ public function train(Dataset $dataset) : void
}
}

/**
* Make a prediction on a given sample dataset.
*
* @param Dataset $dataset
* @throws Exceptions\RuntimeException
* @return mixed[]
*/
public function predict(Dataset $dataset) : array
{
return $this->base->predict($dataset);
}

/**
* The callback that executes after the cross validation task.
*
Expand Down
53 changes: 2 additions & 51 deletions src/PersistentModel.php
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
use Rubix\ML\AnomalyDetectors\Scoring;
use Rubix\ML\Exceptions\InvalidArgumentException;
use Rubix\ML\Exceptions\RuntimeException;
use Rubix\ML\Traits\WrapperAware;

/**
* Persistent Model
Expand All @@ -23,12 +24,7 @@
*/
class PersistentModel implements Wrapper, Learner, Probabilistic, Scoring
{
/**
* The persistable base learner.
*
* @var Learner
*/
protected \Rubix\ML\Learner $base;
use WrapperAware;

/**
* The persister used to interface with the storage layer.
Expand Down Expand Up @@ -84,30 +80,6 @@ public function __construct(Learner $base, Persister $persister, ?Serializer $se
$this->serializer = $serializer ?? new RBX();
}

/**
* Return the estimator type.
*
* @internal
*
* @return EstimatorType
*/
public function type() : EstimatorType
{
return $this->base->type();
}

/**
* Return the data types that the estimator is compatible with.
*
* @internal
*
* @return list<\Rubix\ML\DataType>
*/
public function compatibility() : array
{
return $this->base->compatibility();
}

/**
* Return the settings of the hyper-parameters in an associative array.
*
Expand All @@ -134,16 +106,6 @@ public function trained() : bool
return $this->base->trained();
}

/**
* Return the base estimator instance.
*
* @return Estimator
*/
public function base() : Estimator
{
return $this->base;
}

/**
* Save the model to storage.
*/
Expand All @@ -168,17 +130,6 @@ public function train(Dataset $dataset) : void
$this->base->train($dataset);
}

/**
* Make a prediction on a given sample dataset.
*
* @param Dataset $dataset
* @return mixed[]
*/
public function predict(Dataset $dataset) : array
{
return $this->base->predict($dataset);
}

/**
* Estimate the joint probabilities for each possible outcome.
*
Expand Down
44 changes: 2 additions & 42 deletions src/Pipeline.php
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

use Rubix\ML\Helpers\Params;
use Rubix\ML\Datasets\Dataset;
use Rubix\ML\Traits\WrapperAware;
use Rubix\ML\Transformers\Elastic;
use Rubix\ML\Transformers\Stateful;
use Rubix\ML\Transformers\Transformer;
Expand All @@ -27,7 +28,7 @@
*/
class Pipeline implements Online, Probabilistic, Scoring, Persistable, Wrapper
{
use AutotrackRevisions;
use AutotrackRevisions, WrapperAware;

/**
* A list of transformers to be applied in series.
Expand All @@ -38,13 +39,6 @@ class Pipeline implements Online, Probabilistic, Scoring, Persistable, Wrapper
//
];

/**
* An instance of a base estimator to receive the transformed data.
*
* @var Estimator
*/
protected \Rubix\ML\Estimator $base;

/**
* Should we update the elastic transformers during partial train?
*
Expand Down Expand Up @@ -72,30 +66,6 @@ public function __construct(array $transformers, Estimator $base, bool $elastic
$this->elastic = $elastic;
}

/**
* Return the estimator type.
*
* @internal
*
* @return EstimatorType
*/
public function type() : EstimatorType
{
return $this->base->type();
}

/**
* Return the data types that the estimator is compatible with.
*
* @internal
*
* @return list<\Rubix\ML\DataType>
*/
public function compatibility() : array
{
return $this->base->compatibility();
}

/**
* Return the settings of the hyper-parameters in an associative array.
*
Expand Down Expand Up @@ -124,16 +94,6 @@ public function trained() : bool
: true;

Check failure on line 94 in src/Pipeline.php

View workflow job for this annotation

GitHub Actions / PHP 7.4 on ubuntu-latest

Else branch is unreachable because ternary operator condition is always true.

Check failure on line 94 in src/Pipeline.php

View workflow job for this annotation

GitHub Actions / PHP 8.0 on ubuntu-latest

Else branch is unreachable because ternary operator condition is always true.

Check failure on line 94 in src/Pipeline.php

View workflow job for this annotation

GitHub Actions / PHP 8.1 on ubuntu-latest

Else branch is unreachable because ternary operator condition is always true.

Check failure on line 94 in src/Pipeline.php

View workflow job for this annotation

GitHub Actions / PHP 8.2 on ubuntu-latest

Else branch is unreachable because ternary operator condition is always true.
}

/**
* Return the base estimator instance.
*
* @return Estimator
*/
public function base() : Estimator
{
return $this->base;
}

/**
* Run the training dataset through all transformers in order and use the
* transformed dataset to train the estimator.
Expand Down
73 changes: 73 additions & 0 deletions src/Traits/WrapperAware.php
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
<?php

namespace Rubix\ML\Traits;

use Rubix\ML\Datasets\Dataset;
use Rubix\ML\Estimator;
use Rubix\ML\EstimatorType;
use Rubix\ML\Exceptions\RuntimeException;
use Rubix\ML\Learner;

/**
* Wrapper Aware
*
* This trait fulfills the requirements of the Wrapper interface and is suitable for most implementations.
*
* @category Machine Learning
* @package Rubix/ML
*/
trait WrapperAware
{
/**
* The base estimator.
*
* @var Learner
*/
protected Estimator $base;

/**
* Return the base estimator instance.
*
* @return Estimator
*/
public function base(): Estimator
{
return $this->base;
}

/**
* Return the estimator type.
*
* @internal
*
* @return EstimatorType
*/
public function type() : EstimatorType
{
return $this->base->type();
}

/**
* Return the data types that the estimator is compatible with.
*
* @internal
*
* @return list<\Rubix\ML\DataType>
*/
public function compatibility() : array
{
return $this->base->compatibility();
}

/**
* Make a prediction on a given sample dataset.
*
* @param Dataset $dataset
* @throws RuntimeException
* @return mixed[]
*/
public function predict(Dataset $dataset) : array
{
return $this->base->predict($dataset);
}
}
20 changes: 20 additions & 0 deletions src/Wrapper.php
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
<?php

namespace Rubix\ML;

/**
* Wrapper
*
* @category Machine Learning
* @package Rubix/ML
* @author Andrew DalPino
*/
interface Wrapper extends Estimator
{
/**
* Return the base estimator instance.
*
* @return Estimator
*/
public function base(): Estimator;
}

0 comments on commit 241abc4

Please sign in to comment.