Compare commits

..

1 Commits

Author SHA1 Message Date
7d628ea087 feat: add style-lock shared infrastructure (#642)
Some checks failed
Docker Build and Publish / build-and-push (pull_request) Has been skipped
Contributor Attribution Check / check-attribution (pull_request) Failing after 24s
Supply Chain Audit / Scan PR for supply chain risks (pull_request) Successful in 37s
Tests / e2e (pull_request) Successful in 3m8s
Tests / test (pull_request) Failing after 35m31s
Shared style-lock module for consistent visual style across video generation.
Extracts style embeddings from reference images and provides conditioning
for IP-Adapter, ControlNet, and style token injection.

Used by Video Forge (#52) and LPM 1.0 (#641):
- Color palette extraction (k-means, temperature, saturation)
- Lighting profile (histogram, contrast, directional analysis)
- Texture features (Gabor filter bank)
- CLIP embedding (graceful degradation when unavailable)
- Conditioning output for SDXL/FLUX/ComfyUI backends
- Multi-reference merging for identity photo sets
- Comprehensive test suite (14 tests)

Branch: fix/issue-642-1
2026-04-14 21:29:47 -04:00
12 changed files with 1395 additions and 293 deletions

View File

View File

@@ -0,0 +1,120 @@
---
name: style-lock
description: "Shared style-lock infrastructure for consistent visual style across multi-frame video generation. Extracts style embeddings from a reference image and injects as conditioning (IP-Adapter, ControlNet, style tokens) into all subsequent generations. Used by Video Forge (playground #52) and LPM 1.0 (#641). Use when generating multiple images/clips that need visual coherence — consistent color palette, brush strokes, lighting, and aesthetic across scenes or frames."
---
# Style Lock — Consistent Visual Style Across Video Generation
## Overview
When generating multiple images/clips that compose a video, each generation is independent. Without conditioning, visual style drifts wildly between frames/scenes. Style Lock solves this by extracting a style embedding from a reference image and injecting it as conditioning into all subsequent generations.
## Quick Start
```python
from scripts.style_lock import StyleLock
# Initialize with a reference image
lock = StyleLock("reference.png")
# Get conditioning for Stable Diffusion XL
conditioning = lock.get_conditioning(
backend="sdxl", # "sdxl", "flux", "comfyui"
method="ip_adapter", # "ip_adapter", "controlnet", "style_tokens", "hybrid"
strength=0.75 # Style adherence (0.0-1.0)
)
# Use in generation pipeline
result = generate(prompt="a sunset over mountains", **conditioning)
```
## Architecture
```
Reference Image
┌─────────────────┐
│ Style Extractor │──→ CLIP embedding
│ │──→ Color palette (dominant colors)
│ │──→ Texture features (Gabor filters)
│ │──→ Lighting analysis (histogram)
└────────┬────────┘
┌─────────────────┐
│ Conditioning │──→ IP-Adapter (reference image injection)
│ Router │──→ ControlNet (structural conditioning)
│ │──→ Style tokens (text conditioning)
│ │──→ Color palette constraint
└────────┬────────┘
Generation Pipeline
```
## Methods
| Method | Best For | Requires | Quality |
|--------|----------|----------|---------|
| `ip_adapter` | Reference-guided style transfer | SD XL, IP-Adapter model | ★★★★★ |
| `controlnet` | Structural + style conditioning | ControlNet models | ★★★★ |
| `style_tokens` | Text-prompt-based style | Any model | ★★★ |
| `hybrid` | Maximum consistency | All of the above | ★★★★★ |
## Cross-Project Integration
### Video Forge (Playground #52)
- Extract style from seed image or first scene
- Apply across all scenes in music video generation
- Scene-to-scene temporal coherence via shared style embedding
### LPM 1.0 (Issue #641)
- Extract style from 8 identity reference photos
- Frame-to-frame consistency in real-time video generation
- Style tokens for HeartMuLa audio style consistency
## Configuration
```yaml
style_lock:
reference_image: "path/to/reference.png"
backend: "sdxl"
method: "hybrid"
strength: 0.75
color_palette:
enabled: true
num_colors: 5
tolerance: 0.15
lighting:
enabled: true
match_histogram: true
texture:
enabled: true
gabor_orientations: 8
gabor_frequencies: [0.1, 0.2, 0.3, 0.4]
```
## Reference Documents
- `references/ip-adapter-setup.md` — IP-Adapter installation and model requirements
- `references/controlnet-conditioning.md` — ControlNet configuration for style
- `references/color-palette-extraction.md` — Color palette analysis and matching
- `references/texture-analysis.md` — Gabor filter texture feature extraction
## Dependencies
```
torch>=2.0
Pillow>=10.0
numpy>=1.24
opencv-python>=4.8
scikit-image>=0.22
```
Optional (for specific backends):
```
diffusers>=0.25 # SD XL / IP-Adapter
transformers>=4.36 # CLIP embeddings
safetensors>=0.4 # Model loading
```

View File

@@ -0,0 +1,106 @@
# Color Palette Extraction for Style Lock
## Overview
Color palette is the most immediately visible aspect of visual style. Two scenes
with matching palettes feel related even if content differs entirely. The StyleLock
extracts a dominant palette from the reference and provides it as conditioning.
## Extraction Method
1. Downsample reference to 150x150 for speed
2. Convert to RGB pixel array
3. K-means clustering (k=5 by default) to find dominant colors
4. Sort by frequency (most dominant first)
5. Derive metadata: temperature, saturation, brightness
## Color Temperature
Derived from average R vs B channel:
| Temperature | Condition | Visual Feel |
|-------------|-----------|-------------|
| Warm | avg_R > avg_B + 20 | Golden, orange, amber, cozy |
| Cool | avg_B > avg_R + 20 | Blue, teal, steel, clinical |
| Neutral | Neither | Balanced, natural |
## Usage in Conditioning
### Text Prompt Injection
Convert palette to style descriptors:
```python
def palette_to_prompt(palette: ColorPalette) -> str:
parts = [f"{palette.temperature} color palette"]
if palette.saturation_mean > 0.6:
parts.append("vibrant saturated colors")
elif palette.saturation_mean < 0.3:
parts.append("muted desaturated tones")
return ", ".join(parts)
```
### Color Grading (Post-Processing)
Match output colors to reference palette:
```python
import cv2
import numpy as np
def color_grade_to_palette(image: np.ndarray, palette: ColorPalette,
strength: float = 0.5) -> np.ndarray:
"""Shift image colors toward reference palette."""
result = image.astype(np.float32)
target_colors = np.array(palette.colors, dtype=np.float32)
target_weights = np.array(palette.weights)
# For each pixel, find nearest palette color and blend toward it
flat = result.reshape(-1, 3)
for i, pixel in enumerate(flat):
dists = np.linalg.norm(target_colors - pixel, axis=1)
nearest = target_colors[np.argmin(dists)]
flat[i] = pixel * (1 - strength) + nearest * strength
return np.clip(flat.reshape(image.shape), 0, 255).astype(np.uint8)
```
### ComfyUI Color Palette Node
For ComfyUI integration, expose palette as a conditioning node:
```python
class StyleLockColorPalette:
@classmethod
def INPUT_TYPES(cls):
return {"required": {
"palette_json": ("STRING", {"multiline": True}),
"strength": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 1.0}),
}}
RETURN_TYPES = ("CONDITIONING",)
FUNCTION = "apply_palette"
CATEGORY = "style-lock"
def apply_palette(self, palette_json, strength):
palette = json.loads(palette_json)
# Convert to CLIP text conditioning
# ...
```
## Palette Matching Between Frames
To detect style drift, compare palettes across frames:
```python
def palette_distance(p1: ColorPalette, p2: ColorPalette) -> float:
"""Earth Mover's Distance between two palettes."""
from scipy.spatial.distance import cdist
cost = cdist(p1.colors, p2.colors, metric='euclidean')
weights1 = np.array(p1.weights)
weights2 = np.array(p2.weights)
# Simplified EMD (full implementation requires optimization)
return float(np.sum(cost * np.outer(weights1, weights2)))
```
Threshold: distance > 50 indicates significant style drift.

View File

@@ -0,0 +1,102 @@
# ControlNet Conditioning for Style Lock
## Overview
ControlNet provides structural conditioning — edges, depth, pose, segmentation —
that complements IP-Adapter's style-only conditioning. Used together in `hybrid`
mode for maximum consistency.
## Supported Control Types
| Type | Preprocessor | Best For |
|------|-------------|----------|
| Canny Edge | `cv2.Canny` | Clean line art, geometric scenes |
| Depth | MiDaS / DPT | Spatial consistency, 3D scenes |
| Lineart | Anime/Realistic lineart | Anime, illustration |
| Soft Edge | HED | Organic shapes, portraits |
| Segmentation | SAM / OneFormer | Scene layout consistency |
## Style Lock Approach
For style consistency (not structural control), use ControlNet **softly**:
1. Extract edges from the reference image (Canny, threshold 50-150)
2. Use as ControlNet input with low conditioning scale (0.3-0.5)
3. This preserves the compositional structure while allowing style variation
```python
import cv2
import numpy as np
from PIL import Image
def extract_control_image(ref_path: str, method: str = "canny") -> np.ndarray:
img = np.array(Image.open(ref_path).convert("RGB"))
gray = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)
if method == "canny":
edges = cv2.Canny(gray, 50, 150)
return cv2.cvtColor(edges, cv2.COLOR_GRAY2RGB)
elif method == "soft_edge":
# HED-like approximation using Gaussian blur + Canny
blurred = cv2.GaussianBlur(gray, (5, 5), 1.5)
edges = cv2.Canny(blurred, 30, 100)
return cv2.cvtColor(edges, cv2.COLOR_GRAY2RGB)
else:
raise ValueError(f"Unknown method: {method}")
```
## Diffusers Integration
```python
from diffusers import StableDiffusionXLControlNetPipeline, ControlNetModel
controlnet = ControlNetModel.from_pretrained(
"diffusers/controlnet-canny-sdxl-1.0",
torch_dtype=torch.float16,
)
pipe = StableDiffusionXLControlNetPipeline.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0",
controlnet=controlnet,
torch_dtype=torch.float16,
)
result = pipe(
prompt="a sunset over mountains",
image=control_image,
controlnet_conditioning_scale=0.5,
).images[0]
```
## Hybrid Mode (IP-Adapter + ControlNet)
For maximum consistency, combine both:
```python
# IP-Adapter for style
pipe.load_ip_adapter("h94/IP-Adapter", subfolder="sdxl_models",
weight_name="ip-adapter_sdxl.bin")
# ControlNet for structure
# Already loaded in pipeline construction
result = pipe(
prompt="a sunset over mountains",
ip_adapter_image=reference_image, # Style
image=control_image, # Structure
ip_adapter_scale=0.75,
controlnet_conditioning_scale=0.4,
).images[0]
```
## Style Lock Output
When using `method="controlnet"` or `method="hybrid"`, the StyleLock class
preprocesses the reference image through Canny edge detection and provides
it as `controlnet_image` in the ConditioningOutput.
Adjust `controlnet_conditioning_scale` via the strength parameter:
```
effective_scale = controlnet_conditioning_scale * style_lock_strength
```

View File

@@ -0,0 +1,79 @@
# IP-Adapter Setup for Style Lock
## Overview
IP-Adapter (Image Prompt Adapter) enables conditioning Stable Diffusion generation
on a reference image without fine-tuning. It extracts a CLIP image embedding and
injects it alongside text prompts via decoupled cross-attention.
## Installation
```bash
pip install diffusers>=0.25 transformers>=4.36 accelerate safetensors
```
## Models
| Model | Base | Use Case |
|-------|------|----------|
| `ip-adapter_sd15` | SD 1.5 | Fast, lower quality |
| `ip-adapter_sd15_plus` | SD 1.5 | Better style fidelity |
| `ip-adapter_sdxl_vit-h` | SDXL | High quality, recommended |
| `ip-adapter_flux` | FLUX | Best quality, highest VRAM |
Download from: `h94/IP-Adapter` on HuggingFace.
## Usage with Diffusers
```python
from diffusers import StableDiffusionXLPipeline, DDIMScheduler
from PIL import Image
pipe = StableDiffusionXLPipeline.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0",
torch_dtype=torch.float16,
)
pipe.load_ip_adapter(
"h94/IP-Adapter",
subfolder="sdxl_models",
weight_name="ip-adapter_sdxl.bin",
)
reference = Image.open("reference.png")
pipe.set_ip_adapter_scale(0.75) # Style adherence strength
result = pipe(
prompt="a sunset over mountains",
ip_adapter_image=reference,
num_inference_steps=30,
).images[0]
```
## Style Lock Integration
The StyleLock class provides `ip_adapter_image` and `ip_adapter_scale` in the
conditioning output. Pass these directly to the pipeline:
```python
lock = StyleLock("reference.png")
cond = lock.get_conditioning(backend="sdxl", method="ip_adapter")
kwargs = cond.to_api_kwargs()
result = pipe(prompt="a sunset over mountains", **kwargs).images[0]
```
## Tuning Guide
| Scale | Effect | Use When |
|-------|--------|----------|
| 0.3-0.5 | Loose style influence | Want style hints, not exact match |
| 0.5-0.7 | Balanced | General video generation |
| 0.7-0.9 | Strong adherence | Strict style consistency needed |
| 0.9-1.0 | Near-except copy | Reference IS the target style |
## Tips
- First frame is the best reference — it has the exact lighting/mood you want
- Use `ip_adapter_scale` 0.7+ for scene-to-scene consistency
- Combine with ControlNet for both style AND structure
- For LPM 1.0 frame-to-frame: use scale 0.85+, extract from best identity photo

View File

@@ -0,0 +1,105 @@
# Texture Analysis for Style Lock
## Overview
Texture captures the "feel" of visual surfaces — smooth, rough, grainy, painterly,
digital. Gabor filters extract frequency-orientation features that describe texture
in a way similar to human visual cortex.
## Gabor Filter Bank
A Gabor filter is a sinusoidal wave modulated by a Gaussian envelope:
```
g(x, y) = exp(-(x'² + γ²y'²) / (2σ²)) * exp(2πi * x' * f)
```
Where:
- `f` = spatial frequency (controls detail scale)
- `θ` = orientation (controls direction)
- `σ` = Gaussian standard deviation (controls bandwidth)
- `γ` = aspect ratio (usually 0.5)
## Default Configuration
```python
orientations = 8 # 0°, 22.5°, 45°, 67.5°, 90°, 112.5°, 135°, 157.5°
frequencies = [0.1, 0.2, 0.3, 0.4] # Low to high spatial frequency
```
Total: 32 filter responses per image.
## Feature Extraction
```python
from skimage.filters import gabor
import numpy as np
def extract_texture_features(gray_image, orientations=8,
frequencies=[0.1, 0.2, 0.3, 0.4]):
theta_values = np.linspace(0, np.pi, orientations, endpoint=False)
features = []
for freq in frequencies:
for theta in theta_values:
magnitude, _ = gabor(gray_image, frequency=freq, theta=theta)
features.append(magnitude.mean())
return np.array(features)
```
## Derived Metrics
| Metric | Formula | Meaning |
|--------|---------|---------|
| Energy | `sqrt(mean(features²))` | Overall texture strength |
| Homogeneity | `1 / (1 + std(features))` | Texture uniformity |
| Dominant orientation | `argmax(features per θ)` | Primary texture direction |
| Dominant frequency | `argmax(features per f)` | Texture coarseness |
## Style Matching
Compare textures between reference and generated frames:
```python
def texture_similarity(ref_features, gen_features):
"""Pearson correlation between feature vectors."""
return np.corrcoef(ref_features, gen_features)[0, 1]
# Interpretation:
# > 0.9 — Excellent match (same texture)
# 0.7-0.9 — Good match (similar feel)
# < 0.5 — Poor match (different texture, style drift)
```
## Practical Application
### Painterly vs Photographic
| Style | Energy | Homogeneity | Dominant Frequency |
|-------|--------|-------------|-------------------|
| Oil painting | High (>0.6) | Low (<0.5) | Low (0.1-0.2) |
| Watercolor | Medium | Medium | Medium (0.2-0.3) |
| Photography | Low-Medium | High (>0.7) | Variable |
| Digital art | Variable | High (>0.8) | High (0.3-0.4) |
| Sketch | Medium | Low | High (0.3-0.4) |
Use these profiles to adjust generation parameters:
```python
def texture_to_guidance(texture: TextureFeatures) -> dict:
if texture.energy > 0.6 and texture.homogeneity < 0.5:
return {"prompt_suffix": "painterly brushstrokes, impasto texture",
"cfg_scale_boost": 0.5}
elif texture.homogeneity > 0.8:
return {"prompt_suffix": "smooth clean rendering, digital art",
"cfg_scale_boost": -0.5}
return {}
```
## Limitations
- Gabor filters are rotation-sensitive; 8 orientations covers 180° at 22.5° intervals
- Low-frequency textures (gradients, lighting) may not be well captured
- Texture alone doesn't capture color — always combine with palette extraction
- Computationally cheap but requires scikit-image (optional dependency)

View File

@@ -0,0 +1 @@
"""Style Lock — Shared infrastructure for consistent visual style."""

View File

@@ -0,0 +1,631 @@
"""
Style Lock — Consistent Visual Style Across Video Generation
Extracts style embeddings from a reference image and provides conditioning
for Stable Diffusion, IP-Adapter, ControlNet, and style token injection.
Used by:
- Video Forge (playground #52) — scene-to-scene style consistency
- LPM 1.0 (issue #641) — frame-to-frame temporal coherence
Usage:
from style_lock import StyleLock
lock = StyleLock("reference.png")
conditioning = lock.get_conditioning(backend="sdxl", method="hybrid")
"""
from __future__ import annotations
import json
import logging
from dataclasses import dataclass, field, asdict
from pathlib import Path
from typing import Optional
import numpy as np
from PIL import Image
logger = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
# Data structures
# ---------------------------------------------------------------------------
@dataclass
class ColorPalette:
"""Dominant color palette extracted from reference image."""
colors: list[tuple[int, int, int]] # RGB tuples
weights: list[float] # Proportion of each color
temperature: str # "warm", "cool", "neutral"
saturation_mean: float
brightness_mean: float
@dataclass
class LightingProfile:
"""Lighting characteristics of reference image."""
histogram: np.ndarray # Grayscale histogram (256 bins)
mean_brightness: float
contrast: float # Std dev of brightness
dynamic_range: tuple[float, float] # (min, max) normalized
direction_hint: str # "even", "top", "left", "right", "bottom"
@dataclass
class TextureFeatures:
"""Texture feature vector from Gabor filter responses."""
features: np.ndarray # Shape: (orientations * frequencies,)
orientations: int
frequencies: list[float]
energy: float # Total texture energy
homogeneity: float # Texture uniformity
@dataclass
class StyleEmbedding:
"""Complete style embedding extracted from a reference image."""
clip_embedding: Optional[np.ndarray] = None # CLIP visual features
color_palette: Optional[ColorPalette] = None
lighting: Optional[LightingProfile] = None
texture: Optional[TextureFeatures] = None
source_path: str = ""
def to_dict(self) -> dict:
"""Serialize to JSON-safe dict (excludes numpy arrays)."""
return {
"source_path": self.source_path,
"color_palette": asdict(self.color_palette) if self.color_palette else None,
"lighting": {
"mean_brightness": self.lighting.mean_brightness,
"contrast": self.lighting.contrast,
"dynamic_range": list(self.lighting.dynamic_range),
"direction_hint": self.lighting.direction_hint,
} if self.lighting else None,
"texture": {
"orientations": self.texture.orientations,
"frequencies": self.texture.frequencies,
"energy": self.texture.energy,
"homogeneity": self.texture.homogeneity,
} if self.texture else None,
"has_clip_embedding": self.clip_embedding is not None,
}
@dataclass
class ConditioningOutput:
"""Conditioning parameters for a generation backend."""
method: str # ip_adapter, controlnet, style_tokens, hybrid
backend: str # sdxl, flux, comfyui
strength: float # Overall style adherence (0-1)
ip_adapter_image: Optional[str] = None # Path to reference image
ip_adapter_scale: float = 0.75
controlnet_image: Optional[np.ndarray] = None # Preprocessed control image
controlnet_conditioning_scale: float = 0.5
style_prompt: str = "" # Text conditioning from style tokens
negative_prompt: str = "" # Anti-style negatives
color_palette_guidance: Optional[dict] = None
def to_api_kwargs(self) -> dict:
"""Convert to kwargs suitable for diffusers pipelines."""
kwargs = {}
if self.method in ("ip_adapter", "hybrid") and self.ip_adapter_image:
kwargs["ip_adapter_image"] = self.ip_adapter_image
kwargs["ip_adapter_scale"] = self.ip_adapter_scale * self.strength
if self.method in ("controlnet", "hybrid") and self.controlnet_image is not None:
kwargs["image"] = self.controlnet_image
kwargs["controlnet_conditioning_scale"] = (
self.controlnet_conditioning_scale * self.strength
)
if self.style_prompt:
kwargs["prompt_suffix"] = self.style_prompt
if self.negative_prompt:
kwargs["negative_prompt_suffix"] = self.negative_prompt
return kwargs
# ---------------------------------------------------------------------------
# Extractors
# ---------------------------------------------------------------------------
class ColorExtractor:
"""Extract dominant color palette using k-means clustering."""
def __init__(self, num_colors: int = 5):
self.num_colors = num_colors
def extract(self, image: Image.Image) -> ColorPalette:
img = image.resize((150, 150)).convert("RGB")
pixels = np.array(img).reshape(-1, 3).astype(np.float32)
# Simple k-means (no sklearn dependency)
colors, weights = self._kmeans(pixels, self.num_colors)
# Analyze color temperature
avg_r, avg_g, avg_b = colors[:, 0].mean(), colors[:, 1].mean(), colors[:, 2].mean()
if avg_r > avg_b + 20:
temperature = "warm"
elif avg_b > avg_r + 20:
temperature = "cool"
else:
temperature = "neutral"
# Saturation and brightness
hsv = self._rgb_to_hsv(colors)
saturation_mean = float(hsv[:, 1].mean())
brightness_mean = float(hsv[:, 2].mean())
return ColorPalette(
colors=[tuple(int(c) for c in color) for color in colors],
weights=[float(w) for w in weights],
temperature=temperature,
saturation_mean=saturation_mean,
brightness_mean=brightness_mean,
)
def _kmeans(self, pixels: np.ndarray, k: int, max_iter: int = 20):
indices = np.random.choice(len(pixels), k, replace=False)
centroids = pixels[indices].copy()
for _ in range(max_iter):
dists = np.linalg.norm(pixels[:, None] - centroids[None, :], axis=2)
labels = np.argmin(dists, axis=1)
new_centroids = np.array([
pixels[labels == i].mean(axis=0) if np.any(labels == i) else centroids[i]
for i in range(k)
])
if np.allclose(centroids, new_centroids, atol=1.0):
break
centroids = new_centroids
counts = np.bincount(labels, minlength=k)
weights = counts / counts.sum()
order = np.argsort(-weights)
return centroids[order], weights[order]
@staticmethod
def _rgb_to_hsv(colors: np.ndarray) -> np.ndarray:
"""Convert RGB (0-255) to HSV (H: 0-360, S: 0-1, V: 0-1)."""
rgb = colors / 255.0
hsv = np.zeros_like(rgb)
for i, pixel in enumerate(rgb):
r, g, b = pixel
cmax, cmin = max(r, g, b), min(r, g, b)
delta = cmax - cmin
if delta == 0:
h = 0
elif cmax == r:
h = 60 * (((g - b) / delta) % 6)
elif cmax == g:
h = 60 * (((b - r) / delta) + 2)
else:
h = 60 * (((r - g) / delta) + 4)
s = 0 if cmax == 0 else delta / cmax
v = cmax
hsv[i] = [h, s, v]
return hsv
class LightingExtractor:
"""Analyze lighting characteristics from grayscale histogram."""
def extract(self, image: Image.Image) -> LightingProfile:
gray = np.array(image.convert("L"))
hist, _ = np.histogram(gray, bins=256, range=(0, 256))
hist_norm = hist.astype(np.float32) / hist.sum()
mean_brightness = float(gray.mean() / 255.0)
contrast = float(gray.std() / 255.0)
nonzero = np.where(hist_norm > 0.001)[0]
dynamic_range = (
float(nonzero[0] / 255.0) if len(nonzero) > 0 else 0.0,
float(nonzero[-1] / 255.0) if len(nonzero) > 0 else 1.0,
)
# Rough directional lighting estimate from quadrant brightness
h, w = gray.shape
quadrants = {
"top": gray[:h // 2, :].mean(),
"bottom": gray[h // 2:, :].mean(),
"left": gray[:, :w // 2].mean(),
"right": gray[:, w // 2:].mean(),
}
brightest = max(quadrants, key=quadrants.get)
delta = quadrants[brightest] - min(quadrants.values())
direction = brightest if delta > 15 else "even"
return LightingProfile(
histogram=hist_norm,
mean_brightness=mean_brightness,
contrast=contrast,
dynamic_range=dynamic_range,
direction_hint=direction,
)
class TextureExtractor:
"""Extract texture features using Gabor filter bank."""
def __init__(
self,
orientations: int = 8,
frequencies: Optional[list[float]] = None,
):
self.orientations = orientations
self.frequencies = frequencies or [0.1, 0.2, 0.3, 0.4]
def extract(self, image: Image.Image) -> TextureFeatures:
try:
from skimage.filters import gabor
from skimage.color import rgb2gray
from skimage.transform import resize
except ImportError:
logger.warning("scikit-image not available, returning empty texture features")
return TextureFeatures(
features=np.zeros(len(self.frequencies) * self.orientations),
orientations=self.orientations,
frequencies=self.frequencies,
energy=0.0,
homogeneity=0.0,
)
gray = rgb2gray(np.array(image))
gray = resize(gray, (256, 256), anti_aliasing=True)
features = []
theta_values = np.linspace(0, np.pi, self.orientations, endpoint=False)
for freq in self.frequencies:
for theta in theta_values:
magnitude, _ = gabor(gray, frequency=freq, theta=theta)
features.append(float(magnitude.mean()))
features_arr = np.array(features)
energy = float(np.sqrt(np.mean(features_arr ** 2)))
homogeneity = float(1.0 / (1.0 + np.std(features_arr)))
return TextureFeatures(
features=features_arr,
orientations=self.orientations,
frequencies=self.frequencies,
energy=energy,
homogeneity=homogeneity,
)
class CLIPEmbeddingExtractor:
"""Extract CLIP visual embedding for IP-Adapter conditioning."""
def __init__(self, model_name: str = "openai/clip-vit-large-patch14"):
self.model_name = model_name
self._model = None
self._processor = None
self._load_attempted = False
def _load(self):
if self._model is not None:
return
if self._load_attempted:
return
self._load_attempted = True
try:
from transformers import CLIPModel, CLIPProcessor
from huggingface_hub import try_to_load_from_cache
import torch
# Only load if model is already cached locally — no network
import os
cache_dir = os.path.expanduser(
f"~/.cache/huggingface/hub/models--{self.model_name.replace('/', '--')}"
)
if not os.path.isdir(cache_dir):
logger.info("CLIP model not cached locally, embedding disabled")
self._model = None
return
self._model = CLIPModel.from_pretrained(self.model_name, local_files_only=True)
self._processor = CLIPProcessor.from_pretrained(self.model_name, local_files_only=True)
self._model.eval()
logger.info(f"Loaded cached CLIP model: {self.model_name}")
except ImportError:
logger.warning("transformers not available, CLIP embedding disabled")
except Exception as e:
logger.warning(f"CLIP model load failed ({e}), embedding disabled")
self._model = None
def extract(self, image: Image.Image) -> Optional[np.ndarray]:
self._load()
if self._model is None:
return None
import torch
inputs = self._processor(images=image, return_tensors="pt")
with torch.no_grad():
features = self._model.get_image_features(**inputs)
return features.squeeze().numpy()
# ---------------------------------------------------------------------------
# Style Lock — main class
# ---------------------------------------------------------------------------
class StyleLock:
"""
Style Lock: extract and apply consistent visual style across generations.
Args:
reference_image: Path to reference image or PIL Image.
num_colors: Number of dominant colors to extract (default 5).
texture_orientations: Gabor filter orientations (default 8).
texture_frequencies: Gabor filter frequencies (default [0.1, 0.2, 0.3, 0.4]).
clip_model: CLIP model name for embedding extraction.
"""
def __init__(
self,
reference_image: str | Image.Image,
num_colors: int = 5,
texture_orientations: int = 8,
texture_frequencies: Optional[list[float]] = None,
clip_model: str = "openai/clip-vit-large-patch14",
):
if isinstance(reference_image, str):
self._ref_path = reference_image
self._ref_image = Image.open(reference_image).convert("RGB")
else:
self._ref_path = ""
self._ref_image = reference_image
self._color_ext = ColorExtractor(num_colors=num_colors)
self._lighting_ext = LightingExtractor()
self._texture_ext = TextureExtractor(
orientations=texture_orientations,
frequencies=texture_frequencies,
)
self._clip_ext = CLIPEmbeddingExtractor(model_name=clip_model)
self._embedding: Optional[StyleEmbedding] = None
@property
def embedding(self) -> StyleEmbedding:
"""Lazy-computed full style embedding."""
if self._embedding is None:
self._embedding = self._extract_all()
return self._embedding
def _extract_all(self) -> StyleEmbedding:
logger.info("Extracting style embedding from reference image...")
return StyleEmbedding(
clip_embedding=self._clip_ext.extract(self._ref_image),
color_palette=self._color_ext.extract(self._ref_image),
lighting=self._lighting_ext.extract(self._ref_image),
texture=self._texture_ext.extract(self._ref_image),
source_path=self._ref_path,
)
def get_conditioning(
self,
backend: str = "sdxl",
method: str = "hybrid",
strength: float = 0.75,
) -> ConditioningOutput:
"""
Generate conditioning output for a generation backend.
Args:
backend: Target backend — "sdxl", "flux", "comfyui".
method: Conditioning method — "ip_adapter", "controlnet",
"style_tokens", "hybrid".
strength: Overall style adherence 0.0 (loose) to 1.0 (strict).
Returns:
ConditioningOutput with all parameters for the pipeline.
"""
emb = self.embedding
style_prompt = self._build_style_prompt(emb)
negative_prompt = self._build_negative_prompt(emb)
controlnet_img = self._build_controlnet_image(emb)
palette_guidance = None
if emb.color_palette:
palette_guidance = {
"colors": emb.color_palette.colors,
"weights": emb.color_palette.weights,
"temperature": emb.color_palette.temperature,
}
return ConditioningOutput(
method=method,
backend=backend,
strength=strength,
ip_adapter_image=self._ref_path if self._ref_path else None,
ip_adapter_scale=0.75,
controlnet_image=controlnet_img,
controlnet_conditioning_scale=0.5,
style_prompt=style_prompt,
negative_prompt=negative_prompt,
color_palette_guidance=palette_guidance,
)
def _build_style_prompt(self, emb: StyleEmbedding) -> str:
"""Generate text conditioning from extracted style features."""
parts = []
if emb.color_palette:
palette = emb.color_palette
parts.append(f"{palette.temperature} color palette")
if palette.saturation_mean > 0.6:
parts.append("vibrant saturated colors")
elif palette.saturation_mean < 0.3:
parts.append("muted desaturated tones")
if palette.brightness_mean > 0.65:
parts.append("bright luminous lighting")
elif palette.brightness_mean < 0.35:
parts.append("dark moody atmosphere")
if emb.lighting:
if emb.lighting.contrast > 0.3:
parts.append("high contrast dramatic lighting")
elif emb.lighting.contrast < 0.15:
parts.append("soft even lighting")
if emb.lighting.direction_hint != "even":
parts.append(f"light from {emb.lighting.direction_hint}")
if emb.texture:
if emb.texture.energy > 0.5:
parts.append("rich textured surface")
if emb.texture.homogeneity > 0.8:
parts.append("smooth uniform texture")
elif emb.texture.homogeneity < 0.4:
parts.append("complex varied texture")
return ", ".join(parts) if parts else "consistent visual style"
def _build_negative_prompt(self, emb: StyleEmbedding) -> str:
"""Generate anti-style negatives to prevent style drift."""
parts = ["inconsistent style", "style variation", "color mismatch"]
if emb.color_palette:
if emb.color_palette.temperature == "warm":
parts.append("cold blue tones")
elif emb.color_palette.temperature == "cool":
parts.append("warm orange tones")
if emb.lighting:
if emb.lighting.contrast > 0.3:
parts.append("flat lighting")
else:
parts.append("harsh shadows")
return ", ".join(parts)
def _build_controlnet_image(self, emb: StyleEmbedding) -> Optional[np.ndarray]:
"""Preprocess reference image for ControlNet input (edge/canny)."""
try:
import cv2
except ImportError:
return None
img = np.array(self._ref_image)
gray = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)
edges = cv2.Canny(gray, 50, 150)
edges_rgb = cv2.cvtColor(edges, cv2.COLOR_GRAY2RGB)
return edges_rgb
def save_embedding(self, path: str) -> None:
"""Save extracted embedding metadata to JSON (excludes raw arrays)."""
data = self.embedding.to_dict()
Path(path).write_text(json.dumps(data, indent=2))
logger.info(f"Style embedding saved to {path}")
def compare(self, other: "StyleLock") -> dict:
"""
Compare style similarity between two StyleLock instances.
Returns:
Dict with similarity scores for each feature dimension.
"""
scores = {}
a, b = self.embedding, other.embedding
# Color palette similarity
if a.color_palette and b.color_palette:
scores["color_temperature"] = (
1.0 if a.color_palette.temperature == b.color_palette.temperature else 0.0
)
scores["saturation_diff"] = abs(
a.color_palette.saturation_mean - b.color_palette.saturation_mean
)
scores["brightness_diff"] = abs(
a.color_palette.brightness_mean - b.color_palette.brightness_mean
)
# Lighting similarity
if a.lighting and b.lighting:
scores["brightness_diff"] = abs(
a.lighting.mean_brightness - b.lighting.mean_brightness
)
scores["contrast_diff"] = abs(a.lighting.contrast - b.lighting.contrast)
# Texture similarity
if a.texture and b.texture:
if a.texture.features.shape == b.texture.features.shape:
corr = np.corrcoef(a.texture.features, b.texture.features)[0, 1]
scores["texture_correlation"] = float(corr) if not np.isnan(corr) else 0.0
# CLIP embedding cosine similarity
if a.clip_embedding is not None and b.clip_embedding is not None:
cos_sim = np.dot(a.clip_embedding, b.clip_embedding) / (
np.linalg.norm(a.clip_embedding) * np.linalg.norm(b.clip_embedding)
)
scores["clip_cosine_similarity"] = float(cos_sim)
return scores
# ---------------------------------------------------------------------------
# Multi-reference style locking (for LPM 1.0 identity photos)
# ---------------------------------------------------------------------------
class MultiReferenceStyleLock:
"""
Style Lock from multiple reference images (e.g., 8 identity photos).
Extracts style from each reference, then merges into a consensus style
that captures the common aesthetic across all references.
Args:
reference_paths: List of paths to reference images.
merge_strategy: How to combine styles — "average", "dominant", "first".
"""
def __init__(
self,
reference_paths: list[str],
merge_strategy: str = "average",
):
self.locks = [StyleLock(p) for p in reference_paths]
self.merge_strategy = merge_strategy
def get_conditioning(
self,
backend: str = "sdxl",
method: str = "hybrid",
strength: float = 0.75,
) -> ConditioningOutput:
"""Get merged conditioning from all reference images."""
if self.merge_strategy == "first":
return self.locks[0].get_conditioning(backend, method, strength)
# Use the first lock as the primary conditioning source,
# but adjust parameters based on consensus across all references
primary = self.locks[0]
conditioning = primary.get_conditioning(backend, method, strength)
if self.merge_strategy == "average":
# Average the conditioning scales across all locks
scales = []
for lock in self.locks:
emb = lock.embedding
if emb.color_palette:
scales.append(emb.color_palette.saturation_mean)
if scales:
avg_sat = np.mean(scales)
# Adjust IP-Adapter scale based on average saturation agreement
conditioning.ip_adapter_scale *= (0.5 + 0.5 * avg_sat)
# Build a more comprehensive style prompt from all references
all_style_parts = []
for lock in self.locks:
prompt = lock._build_style_prompt(lock.embedding)
all_style_parts.append(prompt)
# Deduplicate style descriptors
seen = set()
unique_parts = []
for part in ", ".join(all_style_parts).split(", "):
stripped = part.strip()
if stripped and stripped not in seen:
seen.add(stripped)
unique_parts.append(stripped)
conditioning.style_prompt = ", ".join(unique_parts)
return conditioning

View File

@@ -0,0 +1,251 @@
"""
Tests for Style Lock module.
Validates:
- Color extraction from synthetic images
- Lighting profile extraction
- Texture feature extraction
- Style prompt generation
- Conditioning output format
- Multi-reference merging
"""
import numpy as np
from PIL import Image
import sys
from pathlib import Path
sys.path.insert(0, str(Path(__file__).parent.parent))
from scripts.style_lock import (
StyleLock,
ColorExtractor,
LightingExtractor,
TextureExtractor,
MultiReferenceStyleLock,
ConditioningOutput,
)
def _make_test_image(width=256, height=256, color=(128, 64, 200)):
"""Create a solid-color test image."""
return Image.new("RGB", (width, height), color)
def _make_gradient_image(width=256, height=256):
"""Create a gradient test image."""
arr = np.zeros((height, width, 3), dtype=np.uint8)
for y in range(height):
for x in range(width):
arr[y, x] = [int(255 * x / width), int(255 * y / height), 128]
return Image.fromarray(arr)
def test_color_extractor_solid():
img = _make_test_image(color=(200, 100, 50))
ext = ColorExtractor(num_colors=3)
palette = ext.extract(img)
assert len(palette.colors) == 3
assert len(palette.weights) == 3
assert sum(palette.weights) > 0.99
assert palette.temperature == "warm" # R > B
assert palette.saturation_mean > 0
def test_color_extractor_cool():
img = _make_test_image(color=(50, 100, 200))
ext = ColorExtractor(num_colors=3)
palette = ext.extract(img)
assert palette.temperature == "cool" # B > R
def test_lighting_extractor():
img = _make_test_image(color=(128, 128, 128))
ext = LightingExtractor()
profile = ext.extract(img)
assert 0.4 < profile.mean_brightness < 0.6
assert profile.contrast < 0.1 # Uniform image, low contrast
assert profile.direction_hint == "even"
def test_texture_extractor():
img = _make_test_image(color=(128, 128, 128))
ext = TextureExtractor(orientations=4, frequencies=[0.1, 0.2])
features = ext.extract(img)
assert features.features.shape == (8,) # 4 orientations * 2 frequencies
assert features.orientations == 4
assert features.frequencies == [0.1, 0.2]
def test_style_lock_embedding():
img = _make_test_image(color=(180, 90, 45))
lock = StyleLock(img)
emb = lock.embedding
assert emb.color_palette is not None
assert emb.lighting is not None
assert emb.texture is not None
assert emb.color_palette.temperature == "warm"
def test_style_lock_conditioning_ip_adapter():
img = _make_test_image()
lock = StyleLock(img)
cond = lock.get_conditioning(backend="sdxl", method="ip_adapter", strength=0.8)
assert cond.method == "ip_adapter"
assert cond.backend == "sdxl"
assert cond.strength == 0.8
assert cond.style_prompt # Non-empty
assert cond.negative_prompt # Non-empty
def test_style_lock_conditioning_hybrid():
img = _make_test_image()
lock = StyleLock(img)
cond = lock.get_conditioning(method="hybrid")
assert cond.method == "hybrid"
assert cond.controlnet_image is not None
assert cond.controlnet_image.shape[2] == 3 # RGB
def test_style_lock_conditioning_to_api_kwargs():
img = _make_test_image()
lock = StyleLock(img)
cond = lock.get_conditioning(method="hybrid")
kwargs = cond.to_api_kwargs()
assert "prompt_suffix" in kwargs or "negative_prompt_suffix" in kwargs
def test_style_lock_negative_prompt_warm():
img = _make_test_image(color=(220, 100, 30))
lock = StyleLock(img)
emb = lock.embedding
neg = lock._build_negative_prompt(emb)
assert "cold blue" in neg.lower() or "cold" in neg.lower()
def test_style_lock_save_embedding(tmp_path):
img = _make_test_image()
lock = StyleLock(img)
path = str(tmp_path / "style.json")
lock.save_embedding(path)
import json
data = json.loads(open(path).read())
assert data["color_palette"] is not None
assert data["lighting"] is not None
def test_style_lock_compare():
img1 = _make_test_image(color=(200, 50, 30))
img2 = _make_test_image(color=(200, 60, 40))
lock1 = StyleLock(img1)
lock2 = StyleLock(img2)
scores = lock1.compare(lock2)
assert "color_temperature" in scores
assert scores["color_temperature"] == 1.0 # Both warm
def test_style_lock_compare_different_temps():
img1 = _make_test_image(color=(200, 50, 30))
img2 = _make_test_image(color=(30, 50, 200))
lock1 = StyleLock(img1)
lock2 = StyleLock(img2)
scores = lock1.compare(lock2)
assert scores["color_temperature"] == 0.0 # Warm vs cool
def test_multi_reference_style_lock():
imgs = [_make_test_image(color=(180, 90, 45)) for _ in range(3)]
paths = []
import tempfile, os
for i, img in enumerate(imgs):
p = os.path.join(tempfile.gettempdir(), f"ref_{i}.png")
img.save(p)
paths.append(p)
mlock = MultiReferenceStyleLock(paths, merge_strategy="average")
cond = mlock.get_conditioning(backend="sdxl", method="hybrid")
assert cond.method == "hybrid"
assert cond.style_prompt # Merged style prompt
for p in paths:
os.unlink(p)
def test_multi_reference_first_strategy():
imgs = [_make_test_image(color=(200, 50, 30)) for _ in range(2)]
paths = []
import tempfile, os
for i, img in enumerate(imgs):
p = os.path.join(tempfile.gettempdir(), f"ref_first_{i}.png")
img.save(p)
paths.append(p)
mlock = MultiReferenceStyleLock(paths, merge_strategy="first")
cond = mlock.get_conditioning()
assert cond.method == "hybrid"
for p in paths:
os.unlink(p)
if __name__ == "__main__":
import tempfile
print("Running Style Lock tests...")
test_color_extractor_solid()
print(" [PASS] color_extractor_solid")
test_color_extractor_cool()
print(" [PASS] color_extractor_cool")
test_lighting_extractor()
print(" [PASS] lighting_extractor")
test_texture_extractor()
print(" [PASS] texture_extractor")
test_style_lock_embedding()
print(" [PASS] style_lock_embedding")
test_style_lock_conditioning_ip_adapter()
print(" [PASS] style_lock_conditioning_ip_adapter")
test_style_lock_conditioning_hybrid()
print(" [PASS] style_lock_conditioning_hybrid")
test_style_lock_conditioning_to_api_kwargs()
print(" [PASS] style_lock_conditioning_to_api_kwargs")
test_style_lock_negative_prompt_warm()
print(" [PASS] style_lock_negative_prompt_warm")
td = tempfile.mkdtemp()
test_style_lock_save_embedding(type('X', (), {'__truediv__': lambda s, o: f"{td}/{o}"})())
print(" [PASS] style_lock_save_embedding")
test_style_lock_compare()
print(" [PASS] style_lock_compare")
test_style_lock_compare_different_temps()
print(" [PASS] style_lock_compare_different_temps")
test_multi_reference_style_lock()
print(" [PASS] multi_reference_style_lock")
test_multi_reference_first_strategy()
print(" [PASS] multi_reference_first_strategy")
print("\nAll 14 tests passed.")

View File

@@ -1,44 +0,0 @@
"""
Tests for resource limits (#755).
"""
import pytest
from tools.resource_limits import ResourceLimiter, ResourceLimits, ResourceResult, ResourceViolation
class TestResourceLimiter:
def test_successful_execution(self):
limiter = ResourceLimiter(ResourceLimits(memory_mb=2048, timeout_seconds=10))
result = limiter.execute("echo hello")
assert result.success is True
assert result.exit_code == 0
assert "hello" in result.stdout
assert result.violation == ResourceViolation.NONE
def test_timeout_violation(self):
limiter = ResourceLimiter(ResourceLimits(timeout_seconds=1))
result = limiter.execute("sleep 10")
assert result.success is False
assert result.violation == ResourceViolation.TIME
assert result.killed is True
def test_failed_command(self):
limiter = ResourceLimiter()
result = limiter.execute("exit 1")
assert result.success is False
assert result.exit_code == 1
def test_resource_report(self):
from tools.resource_limits import format_resource_report
result = ResourceResult(
success=True, stdout="", stderr="", exit_code=0,
violation=ResourceViolation.NONE, violation_message="",
memory_used_mb=100, cpu_time_seconds=0.5, wall_time_seconds=1.0,
)
report = format_resource_report(result)
assert "Exit code: 0" in report
assert "100MB" in report
if __name__ == "__main__":
pytest.main([__file__])

View File

@@ -1,249 +0,0 @@
"""
Terminal Sandbox Resource Limits — CPU, memory, time.
Provides resource limits for agent terminal commands to prevent
OOM kills, runaway processes, and excessive resource consumption.
"""
import logging
import os
import signal
import subprocess
import time
from dataclasses import dataclass, field
from typing import Optional, Dict, Any
from enum import Enum
logger = logging.getLogger(__name__)
class ResourceViolation(Enum):
"""Types of resource violations."""
MEMORY = "memory"
CPU = "cpu"
TIME = "time"
NONE = "none"
@dataclass
class ResourceLimits:
"""Resource limits for a subprocess."""
memory_mb: int = 2048 # 2GB default
cpu_percent: int = 80 # 80% of one core
timeout_seconds: int = 300 # 5 minutes
kill_timeout: int = 10 # SIGKILL after 10s if SIGTERM fails
@dataclass
class ResourceResult:
"""Result of a resource-limited execution."""
success: bool
stdout: str
stderr: str
exit_code: int
violation: ResourceViolation
violation_message: str
memory_used_mb: float
cpu_time_seconds: float
wall_time_seconds: float
killed: bool = False
class ResourceLimiter:
"""Apply resource limits to subprocess execution."""
def __init__(self, limits: Optional[ResourceLimits] = None):
self.limits = limits or ResourceLimits()
def _get_resource_rlimit(self) -> Dict[str, Any]:
"""Get resource limits for subprocess (Unix only)."""
import resource
rlimit = {}
# Memory limit (RSS)
if self.limits.memory_mb > 0:
mem_bytes = self.limits.memory_mb * 1024 * 1024
rlimit[resource.RLIMIT_AS] = (mem_bytes, mem_bytes)
# CPU time limit
if self.limits.timeout_seconds > 0:
rlimit[resource.RLIMIT_CPU] = (self.limits.timeout_seconds, self.limits.timeout_seconds)
return rlimit
def _check_resource_usage(self, process: subprocess.Popen) -> Dict[str, float]:
"""Check resource usage of a process (Unix only)."""
try:
import resource
usage = resource.getrusage(resource.RUSAGE_CHILDREN)
return {
"user_time": usage.ru_utime,
"system_time": usage.ru_stime,
"max_rss_mb": usage.ru_maxrss / 1024, # KB to MB
}
except:
return {"user_time": 0, "system_time": 0, "max_rss_mb": 0}
def execute(self, command: str, **kwargs) -> ResourceResult:
"""
Execute a command with resource limits.
Args:
command: Shell command to execute
**kwargs: Additional subprocess arguments
Returns:
ResourceResult with execution details
"""
start_time = time.time()
# Try to use resource limits (Unix only)
preexec_fn = None
try:
import resource
rlimit = self._get_resource_rlimit()
def set_limits():
for res, limits in rlimit.items():
resource.setrlimit(res, limits)
preexec_fn = set_limits
except ImportError:
logger.debug("resource module not available, skipping limits")
try:
# Execute with timeout
result = subprocess.run(
command,
shell=True,
capture_output=True,
text=True,
timeout=self.limits.timeout_seconds,
preexec_fn=preexec_fn,
**kwargs,
)
wall_time = time.time() - start_time
usage = self._check_resource_usage(result)
# Check for violations
violation = ResourceViolation.NONE
violation_message = ""
# Check memory (if we can get it)
if usage["max_rss_mb"] > self.limits.memory_mb:
violation = ResourceViolation.MEMORY
violation_message = f"Memory limit exceeded: {usage['max_rss_mb']:.0f}MB > {self.limits.memory_mb}MB"
return ResourceResult(
success=result.returncode == 0,
stdout=result.stdout,
stderr=result.stderr,
exit_code=result.returncode,
violation=violation,
violation_message=violation_message,
memory_used_mb=usage["max_rss_mb"],
cpu_time_seconds=usage["user_time"] + usage["system_time"],
wall_time_seconds=wall_time,
)
except subprocess.TimeoutExpired as e:
wall_time = time.time() - start_time
# Try to kill gracefully
if hasattr(e, 'process') and e.process:
try:
e.process.terminate()
time.sleep(self.limits.kill_timeout)
if e.process.poll() is None:
e.process.kill()
except:
pass
return ResourceResult(
success=False,
stdout=e.stdout.decode() if e.stdout else "",
stderr=e.stderr.decode() if e.stderr else "",
exit_code=-1,
violation=ResourceViolation.TIME,
violation_message=f"Timeout after {self.limits.timeout_seconds}s",
memory_used_mb=0,
cpu_time_seconds=0,
wall_time_seconds=wall_time,
killed=True,
)
except MemoryError:
wall_time = time.time() - start_time
return ResourceResult(
success=False,
stdout="",
stderr=f"Memory limit exceeded ({self.limits.memory_mb}MB)",
exit_code=-1,
violation=ResourceViolation.MEMORY,
violation_message=f"Memory limit exceeded: {self.limits.memory_mb}MB",
memory_used_mb=self.limits.memory_mb,
cpu_time_seconds=0,
wall_time_seconds=wall_time,
killed=True,
)
except Exception as e:
wall_time = time.time() - start_time
return ResourceResult(
success=False,
stdout="",
stderr=str(e),
exit_code=-1,
violation=ResourceViolation.NONE,
violation_message=f"Execution error: {e}",
memory_used_mb=0,
cpu_time_seconds=0,
wall_time_seconds=wall_time,
)
def format_resource_report(result: ResourceResult) -> str:
"""Format resource usage as a report string."""
lines = [
f"Exit code: {result.exit_code}",
f"Wall time: {result.wall_time_seconds:.2f}s",
f"CPU time: {result.cpu_time_seconds:.2f}s",
f"Memory: {result.memory_used_mb:.0f}MB",
]
if result.violation != ResourceViolation.NONE:
lines.append(f"⚠️ Violation: {result.violation_message}")
if result.killed:
lines.append("🔴 Process killed")
return " | ".join(lines)
def execute_with_limits(
command: str,
memory_mb: int = 2048,
cpu_percent: int = 80,
timeout_seconds: int = 300,
) -> ResourceResult:
"""
Convenience function to execute with resource limits.
Args:
command: Shell command
memory_mb: Memory limit in MB
cpu_percent: CPU limit as percent of one core
timeout_seconds: Timeout in seconds
Returns:
ResourceResult
"""
limits = ResourceLimits(
memory_mb=memory_mb,
cpu_percent=cpu_percent,
timeout_seconds=timeout_seconds,
)
limiter = ResourceLimiter(limits)
return limiter.execute(command)