Skip to content

Commit

Permalink
Merge branch '2.5' of github.com:/RubixML/ML into 2.5
Browse files Browse the repository at this point in the history
  • Loading branch information
andrewdalpino committed Sep 20, 2023
2 parents 8792955 + 665a3d1 commit a14e45d
Show file tree
Hide file tree
Showing 4 changed files with 76 additions and 2 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
- 2.5.0
- Blob Generator can now `simulate()` a Dataset object

- 2.4.1
- Sentence Tokenizer fix Arabic and Farsi language support
- Optimize online variance updating
Expand Down
12 changes: 10 additions & 2 deletions docs/datasets/generators/blob.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,16 @@ A normally distributed (Gaussian) n-dimensional blob of samples centered at a gi
```php
use Rubix\ML\Datasets\Generators\Blob;

$generator = new Blob([-1.2, -5., 2.6, 0.8, 10.], 0.25);
$generator = new Blob([-1.2, -5.0, 2.6, 0.8, 10.0], 0.25);
```

## Additional Methods
This generator does not have any additional methods.
Fit a Blob generator to the samples in a dataset.
```php
public static simulate(Dataset $dataset) : self
```

Return the center coordinates of the Blob.
```php
public center() : array
```
42 changes: 42 additions & 0 deletions src/Datasets/Generators/Blob.php
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,14 @@

use Tensor\Matrix;
use Tensor\Vector;
use Rubix\ML\DataType;
use Rubix\ML\Helpers\Stats;
use Rubix\ML\Datasets\Dataset;
use Rubix\ML\Datasets\Unlabeled;
use Rubix\ML\Exceptions\InvalidArgumentException;

use function count;
use function sqrt;

/**
* Blob
Expand Down Expand Up @@ -37,6 +41,34 @@ class Blob implements Generator
*/
protected $stdDev;

/**
* Fit a Blob generator to the samples in a dataset.
*
* @param \Rubix\ML\Datasets\Dataset $dataset
* @throws \Rubix\ML\Exceptions\InvalidArgumentException
* @return self
*/
public static function simulate(Dataset $dataset) : self
{
$features = $dataset->featuresByType(DataType::continuous());

if (count($features) !== $dataset->numFeatures()) {
throw new InvalidArgumentException('Dataset must only contain'
. ' continuous features.');
}

$means = $stdDevs = [];

foreach ($features as $values) {
[$mean, $variance] = Stats::meanVar($values);

$means[] = $mean;
$stdDevs[] = sqrt($variance);
}

return new self($means, $stdDevs);
}

/**
* @param (int|float)[] $center
* @param int|float|(int|float)[] $stdDev
Expand Down Expand Up @@ -74,6 +106,16 @@ public function __construct(array $center = [0, 0], $stdDev = 1.0)
$this->stdDev = $stdDev;
}

/**
* Return the center coordinates of the Blob.
*
* @return list<int|float>
*/
public function center() : array
{
return $this->center->asArray();
}

/**
* Return the dimensionality of the data this generates.
*
Expand Down
21 changes: 21 additions & 0 deletions tests/Datasets/Generators/BlobTest.php
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,19 @@ protected function setUp() : void
$this->generator = new Blob([0, 0, 0], 1.0);
}

/**
* @test
*/
public function simulate() : void
{
$dataset = $this->generator->generate(100);

$generator = Blob::simulate($dataset);

$this->assertInstanceOf(Blob::class, $generator);
$this->assertInstanceOf(Generator::class, $generator);
}

/**
* @test
*/
Expand All @@ -38,6 +51,14 @@ public function build() : void
$this->assertInstanceOf(Generator::class, $this->generator);
}

/**
* @test
*/
public function center() : void
{
$this->assertEquals([0, 0, 0], $this->generator->center());
}

/**
* @test
*/
Expand Down

0 comments on commit a14e45d

Please sign in to comment.