Skip to content

Commit

Permalink
feat: allow to pass any array in all APIs
Browse files Browse the repository at this point in the history
  • Loading branch information
targos committed Feb 26, 2022
1 parent 00794f4 commit 362d8a1
Show file tree
Hide file tree
Showing 14 changed files with 112 additions and 94 deletions.
81 changes: 44 additions & 37 deletions matrix.d.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
type MaybeMatrix = AbstractMatrix | number[][];
type MaybeMatrix = AbstractMatrix | ArrayLike<ArrayLike<number>>;
type ScalarOrMatrix = number | MaybeMatrix;
type MatrixDimension = 'row' | 'column';

Expand Down Expand Up @@ -59,21 +59,21 @@ export interface IVarianceOptions {
}
export interface IVarianceByOptions {
unbiased?: boolean;
mean?: number[];
mean?: ArrayLike<number>;
}

export interface ICenterOptions {
center?: number;
}
export interface ICenterByOptions {
center?: number[];
center?: ArrayLike<number>;
}

export interface IScaleOptions {
scale?: number;
}
export interface IScaleByOptions {
scale?: number[];
scale?: ArrayLike<number>;
}

export interface ICovarianceOptions {
Expand Down Expand Up @@ -138,22 +138,22 @@ export abstract class AbstractMatrix {
static from1DArray(
newRows: number,
newColumns: number,
newData: number[],
newData: ArrayLike<number>,
): Matrix;

/**
* Creates a row vector, a matrix with only one row.
* @param newData - A 1D array containing data for the vector.
* @returns The new matrix.
*/
static rowVector(newData: number[]): Matrix;
static rowVector(newData: ArrayLike<number>): Matrix;

/**
* Creates a column vector, a matrix with only one column.
* @param newData - A 1D array containing data for the vector.
* @returns The new matrix.
*/
static columnVector(newData: number[]): Matrix;
static columnVector(newData: ArrayLike<number>): Matrix;

/**
* Creates a matrix with the given dimensions. Values will be set to zero.
Expand Down Expand Up @@ -219,12 +219,16 @@ export abstract class AbstractMatrix {
* @param columns - Number of columns. Default: `rows`.
* @returns - The new diagonal matrix.
*/
static diag(data: number[], rows?: number, columns?: number): Matrix;
static diag(data: ArrayLike<number>, rows?: number, columns?: number): Matrix;

/**
* Alias for {@link AbstractMatrix.diag}.
*/
static diagonal(data: number[], rows?: number, columns?: number): Matrix;
static diagonal(
data: ArrayLike<number>,
rows?: number,
columns?: number,
): Matrix;

/**
* Returns a matrix whose elements are the minimum between `matrix1` and `matrix2`.
Expand Down Expand Up @@ -379,7 +383,7 @@ export abstract class AbstractMatrix {
* @param index - Row index.
* @param array - Array or vector to set.
*/
setRow(index: number, array: number[] | AbstractMatrix): this;
setRow(index: number, array: ArrayLike<number> | AbstractMatrix): this;

/**
* Swap two rows.
Expand All @@ -405,7 +409,7 @@ export abstract class AbstractMatrix {
* @param index - Column index.
* @param array - Array or vector to set.
*/
setColumn(index: number, array: number[] | AbstractMatrix): this;
setColumn(index: number, array: ArrayLike<number> | AbstractMatrix): this;

/**
* Swap two columns.
Expand All @@ -418,49 +422,49 @@ export abstract class AbstractMatrix {
* Adds the values of a vector to each row.
* @param vector - Array or vector.
*/
addRowVector(vector: number[] | AbstractMatrix): this;
addRowVector(vector: ArrayLike<number> | AbstractMatrix): this;

/**
* Subtracts the values of a vector from each row.
* @param vector - Array or vector.
*/
subRowVector(vector: number[] | AbstractMatrix): this;
subRowVector(vector: ArrayLike<number> | AbstractMatrix): this;

/**
* Multiplies the values of a vector with each row.
* @param vector - Array or vector.
*/
mulRowVector(vector: number[] | AbstractMatrix): this;
mulRowVector(vector: ArrayLike<number> | AbstractMatrix): this;

/**
* Divides the values of each row by those of a vector.
* @param vector - Array or vector.
*/
divRowVector(vector: number[] | AbstractMatrix): this;
divRowVector(vector: ArrayLike<number> | AbstractMatrix): this;

/**
* Adds the values of a vector to each column.
* @param vector - Array or vector.
*/
addColumnVector(vector: number[] | AbstractMatrix): this;
addColumnVector(vector: ArrayLike<number> | AbstractMatrix): this;

/**
* Subtracts the values of a vector from each column.
* @param vector - Array or vector.
*/
subColumnVector(vector: number[] | AbstractMatrix): this;
subColumnVector(vector: ArrayLike<number> | AbstractMatrix): this;

/**
* Multiplies the values of a vector with each column.
* @param vector - Array or vector.
*/
mulColumnVector(vector: number[] | AbstractMatrix): this;
mulColumnVector(vector: ArrayLike<number> | AbstractMatrix): this;

/**
* Divides the values of each column by those of a vector.
* @param vector - Array or vector.
*/
divColumnVector(vector: number[] | AbstractMatrix): this;
divColumnVector(vector: ArrayLike<number> | AbstractMatrix): this;

/**
* Multiplies the values of a row with a scalar.
Expand Down Expand Up @@ -608,7 +612,7 @@ export abstract class AbstractMatrix {
* @param other - Other matrix.
*/
kroneckerProduct(other: MaybeMatrix): Matrix;

/**
* Returns the Kronecker sum between `this` and `other`.
* @link https://en.wikipedia.org/wiki/Kronecker_product#Kronecker_sum
Expand Down Expand Up @@ -659,7 +663,7 @@ export abstract class AbstractMatrix {
* @param endColumn - Last column index. Default: `this.columns - 1`.
*/
subMatrixRow(
indices: number[],
indices: ArrayLike<number>,
startColumn?: number,
endColumn?: number,
): Matrix;
Expand All @@ -671,7 +675,7 @@ export abstract class AbstractMatrix {
* @param endRow - Last row index. Default: `this.rows - 1`.
*/
subMatrixColumn(
indices: number[],
indices: ArrayLike<number>,
startRow?: number,
endRow?: number,
): Matrix;
Expand All @@ -683,7 +687,7 @@ export abstract class AbstractMatrix {
* @param startColumn - The index of the first column to set.
*/
setSubMatrix(
matrix: MaybeMatrix | number[],
matrix: MaybeMatrix,
startRow: number,
startColumn: number,
): this;
Expand All @@ -694,7 +698,10 @@ export abstract class AbstractMatrix {
* @param rowIndices - The row indices to select.
* @param columnIndices - The column indices to select.
*/
selection(rowIndices: number[], columnIndices: number[]): Matrix;
selection(
rowIndices: ArrayLike<number>,
columnIndices: ArrayLike<number>,
): Matrix;

/**
* Returns the trace of the matrix (sum of the diagonal elements).
Expand Down Expand Up @@ -912,7 +919,7 @@ export abstract class AbstractMatrix {

export class Matrix extends AbstractMatrix {
constructor(nRows: number, nColumns: number);
constructor(data: number[][]);
constructor(data: ArrayLike<ArrayLike<number>>);
constructor(otherMatrix: AbstractMatrix);

/**
Expand All @@ -932,14 +939,14 @@ export class Matrix extends AbstractMatrix {
* @param index - Column index. Default: `this.columns`.
* @param array - Column to add.
*/
addColumn(index: number, array: number[] | AbstractMatrix): this;
addColumn(index: number, array: ArrayLike<number> | AbstractMatrix): this;

/**
* Adds a new row to the matrix (in place).
* @param index - Row index. Default: `this.rows`.
* @param array - Row to add.
*/
addRow(index: number, array: number[] | AbstractMatrix): this;
addRow(index: number, array: ArrayLike<number> | AbstractMatrix): this;
}

export default Matrix;
Expand All @@ -949,7 +956,7 @@ export class MatrixColumnView extends AbstractMatrix {
}

export class MatrixColumnSelectionView extends AbstractMatrix {
constructor(matrix: AbstractMatrix, columnIndices: number[]);
constructor(matrix: AbstractMatrix, columnIndices: ArrayLike<number>);
}

export class MatrixFlipColumnView extends AbstractMatrix {
Expand All @@ -965,14 +972,14 @@ export class MatrixRowView extends AbstractMatrix {
}

export class MatrixRowSelectionView extends AbstractMatrix {
constructor(matrix: AbstractMatrix, rowIndices: number[]);
constructor(matrix: AbstractMatrix, rowIndices: ArrayLike<number>);
}

export class MatrixSelectionView extends AbstractMatrix {
constructor(
matrix: AbstractMatrix,
rowIndices: number[],
columnIndices: number[],
rowIndices: ArrayLike<number>,
columnIndices: ArrayLike<number>,
);
}

Expand All @@ -998,18 +1005,18 @@ export interface IWrap1DOptions {
}

export function wrap(
array: number[],
array: ArrayLike<number>,
options?: IWrap1DOptions,
): WrapperMatrix1D;

export function wrap(twoDAray: number[][]): WrapperMatrix2D;
export function wrap(twoDAray: ArrayLike<ArrayLike<number>>): WrapperMatrix2D;

export class WrapperMatrix1D extends AbstractMatrix {
constructor(data: number[], options?: IWrap1DOptions);
constructor(data: ArrayLike<number>, options?: IWrap1DOptions);
}

export class WrapperMatrix2D extends AbstractMatrix {
constructor(data: number[][]);
constructor(data: ArrayLike<ArrayLike<number>>);
}

/**
Expand Down Expand Up @@ -1133,7 +1140,7 @@ export class SingularValueDecomposition {
* @returns - The vector x.
*/
solve(value: Matrix): Matrix;
solveForDiagonal(value: number[]): Matrix;
solveForDiagonal(value: ArrayLike<number>): Matrix;
readonly norm2: number;
readonly threshold: number;
readonly leftSingularVectors: Matrix;
Expand Down Expand Up @@ -1223,7 +1230,7 @@ export interface INipalsOptions {
/**
* A column vector of length `X.rows` that contains known labels for supervised PLS.
*/
Y?: MaybeMatrix | number[];
Y?: MaybeMatrix | ArrayLike<number>;
/**
* The maximum number of allowed iterations before beraking the loop if convergence is not achieved.
* @default 1000
Expand Down
1 change: 1 addition & 0 deletions package.json
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@
"rollup-plugin-terser": "^7.0.2"
},
"dependencies": {
"is-any-array": "^2.0.0",
"ml-array-rescale": "^1.3.7"
}
}
14 changes: 14 additions & 0 deletions src/__tests__/matrix/creation.js
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,20 @@ describe('Matrix creation', () => {
expect(Matrix.isMatrix(matrix)).toBe(true);
});

it('should work with a typed array', () => {
const array = [
Float64Array.of(1, 2, 3),
Float64Array.of(4, 5, 6),
Float64Array.of(7, 8, 9),
];
const matrix = new Matrix(array);
expect(matrix.to2DArray()).toStrictEqual([
[1, 2, 3],
[4, 5, 6],
[7, 8, 9],
]);
});

it('should clone existing matrix', () => {
let original = util.getSquareMatrix();
let matrix = new Matrix(original);
Expand Down
2 changes: 1 addition & 1 deletion src/__tests__/views/columnSelection.js
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ describe('Selection column view', () => {
'column indices are out of range',
);
expect(() => new MatrixColumnSelectionView(m, 1)).toThrow(
'unexpected type for column indices',
'column indices must be an array',
);
});
});
2 changes: 1 addition & 1 deletion src/__tests__/views/rowSelection.js
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ describe('Selection view', () => {
'row indices are out of range',
);
expect(() => new MatrixRowSelectionView(m, 1)).toThrow(
'unexpected type for row indices',
'row indices must be an array',
);
});
});
4 changes: 3 additions & 1 deletion src/correlation.js
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import { isAnyArray } from 'is-any-array';

import Matrix from './matrix';

export function correlation(xMatrix, yMatrix = xMatrix, options = {}) {
Expand All @@ -6,7 +8,7 @@ export function correlation(xMatrix, yMatrix = xMatrix, options = {}) {
if (
typeof yMatrix === 'object' &&
!Matrix.isMatrix(yMatrix) &&
!Array.isArray(yMatrix)
!isAnyArray(yMatrix)
) {
options = yMatrix;
yMatrix = xMatrix;
Expand Down
4 changes: 3 additions & 1 deletion src/covariance.js
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import { isAnyArray } from 'is-any-array';

import Matrix from './matrix';

export function covariance(xMatrix, yMatrix = xMatrix, options = {}) {
Expand All @@ -6,7 +8,7 @@ export function covariance(xMatrix, yMatrix = xMatrix, options = {}) {
if (
typeof yMatrix === 'object' &&
!Matrix.isMatrix(yMatrix) &&
!Array.isArray(yMatrix)
!isAnyArray(yMatrix)
) {
options = yMatrix;
yMatrix = xMatrix;
Expand Down
4 changes: 3 additions & 1 deletion src/dc/nipals.js
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import { isAnyArray } from 'is-any-array';

import Matrix from '../matrix';
import WrapperMatrix2D from '../wrap/WrapperMatrix2D';

Expand All @@ -13,7 +15,7 @@ export default class nipals {

let u;
if (Y) {
if (Array.isArray(Y) && typeof Y[0] === 'number') {
if (isAnyArray(Y) && typeof Y[0] === 'number') {
Y = Matrix.columnVector(Y);
} else {
Y = WrapperMatrix2D.checkMatrix(Y);
Expand Down
Loading

0 comments on commit 362d8a1

Please sign in to comment.