Compare commits
1 Commits
fix/687-tr
...
fix/691-tr
| Author | SHA1 | Date | |
|---|---|---|---|
| 8e14c1b7ec |
260
scripts/training_provenance.py
Normal file
260
scripts/training_provenance.py
Normal file
@@ -0,0 +1,260 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
[PROVENANCE] Training Pair Provenance Tracker
|
||||
Part of the Timmy Foundation tooling.
|
||||
|
||||
Adds, filters, and reports provenance metadata for JSONL training pairs.
|
||||
Tracks source_session_id, model, and timestamp for quality auditing.
|
||||
|
||||
Usage:
|
||||
# Tag pairs with provenance
|
||||
python3 scripts/training_provenance.py tag input.jsonl -o tagged.jsonl \
|
||||
--session abc123 --model nous/hermes-3
|
||||
|
||||
# Filter by model (exclude Anthropic-sourced)
|
||||
python3 scripts/training_provenance.py filter input.jsonl -o filtered.jsonl \
|
||||
--exclude-model anthropic
|
||||
|
||||
# Report: pair count by source model
|
||||
python3 scripts/training_provenance.py report input.jsonl
|
||||
|
||||
# Pipe support
|
||||
cat pairs.jsonl | python3 scripts/training_provenance.py report -
|
||||
"""
|
||||
|
||||
import sys
|
||||
import json
|
||||
import argparse
|
||||
from datetime import datetime, timezone
|
||||
from collections import Counter
|
||||
from typing import Dict, Any, Optional
|
||||
|
||||
|
||||
PROVENANCE_KEYS = ["source_session_id", "source_model", "source_timestamp"]
|
||||
|
||||
|
||||
def tag_pair(pair: Dict[str, Any], session_id: Optional[str] = None,
|
||||
model: Optional[str] = None) -> Dict[str, Any]:
|
||||
"""Add provenance metadata to a training pair."""
|
||||
meta = pair.get("_provenance", {})
|
||||
|
||||
if session_id:
|
||||
meta["source_session_id"] = session_id
|
||||
if model:
|
||||
meta["source_model"] = model
|
||||
meta["source_timestamp"] = datetime.now(timezone.utc).isoformat()
|
||||
|
||||
if meta:
|
||||
pair["_provenance"] = meta
|
||||
|
||||
return pair
|
||||
|
||||
|
||||
def filter_pairs(input_path: str, output_path: str,
|
||||
include_models: Optional[list] = None,
|
||||
exclude_models: Optional[list] = None,
|
||||
min_session_age: Optional[str] = None) -> Dict[str, Any]:
|
||||
"""Filter pairs by provenance metadata."""
|
||||
kept = []
|
||||
removed = []
|
||||
errors = 0
|
||||
|
||||
source = sys.stdin if input_path == "-" else open(input_path, "r")
|
||||
|
||||
try:
|
||||
for line in source:
|
||||
line = line.strip()
|
||||
if not line:
|
||||
continue
|
||||
try:
|
||||
pair = json.loads(line)
|
||||
except json.JSONDecodeError:
|
||||
errors += 1
|
||||
continue
|
||||
|
||||
prov = pair.get("_provenance", {})
|
||||
model = prov.get("source_model", "unknown")
|
||||
|
||||
should_keep = True
|
||||
|
||||
if include_models:
|
||||
should_keep = should_keep and model in include_models
|
||||
|
||||
if exclude_models:
|
||||
should_keep = should_keep and model not in exclude_models
|
||||
|
||||
if should_keep:
|
||||
kept.append(pair)
|
||||
else:
|
||||
removed.append(pair)
|
||||
finally:
|
||||
if source is not sys.stdin:
|
||||
source.close()
|
||||
|
||||
# Write output
|
||||
if output_path:
|
||||
out = sys.stdout if output_path == "-" else open(output_path, "w")
|
||||
try:
|
||||
for pair in kept:
|
||||
out.write(json.dumps(pair, ensure_ascii=False) + "\n")
|
||||
finally:
|
||||
if out is not sys.stdin:
|
||||
out.close()
|
||||
|
||||
return {
|
||||
"total": len(kept) + len(removed),
|
||||
"kept": len(kept),
|
||||
"filtered_out": len(removed),
|
||||
"errors": errors,
|
||||
}
|
||||
|
||||
|
||||
def report(input_path: str) -> Dict[str, Any]:
|
||||
"""Report pair counts by source model and session."""
|
||||
model_counts = Counter()
|
||||
session_counts = Counter()
|
||||
tagged = 0
|
||||
untagged = 0
|
||||
total = 0
|
||||
errors = 0
|
||||
|
||||
source = sys.stdin if input_path == "-" else open(input_path, "r")
|
||||
|
||||
try:
|
||||
for line in source:
|
||||
line = line.strip()
|
||||
if not line:
|
||||
continue
|
||||
try:
|
||||
pair = json.loads(line)
|
||||
except json.JSONDecodeError:
|
||||
errors += 1
|
||||
continue
|
||||
|
||||
total += 1
|
||||
prov = pair.get("_provenance", {})
|
||||
|
||||
if prov:
|
||||
tagged += 1
|
||||
model = prov.get("source_model", "unknown")
|
||||
session = prov.get("source_session_id", "unknown")
|
||||
model_counts[model] += 1
|
||||
session_counts[session] += 1
|
||||
else:
|
||||
untagged += 1
|
||||
finally:
|
||||
if source is not sys.stdin:
|
||||
source.close()
|
||||
|
||||
return {
|
||||
"total": total,
|
||||
"tagged": tagged,
|
||||
"untagged": untagged,
|
||||
"tag_rate": round(tagged / max(total, 1) * 100, 1),
|
||||
"by_model": dict(model_counts.most_common(20)),
|
||||
"by_session": dict(session_counts.most_common(10)),
|
||||
"errors": errors,
|
||||
}
|
||||
|
||||
|
||||
def stamp_command(input_path: str, output_path: str,
|
||||
session_id: Optional[str], model: Optional[str]) -> Dict[str, Any]:
|
||||
"""Tag all pairs in a file with provenance metadata."""
|
||||
tagged = 0
|
||||
skipped = 0
|
||||
errors = 0
|
||||
|
||||
source = sys.stdin if input_path == "-" else open(input_path, "r")
|
||||
out = sys.stdout if output_path == "-" else open(output_path, "w")
|
||||
|
||||
try:
|
||||
for line in source:
|
||||
line = line.strip()
|
||||
if not line:
|
||||
continue
|
||||
try:
|
||||
pair = json.loads(line)
|
||||
except json.JSONDecodeError:
|
||||
errors += 1
|
||||
continue
|
||||
|
||||
# Skip if already tagged with same model
|
||||
existing = pair.get("_provenance", {})
|
||||
if existing.get("source_model") == model and existing.get("source_session_id") == session_id:
|
||||
skipped += 1
|
||||
out.write(line + "\n")
|
||||
continue
|
||||
|
||||
pair = tag_pair(pair, session_id=session_id, model=model)
|
||||
out.write(json.dumps(pair, ensure_ascii=False) + "\n")
|
||||
tagged += 1
|
||||
finally:
|
||||
if source is not sys.stdin:
|
||||
source.close()
|
||||
if out is not sys.stdin:
|
||||
out.close()
|
||||
|
||||
return {"tagged": tagged, "skipped": skipped, "errors": errors}
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="Training pair provenance tracking")
|
||||
sub = parser.add_subparsers(dest="command", required=True)
|
||||
|
||||
# tag subcommand
|
||||
tag_p = sub.add_parser("tag", help="Tag pairs with provenance metadata")
|
||||
tag_p.add_argument("input", help="Input JSONL file (use - for stdin)")
|
||||
tag_p.add_argument("-o", "--output", default="-", help="Output JSONL file")
|
||||
tag_p.add_argument("--session", help="Source session ID")
|
||||
tag_p.add_argument("--model", help="Source model name")
|
||||
|
||||
# filter subcommand
|
||||
filt_p = sub.add_parser("filter", help="Filter pairs by provenance")
|
||||
filt_p.add_argument("input", help="Input JSONL file (use - for stdin)")
|
||||
filt_p.add_argument("-o", "--output", default="-", help="Output JSONL file")
|
||||
filt_p.add_argument("--include-model", action="append", help="Only include these models")
|
||||
filt_p.add_argument("--exclude-model", action="append", help="Exclude these models")
|
||||
|
||||
# report subcommand
|
||||
rpt_p = sub.add_parser("report", help="Report provenance statistics")
|
||||
rpt_p.add_argument("input", help="Input JSONL file (use - for stdin)")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.command == "tag":
|
||||
result = stamp_command(args.input, args.output, args.session, args.model)
|
||||
print(f"Tagged: {result['tagged']} Skipped: {result['skipped']} Errors: {result['errors']}", file=sys.stderr)
|
||||
|
||||
elif args.command == "filter":
|
||||
result = filter_pairs(
|
||||
args.input, args.output,
|
||||
include_models=args.include_model,
|
||||
exclude_models=args.exclude_model,
|
||||
)
|
||||
print(f"Total: {result['total']} Kept: {result['kept']} Filtered: {result['filtered_out']}", file=sys.stderr)
|
||||
|
||||
elif args.command == "report":
|
||||
result = report(args.input)
|
||||
print(f"Training Pair Provenance Report", file=sys.stderr)
|
||||
print(f"{'='*40}", file=sys.stderr)
|
||||
print(f"Total pairs: {result['total']}", file=sys.stderr)
|
||||
print(f"Tagged: {result['tagged']} ({result['tag_rate']}%)", file=sys.stderr)
|
||||
print(f"Untagged: {result['untagged']}", file=sys.stderr)
|
||||
|
||||
if result['by_model']:
|
||||
print(f"\nBy source model:", file=sys.stderr)
|
||||
for model, count in result['by_model'].items():
|
||||
print(f" {model}: {count}", file=sys.stderr)
|
||||
|
||||
if result['by_session']:
|
||||
print(f"\nBy source session (top 10):", file=sys.stderr)
|
||||
for session, count in result['by_session'].items():
|
||||
session_short = session[:12] + "..." if len(session) > 12 else session
|
||||
print(f" {session_short}: {count}", file=sys.stderr)
|
||||
|
||||
# Output JSON to stdout
|
||||
print(json.dumps(result, indent=2))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,266 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
[QUALITY] Training Data Quality Filter
|
||||
Part of the Timmy Foundation tooling.
|
||||
|
||||
Scores and filters JSONL training pairs on specificity, length ratio,
|
||||
and code correctness. Removes low-quality pairs and reports results.
|
||||
|
||||
Usage:
|
||||
python3 scripts/training_quality_filter.py input.jsonl -o filtered.jsonl
|
||||
python3 scripts/training_quality_filter.py input.jsonl --threshold 0.4
|
||||
cat input.jsonl | python3 scripts/training_quality_filter.py -
|
||||
"""
|
||||
|
||||
import sys
|
||||
import json
|
||||
import argparse
|
||||
import re
|
||||
from typing import Dict, Any, Tuple
|
||||
|
||||
DEFAULT_THRESHOLD = 0.35
|
||||
MIN_TERSE_LEN = 3
|
||||
MIN_RICH_LEN = 10
|
||||
|
||||
|
||||
def score_specificity(terse: str, rich: str) -> float:
|
||||
"""Score how specific the rich response is vs the terse prompt.
|
||||
|
||||
Higher score = more specific, actionable detail in the rich version.
|
||||
"""
|
||||
if not terse or not rich:
|
||||
return 0.0
|
||||
|
||||
# Ratio of unique words (higher = more varied/specific language)
|
||||
rich_words = rich.lower().split()
|
||||
terse_words = terse.lower().split()
|
||||
|
||||
if len(rich_words) < 3:
|
||||
return 0.1
|
||||
|
||||
unique_ratio = len(set(rich_words)) / len(rich_words)
|
||||
|
||||
# Check for concrete details: numbers, file paths, commands, code refs
|
||||
concrete_patterns = [
|
||||
r"\b\d+\b", # numbers
|
||||
r"[/\\]\w+", # file paths
|
||||
r"`[^`]+`", # inline code
|
||||
r"\b(fix|add|remove|update|create|delete|check|run|use)\b", # action verbs
|
||||
]
|
||||
concrete_count = sum(
|
||||
len(re.findall(p, rich, re.IGNORECASE)) for p in concrete_patterns
|
||||
)
|
||||
concrete_score = min(concrete_count / 5.0, 1.0)
|
||||
|
||||
# Length expansion ratio (rich should be meaningfully longer than terse)
|
||||
expansion = len(rich_words) / max(len(terse_words), 1)
|
||||
expansion_score = min(expansion / 5.0, 1.0)
|
||||
|
||||
return round(0.3 * unique_ratio + 0.4 * concrete_score + 0.3 * expansion_score, 3)
|
||||
|
||||
|
||||
def score_length_ratio(terse: str, rich: str) -> float:
|
||||
"""Score the length ratio between terse and rich.
|
||||
|
||||
Too short rich = low quality. Too long = possibly padded.
|
||||
Sweet spot: 3-15x expansion.
|
||||
"""
|
||||
if not terse or not rich:
|
||||
return 0.0
|
||||
|
||||
t_len = len(terse.split())
|
||||
r_len = len(rich.split())
|
||||
|
||||
if t_len < MIN_TERSE_LEN or r_len < MIN_RICH_LEN:
|
||||
return 0.1
|
||||
|
||||
ratio = r_len / max(t_len, 1)
|
||||
|
||||
if ratio < 1.5:
|
||||
return 0.2 # barely expanded
|
||||
elif ratio < 3.0:
|
||||
return 0.5 # some expansion
|
||||
elif ratio <= 15.0:
|
||||
return 1.0 # good expansion
|
||||
elif ratio <= 30.0:
|
||||
return 0.7 # possibly padded
|
||||
else:
|
||||
return 0.4 # very padded
|
||||
|
||||
|
||||
def score_code_correctness(terse: str, rich: str) -> float:
|
||||
"""Score code blocks in the rich response for basic correctness.
|
||||
|
||||
Checks for matching brackets, valid-looking syntax patterns.
|
||||
"""
|
||||
if not rich:
|
||||
return 0.5 # no code = neutral
|
||||
|
||||
code_blocks = re.findall(r"```(?:\w*)\n(.*?)```", rich, re.DOTALL)
|
||||
if not code_blocks:
|
||||
return 0.5 # no code blocks = neutral
|
||||
|
||||
scores = []
|
||||
for block in code_blocks:
|
||||
block_score = 1.0
|
||||
|
||||
# Check bracket balance
|
||||
for open_c, close_c in [("(", ")"), ("[", "]"), ("{", "}")]:
|
||||
if block.count(open_c) != block.count(close_c):
|
||||
block_score -= 0.3
|
||||
|
||||
# Check for common syntax errors
|
||||
if re.search(r"def \w+[^:]*\n(?!\s)", block):
|
||||
block_score -= 0.2 # missing colon or body
|
||||
|
||||
# Minimum viable code length
|
||||
if len(block.strip()) < 10:
|
||||
block_score -= 0.3
|
||||
|
||||
scores.append(max(block_score, 0.0))
|
||||
|
||||
return round(sum(scores) / len(scores), 3) if scores else 0.5
|
||||
|
||||
|
||||
def score_pair(pair: Dict[str, Any]) -> Tuple[float, Dict[str, float]]:
|
||||
"""Score a single training pair. Returns (total_score, breakdown)."""
|
||||
terse = pair.get("terse", "") or pair.get("prompt", "") or ""
|
||||
rich = pair.get("rich", "") or pair.get("response", "") or ""
|
||||
|
||||
spec = score_specificity(terse, rich)
|
||||
length = score_length_ratio(terse, rich)
|
||||
code = score_code_correctness(terse, rich)
|
||||
|
||||
# Weighted total
|
||||
total = round(0.4 * spec + 0.3 * length + 0.3 * code, 3)
|
||||
|
||||
return total, {"specificity": spec, "length_ratio": length, "code_correctness": code}
|
||||
|
||||
|
||||
def filter_pairs(input_path: str, output_path: str, threshold: float,
|
||||
report: bool = False) -> Dict[str, Any]:
|
||||
"""Filter JSONL training pairs by quality score."""
|
||||
kept = []
|
||||
removed = []
|
||||
errors = 0
|
||||
|
||||
source = sys.stdin if input_path == "-" else open(input_path, "r")
|
||||
|
||||
try:
|
||||
for line_num, line in enumerate(source, 1):
|
||||
line = line.strip()
|
||||
if not line:
|
||||
continue
|
||||
try:
|
||||
pair = json.loads(line)
|
||||
except json.JSONDecodeError:
|
||||
errors += 1
|
||||
continue
|
||||
|
||||
score, breakdown = score_pair(pair)
|
||||
entry = {**pair, "_quality_score": score, "_quality_breakdown": breakdown}
|
||||
|
||||
if score >= threshold:
|
||||
kept.append(entry)
|
||||
else:
|
||||
removed.append(entry)
|
||||
finally:
|
||||
if source is not sys.stdin:
|
||||
source.close()
|
||||
|
||||
# Write filtered output
|
||||
if output_path:
|
||||
out = sys.stdout if output_path == "-" else open(output_path, "w")
|
||||
try:
|
||||
for pair in kept:
|
||||
# Strip internal scoring fields before output
|
||||
clean = {k: v for k, v in pair.items() if not k.startswith("_quality")}
|
||||
out.write(json.dumps(clean, ensure_ascii=False) + "\n")
|
||||
finally:
|
||||
if out is not sys.stdin:
|
||||
out.close()
|
||||
|
||||
result = {
|
||||
"total": len(kept) + len(removed),
|
||||
"kept": len(kept),
|
||||
"filtered_out": len(removed),
|
||||
"errors": errors,
|
||||
"threshold": threshold,
|
||||
"filter_rate": round(len(removed) / max(len(kept) + len(removed), 1) * 100, 1),
|
||||
}
|
||||
|
||||
if report and removed:
|
||||
# Show worst offenders
|
||||
removed_sorted = sorted(removed, key=lambda x: x["_quality_score"])
|
||||
result["worst_5"] = [
|
||||
{
|
||||
"score": e["_quality_score"],
|
||||
"terse": (e.get("terse", "") or e.get("prompt", ""))[:80],
|
||||
"breakdown": e["_quality_breakdown"],
|
||||
}
|
||||
for e in removed_sorted[:5]
|
||||
]
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="Filter training data pairs by quality")
|
||||
parser.add_argument("input", help="Input JSONL file (use - for stdin)")
|
||||
parser.add_argument("-o", "--output", default="-", help="Output JSONL file (default: stdout)")
|
||||
parser.add_argument("-t", "--threshold", type=float, default=DEFAULT_THRESHOLD,
|
||||
help=f"Quality threshold (0.0-1.0, default: {DEFAULT_THRESHOLD})")
|
||||
parser.add_argument("--report", action="store_true", help="Show quality report")
|
||||
parser.add_argument("--dry-run", action="store_true", help="Score only, dont filter")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.dry_run:
|
||||
# Just score and report, no filtering
|
||||
source = sys.stdin if args.input == "-" else open(args.input, "r")
|
||||
scores = []
|
||||
try:
|
||||
for line in source:
|
||||
line = line.strip()
|
||||
if not line:
|
||||
continue
|
||||
try:
|
||||
pair = json.loads(line)
|
||||
except json.JSONDecodeError:
|
||||
continue
|
||||
score, breakdown = score_pair(pair)
|
||||
scores.append(score)
|
||||
finally:
|
||||
if source is not sys.stdin:
|
||||
source.close()
|
||||
|
||||
if scores:
|
||||
avg = sum(scores) / len(scores)
|
||||
below = sum(1 for s in scores if s < args.threshold)
|
||||
print(f"Total pairs: {len(scores)}")
|
||||
print(f"Average score: {avg:.3f}")
|
||||
print(f"Below threshold ({args.threshold}): {below} ({below/len(scores)*100:.1f}%)")
|
||||
print(f"Min: {min(scores):.3f} Max: {max(scores):.3f} Median: {sorted(scores)[len(scores)//2]:.3f}")
|
||||
return
|
||||
|
||||
result = filter_pairs(args.input, args.output, args.threshold, report=args.report)
|
||||
|
||||
print(f"Training Data Quality Filter", file=sys.stderr)
|
||||
print(f"{'='*40}", file=sys.stderr)
|
||||
print(f"Total pairs: {result['total']}", file=sys.stderr)
|
||||
print(f"Kept: {result['kept']}", file=sys.stderr)
|
||||
print(f"Filtered out: {result['filtered_out']} ({result['filter_rate']}%)", file=sys.stderr)
|
||||
print(f"Errors: {result['errors']}", file=sys.stderr)
|
||||
print(f"Threshold: {result['threshold']}", file=sys.stderr)
|
||||
|
||||
if args.report and "worst_5" in result:
|
||||
print(f"\nWorst 5 pairs:", file=sys.stderr)
|
||||
for w in result["worst_5"]:
|
||||
terse_preview = w["terse"][:60]
|
||||
print(f" [{w['score']:.3f}] {terse_preview}...", file=sys.stderr)
|
||||
bd = w["breakdown"]
|
||||
print(f" spec={bd['specificity']} length={bd['length_ratio']} code={bd['code_correctness']}", file=sys.stderr)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user