diff --git a/__tests__/unit/plots/scatter/regression-line-spec.ts b/__tests__/unit/plots/scatter/regression-line-spec.ts new file mode 100644 index 0000000000..74d7bc70c0 --- /dev/null +++ b/__tests__/unit/plots/scatter/regression-line-spec.ts @@ -0,0 +1,110 @@ +import { Scatter } from '../../../../src'; +import { createDiv } from '../../../utils/dom'; + +const data = [ + { x: 1, y: 4.181 }, + { x: 2, y: 4.665 }, + { x: 3, y: 5.296 }, + { x: 4, y: 5.365 }, + { x: 5, y: 5.448 }, + { x: 6, y: 5.744 }, + { x: 7, y: 5.653 }, + { x: 8, y: 5.844 }, + { x: 9, y: 6.362 }, + { x: 10, y: 6.38 }, + { x: 11, y: 6.311 }, + { x: 12, y: 6.457 }, + { x: 13, y: 6.479 }, + { x: 14, y: 6.59 }, + { x: 15, y: 6.74 }, + { x: 16, y: 6.58 }, + { x: 17, y: 6.852 }, + { x: 18, y: 6.531 }, + { x: 19, y: 6.682 }, + { x: 20, y: 7.013 }, + { x: 21, y: 6.82 }, + { x: 22, y: 6.647 }, + { x: 23, y: 6.951 }, + { x: 24, y: 7.121 }, + { x: 25, y: 7.143 }, + { x: 26, y: 6.914 }, + { x: 27, y: 6.941 }, + { x: 28, y: 7.226 }, + { x: 29, y: 6.898 }, + { x: 30, y: 7.392 }, + { x: 31, y: 6.938 }, +]; + +describe('scatter', () => { + it('regressionLine: type', () => { + const scatter = new Scatter(createDiv('regressionLine'), { + data, + width: 400, + height: 300, + appendPadding: 10, + xField: 'x', + yField: 'y', + size: 5, + pointStyle: { + stroke: '#777777', + lineWidth: 1, + fill: '#5B8FF9', + }, + regressionLine: { + type: 'quad', // linear, exp, loess, log, poly, pow, quad + style: { + stroke: 'red', + }, + }, + }); + + scatter.render(); + + const geometry = scatter.chart.geometries[0]; + const annotation = scatter.chart.annotation(); + // @ts-ignore + const options = annotation.option[0]; + expect(options.type).toBe('shape'); + const elements = geometry.elements; + expect(elements.length).toBe(31); + // @ts-ignore + expect(elements[0].getModel().style.fill).toBe('#5B8FF9'); + expect(elements[0].getModel().size).toBe(5); + }); + + it('regressionLine: algorithm', () => { + const scatter = new Scatter(createDiv('regressionLine*algorithm'), { + data, + width: 400, + height: 300, + appendPadding: 10, + xField: 'x', + yField: 'y', + size: 5, + pointStyle: { + stroke: '#777777', + lineWidth: 1, + fill: '#5B8FF9', + }, + regressionLine: { + algorithm: [ + [0, 0], + [200, 200], + ], + }, + }); + + scatter.render(); + + const geometry = scatter.chart.geometries[0]; + const annotation = scatter.chart.annotation(); + // @ts-ignore + const options = annotation.option[0]; + expect(options.type).toBe('shape'); + const elements = geometry.elements; + expect(elements.length).toBe(31); + // @ts-ignore + expect(elements[0].getModel().style.fill).toBe('#5B8FF9'); + expect(elements[0].getModel().size).toBe(5); + }); +}); diff --git a/examples/scatter/basic/API.en.md b/examples/scatter/basic/API.en.md index af68f4ad83..2896bad95f 100644 --- a/examples/scatter/basic/API.en.md +++ b/examples/scatter/basic/API.en.md @@ -141,6 +141,42 @@ scatterPlot.render(); } ``` +#### regressionLine + +**可选**, _object_ + +功能描述: 设置回归线。 + +| 细分配置 | 类型 | 功能描述 | +| --------- | --------------------------------------------------- | ----------------------------------- | +| type | linear \| exp \| loess \| log\| poly \| pow \| quad | 回归线类型 | +| style | object | 回归线样式,参考绘图属性 | +| algorithm | [][] \| (data) => [][] | 自定义回归算法 返回值自定义数据数组 | + +```ts +regressionLine: { + algorithm: [ + [0, 0], + [200, 200], + ]; +} +``` + +#### quadrant + +**可选**, _object_ + +功能描述: 四象限组件。 + +| 细分配置 | 类型 | 功能描述 | +| ----------- | ------- | ---------------------------------- | +| xBaseline | number | x 方向上的象限分割基准线,默认为 0 | +| yBaseline | number | y 方向上的象限分割基准线,默认为 0 | +| lineStyle | object | 配置象限分割线的样式,参考绘图属性 | +| regionStyle | object | 象限样式,参考绘图属性 | +| lineStyle | object | 回归线样式,参考绘图属性 | +| labels | label[] | 象限文本配置,参考绘图属性 | + #### pointStyle ✨ **可选**, _object_ diff --git a/examples/scatter/basic/API.zh.md b/examples/scatter/basic/API.zh.md index 671bf0a1ac..03b25ed07a 100644 --- a/examples/scatter/basic/API.zh.md +++ b/examples/scatter/basic/API.zh.md @@ -142,6 +142,42 @@ scatterPlot.render(); } ``` +#### regressionLine + +**可选**, _object_ + +功能描述: 设置回归线。 + +| 细分配置 | 类型 | 功能描述 | +| --------- | --------------------------------------------------- | ----------------------------------- | +| type | linear \| exp \| loess \| log\| poly \| pow \| quad | 回归线类型 | +| style | object | 回归线样式,参考绘图属性 | +| algorithm | [][] \| (data) => [][] | 自定义回归算法 返回值自定义数据数组 | + +```ts +regressionLine: { + algorithm: [ + [0, 0], + [200, 200], + ]; +} +``` + +#### quadrant + +**可选**, _object_ + +功能描述: 四象限组件。 + +| 细分配置 | 类型 | 功能描述 | +| ----------- | ------- | ---------------------------------- | +| xBaseline | number | x 方向上的象限分割基准线,默认为 0 | +| yBaseline | number | y 方向上的象限分割基准线,默认为 0 | +| lineStyle | object | 配置象限分割线的样式,参考绘图属性 | +| regionStyle | object | 象限样式,参考绘图属性 | +| lineStyle | object | 回归线样式,参考绘图属性 | +| labels | label[] | 象限文本配置,参考绘图属性 | + #### pointStyle ✨ **可选**, _object_ diff --git a/examples/scatter/basic/demo/line.ts b/examples/scatter/basic/demo/line.ts new file mode 100644 index 0000000000..1ee2f17f66 --- /dev/null +++ b/examples/scatter/basic/demo/line.ts @@ -0,0 +1,50 @@ +import { Scatter } from '@antv/g2plot'; + +const data = [ + { x: 1, y: 4.181 }, + { x: 2, y: 4.665 }, + { x: 3, y: 5.296 }, + { x: 4, y: 5.365 }, + { x: 5, y: 5.448 }, + { x: 6, y: 5.744 }, + { x: 7, y: 5.653 }, + { x: 8, y: 5.844 }, + { x: 9, y: 6.362 }, + { x: 10, y: 6.38 }, + { x: 11, y: 6.311 }, + { x: 12, y: 6.457 }, + { x: 13, y: 6.479 }, + { x: 14, y: 6.59 }, + { x: 15, y: 6.74 }, + { x: 16, y: 6.58 }, + { x: 17, y: 6.852 }, + { x: 18, y: 6.531 }, + { x: 19, y: 6.682 }, + { x: 20, y: 7.013 }, + { x: 21, y: 6.82 }, + { x: 22, y: 6.647 }, + { x: 23, y: 6.951 }, + { x: 24, y: 7.121 }, + { x: 25, y: 7.143 }, + { x: 26, y: 6.914 }, + { x: 27, y: 6.941 }, + { x: 28, y: 7.226 }, + { x: 29, y: 6.898 }, + { x: 30, y: 7.392 }, + { x: 31, y: 6.938 }, +]; +const scatterPlot = new Scatter('container', { + data, + xField: 'x', + yField: 'y', + size: 5, + pointStyle: { + stroke: '#777777', + lineWidth: 1, + fill: '#5B8FF9', + }, + regressionLine: { + type: 'quad', // linear, exp, loess, log, poly, pow, quad + }, +}); +scatterPlot.render(); diff --git a/examples/scatter/basic/demo/meta.json b/examples/scatter/basic/demo/meta.json index 10c16a49c0..0108ef19e0 100644 --- a/examples/scatter/basic/demo/meta.json +++ b/examples/scatter/basic/demo/meta.json @@ -28,6 +28,14 @@ }, "screenshot": "https://gw.alipayobjects.com/mdn/rms_d314dd/afts/img/A*tdedT4uaPaYAAAAAAAAAAABkARQnAQ" }, + { + "filename": "line.ts", + "title": { + "zh": "散点图-回归线", + "en": "Bubble chart regression line" + }, + "screenshot": "https://gw.alipayobjects.com/mdn/rms_d314dd/afts/img/A*JWiDQIYm09AAAAAAAAAAAAAAARQnAQ" + }, { "filename": "axis-right.ts", "title": { diff --git a/package.json b/package.json index 9599c2cbf0..e602dead3b 100644 --- a/package.json +++ b/package.json @@ -57,6 +57,7 @@ "@antv/event-emitter": "^0.1.2", "@antv/g2": "^4.1.0-beta.14", "d3-hierarchy": "^2.0.0", + "d3-regression": "^1.3.5", "dayjs": "^1.8.36", "size-sensor": "^1.0.1", "tslib": "^1.13.0" diff --git a/src/plots/scatter/adaptor.ts b/src/plots/scatter/adaptor.ts index 0b271e4aa5..ef4d288e5c 100644 --- a/src/plots/scatter/adaptor.ts +++ b/src/plots/scatter/adaptor.ts @@ -4,6 +4,7 @@ import { flow } from '../../utils'; import { point } from '../../adaptor/geometries'; import { tooltip, interaction, animation, theme, scale, annotation } from '../../adaptor/common'; import { findGeometry, transformLabel } from '../../utils'; +import { regressionLine } from './annotations/path'; import { getQuadrantDefaultConfig } from './util'; import { ScatterOptions } from './types'; @@ -180,5 +181,17 @@ function scatterAnnotation(params: Params): Params) { // flow 的方式处理所有的配置到 G2 API - return flow(geometry, meta, axis, legend, tooltip, label, interaction, scatterAnnotation, animation, theme)(params); + return flow( + geometry, + meta, + axis, + legend, + tooltip, + label, + interaction, + scatterAnnotation, + animation, + theme, + regressionLine + )(params); } diff --git a/src/plots/scatter/annotations/path.ts b/src/plots/scatter/annotations/path.ts new file mode 100644 index 0000000000..06e0fbf005 --- /dev/null +++ b/src/plots/scatter/annotations/path.ts @@ -0,0 +1,135 @@ +import { + regressionLinear, + regressionExp, + regressionLoess, + regressionLog, + regressionPoly, + regressionPow, + regressionQuad, +} from 'd3-regression'; +import { IGroup, Scale } from '@antv/g2/lib/dependents'; +import { minBy, maxBy, isArray } from '@antv/util'; +import { getScale } from '@antv/scale'; +import { View } from '@antv/g2'; +import { getSplinePath } from '../../../utils'; +import { Params } from '../../../core/adaptor'; +import { ScatterOptions } from '../types'; + +const REGRESSION_MAP = { + exp: regressionExp, + linear: regressionLinear, + loess: regressionLoess, + log: regressionLog, + poly: regressionPoly, + pow: regressionPow, + quad: regressionQuad, +}; + +type renderOptions = { + view: View; + group: IGroup; + options: ScatterOptions; +}; + +type path = { + x: number; + y: number; +}; + +// 处理用户自行配置 min max的情况 +function adjustScale(viewScale: Scale, pathData: path[], dim: string, config: renderOptions) { + const { min, max } = viewScale; + const { + options: { data, xField, yField }, + } = config; + const field = dim === 'x' ? xField : yField; + const dataMin = minBy(data, field)[field]; + const dataMax = maxBy(data, field)[field]; + const minRatio = (min - dataMin) / (dataMax - dataMin); + const maxRatio = (max - dataMax) / (dataMax - dataMin); + const trendMin = minBy(pathData, dim)[dim]; + const trendMax = maxBy(pathData, dim)[dim]; + return { + min: trendMin + minRatio * (trendMax - trendMin), + max: trendMax + maxRatio * (trendMax - trendMin), + }; +} + +function getPath(data: number[][], config: renderOptions) { + const { + view, + options: { xField, yField }, + } = config; + const pathData = data.map((d: [number, number]) => ({ x: d[0], y: d[1] })); + const xScaleView = view.getScaleByField(xField); + const yScaleView = view.getScaleByField(yField); + const coordinate = view.getCoordinate(); + const linearScale = getScale('linear'); + const xRange = adjustScale(xScaleView, pathData, 'x', config); + const xScale = new linearScale({ + min: xRange.min, + max: xRange.max, + }); + const yRange = adjustScale(yScaleView, pathData, 'y', config); + const yScale = new linearScale({ + min: yRange.min, + max: yRange.max, + }); + const points = pathData.map((d) => ({ + x: coordinate.start.x + coordinate['width'] * xScale.scale(d.x), + y: coordinate.start.y - coordinate['height'] * yScale.scale(d.y), + })); + return getSplinePath(points, false); +} + +function renderPath(config: renderOptions) { + const { group, options } = config; + const { xField, yField, data, regressionLine } = options; + const { type = 'linear', style, algorithm } = regressionLine; + let pathData: Array<[number, number]>; + if (algorithm) { + pathData = isArray(algorithm) ? algorithm : algorithm(data); + } else { + const reg = REGRESSION_MAP[type]() + .x((d) => d[xField]) + .y((d) => d[yField]); + pathData = reg(data); + } + const path = getPath(pathData, config); + const defaultStyle = { + stroke: '#9ba29a', + lineWidth: 2, + opacity: 0.5, + }; + group.addShape('path', { + name: 'regression-line', + attrs: { + path, + ...defaultStyle, + ...style, + }, + }); +} + +// 使用 shape annotation 绘制回归线 +export function regressionLine(params: Params) { + const { options, chart } = params; + const { regressionLine } = options; + if (regressionLine) { + chart.annotation().shape({ + render: (container, view) => { + const group = container.addGroup({ + id: `${chart.id}-regression-line`, + name: 'regression-line-group', + }); + renderPath({ + view, + group, + options, + }); + }, + }); + } + + return params; +} diff --git a/src/plots/scatter/types.ts b/src/plots/scatter/types.ts index e799b22b54..7664c19b5e 100644 --- a/src/plots/scatter/types.ts +++ b/src/plots/scatter/types.ts @@ -13,7 +13,7 @@ interface Labels extends Omit { position?: AnnotationPosition; } -interface Quadrant { +interface QuadrantOptions { /** x 方向上的象限分割基准线,默认为 0 */ readonly xBaseline?: number; /** y 方向上的象限分割基准线,默认为 0 */ @@ -26,17 +26,13 @@ interface Quadrant { readonly labels?: Labels[]; } -interface TrendLine { - /** 是否显示 */ - readonly visible?: boolean; - /** 趋势线类型 */ +export interface RegressionLineOptions { + /** 回归线类型 */ readonly type?: string; - /** 配置趋势线样式 */ + /** 配置回归线样式 */ readonly style?: ShapeStyle; - /** 是否绘制置信区间曲线 */ - readonly showConfidence?: boolean; - /** 配置置信区间样式 */ - readonly confidenceStyle?: ShapeStyle; + /** 自定义算法 [[0,0],[100,100]] */ + readonly algorithm?: Array<[number, number]> | ((data: any) => Array<[number, number]>); } export interface ScatterOptions extends Options { @@ -59,7 +55,7 @@ export interface ScatterOptions extends Options { /** 点颜色映射对应的数据字段名 */ readonly colorField?: string; /** 四象限组件 */ - readonly quadrant?: Quadrant; - /** 趋势线组件,为图表添加回归曲线 */ - readonly trendLine?: TrendLine; + readonly quadrant?: QuadrantOptions; + /** 归曲线 */ + readonly regressionLine?: RegressionLineOptions; } diff --git a/src/utils/index.ts b/src/utils/index.ts index 9bc558f73f..dbc57a9157 100644 --- a/src/utils/index.ts +++ b/src/utils/index.ts @@ -6,3 +6,4 @@ export { getContainerSize } from './dom'; export { findGeometry, getAllElements } from './geometry'; export { findViewById } from './view'; export { transformLabel } from './label'; +export { getSplinePath } from './path'; diff --git a/src/utils/path.ts b/src/utils/path.ts new file mode 100644 index 0000000000..39b77824e5 --- /dev/null +++ b/src/utils/path.ts @@ -0,0 +1,164 @@ +import { vec2 } from '@antv/matrix-util'; +import { Position, Point } from '@antv/g2/lib/interface'; + +function points2path(points: Point[], isInCircle: boolean) { + const path = []; + if (points.length) { + path.push(['M', points[0].x, points[0].y]); + for (let i = 1, length = points.length; i < length; i += 1) { + const item = points[i]; + path.push(['L', item.x, item.y]); + } + + if (isInCircle) { + path.push(['Z']); + } + } + + return path; +} + +/** + * @ignore + * 计算光滑的贝塞尔曲线 + */ +export const smoothBezier = ( + points: Position[], + smooth: number, + isLoop: boolean, + constraint: Position[] +): Position[] => { + const cps = []; + + let prevPoint: Position; + let nextPoint: Position; + const hasConstraint = !!constraint; + let min: Position; + let max: Position; + if (hasConstraint) { + min = [Infinity, Infinity]; + max = [-Infinity, -Infinity]; + + for (let i = 0, l = points.length; i < l; i++) { + const point = points[i]; + min = vec2.min([0, 0], min, point) as [number, number]; + max = vec2.max([0, 0], max, point) as [number, number]; + } + min = vec2.min([0, 0], min, constraint[0]) as [number, number]; + max = vec2.max([0, 0], max, constraint[1]) as [number, number]; + } + + for (let i = 0, len = points.length; i < len; i++) { + const point = points[i]; + if (isLoop) { + prevPoint = points[i ? i - 1 : len - 1]; + nextPoint = points[(i + 1) % len]; + } else { + if (i === 0 || i === len - 1) { + cps.push(point); + continue; + } else { + prevPoint = points[i - 1]; + nextPoint = points[i + 1]; + } + } + let v: [number, number] = [0, 0]; + v = vec2.sub(v, nextPoint, prevPoint) as [number, number]; + v = vec2.scale(v, v, smooth) as [number, number]; + + let d0 = vec2.distance(point, prevPoint); + let d1 = vec2.distance(point, nextPoint); + + const sum = d0 + d1; + if (sum !== 0) { + d0 /= sum; + d1 /= sum; + } + + const v1 = vec2.scale([0, 0], v, -d0); + const v2 = vec2.scale([0, 0], v, d1); + + let cp0 = vec2.add([0, 0], point, v1); + let cp1 = vec2.add([0, 0], point, v2); + + if (hasConstraint) { + cp0 = vec2.max([0, 0], cp0, min); + cp0 = vec2.min([0, 0], cp0, max); + cp1 = vec2.max([0, 0], cp1, min); + cp1 = vec2.min([0, 0], cp1, max); + } + + cps.push(cp0); + cps.push(cp1); + } + + if (isLoop) { + cps.push(cps.shift()); + } + return cps; +}; + +/** + * @ignore + * 贝塞尔曲线 + */ +export function catmullRom2bezier(crp: number[], z: boolean, constraint: Position[]) { + const isLoop = !!z; + const pointList = []; + for (let i = 0, l = crp.length; i < l; i += 2) { + pointList.push([crp[i], crp[i + 1]]); + } + + const controlPointList = smoothBezier(pointList, 0.4, isLoop, constraint); + const len = pointList.length; + const d1 = []; + + let cp1: Position; + let cp2: Position; + let p: Position; + + for (let i = 0; i < len - 1; i++) { + cp1 = controlPointList[i * 2]; + cp2 = controlPointList[i * 2 + 1]; + p = pointList[i + 1]; + d1.push(['C', cp1[0], cp1[1], cp2[0], cp2[1], p[0], p[1]]); + } + + if (isLoop) { + cp1 = controlPointList[len]; + cp2 = controlPointList[len + 1]; + p = pointList[0]; + d1.push(['C', cp1[0], cp1[1], cp2[0], cp2[1], p[0], p[1]]); + } + return d1; +} + +/** + * @ignore + * 根据关键点获取限定了范围的平滑线 + */ +export function getSplinePath(points: Point[], isInCircle?: boolean, constaint?: Position[]) { + const data = []; + const first = points[0]; + let prePoint = null; + if (points.length <= 2) { + // 两点以内直接绘制成路径 + return points2path(points, isInCircle); + } + for (let i = 0, len = points.length; i < len; i++) { + const point = points[i]; + if (!prePoint || !(prePoint.x === point.x && prePoint.y === point.y)) { + data.push(point.x); + data.push(point.y); + prePoint = point; + } + } + const constraint = constaint || [ + // 范围 + [0, 0], + [1, 1], + ]; + const splinePath = catmullRom2bezier(data, isInCircle, constraint); + splinePath.unshift(['M', first.x, first.y]); + return splinePath; +}