Skip to content

Commit

Permalink
Make prediction use less bytes when possible.
Browse files Browse the repository at this point in the history
under the default precision of 16 we could have used
Uint16Array to save the memory but we haven't proved that
it is indeed safe until now. now having proved that,
we can now safely increase contextBits in the default setting.
  • Loading branch information
lifthrasiir committed Aug 19, 2021
1 parent ef4d2ba commit 9b1082e
Show file tree
Hide file tree
Showing 3 changed files with 58 additions and 26 deletions.
3 changes: 1 addition & 2 deletions index.d.ts
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
export class ArrayBufferPool {
constructor();
allocate(parent: object, size: number): ArrayBuffer;
newUint32Array(parent: object, length: number): Uint32Array;
newUint8Array(parent: object, length: number): Uint8Array;
release(buf: ArrayBuffer): void;
}

Expand Down Expand Up @@ -173,6 +171,7 @@ export interface PackerOptions {

export class Packer {
constructor(inputs: Input[], options: PackerOptions);
readonly memoryUsageMB: number;
makeDecoder(): {
firstLine: string;
firstLineLengthInBytes: number;
Expand Down
77 changes: 55 additions & 22 deletions index.js
Original file line number Diff line number Diff line change
Expand Up @@ -31,14 +31,6 @@ export class ArrayBufferPool {
return buf;
}

newUint32Array(parent, length) {
return new Uint32Array(this.allocate(parent, length * 4));
}

newUint8Array(parent, length) {
return new Uint8Array(this.allocate(parent, length));
}

// FinalizationRegistry is also possible, but GC couldn't keep up with the memory usage
release(buf) {
let available = this.pool.get(buf.byteLength);
Expand All @@ -50,6 +42,14 @@ export class ArrayBufferPool {
}
}

const newUintArray = (pool, parent, nbits, length) => {
if (nbits <= 8) return new Uint8Array(pool ? pool.allocate(parent, length) : length);
if (nbits <= 16) return new Uint16Array(pool ? pool.allocate(parent, length * 2) : length);
if (nbits <= 32) return new Uint32Array(pool ? pool.allocate(parent, length * 4) : length);
if (nbits <= 64) return new Uint64Array(pool ? pool.allocate(parent, length * 8) : length);
throw 'newUintArray: nbits is too large';
};

//------------------------------------------------------------------------------

const getAnsL = outBits => 1 << (28 - outBits);
Expand Down Expand Up @@ -184,13 +184,8 @@ export class DirectContextModel {
this.modelMaxCount = modelMaxCount;

this.arrayBufferPool = arrayBufferPool;
if (arrayBufferPool) {
this.predictions = arrayBufferPool.newUint32Array(this, 1 << contextBits);
this.counts = arrayBufferPool.newUint8Array(this, 1 << contextBits);
} else {
this.predictions = new Uint32Array(1 << contextBits);
this.counts = new Uint8Array(1 << contextBits)
}
this.predictions = newUintArray(arrayBufferPool, this, precision, 1 << contextBits);
this.counts = newUintArray(arrayBufferPool, this, Math.ceil(Math.log2(modelMaxCount)), 1 << contextBits);
this.predictions.fill(1 << (precision - 1));
this.counts.fill(0);

Expand All @@ -213,10 +208,34 @@ export class DirectContextModel {
++this.counts[context];
}

const delta = ((actualBit << this.precision) - this.predictions[context]) * (1 << (30 - this.precision));
if (delta < -0x80000000 || delta > 0x7fffffff) {
throw new Error('DirectContextModel.update: delta overflow');
}
// adjust P = predictions[context] by (actual - P) / (counts[context] + 0.5).
// when delta = (actual - P) * 2, this adjustment equals to delta / (2 * counts[context] + 1).
// in the compact decoder (2 * counts[context] + 1) is directly stored in the typed array.
//
// claim:
// 1. the entire calculation always stays in the 32-bit signed integer.
// 2. P always stays in the [0, 2^precision) range.
//
// proof:
// assume that 0 <= P < 2^precision and P is an integer.
// counts[context] is already updated so counts[context] >= 1.
//
// if delta > 0, delta = (2^precision - P) * 2^(30-precision) < 2^30.
// then P' = P + trunc(delta / (2 * counts[context] + 1)) / 2^(29-precision)
// <= P + delta / 3 / 2^(29-precision)
// = P + (2^precision - P) * 2^(30-precision) / 2^(29-precision) / 3
// = 2/3 2^precision + 1/3 P
// <= 2/3 2^precision + 1/3 (2^precision - 1)
// = 2^precision - 1/3.
// therefore P' < 2^precision.
//
// if delta < 0, delta = -P * 2^(30-precision) > -2^30.
// then P' = P + trunc(delta / (2 * counts[context] + 1)) / 2^(29-precision)
// >= P + delta / 3 / 2^(29-precision)
// = P - 2/3 P
// > 0.
// therefore P' >= 0.
const delta = ((actualBit << this.precision) - this.predictions[context]) << (30 - this.precision);
this.predictions[context] += (delta / (2 * this.counts[context] + 1) | 0) >> (29 - this.precision);

this.bitContext = (this.bitContext << 1) | actualBit;
Expand Down Expand Up @@ -486,15 +505,21 @@ export const optimizeSparseSelectors = async (selectors, calculateSize, progress

//------------------------------------------------------------------------------

const predictionBytesPerContext = options => (options.precision <= 8 ? 1 : options.precision <= 16 ? 2 : 4);
const countBytesPerContext = options => (options.modelMaxCount < 128 ? 1 : options.modelMaxCount < 32768 ? 2 : 4);

export class Packer {
constructor(inputs, options = {}) {
options.sparseSelectors = options.sparseSelectors || defaultSparseSelectors();
options.maxMemoryMB = options.maxMemoryMB || 150;
options.contextBits = options.contextBits || (Math.log2(options.maxMemoryMB / options.sparseSelectors.length / 5) + 20 | 0);
options.precision = options.precision || 16;
options.modelMaxCount = options.modelMaxCount || 63;
options.learningRateNum = options.learningRateNum || 1;
options.learningRateDenom = options.learningRateDenom || 256;
if (!options.contextBits) {
const bytesPerContext = predictionBytesPerContext(options) + countBytesPerContext(options);
options.contextBits = Math.log2(options.maxMemoryMB / options.sparseSelectors.length / bytesPerContext) + 20 | 0;
}
this.options = options;

this.inputsByType = {};
Expand Down Expand Up @@ -536,6 +561,11 @@ export class Packer {
}
}

get memoryUsageMB() {
const bytesPerContext = predictionBytesPerContext(this.options) + countBytesPerContext(this.options);
return ((this.options.sparseSelectors.length * bytesPerContext) << this.options.contextBits) / 1048576;
}

prepare() {
if (this.combinedInput) return;

Expand Down Expand Up @@ -739,6 +769,9 @@ export class Packer {
compressWithModel(this.combinedInput, model, this.options);

const numModels = sparseSelectors.length;
const predictionBits = 8 * predictionBytesPerContext(this.options);
const countBits = 8 * countBytesPerContext(this.options);

const selectors = [];
for (const selector of sparseSelectors) {
const bits = [];
Expand Down Expand Up @@ -798,8 +831,8 @@ export class Packer {
`t=${state};` +
`M=1<<${precision + 1};` +
`w=${JSON.stringify(Array(numModels).fill(0))};` +
`p=new Uint32Array(${numModels}<<${contextBits}).fill(M/4);` +
`c=new Uint8Array(${numModels}<<${contextBits}).fill(1);` +
`p=new Uint${predictionBits}Array(${numModels}<<${contextBits}).fill(M/4);` +
`c=new Uint${countBits}Array(${numModels}<<${contextBits}).fill(1);` +

// z: decoded data
// r: read position in _
Expand Down
4 changes: 2 additions & 2 deletions tools/demo.html
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@
<aside id=$outputmessage></aside>
<footer><ul>
<li><!--<label>Number of Contexts: <input id=$numcontexts type=number value=12 min=1 max=64> <a href=#num-contexts title=Help>ℹ️</a></label> --><button id=$optimizecontexts>Optimize contexts</button> <a href=#optimize-contexts title=Help>ℹ️</a>
<li><label>Maximum memory usage: <input id=$maxmemory type=number value=150 min=10 max=1024> MB (<span id=$estimatedmemory>120</span> MB estimated) <a href=#max-memory title=Help>ℹ️</a>
<li><label>Maximum memory usage: <input id=$maxmemory type=number value=150 min=10 max=1024> MB (<span id=$estimatedmemory>144</span> MB estimated) <a href=#max-memory title=Help>ℹ️</a>
</ul><!--<details><summary>Advanced configuration</summary><ul>
<li><label>Output variable name: <input id=$outputvar pattern=([A-Za-z_$][0-9A-Za-z_$]*)? value=> <a href=#output-var title=Help>ℹ️</a>
<li><label>Chosen contexts: <input id=$selectors pattern=\d+(,\d+) value=> <a href=#chosen-contexts title=Help>ℹ️</a>
Expand Down Expand Up @@ -264,7 +264,7 @@ <h2>Command-line Usage and API</h2>
const initOutput = () => {
$maxmemory.oninput = () => {
const numContexts = 12; // TODO for now
$estimatedmemory.textContent = numContexts * 5 * 2**Math.floor(Math.log2($maxmemory.value / numContexts / 5));
$estimatedmemory.textContent = numContexts * 3 * 2**Math.floor(Math.log2($maxmemory.value / numContexts / 3));
refreshOutput();
};

Expand Down

0 comments on commit 9b1082e

Please sign in to comment.