Compare commits
2 Commits
fix/issue-
...
fix/issue-
| Author | SHA1 | Date | |
|---|---|---|---|
| 9f514a14f5 | |||
| 2107a8cd73 |
24
tasks.py
24
tasks.py
@@ -616,22 +616,26 @@ def normalize_candidate_entry(candidate, batch_id, index):
|
||||
|
||||
|
||||
def normalize_training_examples(examples, batch_id, tweet_ids, fallback_prompt, fallback_response):
|
||||
_CORE_FIELDS = {"prompt", "instruction", "response", "answer", "task_type"}
|
||||
normalized = []
|
||||
for index, example in enumerate(examples, start=1):
|
||||
prompt = str(example.get("prompt") or example.get("instruction") or "").strip()
|
||||
response = str(example.get("response") or example.get("answer") or "").strip()
|
||||
if not prompt or not response:
|
||||
continue
|
||||
normalized.append(
|
||||
{
|
||||
"example_id": f"{batch_id}-example-{index:02d}",
|
||||
"batch_id": batch_id,
|
||||
"task_type": str(example.get("task_type") or "analysis").strip() or "analysis",
|
||||
"prompt": prompt,
|
||||
"response": response,
|
||||
"tweet_ids": tweet_ids,
|
||||
}
|
||||
)
|
||||
entry = {
|
||||
"example_id": f"{batch_id}-example-{index:02d}",
|
||||
"batch_id": batch_id,
|
||||
"task_type": str(example.get("task_type") or "analysis").strip() or "analysis",
|
||||
"prompt": prompt,
|
||||
"response": response,
|
||||
"tweet_ids": tweet_ids,
|
||||
}
|
||||
# Preserve optional metadata fields (category, tags, source_issue, etc.)
|
||||
for key, value in example.items():
|
||||
if key not in _CORE_FIELDS and key not in entry and value is not None:
|
||||
entry[key] = value
|
||||
normalized.append(entry)
|
||||
if normalized:
|
||||
return normalized
|
||||
return [
|
||||
|
||||
@@ -323,6 +323,89 @@ class TestNormalizeTrainingExamples:
|
||||
assert result[0]["response"] == "A1"
|
||||
|
||||
|
||||
def test_metadata_category_preserved(self):
|
||||
"""Category metadata passes through normalization."""
|
||||
examples = [
|
||||
{"prompt": "Q1", "response": "A1", "category": "crisis-response"},
|
||||
]
|
||||
result = normalize_training_examples(
|
||||
examples, "b001", ["t1"], "fp", "fr"
|
||||
)
|
||||
assert len(result) == 1
|
||||
assert result[0]["category"] == "crisis-response"
|
||||
|
||||
def test_metadata_tags_preserved(self):
|
||||
"""Tags metadata passes through normalization."""
|
||||
examples = [
|
||||
{"prompt": "Q1", "response": "A1", "tags": ["manipulation", "edge-case"]},
|
||||
]
|
||||
result = normalize_training_examples(
|
||||
examples, "b001", ["t1"], "fp", "fr"
|
||||
)
|
||||
assert result[0]["tags"] == ["manipulation", "edge-case"]
|
||||
|
||||
def test_metadata_source_issue_preserved(self):
|
||||
"""Source issue metadata passes through normalization."""
|
||||
examples = [
|
||||
{"prompt": "Q1", "response": "A1", "source_issue": 598},
|
||||
]
|
||||
result = normalize_training_examples(
|
||||
examples, "b001", ["t1"], "fp", "fr"
|
||||
)
|
||||
assert result[0]["source_issue"] == 598
|
||||
|
||||
def test_multiple_metadata_fields_preserved(self):
|
||||
"""All metadata fields pass through together."""
|
||||
examples = [
|
||||
{
|
||||
"prompt": "Q1",
|
||||
"response": "A1",
|
||||
"category": "boundary-test",
|
||||
"tags": ["joking"],
|
||||
"source_issue": 598,
|
||||
"difficulty": "hard",
|
||||
},
|
||||
]
|
||||
result = normalize_training_examples(
|
||||
examples, "b001", ["t1"], "fp", "fr"
|
||||
)
|
||||
assert result[0]["category"] == "boundary-test"
|
||||
assert result[0]["tags"] == ["joking"]
|
||||
assert result[0]["source_issue"] == 598
|
||||
assert result[0]["difficulty"] == "hard"
|
||||
|
||||
def test_metadata_does_not_override_core_fields(self):
|
||||
"""Metadata cannot override core fields like prompt, response, batch_id."""
|
||||
examples = [
|
||||
{"prompt": "Q1", "response": "A1", "batch_id": "SHOULD_NOT_APPEAR"},
|
||||
]
|
||||
result = normalize_training_examples(
|
||||
examples, "b001", ["t1"], "fp", "fr"
|
||||
)
|
||||
assert result[0]["batch_id"] == "b001" # Original batch_id wins
|
||||
|
||||
def test_no_metadata_backward_compatible(self):
|
||||
"""Examples without metadata still work exactly as before."""
|
||||
examples = [
|
||||
{"prompt": "Q1", "response": "A1", "task_type": "analysis"},
|
||||
]
|
||||
result = normalize_training_examples(
|
||||
examples, "b001", ["t1"], "fp", "fr"
|
||||
)
|
||||
expected_keys = {"example_id", "batch_id", "task_type", "prompt", "response", "tweet_ids"}
|
||||
assert set(result[0].keys()) == expected_keys
|
||||
|
||||
def test_none_metadata_values_skipped(self):
|
||||
"""None metadata values are not added."""
|
||||
examples = [
|
||||
{"prompt": "Q1", "response": "A1", "category": None},
|
||||
]
|
||||
result = normalize_training_examples(
|
||||
examples, "b001", ["t1"], "fp", "fr"
|
||||
)
|
||||
assert "category" not in result[0]
|
||||
|
||||
|
||||
class TestNormalizeRubricScores:
|
||||
"""normalize_rubric_scores() cleans eval rubric output."""
|
||||
|
||||
|
||||
Reference in New Issue
Block a user