Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/workflows/e2e.yml
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@ jobs:
- name: Run E2E tests
env:
RUNPOD_API_KEY: ${{ secrets.RUNPOD_API_KEY }}
FLASH_SDK_GIT_REF: ${{ github.sha }}
run: |
uv run pytest e2e/ \
${{ inputs.tests != '' && format('-k "{0}"', inputs.tests) || '' }} \
Expand Down
60 changes: 44 additions & 16 deletions e2e/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,21 @@
import asyncio
import os
import pickle
import sys
from pathlib import Path

import pytest
# Ensure the e2e/ directory is on sys.path so test files can import local
# modules (provisioner, etc.) regardless of how pytest resolves the rootdir.
_E2E_DIR = str(Path(__file__).parent)
if _E2E_DIR not in sys.path:
sys.path.insert(0, _E2E_DIR)

import pytest # noqa: E402

try:
import tomllib
import tomllib # noqa: E402
except ImportError:
import tomli as tomllib # type: ignore[no-redef]
import tomli as tomllib # type: ignore[no-redef] # noqa: E402


def _api_key_from_config() -> str | None:
Expand All @@ -25,7 +32,8 @@ def _api_key_from_config() -> str | None:
try:
data = tomllib.loads(config_file.read_text())
return data.get("default", {}).get("api_key")
except Exception:
except Exception as exc:
print(f"Warning: could not parse ~/.runpod/config.toml: {exc}")
return None


Expand All @@ -38,29 +46,38 @@ def endpoint_id_from_state(project_dir: Path) -> str:

The state file is a (resources_dict, config_hashes_dict) tuple.
resources_dict keys are "ResourceType:name", values are resource objects with .id.

Raises FileNotFoundError if the state file is missing (deploy did not complete).
Raises ValueError if the file exists but contains no endpoint ID (format may have changed).
"""
state_file = project_dir / ".flash" / "resources.pkl"
if not state_file.exists():
raise FileNotFoundError(f"State file not found: {state_file}")
with open(state_file, "rb") as f:
data = pickle.load(f)
try:
with open(state_file, "rb") as f:
data = pickle.load(f)
except Exception as exc:
raise ValueError(
f"Failed to deserialize state file {state_file} — "
f"the .flash/resources.pkl format may have changed: {exc}"
) from exc
resources = data[0] if isinstance(data, tuple) else data
for _key, resource in resources.items():
endpoint_id = getattr(resource, "id", None)
if endpoint_id:
return endpoint_id
raise ValueError(f"No endpoint ID found in state file. Keys: {list(resources)}")
raise ValueError(
f"No endpoint ID found in state file {state_file}. "
f"Keys present: {list(resources)}. "
f"Check that the resource object has an 'id' attribute."
)


def sweep_endpoints(api_key: str) -> None:
"""Delete all endpoints on the account.
def sweep_endpoints(api_key: str, *, prefix: str = "flash-qa-") -> None:
"""Delete endpoints whose names start with prefix.

The e2e RUNPOD_API_KEY is dedicated to testing. Call this in every test's
finally block to ensure quota is fully released regardless of whether the
graceful undeploy succeeded.

To restrict cleanup to smoke-test endpoints only, swap the list comprehension:
endpoints = [ep for ep in endpoints if ep.get("name", "").startswith("flash-qa-smoke-")]
Defaults to "flash-qa-" so only test-created endpoints are removed.
Pass prefix="" to delete all endpoints on the account (use with caution).
"""
from runpod_flash.core.api.runpod import RunpodGraphQLClient

Expand All @@ -69,7 +86,12 @@ async def _run(key: str) -> None:
result = await client._execute_graphql(
"query { myself { endpoints { id name } } }"
)
endpoints = result.get("myself", {}).get("endpoints", [])
all_endpoints = result.get("myself", {}).get("endpoints", [])
endpoints = [
ep
for ep in all_endpoints
if not prefix or ep.get("name", "").startswith(prefix)
]
for ep in endpoints:
eid, ename = ep["id"], ep.get("name", ep["id"])
try:
Expand All @@ -95,3 +117,9 @@ def restore_real_credentials(monkeypatch: pytest.MonkeyPatch) -> None:
)
else:
pytest.skip("No credentials available — skipping E2E test")


@pytest.fixture
def api_key() -> str:
"""Return the RunPod API key for tests that need to pass it explicitly."""
return _REAL_API_KEY # type: ignore[return-value] # guaranteed set by restore_real_credentials autouse
122 changes: 122 additions & 0 deletions e2e/provisioner.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
"""Endpoint provisioner for E2E session-scoped fixtures.

provision() deploys a Flash worker and returns its endpoint_id.
All shared endpoints are deployed in parallel at session start.

Git ref injection
-----------------
Set FLASH_SDK_GIT_REF to a commit SHA or branch name to install that exact
version of runpod-flash inside the worker container instead of the latest
PyPI release. In CI, set this to github.sha so workers run the branch under
test rather than the last published release.

FLASH_SDK_GIT_REF=${{ github.sha }} # in CI workflow
"""

import os
import shutil
import subprocess
import tempfile
from pathlib import Path

from conftest import endpoint_id_from_state

# ---------------------------------------------------------------------------
# Git ref injection
# ---------------------------------------------------------------------------

FLASH_GIT_REF: str = os.environ.get("FLASH_SDK_GIT_REF", "")
FLASH_LOCAL_PATH: str = os.environ.get("FLASH_SDK_LOCAL_PATH", "")
_FLASH_REPO = "https://github.com/runpod/runpod-flash"


def flash_dep() -> str:
"""Return the runpod-flash pip requirement string for worker pyproject.toml.

CI (FLASH_SDK_GIT_REF set): installs the exact commit under test.
Local dev with local path (FLASH_SDK_LOCAL_PATH set): installs from local
checkout — useful when the fix is not yet on PyPI and the git repo is private.
Local dev (unset): installs the latest PyPI release.
"""
if FLASH_LOCAL_PATH:
return f"runpod-flash @ file://{FLASH_LOCAL_PATH}"
if FLASH_GIT_REF:
return f"runpod-flash @ git+{_FLASH_REPO}@{FLASH_GIT_REF}"
return "runpod-flash"


# ---------------------------------------------------------------------------
# Provisioner
# ---------------------------------------------------------------------------

_PYPROJECT_TMPL = """\
[project]
name = "{name}"
version = "0.1.0"
requires-python = ">=3.11,<3.13"
dependencies = [{deps}]
"""


def provision(
worker_code: str,
*,
name: str,
api_key: str,
extra_deps: list[str] | None = None,
deploy_timeout: int = 600,
) -> tuple[str, Path]:
"""Deploy a Flash worker and return (endpoint_id, project_dir).

The returned project_dir is a temporary directory that owns the .flash
state. The caller is responsible for cleanup — call shutil.rmtree() on
project_dir when the endpoint is no longer needed.

Args:
worker_code: Python source of the worker file.
name: Endpoint name (must be unique per CI run).
api_key: RunPod API key passed explicitly to the subprocess env.
extra_deps: Additional pip requirements (beyond runpod-flash).
deploy_timeout: Seconds before subprocess.run times out.

Returns:
(endpoint_id, project_dir)

Raises:
RuntimeError: If flash deploy exits non-zero.
"""
deps = [flash_dep()]
if extra_deps:
deps.extend(extra_deps)
deps_quoted = ", ".join(f'"{d}"' for d in deps)
pyproject = _PYPROJECT_TMPL.format(name=name, deps=deps_quoted)

tmp_dir = Path(tempfile.mkdtemp(prefix=f"flash-e2e-{name}-"))
(tmp_dir / "worker.py").write_text(worker_code)
(tmp_dir / "pyproject.toml").write_text(pyproject)

env = os.environ.copy()
env["RUNPOD_API_KEY"] = api_key # explicit — does not depend on autouse fixture

try:
result = subprocess.run(
["uv", "run", "flash", "deploy"],
cwd=tmp_dir,
env=env,
capture_output=True,
text=True,
timeout=deploy_timeout,
)
except Exception:
shutil.rmtree(tmp_dir, ignore_errors=True)
raise

if result.returncode != 0:
shutil.rmtree(tmp_dir, ignore_errors=True)
raise RuntimeError(
f"flash deploy failed for '{name}' (exit {result.returncode}):\n"
f"stdout: {result.stdout}\nstderr: {result.stderr}"
)

endpoint_id = endpoint_id_from_state(tmp_dir)
return endpoint_id, tmp_dir
13 changes: 5 additions & 8 deletions e2e/test_cpu_smoke.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import runpod

from conftest import endpoint_id_from_state, sweep_endpoints
from provisioner import flash_dep

WORKER_NAME = f"flash-qa-smoke-{uuid.uuid4().hex[:8]}"

Expand All @@ -28,22 +29,20 @@ async def echo(msg: str = "") -> dict:
name = "{WORKER_NAME}"
version = "0.1.0"
requires-python = ">=3.11,<3.13"
dependencies = ["runpod-flash"]
dependencies = ["{flash_dep()}"]
'''


class TestCpuSmoke:
"""CPU smoke: deploy → invoke → undeploy."""

def test_deploy_invoke_undeploy(self, tmp_path: Path) -> None:
"""Deploy a minimal CPU worker, invoke it, verify output, undeploy."""
env = os.environ.copy()

(tmp_path / "worker.py").write_text(WORKER_CODE)
(tmp_path / "pyproject.toml").write_text(PYPROJECT_TOML)

try:
# Deploy
result = subprocess.run(
["uv", "run", "flash", "deploy"],
cwd=tmp_path,
Expand All @@ -59,8 +58,7 @@ def test_deploy_invoke_undeploy(self, tmp_path: Path) -> None:

endpoint_id = endpoint_id_from_state(tmp_path)

# Invoke
runpod.api_key = env.get("RUNPOD_API_KEY")
runpod.api_key = env["RUNPOD_API_KEY"]
output = runpod.Endpoint(endpoint_id).run_sync(
{"msg": "smoke"}, timeout=180
)
Expand All @@ -70,7 +68,7 @@ def test_deploy_invoke_undeploy(self, tmp_path: Path) -> None:
assert output.get("status") == "ok", f"Unexpected status: {output}"

finally:
# Attempt graceful undeploy first
# Exercise the undeploy CLI path; sweep catches any quota leak if this fails.
try:
undeploy = subprocess.run(
["uv", "run", "flash", "undeploy", WORKER_NAME, "--force"],
Expand All @@ -88,6 +86,5 @@ def test_deploy_invoke_undeploy(self, tmp_path: Path) -> None:
except subprocess.TimeoutExpired:
print("WARNING: undeploy timed out after 60s")

# Always sweep all endpoints — dedicated e2e account, stale
# endpoints hit the worker quota on subsequent runs.
# Sweep flash-qa-* endpoints — stale endpoints exhaust worker quota.
sweep_endpoints(env["RUNPOD_API_KEY"])
Loading
Loading