536 lines
14 KiB
Go
536 lines
14 KiB
Go
// Copyright 2023+ Klaus Post. All rights reserved.
|
|
// Use of this source code is governed by a BSD-style
|
|
// license that can be found in the LICENSE file.
|
|
|
|
package dict
|
|
|
|
import (
|
|
"bytes"
|
|
"encoding/binary"
|
|
"errors"
|
|
"fmt"
|
|
"io"
|
|
"math/rand"
|
|
"sort"
|
|
"time"
|
|
|
|
"github.com/klauspost/compress/s2"
|
|
"github.com/klauspost/compress/zstd"
|
|
)
|
|
|
|
type match struct {
|
|
hash uint32
|
|
n uint32
|
|
offset int64
|
|
}
|
|
|
|
type matchValue struct {
|
|
value []byte
|
|
followBy map[uint32]uint32
|
|
preceededBy map[uint32]uint32
|
|
}
|
|
|
|
type Options struct {
|
|
// MaxDictSize is the max size of the backreference dictionary.
|
|
MaxDictSize int
|
|
|
|
// HashBytes is the minimum length to index.
|
|
// Must be >=4 and <=8
|
|
HashBytes int
|
|
|
|
// Debug output
|
|
Output io.Writer
|
|
|
|
// ZstdDictID is the Zstd dictionary ID to use.
|
|
// Leave at zero to generate a random ID.
|
|
ZstdDictID uint32
|
|
|
|
// ZstdDictCompat will make the dictionary compatible with Zstd v1.5.5 and earlier.
|
|
// See https://github.com/facebook/zstd/issues/3724
|
|
ZstdDictCompat bool
|
|
|
|
// Use the specified encoder level for Zstandard dictionaries.
|
|
// The dictionary will be built using the specified encoder level,
|
|
// which will reflect speed and make the dictionary tailored for that level.
|
|
// If not set zstd.SpeedBestCompression will be used.
|
|
ZstdLevel zstd.EncoderLevel
|
|
|
|
outFormat int
|
|
}
|
|
|
|
const (
|
|
formatRaw = iota
|
|
formatZstd
|
|
formatS2
|
|
)
|
|
|
|
// BuildZstdDict will build a Zstandard dictionary from the provided input.
|
|
func BuildZstdDict(input [][]byte, o Options) ([]byte, error) {
|
|
o.outFormat = formatZstd
|
|
if o.ZstdDictID == 0 {
|
|
rng := rand.New(rand.NewSource(time.Now().UnixNano()))
|
|
o.ZstdDictID = 32768 + uint32(rng.Int31n((1<<31)-32768))
|
|
}
|
|
return buildDict(input, o)
|
|
}
|
|
|
|
// BuildS2Dict will build a S2 dictionary from the provided input.
|
|
func BuildS2Dict(input [][]byte, o Options) ([]byte, error) {
|
|
o.outFormat = formatS2
|
|
if o.MaxDictSize > s2.MaxDictSize {
|
|
return nil, errors.New("max dict size too large")
|
|
}
|
|
return buildDict(input, o)
|
|
}
|
|
|
|
// BuildRawDict will build a raw dictionary from the provided input.
|
|
// This can be used for deflate, lz4 and others.
|
|
func BuildRawDict(input [][]byte, o Options) ([]byte, error) {
|
|
o.outFormat = formatRaw
|
|
return buildDict(input, o)
|
|
}
|
|
|
|
func buildDict(input [][]byte, o Options) ([]byte, error) {
|
|
matches := make(map[uint32]uint32)
|
|
offsets := make(map[uint32]int64)
|
|
var total uint64
|
|
|
|
wantLen := o.MaxDictSize
|
|
hashBytes := o.HashBytes
|
|
if len(input) == 0 {
|
|
return nil, fmt.Errorf("no input provided")
|
|
}
|
|
if hashBytes < 4 || hashBytes > 8 {
|
|
return nil, fmt.Errorf("HashBytes must be >= 4 and <= 8")
|
|
}
|
|
println := func(args ...interface{}) {
|
|
if o.Output != nil {
|
|
fmt.Fprintln(o.Output, args...)
|
|
}
|
|
}
|
|
printf := func(s string, args ...interface{}) {
|
|
if o.Output != nil {
|
|
fmt.Fprintf(o.Output, s, args...)
|
|
}
|
|
}
|
|
found := make(map[uint32]struct{})
|
|
for i, b := range input {
|
|
for k := range found {
|
|
delete(found, k)
|
|
}
|
|
for i := range b {
|
|
rem := b[i:]
|
|
if len(rem) < 8 {
|
|
break
|
|
}
|
|
h := hashLen(binary.LittleEndian.Uint64(rem), 32, uint8(hashBytes))
|
|
if _, ok := found[h]; ok {
|
|
// Only count first occurrence
|
|
continue
|
|
}
|
|
matches[h]++
|
|
offsets[h] += int64(i)
|
|
total++
|
|
found[h] = struct{}{}
|
|
}
|
|
printf("\r input %d indexed...", i)
|
|
}
|
|
threshold := uint32(total / uint64(len(matches)))
|
|
println("\nTotal", total, "match", len(matches), "avg", threshold)
|
|
sorted := make([]match, 0, len(matches)/2)
|
|
for k, v := range matches {
|
|
if v <= threshold {
|
|
continue
|
|
}
|
|
sorted = append(sorted, match{hash: k, n: v, offset: offsets[k]})
|
|
}
|
|
sort.Slice(sorted, func(i, j int) bool {
|
|
if true {
|
|
// Group very similar counts together and emit low offsets first.
|
|
// This will keep together strings that are very similar.
|
|
deltaN := int(sorted[i].n) - int(sorted[j].n)
|
|
if deltaN < 0 {
|
|
deltaN = -deltaN
|
|
}
|
|
if uint32(deltaN) < sorted[i].n/32 {
|
|
return sorted[i].offset < sorted[j].offset
|
|
}
|
|
} else {
|
|
if sorted[i].n == sorted[j].n {
|
|
return sorted[i].offset < sorted[j].offset
|
|
}
|
|
}
|
|
return sorted[i].n > sorted[j].n
|
|
})
|
|
println("Sorted len:", len(sorted))
|
|
if len(sorted) > wantLen {
|
|
sorted = sorted[:wantLen]
|
|
}
|
|
lowestOcc := sorted[len(sorted)-1].n
|
|
println("Cropped len:", len(sorted), "Lowest occurrence:", lowestOcc)
|
|
|
|
wantMatches := make(map[uint32]uint32, len(sorted))
|
|
for _, v := range sorted {
|
|
wantMatches[v.hash] = v.n
|
|
}
|
|
|
|
output := make(map[uint32]matchValue, len(sorted))
|
|
var remainCnt [256]int
|
|
var remainTotal int
|
|
var firstOffsets []int
|
|
for i, b := range input {
|
|
for i := range b {
|
|
rem := b[i:]
|
|
if len(rem) < 8 {
|
|
break
|
|
}
|
|
var prev []byte
|
|
if i > hashBytes {
|
|
prev = b[i-hashBytes:]
|
|
}
|
|
|
|
h := hashLen(binary.LittleEndian.Uint64(rem), 32, uint8(hashBytes))
|
|
if _, ok := wantMatches[h]; !ok {
|
|
remainCnt[rem[0]]++
|
|
remainTotal++
|
|
continue
|
|
}
|
|
mv := output[h]
|
|
if len(mv.value) == 0 {
|
|
var tmp = make([]byte, hashBytes)
|
|
copy(tmp[:], rem)
|
|
mv.value = tmp[:]
|
|
}
|
|
if mv.followBy == nil {
|
|
mv.followBy = make(map[uint32]uint32, 4)
|
|
mv.preceededBy = make(map[uint32]uint32, 4)
|
|
}
|
|
if len(rem) > hashBytes+8 {
|
|
// Check if we should add next as well.
|
|
hNext := hashLen(binary.LittleEndian.Uint64(rem[hashBytes:]), 32, uint8(hashBytes))
|
|
if _, ok := wantMatches[hNext]; ok {
|
|
mv.followBy[hNext]++
|
|
}
|
|
}
|
|
if len(prev) >= 8 {
|
|
// Check if we should prev next as well.
|
|
hPrev := hashLen(binary.LittleEndian.Uint64(prev), 32, uint8(hashBytes))
|
|
if _, ok := wantMatches[hPrev]; ok {
|
|
mv.preceededBy[hPrev]++
|
|
}
|
|
}
|
|
output[h] = mv
|
|
}
|
|
printf("\rinput %d re-indexed...", i)
|
|
}
|
|
println("")
|
|
dst := make([][]byte, 0, wantLen/hashBytes)
|
|
added := 0
|
|
const printUntil = 500
|
|
for i, e := range sorted {
|
|
if added > o.MaxDictSize {
|
|
println("Ending. Next Occurrence:", e.n)
|
|
break
|
|
}
|
|
m, ok := output[e.hash]
|
|
if !ok {
|
|
// Already added
|
|
continue
|
|
}
|
|
wantLen := e.n / uint32(hashBytes) / 4
|
|
if wantLen <= lowestOcc {
|
|
wantLen = lowestOcc
|
|
}
|
|
|
|
var tmp = make([]byte, 0, hashBytes*2)
|
|
{
|
|
sortedPrev := make([]match, 0, len(m.followBy))
|
|
for k, v := range m.preceededBy {
|
|
if _, ok := output[k]; v < wantLen || !ok {
|
|
continue
|
|
}
|
|
sortedPrev = append(sortedPrev, match{
|
|
hash: k,
|
|
n: v,
|
|
})
|
|
}
|
|
if len(sortedPrev) > 0 {
|
|
sort.Slice(sortedPrev, func(i, j int) bool {
|
|
return sortedPrev[i].n > sortedPrev[j].n
|
|
})
|
|
bestPrev := output[sortedPrev[0].hash]
|
|
tmp = append(tmp, bestPrev.value...)
|
|
}
|
|
}
|
|
tmp = append(tmp, m.value...)
|
|
delete(output, e.hash)
|
|
|
|
sortedFollow := make([]match, 0, len(m.followBy))
|
|
for {
|
|
var nh uint32 // Next hash
|
|
stopAfter := false
|
|
{
|
|
sortedFollow = sortedFollow[:0]
|
|
for k, v := range m.followBy {
|
|
if _, ok := output[k]; !ok {
|
|
continue
|
|
}
|
|
sortedFollow = append(sortedFollow, match{
|
|
hash: k,
|
|
n: v,
|
|
offset: offsets[k],
|
|
})
|
|
}
|
|
if len(sortedFollow) == 0 {
|
|
// Step back
|
|
// Extremely small impact, but helps longer hashes a bit.
|
|
const stepBack = 2
|
|
if stepBack > 0 && len(tmp) >= hashBytes+stepBack {
|
|
var t8 [8]byte
|
|
copy(t8[:], tmp[len(tmp)-hashBytes-stepBack:])
|
|
m, ok = output[hashLen(binary.LittleEndian.Uint64(t8[:]), 32, uint8(hashBytes))]
|
|
if ok && len(m.followBy) > 0 {
|
|
found := []byte(nil)
|
|
for k := range m.followBy {
|
|
v, ok := output[k]
|
|
if !ok {
|
|
continue
|
|
}
|
|
found = v.value
|
|
break
|
|
}
|
|
if found != nil {
|
|
tmp = tmp[:len(tmp)-stepBack]
|
|
printf("Step back: %q + %q\n", string(tmp), string(found))
|
|
continue
|
|
}
|
|
}
|
|
break
|
|
} else {
|
|
if i < printUntil {
|
|
printf("FOLLOW: none after %q\n", string(m.value))
|
|
}
|
|
}
|
|
break
|
|
}
|
|
sort.Slice(sortedFollow, func(i, j int) bool {
|
|
if sortedFollow[i].n == sortedFollow[j].n {
|
|
return sortedFollow[i].offset > sortedFollow[j].offset
|
|
}
|
|
return sortedFollow[i].n > sortedFollow[j].n
|
|
})
|
|
nh = sortedFollow[0].hash
|
|
stopAfter = sortedFollow[0].n < wantLen
|
|
if stopAfter && i < printUntil {
|
|
printf("FOLLOW: %d < %d after %q. Stopping after this.\n", sortedFollow[0].n, wantLen, string(m.value))
|
|
}
|
|
}
|
|
m, ok = output[nh]
|
|
if !ok {
|
|
break
|
|
}
|
|
if len(tmp) > 0 {
|
|
// Delete all hashes that are in the current string to avoid stuttering.
|
|
var toDel [16 + 8]byte
|
|
copy(toDel[:], tmp[len(tmp)-hashBytes:])
|
|
copy(toDel[hashBytes:], m.value)
|
|
for i := range toDel[:hashBytes*2] {
|
|
delete(output, hashLen(binary.LittleEndian.Uint64(toDel[i:]), 32, uint8(hashBytes)))
|
|
}
|
|
}
|
|
tmp = append(tmp, m.value...)
|
|
//delete(output, nh)
|
|
if stopAfter {
|
|
// Last entry was no significant.
|
|
break
|
|
}
|
|
}
|
|
if i < printUntil {
|
|
printf("ENTRY %d: %q (%d occurrences, cutoff %d)\n", i, string(tmp), e.n, wantLen)
|
|
}
|
|
// Delete substrings already added.
|
|
if len(tmp) > hashBytes {
|
|
for j := range tmp[:len(tmp)-hashBytes+1] {
|
|
var t8 [8]byte
|
|
copy(t8[:], tmp[j:])
|
|
if i < printUntil {
|
|
//printf("* POST DELETE %q\n", string(t8[:hashBytes]))
|
|
}
|
|
delete(output, hashLen(binary.LittleEndian.Uint64(t8[:]), 32, uint8(hashBytes)))
|
|
}
|
|
}
|
|
dst = append(dst, tmp)
|
|
added += len(tmp)
|
|
// Find offsets
|
|
// TODO: This can be better if done as a global search.
|
|
if len(firstOffsets) < 3 {
|
|
if len(tmp) > 16 {
|
|
tmp = tmp[:16]
|
|
}
|
|
offCnt := make(map[int]int, len(input))
|
|
// Find first offsets
|
|
for _, b := range input {
|
|
off := bytes.Index(b, tmp)
|
|
if off == -1 {
|
|
continue
|
|
}
|
|
offCnt[off]++
|
|
}
|
|
for _, off := range firstOffsets {
|
|
// Very unlikely, but we deleted it just in case
|
|
delete(offCnt, off-added)
|
|
}
|
|
maxCnt := 0
|
|
maxOffset := 0
|
|
for k, v := range offCnt {
|
|
if v == maxCnt && k > maxOffset {
|
|
// Prefer the longer offset on ties , since it is more expensive to encode
|
|
maxCnt = v
|
|
maxOffset = k
|
|
continue
|
|
}
|
|
|
|
if v > maxCnt {
|
|
maxCnt = v
|
|
maxOffset = k
|
|
}
|
|
}
|
|
if maxCnt > 1 {
|
|
firstOffsets = append(firstOffsets, maxOffset+added)
|
|
println(" - Offset:", len(firstOffsets), "at", maxOffset+added, "count:", maxCnt, "total added:", added, "src index", maxOffset)
|
|
}
|
|
}
|
|
}
|
|
out := bytes.NewBuffer(nil)
|
|
written := 0
|
|
for i, toWrite := range dst {
|
|
if len(toWrite)+written > wantLen {
|
|
toWrite = toWrite[:wantLen-written]
|
|
}
|
|
dst[i] = toWrite
|
|
written += len(toWrite)
|
|
if written >= wantLen {
|
|
dst = dst[:i+1]
|
|
break
|
|
}
|
|
}
|
|
// Write in reverse order.
|
|
for i := range dst {
|
|
toWrite := dst[len(dst)-i-1]
|
|
out.Write(toWrite)
|
|
}
|
|
if o.outFormat == formatRaw {
|
|
return out.Bytes(), nil
|
|
}
|
|
|
|
if o.outFormat == formatS2 {
|
|
dOff := 0
|
|
dBytes := out.Bytes()
|
|
if len(dBytes) > s2.MaxDictSize {
|
|
dBytes = dBytes[:s2.MaxDictSize]
|
|
}
|
|
for _, off := range firstOffsets {
|
|
myOff := len(dBytes) - off
|
|
if myOff < 0 || myOff > s2.MaxDictSrcOffset {
|
|
continue
|
|
}
|
|
dOff = myOff
|
|
}
|
|
|
|
dict := s2.MakeDictManual(dBytes, uint16(dOff))
|
|
if dict == nil {
|
|
return nil, fmt.Errorf("unable to create s2 dictionary")
|
|
}
|
|
return dict.Bytes(), nil
|
|
}
|
|
|
|
offsetsZstd := [3]int{1, 4, 8}
|
|
for i, off := range firstOffsets {
|
|
if i >= 3 || off == 0 || off >= out.Len() {
|
|
break
|
|
}
|
|
offsetsZstd[i] = off
|
|
}
|
|
println("\nCompressing. Offsets:", offsetsZstd)
|
|
return zstd.BuildDict(zstd.BuildDictOptions{
|
|
ID: o.ZstdDictID,
|
|
Contents: input,
|
|
History: out.Bytes(),
|
|
Offsets: offsetsZstd,
|
|
CompatV155: o.ZstdDictCompat,
|
|
Level: o.ZstdLevel,
|
|
DebugOut: o.Output,
|
|
})
|
|
}
|
|
|
|
const (
|
|
prime3bytes = 506832829
|
|
prime4bytes = 2654435761
|
|
prime5bytes = 889523592379
|
|
prime6bytes = 227718039650203
|
|
prime7bytes = 58295818150454627
|
|
prime8bytes = 0xcf1bbcdcb7a56463
|
|
)
|
|
|
|
// hashLen returns a hash of the lowest l bytes of u for a size size of h bytes.
|
|
// l must be >=4 and <=8. Any other value will return hash for 4 bytes.
|
|
// h should always be <32.
|
|
// Preferably h and l should be a constant.
|
|
// LENGTH 4 is passed straight through
|
|
func hashLen(u uint64, hashLog, mls uint8) uint32 {
|
|
switch mls {
|
|
case 5:
|
|
return hash5(u, hashLog)
|
|
case 6:
|
|
return hash6(u, hashLog)
|
|
case 7:
|
|
return hash7(u, hashLog)
|
|
case 8:
|
|
return hash8(u, hashLog)
|
|
default:
|
|
return uint32(u)
|
|
}
|
|
}
|
|
|
|
// hash3 returns the hash of the lower 3 bytes of u to fit in a hash table with h bits.
|
|
// Preferably h should be a constant and should always be <32.
|
|
func hash3(u uint32, h uint8) uint32 {
|
|
return ((u << (32 - 24)) * prime3bytes) >> ((32 - h) & 31)
|
|
}
|
|
|
|
// hash4 returns the hash of u to fit in a hash table with h bits.
|
|
// Preferably h should be a constant and should always be <32.
|
|
func hash4(u uint32, h uint8) uint32 {
|
|
return (u * prime4bytes) >> ((32 - h) & 31)
|
|
}
|
|
|
|
// hash4x64 returns the hash of the lowest 4 bytes of u to fit in a hash table with h bits.
|
|
// Preferably h should be a constant and should always be <32.
|
|
func hash4x64(u uint64, h uint8) uint32 {
|
|
return (uint32(u) * prime4bytes) >> ((32 - h) & 31)
|
|
}
|
|
|
|
// hash5 returns the hash of the lowest 5 bytes of u to fit in a hash table with h bits.
|
|
// Preferably h should be a constant and should always be <64.
|
|
func hash5(u uint64, h uint8) uint32 {
|
|
return uint32(((u << (64 - 40)) * prime5bytes) >> ((64 - h) & 63))
|
|
}
|
|
|
|
// hash6 returns the hash of the lowest 6 bytes of u to fit in a hash table with h bits.
|
|
// Preferably h should be a constant and should always be <64.
|
|
func hash6(u uint64, h uint8) uint32 {
|
|
return uint32(((u << (64 - 48)) * prime6bytes) >> ((64 - h) & 63))
|
|
}
|
|
|
|
// hash7 returns the hash of the lowest 7 bytes of u to fit in a hash table with h bits.
|
|
// Preferably h should be a constant and should always be <64.
|
|
func hash7(u uint64, h uint8) uint32 {
|
|
return uint32(((u << (64 - 56)) * prime7bytes) >> ((64 - h) & 63))
|
|
}
|
|
|
|
// hash8 returns the hash of u to fit in a hash table with h bits.
|
|
// Preferably h should be a constant and should always be <64.
|
|
func hash8(u uint64, h uint8) uint32 {
|
|
return uint32((u * prime8bytes) >> ((64 - h) & 63))
|
|
}
|