diff --git a/packages/visx-grid/src/grids/GridColumns.tsx b/packages/visx-grid/src/grids/GridColumns.tsx index 98968672c..9941bf402 100644 --- a/packages/visx-grid/src/grids/GridColumns.tsx +++ b/packages/visx-grid/src/grids/GridColumns.tsx @@ -5,6 +5,7 @@ import { Group } from '@visx/group'; import { Point } from '@visx/point'; import { getTicks, ScaleInput, coerceNumber } from '@visx/scale'; import { CommonGridProps, GridScale } from '../types'; +import getScaleBandwidth from '../utils/getScaleBandwidth'; export type GridColumnsProps = CommonGridProps & { /** `@visx/scale` or `d3-scale` object used to convert value to position. */ @@ -41,8 +42,9 @@ export default function GridColumns({ ...restProps }: AllGridColumnsProps) { const ticks = tickValues ?? getTicks(scale, numTicks); + const scaleOffset = (offset ?? 0) + getScaleBandwidth(scale) / 2; const tickLines = ticks.map(d => { - const x = offset ? (coerceNumber(scale(d)) || 0) + offset : coerceNumber(scale(d)) || 0; + const x = (coerceNumber(scale(d)) ?? 0) + scaleOffset; return { from: new Point({ x, diff --git a/packages/visx-grid/src/grids/GridRows.tsx b/packages/visx-grid/src/grids/GridRows.tsx index 811834bb6..636de9729 100644 --- a/packages/visx-grid/src/grids/GridRows.tsx +++ b/packages/visx-grid/src/grids/GridRows.tsx @@ -5,6 +5,7 @@ import { Group } from '@visx/group'; import { Point } from '@visx/point'; import { getTicks, ScaleInput, coerceNumber } from '@visx/scale'; import { CommonGridProps, GridScale } from '../types'; +import getScaleBandwidth from '../utils/getScaleBandwidth'; export type GridRowsProps = CommonGridProps & { /** `@visx/scale` or `d3-scale` object used to convert value to position. */ @@ -41,8 +42,9 @@ export default function GridRows({ ...restProps }: AllGridRowsProps) { const ticks = tickValues ?? getTicks(scale, numTicks); + const scaleOffset = (offset ?? 0) + getScaleBandwidth(scale) / 2; const tickLines = ticks.map(d => { - const y = offset ? (coerceNumber(scale(d)) || 0) + offset : coerceNumber(scale(d)) || 0; + const y = (coerceNumber(scale(d)) ?? 0) + scaleOffset; return { from: new Point({ x: 0, diff --git a/packages/visx-grid/src/utils/getScaleBandwidth.ts b/packages/visx-grid/src/utils/getScaleBandwidth.ts new file mode 100644 index 000000000..14c40f007 --- /dev/null +++ b/packages/visx-grid/src/utils/getScaleBandwidth.ts @@ -0,0 +1,5 @@ +import { GridScale } from '../types'; + +export default function getScaleBandwidth(scale: GridScale) { + return 'bandwidth' in scale ? scale.bandwidth() : 0; +} diff --git a/packages/visx-grid/test/utils.test.ts b/packages/visx-grid/test/utils.test.ts index a265b1036..6994cf383 100644 --- a/packages/visx-grid/test/utils.test.ts +++ b/packages/visx-grid/test/utils.test.ts @@ -1,6 +1,8 @@ +import { scaleLinear, scaleBand } from '@visx/scale'; import polarToCartesian from '../src/utils/polarToCartesian'; +import getScaleBandwidth from '../src/utils/getScaleBandwidth'; -describe('GridUtils', () => { +describe('grid utils', () => { describe('polarToCartesian', () => { const config = { radius: 20, @@ -14,4 +16,29 @@ describe('GridUtils', () => { expect(polarToCartesian(config)).toEqual(expected); }); }); + + describe('getScaleBandwidth', () => { + it('should return 0 for non-band scales', () => { + expect( + getScaleBandwidth( + scaleLinear({ + range: [0, 90], + domain: [0, 100], + }), + ), + ).toBe(0); + }); + + it('should return the size of the band for band scales', () => { + expect( + getScaleBandwidth( + scaleBand({ + range: [0, 90], + domain: ['a', 'b', 'c'], + padding: 0, + }), + ), + ).toBe(30); + }); + }); });