Skip to content

Commit

Permalink
Improved SparseVector constructor
Browse files Browse the repository at this point in the history
  • Loading branch information
ankane committed Jun 26, 2024
1 parent 10a72fb commit f8610a2
Show file tree
Hide file tree
Showing 14 changed files with 79 additions and 76 deletions.
2 changes: 1 addition & 1 deletion src/utils/index.js
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ function sparsevecFromSql(value) {
if (value === null) {
return null;
}
return SparseVector.fromSql(value);
return new SparseVector(value);
}

function sparsevecToSql(value) {
Expand Down
55 changes: 29 additions & 26 deletions src/utils/sparse-vector.js
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
const util = require('node:util');

class SparseVector {
constructor(dimensions, indices, values) {
if (indices.length != values.length) {
throw new Error('indices and values must be the same length');
constructor(value, dimensions) {
if (typeof value === 'string') {
this.#fromSql(value);
} else if (dimensions !== undefined) {
this.#fromMap(value, dimensions);
} else {
this.#fromDense(value);
}
this.dimensions = dimensions;
this.indices = indices;
this.values = values;
}

toString() {
Expand All @@ -32,43 +33,45 @@ class SparseVector {
return arr;
}

static fromSql(value) {
#fromSql(value) {
const parts = value.split('/', 2);

this.dimensions = parseInt(parts[1]);
this.indices = [];
this.values = [];

const elements = parts[0].slice(1, -1).split(',');
const dimensions = parseInt(parts[1]);
const indices = [];
const values = [];
for (const element of elements) {
const ep = element.split(':', 2);
indices.push(parseInt(ep[0]) - 1);
values.push(parseFloat(ep[1]));
this.indices.push(parseInt(ep[0]) - 1);
this.values.push(parseFloat(ep[1]));
}
return new SparseVector(dimensions, indices, values);
}

static fromDense(value) {
const dimensions = value.length;
const indices = [];
const values = [];
#fromDense(value) {
this.dimensions = value.length;
this.indices = [];
this.values = [];

for (const [i, v] of value.entries()) {
if (v != 0) {
indices.push(Number(i));
values.push(Number(v));
this.indices.push(Number(i));
this.values.push(Number(v));
}
}
return new SparseVector(dimensions, indices, values);
}

static fromMap(map, dimensions) {
const indices = [];
const values = [];
#fromMap(map, dimensions) {
this.dimensions = Number(dimensions);
this.indices = [];
this.values = [];

for (const [i, v] of map.entries()) {
if (v != 0) {
indices.push(Number(i));
values.push(Number(v));
this.indices.push(Number(i));
this.values.push(Number(v));
}
}
return new SparseVector(Number(dimensions), indices, values);
}
}

Expand Down
8 changes: 4 additions & 4 deletions tests/drizzle-orm/index.test.mjs
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,9 @@ test('example', async () => {
});

const newItems = [
{embedding: [1, 1, 1], halfEmbedding: [1, 1, 1], binaryEmbedding: '000', sparseEmbedding: SparseVector.fromDense([1, 1, 1])},
{embedding: [2, 2, 2], halfEmbedding: [2, 2, 2], binaryEmbedding: '101', sparseEmbedding: SparseVector.fromDense([2, 2, 2])},
{embedding: [1, 1, 2], halfEmbedding: [1, 1, 2], binaryEmbedding: '111', sparseEmbedding: SparseVector.fromDense([1, 1, 2])},
{embedding: [1, 1, 1], halfEmbedding: [1, 1, 1], binaryEmbedding: '000', sparseEmbedding: new SparseVector([1, 1, 1])},
{embedding: [2, 2, 2], halfEmbedding: [2, 2, 2], binaryEmbedding: '101', sparseEmbedding: new SparseVector([2, 2, 2])},
{embedding: [1, 1, 2], halfEmbedding: [1, 1, 2], binaryEmbedding: '111', sparseEmbedding: new SparseVector([1, 1, 2])},
{embedding: null}
];
await db.insert(items).values(newItems);
Expand All @@ -48,7 +48,7 @@ test('example', async () => {
// L2 distance - sparsevec
allItems = await db.select()
.from(items)
.orderBy(l2Distance(items.sparseEmbedding, SparseVector.fromDense([1, 1, 1])))
.orderBy(l2Distance(items.sparseEmbedding, new SparseVector([1, 1, 1])))
.limit(5);
expect(allItems.map(v => v.id)).toStrictEqual([1, 3, 2, 4]);

Expand Down
8 changes: 4 additions & 4 deletions tests/knex/index.test.mjs
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,9 @@ test('example', async () => {
});

const newItems = [
{embedding: pgvector.toSql([1, 1, 1]), half_embedding: pgvector.toSql([1, 1, 1]), binary_embedding: '000', sparse_embedding: SparseVector.fromDense([1, 1, 1])},
{embedding: pgvector.toSql([2, 2, 2]), half_embedding: pgvector.toSql([2, 2, 2]), binary_embedding: '101', sparse_embedding: SparseVector.fromDense([2, 2, 2])},
{embedding: pgvector.toSql([1, 1, 2]), half_embedding: pgvector.toSql([1, 1, 2]), binary_embedding: '111', sparse_embedding: SparseVector.fromDense([1, 1, 2])},
{embedding: pgvector.toSql([1, 1, 1]), half_embedding: pgvector.toSql([1, 1, 1]), binary_embedding: '000', sparse_embedding: new SparseVector([1, 1, 1])},
{embedding: pgvector.toSql([2, 2, 2]), half_embedding: pgvector.toSql([2, 2, 2]), binary_embedding: '101', sparse_embedding: new SparseVector([2, 2, 2])},
{embedding: pgvector.toSql([1, 1, 2]), half_embedding: pgvector.toSql([1, 1, 2]), binary_embedding: '111', sparse_embedding: new SparseVector([1, 1, 2])},
{embedding: null}
];
await knex('knex_items').insert(newItems);
Expand All @@ -43,7 +43,7 @@ test('example', async () => {

// L2 distance - sparsevec
items = await knex('knex_items')
.orderBy(knex.l2Distance('sparse_embedding', SparseVector.fromDense([1, 1, 1])))
.orderBy(knex.l2Distance('sparse_embedding', new SparseVector([1, 1, 1])))
.limit(5);
expect(items.map(v => v.id)).toStrictEqual([1, 3, 2, 4]);

Expand Down
8 changes: 4 additions & 4 deletions tests/kysely/index.test.mjs
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,9 @@ test('example', async () => {
.execute();

const newItems = [
{embedding: pgvector.toSql([1, 1, 1]), half_embedding: pgvector.toSql([1, 1, 1]), binary_embedding: '000', sparse_embedding: SparseVector.fromDense([1, 1, 1])},
{embedding: pgvector.toSql([2, 2, 2]), half_embedding: pgvector.toSql([2, 2, 2]), binary_embedding: '101', sparse_embedding: SparseVector.fromDense([2, 2, 2])},
{embedding: pgvector.toSql([1, 1, 2]), half_embedding: pgvector.toSql([1, 1, 2]), binary_embedding: '111', sparse_embedding: SparseVector.fromDense([1, 1, 2])},
{embedding: pgvector.toSql([1, 1, 1]), half_embedding: pgvector.toSql([1, 1, 1]), binary_embedding: '000', sparse_embedding: new SparseVector([1, 1, 1])},
{embedding: pgvector.toSql([2, 2, 2]), half_embedding: pgvector.toSql([2, 2, 2]), binary_embedding: '101', sparse_embedding: new SparseVector([2, 2, 2])},
{embedding: pgvector.toSql([1, 1, 2]), half_embedding: pgvector.toSql([1, 1, 2]), binary_embedding: '111', sparse_embedding: new SparseVector([1, 1, 2])},
{embedding: null}
];
await db.insertInto('kysely_items')
Expand Down Expand Up @@ -61,7 +61,7 @@ test('example', async () => {
// L2 distance - sparsevec
items = await db.selectFrom('kysely_items')
.selectAll()
.orderBy(l2Distance('sparse_embedding', SparseVector.fromDense([1, 1, 1])))
.orderBy(l2Distance('sparse_embedding', new SparseVector([1, 1, 1])))
.limit(5)
.execute();
expect(items.map(v => v.id)).toStrictEqual([1, 3, 2, 4]);
Expand Down
8 changes: 4 additions & 4 deletions tests/mikro-orm/index.test.mjs
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,9 @@ test('example', async () => {
const generator = orm.getSchemaGenerator();
await generator.refreshDatabase();

em.create(Item, {embedding: [1, 1, 1], half_embedding: [1, 1, 1], binary_embedding: '000', sparse_embedding: SparseVector.fromDense([1, 1, 1])});
em.create(Item, {embedding: [2, 2, 2], half_embedding: [2, 2, 2], binary_embedding: '101', sparse_embedding: SparseVector.fromDense([2, 2, 2])});
em.create(Item, {embedding: [1, 1, 2], half_embedding: [1, 1, 2], binary_embedding: '111', sparse_embedding: SparseVector.fromDense([1, 1, 2])});
em.create(Item, {embedding: [1, 1, 1], half_embedding: [1, 1, 1], binary_embedding: '000', sparse_embedding: new SparseVector([1, 1, 1])});
em.create(Item, {embedding: [2, 2, 2], half_embedding: [2, 2, 2], binary_embedding: '101', sparse_embedding: new SparseVector([2, 2, 2])});
em.create(Item, {embedding: [1, 1, 2], half_embedding: [1, 1, 2], binary_embedding: '111', sparse_embedding: new SparseVector([1, 1, 2])});
em.create(Item, {embedding: null});

// L2 distance
Expand All @@ -51,7 +51,7 @@ test('example', async () => {

// L2 distance - sparsevec
items = await em.createQueryBuilder(Item)
.orderBy({[l2Distance('sparse_embedding', SparseVector.fromDense([1, 1, 1]))]: 'ASC'})
.orderBy({[l2Distance('sparse_embedding', new SparseVector([1, 1, 1]))]: 'ASC'})
.limit(5)
.getResult();
expect(items.map(v => v.id)).toStrictEqual([1, 3, 2, 4]);
Expand Down
8 changes: 4 additions & 4 deletions tests/objection/index.test.mjs
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,9 @@ test('example', async () => {
});

const newItems = [
{embedding: pgvector.toSql([1, 1, 1]), half_embedding: pgvector.toSql([1, 1, 1]), binary_embedding: '000', sparse_embedding: SparseVector.fromDense([1, 1, 1])},
{embedding: pgvector.toSql([2, 2, 2]), half_embedding: pgvector.toSql([2, 2, 2]), binary_embedding: '101', sparse_embedding: SparseVector.fromDense([2, 2, 2])},
{embedding: pgvector.toSql([1, 1, 2]), half_embedding: pgvector.toSql([1, 1, 2]), binary_embedding: '111', sparse_embedding: SparseVector.fromDense([1, 1, 2])},
{embedding: pgvector.toSql([1, 1, 1]), half_embedding: pgvector.toSql([1, 1, 1]), binary_embedding: '000', sparse_embedding: new SparseVector([1, 1, 1])},
{embedding: pgvector.toSql([2, 2, 2]), half_embedding: pgvector.toSql([2, 2, 2]), binary_embedding: '101', sparse_embedding: new SparseVector([2, 2, 2])},
{embedding: pgvector.toSql([1, 1, 2]), half_embedding: pgvector.toSql([1, 1, 2]), binary_embedding: '111', sparse_embedding: new SparseVector([1, 1, 2])},
{embedding: null}
];
await Item.query().insert(newItems);
Expand All @@ -53,7 +53,7 @@ test('example', async () => {

// L2 distance - sparsevec
items = await Item.query()
.orderBy(l2Distance('sparse_embedding', SparseVector.fromDense([1, 1, 1])))
.orderBy(l2Distance('sparse_embedding', new SparseVector([1, 1, 1])))
.limit(5);
expect(items.map(v => v.id)).toStrictEqual([1, 3, 2, 4]);

Expand Down
6 changes: 3 additions & 3 deletions tests/pg-promise/index.test.mjs
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,9 @@ test('example', async () => {
await db.none('CREATE TABLE pg_promise_items (id serial PRIMARY KEY, embedding vector(3), half_embedding halfvec(3), binary_embedding bit(3), sparse_embedding sparsevec(3))');

const params = [
pgvector.toSql([1, 1, 1]), pgvector.toSql([1, 1, 1]), '000', SparseVector.fromDense([1, 1, 1]),
pgvector.toSql([2, 2, 2]), pgvector.toSql([2, 2, 2]), '101', SparseVector.fromDense([2, 2, 2]),
pgvector.toSql([1, 1, 2]), pgvector.toSql([1, 1, 2]), '111', SparseVector.fromDense([1, 1, 2]),
pgvector.toSql([1, 1, 1]), pgvector.toSql([1, 1, 1]), '000', new SparseVector([1, 1, 1]),
pgvector.toSql([2, 2, 2]), pgvector.toSql([2, 2, 2]), '101', new SparseVector([2, 2, 2]),
pgvector.toSql([1, 1, 2]), pgvector.toSql([1, 1, 2]), '111', new SparseVector([1, 1, 2]),
null, null, null, null
];
await db.none('INSERT INTO pg_promise_items (embedding, half_embedding, binary_embedding, sparse_embedding) VALUES ($1, $2, $3, $4), ($5, $6, $7, $8), ($9, $10, $11, $12), ($13, $14, $15, $16)', params);
Expand Down
6 changes: 3 additions & 3 deletions tests/pg/index.test.mjs
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,9 @@ test('example', async () => {
await client.query('CREATE TABLE pg_items (id serial PRIMARY KEY, embedding vector(3), half_embedding halfvec(3), binary_embedding bit(3), sparse_embedding sparsevec(3))');

const params = [
pgvector.toSql([1, 1, 1]), pgvector.toSql([1, 1, 1]), '000', SparseVector.fromDense([1, 1, 1]),
pgvector.toSql([2, 2, 2]), pgvector.toSql([2, 2, 2]), '101', SparseVector.fromDense([2, 2, 2]),
pgvector.toSql([1, 1, 2]), pgvector.toSql([1, 1, 2]), '111', SparseVector.fromDense([1, 1, 2]),
pgvector.toSql([1, 1, 1]), pgvector.toSql([1, 1, 1]), '000', new SparseVector([1, 1, 1]),
pgvector.toSql([2, 2, 2]), pgvector.toSql([2, 2, 2]), '101', new SparseVector([2, 2, 2]),
pgvector.toSql([1, 1, 2]), pgvector.toSql([1, 1, 2]), '111', new SparseVector([1, 1, 2]),
null, null, null, null
];
await client.query('INSERT INTO pg_items (embedding, half_embedding, binary_embedding, sparse_embedding) VALUES ($1, $2, $3, $4), ($5, $6, $7, $8), ($9, $10, $11, $12), ($13, $14, $15, $16)', params);
Expand Down
8 changes: 4 additions & 4 deletions tests/postgres/index.test.mjs
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,9 @@ test('example', async () => {
await sql`CREATE TABLE postgres_items (id serial PRIMARY KEY, embedding vector(3), half_embedding halfvec(3), binary_embedding bit(3), sparse_embedding sparsevec(3))`;

const newItems = [
{embedding: pgvector.toSql([1, 1, 1]), half_embedding: pgvector.toSql([1, 1, 1]), binary_embedding: '000', sparse_embedding: SparseVector.fromDense([1, 1, 1])},
{embedding: pgvector.toSql([2, 2, 2]), half_embedding: pgvector.toSql([2, 2, 2]), binary_embedding: '101', sparse_embedding: SparseVector.fromDense([2, 2, 2])},
{embedding: pgvector.toSql([1, 1, 2]), half_embedding: pgvector.toSql([1, 1, 2]), binary_embedding: '111', sparse_embedding: SparseVector.fromDense([1, 1, 2])}
{embedding: pgvector.toSql([1, 1, 1]), half_embedding: pgvector.toSql([1, 1, 1]), binary_embedding: '000', sparse_embedding: new SparseVector([1, 1, 1])},
{embedding: pgvector.toSql([2, 2, 2]), half_embedding: pgvector.toSql([2, 2, 2]), binary_embedding: '101', sparse_embedding: new SparseVector([2, 2, 2])},
{embedding: pgvector.toSql([1, 1, 2]), half_embedding: pgvector.toSql([1, 1, 2]), binary_embedding: '111', sparse_embedding: new SparseVector([1, 1, 2])}
];
await sql`INSERT INTO postgres_items ${ sql(newItems, 'embedding', 'half_embedding', 'binary_embedding', 'sparse_embedding') }`;

Expand All @@ -22,7 +22,7 @@ test('example', async () => {
expect(pgvector.fromSql(items[0].embedding)).toStrictEqual([1, 1, 1]);
expect(pgvector.fromSql(items[0].half_embedding)).toStrictEqual([1, 1, 1]);
expect(items[0].binary_embedding).toStrictEqual('000');
expect(SparseVector.fromSql(items[0].sparse_embedding).toArray()).toStrictEqual([1, 1, 1]);
expect((new SparseVector(items[0].sparse_embedding)).toArray()).toStrictEqual([1, 1, 1]);

await sql`CREATE INDEX ON postgres_items USING hnsw (embedding vector_l2_ops)`;

Expand Down
14 changes: 7 additions & 7 deletions tests/prisma/index.test.mjs
Original file line number Diff line number Diff line change
Expand Up @@ -64,18 +64,18 @@ test('sparsevec', async () => {

// TODO use create when possible (field is not available in the generated client)
// https://www.prisma.io/docs/concepts/components/prisma-schema/features-without-psl-equivalent#unsupported-field-types
const embedding1 = SparseVector.fromDense([1, 1, 1]).toSql();
const embedding2 = SparseVector.fromDense([2, 2, 2]).toSql();
const embedding3 = SparseVector.fromDense([1, 1, 2]).toSql();
const embedding1 = (new SparseVector([1, 1, 1])).toSql();
const embedding2 = (new SparseVector([2, 2, 2])).toSql();
const embedding3 = (new SparseVector([1, 1, 2])).toSql();
await prisma.$executeRaw`INSERT INTO prisma_items (sparse_embedding) VALUES (${embedding1}::sparsevec), (${embedding2}::sparsevec), (${embedding3}::sparsevec)`;

// TODO use raw orderBy when available
// https://github.com/prisma/prisma/issues/5848
const embedding = SparseVector.fromDense([1, 1, 1]).toSql();
const embedding = (new SparseVector([1, 1, 1])).toSql();
const items = await prisma.$queryRaw`SELECT id, sparse_embedding::text FROM prisma_items ORDER BY sparse_embedding <-> ${embedding}::sparsevec LIMIT 5`;
expect(SparseVector.fromSql(items[0].sparse_embedding).toArray()).toStrictEqual([1, 1, 1]);
expect(SparseVector.fromSql(items[1].sparse_embedding).toArray()).toStrictEqual([1, 1, 2]);
expect(SparseVector.fromSql(items[2].sparse_embedding).toArray()).toStrictEqual([2, 2, 2]);
expect((new SparseVector(items[0].sparse_embedding)).toArray()).toStrictEqual([1, 1, 1]);
expect((new SparseVector(items[1].sparse_embedding)).toArray()).toStrictEqual([1, 1, 2]);
expect((new SparseVector(items[2].sparse_embedding)).toArray()).toStrictEqual([2, 2, 2]);
});

beforeEach(async () => {
Expand Down
8 changes: 4 additions & 4 deletions tests/slonik/index.test.mjs
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,9 @@ test('example', async () => {
const binaryEmbedding1 = '000';
const binaryEmbedding2 = '101';
const binaryEmbedding3 = '111';
const sparseEmbedding1 = SparseVector.fromDense([1, 1, 1]).toSql();
const sparseEmbedding2 = SparseVector.fromDense([2, 2, 2]).toSql();
const sparseEmbedding3 = SparseVector.fromDense([1, 1, 2]).toSql();
const sparseEmbedding1 = (new SparseVector([1, 1, 1])).toSql();
const sparseEmbedding2 = (new SparseVector([2, 2, 2])).toSql();
const sparseEmbedding3 = (new SparseVector([1, 1, 2])).toSql();
await pool.query(sql.unsafe`INSERT INTO slonik_items (embedding, half_embedding, binary_embedding, sparse_embedding) VALUES (${embedding1}, ${halfEmbedding1}, ${binaryEmbedding1}, ${sparseEmbedding1}), (${embedding2}, ${halfEmbedding2}, ${binaryEmbedding2}, ${sparseEmbedding2}), (${embedding3}, ${halfEmbedding3}, ${binaryEmbedding3}, ${sparseEmbedding3})`);

const embedding = pgvector.toSql([1, 1, 1]);
Expand All @@ -29,7 +29,7 @@ test('example', async () => {
expect(pgvector.fromSql(items.rows[0].embedding)).toStrictEqual([1, 1, 1]);
expect(pgvector.fromSql(items.rows[0].half_embedding)).toStrictEqual([1, 1, 1]);
expect(items.rows[0].binary_embedding).toStrictEqual('000');
expect(SparseVector.fromSql(items.rows[0].sparse_embedding).toArray()).toStrictEqual([1, 1, 1]);
expect((new SparseVector(items.rows[0].sparse_embedding)).toArray()).toStrictEqual([1, 1, 1]);

await pool.query(sql.unsafe`CREATE INDEX ON slonik_items USING hnsw (embedding vector_l2_ops)`);

Expand Down
8 changes: 4 additions & 4 deletions tests/typeorm/index.test.mjs
Original file line number Diff line number Diff line change
Expand Up @@ -44,9 +44,9 @@ test('example', async () => {
await AppDataSource.query('CREATE TABLE typeorm_items (id bigserial PRIMARY KEY, embedding vector(3), half_embedding halfvec(3), binary_embedding bit(3), sparse_embedding sparsevec(3))');

const itemRepository = AppDataSource.getRepository(Item);
await itemRepository.save({embedding: pgvector.toSql([1, 1, 1]), half_embedding: pgvector.toSql([1, 1, 1]), binary_embedding: '000', sparse_embedding: SparseVector.fromDense([1, 1, 1])});
await itemRepository.save({embedding: pgvector.toSql([2, 2, 2]), half_embedding: pgvector.toSql([2, 2, 2]), binary_embedding: '101', sparse_embedding: SparseVector.fromDense([2, 2, 2])});
await itemRepository.save({embedding: pgvector.toSql([1, 1, 2]), half_embedding: pgvector.toSql([1, 1, 2]), binary_embedding: '111', sparse_embedding: SparseVector.fromDense([1, 1, 2])});
await itemRepository.save({embedding: pgvector.toSql([1, 1, 1]), half_embedding: pgvector.toSql([1, 1, 1]), binary_embedding: '000', sparse_embedding: new SparseVector([1, 1, 1])});
await itemRepository.save({embedding: pgvector.toSql([2, 2, 2]), half_embedding: pgvector.toSql([2, 2, 2]), binary_embedding: '101', sparse_embedding: new SparseVector([2, 2, 2])});
await itemRepository.save({embedding: pgvector.toSql([1, 1, 2]), half_embedding: pgvector.toSql([1, 1, 2]), binary_embedding: '111', sparse_embedding: new SparseVector([1, 1, 2])});

const items = await itemRepository
.createQueryBuilder('item')
Expand All @@ -58,7 +58,7 @@ test('example', async () => {
expect(pgvector.fromSql(items[0].embedding)).toStrictEqual([1, 1, 1]);
expect(pgvector.fromSql(items[0].half_embedding)).toStrictEqual([1, 1, 1]);
expect(items[0].binary_embedding).toStrictEqual('000');
expect(SparseVector.fromSql(items[0].sparse_embedding).toArray()).toStrictEqual([1, 1, 1]);
expect((new SparseVector(items[0].sparse_embedding).toArray())).toStrictEqual([1, 1, 1]);

await AppDataSource.destroy();
});
8 changes: 4 additions & 4 deletions tests/utils/sparse-vector.test.mjs
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
import { SparseVector } from 'pgvector/utils';

test('fromSql', () => {
const vec = SparseVector.fromSql('{1:1,3:2,5:3}/6');
const vec = new SparseVector('{1:1,3:2,5:3}/6');
expect(vec.toArray()).toStrictEqual([1, 0, 2, 0, 3, 0]);
expect(vec.dimensions).toStrictEqual(6);
expect(vec.indices).toStrictEqual([0, 2, 4]);
expect(vec.values).toStrictEqual([1, 2, 3]);
});

test('fromDense', () => {
const vec = SparseVector.fromDense([1, 0, 2, 0, 3, 0]);
const vec = new SparseVector([1, 0, 2, 0, 3, 0]);
expect(vec.toSql()).toStrictEqual('{1:1,3:2,5:3}/6');
expect(vec.dimensions).toStrictEqual(6);
expect(vec.indices).toStrictEqual([0, 2, 4]);
Expand All @@ -22,13 +22,13 @@ test('fromMap', () => {
map.set(4, 3);
map.set(0, 1);
map.set(3, 0);
const vec = SparseVector.fromMap(map, 6);
const vec = new SparseVector(map, 6);
expect(vec.dimensions).toStrictEqual(6);
expect(vec.indices).toStrictEqual([2, 4, 0]);
expect(vec.values).toStrictEqual([2, 3, 1]);
});

test('toSql', () => {
const vec = SparseVector.fromDense([1.23456789]);
const vec = new SparseVector([1.23456789]);
expect(vec.toSql()).toStrictEqual('{1:1.23456789}/1');
});

0 comments on commit f8610a2

Please sign in to comment.