diff --git a/src/laravel/Distance.php b/src/laravel/Distance.php index 3ef9436..d698d81 100644 --- a/src/laravel/Distance.php +++ b/src/laravel/Distance.php @@ -2,10 +2,9 @@ namespace Pgvector\Laravel; -// TODO use enum when PHP 8.0 reaches EOL -class Distance +enum Distance { - public const L2 = 0; - public const InnerProduct = 1; - public const Cosine = 2; + case L2; + case InnerProduct; + case Cosine; } diff --git a/src/laravel/HasNeighbors.php b/src/laravel/HasNeighbors.php index 1805b7b..66f1978 100644 --- a/src/laravel/HasNeighbors.php +++ b/src/laravel/HasNeighbors.php @@ -6,7 +6,7 @@ trait HasNeighbors { - public function scopeNearestNeighbors(Builder $query, string $column, mixed $value, int $distance): void + public function scopeNearestNeighbors(Builder $query, string $column, mixed $value, Distance $distance): void { switch ($distance) { case Distance::L2: @@ -34,7 +34,7 @@ public function scopeNearestNeighbors(Builder $query, string $column, mixed $val ->orderByRaw($order, [$vector]); } - public function nearestNeighbors(string $column, int $distance): Builder + public function nearestNeighbors(string $column, Distance $distance): Builder { $id = $this->getKey(); if (!array_key_exists($column, $this->attributes)) { diff --git a/tests/LaravelTest.php b/tests/LaravelTest.php index 294771b..0635426 100644 --- a/tests/LaravelTest.php +++ b/tests/LaravelTest.php @@ -108,8 +108,7 @@ public function testMissingAttribute() public function testInvalidDistance() { - $this->expectException(InvalidArgumentException::class); - $this->expectExceptionMessage('Invalid distance'); + $this->expectException(TypeError::class); Item::query()->nearestNeighbors('embedding', [1, 2, 3], 4); }