diff --git a/src/AnomalyDetectors/OneClassSVM.php b/src/AnomalyDetectors/OneClassSVM.php index a7ee88df2..93e0a0369 100644 --- a/src/AnomalyDetectors/OneClassSVM.php +++ b/src/AnomalyDetectors/OneClassSVM.php @@ -79,6 +79,7 @@ public function __construct( new ExtensionIsLoaded('svm'), new ExtensionMinimumVersion('svm', '0.2.0'), ])->check(); + if ($nu < 0.0 or $nu > 1.0) { throw new InvalidArgumentException('Nu must be between' @@ -182,7 +183,14 @@ public function train(Dataset $dataset) : void new SamplesAreCompatibleWithEstimator($dataset, $this), ])->check(); - $this->model = $this->svm->train($dataset->samples()); + $data = []; + + foreach ($dataset->samples() as $sample) { + array_unshift($sample, 1); + $data[] = $sample; + } + + $this->model = $this->svm->train($data); } /** @@ -211,7 +219,13 @@ public function predictSample(array $sample) : int throw new RuntimeException('Estimator has not been trained.'); } - return $this->model->predict($sample) !== 1.0 ? 0 : 1; + $sampleWithOffset = []; + + foreach ($sample as $key => $value) { + $sampleWithOffset[$key + 1] = $value; + } + + return $this->model->predict($sampleWithOffset) == 1 ? 0 : 1; } /**