diff --git a/langfuse/__init__.py b/langfuse/__init__.py index d33febca7..08d8325cf 100644 --- a/langfuse/__init__.py +++ b/langfuse/__init__.py @@ -8,7 +8,7 @@ EvaluatorStats, MapperFunction, ) -from langfuse.experiment import Evaluation +from langfuse.experiment import Evaluation, RegressionError, RunnerContext from ._client import client as _client_module from ._client.attributes import LangfuseOtelSpanAttributes @@ -63,6 +63,8 @@ "EvaluatorStats", "BatchEvaluationResumeToken", "BatchEvaluationResult", + "RunnerContext", + "RegressionError", "__version__", "is_default_export_span", "is_langfuse_span", diff --git a/langfuse/experiment.py b/langfuse/experiment.py index 67b50a900..404c96e1d 100644 --- a/langfuse/experiment.py +++ b/langfuse/experiment.py @@ -6,7 +6,9 @@ """ import asyncio +from datetime import datetime from typing import ( + TYPE_CHECKING, Any, Awaitable, Dict, @@ -15,12 +17,17 @@ Protocol, TypedDict, Union, + overload, ) from langfuse.api import DatasetItem from langfuse.logger import langfuse_logger as logger from langfuse.types import ExperimentScoreType +if TYPE_CHECKING: + from langfuse._client.client import Langfuse + from langfuse.batch_evaluation import CompositeEvaluatorFunction + class LocalExperimentItem(TypedDict, total=False): """Structure for local experiment data items (not from Langfuse datasets). @@ -1049,3 +1056,152 @@ def langfuse_evaluator( ) return langfuse_evaluator + + +class RunnerContext: + """Wraps :meth:`Langfuse.run_experiment` with CI-injected defaults. + + Intended for use with the ``langfuse/experiment-action`` GitHub Action + (https://github.com/langfuse/experiment-action). The action builds a + ``RunnerContext`` before invoking the user's ``experiment(context)`` + function. Defaults set here (dataset, metadata tags) are applied when + the user omits them on the :meth:`run_experiment` call; users can + override any default by passing the corresponding argument explicitly. + """ + + def __init__( + self, + *, + client: "Langfuse", + data: Optional[ExperimentData] = None, + dataset_version: Optional[datetime] = None, + metadata: Optional[Dict[str, str]] = None, + ): + """Build a ``RunnerContext`` populated with defaults for ``run_experiment``. + + Typically called by the ``langfuse/experiment-action`` GitHub Action, + not by end users directly. Every field except ``client`` is optional: + fields left as ``None`` simply mean the corresponding argument must be + supplied on the :meth:`run_experiment` call. + + Args: + client: Initialized Langfuse SDK client used to execute the + experiment. The action creates this from the + ``langfuse_public_key`` / ``langfuse_secret_key`` / + ``langfuse_base_url`` inputs. + data: Default dataset items to run the experiment on. Accepts + either ``List[LocalExperimentItem]`` or ``List[DatasetItem]``. + Injected by the action when ``dataset_name`` is configured. + If ``None``, the user must pass ``data=`` to + :meth:`run_experiment`. + dataset_version: Optional pinned dataset version. Injected by the + action when ``dataset_version`` is configured. + metadata: Default metadata attached to every experiment trace and + the dataset run. The action injects GitHub-sourced tags (SHA, + PR link, workflow run link, branch, GH user, etc.). Merged + with any ``metadata`` passed to :meth:`run_experiment`, with + user-supplied keys winning on collision. + """ + self.client = client + self.data = data + self.dataset_version = dataset_version + self.metadata = metadata + + def run_experiment( + self, + *, + name: str, + run_name: Optional[str] = None, + description: Optional[str] = None, + data: Optional[ExperimentData] = None, + task: TaskFunction, + evaluators: List[EvaluatorFunction] = [], + composite_evaluator: Optional["CompositeEvaluatorFunction"] = None, + run_evaluators: List[RunEvaluatorFunction] = [], + max_concurrency: int = 50, + metadata: Optional[Dict[str, str]] = None, + _dataset_version: Optional[datetime] = None, + ) -> ExperimentResult: + resolved_data = data if data is not None else self.data + if resolved_data is None: + raise ValueError( + "`data` must be provided either on the RunnerContext or the run_experiment call" + ) + + resolved_dataset_version = ( + _dataset_version if _dataset_version is not None else self.dataset_version + ) + + merged_metadata: Optional[Dict[str, str]] + if self.metadata is None and metadata is None: + merged_metadata = None + else: + merged_metadata = {**(self.metadata or {}), **(metadata or {})} + + return self.client.run_experiment( + name=name, + run_name=run_name, + description=description, + data=resolved_data, + task=task, + evaluators=evaluators, + composite_evaluator=composite_evaluator, + run_evaluators=run_evaluators, + max_concurrency=max_concurrency, + metadata=merged_metadata, + _dataset_version=resolved_dataset_version, + ) + + +class RegressionError(Exception): + """Raised by a user's ``experiment`` function to signal a CI gate failure. + + Intended for use with the ``langfuse/experiment-action`` GitHub Action + (https://github.com/langfuse/experiment-action). The action catches this + exception and, when ``should_fail_on_error`` is enabled, fails the + workflow run and renders a callout in the PR comment using + ``metric``/``value``/``threshold`` if supplied, otherwise ``str(exc)``. + + Callers choose one of three forms: + + - ``RegressionError(result=r)`` — minimal, generic message. + - ``RegressionError(result=r, message="...")`` — free-form message. + - ``RegressionError(result=r, metric="acc", value=0.7, threshold=0.9)`` — + structured; ``metric`` and ``value`` must be provided together so the + action can render a targeted callout without ``None`` placeholders. + """ + + @overload + def __init__(self, *, result: ExperimentResult) -> None: ... + @overload + def __init__(self, *, result: ExperimentResult, message: str) -> None: ... + @overload + def __init__( + self, + *, + result: ExperimentResult, + metric: str, + value: float, + threshold: Optional[float] = None, + message: Optional[str] = None, + ) -> None: ... + def __init__( + self, + *, + result: ExperimentResult, + metric: Optional[str] = None, + value: Optional[float] = None, + threshold: Optional[float] = None, + message: Optional[str] = None, + ): + self.result = result + self.metric = metric + self.value = value + self.threshold = threshold + if message is not None: + formatted = message + elif metric is not None and value is not None: + formatted = f"Regression on `{metric}`: {value} (threshold {threshold})" + else: + formatted = "Experiment regression detected" + super().__init__(formatted) diff --git a/tests/unit/test_experiment.py b/tests/unit/test_experiment.py new file mode 100644 index 000000000..c6c8465a3 --- /dev/null +++ b/tests/unit/test_experiment.py @@ -0,0 +1,248 @@ +"""Tests for ``langfuse.experiment`` — ``RunnerContext`` and ``RegressionError``.""" + +import inspect +import typing +from datetime import datetime +from typing import get_type_hints +from unittest.mock import MagicMock + +import pytest + +from langfuse import RegressionError, RunnerContext +from langfuse._client.client import Langfuse +from langfuse.batch_evaluation import CompositeEvaluatorFunction + + +def _noop_task(*, item, **kwargs): # pragma: no cover - never invoked via mock + return None + + +def _make_ctx(**kwargs) -> RunnerContext: + client = MagicMock(spec=Langfuse) + client.run_experiment.return_value = "result-sentinel" + return RunnerContext(client=client, **kwargs) + + +class TestRunnerContextDefaults: + def test_context_defaults_flow_through(self): + ctx_data = [{"input": "a"}] + ctx_version = datetime(2026, 1, 1) + ctx = _make_ctx( + data=ctx_data, + dataset_version=ctx_version, + metadata={"sha": "abc123"}, + ) + + result = ctx.run_experiment(name="exp", task=_noop_task) + + assert result == "result-sentinel" + ctx.client.run_experiment.assert_called_once() + kwargs = ctx.client.run_experiment.call_args.kwargs + assert kwargs["name"] == "exp" + assert kwargs["data"] is ctx_data + assert kwargs["metadata"] == {"sha": "abc123"} + assert kwargs["_dataset_version"] == ctx_version + assert kwargs["task"] is _noop_task + + def test_call_overrides_win(self): + ctx = _make_ctx( + data=[{"input": "ctx"}], + dataset_version=datetime(2026, 1, 1), + ) + + override_data = [{"input": "override"}] + override_version = datetime(2026, 6, 6) + ctx.run_experiment( + name="exp", + task=_noop_task, + run_name="call-run", + data=override_data, + _dataset_version=override_version, + ) + + kwargs = ctx.client.run_experiment.call_args.kwargs + assert kwargs["name"] == "exp" + assert kwargs["run_name"] == "call-run" + assert kwargs["data"] is override_data + assert kwargs["_dataset_version"] == override_version + + +class TestRunnerContextMetadataMerge: + def test_user_keys_win_on_collision(self): + ctx = _make_ctx( + data=[{"input": "a"}], + metadata={"sha": "abc", "branch": "main"}, + ) + ctx.run_experiment( + name="exp", task=_noop_task, metadata={"sha": "def", "pr": "42"} + ) + assert ctx.client.run_experiment.call_args.kwargs["metadata"] == { + "sha": "def", + "branch": "main", + "pr": "42", + } + + def test_context_metadata_only(self): + ctx = _make_ctx(data=[{"input": "a"}], metadata={"sha": "abc"}) + ctx.run_experiment(name="exp", task=_noop_task) + assert ctx.client.run_experiment.call_args.kwargs["metadata"] == {"sha": "abc"} + + def test_call_metadata_only(self): + ctx = _make_ctx(data=[{"input": "a"}]) + ctx.run_experiment(name="exp", task=_noop_task, metadata={"pr": "1"}) + assert ctx.client.run_experiment.call_args.kwargs["metadata"] == {"pr": "1"} + + def test_both_none_stays_none(self): + ctx = _make_ctx(data=[{"input": "a"}]) + ctx.run_experiment(name="exp", task=_noop_task) + assert ctx.client.run_experiment.call_args.kwargs["metadata"] is None + + +class TestRunnerContextLocalItems: + def test_local_items_pass_through_as_context_default(self): + items = [{"input": "x", "expected_output": "y"}] + ctx = _make_ctx(data=items) + ctx.run_experiment(name="exp", task=_noop_task) + assert ctx.client.run_experiment.call_args.kwargs["data"] is items + + def test_local_items_pass_through_as_call_override(self): + ctx = _make_ctx() + items = [{"input": "x"}] + ctx.run_experiment(name="exp", task=_noop_task, data=items) + assert ctx.client.run_experiment.call_args.kwargs["data"] is items + + +class TestRunnerContextValidation: + def test_missing_data_raises(self): + ctx = _make_ctx() + with pytest.raises(ValueError, match="data"): + ctx.run_experiment(name="exp", task=_noop_task) + + +class TestRegressionError: + def test_is_exception(self): + result = MagicMock() + exc = RegressionError(result=result) + assert isinstance(exc, Exception) + assert exc.result is result + + def test_default_message(self): + exc = RegressionError(result=MagicMock()) + assert str(exc) == "Experiment regression detected" + assert exc.metric is None + assert exc.value is None + assert exc.threshold is None + + def test_structured_message(self): + exc = RegressionError( + result=MagicMock(), metric="avg_accuracy", value=0.78, threshold=0.9 + ) + assert exc.metric == "avg_accuracy" + assert exc.value == 0.78 + assert exc.threshold == 0.9 + assert "avg_accuracy" in str(exc) + assert "0.78" in str(exc) + assert "0.9" in str(exc) + + def test_free_form_message(self): + exc = RegressionError( + result=MagicMock(), + message="custom explanation", + ) + assert str(exc) == "custom explanation" + + def test_message_wins_over_structured(self): + exc = RegressionError( + result=MagicMock(), + metric="avg_accuracy", + value=0.5, + threshold=0.9, + message="custom explanation", + ) + assert str(exc) == "custom explanation" + assert exc.metric == "avg_accuracy" + assert exc.value == 0.5 + assert exc.threshold == 0.9 + + def test_partial_structured_falls_back_to_default(self): + """The structured overload requires ``metric`` and ``value`` together. + + If a caller bypasses the type checker and passes only one, we fall + back to the default message rather than rendering misleading + ``None`` placeholders in the PR comment. + """ + exc = RegressionError(result=MagicMock(), metric="avg_accuracy") # type: ignore[call-overload] + assert str(exc) == "Experiment regression detected" + + +class TestSignatureDriftGuard: + """Fails loudly if ``Langfuse.run_experiment`` grows a parameter that is + not threaded through ``RunnerContext.run_experiment``. + + ``data`` is the only genuinely relaxed parameter: it is required on the + client but optional on the RunnerContext so the action can inject it. + ``run_name`` and ``_dataset_version`` are already ``Optional`` on the + client and must match as-is. ``name`` is required on both — the action + supports a directory of experiments, so each script must name itself. + """ + + RELAXED_PARAMS = {"data"} + + # `CompositeEvaluatorFunction` is only imported under TYPE_CHECKING in + # ``langfuse.experiment`` to break the circular dependency with + # ``langfuse.batch_evaluation``, so its forward-ref must be resolved + # explicitly when inspecting annotations. + LOCALNS = {"CompositeEvaluatorFunction": CompositeEvaluatorFunction} + + def test_no_divergence(self): + client_param_names = self._param_names(Langfuse.run_experiment) + ctx_param_names = self._param_names(RunnerContext.run_experiment) + + assert client_param_names == ctx_param_names, ( + "RunnerContext.run_experiment params do not match " + "Langfuse.run_experiment. Missing: " + f"{client_param_names - ctx_param_names}. " + f"Extra: {ctx_param_names - client_param_names}." + ) + + client_hints = get_type_hints(Langfuse.run_experiment) + ctx_hints = get_type_hints( + RunnerContext.run_experiment, localns=self.LOCALNS + ) + + for name in client_param_names: + client_ann = client_hints.get(name, inspect.Parameter.empty) + ctx_ann = ctx_hints.get(name, inspect.Parameter.empty) + + if name in self.RELAXED_PARAMS: + # RunnerContext version must be Optional[]. + # Already-optional client annotations (``run_name``, + # ``_dataset_version``) just need to match as-is. + if self._is_optional(client_ann): + assert ctx_ann == client_ann, ( + f"param `{name}`: expected {client_ann}, got {ctx_ann}" + ) + else: + assert ctx_ann == typing.Optional[client_ann], ( + f"param `{name}`: expected Optional[{client_ann}], " + f"got {ctx_ann}" + ) + else: + assert ctx_ann == client_ann, ( + f"param `{name}`: annotation drift — " + f"client={client_ann}, context={ctx_ann}" + ) + + @staticmethod + def _param_names(func) -> set: + return { + name + for name in inspect.signature(func).parameters + if name != "self" + } + + @staticmethod + def _is_optional(annotation) -> bool: + origin = typing.get_origin(annotation) + args = typing.get_args(annotation) + return origin is typing.Union and type(None) in args