diff --git a/src/graphql/index.ts b/src/graphql/index.ts index 705e6ec..3892672 100644 --- a/src/graphql/index.ts +++ b/src/graphql/index.ts @@ -9,7 +9,7 @@ import { } from 'graphql'; import DataLoader from 'dataloader'; import { ResolverContextInput } from './resolvers'; -import { getTableName } from '../utils/database'; +import { getTableName, applyBlockFilter } from '../utils/database'; /** * Creates getLoader function that will return existing, or create a new dataloader @@ -25,14 +25,13 @@ export const createGetLoader = (context: ResolverContextInput) => { if (!loaders[key]) { loaders[key] = new DataLoader(async ids => { + const tableName = getTableName(name); + let query = context.knex .select('*') - .from(getTableName(name)) + .from(tableName) .whereIn(field, ids as string[]); - query = - block !== undefined - ? query.andWhereRaw('block_range @> int8(??)', [block]) - : query.andWhereRaw('upper_inf(block_range)'); + query = applyBlockFilter(query, tableName, block); context.log.debug({ sql: query.toQuery(), ids }, 'executing batched query'); diff --git a/src/graphql/resolvers.ts b/src/graphql/resolvers.ts index 7bd4690..2ab2d8b 100644 --- a/src/graphql/resolvers.ts +++ b/src/graphql/resolvers.ts @@ -16,7 +16,7 @@ import { import { Knex } from 'knex'; import { Pool as PgPool } from 'pg'; import { getNonNullType, getDerivedFromDirective } from '../utils/graphql'; -import { getTableName } from '../utils/database'; +import { getTableName, applyBlockFilter } from '../utils/database'; import { Logger } from '../utils/logger'; import type DataLoader from 'dataloader'; @@ -71,10 +71,7 @@ export async function queryMulti( const nestedEntitiesMappings = {} as Record>; let query = knex.select(`${tableName}.*`).from(tableName); - query = - args.block !== undefined - ? query.andWhereRaw(`${tableName}.block_range @> int8(??)`, [args.block]) - : query.andWhereRaw(`upper_inf(${tableName}.block_range)`); + query = applyBlockFilter(query, tableName, args.block); const handleWhere = (query: Knex.QueryBuilder, prefix: string, where: Record) => { Object.entries(where).map((w: [string, any]) => { @@ -132,10 +129,7 @@ export async function queryMulti( .columns(nestedEntitiesMappings[fieldName]) .innerJoin(nestedTableName, `${tableName}.${fieldName}`, '=', `${nestedTableName}.id`); - query = - args.block !== undefined - ? query.andWhereRaw(`${nestedTableName}.block_range @> int8(??)`, [args.block]) - : query.andWhereRaw(`upper_inf(${nestedTableName}.block_range)`); + query = applyBlockFilter(query, nestedTableName, args.block); handleWhere(query, nestedTableName, w[1]); } else { @@ -253,16 +247,10 @@ export const getNestedResolver = let result: Record[] = []; if (!derivedFromDirective) { - let query = knex - .select('*') - .from(getTableName(columnName)) - .whereIn('id', parent[info.fieldName]); - query = - block !== undefined - ? query.andWhereRaw('block_range @> int8(??)', [block]) - : query.andWhereRaw('upper_inf(block_range)'); - - result = await query; + const tableName = getTableName(columnName); + const query = knex.select('*').from(tableName).whereIn('id', parent[info.fieldName]); + + result = await applyBlockFilter(query, tableName, block); } else { const fieldArgument = derivedFromDirective.arguments?.find(arg => arg.name.value === 'field'); if (!fieldArgument || fieldArgument.value.kind !== 'StringValue') { diff --git a/src/stores/checkpoints.ts b/src/stores/checkpoints.ts index 628f9fc..4a90d87 100644 --- a/src/stores/checkpoints.ts +++ b/src/stores/checkpoints.ts @@ -49,6 +49,8 @@ export enum MetadataId { SchemaVersion = 'schema_version' } +export const INTERNAL_TABLES = Object.values(Table); + const CheckpointIdSize = 10; /** diff --git a/src/utils/database.ts b/src/utils/database.ts index dc54c9e..97800d2 100644 --- a/src/utils/database.ts +++ b/src/utils/database.ts @@ -1,7 +1,17 @@ import pluralize from 'pluralize'; +import { Knex } from 'knex'; +import { INTERNAL_TABLES } from '../stores/checkpoints'; export const getTableName = (name: string) => { if (name === '_metadata') return '_metadatas'; return pluralize(name); }; + +export function applyBlockFilter(query: Knex.QueryBuilder, tableName: string, block?: number) { + if (INTERNAL_TABLES.includes(tableName)) return query; + + return block !== undefined + ? query.andWhereRaw(`${tableName}.block_range @> int8(??)`, [block]) + : query.andWhereRaw(`upper_inf(${tableName}.block_range)`); +} diff --git a/test/unit/utils/database.test.ts b/test/unit/utils/database.test.ts new file mode 100644 index 0000000..e294bb6 --- /dev/null +++ b/test/unit/utils/database.test.ts @@ -0,0 +1,57 @@ +import knex from 'knex'; +import { getTableName, applyBlockFilter } from '../../../src/utils/database'; + +const mockKnex = knex({ + client: 'sqlite3', + connection: { + filename: ':memory:' + }, + useNullAsDefault: true +}); + +afterAll(async () => { + await mockKnex.destroy(); +}); + +describe('getTableName', () => { + it.each([ + ['table', 'tables'], + ['user', 'users'], + ['post', 'posts'], + ['space', 'spaces'], + ['vote', 'votes'], + ['comment', 'comments'] + ])('should return pluralized table name', (name, expected) => { + expect(getTableName(name)).toEqual(expected); + }); + + it('should return hardcoded table name for metadata', () => { + expect(getTableName('_metadata')).toEqual('_metadatas'); + }); +}); + +describe('applyBlockFilter', () => { + it('should not apply filter for internal tables', () => { + const query = mockKnex.select('*').from('_metadatas'); + + const result = applyBlockFilter(query, '_metadatas', 123); + + expect(result.toString()).toBe('select * from `_metadatas`'); + }); + + it('should apply capped block filter if block is provided', () => { + const query = mockKnex.select('*').from('posts'); + + const result = applyBlockFilter(query, 'posts', 123); + + expect(result.toString()).toBe('select * from `posts` where posts.block_range @> int8(123)'); + }); + + it('should apply upper_inf block filter if block is not provided', () => { + const query = mockKnex.select('*').from('posts'); + + const result = applyBlockFilter(query, 'posts'); + + expect(result.toString()).toBe('select * from `posts` where upper_inf(posts.block_range)'); + }); +});