show cache pricing as well (if supported)

This commit is contained in:
emozilla
2026-04-05 22:39:02 -04:00
committed by Teknium
parent 0365f6202c
commit 3962bc84b7
2 changed files with 62 additions and 19 deletions

View File

@@ -2167,24 +2167,35 @@ def _prompt_model_selection(
has_pricing = bool(pricing and any(pricing.get(m) for m in ordered))
name_col = max((len(m) for m in ordered), default=0) + 2 if has_pricing else 0
# Pre-compute formatted prices and dynamic column width
_price_cache: dict[str, tuple[str, str]] = {}
# Pre-compute formatted prices and dynamic column widths
_price_cache: dict[str, tuple[str, str, str]] = {}
price_col = 3 # minimum width
cache_col = 0 # only set if any model has cache pricing
has_cache = False
if has_pricing:
for mid in ordered:
p = pricing.get(mid) # type: ignore[union-attr]
if p:
inp = _format_price_per_mtok(p.get("prompt", ""))
out = _format_price_per_mtok(p.get("completion", ""))
cache_read = p.get("input_cache_read", "")
cache = _format_price_per_mtok(cache_read) if cache_read else ""
if cache:
has_cache = True
else:
inp, out = "", ""
_price_cache[mid] = (inp, out)
inp, out, cache = "", "", ""
_price_cache[mid] = (inp, out, cache)
price_col = max(price_col, len(inp), len(out))
cache_col = max(cache_col, len(cache))
if has_cache:
cache_col = max(cache_col, 5) # minimum: "Cache" header
def _label(mid):
if has_pricing:
inp, out = _price_cache.get(mid, ("", ""))
inp, out, cache = _price_cache.get(mid, ("", "", ""))
price_part = f" {inp:>{price_col}} {out:>{price_col}}"
if has_cache:
price_part += f" {cache:>{cache_col}}"
base = f"{mid:<{name_col}}{price_part}"
else:
base = mid
@@ -2198,8 +2209,14 @@ def _prompt_model_selection(
# Build a pricing header hint for the menu title
menu_title = "Select default model:"
if has_pricing:
# Align the header with the model column
menu_title += f"\n {'':>{name_col}} {'In':>{price_col}} {'Out':>{price_col}} /Mtok"
# Align the header with the model column.
# Each choice is " {label}" (2 spaces) and simple_term_menu prepends
# a 3-char cursor region ("-> " or " "), so content starts at col 5.
pad = " " * 5
header = f"\n{pad}{'':>{name_col}} {'In':>{price_col}} {'Out':>{price_col}}"
if has_cache:
header += f" {'Cache':>{cache_col}}"
menu_title += header + " /Mtok"
# Try arrow-key menu first, fall back to number input
try:

View File

@@ -360,7 +360,7 @@ def _format_price_per_mtok(per_token_str: str) -> str:
def format_pricing_label(pricing: dict[str, str] | None) -> str:
"""Build a compact pricing label like '$3/$15' (input/output per Mtok).
"""Build a compact pricing label like 'in $3 · out $15 · cache $0.30/Mtok'.
Returns empty string when pricing is unavailable.
"""
@@ -374,9 +374,14 @@ def format_pricing_label(pricing: dict[str, str] | None) -> str:
out = _format_price_per_mtok(completion_price)
if inp == "free" and out == "free":
return "free"
if inp == out:
cache_read = pricing.get("input_cache_read", "")
cache_str = _format_price_per_mtok(cache_read) if cache_read else ""
if inp == out and not cache_str:
return f"{inp}/Mtok"
return f"in {inp} · out {out}/Mtok"
parts = [f"in {inp}", f"out {out}"]
if cache_str and cache_str != "?" and cache_str != inp:
parts.append(f"cache {cache_str}")
return " · ".join(parts) + "/Mtok"
def format_model_pricing_table(
@@ -393,17 +398,22 @@ def format_model_pricing_table(
if not models:
return []
# Build rows: (model_id, input_price, output_price, is_current)
rows: list[tuple[str, str, str, bool]] = []
# Build rows: (model_id, input_price, output_price, cache_price, is_current)
rows: list[tuple[str, str, str, str, bool]] = []
has_cache = False
for mid, _desc in models:
is_cur = mid == current_model
p = pricing_map.get(mid)
if p:
inp = _format_price_per_mtok(p.get("prompt", ""))
out = _format_price_per_mtok(p.get("completion", ""))
cache_read = p.get("input_cache_read", "")
cache = _format_price_per_mtok(cache_read) if cache_read else ""
if cache:
has_cache = True
else:
inp, out = "", ""
rows.append((mid, inp, out, is_cur))
inp, out, cache = "", "", ""
rows.append((mid, inp, out, cache, is_cur))
name_col = max(len(r[0]) for r in rows) + 2
# Compute price column widths from the actual data so decimals align
@@ -412,15 +422,26 @@ def format_model_pricing_table(
max((len(r[2]) for r in rows if r[2]), default=4),
3, # minimum: "In" / "Out" header
)
cache_col = max(
max((len(r[3]) for r in rows if r[3]), default=4),
5, # minimum: "Cache" header
) if has_cache else 0
lines: list[str] = []
# Header
lines.append(f"{indent}{'Model':<{name_col}} {'In':>{price_col}} {'Out':>{price_col}} /Mtok")
lines.append(f"{indent}{'-' * name_col} {'-' * price_col} {'-' * price_col}")
if has_cache:
lines.append(f"{indent}{'Model':<{name_col}} {'In':>{price_col}} {'Out':>{price_col}} {'Cache':>{cache_col}} /Mtok")
lines.append(f"{indent}{'-' * name_col} {'-' * price_col} {'-' * price_col} {'-' * cache_col}")
else:
lines.append(f"{indent}{'Model':<{name_col}} {'In':>{price_col}} {'Out':>{price_col}} /Mtok")
lines.append(f"{indent}{'-' * name_col} {'-' * price_col} {'-' * price_col}")
for mid, inp, out, is_cur in rows:
for mid, inp, out, cache, is_cur in rows:
marker = " ← current" if is_cur else ""
lines.append(f"{indent}{mid:<{name_col}} {inp:>{price_col}} {out:>{price_col}}{marker}")
if has_cache:
lines.append(f"{indent}{mid:<{name_col}} {inp:>{price_col}} {out:>{price_col}} {cache:>{cache_col}}{marker}")
else:
lines.append(f"{indent}{mid:<{name_col}} {inp:>{price_col}} {out:>{price_col}}{marker}")
return lines
@@ -459,10 +480,15 @@ def fetch_models_with_pricing(
mid = item.get("id")
pricing = item.get("pricing")
if mid and isinstance(pricing, dict):
result[mid] = {
entry: dict[str, str] = {
"prompt": str(pricing.get("prompt", "")),
"completion": str(pricing.get("completion", "")),
}
if pricing.get("input_cache_read"):
entry["input_cache_read"] = str(pricing["input_cache_read"])
if pricing.get("input_cache_write"):
entry["input_cache_write"] = str(pricing["input_cache_write"])
result[mid] = entry
_pricing_cache[cache_key] = result
return result