diff --git a/apps/openenv/README.md b/apps/openenv/README.md new file mode 100644 index 000000000..cfb48755a --- /dev/null +++ b/apps/openenv/README.md @@ -0,0 +1,410 @@ +# OpenEnv - Generic GRPO Training Framework + +A centralized framework for training language models on any OpenEnv task using GRPO (Grouped Relative Policy Optimization) or DAPO (Decoupled Clip and Dynamic Sampling Policy Optimization). + +## Key Features + +- **GenericEnvClient**: Works with ANY OpenEnv Docker image without requiring environment-specific packages locally +- **GenericAction**: Simple dict wrapper that maps to environment-specific actions at runtime +- **Single Main Script**: One `main.py` works for all OpenEnv tasks +- **Circuit Breaker Pattern**: Automatic detection and restart of unhealthy Docker containers +- **Episode Dropout**: Configurable filtering of low-quality training batches +- **GRPO/DAPO Loss**: Switchable loss functions with configurable parameters +- **Parallel Evaluation**: Multiple env_actors for isolated, parallel reward evaluation + +## Folder Structure + +``` +apps/openenv/ + ├── main.py # Generic training script (use this) + ├── julia_utils.py # Julia task utilities (GenericAction) + ├── python_utils.py # Python task utilities (GenericAction) + ├── llama3_8b_julia.yaml # Julia training config + ├── llama3_8b_coding.yaml# Python coding training config + └── README.md # This file +``` + +## Quick Start + +### Run Julia Training + +```bash +python -m apps.openenv.main --config apps/openenv/llama3_8b_julia.yaml +``` + +### Run Python Coding Training + +```bash +python -m apps.openenv.main --config apps/openenv/llama3_8b_coding.yaml +``` + +## YAML Configuration + +### Minimal Configuration + +Each task config needs at minimum: + +```yaml +# Task-specific configuration +task: + env_name: "julia" # Environment name + build_action: !function apps.openenv.julia_utils.build_julia_action + evaluate_response: !function apps.openenv.julia_utils.evaluate_julia_response + transform_sample: !function apps.openenv.julia_utils.transform_julia_sample + +# OpenEnv configuration - only docker_image is required! +openenv_config: + docker_image: "julia-env:latest" +``` + +### Full Configuration Reference + +```yaml +# Global configuration +group_size: 8 # Number of responses per prompt +batch_size: 2 # Batches per training step +max_req_tokens: 1024 # Max prompt tokens +max_res_tokens: 1024 # Max response tokens +model: "path/to/model" # Model path +off_by_n: 1 # Max policy version age for episodes + +# Loss configuration (GRPO or DAPO) +grpo: + loss_type: grpo # "grpo" or "dapo" + clip_eps_low: 0.2 # Lower clipping bound + clip_eps_high: 0.28 # Upper clipping bound + agg_type: fixed_horizon # "fixed_horizon" (GRPO) or "token_mean" (DAPO) + beta: 0.1 # KL penalty (GRPO only) + dual_clip_c: 3.0 # Dual-clip constant (DAPO only) + +# Episode dropout configuration +episode_dropout: + enable_variance_dropout: true # Drop low-variance batches + enable_truncation_dropout: true # Drop batches with truncated responses + variance_threshold: 0.001 # Std threshold for variance dropout + +# Main loop configuration +rollout_threads: 1 # Parallel rollout threads +evaluation_timeout_s: 20.0 # Timeout for environment evaluation + +# Circuit breaker configuration +circuit_breaker: + threshold: 5 # Timeouts before tripping + window_s: 60.0 # Time window for counting timeouts + cooldown_s: 60.0 # Wait time after container restart + +# Task configuration +task: + env_name: "julia" + build_action: !function apps.openenv.julia_utils.build_julia_action + evaluate_response: !function apps.openenv.julia_utils.evaluate_julia_response + transform_sample: !function apps.openenv.julia_utils.transform_julia_sample + +# Dataset configuration +dataset: + path: "path/to/dataset.parquet" # Supports .parquet, .json, or HF datasets + data_split: "train" + streaming: false + +# OpenEnv configuration +openenv_config: + docker_image: "julia-env:latest" + container_timeout_s: 180.0 # Container startup timeout + container_memory_gb: 1024 # Container memory limit + port: 8000 # Starting port for containers + num_env_actors: 2 # Number of parallel reward actors + num_containers: 2 # Containers per actor + num_connections: 12 # WebSocket connections per container + request_timeout_s: 20.0 # Per-request timeout + env_vars: # Environment variables for containers + JULIA_EXECUTION_TIMEOUT: "15" + JULIA_MAX_WORKERS: "16" +``` + +## Adding a New Language + +To add support for a new language (e.g., Rust): + +### 1. Create Utils File + +Create `apps/openenv/rust_utils.py`: + +```python +from typing import Any, Dict +from openenv import GenericAction +from forge.observability.metrics import record_metric, Reduce + + +def get_rust_system_prompt() -> str: + """Get system prompt for Rust coding tasks.""" + return """You are an expert Rust programmer. +Write correct, safe Rust code that compiles and runs. +""".strip() + + +def build_rust_prompt(sample: Dict[str, Any], tokenizer) -> str: + """Build prompt for Rust code generation.""" + system_prompt = get_rust_system_prompt() + request = sample.get("prompt", "") + + messages = [ + {"role": "system", "content": system_prompt}, + {"role": "user", "content": request}, + ] + + return tokenizer.apply_chat_template( + messages, + tokenize=False, + add_generation_prompt=True, + ) + + +def build_rust_action(response: str, sample: Dict[str, Any]) -> GenericAction: + """Build GenericAction from model response.""" + code = extract_rust_code(response) + test_code = sample.get("target", "") + + # GenericAction fields must match what RustEnv expects + return GenericAction( + code=code, + test_code=test_code, + ) + + +def evaluate_rust_response(result, response: str, sample: Dict[str, Any]) -> float: + """Evaluate Rust code execution and return reward.""" + obs = result.observation + if isinstance(obs, dict): + exit_code = obs.get("exit_code", -1) + else: + exit_code = obs.exit_code + + reward = 1.0 if exit_code == 0 else 0.0 + record_metric("reward/rust/reward", reward, Reduce.MEAN) + return reward + + +def extract_rust_code(response: str) -> str: + """Extract Rust code from markdown blocks.""" + import re + pattern = r"```rust\n(.*?)```" + match = re.search(pattern, response, re.DOTALL) + if match: + return match.group(1).strip() + return response.strip() + + +def transform_rust_sample(sample: Dict[str, Any], tokenizer) -> Dict[str, Any] | None: + """Transform dataset sample for Rust tasks.""" + if not sample.get("prompt"): + return None + + return { + "request": build_rust_prompt(sample, tokenizer), + "target": sample.get("test", ""), + "task_id": sample.get("task_id", ""), + } +``` + +### 2. Create YAML Config + +Create `apps/openenv/llama3_8b_rust.yaml`: + +```yaml +# Rust training config using GenericEnvClient +group_size: 8 +batch_size: 2 +max_req_tokens: 1024 +max_res_tokens: 1024 +model: "path/to/model" + +grpo: + loss_type: grpo + clip_eps_low: 0.2 + clip_eps_high: 0.28 + beta: 0.1 + +task: + env_name: "rust" + build_action: !function apps.openenv.rust_utils.build_rust_action + evaluate_response: !function apps.openenv.rust_utils.evaluate_rust_response + transform_sample: !function apps.openenv.rust_utils.transform_rust_sample + +dataset: + path: "path/to/rust/dataset" + data_split: "train" + +openenv_config: + docker_image: "rust-env:latest" + container_timeout_s: 180.0 + num_env_actors: 2 + num_containers: 2 + num_connections: 8 + +# ... rest of config (copy from existing configs) +``` + +### 3. Run Training + +```bash +python -m apps.openenv.main --config apps/openenv/llama3_8b_rust.yaml +``` + +## Task Utils API + +Each task utils file should implement these functions: + +### Required Functions + +1. **`build__action(response: str, sample: dict) -> GenericAction`** + - Builds GenericAction from model response + - GenericAction fields must match what the environment expects + +2. **`evaluate__response(result, response: str, sample: dict) -> float`** + - Evaluates execution result and returns reward (0.0 to 1.0) + - Works with both typed observations and raw dicts + +3. **`transform__sample(sample: dict, tokenizer) -> dict | None`** + - Transforms raw dataset sample into training format + - Returns dict with 'request', 'target', 'task_id' or None if invalid + +### Optional Helper Functions + +- **`get__system_prompt() -> str`**: Get system prompt for the language +- **`build__prompt(sample: dict, tokenizer) -> str`**: Build formatted prompt +- **`extract__code(response: str) -> str`**: Extract code from markdown + +## Architecture + +### GenericEnvClient + +The `OpenEnvActor` manages Docker containers and WebSocket connections: + +``` +┌─────────────────────────────────────────────────────────────┐ +│ GenericRewardActor │ +│ ┌─────────────┐ ┌─────────────┐ ┌─────────────┐ │ +│ │ env_actor_0 │ │ env_actor_1 │ │ env_actor_2 │ ... │ +│ │ ┌─────────┐ │ │ ┌─────────┐ │ │ ┌─────────┐ │ │ +│ │ │Container│ │ │ │Container│ │ │ │Container│ │ │ +│ │ │ WS x12 │ │ │ │ WS x12 │ │ │ │ WS x12 │ │ │ +│ │ └─────────┘ │ │ └─────────┘ │ │ └─────────┘ │ │ +│ └─────────────┘ └─────────────┘ └─────────────┘ │ +└─────────────────────────────────────────────────────────────┘ +``` + +- Each `env_actor` manages its own container pool +- Circuit breaker isolates failures per actor +- Unhealthy actors trigger automatic container restart + +### Circuit Breaker + +The circuit breaker pattern prevents cascading failures: + +1. **Closed**: Normal operation, requests flow through +2. **Open**: Too many timeouts detected, actor marked unhealthy +3. **Half-Open**: After cooldown, actor retries with fresh container + +Configuration: +```yaml +circuit_breaker: + threshold: 5 # Timeouts before opening + window_s: 60.0 # Counting window + cooldown_s: 60.0 # Time before retry +``` + +### Episode Dropout + +Batches are filtered based on quality: + +1. **Variance Dropout**: Drops batches where all rewards are similar (e.g., all 0 or all 1) +2. **Truncation Dropout**: Drops batches with truncated responses (hit max_tokens) + +This prevents training on uninformative gradients. + + +## Observability + +### Metrics + +Key metrics tracked: +- `reward/*/avg_reward`: Average reward per task +- `reward/*/pass_rate`: Test pass rate +- `circuit_breaker/*/tripped`: Circuit breaker activations +- `episode/avg_response_tokens`: Average response length +- `training/weight_update_duration_s`: Weight sync time + +### Logging + +Set log level via environment variable: +```bash +LOG_LEVEL=DEBUG python -m apps.openenv.main --config ... +``` + +### Weights & Biases + +Configure in YAML: +```yaml +metric_logging: + wandb: + entity: "your-team" + project: "your-project" + logging_mode: global_reduce +``` + +## Performance Tuning + +### GPU Memory + +Enable expandable segments (set automatically in main.py): +```python +os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" +``` + +### Timeout Configuration + +For Julia (with internal worker pool): +```yaml +# Julia kills workers at 15s, we wait 20s to allow recovery +evaluation_timeout_s: 20.0 +openenv_config: + request_timeout_s: 20.0 + env_vars: + JULIA_EXECUTION_TIMEOUT: "15" +``` + +### Buffer Starvation + +If training stalls waiting for episodes: +1. Increase `off_by_n` to accept older episodes +2. Increase `rollout_threads` for more parallel generation +3. Increase policy `num_replicas` for more generation capacity + +Environment variables for debugging: +```bash +FORGE_MAX_EMPTY_BUFFER_WAIT_S=120 # Max wait before error +FORGE_BACKPRESSURE_TIMEOUT_S=30 # Max backpressure wait +``` + +## Debugging + +### Common Issues + +1. **No code extracted**: Model not following format + - Check system prompt in utils file + - Verify `extract_*_code()` handles model output format + +2. **All evaluations timeout**: Container issues + - Check container logs + - Reduce `num_connections` to prevent overload + - Increase `container_memory_gb` + +3. **Circuit breaker keeps tripping**: Environment instability + - Increase `threshold` for more tolerance + - Check for memory leaks in environment + - Add more `num_containers` for redundancy + +4. **Buffer starvation**: Training faster than rollouts + - Increase `off_by_n` (accept older episodes) + - Increase `rollout_threads` + - Add more policy replicas diff --git a/apps/openenv/julia_utils.py b/apps/openenv/julia_utils.py new file mode 100644 index 000000000..aba59964c --- /dev/null +++ b/apps/openenv/julia_utils.py @@ -0,0 +1,348 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +""" +Julia task-specific utilities using GenericAction. + +This version uses OpenEnv's GenericAction (a simple dict wrapper) instead of +the environment-specific JuliaAction class. This means you don't need to +install the julia_env package locally. + +Usage: + # In your YAML config: + task: + env_name: "julia" + build_action: !function apps.openenv.julia_utils.build_julia_action + evaluate_response: !function apps.openenv.julia_utils.evaluate_julia_response + transform_sample: !function apps.openenv.julia_utils.transform_julia_sample +""" + +import re +from typing import Any, Dict + +from openenv import GenericAction + +from forge.observability.metrics import record_metric, Reduce + + +def get_julia_system_prompt() -> str: + """Get system prompt for Julia coding tasks.""" + return """You are a precise and pragmatic Julia programmer. + +Write a **single Julia function** that correctly solves the problem described below. + +CRITICAL - Julia is NOT Python! Use correct Julia syntax: +- Use `lowercase()` NOT `tolower()` +- Use `uppercase()` NOT `upper()` +- Use `reverse()` NOT `rev()` or `reversed()` +- Use `parse(Int, x)` or `Int(x)` for type conversion, NOT `int(x)` +- Use `string()` for string conversion, NOT `str()` +- Use `filter()` NOT `subset()` +- Use `length()` NOT `len()` +- Use `push!()` to append to arrays, NOT `append()` +- String indexing: `str[i]` returns a Char, use `str[i:i]` for single-char String +- Arrays are 1-indexed, NOT 0-indexed +- Use `println()` NOT `print()` for line output +- Use `Dict()` NOT `dict()` +- Boolean operators: `&&` for AND, `||` for OR, `!` for NOT +- Check string contains: `occursin(needle, haystack)` NOT `in` or `contains(haystack, needle)` + +Example - Convert string to uppercase and reverse: +```julia +function process_text(text::String) + upper_text = uppercase(text) # NOT upper() + reversed_text = reverse(upper_text) # NOT rev() + return reversed_text +end +``` + +Example - Work with integers and arrays: +```julia +function sum_digits(n::Int) + total = 0 + digits_arr = Int[] # Empty array + while n > 0 + digit = n % 10 + push!(digits_arr, digit) # NOT append() + total += digit + n = div(n, 10) + end + return total +end +``` + +Rules: +- The code must be syntactically correct and runnable as is +- Use only the Julia standard library +- Do **not** wrap the code in a module or add a `main` function +- Do **not** include any test code in your response +- Do **not** hardcode specific test cases or outputs — the function must work for general inputs +- The **function name must exactly match** the one used in the provided tests +- Respond with **only the Julia function** and nothing else (no explanations, no comments, no extra text) +- Character literal should not contain multiple characters +- Take care of object types and mind that spaces matter in Julia + +Passing tests and clean, compilable code are rewarded. Hardcoding or failing tests is penalized. + +FORMAT YOUR RESPONSE AS: + +```julia +function () + +end +``` +""".strip() + + +def build_julia_prompt(sample: Dict[str, Any], tokenizer) -> str: + """ + Build prompt for Julia code generation. + + Args: + sample: Dataset sample with 'julia_prompt', 'julia_test', 'first_test_case', 'task_id' + tokenizer: HuggingFace tokenizer for chat template + + Returns: + Formatted prompt string ready for model generation + """ + system_prompt = get_julia_system_prompt() + request = sample.get("julia_prompt", "") + + messages = [ + {"role": "system", "content": system_prompt}, + {"role": "user", "content": request}, + ] + + formatted_request = tokenizer.apply_chat_template( + messages, + tokenize=False, + add_generation_prompt=True, + ) + + return formatted_request + + +def build_julia_action(response: str, sample: Dict[str, Any]) -> GenericAction: + """ + Build GenericAction from model response and dataset sample. + + This uses GenericAction (a simple dict wrapper) instead of JuliaAction, + so you don't need to install julia_env locally. + + Args: + response: Model's generated response + sample: Dataset sample with 'target' field containing test code + + Returns: + GenericAction instance with core_code and test_code fields + """ + # Extract code from markdown if present + code = extract_julia_code(response) + + # Get test code from dataset + test_code = sample.get("target", "") + + # GenericAction is just a dict wrapper - same fields as JuliaAction + return GenericAction( + core_code=code, + test_code=test_code, + ) + + +def evaluate_julia_response(result, response: str, sample: Dict[str, Any]) -> float: + """ + Evaluate Julia code execution result and return reward. + + Works with both typed observations (JuliaObservation) and raw dicts + returned by GenericEnvClient. + + Args: + result: StepResult from environment execution + response: Model's response (for logging) + sample: Dataset sample (for logging) + + Returns: + Reward score (0.0 to 1.0) + """ + try: + print("=" * 80) + print("RAW RESPONSE FROM MODEL:") + print("-" * 80) + print(response) + print("-" * 80) + + # Extract code for validation + code = extract_julia_code(response) + + if not code: + print("No Julia code extracted - Reward: 0.0") + print("=" * 80) + record_metric("reward/julia/no_code_extracted", 1, Reduce.SUM) + return 0.0 + + # Get test code from sample for logging + test_code = sample.get("target", "") + + # Log both code and test_code together for extraction + print("EXTRACTED JULIA CODE:") + print("-" * 80) + print(code) + print("-" * 80) + print("TEST CODE:") + print("-" * 80) + print(test_code) + print("-" * 80) + print("END OF SAMPLE") + print("=" * 80) + + # Validate for common Python-like syntax errors + is_valid, validation_warnings = validate_julia_syntax(code) + if not is_valid: + print("SYNTAX VALIDATION WARNINGS:") + for warning in validation_warnings: + print(f" {warning}") + print("-" * 80) + record_metric("reward/julia/syntax_warnings", len(validation_warnings), Reduce.SUM) + + # Extract reward from result + reward = result.reward if result.reward is not None else 0.0 + + # Handle both typed observation and dict observation (from GenericEnvClient) + obs = result.observation + if isinstance(obs, dict): + # GenericEnvClient returns dicts + passed = obs.get("tests_passed", 0) + failed = obs.get("tests_failed", 0) + exit_code = obs.get("exit_code", -1) + code_compiles = obs.get("code_compiles", False) + stderr = obs.get("stderr", "") + stdout = obs.get("stdout", "") + else: + # Typed observation (JuliaObservation) + passed = obs.tests_passed + failed = obs.tests_failed + exit_code = obs.exit_code + code_compiles = obs.code_compiles + stderr = obs.stderr + stdout = obs.stdout + + total = passed + failed + + # Log execution details + print("JuliaEnv Execution Result:") + print(f" Reward: {reward:.3f}") + print(f" Tests Passed: {passed}") + print(f" Tests Failed: {failed}") + print(f" Total Tests: {total}") + print(f" Exit Code: {exit_code}") + print(f" Code Compiles: {code_compiles}") + + if stderr: + print(f" Stderr: {stderr[:500]}") + record_metric("reward/julia/has_errors", 1, Reduce.SUM) + + if stdout: + print(f" Stdout (first 200 chars): {stdout[:200]}") + + # Log metrics + pass_rate = passed / total if total > 0 else 0.0 + record_metric("reward/julia/pass_rate", pass_rate, Reduce.MEAN) + + print(f"Final Reward: {reward:.3f}") + print("=" * 80) + + return reward + + except Exception as e: + print(f"✗ Error evaluating response: {e} - Reward: 0.0") + print("=" * 80) + record_metric("reward/julia/evaluation_errors", 1, Reduce.SUM) + return 0.0 + + +def extract_julia_code(response: str) -> str: + """ + Extract Julia code from markdown code blocks. + + Args: + response: Model's response text + + Returns: + Extracted Julia code + """ + text = re.sub(r"^```julia\s*\n?", "", response, flags=re.IGNORECASE) + text = re.sub(r"\n?```\s*$", "", text) + return text.strip() + + +def validate_julia_syntax(code: str) -> tuple[bool, list[str]]: + """ + Validate Julia code for common Python-like syntax errors. + + Args: + code: Julia code string to validate + + Returns: + Tuple of (is_valid, list of warning messages) + """ + warnings = [] + + python_functions = { + r'\btolower\(': 'tolower() -> use lowercase()', + r'\bupper\(': 'upper() -> use uppercase()', + r'\brev\(': 'rev() -> use reverse()', + r'\bint\(': 'int() -> use parse(Int, x) or Int(x)', + r'\bstr\(': 'str() -> use string()', + r'\blen\(': 'len() -> use length()', + r'\bsubset\(': 'subset() -> use filter()', + r'\bappend\(': 'append() -> use push!()', + r'\bdict\(': 'dict() -> use Dict()', + r'\breversed\(': 'reversed() -> use reverse()', + r'\.append\(': '.append() -> use push!()', + r'\.lower\(': '.lower() -> use lowercase()', + r'\.upper\(': '.upper() -> use uppercase()', + } + + for pattern, suggestion in python_functions.items(): + if re.search(pattern, code, re.IGNORECASE): + warnings.append(f"⚠ Found Python-like syntax: {suggestion}") + + if re.search(r'\[\s*0\s*\]', code): + warnings.append("⚠ Found [0] indexing - Julia arrays are 1-indexed") + + if 'function' in code and not re.search(r'\bend\b', code): + warnings.append("⚠ Function missing 'end' keyword") + + is_valid = len(warnings) == 0 + return is_valid, warnings + + +def transform_julia_sample(sample: Dict[str, Any], tokenizer) -> Dict[str, Any] | None: + """ + Transform raw dataset sample into training format. + + Args: + sample: Raw dataset sample + tokenizer: HuggingFace tokenizer + + Returns: + Transformed sample with 'request', 'target', 'task_id' or None if invalid + """ + if not sample.get("julia_test") or not sample.get("first_test_case"): + if not hasattr(transform_julia_sample, "_warned"): + print( + f"WARNING: Sample rejected - missing 'julia_test' or 'first_test_case' field. Sample keys: {list(sample.keys())}" + ) + transform_julia_sample._warned = True + return None + + formatted_request = build_julia_prompt(sample, tokenizer) + + return { + "request": formatted_request, + "target": sample.get("julia_test", ""), + "task_id": sample.get("task_id", ""), + } diff --git a/apps/openenv/llama3_8b_coding.yaml b/apps/openenv/llama3_8b_coding.yaml new file mode 100644 index 000000000..9f4b2d180 --- /dev/null +++ b/apps/openenv/llama3_8b_coding.yaml @@ -0,0 +1,210 @@ +# Grouped Relative Policy Optimization (GRPO) for Python Code Generation +# Using GenericEnvClient (no environment-specific packages required) +# >>> python -m apps.openenv.main --config apps/openenv/llama3_8b_coding.yaml + +# Global configuration +group_size: 8 +batch_size: 2 +max_req_tokens: 1024 +max_res_tokens: 2048 +model: "/home/kaiwu/.cache/huggingface/hub/models--meta-llama--Meta-Llama-3.1-8B-Instruct/snapshots/0e9e39f249a16976918f6564b8830bc894c89659" +off_by_n: 1 # Allow episodes from previous 2 policy versions to prevent buffer starvation + +# GRPO/DAPO loss configuration +grpo: + loss_type: dapo # Options: "grpo" or "dapo" (default: dapo) + # Common parameters + clip_eps_low: 0.2 # Lower clipping bound (alias: clip_low) + clip_eps_high: 0.28 # Upper clipping bound (alias: clip_high) + agg_type: token_mean # Aggregation: "token_mean" (DAPO) or "fixed_horizon" (GRPO) + # GRPO-specific parameters + beta: 0.1 # KL penalty coefficient (only used when loss_type: grpo) + # DAPO-specific parameters + dual_clip_c: 3.0 # Dual-clip constant (only used when loss_type: dapo) + +# Episode dropout configuration (aligned with GRPO reference) +episode_dropout: + enable_variance_dropout: true # Drop batches with low reward variance + enable_truncation_dropout: true # Drop batches with truncated responses + variance_threshold: 0.001 # std threshold for variance dropout + +# Main loop configuration +rollout_threads: 2 +sample_timeout_s: 300 +evaluation_timeout_s: 60.0 + +# Task-specific configuration using GenericAction (no coding_env import needed) +task: + env_name: "coding" + # Use python_utils which uses GenericAction instead of CodeAction + build_action: !function apps.openenv.python_utils.build_python_action + evaluate_response: !function apps.openenv.python_utils.evaluate_python_response + transform_sample: !function apps.openenv.python_utils.transform_python_sample + +# Observability configuration +metric_logging: + wandb: + entity: "torchforge" + project: "kaiwu-openenv-grpo" + logging_mode: global_reduce + console: + logging_mode: global_reduce + log_per_rank: True + +# Dataset configuration +dataset: + # TODO: Replace with your dataset path + path: "/home/kaiwu/work/kaiwu/AceCoder/train/train_rl/OpenRLHF/scripts/data/acecode_89k/acecode_hard02.json" + revision: "main" + data_split: "train" + streaming: false + model: ${model} + +# OpenEnv configuration - only docker_image is required! +# GenericEnvClient handles everything else automatically +openenv_config: + docker_image: "coding-env:latest" # Just specify the image, no env_class needed + env_name: "coding" # Environment name for logging paths + container_timeout_s: 180.0 + container_memory_gb: 128 + port: 8000 + num_worker: 16 + request_timeout_s: 60.0 + # Environment-specific env vars + env_vars: + PYTHON_ADDITIONAL_IMPORTS: "sys,os,functools,typing,numpy,dataclasses,copy,heapq,enum,string,ast,json,struct,base64,csv,math,cmath,abc,contextlib,inspect,secrets,uuid,pathlib,io,threading,asyncio,concurrent.futures,urllib.parse,socket,random,itertools,collections,time,datetime,re,statistics" + +# Policy configuration +policy: + prefetch_weights_to_shm: false + engine_args: + model: ${model} + tensor_parallel_size: 1 + pipeline_parallel_size: 1 + enforce_eager: false + gpu_memory_utilization: 0.85 + max_model_len: 8192 + enable_chunked_prefill: true + max_num_batched_tokens: 8192 + max_num_seqs: 32 + sampling_params: + n: ${group_size} + max_tokens: ${max_res_tokens} + temperature: 1.0 + top_p: 1.0 + min_tokens: 10 + logprobs: 1 # Required for both GRPOLoss and DAPOLoss + truncate_prompt_tokens: null + include_stop_str_in_output: false + +# Trainer configuration +trainer: + model: + name: llama3 + flavor: 8B + hf_assets_path: ${model} + optimizer: + name: AdamW + lr: 1e-5 + eps: 1e-8 + weight_decay: 0.01 + lr_scheduler: + warmup_steps: 50 + training: + local_batch_size: ${multiply:${batch_size},${group_size}} + seq_len: ${sum:${max_req_tokens},${max_res_tokens}} + max_norm: 1.0 + steps: 1500 + dtype: bfloat16 + gc_freq: 1 + compile: + enable: false + parallelism: + data_parallel_replicate_degree: 1 + data_parallel_shard_degree: 1 + tensor_parallel_degree: 1 + pipeline_parallel_degree: 1 + context_parallel_degree: 1 + expert_parallel_degree: 1 + disable_loss_parallel: true + checkpoint: + enable: true + folder: "checkpoint_llama3_8b_coding" + initial_load_path: ${model} + initial_load_in_hf: true + last_save_in_hf: true + interval: 150 + async_mode: "async" + activation_checkpoint: + mode: selective + selective_ac_option: op + +# Replay buffer configuration +replay_buffer: + batch_size: ${trainer.training.local_batch_size} + max_policy_age: ${off_by_n} + dp_size: ${trainer.parallelism.data_parallel_shard_degree} + +# Reference model configuration +ref_model: + model: + name: llama3 + flavor: 8B + hf_assets_path: ${model} + training: + seq_len: ${trainer.training.seq_len} + dtype: bfloat16 + gc_freq: 1 + compile: + enable: false + parallelism: + data_parallel_replicate_degree: 1 + data_parallel_shard_degree: 1 + tensor_parallel_degree: 1 + pipeline_parallel_degree: 1 + context_parallel_degree: 1 + expert_parallel_degree: 1 + checkpoint: + enable: true + initial_load_path: ${model} + initial_load_in_hf: true + +# All resource allocations +services: + policy: + procs: ${policy.engine_args.tensor_parallel_size} + num_replicas: 2 + mesh_name: policy + with_gpus: true + ref_model: + procs: 1 + num_replicas: 1 + mesh_name: ref_model + with_gpus: true + reward_actor: + procs: 1 + num_replicas: 1 + mesh_name: reward_actor + with_gpus: false + +actors: + dataset: + procs: 1 + with_gpus: false + mesh_name: dataset + trainer: + procs: 1 + with_gpus: true + mesh_name: trainer + replay_buffer: + procs: 1 + with_gpus: false + mesh_name: replay_buffer + compute_advantages: + procs: 1 + with_gpus: false + mesh_name: compute_advantages + coding_env: + procs: 1 + with_gpus: false + mesh_name: coding_env diff --git a/apps/openenv/llama3_8b_julia.yaml b/apps/openenv/llama3_8b_julia.yaml new file mode 100644 index 000000000..bb96d8e28 --- /dev/null +++ b/apps/openenv/llama3_8b_julia.yaml @@ -0,0 +1,223 @@ +# Grouped Relative Policy Optimization (GRPO) for Julia Code Generation +# Using GenericEnvClient (no environment-specific packages required) +# >>> python -m apps.openenv.main --config apps/openenv/llama3_8b_julia.yaml + +# Global configuration +group_size: 8 +batch_size: 2 +max_req_tokens: 1024 +max_res_tokens: 1024 +model: "meta-llama/Meta-Llama-3.1-8B-Instruct" +off_by_n: 1 +compile: true # Enable torch.compile for trainer/ref_model, and CUDA graphs for vLLM + +# GRPO/DAPO loss configuration - using DAPO (recommended, no ref_model KL needed) +grpo: + loss_type: dapo # Options: "grpo" or "dapo" (default: dapo) + # Common parameters + agg_type: token_mean # Aggregation: "fixed_horizon" (GRPO) or "token_mean" (DAPO) + # DAPO-specific parameters + clip_low: 0.2 # Lower clip bound (ratio clamped to min 1 - clip_low = 0.8) + clip_high: 0.28 # Upper clip bound (ratio clamped to max 1 + clip_high = 1.28) + dual_clip_c: 3.0 # Dual-clip cap constant for negative advantages + +# Episode dropout configuration (aligned with GRPO reference) +episode_dropout: + enable_variance_dropout: true # Drop batches with low reward variance + enable_truncation_dropout: true # Drop batches with truncated responses + variance_threshold: 0.001 # std threshold for variance dropout + +# Main loop configuration +rollout_threads: 1 +sample_timeout_s: 15 +# IMPORTANT: evaluation_timeout_s must be > JULIA_EXECUTION_TIMEOUT to allow Julia +# to kill stuck workers and recover. Julia timeout = 15s, so we wait 20s. +evaluation_timeout_s: 20.0 + +# Circuit breaker configuration - triggers container restart on repeated failures +circuit_breaker: + threshold: 5 # Number of timeouts before tripping + window_s: 60.0 # Time window for counting timeouts + cooldown_s: 60.0 # Wait time for container restart to complete + +# Task-specific configuration using GenericAction (no julia_env import needed) +task: + env_name: "julia" + # Use julia_utils which uses GenericAction instead of JuliaAction + build_action: !function apps.openenv.julia_utils.build_julia_action + evaluate_response: !function apps.openenv.julia_utils.evaluate_julia_response + transform_sample: !function apps.openenv.julia_utils.transform_julia_sample + +#Observability configuration +metric_logging: + wandb: + entity: "torchforge" + project: "kaiwu-openenv-grpo" + logging_mode: global_reduce + console: + logging_mode: global_reduce + log_per_rank: True + +# Dataset configuration +dataset: + # TODO: Replace with your dataset path + path: "/home/kaiwu/work/amd/amd-submission/julia_trainset.parquet" + revision: "main" + data_split: "train" + streaming: false + model: ${model} + +# OpenEnv configuration - only docker_image is required! +# GenericEnvClient handles everything else automatically +openenv_config: + docker_image: "julia-env:latest" # Just specify the image, no env_class needed + env_name: "julia" # Environment name for logging paths + container_timeout_s: 180.0 + container_memory_gb: 1024 + port: 8000 + num_env_actors: 2 # Multiple actors for circuit breaker isolation + num_containers: 2 # 1 container per actor (2 total) + num_connections: 12 # Connections per container (24 total across 2 actors) + request_timeout_s: 20.0 # Should be > JULIA_EXECUTION_TIMEOUT to allow recovery + env_vars: + JULIAUP_OFFLINE: "1" # Disable juliaup version checks (no internet on cluster) + JULIA_EXECUTION_TIMEOUT: "15" # Julia-side timeout (must be < request_timeout_s) + JULIA_MAX_WORKERS: "8" # Reduced to avoid resource exhaustion with multiple containers + +# Policy configuration +policy: + prefetch_weights_to_shm: false # GPU Direct RDMA is now fixed with proper timeout + engine_args: + model: ${model} + tensor_parallel_size: 1 + pipeline_parallel_size: 1 + enforce_eager: false + gpu_memory_utilization: 0.85 + max_model_len: 8192 + enable_chunked_prefill: true + max_num_batched_tokens: 8192 + max_num_seqs: 32 + sampling_params: + n: ${group_size} + max_tokens: ${max_res_tokens} + temperature: 1.0 + top_p: 1.0 + min_tokens: 10 + logprobs: 1 # Required for both GRPOLoss and DAPOLoss + truncate_prompt_tokens: null + include_stop_str_in_output: false + +# Trainer configuration +trainer: + model: + name: llama3 + flavor: 8B + hf_assets_path: hf://${model} + optimizer: + name: AdamW + lr: 5e-6 # Reduced from 1e-5 for training stability + eps: 1e-8 + weight_decay: 0.01 + lr_scheduler: + warmup_steps: 50 + training: + local_batch_size: ${multiply:${batch_size},${group_size}} + seq_len: ${sum:${max_req_tokens},${max_res_tokens}} + max_norm: 1.0 + steps: 1500 + dtype: bfloat16 + gc_freq: 1 + compile: + enable: ${compile} + parallelism: + data_parallel_replicate_degree: 1 + data_parallel_shard_degree: 1 + tensor_parallel_degree: 1 + pipeline_parallel_degree: 1 + context_parallel_degree: 1 + expert_parallel_degree: 1 + disable_loss_parallel: true + checkpoint: + enable: true + folder: "checkpoint_llama3_8b_julia" + initial_load_path: hf://${model} # The path to load the initial checkpoint from. Ignored if `folder` exists. + initial_load_in_hf: true + last_save_in_hf: true + interval: 150 + async_mode: "async" + activation_checkpoint: + mode: selective + selective_ac_option: op + +# Replay buffer configuration +replay_buffer: + batch_size: ${trainer.training.local_batch_size} + max_policy_age: ${off_by_n} + dp_size: ${trainer.parallelism.data_parallel_shard_degree} + +# Reference model configuration (only used when services.ref_model is enabled) +# Kept for documentation - uncomment services.ref_model to enable for GRPO +ref_model: + model: + name: llama3 + flavor: 8B + hf_assets_path: hf://${model} + training: + seq_len: ${trainer.training.seq_len} + dtype: bfloat16 + gc_freq: 1 + compile: + enable: false + parallelism: + data_parallel_replicate_degree: 1 + data_parallel_shard_degree: 1 + tensor_parallel_degree: 1 + pipeline_parallel_degree: 1 + context_parallel_degree: 1 + expert_parallel_degree: 1 + checkpoint: + enable: true + initial_load_path: hf://${model} + initial_load_in_hf: true + +# All resource allocations +services: + policy: + procs: ${policy.engine_args.tensor_parallel_size} + num_replicas: 1 + mesh_name: policy + with_gpus: true + # ref_model: Only needed for GRPO with beta > 0 (KL penalty) + # Commented out for DAPO to save a GPU + # ref_model: + # procs: 1 + # num_replicas: 1 + # mesh_name: ref_model + # with_gpus: true + reward_actor: + procs: 1 + num_replicas: 1 + mesh_name: reward_actor + with_gpus: false + +actors: + dataset: + procs: 1 + with_gpus: false + mesh_name: dataset + trainer: + procs: 1 + with_gpus: true + mesh_name: trainer + replay_buffer: + procs: 1 + with_gpus: false + mesh_name: replay_buffer + compute_advantages: + procs: 1 + with_gpus: false + mesh_name: compute_advantages + julia_env: + procs: 1 + with_gpus: false + mesh_name: julia_env diff --git a/apps/openenv/main.py b/apps/openenv/main.py new file mode 100644 index 000000000..3dd69151f --- /dev/null +++ b/apps/openenv/main.py @@ -0,0 +1,1425 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +""" +OpenEnv GRPO Training Script using GenericEnvClient. + +This version uses GenericEnvClient and GenericAction to work with ANY +OpenEnv environment without requiring environment-specific packages. + +Usage: + python -m apps.openenv.main --config apps/openenv/llama3_8b_julia.yaml + python -m apps.openenv.main --config apps/openenv/llama3_8b_coding.yaml +""" + +from __future__ import annotations + +# CRITICAL: Set CUDA allocator config BEFORE any PyTorch imports +# This enables expandable segments which: +# 1. Reduces GPU memory fragmentation +# 2. Enables GPU Direct RDMA for faster weight updates (~4s vs ~10s) +# 3. Prevents OOM errors when storage volume uses GPU memory +import os +os.environ.setdefault("PYTORCH_CUDA_ALLOC_CONF", "expandable_segments:True") + +import asyncio +import importlib +import logging +import sys +import time +import uuid +from dataclasses import dataclass, field +from pathlib import Path +from typing import Any, Callable, Dict, TYPE_CHECKING + +# Type-only imports to avoid runtime import of openenv (which pulls in fastmcp/docket +# and conflicts with monarch's OpenTelemetry meter provider) +if TYPE_CHECKING: + from openenv import GenericAction + from openenv.core.client_types import StepResult + +# CRITICAL: Add openenv directory to sys.path at module level +_appdir = Path(__file__).parent +if str(_appdir) not in sys.path: + sys.path.insert(0, str(_appdir)) + +import torch +import torch.nn.functional as F +import torchstore as ts +import yaml +from datasets import load_dataset +from forge.actors.generator import Generator +from forge.actors.openenv import OpenEnvActor +# find_available_port is used by OpenEnvActor internally +from forge.actors.reference_model import ReferenceModel +from forge.actors.replay_buffer import ReplayBuffer +from forge.actors.trainer import TitanTrainer +from forge.controller.actor import ForgeActor +from forge.controller.provisioner import init_provisioner, shutdown +from forge.data_models.completion import Completion +from forge.observability.metric_actors import get_or_create_metric_logger +from forge.observability.metrics import record_metric, Reduce +from forge.observability.perf_tracker import Tracer +from forge.rl.loss import GRPOLoss, DAPOLoss +from forge.types import LauncherConfig, ProvisionerConfig, TrainBatch +from forge.util.checkpoint import drop_weights +from forge.util.config import parse +from monarch.actor import endpoint +from omegaconf import DictConfig, ListConfig +from vllm.transformers_utils.tokenizer import get_tokenizer + + +# Set up module logger +logger = logging.getLogger(__name__) +log_level = os.getenv("LOG_LEVEL", "INFO").upper() +logging.basicConfig( + level=getattr(logging, log_level, logging.INFO), + format="[%(levelname)s %(name)s] %(message)s", +) + + +@dataclass +class Episode: + episode_id: str + pad_id: int + request_len: int + response_len: int + target: Any | None = None + completion: Completion | None = None + ref_logprobs: torch.Tensor | None = None + generator_logprobs: torch.Tensor | None = None # For GRPOLoss + loss_mask: torch.Tensor | None = None # For GRPOLoss + reward: float | None = None + advantage: float | None = None + + @property + def policy_version(self) -> int | None: + return self.completion.generator_version if self.completion else None + + @property + def stop_reason(self) -> str | None: + """Get stop reason from completion for truncation detection.""" + return self.completion.stop_reason if self.completion else None + + @property + def request_tensor(self) -> torch.Tensor: + tensor: torch.Tensor = self.completion.prompt_ids.to(torch.long) + if tensor.shape[0] > self.request_len: # truncate from left (keep end) + tensor = tensor[-self.request_len :] + elif tensor.shape[0] < self.request_len: # left pad + diff = self.request_len - tensor.shape[0] + tensor = F.pad(tensor, (diff, 0), value=self.pad_id) + return tensor + + @property + def response_tensor(self) -> torch.Tensor: + tensor: torch.Tensor = self.completion.token_ids.to(torch.long) + if tensor.shape[0] > self.response_len: # truncate from right (keep beginning) + tensor = tensor[: self.response_len] + elif tensor.shape[0] < self.response_len: # right pad + diff = self.response_len - tensor.shape[0] + tensor = F.pad(tensor, (0, diff), value=self.pad_id) + return tensor + + +Group = list[Episode] +Policy = Generator + + +def load_function_from_string(func_ref: str) -> Callable: + """Load a function from a string reference like 'module.function_name'.""" + openenv_dir = Path(__file__).parent + if str(openenv_dir) not in sys.path: + sys.path.insert(0, str(openenv_dir)) + + module_name, func_name = func_ref.rsplit(".", 1) + module = importlib.import_module(module_name) + return getattr(module, func_name) + + +def function_constructor(loader, node): + """YAML constructor for !function tag.""" + value = loader.construct_scalar(node) + return ("!function", value) + + +yaml.add_constructor("!function", function_constructor, Loader=yaml.SafeLoader) + + +def collate( + batches: list[Group], +) -> list[TrainBatch]: + """Collates a list of batches into TrainBatch objects. + + Supports both GRPOLoss (requires generator_logprobs, loss_mask) and DAPOLoss. + """ + result = [] + for batch_idx, batch in enumerate(batches): + logger.debug(f"collate Processing batch {batch_idx}, len={len(batch)}") + + request = [e.request_tensor for e in batch] + request = torch.stack(request) + + response = [e.response_tensor for e in batch] + response = torch.stack(response) + + input_ids = torch.cat([request, response], dim=1) + seq_len = input_ids.shape[1] + + # ref_logprobs is optional - only stack if all episodes have it + ref_logprobs = None + if all(e.ref_logprobs is not None for e in batch): + ref_logprobs = torch.stack([e.ref_logprobs for e in batch]) + + advantages = [e.advantage for e in batch] + advantages = torch.tensor(advantages).unsqueeze(-1) # [b x 1] + advantages = advantages.expand(-1, seq_len) # [b x s] + + generator_logprobs = torch.stack([e.generator_logprobs for e in batch]) + loss_mask = torch.stack([e.loss_mask for e in batch]) + + loss_inputs = { + "generator_logprobs": generator_logprobs, + "loss_mask": loss_mask, + "advantages": advantages, + } + # Include ref_logprobs for GRPOLoss (uses it for KL penalty when beta > 0) + if ref_logprobs is not None: + loss_inputs["ref_logprobs"] = ref_logprobs + + result.append( + TrainBatch( + model_inputs={"tokens": input_ids}, + loss_inputs=loss_inputs, + ) + ) + return result + + +def make_loss(cfg: DictConfig): + """Factory function to create loss based on config. + + Supports both GRPOLoss and DAPOLoss based on `loss_type` config. + + Args: + cfg: Configuration dict containing `grpo` section with: + - loss_type: "grpo" or "dapo" (default: "dapo") + - beta: KL penalty coefficient (for GRPOLoss only) + - clip_eps_low / clip_low: Lower clipping bound + - clip_eps_high / clip_high: Upper clipping bound + - agg_type: Aggregation type + - dual_clip_c: Dual-clip constant (for DAPOLoss only) + + Returns: + Loss function (GRPOLoss or DAPOLoss instance) + """ + grpo_cfg = cfg.get("grpo", {}) + loss_type = grpo_cfg.get("loss_type", "dapo").lower() + + # Support both naming conventions + clip_low = grpo_cfg.get("clip_eps_low", grpo_cfg.get("clip_low", 0.2)) + clip_high = grpo_cfg.get("clip_eps_high", grpo_cfg.get("clip_high", 0.28)) + + if loss_type == "grpo": + beta = grpo_cfg.get("beta", 0.1) + agg_type = grpo_cfg.get("agg_type", "fixed_horizon") + logger.info( + f"Using GRPOLoss with clip_low={clip_low}, clip_high={clip_high}, " + f"beta={beta}, agg_type={agg_type}" + ) + return GRPOLoss( + clip_low=clip_low, + clip_high=clip_high, + beta=beta, + agg_type=agg_type, + ) + elif loss_type == "dapo": + dual_clip_c = grpo_cfg.get("dual_clip_c", 3.0) + agg_type = grpo_cfg.get("agg_type", "token_mean") + logger.info( + f"Using DAPOLoss with clip_low={clip_low}, clip_high={clip_high}, " + f"dual_clip_c={dual_clip_c}, agg_type={agg_type}" + ) + return DAPOLoss( + clip_low=clip_low, + clip_high=clip_high, + dual_clip_c=dual_clip_c, + agg_type=agg_type, + ) + else: + raise ValueError( + f"Unknown loss_type: {loss_type}. Supported: 'grpo', 'dapo'" + ) + + +@dataclass +class GenericRewardActor(ForgeActor): + """Generic reward actor that uses GenericEnvClient and GenericAction. + + Supports multiple env_actors for parallel evaluation across different + WebSocket connections. Includes circuit breaker pattern to detect and + restart unhealthy containers. + """ + + env_actors: list # List of OpenEnvActor instances + build_action_fn: Callable[[str, Dict[str, Any]], GenericAction] + evaluate_response_fn: Callable[[StepResult, str, Dict[str, Any]], float] + evaluation_timeout_s: float = 60.0 + + # Circuit breaker configuration + circuit_breaker_threshold: int = 10 # Timeouts before marking unhealthy + circuit_breaker_window_s: float = 60.0 # Time window for counting timeouts + circuit_breaker_cooldown_s: float = 30.0 # Cooldown before retrying unhealthy actor + + _request_counter: int = 0 # For round-robin distribution + + # Circuit breaker state (initialized in setup using field defaults for safety) + _actor_timeout_counts: list = field(default_factory=list) # Timeout count per actor + _actor_timeout_timestamps: list = field(default_factory=list) # Recent timeout timestamps per actor + _actor_healthy: list = field(default_factory=list) # Health status per actor + _actor_cooldown_until: list = field(default_factory=list) # Cooldown end time per actor + _restart_in_progress: list = field(default_factory=list) # Restart lock per actor + _restart_tasks: list = field(default_factory=list) # Track restart tasks for cleanup + + @endpoint + async def setup(self): + """Ensure the openenv directory is in sys.path for imports.""" + logger.debug("GenericRewardActor.setup Starting setup...") + openenv_dir = Path(__file__).parent + if str(openenv_dir) not in sys.path: + sys.path.insert(0, str(openenv_dir)) + + # Initialize circuit breaker state + num_actors = len(self.env_actors) + self._actor_timeout_counts = [0] * num_actors + self._actor_timeout_timestamps = [[] for _ in range(num_actors)] + self._actor_healthy = [True] * num_actors + self._actor_cooldown_until = [0.0] * num_actors + self._restart_in_progress = [False] * num_actors + + logger.debug( + f"GenericRewardActor.setup Timeout set to {self.evaluation_timeout_s}s" + ) + logger.debug(f"GenericRewardActor.setup Using {num_actors} env_actors for parallel evaluation") + logger.debug( + f"GenericRewardActor.setup Circuit breaker: threshold={self.circuit_breaker_threshold}, " + f"window={self.circuit_breaker_window_s}s" + ) + logger.debug("GenericRewardActor.setup Setup complete!") + + def _get_healthy_actor_idx(self) -> int: + """Get the next healthy actor index using round-robin with health awareness. + + Returns: + Index of a healthy actor, or the least-bad unhealthy actor if all are unhealthy + """ + num_actors = len(self.env_actors) + current_time = time.time() + + # Try to find a healthy actor using round-robin + for _ in range(num_actors): + idx = self._request_counter % num_actors + self._request_counter += 1 + + if self._actor_healthy[idx]: + return idx + + # Check if cooldown has expired for unhealthy actor + if current_time >= self._actor_cooldown_until[idx]: + # Cooldown expired, give it another chance + logger.info(f"Circuit breaker: Actor {idx} cooldown expired, retrying") + self._actor_healthy[idx] = True + self._actor_timeout_counts[idx] = 0 + self._actor_timeout_timestamps[idx] = [] + return idx + + # All actors are unhealthy and in cooldown - use the one with earliest cooldown end + earliest_idx = min(range(num_actors), key=lambda i: self._actor_cooldown_until[i]) + logger.warning(f"Circuit breaker: All actors unhealthy, using actor {earliest_idx}") + return earliest_idx + + def _record_timeout(self, actor_idx: int): + """Record a timeout for an actor and check if circuit should trip.""" + current_time = time.time() + + # Add timestamp to recent timeouts + self._actor_timeout_timestamps[actor_idx].append(current_time) + + # Remove old timestamps outside the window + window_start = current_time - self.circuit_breaker_window_s + self._actor_timeout_timestamps[actor_idx] = [ + ts for ts in self._actor_timeout_timestamps[actor_idx] + if ts >= window_start + ] + + # Count recent timeouts + recent_timeouts = len(self._actor_timeout_timestamps[actor_idx]) + self._actor_timeout_counts[actor_idx] = recent_timeouts + + record_metric(f"circuit_breaker/actor_{actor_idx}/timeout_count", recent_timeouts, Reduce.MAX) + + # Check if circuit should trip + if recent_timeouts >= self.circuit_breaker_threshold and self._actor_healthy[actor_idx]: + logger.error( + f"Circuit breaker TRIPPED for actor {actor_idx}: " + f"{recent_timeouts} timeouts in {self.circuit_breaker_window_s}s" + ) + self._actor_healthy[actor_idx] = False + self._actor_cooldown_until[actor_idx] = current_time + self.circuit_breaker_cooldown_s + record_metric(f"circuit_breaker/actor_{actor_idx}/tripped", 1, Reduce.SUM) + + # Trigger async restart with error handling callback + def _restart_done_callback(task, idx=actor_idx): + try: + exc = task.exception() + if exc is not None: + logger.error(f"Circuit breaker: Restart task for actor {idx} failed: {exc}") + except asyncio.CancelledError: + pass # Task was cancelled, not an error + # Remove completed task from tracking list + if task in self._restart_tasks: + self._restart_tasks.remove(task) + + restart_task = asyncio.create_task(self._restart_actor(actor_idx)) + restart_task.add_done_callback(_restart_done_callback) + self._restart_tasks.append(restart_task) # Track for cleanup during shutdown + + def _record_success(self, actor_idx: int): + """Record a successful execution for an actor.""" + # Successful execution reduces timeout pressure + # We don't clear all timeouts, but the window will naturally expire them + + async def _restart_actor(self, actor_idx: int): + """Restart an unhealthy actor's container pool.""" + if self._restart_in_progress[actor_idx]: + logger.debug(f"Restart already in progress for actor {actor_idx}") + return + + self._restart_in_progress[actor_idx] = True + record_metric(f"circuit_breaker/actor_{actor_idx}/restart_attempt", 1, Reduce.SUM) + + try: + logger.warning(f"Circuit breaker: Initiating FULL POOL restart for actor {actor_idx}") + + env_actor = self.env_actors[actor_idx] + result = await env_actor.restart_container.call_one() + + if result.get("success"): + logger.info( + f"Circuit breaker: Actor {actor_idx} pool restarted successfully - " + f"{result.get('num_containers')} containers, {result.get('num_connections')} connections" + ) + self._actor_healthy[actor_idx] = True + self._actor_timeout_counts[actor_idx] = 0 + self._actor_timeout_timestamps[actor_idx] = [] + self._actor_cooldown_until[actor_idx] = 0.0 + record_metric(f"circuit_breaker/actor_{actor_idx}/restart_success", 1, Reduce.SUM) + else: + logger.error( + f"Circuit breaker: Actor {actor_idx} restart failed: {result.get('error')}" + ) + # Extend cooldown on failure + self._actor_cooldown_until[actor_idx] = time.time() + self.circuit_breaker_cooldown_s * 2 + record_metric(f"circuit_breaker/actor_{actor_idx}/restart_failure", 1, Reduce.SUM) + + except Exception as e: + logger.error(f"Circuit breaker: Exception during restart of actor {actor_idx}: {e}") + import traceback + traceback.print_exc() + self._actor_cooldown_until[actor_idx] = time.time() + self.circuit_breaker_cooldown_s * 2 + record_metric(f"circuit_breaker/actor_{actor_idx}/restart_failure", 1, Reduce.SUM) + + finally: + self._restart_in_progress[actor_idx] = False + + @endpoint + async def evaluate_response(self, prompt: str, response: str, target: Any) -> float: + """ + Evaluate response using task-specific functions with timeout protection. + + Uses health-aware round-robin distribution across env_actors with + circuit breaker pattern to detect and restart unhealthy containers. + + Args: + prompt: The problem description + response: The model's generated response + target: The target/test data from dataset + + Returns: + Reward score (0.0 if timeout or error) + """ + # Initialize actor index for error logging (may be updated below) + env_actor_idx = -1 + + try: + # Build action using task-specific function (returns GenericAction) + sample = {"target": target} + action = self.build_action_fn(response, sample) + + # Get healthy actor using circuit breaker logic + env_actor_idx = self._get_healthy_actor_idx() + env_actor = self.env_actors[env_actor_idx] + + # Execute in environment with timeout protection + result = await asyncio.wait_for( + env_actor.execute.call_one( + dict(action) + ), # Convert to dict for serialization + timeout=self.evaluation_timeout_s, + ) + + # Record success + self._record_success(env_actor_idx) + + # Evaluate result using task-specific function + reward = self.evaluate_response_fn(result, response, sample) + + record_metric("reward/evaluate_response/sum_reward", reward, Reduce.SUM) + record_metric("reward/evaluate_response/avg_reward", reward, Reduce.MEAN) + record_metric("reward/evaluate_response/count_calls", 1, Reduce.SUM) + + return reward + + except asyncio.TimeoutError: + logger.warning( + f"Evaluation timeout after {self.evaluation_timeout_s}s on actor {env_actor_idx} " + f"- likely infinite loop in generated code" + ) + # Record timeout for circuit breaker + self._record_timeout(env_actor_idx) + record_metric("reward/evaluate_response/timeout_count", 1, Reduce.SUM) + return 0.0 + + except Exception as e: + logger.error(f"Evaluation error on actor {env_actor_idx}: {e}") + # Connection errors also count towards circuit breaker + # Only record timeout if we actually got an actor (env_actor_idx >= 0) + # to avoid recording against wrong actor when build_action_fn fails + if env_actor_idx >= 0 and ("connection" in str(e).lower() or "websocket" in str(e).lower()): + self._record_timeout(env_actor_idx) + record_metric("reward/evaluate_response/error_count", 1, Reduce.SUM) + return 0.0 + + @endpoint + async def cancel_restart_tasks(self) -> int: + """Cancel all pending restart tasks during shutdown. + + Returns: + Number of tasks that were cancelled + """ + cancelled_count = 0 + for task in self._restart_tasks: + if not task.done(): + task.cancel() + cancelled_count += 1 + # Wait for all tasks to complete cancellation + if self._restart_tasks: + await asyncio.gather(*self._restart_tasks, return_exceptions=True) + self._restart_tasks.clear() + return cancelled_count + + @endpoint + async def get_health_status(self) -> Dict[str, Any]: + """Get health status of all env_actors for monitoring.""" + return { + "actors": [ + { + "index": i, + "healthy": self._actor_healthy[i], + "timeout_count": self._actor_timeout_counts[i], + "cooldown_remaining": max(0, self._actor_cooldown_until[i] - time.time()), + "restart_in_progress": self._restart_in_progress[i], + } + for i in range(len(self.env_actors)) + ], + "healthy_count": sum(self._actor_healthy), + "total_count": len(self.env_actors), + } + + +@dataclass +class ComputeAdvantages(ForgeActor): + @endpoint + async def setup(self): + logger.debug("ComputeAdvantages.setup Setup complete!") + + @endpoint + async def compute(self, group: Group) -> list[float]: + rewards = torch.tensor([[e.reward for e in group]], dtype=torch.float32) + mean = rewards.mean(1, keepdim=True) + std = rewards.std(1, keepdim=True) + advantages = (rewards - mean) / (std + 1e-4) + return advantages.squeeze(0).tolist() + + +@dataclass +class GenericDatasetActor(ForgeActor): + """Generic dataset actor that uses task-specific transformation function.""" + + path: str + revision: str = "main" + data_split: str = "train" + streaming: bool = False + model: str = "meta-llama/Meta-Llama-3.1-8B-Instruct" + transform_sample_fn: Callable | None = None + + @endpoint + async def setup(self): + openenv_dir = Path(__file__).parent + if str(openenv_dir) not in sys.path: + sys.path.insert(0, str(openenv_dir)) + + self._tokenizer = get_tokenizer(self.model) + + logger.info(f"Loading dataset from: {self.path}") + if os.path.isfile(self.path): + if self.path.endswith(".parquet"): + ds = load_dataset( + "parquet", + data_files={"train": self.path}, + split=self.data_split, + ) + elif self.path.endswith(".json"): + ds = load_dataset( + "json", + data_files={"train": self.path}, + split=self.data_split, + ) + else: + raise ValueError(f"Unsupported file format: {self.path}") + else: + ds = load_dataset( + self.path, + split=self.data_split, + streaming=self.streaming, + revision=self.revision, + ) + + if len(ds) == 0: + raise ValueError(f"Dataset is empty after loading from {self.path}.") + + logger.info(f"Dataset loaded, size: {len(ds)}") + + if self.transform_sample_fn: + def transform_wrapper(sample): + return self.transform_sample_fn(sample, self._tokenizer) + + original_size = len(ds) + ds = ds.filter(lambda x: transform_wrapper(x) is not None) + filtered_size = len(ds) + logger.info(f"Dataset filtered: {original_size} -> {filtered_size}") + + if filtered_size == 0: + raise ValueError( + f"Dataset transform filtered out ALL {original_size} samples!" + ) + + ds = ds.map(transform_wrapper) + + ds = ds.shuffle() + + if len(ds) == 0: + raise ValueError("Dataset is empty after all transformations!") + + self._dataset = ds # Keep reference for looping + self._iterator = iter(ds) + self._loop_count = 0 + logger.info(f"Dataset setup complete! Size: {len(ds)}") + + @endpoint + async def sample(self) -> dict[str, str] | None: + try: + sample = next(self._iterator) + record_metric("dataset/sample/count_samples_generated", 1, Reduce.SUM) + if "request" in sample: + record_metric( + "dataset/sample/avg_sample_len", + len(sample["request"]), + Reduce.MEAN, + ) + return sample + except StopIteration: + # Loop the dataset instead of returning None + self._loop_count += 1 + logger.info(f"[DATASET] Completed epoch {self._loop_count}, reshuffling and restarting...") + record_metric("dataset/epoch_completed", 1, Reduce.SUM) + self._dataset = self._dataset.shuffle() + self._iterator = iter(self._dataset) + # Return first sample from new epoch + sample = next(self._iterator) + record_metric("dataset/sample/count_samples_generated", 1, Reduce.SUM) + return sample + + @endpoint + async def pad_token(self): + if self._tokenizer.pad_token_id is not None: + return self._tokenizer.pad_token_id + return self._tokenizer.eos_token_id + + +async def main(cfg: DictConfig): + """Main GRPO training loop using GenericEnvClient.""" + group_size = cfg.group_size + max_req_tokens = cfg.max_req_tokens + max_res_tokens = cfg.max_res_tokens + + # Load task-specific functions + logger.debug("main Loading task-specific functions...") + task_config = cfg.task + + build_action_fn = None + evaluate_response_fn = None + transform_sample_fn = None + + if ( + isinstance(task_config.build_action, (tuple, list, ListConfig)) + and len(task_config.build_action) == 2 + and task_config.build_action[0] == "!function" + ): + build_action_fn = load_function_from_string(task_config.build_action[1]) + + if ( + isinstance(task_config.evaluate_response, (tuple, list, ListConfig)) + and len(task_config.evaluate_response) == 2 + and task_config.evaluate_response[0] == "!function" + ): + evaluate_response_fn = load_function_from_string( + task_config.evaluate_response[1] + ) + + if hasattr(task_config, "transform_sample"): + if ( + isinstance(task_config.transform_sample, (tuple, list, ListConfig)) + and len(task_config.transform_sample) == 2 + and task_config.transform_sample[0] == "!function" + ): + transform_sample_fn = load_function_from_string( + task_config.transform_sample[1] + ) + + logger.debug("main All task-specific functions loaded successfully") + + # Global setups + provisioner = None + if cfg.get("provisioner", None) is not None: + provisioner = await init_provisioner( + ProvisionerConfig(launcher_config=LauncherConfig(**cfg.provisioner)) + ) + else: + provisioner = await init_provisioner() + + metric_logging_cfg = cfg.get("metric_logging", {}) + mlogger = await get_or_create_metric_logger(process_name="Controller") + await mlogger.init_backends.call_one(metric_logging_cfg) + + # ---- Setup loss function ---- # + loss_fn = make_loss(cfg) + + # Fail-fast: Check loss/ref_model compatibility before spawning actors + uses_ref_model = cfg.get("services", {}).get("ref_model") is not None + if uses_ref_model and not isinstance(loss_fn, GRPOLoss): + logger.warning( + f"ref_model is configured but {type(loss_fn).__name__} does not use ref_logprobs. " + "Consider removing the ref_model service config to save GPU resources." + ) + if isinstance(loss_fn, GRPOLoss) and loss_fn.beta > 0 and not uses_ref_model: + raise ValueError( + f"GRPOLoss with beta={loss_fn.beta} requires ref_logprobs, but ref_model is not configured. " + "Either add ref_model to services config or set beta=0." + ) + + # Setup OpenEnvActor - works with ANY OpenEnv Docker image + openenv_config = cfg.get("openenv_config", {}) + docker_image = openenv_config.get("docker_image") + env_vars = openenv_config.get("env_vars", {}) + container_timeout_s = openenv_config.get("container_timeout_s", 180.0) + request_timeout_s = openenv_config.get("request_timeout_s", 120.0) + container_memory_gb = openenv_config.get("container_memory_gb", 4) + + # Set environment variables from config + if "PORT" not in env_vars: + env_vars["PORT"] = str(openenv_config.get("port", 8000)) + if "NUM_WORKER" not in env_vars: + env_vars["NUM_WORKER"] = str(openenv_config.get("num_worker", 4)) + + # Get env_name for actor mesh naming and logging paths + env_name = openenv_config.get("env_name", task_config.get("env_name", "generic")) + + logger.debug( + f"main Initializing OpenEnvActor with image={docker_image}..." + ) + + # Smart container allocation: Create one actor per concurrent evaluation needed + # Each actor manages its own container(s) with connection pooling + + num_env_actors = openenv_config.get("num_env_actors", cfg.get("group_size", 8)) + num_containers_per_actor = openenv_config.get("num_containers", 1) + num_connections_per_container = openenv_config.get("num_connections", 1) + + logger.info( + f"Creating {num_env_actors} env_actors, each with {num_containers_per_actor} containers " + f"and {num_connections_per_container} connections per container" + ) + + # Create env_actors + env_actors = [] + base_port = openenv_config.get("port", 8000) + + for i in range(num_env_actors): + actor_env_vars = env_vars.copy() + # Each actor starts from a different port range to avoid conflicts + actor_port = base_port - (i * num_containers_per_actor * 2) + + logger.debug( + f"Creating env_actor {i + 1}/{num_env_actors} starting at port {actor_port}" + ) + + env_actor = await OpenEnvActor.options( + **cfg.actors.get(f"{env_name}_env", cfg.actors.get("env", {})) + ).as_actor( + docker_image=docker_image, + env_name=env_name, + env_vars=actor_env_vars, + container_timeout_s=container_timeout_s, + request_timeout_s=request_timeout_s, + container_memory_gb=container_memory_gb, + port=actor_port, + num_containers=num_containers_per_actor, + num_connections=num_connections_per_container, + ) + env_actors.append(env_actor) + + total_containers = num_env_actors * num_containers_per_actor + logger.info( + f"All {num_env_actors} env_actors initialized successfully " + f"({total_containers} total containers)" + ) + + # Create all other actors + async def noop(): + return None + + ( + dataloader, + policy, + trainer, + replay_buffer, + compute_advantages, + ref_model, + reward_actor, + ) = await asyncio.gather( + GenericDatasetActor.options(**cfg.actors.dataset).as_actor( + path=cfg.dataset.path, + revision=cfg.dataset.get("revision", "main"), + data_split=cfg.dataset.get("data_split", "train"), + streaming=cfg.dataset.get("streaming", False), + model=cfg.model, + transform_sample_fn=transform_sample_fn, + ), + Policy.options(**cfg.services.policy).as_service(**cfg.policy), + TitanTrainer.options(**cfg.actors.trainer).as_actor( + **cfg.trainer, + loss=loss_fn, + ), + ReplayBuffer.options(**cfg.actors.replay_buffer).as_actor( + **cfg.replay_buffer, collate=collate + ), + ComputeAdvantages.options(**cfg.actors.compute_advantages).as_actor(), + ( + ReferenceModel.options(**cfg.services.ref_model).as_service(**cfg.ref_model) + if uses_ref_model + else noop() + ), + GenericRewardActor.options(**cfg.services.reward_actor).as_service( + env_actors=env_actors, + build_action_fn=build_action_fn, + evaluate_response_fn=evaluate_response_fn, + evaluation_timeout_s=cfg.get("evaluation_timeout_s", 60.0), + # Circuit breaker configuration + circuit_breaker_threshold=cfg.get("circuit_breaker", {}).get("threshold", 10), + circuit_breaker_window_s=cfg.get("circuit_breaker", {}).get("window_s", 60.0), + circuit_breaker_cooldown_s=cfg.get("circuit_breaker", {}).get("cooldown_s", 30.0), + ), + ) + logger.debug("main asyncio.gather completed successfully!") + + max_steps = cfg.trainer.training.steps or -1 + + print("All services initialized successfully!") + shutdown_event = asyncio.Event() + + # Initialize torchstore + trainer_num_procs = cfg.actors.trainer["procs"] + trainer_host_mesh_name = cfg.actors.trainer["mesh_name"] + trainer_hosts = await provisioner.get_host_mesh(trainer_host_mesh_name) + await ts.initialize( + mesh=trainer_hosts.spawn_procs(per_host={"procs": trainer_num_procs}), + strategy=ts.LocalRankStrategy(), + ) + print("Torchstore successfully initialized with local rank strategy") + + # Episode dropout configuration + dropout_cfg = cfg.get("episode_dropout", {}) + enable_variance_dropout = dropout_cfg.get("enable_variance_dropout", True) + enable_truncation_dropout = dropout_cfg.get("enable_truncation_dropout", True) + variance_threshold = dropout_cfg.get("variance_threshold", 1e-3) + + # Core RL loops + async def continuous_rollouts(): + try: + rollout_count = 0 + consecutive_errors = 0 + max_consecutive_errors = int(os.environ.get("FORGE_MAX_ROLLOUT_ERRORS", "50")) + rollout_timeout_s = float(os.environ.get("FORGE_ROLLOUT_TIMEOUT_S", "300")) + + pad_id = await dataloader.pad_token.call_one() + + # Rollout-side backpressure settings + # Only produce new episodes when buffer needs them (prevents sample waste) + batch_size = cfg.batch_size + episodes_per_step = batch_size * group_size + # Buffer target: enough for N training steps (configurable via env var) + buffer_target_steps = int(os.environ.get("FORGE_BUFFER_TARGET_STEPS", "4")) + max_buffer_episodes = episodes_per_step * buffer_target_steps + backpressure_check_interval = float(os.environ.get("FORGE_BACKPRESSURE_CHECK_INTERVAL", "0.5")) + + while not shutdown_event.is_set(): + try: + t = Tracer("main_perf/continuous_rollouts") + t.start() + + # ROLLOUT BACKPRESSURE: Check if buffer needs more episodes + # This prevents overproduction and sample waste + try: + buffer_size = await replay_buffer._numel.call_one() + if buffer_size >= max_buffer_episodes: + # Buffer is full enough, wait before producing more + record_metric("rollout/backpressure/paused", 1, Reduce.SUM) + record_metric("rollout/backpressure/buffer_size", buffer_size, Reduce.MAX) + await asyncio.sleep(backpressure_check_interval) + t.stop() # Don't count this as a rollout iteration + continue + except Exception as e: + # If buffer check fails, continue with rollout + logger.debug(f"Buffer size check failed: {e}") + + t.step("backpressure_check") + + # Timeout on data loading + try: + sample = await asyncio.wait_for( + dataloader.sample.call_one(), + timeout=30.0, + ) + except asyncio.TimeoutError: + logger.warning("[ROLLOUT] Timeout waiting for dataloader sample") + record_metric("main/continuous_rollouts/dataloader_timeout", 1, Reduce.SUM) + continue + + if sample is None: + print("Dataloader is empty, exiting continuous rollout") + return + + t.step("data_loading") + + prompt, target = sample["request"], sample["target"] + + # Timeout on policy generation + try: + responses: list[Completion] = await asyncio.wait_for( + policy.generate.route(prompt), + timeout=rollout_timeout_s, + ) + except asyncio.TimeoutError as timeout_err: + logger.error( + f"[ROLLOUT] Timeout after {rollout_timeout_s}s waiting for policy.generate(). " + f"Generator may be stuck during weight update." + ) + record_metric("main/continuous_rollouts/generation_timeout", 1, Reduce.SUM) + consecutive_errors += 1 + if consecutive_errors >= max_consecutive_errors: + raise RuntimeError( + f"[ROLLOUT FAILURE] {consecutive_errors} consecutive rollout errors. " + f"Generator appears to be unresponsive." + ) from timeout_err + continue + + t.step("policy_generation") + + episodes = [] + input_ids = torch.ones( + (group_size, max_req_tokens + max_res_tokens), + dtype=torch.long, + ) + seq_len = max_req_tokens + max_res_tokens + + # Track evaluation errors for circuit breaker + eval_errors_this_batch = 0 + + # Create episodes first + for i, response in enumerate(responses): + # Both GRPOLoss and DAPOLoss need generator_logprobs and loss_mask + # Validate logprobs exist + if response.logprobs is None: + raise ValueError( + "Completion.logprobs is None. " + "Ensure Generator returns logprobs by setting 'logprobs: 1' in sampling_params config." + ) + + # Prepare generator_logprobs (shifted for next-token prediction) + actual_response_len = response.token_ids.shape[0] + generator_logprobs = torch.zeros(seq_len, dtype=response.logprobs.dtype) + generator_logprobs[ + max_req_tokens : max_req_tokens + actual_response_len + ] = response.logprobs + generator_logprobs = torch.roll(generator_logprobs, shifts=-1, dims=0) + generator_logprobs[-1] = 0.0 + + # Prepare loss_mask + response_mask = torch.zeros(seq_len, dtype=torch.float32) + response_mask[max_req_tokens : max_req_tokens + actual_response_len] = 1.0 + loss_mask = torch.roll(response_mask, shifts=-1, dims=0) + loss_mask[-1] = 0.0 + + episode = Episode( + episode_id=str(uuid.uuid4()), + pad_id=pad_id, + request_len=max_req_tokens, + response_len=max_res_tokens, + target=target, + completion=response, + generator_logprobs=generator_logprobs, + loss_mask=loss_mask, + ) + episodes.append(episode) + + # Parallel reward evaluation using asyncio.gather + async def evaluate_single( + idx, episode, response, *, _prompt=prompt, _target=target + ): + try: + reward = await reward_actor.evaluate_response.route( + prompt=_prompt, response=response.text, target=_target + ) + return idx, reward, None + except Exception as eval_exc: + return idx, 0.0, eval_exc + + eval_tasks = [ + evaluate_single(i, ep, resp) + for i, (ep, resp) in enumerate( + zip(episodes, responses, strict=True) + ) + ] + eval_results = await asyncio.gather(*eval_tasks) + + # Process results + for idx, reward, error in eval_results: + episodes[idx].reward = reward + if error is not None: + logger.warning(f"[ROLLOUT] Reward evaluation failed: {error}") + eval_errors_this_batch += 1 + record_metric("main/continuous_rollouts/eval_error", 1, Reduce.SUM) + + # Build input_ids after rewards are assigned + for i, episode in enumerate(episodes): + input_ids[i, :max_req_tokens] = episode.request_tensor + input_ids[i, max_req_tokens:] = episode.response_tensor + + t.step("reward_evaluation") + + # Episode dropout logic (aligned with GRPO reference implementation) + # Drop entire batch if: + # 1. Reward variance is too low (including all 0s and all 1s) + # 2. Any response was truncated (didn't end with EOS) + rewards = [e.reward for e in episodes] + rewards_std = torch.std(torch.tensor(rewards)) + is_low_variance = rewards_std < variance_threshold + + # DAPO/GRPO aggressive truncation dropout: Drop entire batch if ANY + # response was truncated (stop_reason == "length"). This is intentional + # per DAPO paper recommendations - truncated responses provide incomplete + # signal and can hurt training. The dropout is batch-level rather than + # per-episode to maintain advantage computation correctness within groups. + num_truncated = sum( + 1 for e in episodes if e.stop_reason == "length" + ) + is_truncated = num_truncated > 0 + + # Record dropout metrics + n = len(episodes) + if enable_variance_dropout: + record_metric( + "main/continuous_rollouts/episodes_dropped/low_variance", + n if is_low_variance else 0, + Reduce.SUM, + ) + + if enable_truncation_dropout: + record_metric( + "main/continuous_rollouts/episodes_dropped/truncated", + num_truncated, + Reduce.SUM, + ) + + # Determine if we should drop this batch + should_drop = ( + (enable_variance_dropout and is_low_variance) or + (enable_truncation_dropout and is_truncated) + ) + + record_metric( + "main/continuous_rollouts/episodes_dropped/total", + n if should_drop else 0, + Reduce.SUM, + ) + + if should_drop: + if is_low_variance: + logger.debug( + f"[DROPOUT] Dropping batch: low reward variance " + f"(std={rewards_std:.4f} < {variance_threshold})" + ) + if is_truncated: + logger.debug( + f"[DROPOUT] Dropping batch: {num_truncated}/{n} episodes truncated" + ) + del input_ids, episodes + continue + + # Circuit breaker: if ALL evaluations failed, something is wrong + if eval_errors_this_batch == len(responses): + consecutive_errors += 1 + logger.warning( + f"[CIRCUIT BREAKER] All {len(responses)} evaluations failed. " + f"Consecutive error batches: {consecutive_errors}/{max_consecutive_errors}" + ) + if consecutive_errors >= max_consecutive_errors: + raise RuntimeError( + f"[ROLLOUT FAILURE] {consecutive_errors} consecutive batches with all evaluations failing. " + f"Environment actor appears to be unresponsive. Check container health." + ) + else: + # Reset error counter on partial success + consecutive_errors = 0 + + # Compute ref_logprobs only if ref_model is configured + if ref_model is not None: + try: + ref_logprobs = await asyncio.wait_for( + ref_model.forward.route(input_ids, return_logprobs=True), + timeout=60.0, + ) + except asyncio.TimeoutError: + logger.error("[ROLLOUT] Timeout waiting for ref_model.forward()") + record_metric("main/continuous_rollouts/ref_model_timeout", 1, Reduce.SUM) + continue + + t.step("reference_model_calculate_logprobs") + + if not isinstance(ref_logprobs, torch.Tensor): + raise TypeError( + f"ref_model.forward.route() returned {type(ref_logprobs)} instead of torch.Tensor" + ) + + for i, episode in enumerate(episodes): + episode.ref_logprobs = ref_logprobs[i] + del ref_logprobs + + del input_ids + + advantages = await compute_advantages.compute.call_one(episodes) + for episode, advantage in zip(episodes, advantages, strict=True): + episode.advantage = advantage + await replay_buffer.add.call_one(episode) + + # Track token-based metrics (aligned with GRPO) + prompt_tokens = episode.completion.prompt_ids.shape[0] + response_tokens = episode.completion.token_ids.shape[0] + + record_metric("episode/avg_prompt_tokens", prompt_tokens, Reduce.MEAN) + record_metric("episode/max_prompt_tokens", prompt_tokens, Reduce.MAX) + record_metric("episode/min_prompt_tokens", prompt_tokens, Reduce.MIN) + record_metric("episode/avg_response_tokens", response_tokens, Reduce.MEAN) + record_metric("episode/max_response_tokens", response_tokens, Reduce.MAX) + record_metric("episode/min_response_tokens", response_tokens, Reduce.MIN) + record_metric("episode/avg_reward", episode.reward, Reduce.MEAN) + + rollout_count += 1 + record_metric( + "main/continuous_rollouts/count_rollout_iterations", 1, Reduce.SUM + ) + t.stop() + + except Exception as rollout_err: + # Catch any unexpected errors in rollout loop to prevent thread crash + logger.error(f"[ROLLOUT] Unexpected error in rollout loop: {rollout_err}") + record_metric("main/continuous_rollouts/unexpected_error", 1, Reduce.SUM) + consecutive_errors += 1 + if consecutive_errors >= max_consecutive_errors: + raise RuntimeError( + f"[ROLLOUT FAILURE] {consecutive_errors} consecutive errors in rollout loop. " + f"Last error: {rollout_err}" + ) from rollout_err + # Brief pause before retry + await asyncio.sleep(1.0) + except Exception as e: + import traceback + logger.error(f"[ROLLOUT FATAL] continuous_rollouts() crashed with error: {e}") + logger.error(f"[ROLLOUT FATAL] Traceback:\n{traceback.format_exc()}") + raise + + async def continuous_training(): + training_step = 0 + restart_tracer = True + consecutive_empty_samples = 0 + # Configurable via environment variable for advanced tuning + max_empty_samples_before_error = int( + os.environ.get("FORGE_MAX_EMPTY_BUFFER_WAIT_S", "120") + ) * 10 # Convert seconds to 0.1s intervals + + while max_steps == -1 or training_step < max_steps: + if restart_tracer: + t = Tracer("main_perf/continuous_training") + t.start() + restart_tracer = False + + batch = await replay_buffer.sample.call_one( + curr_policy_version=training_step + ) + if batch is None: + consecutive_empty_samples += 1 + + # Log warning at increasing intervals + if consecutive_empty_samples == 10: # 1 second + logger.warning( + f"[BUFFER STARVATION] Buffer empty for 1s at step {training_step}. " + f"Rollouts may be blocked during weight update." + ) + elif consecutive_empty_samples == 100: # 10 seconds + logger.warning( + f"[BUFFER STARVATION] Buffer empty for 10s at step {training_step}. " + f"Consider increasing max_policy_age or rollout_threads." + ) + elif consecutive_empty_samples == 300: # 30 seconds + logger.error( + f"[BUFFER STARVATION] Buffer empty for 30s at step {training_step}. " + f"This indicates a likely deadlock. Check generator weight updates." + ) + + # Fail after max wait to prevent infinite hangs + if consecutive_empty_samples >= max_empty_samples_before_error: + raise RuntimeError( + f"[BUFFER STARVATION DEADLOCK] Replay buffer has been empty for " + f"{consecutive_empty_samples * 0.1:.1f} seconds at training step {training_step}. " + f"This typically indicates that:\n" + f" 1. All policy replicas are blocked during weight updates\n" + f" 2. max_policy_age ({cfg.get('off_by_n', 1)}) is too aggressive\n" + f" 3. rollout_threads ({num_rollout_threads}) is insufficient\n" + f"Solutions:\n" + f" - Increase 'off_by_n' in config (recommended: 2-3)\n" + f" - Increase 'rollout_threads' in config\n" + f" - Increase policy service 'num_replicas'\n" + f" - Set FORGE_MAX_EMPTY_BUFFER_WAIT_S env var to increase timeout" + ) + + logger.debug("Running out of batch, now waiting") + await asyncio.sleep(0.1) + else: + # Reset starvation counter on successful sample + consecutive_empty_samples = 0 + t.step("waiting_for_buffer") + + await trainer.train_step.call(batch) + training_step += 1 + t.step("train_step") + + # Push and update weights every step + await trainer.push_weights.call(training_step) + t.step("push_weights") + + # Backpressure: Check buffer health for NEXT policy version. + # Weight updates block all rollouts, so we need enough buffer headroom + # to survive the blocking period without starving. + # CRITICAL: Check training_step + 1 because after weight update, + # episodes from current version will be evicted! + buffer_health = await replay_buffer.health_check.call_one( + curr_policy_version=training_step + 1 # Check NEXT version survivability + ) + required_surviving = buffer_health["required"] * 2 # Need 2x batch for safety margin + surviving = buffer_health["surviving_after_eviction"] + + if surviving < required_surviving: + backpressure_start = time.time() + max_backpressure_wait = float(os.environ.get("FORGE_BACKPRESSURE_TIMEOUT_S", "30")) + logger.warning( + f"[BACKPRESSURE] Buffer low before weight update at step {training_step}. " + f"surviving={surviving}, required={required_surviving}. " + f"Waiting up to {max_backpressure_wait}s for more episodes." + ) + record_metric("backpressure/triggered", 1, Reduce.SUM) + + # Wait with exponential backoff + wait_interval = 0.5 + while (time.time() - backpressure_start) < max_backpressure_wait: + await asyncio.sleep(wait_interval) + wait_interval = min(wait_interval * 1.5, 5.0) # Cap at 5s intervals + + buffer_health = await replay_buffer.health_check.call_one( + curr_policy_version=training_step + 1 + ) + if buffer_health["surviving_after_eviction"] >= required_surviving: + wait_duration = time.time() - backpressure_start + logger.info(f"[BACKPRESSURE] Buffer recovered after {wait_duration:.1f}s") + record_metric("backpressure/wait_duration_s", wait_duration, Reduce.MEAN) + break + else: + wait_duration = time.time() - backpressure_start + logger.warning( + f"[BACKPRESSURE] Buffer still low after {wait_duration:.1f}s. " + f"Proceeding with weight update to prevent complete stall." + ) + record_metric("backpressure/timeout", 1, Reduce.SUM) + t.step("backpressure_check") + + # Track weight update duration for monitoring + weight_update_start = time.time() + await policy.update_weights.fanout(training_step) + weight_update_duration = time.time() - weight_update_start + record_metric("training/weight_update_duration_s", weight_update_duration, Reduce.MEAN) + if weight_update_duration > 20.0: + logger.warning( + f"[SLOW WEIGHT UPDATE] Step {training_step} took {weight_update_duration:.1f}s. " + f"Consider increasing off_by_n or policy replicas." + ) + record_metric("training/slow_weight_update_count", 1, Reduce.SUM) + t.step("update_weights") + + if training_step >= 2: + await drop_weights(training_step - 1) + t.step("drop_weights") + + t.stop() + restart_tracer = True + + await mlogger.flush.call_one(training_step) + + # Periodic health monitoring every 10 steps + if training_step % 10 == 0: + health_buffer = await replay_buffer.health_check.call_one( + curr_policy_version=training_step + ) + record_metric("health/buffer_size", health_buffer["size"], Reduce.MAX) + record_metric("health/buffer_surviving", health_buffer["surviving_after_eviction"], Reduce.MAX) + record_metric("health/buffer_freshness_ratio", health_buffer["freshness_ratio"], Reduce.MEAN) + record_metric("health/buffer_required", health_buffer["required"], Reduce.MAX) + + # Log reward actor health + try: + reward_health = await reward_actor.get_health_status.route() + record_metric("health/env_actors_healthy", reward_health["healthy_count"], Reduce.MAX) + record_metric("health/env_actors_total", reward_health["total_count"], Reduce.MAX) + except Exception as health_err: + logger.debug(f"Could not get reward actor health: {health_err}") + + # Log training progress + record_metric("training/step", training_step, Reduce.MAX) + progress_pct = 100.0 * training_step / max_steps if max_steps > 0 else 0 + record_metric("training/progress_pct", progress_pct, Reduce.MAX) + + logger.info( + f"[HEALTH] Step {training_step}/{max_steps} ({progress_pct:.1f}%) | " + f"Buffer: {health_buffer['size']} ({health_buffer['surviving_after_eviction']} surviving) | " + f"Freshness: {health_buffer['freshness_ratio']:.2f}" + ) + + print( + f"Reached training limit ({max_steps} steps). Exiting continuous_training loop." + ) + + num_rollout_threads = cfg.get("rollout_threads", 1) + print(f"Starting OpenEnv GRPO with {num_rollout_threads} rollout threads") + + # Callback to immediately report rollout task failures + def rollout_task_done_callback(task): + try: + exc = task.exception() + if exc is not None: + import traceback + logger.error(f"[ROLLOUT TASK FAILED] Rollout task crashed: {exc}") + tb_str = "".join( + traceback.format_exception(type(exc), exc, exc.__traceback__) + ) + logger.error(f"[ROLLOUT TASK FAILED] Traceback:\n{tb_str}") + except asyncio.CancelledError: + pass # Task was cancelled, not an error + + # Start rollout tasks first + rollout_tasks = [] + for _ in range(num_rollout_threads): + task = asyncio.create_task(continuous_rollouts()) + task.add_done_callback(rollout_task_done_callback) + rollout_tasks.append(task) + + # Start training immediately (no warmup) + training_task = asyncio.create_task(continuous_training()) + + try: + await training_task + except KeyboardInterrupt: + print("Training interrupted by user") + except Exception as e: + import traceback + + print(f"Training failed with error: {e}") + traceback.print_exc() + raise + finally: + print("Shutting down...") + shutdown_event.set() + + try: + await asyncio.wait_for( + asyncio.gather(*rollout_tasks, return_exceptions=True), + timeout=5, + ) + except asyncio.TimeoutError: + for t in rollout_tasks: + t.cancel() + await asyncio.gather(*rollout_tasks, return_exceptions=True) + + training_task.cancel() + + # Cancel any pending circuit breaker restart tasks + try: + cancelled = await reward_actor.cancel_restart_tasks.route() + if cancelled > 0: + print(f"Cancelled {cancelled} pending circuit breaker restart tasks") + except Exception as cancel_err: + print(f"Warning: Error cancelling restart tasks: {cancel_err}") + + print(f"Cleaning up {len(env_actors)} environment Docker containers...") + for i, env_actor in enumerate(env_actors): + try: + await env_actor.teardown.call_one() + print(f"Environment Docker container {i + 1}/{len(env_actors)} stopped successfully") + except Exception as teardown_error: + print(f"Warning: Error during environment teardown {i + 1}: {teardown_error}") + + await shutdown() + + +if __name__ == "__main__": + + @parse + def _main(cfg): + os.environ["NCCL_ASYNC_ERROR_HANDLING"] = "1" + os.environ["NCCL_TIMEOUT_MS"] = "60000" + os.environ["MONARCH_HOSTMESH_V1"] = "1" + os.environ["TORCHSTORE_RDMA_ENABLED"] = "1" + asyncio.run(main(cfg)) + + _main() diff --git a/apps/openenv/python_utils.py b/apps/openenv/python_utils.py new file mode 100644 index 000000000..b59233065 --- /dev/null +++ b/apps/openenv/python_utils.py @@ -0,0 +1,277 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +""" +Python coding task-specific utilities using GenericAction. + +This version uses OpenEnv's GenericAction (a simple dict wrapper) instead of +the environment-specific CodeAction class. This means you don't need to +install the coding_env package locally. + +Usage: + # In your YAML config: + task: + env_name: "coding" + build_action: !function apps.openenv.python_utils.build_python_action + evaluate_response: !function apps.openenv.python_utils.evaluate_python_response + transform_sample: !function apps.openenv.python_utils.transform_python_sample +""" + +import re +from typing import Any, Dict + +from openenv import GenericAction + +from forge.observability.metrics import record_metric, Reduce + + +def get_python_system_prompt() -> str: + """Get system prompt for Python coding tasks.""" + return """You are an expert Python programmer. + +Write a Python function that correctly solves the problem described below. + +Rules: +- The code must be syntactically correct and runnable +- Use proper Python conventions and best practices +- Include necessary imports +- Do not include test code in your response +- Return only the Python code + +FORMAT YOUR RESPONSE AS: + +```python +def function_name(args): + # implementation + return result +``` +""".strip() + + +def build_python_prompt(sample: Dict[str, Any], tokenizer) -> str: + """ + Build prompt for Python code generation. + + Args: + sample: Dataset sample with 'prompt' field (e.g., from HumanEval) + tokenizer: HuggingFace tokenizer for chat template + + Returns: + Formatted prompt string ready for model generation + """ + system_prompt = get_python_system_prompt() + + # Support both HumanEval and AceCode formats + request = sample.get("prompt") or sample.get("question", "") + + messages = [ + {"role": "system", "content": system_prompt}, + {"role": "user", "content": request}, + ] + + formatted_request = tokenizer.apply_chat_template( + messages, + tokenize=False, + add_generation_prompt=True, + ) + + return formatted_request + + +def build_python_action(response: str, sample: Dict[str, Any]) -> GenericAction: + """ + Build GenericAction from model response and dataset sample. + + This uses GenericAction (a simple dict wrapper) instead of CodeAction, + so you don't need to install coding_env locally. + + The coding environment only accepts GenericAction(code=...), so we combine + the model's generated code with the test code into a single code string. + + Args: + response: Model's generated response + sample: Dataset sample with test information + + Returns: + GenericAction instance with combined code (model code + test code) + """ + # Extract code from markdown if present + model_code = extract_python_code(response) + + # Get test code if available + test_code = sample.get("target", "") + + # Combine model code and test code into a single executable script + # The test code typically contains assertions or function calls that test the model's code + if test_code: + combined_code = f"{model_code}\n\n# Test code\n{test_code}" + else: + combined_code = model_code + + # GenericAction only accepts 'code' field (maps to CodeAction) + return GenericAction(code=combined_code) + + +def evaluate_python_response(result, response: str, sample: Dict[str, Any]) -> float: + """ + Evaluate Python code execution result and return reward. + + Since the coding environment executes combined code (model code + test code), + we determine success based on the execution output: + - exit_code == 0 means all tests passed -> reward = 1.0 + - exit_code != 0 means tests failed or code error -> reward = 0.0 + + Works with both typed observations (CodeObservation) and raw dicts + returned by GenericEnvClient. + + Args: + result: StepResult from environment execution + response: Model's response (for logging) + sample: Dataset sample (for logging) + + Returns: + Reward score: 1.0 if all tests pass (exit_code == 0), 0.0 otherwise + """ + try: + print("=" * 80) + print("RAW RESPONSE FROM MODEL:") + print("-" * 80) + print(response) + print("-" * 80) + + # Extract code for validation + code = extract_python_code(response) + + if not code: + print("No Python code extracted - Reward: 0.0") + print("=" * 80) + record_metric("reward/python/no_code_extracted", 1, Reduce.SUM) + return 0.0 + + print("EXTRACTED PYTHON CODE:") + print("-" * 80) + print(code) + print("-" * 80) + + # Handle both typed observation and dict observation (from GenericEnvClient) + obs = result.observation + if isinstance(obs, dict): + # GenericEnvClient returns dicts + exit_code = obs.get("exit_code", -1) + stderr = obs.get("stderr", "") + stdout = obs.get("stdout", "") + else: + # Typed observation (CodeObservation) + exit_code = getattr(obs, "exit_code", -1) + stderr = getattr(obs, "stderr", "") + stdout = getattr(obs, "stdout", "") + + # Log execution details + print("CodingEnv Execution Result:") + print(f" Exit Code: {exit_code}") + + if stdout: + print(" Stdout (first 500 chars):") + print("-" * 40) + print(stdout[:500]) + print("-" * 40) + + if stderr: + print(" Stderr (first 500 chars):") + print("-" * 40) + print(stderr[:500]) + print("-" * 40) + + # Compute reward based on exit code + # exit_code == 0 means the combined code (model code + tests) ran successfully + # This indicates all assertions passed + if exit_code == 0: + reward = 1.0 + record_metric("reward/python/tests_passed", 1, Reduce.SUM) + else: + reward = 0.0 + record_metric("reward/python/tests_failed", 1, Reduce.SUM) + if "AssertionError" in stderr: + record_metric("reward/python/assertion_errors", 1, Reduce.SUM) + elif "SyntaxError" in stderr: + record_metric("reward/python/syntax_errors", 1, Reduce.SUM) + elif "Error" in stderr or "error" in stderr: + record_metric("reward/python/other_errors", 1, Reduce.SUM) + + record_metric("reward/python/reward", reward, Reduce.MEAN) + + print(f"Final Reward: {reward:.3f}") + print("=" * 80) + + return reward + + except Exception as e: + print(f"✗ Error evaluating response: {e} - Reward: 0.0") + print("=" * 80) + record_metric("reward/python/evaluation_errors", 1, Reduce.SUM) + return 0.0 + + +def extract_python_code(response: str) -> str: + """ + Extract Python code from markdown code blocks. + + Args: + response: Model's response text + + Returns: + Extracted Python code + """ + # Try to find ```python code block + pattern = r"```python\n(.*?)```" + match = re.search(pattern, response, re.DOTALL) + if match: + return match.group(1).strip() + + # Try generic code block + pattern = r"```\n(.*?)```" + match = re.search(pattern, response, re.DOTALL) + if match: + return match.group(1).strip() + + # No markdown block, return as-is + return response.strip() + + +def transform_python_sample(sample: Dict[str, Any], tokenizer) -> Dict[str, Any] | None: + """ + Transform raw dataset sample into training format. + + Args: + sample: Raw dataset sample (e.g., from HumanEval) + tokenizer: HuggingFace tokenizer + + Returns: + Transformed sample with 'request', 'target', 'task_id' or None if invalid + """ + # Validate required fields - support both HumanEval and AceCode formats + prompt_text = sample.get("prompt") or sample.get("question") + if not prompt_text: + # Debug: log why sample was rejected (only for first few) + if not hasattr(transform_python_sample, '_warned'): + print(f"WARNING: Sample rejected - missing 'prompt' or 'question' field. Sample keys: {list(sample.keys())}") + transform_python_sample._warned = True + return None + + # Build prompt + formatted_request = build_python_prompt(sample, tokenizer) + + # Get test code - support both formats + test_code = sample.get("test") or sample.get("test_cases", "") + if isinstance(test_code, list): + # AceCode format: list of test cases + test_code = "\n".join(test_code) + + return { + "request": formatted_request, + "target": test_code, # Test code for reward function + "task_id": sample.get("task_id") or sample.get("id", ""), + } diff --git a/src/forge/actors/openenv.py b/src/forge/actors/openenv.py new file mode 100644 index 000000000..d4160b66f --- /dev/null +++ b/src/forge/actors/openenv.py @@ -0,0 +1,321 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +""" +OpenEnv Actor for sandboxed code execution. + +This actor works with ANY OpenEnv environment using only raw dictionaries, +without requiring environment-specific packages (like julia_env or coding_env). + +Usage: + from openenv import GenericEnvClient, GenericAction + + # Create actor for any environment - just specify the Docker image + actor = OpenEnvActor( + docker_image="julia-env:latest", + env_name="julia", + ) + await actor.setup() + + # Execute with GenericAction (just a dict wrapper) + action = GenericAction(core_code="println('hello')", test_code="@test true") + result = await actor.execute(action) # Returns StepResult with dict observation + + await actor.teardown() +""" + +import logging +from typing import Any, Dict, Optional, TYPE_CHECKING + +from monarch.actor import endpoint + +if TYPE_CHECKING: + from openenv import GenericEnvClient + from openenv.core.client_types import StepResult + +from forge.controller import ForgeActor +from forge.observability.metrics import record_metric, Reduce +from forge.actors.openenv_utils import ( + ContainerConfig, + ContainerManager, + ConnectionPool, + is_connection_error, +) + +logger = logging.getLogger(__name__) +logger.setLevel(logging.DEBUG) + + +class OpenEnvActor(ForgeActor): + """A generic sandboxed execution environment using GenericEnvClient. + + This actor manages WebSocket connections to Docker containers, + with connection pooling for high concurrency. + + Args: + docker_image: Docker image name (e.g., "julia-env:latest") + env_name: Environment name for logging (e.g., "julia", "python") + env_vars: Environment variables to pass to containers + container_timeout_s: Timeout for container startup + request_timeout_s: Timeout for individual requests + port: Starting port for containers + container_memory_gb: Memory limit per container in GB + enable_zombie_cleanup: Whether to enable zombie process cleanup + num_connections: Total WebSocket connections to create + num_containers: Number of Docker containers + + Usage: + >>> actor = OpenEnvActor( + ... docker_image="julia-env:latest", + ... env_name="julia", + ... num_connections=16, + ... num_containers=2, + ... ) + >>> await actor.setup() + >>> action = {"core_code": "...", "test_code": "..."} + >>> result = await actor.execute(action) + >>> await actor.teardown() + """ + + def __init__( + self, + docker_image: str, + env_name: str = "openenv", + env_vars: Optional[Dict[str, str]] = None, + container_timeout_s: float = 180.0, + request_timeout_s: float = 120.0, + port: int = 8000, + container_memory_gb: int = 4, + enable_zombie_cleanup: bool = False, + num_connections: int = 1, + num_containers: int = 1, + ): + self.num_connections = num_connections + self.num_containers = num_containers + self.request_timeout_s = request_timeout_s + self.enable_zombie_cleanup = enable_zombie_cleanup + + # Container management + self._container_config = ContainerConfig( + docker_image=docker_image, + env_name=env_name, + env_vars=env_vars or {}, + port=port, + memory_gb=container_memory_gb, + timeout_s=container_timeout_s, + ) + self._container_manager = ContainerManager(self._container_config) + + # Connection pool + self._pool = ConnectionPool(request_timeout_s=request_timeout_s) + + # Backward compatibility + self.client = None + self.actual_port = None + + @endpoint + async def setup(self): + """Initialize containers and create sync connection pool with thread pool executor.""" + logger.info( + f"Setting up: {self.num_connections} sync connections " + f"across {self.num_containers} containers (with thread pool)" + ) + + # Create containers + container_urls = self._container_manager.create_containers(self.num_containers) + + # Initialize thread pool and create sync connections + await self._pool.initialize(num_connections=self.num_connections) + self._pool.create_connections(container_urls, self.num_connections) + + # Backward compatibility + if self._pool.clients: + self.client = self._pool.clients[0] + self.actual_port = self._container_config.port + + @endpoint + async def recreate(self): + """Resets the environment to a clean state.""" + import asyncio + + if not self.client: + raise RuntimeError("Client not initialized. Call setup() first.") + loop = asyncio.get_event_loop() + await loop.run_in_executor(self._pool._executor, self.client.reset) + + @endpoint + async def execute(self, action: Dict[str, Any]) -> "StepResult[Dict[str, Any]]": + """Execute an action using an available connection from the pool. + + Args: + action: Dictionary action with keys like core_code, test_code. + + Returns: + StepResult with observation dict. + """ + if not self._pool.clients: + raise RuntimeError("Connection pool not initialized. Call setup() first.") + + client_idx, client = await self._pool.acquire(timeout=self.request_timeout_s) + + try: + return await self._execute_with_retry(client_idx, client, action) + finally: + await self._pool.release(client_idx) + + async def _execute_with_retry( + self, client_idx: int, client: "GenericEnvClient", action: Dict[str, Any] + ) -> "StepResult[Dict[str, Any]]": + """Execute action with retry logic for connection errors. + + Uses thread pool to run sync WebSocket calls without blocking event loop. + """ + max_retries = 3 + + for attempt in range(max_retries): + try: + # Execute in thread pool - doesn't block event loop + result = await self._pool.execute_step(client_idx, action) + record_metric("pool/execute_success", 1, Reduce.SUM) + return result + + except Exception as e: + is_conn_error, error_type = is_connection_error(str(e)) + + if is_conn_error and attempt < max_retries - 1: + record_metric(f"pool/{error_type}_error_count", 1, Reduce.SUM) + logger.error(f"{error_type} error on client {client_idx}: {e}") + + try: + client = await self._pool.reconnect( + client_idx, self._container_manager.container_urls + ) + record_metric("pool/reconnect_success", 1, Reduce.SUM) + continue + except Exception as reconnect_error: + logger.error(f"Reconnect failed: {reconnect_error}") + record_metric("pool/reconnect_failure", 1, Reduce.SUM) + + if is_conn_error: + raise RuntimeError( + f"Client {client_idx} failed after {max_retries} attempts: {e}" + ) from e + raise + + raise RuntimeError("Execution failed after all retry attempts") + + @endpoint + async def health_check(self) -> Dict[str, Any]: + """Check if the environment is healthy.""" + import asyncio + + if not self._pool.clients: + return {"healthy": False, "error": "Pool not initialized"} + + healthy_count = 0 + pool_status = [] + loop = asyncio.get_event_loop() + + for i, client in enumerate(self._pool.clients): + try: + # Run sync state() in thread pool + await loop.run_in_executor(self._pool._executor, client.state) + pool_status.append({"index": i, "healthy": True}) + healthy_count += 1 + except Exception as e: + pool_status.append({"index": i, "healthy": False, "error": str(e)}) + + return { + "healthy": healthy_count > 0, + "healthy_count": healthy_count, + "total_clients": len(self._pool.clients), + "pool_status": pool_status, + } + + @endpoint + async def get_pool_status(self) -> Dict[str, Any]: + """Get connection pool status.""" + status = self._pool.get_status() + status["num_containers"] = len(self._container_manager.container_urls) + status["container_urls"] = self._container_manager.container_urls + return status + + @endpoint + async def get_state(self) -> Dict[str, Any]: + """Get current environment state.""" + import asyncio + + if not self._pool.clients: + raise RuntimeError("Pool not initialized. Call setup() first.") + + loop = asyncio.get_event_loop() + return await loop.run_in_executor( + self._pool._executor, + self._pool.clients[0].state + ) + + @endpoint + async def teardown(self): + """Clean up all connections and containers.""" + logger.debug("Tearing down...") + await self._pool.close_all() + self._container_manager.stop_all() + self.client = None + logger.debug("Teardown complete.") + + @endpoint + async def restart_container(self) -> Dict[str, Any]: + """Restart all containers and reconnect the pool.""" + import asyncio + + logger.warning("Restarting all containers...") + + try: + # Cleanup existing + await self._pool.close_all() + self._container_manager.stop_all() + await asyncio.sleep(2) + + # Recreate + container_urls = self._container_manager.create_containers(self.num_containers) + await self._pool.initialize(num_connections=self.num_connections) + self._pool.create_connections(container_urls, self.num_connections) + + if self._pool.clients: + self.client = self._pool.clients[0] + + logger.info(f"Restart complete: {len(self._pool.clients)} sync connections") + return { + "success": True, + "num_containers": len(container_urls), + "num_connections": len(self._pool.clients), + } + + except Exception as e: + logger.error(f"Restart failed: {e}") + return {"success": False, "error": str(e)} + + def create_action(self, **kwargs) -> "GenericAction": + """Create a GenericAction instance.""" + from openenv import GenericAction + return GenericAction(**kwargs) + + # Expose internal state for backward compatibility + @property + def clients(self) -> list: + return self._pool.clients + + @property + def client_available(self) -> list: + return self._pool.client_available + + @property + def container_urls(self) -> list: + return self._container_manager.container_urls + + @property + def providers(self) -> list: + return self._container_manager.providers diff --git a/src/forge/actors/openenv_utils.py b/src/forge/actors/openenv_utils.py new file mode 100644 index 000000000..9beb7e9bd --- /dev/null +++ b/src/forge/actors/openenv_utils.py @@ -0,0 +1,475 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +""" +Shared utilities for OpenEnv actors. + +This module contains common utility functions used by OpenEnvActor. +""" + +import logging +import socket + +logger = logging.getLogger(__name__) + + +# Error keywords for connection issue detection +WEBSOCKET_ERROR_KEYWORDS = [ + "connectionclosederror", + "keepalive ping timeout", + "websocket", + "connection closed", + "connection reset", + "broken pipe", +] + +CONTAINER_ERROR_KEYWORDS = [ + "no such container", + "container not found", + "container is not running", + "container has stopped", + "container exited", + "exec session", + "state improper", + "oci runtime error", + "docker daemon", + "cannot connect to docker", + "connection refused", +] + +HTTP_ERROR_KEYWORDS = [ + "connection timeout", + "read timeout", + "http error", + "status code", +] + + +def is_connection_error(error_msg: str) -> tuple: + """Check if error is a connection-related error. + + Args: + error_msg: The error message to check. + + Returns: + Tuple of (is_error, error_type) where error_type is 'websocket', 'container', or None. + """ + error_lower = error_msg.lower() + if any(kw in error_lower for kw in WEBSOCKET_ERROR_KEYWORDS): + return True, "websocket" + if any(kw in error_lower for kw in CONTAINER_ERROR_KEYWORDS): + return True, "container" + return False, None + + +def is_http_error(error_msg: str) -> bool: + """Check if error is an HTTP-level error (not requiring container recreation).""" + error_lower = error_msg.lower() + return any(kw in error_lower for kw in HTTP_ERROR_KEYWORDS) + + +def is_port_in_use(port: int, host: str = "127.0.0.1") -> bool: + """ + Check if a port is already in use on the specified host. + + Args: + port: Port number to check + host: Host address to check (default: 127.0.0.1) + + Returns: + True if port is in use, False if available + """ + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + try: + s.bind((host, port)) + return False + except OSError: + return True + + +def find_available_port( + preferred_port: int, min_port: int = 5000, max_attempts: int = 100 +) -> int: + """ + Find an available port starting from preferred_port and decrementing. + + Args: + preferred_port: The preferred port to use + min_port: Minimum port number to try (default: 5000) + max_attempts: Maximum number of ports to try (default: 100) + + Returns: + An available port number + + Raises: + RuntimeError: If no available port is found after max_attempts + """ + port = preferred_port + attempts = 0 + + while attempts < max_attempts: + if port < min_port: + raise RuntimeError( + f"No available port found after trying {attempts} ports. " + f"Reached minimum port {min_port}." + ) + + if not is_port_in_use(port): + logger.info(f"Found available port: {port}") + return port + + logger.debug(f"Port {port} is in use, trying {port - 1}") + port -= 1 + attempts += 1 + + raise RuntimeError( + f"No available port found after trying {max_attempts} ports " + f"(from {preferred_port} down to {port})." + ) + + +class ContainerConfig: + """Configuration for container setup.""" + + def __init__( + self, + docker_image: str, + env_name: str = "openenv", + env_vars: dict = None, + port: int = 8000, + memory_gb: int = 4, + timeout_s: float = 180.0, + ): + self.docker_image = docker_image + self.env_name = env_name + self.env_vars = env_vars or {} + self.port = port + self.memory_gb = memory_gb + self.timeout_s = timeout_s + + +class ContainerManager: + """Manages Docker container lifecycle for OpenEnv environments.""" + + def __init__(self, config: ContainerConfig): + self.config = config + self.providers = [] + self.container_urls = [] + self._logs_dir = None + + def _setup_logs_dir(self): + """Create the logs directory for container output.""" + import os + self._logs_dir = os.path.expanduser(f"~/{self.config.env_name}_container_logs") + os.makedirs(self._logs_dir, exist_ok=True) + return self._logs_dir + + def _build_container_env_vars(self, container_port: int) -> dict: + """Build environment variables for a container.""" + env_vars = self.config.env_vars.copy() + env_vars["PORT"] = str(container_port) + + env_name_upper = self.config.env_name.upper() + log_filename = f"{self.config.env_name}_env_port_{container_port}.log" + container_log_path = f"/tmp/{self.config.env_name}_logs/{log_filename}" + env_vars[f"{env_name_upper}_LOG_FILE"] = container_log_path + env_vars[f"{env_name_upper}_LOG_LEVEL"] = "DEBUG" + + return env_vars + + def _build_volumes(self) -> dict: + """Build volume mappings for containers.""" + if not self._logs_dir: + self._setup_logs_dir() + return {self._logs_dir: f"/tmp/{self.config.env_name}_logs"} + + def create_containers(self, num_containers: int): + """Create and start Docker containers. + + Args: + num_containers: Number of containers to create. + + Returns: + List of container base URLs. + """ + from openenv.core.containers.runtime import LocalDockerProvider + + self._setup_logs_dir() + self.providers = [] + self.container_urls = [] + + for i in range(num_containers): + try: + container_port = find_available_port(self.config.port - i) + logger.info(f"Creating container {i + 1}/{num_containers} on port {container_port}") + + env_vars = self._build_container_env_vars(container_port) + volumes = self._build_volumes() + + provider = LocalDockerProvider() + base_url = provider.start_container( + self.config.docker_image, + port=container_port, + env_vars=env_vars, + volumes=volumes, + memory_gb=self.config.memory_gb, + ) + + provider.wait_for_ready(base_url, timeout_s=self.config.timeout_s) + + self.providers.append(provider) + self.container_urls.append(base_url) + + logger.info(f"Container {i + 1}/{num_containers} ready at {base_url}") + + except Exception as e: + logger.error(f"Failed to create container {i + 1}: {e}") + self.stop_all() + raise + + return self.container_urls + + def stop_all(self): + """Stop all managed containers.""" + import subprocess + + for i, provider in enumerate(self.providers): + try: + if hasattr(provider, 'container_id') and provider.container_id: + # Use docker kill directly for faster shutdown + # docker stop can hang on stuck processes + try: + subprocess.run( + ['docker', 'kill', provider.container_id], + timeout=5, + capture_output=True + ) + except subprocess.TimeoutExpired: + logger.warning(f"docker kill timed out for container {i}") + else: + provider.stop_container() + logger.debug(f"Stopped container {i}") + except Exception as e: + logger.warning(f"Error stopping container {i}: {e}") + + self.providers = [] + self.container_urls = [] + + +class ConnectionPool: + """Manages a pool of sync WebSocket connections to OpenEnv containers. + + Uses thread pool execution to avoid blocking the asyncio event loop + while maintaining simple sync WebSocket clients. + """ + + def __init__(self, request_timeout_s: float = 120.0, max_workers: int | None = None): + self.request_timeout_s = request_timeout_s + self._max_workers = max_workers # If None, will be derived from num_connections + self.clients = [] + self.client_available = [] + self._lock = None + self._condition = None + self._executor = None + + async def initialize(self, num_connections: int = 16): + """Initialize async primitives and thread pool. + + Args: + num_connections: Number of connections that will be created. + Used to derive thread pool size if max_workers not set. + """ + import asyncio + from concurrent.futures import ThreadPoolExecutor + + self._lock = asyncio.Lock() + self._condition = asyncio.Condition(self._lock) + + # Derive thread pool size: 2 threads per connection, capped at 64 + # This allows for concurrent execute + health check per connection + max_workers = self._max_workers or min(64, max(4, num_connections * 2)) + logger.info(f"ConnectionPool: Creating thread pool with {max_workers} workers for {num_connections} connections") + self._executor = ThreadPoolExecutor(max_workers=max_workers, thread_name_prefix="ws_pool") + + def create_connections(self, container_urls: list, num_connections: int): + """Create sync WebSocket connections distributed across containers. + + Args: + container_urls: List of container base URLs. + num_connections: Total number of connections to create. + """ + from openenv import GenericEnvClient + + self.clients = [] + self.client_available = [] + num_containers = len(container_urls) + + for i in range(num_connections): + try: + container_idx = i % num_containers + base_url = container_urls[container_idx] + + logger.debug(f"Creating sync connection {i + 1}/{num_connections} → container {container_idx}") + + client = GenericEnvClient( + base_url=base_url, + connect_timeout_s=10.0, + message_timeout_s=self.request_timeout_s, + ) + client.reset() + + self.clients.append(client) + self.client_available.append(True) + + except Exception as e: + logger.error(f"Failed to create connection {i + 1}: {e}") + self.close_all_sync() + raise + + logger.info(f"Connection pool ready: {len(self.clients)} sync connections") + + async def acquire(self, timeout: float = 30.0) -> tuple: + """Acquire an available client from the pool. + + Args: + timeout: Maximum wait time in seconds. + + Returns: + Tuple of (client_index, client). + + Raises: + TimeoutError: If no client available within timeout. + """ + import asyncio + + start_time = asyncio.get_event_loop().time() + + async with self._condition: + while True: + for i, available in enumerate(self.client_available): + if available: + self.client_available[i] = False + logger.debug(f"Acquired client {i} from pool") + return i, self.clients[i] + + elapsed = asyncio.get_event_loop().time() - start_time + remaining = timeout - elapsed + if remaining <= 0: + raise TimeoutError( + f"No client available after {timeout}s. " + f"All {len(self.clients)} clients busy." + ) + + try: + await asyncio.wait_for(self._condition.wait(), timeout=remaining) + except asyncio.TimeoutError as timeout_err: + raise TimeoutError( + f"No client available after {timeout}s. " + f"All {len(self.clients)} clients busy." + ) from timeout_err + + async def release(self, client_idx: int): + """Release a client back to the pool.""" + async with self._condition: + self.client_available[client_idx] = True + logger.debug(f"Released client {client_idx}") + self._condition.notify() + + async def reconnect(self, client_idx: int, container_urls: list) -> "GenericEnvClient": + """Reconnect a failed client. + + Args: + client_idx: Index of client to reconnect. + container_urls: List of container URLs. + + Returns: + New client instance. + """ + from openenv import GenericEnvClient + import asyncio + + num_containers = len(container_urls) + container_idx = client_idx % num_containers + base_url = container_urls[container_idx] + + logger.info(f"Reconnecting sync client {client_idx} to {base_url}") + + # Close old client in thread pool + old_client = self.clients[client_idx] + if old_client: + try: + loop = asyncio.get_event_loop() + await loop.run_in_executor(self._executor, old_client.close) + except Exception as e: + logger.debug(f"Error closing old client: {e}") + + await asyncio.sleep(1) + + # Create new client in thread pool + def create_client(): + client = GenericEnvClient( + base_url=base_url, + connect_timeout_s=10.0, + message_timeout_s=self.request_timeout_s, + ) + client.reset() + return client + + loop = asyncio.get_event_loop() + new_client = await loop.run_in_executor(self._executor, create_client) + + self.clients[client_idx] = new_client + logger.info(f"Sync client {client_idx} reconnected") + return new_client + + async def execute_step(self, client_idx: int, action: dict): + """Execute step on client using thread pool to avoid blocking event loop. + + Args: + client_idx: Index of client to use. + action: Action dictionary to execute. + + Returns: + StepResult from the client. + """ + import asyncio + + client = self.clients[client_idx] + loop = asyncio.get_event_loop() + + # Run sync WebSocket call in thread pool - doesn't block event loop + return await loop.run_in_executor( + self._executor, + client.step, + action + ) + + async def close_all(self): + """Close all connections and shutdown thread pool.""" + self.close_all_sync() + if self._executor: + self._executor.shutdown(wait=False) + self._executor = None + + def close_all_sync(self): + """Close all connections synchronously.""" + for i, client in enumerate(self.clients): + try: + client.close() + logger.debug(f"Closed sync client {i}") + except Exception as e: + logger.warning(f"Error closing client {i}: {e}") + + self.clients = [] + self.client_available = [] + + def get_status(self) -> dict: + """Get pool status.""" + return { + "total": len(self.clients), + "available": sum(1 for a in self.client_available if a), + "busy": sum(1 for a in self.client_available if not a), + } diff --git a/src/forge/actors/reference_model.py b/src/forge/actors/reference_model.py index 29ac06651..38347e28f 100644 --- a/src/forge/actors/reference_model.py +++ b/src/forge/actors/reference_model.py @@ -145,6 +145,7 @@ async def forward( t = Tracer("reference_perf/forward", timer="gpu", track_memory=True) t.start() + self.engine.gc_handler.run(self.step) model_parts = self.engine.model_parts diff --git a/src/forge/actors/replay_buffer.py b/src/forge/actors/replay_buffer.py index edec45b57..db0ed1a21 100644 --- a/src/forge/actors/replay_buffer.py +++ b/src/forge/actors/replay_buffer.py @@ -162,10 +162,27 @@ def _evict(self, curr_policy_version): evicted_count = buffer_len_before_evict - len(self.buffer) record_metric("buffer/evict/sum_episodes_evicted", evicted_count, Reduce.SUM) - logger.debug( - f"maximum policy age: {self.max_policy_age}, current policy version: {curr_policy_version}, " - f"{evicted_count} episodes expired, {len(self.buffer)} episodes left" - ) + # Enhanced debug logging to detect hang conditions + if evicted_count > 0 or len(self.buffer) == 0: + policy_versions_in_buffer = [ep.data.policy_version for ep in self.buffer] + logger.debug( + f"[BUFFER EVICTION] max_policy_age: {self.max_policy_age}, " + f"curr_policy_version: {curr_policy_version}, " + f"evicted: {evicted_count}, remaining: {len(self.buffer)}, " + f"versions_in_buffer: {set(policy_versions_in_buffer) if policy_versions_in_buffer else 'EMPTY'}" + ) + + # Log buffer status at debug level (starvation warnings are in main loop) + if len(self.buffer) == 0: + logger.debug( + f"[BUFFER] Buffer empty after eviction. " + f"curr_policy_version={curr_policy_version}, max_policy_age={self.max_policy_age}" + ) + elif len(self.buffer) < self.batch_size * self.dp_size: + logger.debug( + f"[BUFFER] Buffer low: {len(self.buffer)} episodes, " + f"need {self.batch_size * self.dp_size} for sampling." + ) def _collect(self, indices: list[int]): """Efficiently traverse deque and collect elements at each requested index""" @@ -220,3 +237,38 @@ async def state_dict(self) -> dict[str, Any]: async def load_state_dict(self, state_dict: dict[str, Any]) -> None: self.buffer = state_dict["buffer"] random.setstate(state_dict["rng_state"]) + + @endpoint + async def health_check(self, curr_policy_version: int) -> dict[str, Any]: + """Check buffer health without modifying state. + + Returns a dict with: + - size: current buffer size + - required: episodes needed for one batch + - healthy: True if buffer has enough fresh episodes + - versions: set of policy versions in buffer + - freshness_ratio: ratio of on-policy episodes to total + + This is useful for implementing backpressure - training can check + buffer health before triggering weight updates that block rollouts. + """ + required = self.dp_size * self.batch_size + + # Count episodes that would survive eviction + surviving_count = 0 + versions_in_buffer = [] + for entry in self.buffer: + age = curr_policy_version - entry.data.policy_version + if self.max_policy_age is None or age <= self.max_policy_age: + if self.max_resample_count is None or entry.sample_count <= self.max_resample_count: + surviving_count += 1 + versions_in_buffer.append(entry.data.policy_version) + + return { + "size": len(self.buffer), + "surviving_after_eviction": surviving_count, + "required": required, + "healthy": surviving_count >= required * 2, # 2x margin for safety + "versions": set(versions_in_buffer) if versions_in_buffer else set(), + "freshness_ratio": surviving_count / len(self.buffer) if self.buffer else 0.0, + } diff --git a/src/forge/actors/vllm/v0/generator.py b/src/forge/actors/vllm/v0/generator.py index cd456d40e..0945b8884 100644 --- a/src/forge/actors/vllm/v0/generator.py +++ b/src/forge/actors/vllm/v0/generator.py @@ -449,9 +449,28 @@ async def update_weights(self, version: int) -> None: wait_start = time.perf_counter() # Wait until all pending requests have been processed - # TODO: If generating long sequences, this might be long and will block - # generator weight updates - await self.request_lock.wait_for(lambda: len(self.requests) == 0) + # Use timeout to prevent deadlock if requests are stuck + # Default timeout is 60 seconds, which should be sufficient for most sequences + weight_update_timeout = float( + os.environ.get("FORGE_WEIGHT_UPDATE_TIMEOUT_S", "60") + ) + try: + await asyncio.wait_for( + self.request_lock.wait_for(lambda: len(self.requests) == 0), + timeout=weight_update_timeout, + ) + except asyncio.TimeoutError: + pending_count = len(self.requests) + logger.warning( + f"[WEIGHT UPDATE] Timeout after {weight_update_timeout}s waiting for " + f"{pending_count} pending requests to complete. " + f"Proceeding with weight update to prevent deadlock." + ) + record_metric( + "generator_perf/update_weights/timeout_count", + 1, + Reduce.SUM, + ) if curr_requests: wait_duration = time.perf_counter() - wait_start @@ -684,14 +703,90 @@ async def update_weights( loaded_weights = set() logger.info("[GeneratorWorker] Updating weights from torchstore.") hf_param_names = [extract_param_name(key) for key in matching_keys] - # We can't pass a generator since vllm load_weights is not async. - # Instead, we just call load_weights with one parameter at a time. - for name in hf_param_names: - param_key = get_param_key(version, name) - param = await ts.get(param_key) - loaded = model.load_weights([(name, param)]) - del param - loaded_weights.update(loaded) + + # Check if GPU Direct RDMA is enabled + import os + gpu_direct_enabled = os.environ.get("TORCHSTORE_GPU_DIRECT_RDMA", "1") == "1" + use_gpu_direct = gpu_direct_enabled and torch.cuda.is_available() + + if use_gpu_direct: + logger.info( + f"[GeneratorWorker] GPU Direct RDMA enabled - fetching {len(hf_param_names)} " + "parameters directly to GPU memory" + ) + + # Chunked parallel fetching: fetch in batches to balance parallelism with memory + # Too many parallel fetches causes memory pressure; too few loses parallelism benefit + # Default batch size of 16 = ~1GB per batch for 8B model (16 * 64MB avg param size) + # Configurable via FORGE_WEIGHT_FETCH_BATCH_SIZE for different model sizes + BATCH_SIZE = int(os.environ.get("FORGE_WEIGHT_FETCH_BATCH_SIZE", "16")) + param_keys = [get_param_key(version, name) for name in hf_param_names] + logger.info( + f"[GeneratorWorker] Fetching {len(param_keys)} parameters in batches of {BATCH_SIZE}..." + ) + + # Retry configuration for RDMA resilience + MAX_RETRIES = 3 + BASE_DELAY = 1.0 + + async def _fetch_with_retry(key: str) -> torch.Tensor: + """Fetch a single weight with exponential backoff retry.""" + last_exception = None + for attempt in range(MAX_RETRIES): + try: + return await ts.get(key) + except Exception as e: + last_exception = e + if attempt < MAX_RETRIES - 1: + delay = BASE_DELAY * (2**attempt) + logger.warning( + f"[GeneratorWorker] Weight fetch failed (attempt {attempt + 1}/{MAX_RETRIES}), " + f"retrying in {delay:.1f}s: {e}" + ) + record_metric("generator/weight_update/retries", 1, Reduce.SUM) + await asyncio.sleep(delay) + # Record failure metric before raising + record_metric("generator/weight_update/fetch_failures", 1, Reduce.SUM) + raise last_exception + + for batch_start in range(0, len(hf_param_names), BATCH_SIZE): + batch_end = min(batch_start + BATCH_SIZE, len(hf_param_names)) + batch_names = hf_param_names[batch_start:batch_end] + batch_keys = param_keys[batch_start:batch_end] + + # Fetch batch in parallel with retry logic + # With GPU Direct RDMA enabled, torchstore will allocate directly on GPU + batch_results = await asyncio.gather( + *[_fetch_with_retry(key) for key in batch_keys], + return_exceptions=True, + ) + + # Check for any failures that persisted after retries + failed_indices = [ + i for i, result in enumerate(batch_results) if isinstance(result, Exception) + ] + if failed_indices: + failed_names = [batch_names[i] for i in failed_indices] + failed_exceptions = [batch_results[i] for i in failed_indices] + record_metric( + "generator/weight_update/batch_failures", len(failed_indices), Reduce.SUM + ) + logger.error( + f"[GeneratorWorker] Failed to fetch {len(failed_indices)} weights after retries: " + f"{failed_names}" + ) + raise RuntimeError( + f"Failed to fetch {len(failed_indices)} weights after {MAX_RETRIES} retries: " + f"{failed_names}. First error: {failed_exceptions[0]}" + ) + + # Load batch to model and free memory immediately + for name, param in zip(batch_names, batch_results): + # If param is on GPU (from GPU Direct RDMA), it's already ready + # If on CPU, load_weights will handle the transfer + loaded = model.load_weights([(name, param)]) + del param + loaded_weights.update(loaded) @endpoint async def save_model_params(self): @@ -725,18 +820,69 @@ async def fetch( param_names: list[str], ) -> dict[str, SharedTensorHandle]: """Fetch weights from torchstore and load them into shared memory.""" + # Chunked parallel fetching to balance parallelism with memory pressure + # Configurable via FORGE_WEIGHT_FETCH_BATCH_SIZE for different model sizes + BATCH_SIZE = int(os.environ.get("FORGE_WEIGHT_FETCH_BATCH_SIZE", "16")) sd = {} - for name in param_names: - param_key = get_param_key(version, name) + + # Retry configuration for RDMA resilience + MAX_RETRIES = 3 + BASE_DELAY = 1.0 + + async def _fetch_with_retry(key: str) -> torch.Tensor: + """Fetch a single weight with exponential backoff retry.""" + last_exception = None + for attempt in range(MAX_RETRIES): + try: + return await ts.get(key) + except Exception as e: + last_exception = e + if attempt < MAX_RETRIES - 1: + delay = BASE_DELAY * (2**attempt) + logger.warning( + f"[WeightFetcher] Weight fetch failed (attempt {attempt + 1}/{MAX_RETRIES}), " + f"retrying in {delay:.1f}s: {e}" + ) + record_metric("generator/weight_fetcher/retries", 1, Reduce.SUM) + await asyncio.sleep(delay) + record_metric("generator/weight_fetcher/fetch_failures", 1, Reduce.SUM) + raise last_exception + + for batch_start in range(0, len(param_names), BATCH_SIZE): + batch_end = min(batch_start + BATCH_SIZE, len(param_names)) + batch_names = param_names[batch_start:batch_end] + batch_keys = [get_param_key(version, name) for name in batch_names] + + # Fetch batch in parallel with retry logic + batch_results = await asyncio.gather( + *[_fetch_with_retry(key) for key in batch_keys], + return_exceptions=True, + ) + + # Check for any failures that persisted after retries + failed_indices = [ + i for i, result in enumerate(batch_results) if isinstance(result, Exception) + ] + if failed_indices: + failed_names = [batch_names[i] for i in failed_indices] + record_metric( + "generator/weight_fetcher/batch_failures", len(failed_indices), Reduce.SUM + ) + raise RuntimeError( + f"[WeightFetcher] Failed to fetch {len(failed_indices)} weights after {MAX_RETRIES} retries: " + f"{failed_names}" + ) + + # Create shared tensors and free memory immediately # Use explicit resource handling instead of context manager because # ownership is transferred to the Generator (which calls handle.drop() # to clean up). We must unregister from resource_tracker here, otherwise # the fetcher process will try to clean up the shared memory on exit. - param = await ts.get(param_key) - shared_tensor = SharedTensor(tensor=param) - handle = shared_tensor.get_handle() - resource_tracker.unregister(f"/{handle.shm_name}", "shared_memory") - sd[name] = handle - shared_tensor.close() - del param # Explicitly free the tensor after copying to shared memory + for name, param in zip(batch_names, batch_results): + shared_tensor = SharedTensor(tensor=param) + handle = shared_tensor.get_handle() + resource_tracker.unregister(f"/{handle.shm_name}", "shared_memory") + sd[name] = handle + shared_tensor.close() + del param return sd diff --git a/src/forge/observability/metric_actors.py b/src/forge/observability/metric_actors.py index 479df7e5c..712357383 100644 --- a/src/forge/observability/metric_actors.py +++ b/src/forge/observability/metric_actors.py @@ -119,6 +119,10 @@ async def get_or_create_metric_logger( local_fetcher_actor = proc.spawn( "local_fetcher_actor", LocalFetcherActor, global_logger, process_name ) + # Wait for the actor to be ready on all ranks before registering. + # This prevents race conditions when multiple actors spawn in parallel. + await local_fetcher_actor.setup.call() + # Generate a unique ID to map procmesh to fetcher proc._uid = str(uuid.uuid4()) proc._local_fetcher = local_fetcher_actor # pyre-ignore diff --git a/src/forge/rl/types.py b/src/forge/rl/types.py index fec18f587..9eb4ce662 100644 --- a/src/forge/rl/types.py +++ b/src/forge/rl/types.py @@ -32,12 +32,19 @@ class Episode: @property def policy_version(self) -> int | None: - return self.completion.generator_version + return self.completion.generator_version if self.completion else None + + @property + def stop_reason(self) -> str | None: + """Get stop reason from completion for truncation detection.""" + return self.completion.stop_reason if self.completion else None @property def request_tensor(self) -> torch.Tensor: tensor: torch.Tensor = self.completion.prompt_ids.to(torch.long) - if tensor.shape[0] < self.request_len: # left pad + if tensor.shape[0] > self.request_len: # truncate from left (keep end) + tensor = tensor[-self.request_len :] + elif tensor.shape[0] < self.request_len: # left pad diff = self.request_len - tensor.shape[0] tensor = F.pad(tensor, (diff, 0), value=self.pad_id) return tensor @@ -45,7 +52,9 @@ def request_tensor(self) -> torch.Tensor: @property def response_tensor(self) -> torch.Tensor: tensor: torch.Tensor = self.completion.token_ids.to(torch.long) - if tensor.shape[0] < self.response_len: # right pad + if tensor.shape[0] > self.response_len: # truncate from right (keep beginning) + tensor = tensor[: self.response_len] + elif tensor.shape[0] < self.response_len: # right pad diff = self.response_len - tensor.shape[0] tensor = F.pad(tensor, (0, diff), value=self.pad_id) return tensor @@ -67,7 +76,7 @@ def to_dict(self, exclude: list[str] | None = None) -> dict[str, Any]: "completion": self.completion, } - if self.reward_breakdown is not None and "reward_breakdown" not in exclude: + if self.reward_breakdown is not None and (exclude is None or "reward_breakdown" not in exclude): result.update(self.reward_breakdown) if exclude: diff --git a/src/forge/util/config.py b/src/forge/util/config.py index 577fa6457..196955082 100644 --- a/src/forge/util/config.py +++ b/src/forge/util/config.py @@ -16,6 +16,7 @@ # Add support for summing lists of numbers, e.g. ${sum:${max_req_tokens},${max_res_tokens}} OmegaConf.register_new_resolver("sum", lambda *args: sum(args), replace=True) +OmegaConf.register_new_resolver("multiply", lambda a, b: a * b, replace=True) # Add support for boolean negation, e.g. ${not:${compile}} OmegaConf.register_new_resolver("not", lambda x: not x, replace=True)