Files
2026-03-31 20:02:01 +00:00

536 lines
14 KiB
Go

// Copyright 2023+ Klaus Post. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package dict
import (
"bytes"
"encoding/binary"
"errors"
"fmt"
"io"
"math/rand"
"sort"
"time"
"github.com/klauspost/compress/s2"
"github.com/klauspost/compress/zstd"
)
type match struct {
hash uint32
n uint32
offset int64
}
type matchValue struct {
value []byte
followBy map[uint32]uint32
preceededBy map[uint32]uint32
}
type Options struct {
// MaxDictSize is the max size of the backreference dictionary.
MaxDictSize int
// HashBytes is the minimum length to index.
// Must be >=4 and <=8
HashBytes int
// Debug output
Output io.Writer
// ZstdDictID is the Zstd dictionary ID to use.
// Leave at zero to generate a random ID.
ZstdDictID uint32
// ZstdDictCompat will make the dictionary compatible with Zstd v1.5.5 and earlier.
// See https://github.com/facebook/zstd/issues/3724
ZstdDictCompat bool
// Use the specified encoder level for Zstandard dictionaries.
// The dictionary will be built using the specified encoder level,
// which will reflect speed and make the dictionary tailored for that level.
// If not set zstd.SpeedBestCompression will be used.
ZstdLevel zstd.EncoderLevel
outFormat int
}
const (
formatRaw = iota
formatZstd
formatS2
)
// BuildZstdDict will build a Zstandard dictionary from the provided input.
func BuildZstdDict(input [][]byte, o Options) ([]byte, error) {
o.outFormat = formatZstd
if o.ZstdDictID == 0 {
rng := rand.New(rand.NewSource(time.Now().UnixNano()))
o.ZstdDictID = 32768 + uint32(rng.Int31n((1<<31)-32768))
}
return buildDict(input, o)
}
// BuildS2Dict will build a S2 dictionary from the provided input.
func BuildS2Dict(input [][]byte, o Options) ([]byte, error) {
o.outFormat = formatS2
if o.MaxDictSize > s2.MaxDictSize {
return nil, errors.New("max dict size too large")
}
return buildDict(input, o)
}
// BuildRawDict will build a raw dictionary from the provided input.
// This can be used for deflate, lz4 and others.
func BuildRawDict(input [][]byte, o Options) ([]byte, error) {
o.outFormat = formatRaw
return buildDict(input, o)
}
func buildDict(input [][]byte, o Options) ([]byte, error) {
matches := make(map[uint32]uint32)
offsets := make(map[uint32]int64)
var total uint64
wantLen := o.MaxDictSize
hashBytes := o.HashBytes
if len(input) == 0 {
return nil, fmt.Errorf("no input provided")
}
if hashBytes < 4 || hashBytes > 8 {
return nil, fmt.Errorf("HashBytes must be >= 4 and <= 8")
}
println := func(args ...interface{}) {
if o.Output != nil {
fmt.Fprintln(o.Output, args...)
}
}
printf := func(s string, args ...interface{}) {
if o.Output != nil {
fmt.Fprintf(o.Output, s, args...)
}
}
found := make(map[uint32]struct{})
for i, b := range input {
for k := range found {
delete(found, k)
}
for i := range b {
rem := b[i:]
if len(rem) < 8 {
break
}
h := hashLen(binary.LittleEndian.Uint64(rem), 32, uint8(hashBytes))
if _, ok := found[h]; ok {
// Only count first occurrence
continue
}
matches[h]++
offsets[h] += int64(i)
total++
found[h] = struct{}{}
}
printf("\r input %d indexed...", i)
}
threshold := uint32(total / uint64(len(matches)))
println("\nTotal", total, "match", len(matches), "avg", threshold)
sorted := make([]match, 0, len(matches)/2)
for k, v := range matches {
if v <= threshold {
continue
}
sorted = append(sorted, match{hash: k, n: v, offset: offsets[k]})
}
sort.Slice(sorted, func(i, j int) bool {
if true {
// Group very similar counts together and emit low offsets first.
// This will keep together strings that are very similar.
deltaN := int(sorted[i].n) - int(sorted[j].n)
if deltaN < 0 {
deltaN = -deltaN
}
if uint32(deltaN) < sorted[i].n/32 {
return sorted[i].offset < sorted[j].offset
}
} else {
if sorted[i].n == sorted[j].n {
return sorted[i].offset < sorted[j].offset
}
}
return sorted[i].n > sorted[j].n
})
println("Sorted len:", len(sorted))
if len(sorted) > wantLen {
sorted = sorted[:wantLen]
}
lowestOcc := sorted[len(sorted)-1].n
println("Cropped len:", len(sorted), "Lowest occurrence:", lowestOcc)
wantMatches := make(map[uint32]uint32, len(sorted))
for _, v := range sorted {
wantMatches[v.hash] = v.n
}
output := make(map[uint32]matchValue, len(sorted))
var remainCnt [256]int
var remainTotal int
var firstOffsets []int
for i, b := range input {
for i := range b {
rem := b[i:]
if len(rem) < 8 {
break
}
var prev []byte
if i > hashBytes {
prev = b[i-hashBytes:]
}
h := hashLen(binary.LittleEndian.Uint64(rem), 32, uint8(hashBytes))
if _, ok := wantMatches[h]; !ok {
remainCnt[rem[0]]++
remainTotal++
continue
}
mv := output[h]
if len(mv.value) == 0 {
var tmp = make([]byte, hashBytes)
copy(tmp[:], rem)
mv.value = tmp[:]
}
if mv.followBy == nil {
mv.followBy = make(map[uint32]uint32, 4)
mv.preceededBy = make(map[uint32]uint32, 4)
}
if len(rem) > hashBytes+8 {
// Check if we should add next as well.
hNext := hashLen(binary.LittleEndian.Uint64(rem[hashBytes:]), 32, uint8(hashBytes))
if _, ok := wantMatches[hNext]; ok {
mv.followBy[hNext]++
}
}
if len(prev) >= 8 {
// Check if we should prev next as well.
hPrev := hashLen(binary.LittleEndian.Uint64(prev), 32, uint8(hashBytes))
if _, ok := wantMatches[hPrev]; ok {
mv.preceededBy[hPrev]++
}
}
output[h] = mv
}
printf("\rinput %d re-indexed...", i)
}
println("")
dst := make([][]byte, 0, wantLen/hashBytes)
added := 0
const printUntil = 500
for i, e := range sorted {
if added > o.MaxDictSize {
println("Ending. Next Occurrence:", e.n)
break
}
m, ok := output[e.hash]
if !ok {
// Already added
continue
}
wantLen := e.n / uint32(hashBytes) / 4
if wantLen <= lowestOcc {
wantLen = lowestOcc
}
var tmp = make([]byte, 0, hashBytes*2)
{
sortedPrev := make([]match, 0, len(m.followBy))
for k, v := range m.preceededBy {
if _, ok := output[k]; v < wantLen || !ok {
continue
}
sortedPrev = append(sortedPrev, match{
hash: k,
n: v,
})
}
if len(sortedPrev) > 0 {
sort.Slice(sortedPrev, func(i, j int) bool {
return sortedPrev[i].n > sortedPrev[j].n
})
bestPrev := output[sortedPrev[0].hash]
tmp = append(tmp, bestPrev.value...)
}
}
tmp = append(tmp, m.value...)
delete(output, e.hash)
sortedFollow := make([]match, 0, len(m.followBy))
for {
var nh uint32 // Next hash
stopAfter := false
{
sortedFollow = sortedFollow[:0]
for k, v := range m.followBy {
if _, ok := output[k]; !ok {
continue
}
sortedFollow = append(sortedFollow, match{
hash: k,
n: v,
offset: offsets[k],
})
}
if len(sortedFollow) == 0 {
// Step back
// Extremely small impact, but helps longer hashes a bit.
const stepBack = 2
if stepBack > 0 && len(tmp) >= hashBytes+stepBack {
var t8 [8]byte
copy(t8[:], tmp[len(tmp)-hashBytes-stepBack:])
m, ok = output[hashLen(binary.LittleEndian.Uint64(t8[:]), 32, uint8(hashBytes))]
if ok && len(m.followBy) > 0 {
found := []byte(nil)
for k := range m.followBy {
v, ok := output[k]
if !ok {
continue
}
found = v.value
break
}
if found != nil {
tmp = tmp[:len(tmp)-stepBack]
printf("Step back: %q + %q\n", string(tmp), string(found))
continue
}
}
break
} else {
if i < printUntil {
printf("FOLLOW: none after %q\n", string(m.value))
}
}
break
}
sort.Slice(sortedFollow, func(i, j int) bool {
if sortedFollow[i].n == sortedFollow[j].n {
return sortedFollow[i].offset > sortedFollow[j].offset
}
return sortedFollow[i].n > sortedFollow[j].n
})
nh = sortedFollow[0].hash
stopAfter = sortedFollow[0].n < wantLen
if stopAfter && i < printUntil {
printf("FOLLOW: %d < %d after %q. Stopping after this.\n", sortedFollow[0].n, wantLen, string(m.value))
}
}
m, ok = output[nh]
if !ok {
break
}
if len(tmp) > 0 {
// Delete all hashes that are in the current string to avoid stuttering.
var toDel [16 + 8]byte
copy(toDel[:], tmp[len(tmp)-hashBytes:])
copy(toDel[hashBytes:], m.value)
for i := range toDel[:hashBytes*2] {
delete(output, hashLen(binary.LittleEndian.Uint64(toDel[i:]), 32, uint8(hashBytes)))
}
}
tmp = append(tmp, m.value...)
//delete(output, nh)
if stopAfter {
// Last entry was no significant.
break
}
}
if i < printUntil {
printf("ENTRY %d: %q (%d occurrences, cutoff %d)\n", i, string(tmp), e.n, wantLen)
}
// Delete substrings already added.
if len(tmp) > hashBytes {
for j := range tmp[:len(tmp)-hashBytes+1] {
var t8 [8]byte
copy(t8[:], tmp[j:])
if i < printUntil {
//printf("* POST DELETE %q\n", string(t8[:hashBytes]))
}
delete(output, hashLen(binary.LittleEndian.Uint64(t8[:]), 32, uint8(hashBytes)))
}
}
dst = append(dst, tmp)
added += len(tmp)
// Find offsets
// TODO: This can be better if done as a global search.
if len(firstOffsets) < 3 {
if len(tmp) > 16 {
tmp = tmp[:16]
}
offCnt := make(map[int]int, len(input))
// Find first offsets
for _, b := range input {
off := bytes.Index(b, tmp)
if off == -1 {
continue
}
offCnt[off]++
}
for _, off := range firstOffsets {
// Very unlikely, but we deleted it just in case
delete(offCnt, off-added)
}
maxCnt := 0
maxOffset := 0
for k, v := range offCnt {
if v == maxCnt && k > maxOffset {
// Prefer the longer offset on ties , since it is more expensive to encode
maxCnt = v
maxOffset = k
continue
}
if v > maxCnt {
maxCnt = v
maxOffset = k
}
}
if maxCnt > 1 {
firstOffsets = append(firstOffsets, maxOffset+added)
println(" - Offset:", len(firstOffsets), "at", maxOffset+added, "count:", maxCnt, "total added:", added, "src index", maxOffset)
}
}
}
out := bytes.NewBuffer(nil)
written := 0
for i, toWrite := range dst {
if len(toWrite)+written > wantLen {
toWrite = toWrite[:wantLen-written]
}
dst[i] = toWrite
written += len(toWrite)
if written >= wantLen {
dst = dst[:i+1]
break
}
}
// Write in reverse order.
for i := range dst {
toWrite := dst[len(dst)-i-1]
out.Write(toWrite)
}
if o.outFormat == formatRaw {
return out.Bytes(), nil
}
if o.outFormat == formatS2 {
dOff := 0
dBytes := out.Bytes()
if len(dBytes) > s2.MaxDictSize {
dBytes = dBytes[:s2.MaxDictSize]
}
for _, off := range firstOffsets {
myOff := len(dBytes) - off
if myOff < 0 || myOff > s2.MaxDictSrcOffset {
continue
}
dOff = myOff
}
dict := s2.MakeDictManual(dBytes, uint16(dOff))
if dict == nil {
return nil, fmt.Errorf("unable to create s2 dictionary")
}
return dict.Bytes(), nil
}
offsetsZstd := [3]int{1, 4, 8}
for i, off := range firstOffsets {
if i >= 3 || off == 0 || off >= out.Len() {
break
}
offsetsZstd[i] = off
}
println("\nCompressing. Offsets:", offsetsZstd)
return zstd.BuildDict(zstd.BuildDictOptions{
ID: o.ZstdDictID,
Contents: input,
History: out.Bytes(),
Offsets: offsetsZstd,
CompatV155: o.ZstdDictCompat,
Level: o.ZstdLevel,
DebugOut: o.Output,
})
}
const (
prime3bytes = 506832829
prime4bytes = 2654435761
prime5bytes = 889523592379
prime6bytes = 227718039650203
prime7bytes = 58295818150454627
prime8bytes = 0xcf1bbcdcb7a56463
)
// hashLen returns a hash of the lowest l bytes of u for a size size of h bytes.
// l must be >=4 and <=8. Any other value will return hash for 4 bytes.
// h should always be <32.
// Preferably h and l should be a constant.
// LENGTH 4 is passed straight through
func hashLen(u uint64, hashLog, mls uint8) uint32 {
switch mls {
case 5:
return hash5(u, hashLog)
case 6:
return hash6(u, hashLog)
case 7:
return hash7(u, hashLog)
case 8:
return hash8(u, hashLog)
default:
return uint32(u)
}
}
// hash3 returns the hash of the lower 3 bytes of u to fit in a hash table with h bits.
// Preferably h should be a constant and should always be <32.
func hash3(u uint32, h uint8) uint32 {
return ((u << (32 - 24)) * prime3bytes) >> ((32 - h) & 31)
}
// hash4 returns the hash of u to fit in a hash table with h bits.
// Preferably h should be a constant and should always be <32.
func hash4(u uint32, h uint8) uint32 {
return (u * prime4bytes) >> ((32 - h) & 31)
}
// hash4x64 returns the hash of the lowest 4 bytes of u to fit in a hash table with h bits.
// Preferably h should be a constant and should always be <32.
func hash4x64(u uint64, h uint8) uint32 {
return (uint32(u) * prime4bytes) >> ((32 - h) & 31)
}
// hash5 returns the hash of the lowest 5 bytes of u to fit in a hash table with h bits.
// Preferably h should be a constant and should always be <64.
func hash5(u uint64, h uint8) uint32 {
return uint32(((u << (64 - 40)) * prime5bytes) >> ((64 - h) & 63))
}
// hash6 returns the hash of the lowest 6 bytes of u to fit in a hash table with h bits.
// Preferably h should be a constant and should always be <64.
func hash6(u uint64, h uint8) uint32 {
return uint32(((u << (64 - 48)) * prime6bytes) >> ((64 - h) & 63))
}
// hash7 returns the hash of the lowest 7 bytes of u to fit in a hash table with h bits.
// Preferably h should be a constant and should always be <64.
func hash7(u uint64, h uint8) uint32 {
return uint32(((u << (64 - 56)) * prime7bytes) >> ((64 - h) & 63))
}
// hash8 returns the hash of u to fit in a hash table with h bits.
// Preferably h should be a constant and should always be <64.
func hash8(u uint64, h uint8) uint32 {
return uint32((u * prime8bytes) >> ((64 - h) & 63))
}