Enhance batch processing and image generation tools

- Updated batch processing to include robust resume functionality by scanning completed prompts based on content rather than indices, improving recovery from failures.
- Implemented retry logic for image downloads with exponential backoff to handle transient failures effectively.
- Refined image generation tool to utilize the FLUX 2 Pro model, updating descriptions and parameters for clarity and consistency.
- Added new configuration scripts for GLM 4.7 and Imagen tasks, enhancing usability and logging capabilities.
- Removed outdated scripts and test files to streamline the codebase.
This commit is contained in:
teknium
2026-01-18 10:11:59 +00:00
parent b32cc4b09d
commit 6eb76c7c1a
14 changed files with 293 additions and 233 deletions

1
.gitignore vendored
View File

@@ -30,3 +30,4 @@ run_datagen_megascience_glm4-6.sh
run_datagen_sonnet.sh
source-data/*
run_datagen_megascience_glm4-6.sh
data/*

View File

@@ -379,8 +379,13 @@ def _process_batch_worker(args: Tuple) -> Dict[str, Any]:
batch_tool_stats[tool_name]["success"] += stats["success"]
batch_tool_stats[tool_name]["failure"] += stats["failure"]
completed_in_batch.append(prompt_index)
print(f" ✅ Prompt {prompt_index} completed")
# Only mark as completed if successfully saved (failed prompts can be retried on resume)
if result["success"] and result["trajectory"]:
completed_in_batch.append(prompt_index)
status = "⚠️ partial" if result.get("partial") else ""
print(f" {status} Prompt {prompt_index} completed")
else:
print(f" ❌ Prompt {prompt_index} failed (will retry on resume)")
print(f"✅ Batch {batch_num}: Completed ({len(prompts_to_process)} prompts processed)")
@@ -578,6 +583,83 @@ class BatchRunner:
with open(self.checkpoint_file, 'w', encoding='utf-8') as f:
json.dump(checkpoint_data, f, indent=2, ensure_ascii=False)
def _scan_completed_prompts_by_content(self) -> set:
"""
Scan all batch files and extract completed prompts by their actual content.
This provides a more robust resume mechanism that matches on prompt text
rather than indices, allowing recovery even if indices don't match.
Returns:
set: Set of prompt texts that have been successfully processed
"""
completed_prompts = set()
batch_files = sorted(self.output_dir.glob("batch_*.jsonl"))
if not batch_files:
return completed_prompts
print(f"📂 Scanning {len(batch_files)} batch files for completed prompts...")
for batch_file in batch_files:
try:
with open(batch_file, 'r', encoding='utf-8') as f:
for line in f:
try:
entry = json.loads(line.strip())
# Skip failed entries - we want to retry these
if entry.get("failed", False):
continue
# Extract the human/user prompt from conversations
conversations = entry.get("conversations", [])
for msg in conversations:
if msg.get("from") == "human":
prompt_text = msg.get("value", "").strip()
if prompt_text:
completed_prompts.add(prompt_text)
break # Only need the first human message
except json.JSONDecodeError:
continue
except Exception as e:
print(f" ⚠️ Warning: Error reading {batch_file.name}: {e}")
return completed_prompts
def _filter_dataset_by_completed(self, completed_prompts: set) -> Tuple[List[Dict], List[int]]:
"""
Filter the dataset to exclude prompts that have already been completed.
Args:
completed_prompts: Set of prompt texts that have been completed
Returns:
Tuple of (filtered_dataset, skipped_indices)
"""
filtered_dataset = []
skipped_indices = []
for idx, entry in enumerate(self.dataset):
# Extract prompt from the dataset entry
prompt_text = entry.get("prompt", "").strip()
# Also check conversations format
if not prompt_text:
conversations = entry.get("conversations", [])
for msg in conversations:
role = msg.get("role") or msg.get("from")
if role in ("user", "human"):
prompt_text = (msg.get("content") or msg.get("value", "")).strip()
break
if prompt_text in completed_prompts:
skipped_indices.append(idx)
else:
# Keep original index for tracking
filtered_dataset.append((idx, entry))
return filtered_dataset, skipped_indices
def run(self, resume: bool = False):
"""
@@ -590,17 +672,48 @@ class BatchRunner:
print("🚀 Starting Batch Processing")
print("=" * 70)
# Load checkpoint
checkpoint_data = self._load_checkpoint() if resume else {
# Smart resume: scan batch files by content to find completed prompts
completed_prompt_texts = set()
if resume:
completed_prompt_texts = self._scan_completed_prompts_by_content()
if completed_prompt_texts:
print(f" Found {len(completed_prompt_texts)} already-completed prompts by content matching")
# Filter dataset to only include unprocessed prompts
if resume and completed_prompt_texts:
filtered_entries, skipped_indices = self._filter_dataset_by_completed(completed_prompt_texts)
if not filtered_entries:
print("\n✅ All prompts have already been processed!")
return
# Recreate batches from filtered entries (keeping original indices for tracking)
batches_to_process = []
for i in range(0, len(filtered_entries), self.batch_size):
batch = filtered_entries[i:i + self.batch_size]
batches_to_process.append(batch)
self.batches = batches_to_process
# Print prominent resume summary
print("\n" + "=" * 70)
print("📊 RESUME SUMMARY")
print("=" * 70)
print(f" Original dataset size: {len(self.dataset):,} prompts")
print(f" Already completed: {len(skipped_indices):,} prompts")
print(f" ─────────────────────────────────────────")
print(f" 🎯 RESUMING WITH: {len(filtered_entries):,} prompts")
print(f" New batches created: {len(batches_to_process)}")
print("=" * 70 + "\n")
# Initialize checkpoint data (needed for saving at the end)
checkpoint_data = {
"run_name": self.run_name,
"completed_prompts": [],
"batch_stats": {},
"last_updated": None
}
if resume and checkpoint_data.get("completed_prompts"):
print(f"📂 Resuming from checkpoint ({len(checkpoint_data['completed_prompts'])} prompts already completed)")
# Prepare configuration for workers
config = {
"distribution": self.distribution,
@@ -617,8 +730,8 @@ class BatchRunner:
"provider_sort": self.provider_sort,
}
# Get completed prompts set
completed_prompts_set = set(checkpoint_data.get("completed_prompts", []))
# For backward compatibility, still track by index (but this is secondary to content matching)
completed_prompts_set = set()
# Aggregate statistics across all batches
total_tool_stats = {}
@@ -709,45 +822,51 @@ class BatchRunner:
stats["success_rate"] = 0.0
stats["failure_rate"] = 0.0
# Combine all batch files into a single trajectories.jsonl file
# Combine ALL batch files in directory into a single trajectories.jsonl file
# This includes both old batches (from previous runs) and new batches (from resume)
# Also filter out corrupted entries (where model generated invalid tool names)
combined_file = self.output_dir / "trajectories.jsonl"
print(f"\n📦 Combining batch files into {combined_file.name}...")
print(f"\n📦 Combining ALL batch files into {combined_file.name}...")
VALID_TOOLS = {'web_search', 'web_extract', 'web_crawl', 'terminal', 'vision_analyze',
'image_generate', 'mixture_of_agents'}
total_entries = 0
filtered_entries = 0
batch_files_found = 0
# Find ALL batch files in the output directory (handles resume merging old + new)
all_batch_files = sorted(self.output_dir.glob("batch_*.jsonl"))
with open(combined_file, 'w', encoding='utf-8') as outfile:
for batch_num in range(len(self.batches)):
batch_file = self.output_dir / f"batch_{batch_num}.jsonl"
if batch_file.exists():
with open(batch_file, 'r', encoding='utf-8') as infile:
for line in infile:
total_entries += 1
try:
data = json.loads(line)
tool_stats = data.get('tool_stats', {})
# Check for invalid tool names (model hallucinations)
invalid_tools = [k for k in tool_stats.keys() if k not in VALID_TOOLS]
if invalid_tools:
filtered_entries += 1
invalid_preview = invalid_tools[0][:50] + "..." if len(invalid_tools[0]) > 50 else invalid_tools[0]
print(f" ⚠️ Filtering corrupted entry (batch {batch_num}): invalid tool '{invalid_preview}'")
continue
outfile.write(line)
except json.JSONDecodeError:
for batch_file in all_batch_files:
batch_files_found += 1
batch_num = batch_file.stem.split("_")[1] # Extract batch number for logging
with open(batch_file, 'r', encoding='utf-8') as infile:
for line in infile:
total_entries += 1
try:
data = json.loads(line)
tool_stats = data.get('tool_stats', {})
# Check for invalid tool names (model hallucinations)
invalid_tools = [k for k in tool_stats.keys() if k not in VALID_TOOLS]
if invalid_tools:
filtered_entries += 1
print(f" ⚠️ Filtering invalid JSON entry (batch {batch_num})")
invalid_preview = invalid_tools[0][:50] + "..." if len(invalid_tools[0]) > 50 else invalid_tools[0]
print(f" ⚠️ Filtering corrupted entry (batch {batch_num}): invalid tool '{invalid_preview}'")
continue
outfile.write(line)
except json.JSONDecodeError:
filtered_entries += 1
print(f" ⚠️ Filtering invalid JSON entry (batch {batch_num})")
if filtered_entries > 0:
print(f"⚠️ Filtered {filtered_entries} corrupted entries out of {total_entries} total")
print(f"✅ Combined {len(self.batches)} batch files into trajectories.jsonl ({total_entries - filtered_entries} entries)")
print(f"✅ Combined {batch_files_found} batch files into trajectories.jsonl ({total_entries - filtered_entries} entries)")
# Save final statistics
final_stats = {
@@ -769,8 +888,9 @@ class BatchRunner:
print("\n" + "=" * 70)
print("📊 BATCH PROCESSING COMPLETE")
print("=" * 70)
print(f"Total prompts processed: {len(self.dataset)}")
print(f"✅ Total batches: {len(self.batches)}")
print(f"Prompts processed this run: {sum(r.get('processed', 0) for r in results)}")
print(f"✅ Total trajectories in merged file: {total_entries - filtered_entries}")
print(f"✅ Total batch files merged: {batch_files_found}")
print(f"⏱️ Total duration: {round(time.time() - start_time, 2)}s")
print(f"\n📈 Tool Usage Statistics:")
print("-" * 70)

View File

@@ -11,13 +11,13 @@ echo "📝 Logging output to: $LOG_FILE"
python batch_runner.py \
--dataset_file="source-data/hermes-agent-imagen-data/hermes_agent_imagen_train_sft.jsonl" \
--batch_size=10 \
--batch_size=20 \
--run_name="imagen_train_sft_glm4.7" \
--distribution="image_gen" \
--model="z-ai/glm-4.7" \
--base_url="https://openrouter.ai/api/v1" \
--providers_allowed="gmicloud,siliconflow,atlas-cloud,z-ai,novita" \
--num_workers=1 \
--num_workers=50 \
--max_turns=25 \
--ephemeral_system_prompt="When generating an image for the user view the image by using the vision_analyze tool to ensure it is what the user wanted. If it isn't feel free to retry a few times. If none are perfect, choose the best option that is the closest match, and explain its imperfections. If the image generation tool fails, try again a few times. If the vision analyze tool fails, provide the image to the user and explain it is your best effort attempt." \
2>&1 | tee "$LOG_FILE"

View File

@@ -0,0 +1,27 @@
#!/bin/bash
# Create logs directory if it doesn't exist
mkdir -p logs
# Generate log filename with timestamp
LOG_FILE="logs/glm4.7-thinking-sft1-10k_$(date +%Y%m%d_%H%M%S).log"
echo "📝 Logging output to: $LOG_FILE"
python batch_runner.py \
--dataset_file="source-data/hermes-agent-megascience-data/hermes_agent_megascience_sft_train_1_10k.jsonl" \
--batch_size=20 \
--run_name="megascience_glm4.7-thinking-sft1" \
--distribution="science" \
--model="z-ai/glm-4.7" \
--base_url="https://openrouter.ai/api/v1" \
--providers_allowed="gmicloud,siliconflow,atlas-cloud,z-ai,novita" \
--num_workers=50 \
--max_turns=60 \
--resume \
--ephemeral_system_prompt="You have access to a variety of tools to help you solve scientific, math, and technology problems presented to you. You can use them in sequence and build off of the results of prior tools you've used for furthering results. Always use the terminal or search tool if it can provide additional context, verify formulas, double check concepts and recent studies and understanding, doing all calculations, etc. You should only be confident in your own reasoning, knowledge, or calculations if you've exhaustively used all tools available to you to that can help you verify or validate your work. Always pip install any packages you need to use the python scripts you want to run. If you need to use a tool that isn't available, you can use the terminal tool to install or create it in many cases as well. Do not use the terminal tool to communicate with the user, as they cannot see your commands, only your final response after completing the task. Search for at least 3 sources, but not more than 12, so you can maintain a focused context." \
2>&1 | tee "$LOG_FILE"
echo "✅ Log saved to: $LOG_FILE"
# --verbose \

View File

@@ -0,0 +1,28 @@
#!/bin/bash
# Create logs directory if it doesn't exist
mkdir -p logs
# Generate log filename with timestamp
LOG_FILE="logs/glm4.7-terminal-tasks_$(date +%Y%m%d_%H%M%S).log"
echo "📝 Logging output to: $LOG_FILE"
python batch_runner.py \
--dataset_file="source-data/raw_tasks_prompts.jsonl" \
--batch_size=20 \
--run_name="terminal-tasks-glm4.7-thinking" \
--distribution="default" \
--model="z-ai/glm-4.7" \
--base_url="https://openrouter.ai/api/v1" \
--providers_allowed="gmicloud,siliconflow,atlas-cloud,z-ai,novita" \
--num_workers=50 \
--max_turns=60 \
--ephemeral_system_prompt="You have access to a variety of tools to help you complete coding, system administration, and general computing tasks. You can use them in sequence and build off of the results of prior tools you've used. Always use the terminal tool to execute commands, write code, install packages, and verify your work. You should test and validate everything you create. Always pip install any packages you need (use --break-system-packages if needed). If you need a tool that isn't available, you can use the terminal to install or create it. Do not use the terminal tool to communicate with the user, as they cannot see your commands, only your final response after completing the task. Use web search when you need to look up documentation, APIs, or current best practices." \
2>&1 | tee "$LOG_FILE"
echo "✅ Log saved to: $LOG_FILE"
# --verbose \
# --resume \

View File

@@ -220,7 +220,7 @@ def get_image_tool_definitions() -> List[Dict[str, Any]]:
"type": "function",
"function": {
"name": "image_generate",
"description": "Generate high-quality images from text prompts using FLUX Krea model with automatic 2x upscaling. Creates detailed, artistic images that are automatically enhanced for superior quality. Returns a single upscaled image URL that can be displayed using <img src=\"{URL}\"></img> tags.",
"description": "Generate high-quality images from text prompts using FLUX 2 Pro model with automatic 2x upscaling. Creates detailed, artistic images that are automatically upscaled for hi-rez results. Returns a single upscaled image URL that can be displayed using <img src=\"{URL}\"></img> tags.",
"parameters": {
"type": "object",
"properties": {
@@ -228,11 +228,11 @@ def get_image_tool_definitions() -> List[Dict[str, Any]]:
"type": "string",
"description": "The text prompt describing the desired image. Be detailed and descriptive."
},
"image_size": {
"aspect_ratio": {
"type": "string",
"enum": ["square","portrait_16_9", "landscape_16_9"],
"description": "The size/aspect ratio of the generated image (default: landscape_4_3)",
"default": "landscape_16_9"
"enum": ["landscape", "square", "portrait"],
"description": "The aspect ratio of the generated image. 'landscape' is 16:9 wide, 'portrait' is 16:9 tall, 'square' is 1:1.",
"default": "landscape"
}
},
"required": ["prompt"]
@@ -560,16 +560,13 @@ def handle_image_function_call(function_name: str, function_args: Dict[str, Any]
if not prompt:
return json.dumps({"success": False, "image": None}, ensure_ascii=False)
image_size = function_args.get("image_size", "landscape_16_9")
aspect_ratio = function_args.get("aspect_ratio", "landscape")
# Use fixed internal defaults for all other parameters (not exposed to model)
num_inference_steps = 50
guidance_scale = 4.5
num_images = 1
enable_safety_checker = True
output_format = "png"
acceleration = "none"
allow_nsfw_images = True
seed = None
# Run async function in event loop with proper handling for multiprocessing
@@ -588,14 +585,11 @@ def handle_image_function_call(function_name: str, function_args: Dict[str, Any]
# Run the coroutine in the event loop
result = loop.run_until_complete(image_generate_tool(
prompt=prompt,
image_size=image_size,
aspect_ratio=aspect_ratio,
num_inference_steps=num_inference_steps,
guidance_scale=guidance_scale,
num_images=num_images,
enable_safety_checker=enable_safety_checker,
output_format=output_format,
acceleration=acceleration,
allow_nsfw_images=allow_nsfw_images,
seed=seed
))

View File

@@ -1,12 +0,0 @@
python batch_runner.py \
--dataset_file="hermes-agent-megascience-data/hermes_agent_megascience_eval.jsonl" \
--batch_size=10 \
--run_name="megascience_eval_glm4-6-fixedterminal-2" \
--distribution="science" \
--model="z-ai/glm-4.6" \
--base_url="https://openrouter.ai/api/v1" \
--api_key="${OPENROUTER_API_KEY}" \
--num_workers=5 \
--max_turns=30 \
--verbose \
--ephemeral_system_prompt="You have access to a variety of tools to help you solve scientific, math, and technology problems presented to you. You can use them in sequence and build off of the results of prior tools you've used results. Always use a tool if it can provide additional context, verify formulas, double check concepts and recent studies and understanding, doing all calculations, etc. You should only be confident in your own reasoning, knowledge, or calculations if you've exhaustively used all tools available to you to that can help you verify or validate your work. Always pip install any packages you need to use the python scripts you want to run."

View File

@@ -1,124 +0,0 @@
#!/usr/bin/env python3
"""
Test script to see how minimax-m2.1 responds to a tool-calling request via OpenRouter.
"""
import os
import json
from pathlib import Path
from openai import OpenAI
from dotenv import load_dotenv
# Load environment variables
env_path = Path(__file__).parent / '.env'
if env_path.exists():
load_dotenv(dotenv_path=env_path)
print(f"✅ Loaded .env from {env_path}")
# Get API key
api_key = os.getenv("OPENROUTER_API_KEY")
if not api_key:
print("❌ OPENROUTER_API_KEY not found in environment")
exit(1)
print(f"🔑 Using API key: {api_key[:12]}...{api_key[-4:]}")
# Initialize client
client = OpenAI(
base_url="https://openrouter.ai/api/v1",
api_key=api_key
)
# Define a single simple tool
tools = [
{
"type": "function",
"function": {
"name": "web_search",
"description": "Search the web for information on any topic. Returns relevant results with titles and URLs.",
"parameters": {
"type": "object",
"properties": {
"query": {
"type": "string",
"description": "The search query to look up on the web"
}
},
"required": ["query"]
}
}
}
]
# Messages
messages = [
{
"role": "system",
"content": "You are a helpful assistant with access to tools. Use the web_search tool when you need to find information."
},
{
"role": "user",
"content": "What is the current price of Bitcoin?"
}
]
print("\n" + "="*60)
print("📤 SENDING REQUEST")
print("="*60)
print(f"Model: minimax/minimax-m2.1")
print(f"Messages: {len(messages)}")
print(f"Tools: {len(tools)}")
print(f"User query: {messages[-1]['content']}")
# Make the request
try:
response = client.chat.completions.create(
model="minimax/minimax-m2.1",
messages=messages,
tools=tools,
extra_body={
"provider": {
"only": ["minimax"]
}
},
timeout=120.0
)
print("\n" + "="*60)
print("📥 RESPONSE RECEIVED")
print("="*60)
# Print raw response info
print(f"\nModel: {response.model}")
print(f"ID: {response.id}")
print(f"Created: {response.created}")
if response.usage:
print(f"\n📊 Usage:")
print(f" Prompt tokens: {response.usage.prompt_tokens}")
print(f" Completion tokens: {response.usage.completion_tokens}")
print(f" Total tokens: {response.usage.total_tokens}")
# Print the message
msg = response.choices[0].message
print(f"\n🤖 Assistant Response:")
print(f" Role: {msg.role}")
print(f" Content: {msg.content}")
print(f" Tool calls: {msg.tool_calls}")
if msg.tool_calls:
print(f"\n🔧 Tool Calls Detail:")
for i, tc in enumerate(msg.tool_calls):
print(f" [{i}] ID: {tc.id}")
print(f" Function: {tc.function.name}")
print(f" Arguments: {tc.function.arguments}")
# Print full raw response as JSON
print("\n" + "="*60)
print("📝 RAW RESPONSE (JSON)")
print("="*60)
print(json.dumps(response.model_dump(), indent=2, default=str))
except Exception as e:
print(f"\n❌ Error: {type(e).__name__}: {e}")
import traceback
traceback.print_exc()

View File

@@ -2,14 +2,14 @@
"""
Image Generation Tools Module
This module provides image generation tools using FAL.ai's FLUX.1 Krea model with
This module provides image generation tools using FAL.ai's FLUX 2 Pro model with
automatic upscaling via FAL.ai's Clarity Upscaler for enhanced image quality.
Available tools:
- image_generate_tool: Generate images from text prompts with automatic upscaling
Features:
- High-quality image generation using FLUX.1 Krea model
- High-quality image generation using FLUX 2 Pro model
- Automatic 2x upscaling using Clarity Upscaler for enhanced quality
- Comprehensive parameter control (size, steps, guidance, etc.)
- Proper error handling and validation with fallback to original images
@@ -38,13 +38,25 @@ from typing import Dict, Any, Optional, Union
import fal_client
# Configuration for image generation
DEFAULT_MODEL = "fal-ai/flux/krea"
DEFAULT_IMAGE_SIZE = "landscape_4_3"
DEFAULT_MODEL = "fal-ai/flux-2-pro"
DEFAULT_ASPECT_RATIO = "landscape"
DEFAULT_NUM_INFERENCE_STEPS = 50
DEFAULT_GUIDANCE_SCALE = 4.5
DEFAULT_NUM_IMAGES = 1
DEFAULT_OUTPUT_FORMAT = "png"
# Safety settings
ENABLE_SAFETY_CHECKER = False
SAFETY_TOLERANCE = "5" # Maximum tolerance (1-5, where 5 is most permissive)
# Aspect ratio mapping - simplified choices for model to select
ASPECT_RATIO_MAP = {
"landscape": "landscape_16_9",
"square": "square_hd",
"portrait": "portrait_16_9"
}
VALID_ASPECT_RATIOS = list(ASPECT_RATIO_MAP.keys())
# Configuration for automatic upscaling
UPSCALER_MODEL = "fal-ai/clarity-upscaler"
UPSCALER_FACTOR = 2
@@ -56,7 +68,7 @@ UPSCALER_RESEMBLANCE = 0.6
UPSCALER_GUIDANCE_SCALE = 4
UPSCALER_NUM_INFERENCE_STEPS = 18
# Valid parameter values for validation based on FLUX Krea documentation
# Valid parameter values for validation based on FLUX 2 Pro documentation
VALID_IMAGE_SIZES = [
"square_hd", "square", "portrait_4_3", "portrait_16_9", "landscape_4_3", "landscape_16_9"
]
@@ -133,7 +145,7 @@ def _validate_parameters(
acceleration: str = "none"
) -> Dict[str, Any]:
"""
Validate and normalize image generation parameters for FLUX Krea model.
Validate and normalize image generation parameters for FLUX 2 Pro model.
Args:
image_size: Either a preset string or custom size dict
@@ -174,7 +186,7 @@ def _validate_parameters(
raise ValueError("num_inference_steps must be an integer between 1 and 100")
validated["num_inference_steps"] = num_inference_steps
# Validate guidance_scale (FLUX Krea default is 4.5)
# Validate guidance_scale (FLUX 2 Pro default is 4.5)
if not isinstance(guidance_scale, (int, float)) or guidance_scale < 0.1 or guidance_scale > 20.0:
raise ValueError("guidance_scale must be a number between 0.1 and 20.0")
validated["guidance_scale"] = float(guidance_scale)
@@ -254,34 +266,28 @@ async def _upscale_image(image_url: str, original_prompt: str) -> Dict[str, Any]
async def image_generate_tool(
prompt: str,
image_size: Union[str, Dict[str, int]] = DEFAULT_IMAGE_SIZE,
aspect_ratio: str = DEFAULT_ASPECT_RATIO,
num_inference_steps: int = DEFAULT_NUM_INFERENCE_STEPS,
guidance_scale: float = DEFAULT_GUIDANCE_SCALE,
num_images: int = DEFAULT_NUM_IMAGES,
enable_safety_checker: bool = True,
output_format: str = DEFAULT_OUTPUT_FORMAT,
acceleration: str = "none",
allow_nsfw_images: bool = True,
seed: Optional[int] = None
) -> str:
"""
Generate images from text prompts using FAL.ai's FLUX.1 Krea model with automatic upscaling.
Generate images from text prompts using FAL.ai's FLUX 2 Pro model with automatic upscaling.
This tool uses FAL.ai's FLUX.1 Krea model for high-quality text-to-image generation
This tool uses FAL.ai's FLUX 2 Pro model for high-quality text-to-image generation
with extensive customization options. Generated images are automatically upscaled 2x
using FAL.ai's Clarity Upscaler for enhanced quality. The final upscaled images are
returned as URLs that can be displayed using <img src="{URL}"></img> tags.
Args:
prompt (str): The text prompt describing the desired image
image_size (Union[str, Dict[str, int]]): Preset size or custom {"width": int, "height": int}
num_inference_steps (int): Number of denoising steps (1-50, default: 28)
aspect_ratio (str): Image aspect ratio - "landscape", "square", or "portrait" (default: "landscape")
num_inference_steps (int): Number of denoising steps (1-50, default: 50)
guidance_scale (float): How closely to follow prompt (0.1-20.0, default: 4.5)
num_images (int): Number of images to generate (1-4, default: 1)
enable_safety_checker (bool): Enable content safety filtering (default: True)
output_format (str): Image format "jpeg" or "png" (default: "png")
acceleration (str): Generation speed "none", "regular", or "high" (default: "none")
allow_nsfw_images (bool): Allow generation of NSFW content (default: True)
seed (Optional[int]): Random seed for reproducible results (optional)
Returns:
@@ -291,17 +297,22 @@ async def image_generate_tool(
"image": str or None # URL of the upscaled image, or None if failed
}
"""
# Validate and map aspect_ratio to actual image_size
aspect_ratio_lower = aspect_ratio.lower().strip() if aspect_ratio else DEFAULT_ASPECT_RATIO
if aspect_ratio_lower not in ASPECT_RATIO_MAP:
print(f"⚠️ Invalid aspect_ratio '{aspect_ratio}', defaulting to '{DEFAULT_ASPECT_RATIO}'")
aspect_ratio_lower = DEFAULT_ASPECT_RATIO
image_size = ASPECT_RATIO_MAP[aspect_ratio_lower]
debug_call_data = {
"parameters": {
"prompt": prompt,
"aspect_ratio": aspect_ratio,
"image_size": image_size,
"num_inference_steps": num_inference_steps,
"guidance_scale": guidance_scale,
"num_images": num_images,
"enable_safety_checker": enable_safety_checker,
"output_format": output_format,
"acceleration": acceleration,
"allow_nsfw_images": allow_nsfw_images,
"seed": seed
},
"error": None,
@@ -313,7 +324,7 @@ async def image_generate_tool(
start_time = datetime.datetime.now()
try:
print(f"🎨 Generating {num_images} image(s) with FLUX Krea: {prompt[:80]}{'...' if len(prompt) > 80 else ''}")
print(f"🎨 Generating {num_images} image(s) with FLUX 2 Pro: {prompt[:80]}{'...' if len(prompt) > 80 else ''}")
# Validate prompt
if not prompt or not isinstance(prompt, str) or len(prompt.strip()) == 0:
@@ -323,22 +334,21 @@ async def image_generate_tool(
if not os.getenv("FAL_KEY"):
raise ValueError("FAL_KEY environment variable not set")
# Validate parameters
# Validate other parameters
validated_params = _validate_parameters(
image_size, num_inference_steps, guidance_scale, num_images, output_format, acceleration
image_size, num_inference_steps, guidance_scale, num_images, output_format, "none"
)
# Prepare arguments for FAL.ai FLUX Krea API
# Prepare arguments for FAL.ai FLUX 2 Pro API
arguments = {
"prompt": prompt.strip(),
"image_size": validated_params["image_size"],
"num_inference_steps": validated_params["num_inference_steps"],
"guidance_scale": validated_params["guidance_scale"],
"num_images": validated_params["num_images"],
"enable_safety_checker": enable_safety_checker,
"output_format": validated_params["output_format"],
"acceleration": validated_params["acceleration"],
"allow_nsfw_images": allow_nsfw_images,
"enable_safety_checker": ENABLE_SAFETY_CHECKER,
"safety_tolerance": SAFETY_TOLERANCE,
"sync_mode": True # Use sync mode for immediate results
}
@@ -346,12 +356,11 @@ async def image_generate_tool(
if seed is not None and isinstance(seed, int):
arguments["seed"] = seed
print(f"🚀 Submitting generation request to FAL.ai FLUX Krea...")
print(f"🚀 Submitting generation request to FAL.ai FLUX 2 Pro...")
print(f" Model: {DEFAULT_MODEL}")
print(f" Size: {validated_params['image_size']}")
print(f" Aspect Ratio: {aspect_ratio_lower}{image_size}")
print(f" Steps: {validated_params['num_inference_steps']}")
print(f" Guidance: {validated_params['guidance_scale']}")
print(f" Acceleration: {validated_params['acceleration']}")
# Submit request to FAL.ai
handler = await fal_client.submit_async(
@@ -492,7 +501,7 @@ if __name__ == "__main__":
"""
Simple test/demo when run directly
"""
print("🎨 Image Generation Tools Module - FLUX.1 Krea + Auto Upscaling")
print("🎨 Image Generation Tools Module - FLUX 2 Pro + Auto Upscaling")
print("=" * 60)
# Check if API key is available

View File

@@ -131,35 +131,52 @@ def _validate_image_url(url: str) -> bool:
return True # Allow all HTTP/HTTPS URLs for flexibility
async def _download_image(image_url: str, destination: Path) -> Path:
async def _download_image(image_url: str, destination: Path, max_retries: int = 3) -> Path:
"""
Download an image from a URL to a local destination (async).
Download an image from a URL to a local destination (async) with retry logic.
Args:
image_url (str): The URL of the image to download
destination (Path): The path where the image should be saved
max_retries (int): Maximum number of retry attempts (default: 3)
Returns:
Path: The path to the downloaded image
Raises:
Exception: If download fails or response is invalid
Exception: If download fails after all retries
"""
import asyncio
# Create parent directories if they don't exist
destination.parent.mkdir(parents=True, exist_ok=True)
# Download the image with appropriate headers using async httpx
async with httpx.AsyncClient(timeout=30.0) as client:
response = await client.get(
image_url,
headers={"User-Agent": "hermes-agent-vision/1.0"},
)
response.raise_for_status()
# Save the image content
destination.write_bytes(response.content)
last_error = None
for attempt in range(max_retries):
try:
# Download the image with appropriate headers using async httpx
async with httpx.AsyncClient(timeout=30.0) as client:
response = await client.get(
image_url,
headers={"User-Agent": "hermes-agent-vision/1.0"},
)
response.raise_for_status()
# Save the image content
destination.write_bytes(response.content)
return destination
except Exception as e:
last_error = e
if attempt < max_retries - 1:
wait_time = 2 ** (attempt + 1) # 2s, 4s, 8s
print(f"⚠️ Image download failed (attempt {attempt + 1}/{max_retries}): {str(e)[:50]}")
print(f" Retrying in {wait_time}s...")
await asyncio.sleep(wait_time)
else:
print(f"❌ Image download failed after {max_retries} attempts: {str(e)[:100]}")
return destination
raise last_error
def _determine_mime_type(image_path: Path) -> str: