Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

huff0: Speed up compression of short blocks #744

Merged
merged 1 commit into from
Jan 20, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
90 changes: 55 additions & 35 deletions huff0/compress.go
Original file line number Diff line number Diff line change
Expand Up @@ -484,53 +484,55 @@ func (s *Scratch) buildCTable() error {
// Different from reference implementation.
huffNode0 := s.nodes[0 : huffNodesLen+1]

for huffNode[nonNullRank].count == 0 {
for huffNode[nonNullRank].count() == 0 {
nonNullRank--
}

lowS := int16(nonNullRank)
nodeRoot := nodeNb + lowS - 1
lowN := nodeNb
huffNode[nodeNb].count = huffNode[lowS].count + huffNode[lowS-1].count
huffNode[lowS].parent, huffNode[lowS-1].parent = uint16(nodeNb), uint16(nodeNb)
huffNode[nodeNb].setCount(huffNode[lowS].count() + huffNode[lowS-1].count())
huffNode[lowS].setParent(nodeNb)
huffNode[lowS-1].setParent(nodeNb)
nodeNb++
lowS -= 2
for n := nodeNb; n <= nodeRoot; n++ {
huffNode[n].count = 1 << 30
huffNode[n].setCount(1 << 30)
}
// fake entry, strong barrier
huffNode0[0].count = 1 << 31
huffNode0[0].setCount(1 << 31)

// create parents
for nodeNb <= nodeRoot {
var n1, n2 int16
if huffNode0[lowS+1].count < huffNode0[lowN+1].count {
if huffNode0[lowS+1].count() < huffNode0[lowN+1].count() {
n1 = lowS
lowS--
} else {
n1 = lowN
lowN++
}
if huffNode0[lowS+1].count < huffNode0[lowN+1].count {
if huffNode0[lowS+1].count() < huffNode0[lowN+1].count() {
n2 = lowS
lowS--
} else {
n2 = lowN
lowN++
}

huffNode[nodeNb].count = huffNode0[n1+1].count + huffNode0[n2+1].count
huffNode0[n1+1].parent, huffNode0[n2+1].parent = uint16(nodeNb), uint16(nodeNb)
huffNode[nodeNb].setCount(huffNode0[n1+1].count() + huffNode0[n2+1].count())
huffNode0[n1+1].setParent(nodeNb)
huffNode0[n2+1].setParent(nodeNb)
nodeNb++
}

// distribute weights (unlimited tree height)
huffNode[nodeRoot].nbBits = 0
huffNode[nodeRoot].setNbBits(0)
for n := nodeRoot - 1; n >= startNode; n-- {
huffNode[n].nbBits = huffNode[huffNode[n].parent].nbBits + 1
huffNode[n].setNbBits(huffNode[huffNode[n].parent()].nbBits() + 1)
}
for n := uint16(0); n <= nonNullRank; n++ {
huffNode[n].nbBits = huffNode[huffNode[n].parent].nbBits + 1
huffNode[n].setNbBits(huffNode[huffNode[n].parent()].nbBits() + 1)
}
s.actualTableLog = s.setMaxHeight(int(nonNullRank))
maxNbBits := s.actualTableLog
Expand All @@ -542,7 +544,7 @@ func (s *Scratch) buildCTable() error {
var nbPerRank [tableLogMax + 1]uint16
var valPerRank [16]uint16
for _, v := range huffNode[:nonNullRank+1] {
nbPerRank[v.nbBits]++
nbPerRank[v.nbBits()]++
}
// determine stating value per rank
{
Expand All @@ -557,7 +559,7 @@ func (s *Scratch) buildCTable() error {

// push nbBits per symbol, symbol order
for _, v := range huffNode[:nonNullRank+1] {
s.cTable[v.symbol].nBits = v.nbBits
s.cTable[v.symbol()].nBits = v.nbBits()
}

// assign value within rank, symbol order
Expand Down Expand Up @@ -603,12 +605,12 @@ func (s *Scratch) huffSort() {
pos := rank[r].current
rank[r].current++
prev := nodes[(pos-1)&huffNodesMask]
for pos > rank[r].base && c > prev.count {
for pos > rank[r].base && c > prev.count() {
nodes[pos&huffNodesMask] = prev
pos--
prev = nodes[(pos-1)&huffNodesMask]
}
nodes[pos&huffNodesMask] = nodeElt{count: c, symbol: byte(n)}
nodes[pos&huffNodesMask] = makeNodeElt(c, byte(n))
}
}

Expand All @@ -617,7 +619,7 @@ func (s *Scratch) setMaxHeight(lastNonNull int) uint8 {
huffNode := s.nodes[1 : huffNodesLen+1]
//huffNode = huffNode[: huffNodesLen]

largestBits := huffNode[lastNonNull].nbBits
largestBits := huffNode[lastNonNull].nbBits()

// early exit : no elt > maxNbBits
if largestBits <= maxNbBits {
Expand All @@ -627,14 +629,14 @@ func (s *Scratch) setMaxHeight(lastNonNull int) uint8 {
baseCost := int(1) << (largestBits - maxNbBits)
n := uint32(lastNonNull)

for huffNode[n].nbBits > maxNbBits {
totalCost += baseCost - (1 << (largestBits - huffNode[n].nbBits))
huffNode[n].nbBits = maxNbBits
for huffNode[n].nbBits() > maxNbBits {
totalCost += baseCost - (1 << (largestBits - huffNode[n].nbBits()))
huffNode[n].setNbBits(maxNbBits)
n--
}
// n stops at huffNode[n].nbBits <= maxNbBits

for huffNode[n].nbBits == maxNbBits {
for huffNode[n].nbBits() == maxNbBits {
n--
}
// n end at index of smallest symbol using < maxNbBits
Expand All @@ -655,10 +657,10 @@ func (s *Scratch) setMaxHeight(lastNonNull int) uint8 {
{
currentNbBits := maxNbBits
for pos := int(n); pos >= 0; pos-- {
if huffNode[pos].nbBits >= currentNbBits {
if huffNode[pos].nbBits() >= currentNbBits {
continue
}
currentNbBits = huffNode[pos].nbBits // < maxNbBits
currentNbBits = huffNode[pos].nbBits() // < maxNbBits
rankLast[maxNbBits-currentNbBits] = uint32(pos)
}
}
Expand All @@ -675,8 +677,8 @@ func (s *Scratch) setMaxHeight(lastNonNull int) uint8 {
if lowPos == noSymbol {
break
}
highTotal := huffNode[highPos].count
lowTotal := 2 * huffNode[lowPos].count
highTotal := huffNode[highPos].count()
lowTotal := 2 * huffNode[lowPos].count()
if highTotal <= lowTotal {
break
}
Expand All @@ -692,39 +694,57 @@ func (s *Scratch) setMaxHeight(lastNonNull int) uint8 {
// this rank is no longer empty
rankLast[nBitsToDecrease-1] = rankLast[nBitsToDecrease]
}
huffNode[rankLast[nBitsToDecrease]].nbBits++
huffNode[rankLast[nBitsToDecrease]].setNbBits(1 +
huffNode[rankLast[nBitsToDecrease]].nbBits())
if rankLast[nBitsToDecrease] == 0 {
/* special case, reached largest symbol */
rankLast[nBitsToDecrease] = noSymbol
} else {
rankLast[nBitsToDecrease]--
if huffNode[rankLast[nBitsToDecrease]].nbBits != maxNbBits-nBitsToDecrease {
if huffNode[rankLast[nBitsToDecrease]].nbBits() != maxNbBits-nBitsToDecrease {
rankLast[nBitsToDecrease] = noSymbol /* this rank is now empty */
}
}
}

for totalCost < 0 { /* Sometimes, cost correction overshoot */
if rankLast[1] == noSymbol { /* special case : no rank 1 symbol (using maxNbBits-1); let's create one from largest rank 0 (using maxNbBits) */
for huffNode[n].nbBits == maxNbBits {
for huffNode[n].nbBits() == maxNbBits {
n--
}
huffNode[n+1].nbBits--
huffNode[n+1].setNbBits(huffNode[n+1].nbBits() - 1)
rankLast[1] = n + 1
totalCost++
continue
}
huffNode[rankLast[1]+1].nbBits--
huffNode[rankLast[1]+1].setNbBits(huffNode[rankLast[1]+1].nbBits() - 1)
rankLast[1]++
totalCost++
}
}
return maxNbBits
}

type nodeElt struct {
count uint32
parent uint16
symbol byte
nbBits uint8
// A nodeElt is the fields
//
// count uint32
// parent uint16
// symbol byte
// nbBits uint8
//
// in some order, all squashed into an integer so that the compiler
// always loads and stores entire nodeElts instead of separate fields.
type nodeElt uint64

func makeNodeElt(count uint32, symbol byte) nodeElt {
return nodeElt(count) | nodeElt(symbol)<<48
}

func (e *nodeElt) count() uint32 { return uint32(*e) }
func (e *nodeElt) parent() uint16 { return uint16(*e >> 32) }
func (e *nodeElt) symbol() byte { return byte(*e >> 48) }
func (e *nodeElt) nbBits() uint8 { return uint8(*e >> 56) }

func (e *nodeElt) setCount(c uint32) { *e = (*e)&0xffffffff00000000 | nodeElt(c) }
func (e *nodeElt) setParent(p int16) { *e = (*e)&0xffff0000ffffffff | nodeElt(uint16(p))<<32 }
func (e *nodeElt) setNbBits(n uint8) { *e = (*e)&0x00ffffffffffffff | nodeElt(n)<<56 }