diff --git a/matrix.d.ts b/matrix.d.ts index e57f8398..fee6bd7a 100644 --- a/matrix.d.ts +++ b/matrix.d.ts @@ -1,4 +1,4 @@ -type MaybeMatrix = AbstractMatrix | number[][]; +type MaybeMatrix = AbstractMatrix | ArrayLike>; type ScalarOrMatrix = number | MaybeMatrix; type MatrixDimension = 'row' | 'column'; @@ -59,21 +59,21 @@ export interface IVarianceOptions { } export interface IVarianceByOptions { unbiased?: boolean; - mean?: number[]; + mean?: ArrayLike; } export interface ICenterOptions { center?: number; } export interface ICenterByOptions { - center?: number[]; + center?: ArrayLike; } export interface IScaleOptions { scale?: number; } export interface IScaleByOptions { - scale?: number[]; + scale?: ArrayLike; } export interface ICovarianceOptions { @@ -138,7 +138,7 @@ export abstract class AbstractMatrix { static from1DArray( newRows: number, newColumns: number, - newData: number[], + newData: ArrayLike, ): Matrix; /** @@ -146,14 +146,14 @@ export abstract class AbstractMatrix { * @param newData - A 1D array containing data for the vector. * @returns The new matrix. */ - static rowVector(newData: number[]): Matrix; + static rowVector(newData: ArrayLike): Matrix; /** * Creates a column vector, a matrix with only one column. * @param newData - A 1D array containing data for the vector. * @returns The new matrix. */ - static columnVector(newData: number[]): Matrix; + static columnVector(newData: ArrayLike): Matrix; /** * Creates a matrix with the given dimensions. Values will be set to zero. @@ -219,12 +219,16 @@ export abstract class AbstractMatrix { * @param columns - Number of columns. Default: `rows`. * @returns - The new diagonal matrix. */ - static diag(data: number[], rows?: number, columns?: number): Matrix; + static diag(data: ArrayLike, rows?: number, columns?: number): Matrix; /** * Alias for {@link AbstractMatrix.diag}. */ - static diagonal(data: number[], rows?: number, columns?: number): Matrix; + static diagonal( + data: ArrayLike, + rows?: number, + columns?: number, + ): Matrix; /** * Returns a matrix whose elements are the minimum between `matrix1` and `matrix2`. @@ -379,7 +383,7 @@ export abstract class AbstractMatrix { * @param index - Row index. * @param array - Array or vector to set. */ - setRow(index: number, array: number[] | AbstractMatrix): this; + setRow(index: number, array: ArrayLike | AbstractMatrix): this; /** * Swap two rows. @@ -405,7 +409,7 @@ export abstract class AbstractMatrix { * @param index - Column index. * @param array - Array or vector to set. */ - setColumn(index: number, array: number[] | AbstractMatrix): this; + setColumn(index: number, array: ArrayLike | AbstractMatrix): this; /** * Swap two columns. @@ -418,49 +422,49 @@ export abstract class AbstractMatrix { * Adds the values of a vector to each row. * @param vector - Array or vector. */ - addRowVector(vector: number[] | AbstractMatrix): this; + addRowVector(vector: ArrayLike | AbstractMatrix): this; /** * Subtracts the values of a vector from each row. * @param vector - Array or vector. */ - subRowVector(vector: number[] | AbstractMatrix): this; + subRowVector(vector: ArrayLike | AbstractMatrix): this; /** * Multiplies the values of a vector with each row. * @param vector - Array or vector. */ - mulRowVector(vector: number[] | AbstractMatrix): this; + mulRowVector(vector: ArrayLike | AbstractMatrix): this; /** * Divides the values of each row by those of a vector. * @param vector - Array or vector. */ - divRowVector(vector: number[] | AbstractMatrix): this; + divRowVector(vector: ArrayLike | AbstractMatrix): this; /** * Adds the values of a vector to each column. * @param vector - Array or vector. */ - addColumnVector(vector: number[] | AbstractMatrix): this; + addColumnVector(vector: ArrayLike | AbstractMatrix): this; /** * Subtracts the values of a vector from each column. * @param vector - Array or vector. */ - subColumnVector(vector: number[] | AbstractMatrix): this; + subColumnVector(vector: ArrayLike | AbstractMatrix): this; /** * Multiplies the values of a vector with each column. * @param vector - Array or vector. */ - mulColumnVector(vector: number[] | AbstractMatrix): this; + mulColumnVector(vector: ArrayLike | AbstractMatrix): this; /** * Divides the values of each column by those of a vector. * @param vector - Array or vector. */ - divColumnVector(vector: number[] | AbstractMatrix): this; + divColumnVector(vector: ArrayLike | AbstractMatrix): this; /** * Multiplies the values of a row with a scalar. @@ -608,7 +612,7 @@ export abstract class AbstractMatrix { * @param other - Other matrix. */ kroneckerProduct(other: MaybeMatrix): Matrix; - + /** * Returns the Kronecker sum between `this` and `other`. * @link https://en.wikipedia.org/wiki/Kronecker_product#Kronecker_sum @@ -659,7 +663,7 @@ export abstract class AbstractMatrix { * @param endColumn - Last column index. Default: `this.columns - 1`. */ subMatrixRow( - indices: number[], + indices: ArrayLike, startColumn?: number, endColumn?: number, ): Matrix; @@ -671,7 +675,7 @@ export abstract class AbstractMatrix { * @param endRow - Last row index. Default: `this.rows - 1`. */ subMatrixColumn( - indices: number[], + indices: ArrayLike, startRow?: number, endRow?: number, ): Matrix; @@ -683,7 +687,7 @@ export abstract class AbstractMatrix { * @param startColumn - The index of the first column to set. */ setSubMatrix( - matrix: MaybeMatrix | number[], + matrix: MaybeMatrix, startRow: number, startColumn: number, ): this; @@ -694,7 +698,10 @@ export abstract class AbstractMatrix { * @param rowIndices - The row indices to select. * @param columnIndices - The column indices to select. */ - selection(rowIndices: number[], columnIndices: number[]): Matrix; + selection( + rowIndices: ArrayLike, + columnIndices: ArrayLike, + ): Matrix; /** * Returns the trace of the matrix (sum of the diagonal elements). @@ -912,7 +919,7 @@ export abstract class AbstractMatrix { export class Matrix extends AbstractMatrix { constructor(nRows: number, nColumns: number); - constructor(data: number[][]); + constructor(data: ArrayLike>); constructor(otherMatrix: AbstractMatrix); /** @@ -932,14 +939,14 @@ export class Matrix extends AbstractMatrix { * @param index - Column index. Default: `this.columns`. * @param array - Column to add. */ - addColumn(index: number, array: number[] | AbstractMatrix): this; + addColumn(index: number, array: ArrayLike | AbstractMatrix): this; /** * Adds a new row to the matrix (in place). * @param index - Row index. Default: `this.rows`. * @param array - Row to add. */ - addRow(index: number, array: number[] | AbstractMatrix): this; + addRow(index: number, array: ArrayLike | AbstractMatrix): this; } export default Matrix; @@ -949,7 +956,7 @@ export class MatrixColumnView extends AbstractMatrix { } export class MatrixColumnSelectionView extends AbstractMatrix { - constructor(matrix: AbstractMatrix, columnIndices: number[]); + constructor(matrix: AbstractMatrix, columnIndices: ArrayLike); } export class MatrixFlipColumnView extends AbstractMatrix { @@ -965,14 +972,14 @@ export class MatrixRowView extends AbstractMatrix { } export class MatrixRowSelectionView extends AbstractMatrix { - constructor(matrix: AbstractMatrix, rowIndices: number[]); + constructor(matrix: AbstractMatrix, rowIndices: ArrayLike); } export class MatrixSelectionView extends AbstractMatrix { constructor( matrix: AbstractMatrix, - rowIndices: number[], - columnIndices: number[], + rowIndices: ArrayLike, + columnIndices: ArrayLike, ); } @@ -998,18 +1005,18 @@ export interface IWrap1DOptions { } export function wrap( - array: number[], + array: ArrayLike, options?: IWrap1DOptions, ): WrapperMatrix1D; -export function wrap(twoDAray: number[][]): WrapperMatrix2D; +export function wrap(twoDAray: ArrayLike>): WrapperMatrix2D; export class WrapperMatrix1D extends AbstractMatrix { - constructor(data: number[], options?: IWrap1DOptions); + constructor(data: ArrayLike, options?: IWrap1DOptions); } export class WrapperMatrix2D extends AbstractMatrix { - constructor(data: number[][]); + constructor(data: ArrayLike>); } /** @@ -1133,7 +1140,7 @@ export class SingularValueDecomposition { * @returns - The vector x. */ solve(value: Matrix): Matrix; - solveForDiagonal(value: number[]): Matrix; + solveForDiagonal(value: ArrayLike): Matrix; readonly norm2: number; readonly threshold: number; readonly leftSingularVectors: Matrix; @@ -1223,7 +1230,7 @@ export interface INipalsOptions { /** * A column vector of length `X.rows` that contains known labels for supervised PLS. */ - Y?: MaybeMatrix | number[]; + Y?: MaybeMatrix | ArrayLike; /** * The maximum number of allowed iterations before beraking the loop if convergence is not achieved. * @default 1000 diff --git a/package.json b/package.json index d9b34d6e..847aa17c 100644 --- a/package.json +++ b/package.json @@ -73,6 +73,7 @@ "rollup-plugin-terser": "^7.0.2" }, "dependencies": { + "is-any-array": "^2.0.0", "ml-array-rescale": "^1.3.7" } } diff --git a/src/__tests__/matrix/creation.js b/src/__tests__/matrix/creation.js index 3d70531c..24c06257 100644 --- a/src/__tests__/matrix/creation.js +++ b/src/__tests__/matrix/creation.js @@ -11,6 +11,20 @@ describe('Matrix creation', () => { expect(Matrix.isMatrix(matrix)).toBe(true); }); + it('should work with a typed array', () => { + const array = [ + Float64Array.of(1, 2, 3), + Float64Array.of(4, 5, 6), + Float64Array.of(7, 8, 9), + ]; + const matrix = new Matrix(array); + expect(matrix.to2DArray()).toStrictEqual([ + [1, 2, 3], + [4, 5, 6], + [7, 8, 9], + ]); + }); + it('should clone existing matrix', () => { let original = util.getSquareMatrix(); let matrix = new Matrix(original); diff --git a/src/__tests__/views/columnSelection.js b/src/__tests__/views/columnSelection.js index 6b736521..e21a3733 100644 --- a/src/__tests__/views/columnSelection.js +++ b/src/__tests__/views/columnSelection.js @@ -17,7 +17,7 @@ describe('Selection column view', () => { 'column indices are out of range', ); expect(() => new MatrixColumnSelectionView(m, 1)).toThrow( - 'unexpected type for column indices', + 'column indices must be an array', ); }); }); diff --git a/src/__tests__/views/rowSelection.js b/src/__tests__/views/rowSelection.js index 9f2b2854..9cb8a9a4 100644 --- a/src/__tests__/views/rowSelection.js +++ b/src/__tests__/views/rowSelection.js @@ -17,7 +17,7 @@ describe('Selection view', () => { 'row indices are out of range', ); expect(() => new MatrixRowSelectionView(m, 1)).toThrow( - 'unexpected type for row indices', + 'row indices must be an array', ); }); }); diff --git a/src/correlation.js b/src/correlation.js index 8d32c068..b54c34c5 100644 --- a/src/correlation.js +++ b/src/correlation.js @@ -1,3 +1,5 @@ +import { isAnyArray } from 'is-any-array'; + import Matrix from './matrix'; export function correlation(xMatrix, yMatrix = xMatrix, options = {}) { @@ -6,7 +8,7 @@ export function correlation(xMatrix, yMatrix = xMatrix, options = {}) { if ( typeof yMatrix === 'object' && !Matrix.isMatrix(yMatrix) && - !Array.isArray(yMatrix) + !isAnyArray(yMatrix) ) { options = yMatrix; yMatrix = xMatrix; diff --git a/src/covariance.js b/src/covariance.js index e7bad437..cef262c2 100644 --- a/src/covariance.js +++ b/src/covariance.js @@ -1,3 +1,5 @@ +import { isAnyArray } from 'is-any-array'; + import Matrix from './matrix'; export function covariance(xMatrix, yMatrix = xMatrix, options = {}) { @@ -6,7 +8,7 @@ export function covariance(xMatrix, yMatrix = xMatrix, options = {}) { if ( typeof yMatrix === 'object' && !Matrix.isMatrix(yMatrix) && - !Array.isArray(yMatrix) + !isAnyArray(yMatrix) ) { options = yMatrix; yMatrix = xMatrix; diff --git a/src/dc/nipals.js b/src/dc/nipals.js index 9894b1b1..8fd440af 100644 --- a/src/dc/nipals.js +++ b/src/dc/nipals.js @@ -1,3 +1,5 @@ +import { isAnyArray } from 'is-any-array'; + import Matrix from '../matrix'; import WrapperMatrix2D from '../wrap/WrapperMatrix2D'; @@ -13,7 +15,7 @@ export default class nipals { let u; if (Y) { - if (Array.isArray(Y) && typeof Y[0] === 'number') { + if (isAnyArray(Y) && typeof Y[0] === 'number') { Y = Matrix.columnVector(Y); } else { Y = WrapperMatrix2D.checkMatrix(Y); diff --git a/src/matrix.js b/src/matrix.js index 42835269..15f1d66f 100644 --- a/src/matrix.js +++ b/src/matrix.js @@ -1,3 +1,4 @@ +import { isAnyArray } from 'is-any-array'; import rescale from 'ml-array-rescale'; import { inspectMatrix, inspectMatrixWithOptions } from './inspect'; @@ -28,8 +29,9 @@ import { checkColumnIndex, checkColumnVector, checkRange, - checkIndices, checkNonEmpty, + checkRowIndices, + checkColumnIndices, } from './util'; export class AbstractMatrix { @@ -1237,12 +1239,13 @@ export class AbstractMatrix { } selection(rowIndices, columnIndices) { - let indices = checkIndices(this, rowIndices, columnIndices); + checkRowIndices(this, rowIndices); + checkColumnIndices(this, columnIndices); let newMatrix = new Matrix(rowIndices.length, columnIndices.length); - for (let i = 0; i < indices.row.length; i++) { - let rowIndex = indices.row[i]; - for (let j = 0; j < indices.column.length; j++) { - let columnIndex = indices.column[j]; + for (let i = 0; i < rowIndices.length; i++) { + let rowIndex = rowIndices[i]; + for (let j = 0; j < columnIndices.length; j++) { + let columnIndex = columnIndices[j]; newMatrix.set(i, j, this.get(rowIndex, columnIndex)); } } @@ -1330,13 +1333,13 @@ export class AbstractMatrix { } switch (by) { case 'row': { - if (!Array.isArray(mean)) { + if (!isAnyArray(mean)) { throw new TypeError('mean must be an array'); } return varianceByRow(this, unbiased, mean); } case 'column': { - if (!Array.isArray(mean)) { + if (!isAnyArray(mean)) { throw new TypeError('mean must be an array'); } return varianceByColumn(this, unbiased, mean); @@ -1379,14 +1382,14 @@ export class AbstractMatrix { const { center = this.mean(by) } = options; switch (by) { case 'row': { - if (!Array.isArray(center)) { + if (!isAnyArray(center)) { throw new TypeError('center must be an array'); } centerByRow(this, center); return this; } case 'column': { - if (!Array.isArray(center)) { + if (!isAnyArray(center)) { throw new TypeError('center must be an array'); } centerByColumn(this, center); @@ -1417,7 +1420,7 @@ export class AbstractMatrix { case 'row': { if (scale === undefined) { scale = getScaleByRow(this); - } else if (!Array.isArray(scale)) { + } else if (!isAnyArray(scale)) { throw new TypeError('scale must be an array'); } scaleByRow(this, scale); @@ -1426,7 +1429,7 @@ export class AbstractMatrix { case 'column': { if (scale === undefined) { scale = getScaleByColumn(this); - } else if (!Array.isArray(scale)) { + } else if (!isAnyArray(scale)) { throw new TypeError('scale must be an array'); } scaleByColumn(this, scale); @@ -1487,7 +1490,7 @@ export default class Matrix extends AbstractMatrix { } else { throw new TypeError('nColumns must be a positive integer'); } - } else if (Array.isArray(nRows)) { + } else if (isAnyArray(nRows)) { // Copy the values from the 2D array const arrayData = nRows; nRows = arrayData.length; diff --git a/src/util.js b/src/util.js index c70d99b6..2ac556b2 100644 --- a/src/util.js +++ b/src/util.js @@ -1,3 +1,5 @@ +import { isAnyArray } from 'is-any-array'; + /** * @private * Check that a row index is not out of bounds @@ -64,46 +66,28 @@ export function checkColumnVector(matrix, vector) { return vector; } -export function checkIndices(matrix, rowIndices, columnIndices) { - return { - row: checkRowIndices(matrix, rowIndices), - column: checkColumnIndices(matrix, columnIndices), - }; -} - export function checkRowIndices(matrix, rowIndices) { - if (typeof rowIndices !== 'object') { - throw new TypeError('unexpected type for row indices'); + if (!isAnyArray(rowIndices)) { + throw new TypeError('row indices must be an array'); } - let rowOut = rowIndices.some((r) => { - return r < 0 || r >= matrix.rows; - }); - - if (rowOut) { - throw new RangeError('row indices are out of range'); + for (let i = 0; i < rowIndices.length; i++) { + if (rowIndices[i] < 0 || rowIndices[i] >= matrix.rows) { + throw new RangeError('row indices are out of range'); + } } - - if (!Array.isArray(rowIndices)) rowIndices = Array.from(rowIndices); - - return rowIndices; } export function checkColumnIndices(matrix, columnIndices) { - if (typeof columnIndices !== 'object') { - throw new TypeError('unexpected type for column indices'); + if (!isAnyArray(columnIndices)) { + throw new TypeError('column indices must be an array'); } - let columnOut = columnIndices.some((c) => { - return c < 0 || c >= matrix.columns; - }); - - if (columnOut) { - throw new RangeError('column indices are out of range'); + for (let i = 0; i < columnIndices.length; i++) { + if (columnIndices[i] < 0 || columnIndices[i] >= matrix.columns) { + throw new RangeError('column indices are out of range'); + } } - if (!Array.isArray(columnIndices)) columnIndices = Array.from(columnIndices); - - return columnIndices; } export function checkRange(matrix, startRow, endRow, startColumn, endColumn) { diff --git a/src/views/columnSelection.js b/src/views/columnSelection.js index fb515e29..76675aec 100644 --- a/src/views/columnSelection.js +++ b/src/views/columnSelection.js @@ -4,7 +4,7 @@ import BaseView from './base'; export default class MatrixColumnSelectionView extends BaseView { constructor(matrix, columnIndices) { - columnIndices = checkColumnIndices(matrix, columnIndices); + checkColumnIndices(matrix, columnIndices); super(matrix, matrix.rows, columnIndices.length); this.columnIndices = columnIndices; } diff --git a/src/views/rowSelection.js b/src/views/rowSelection.js index 42dfd46d..ff32f7ce 100644 --- a/src/views/rowSelection.js +++ b/src/views/rowSelection.js @@ -4,7 +4,7 @@ import BaseView from './base'; export default class MatrixRowSelectionView extends BaseView { constructor(matrix, rowIndices) { - rowIndices = checkRowIndices(matrix, rowIndices); + checkRowIndices(matrix, rowIndices); super(matrix, rowIndices.length, matrix.columns); this.rowIndices = rowIndices; } diff --git a/src/views/selection.js b/src/views/selection.js index c10d9200..679b412a 100644 --- a/src/views/selection.js +++ b/src/views/selection.js @@ -1,13 +1,14 @@ -import { checkIndices } from '../util'; +import { checkRowIndices, checkColumnIndices } from '../util'; import BaseView from './base'; export default class MatrixSelectionView extends BaseView { constructor(matrix, rowIndices, columnIndices) { - let indices = checkIndices(matrix, rowIndices, columnIndices); - super(matrix, indices.row.length, indices.column.length); - this.rowIndices = indices.row; - this.columnIndices = indices.column; + checkRowIndices(matrix, rowIndices); + checkColumnIndices(matrix, columnIndices); + super(matrix, rowIndices.length, columnIndices.length); + this.rowIndices = rowIndices; + this.columnIndices = columnIndices; } set(rowIndex, columnIndex, value) { diff --git a/src/wrap/wrap.js b/src/wrap/wrap.js index dc68bee6..83de8faa 100644 --- a/src/wrap/wrap.js +++ b/src/wrap/wrap.js @@ -1,9 +1,11 @@ +import { isAnyArray } from 'is-any-array'; + import WrapperMatrix1D from './WrapperMatrix1D'; import WrapperMatrix2D from './WrapperMatrix2D'; export function wrap(array, options) { - if (Array.isArray(array)) { - if (array[0] && Array.isArray(array[0])) { + if (isAnyArray(array)) { + if (array[0] && isAnyArray(array[0])) { return new WrapperMatrix2D(array); } else { return new WrapperMatrix1D(array, options);