Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ version = "0.2.10"
description = "Mem0 Dify plugin"
requires-python = ">=3.12"
dependencies = [
"mem0ai>=1.0.2",
"mem0ai>=1.0.2,<=1.0.11",
"openai",
"azure-identity",
"langchain-neo4j",
Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
mem0ai>=1.0.2
mem0ai>=1.0.2,<=1.0.11
openai
azure-identity
langchain-neo4j
Expand Down
45 changes: 45 additions & 0 deletions tests/unit/provider/test_mem0_provider_validation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
from __future__ import annotations

import asyncio
from unittest.mock import MagicMock

from provider.mem0ai import Mem0Provider


def test_validate_credentials_async_mode_uses_async_client_search(monkeypatch) -> None:
import provider.mem0ai as provider_mod

captured: dict[str, object] = {}
fake_loop = object()
fake_future = MagicMock()
fake_future.result.return_value = {"results": []}

class FakeClient:
def ensure_bg_loop(self) -> object:
captured["ensure_bg_loop_called"] = True
return fake_loop

async def search(self, payload: dict[str, object], timeout_s: int) -> dict[str, object]:
captured["search_payload"] = payload
captured["search_timeout"] = timeout_s
return {"results": []}

def _fake_run_coroutine_threadsafe(coro, loop): # noqa: ANN001
assert asyncio.iscoroutine(coro)
captured["loop"] = loop
coro.close()
return fake_future

monkeypatch.setattr(provider_mod, "get_async_client", lambda _credentials: FakeClient())
monkeypatch.setattr(
provider_mod.asyncio,
"run_coroutine_threadsafe",
_fake_run_coroutine_threadsafe,
)

provider = object.__new__(Mem0Provider)
provider._validate_credentials({"async_mode": True, "log_level": "INFO"})

assert captured["ensure_bg_loop_called"] is True
assert captured["loop"] is fake_loop
assert fake_future.result.call_count == 1
57 changes: 57 additions & 0 deletions tests/unit/utils/test_async_memory_init_compat.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
from __future__ import annotations

from unittest.mock import MagicMock

import pytest

from utils.mem0_client import AsyncMem0Client


@pytest.mark.asyncio
async def test_create_supports_async_from_config(monkeypatch: pytest.MonkeyPatch) -> None:
import utils.mem0_client as mem0_client

fake_memory = MagicMock()
fake_memory.llm = None

monkeypatch.setattr(mem0_client, "build_local_mem0_config", lambda _c: {})

async def _fake_from_config(config: dict[str, object]) -> object:
assert config == {}
return fake_memory

monkeypatch.setattr(mem0_client.AsyncMemory, "from_config", _fake_from_config)

client = AsyncMem0Client({})

try:
created = await client.create()
assert created is fake_memory
assert client.memory is fake_memory
finally:
await client.aclose()


@pytest.mark.asyncio
async def test_create_supports_sync_from_config(monkeypatch: pytest.MonkeyPatch) -> None:
import utils.mem0_client as mem0_client

fake_memory = MagicMock()
fake_memory.llm = None

monkeypatch.setattr(mem0_client, "build_local_mem0_config", lambda _c: {})

def _fake_from_config(config: dict[str, object]) -> object:
assert config == {}
return fake_memory

monkeypatch.setattr(mem0_client.AsyncMemory, "from_config", _fake_from_config)

client = AsyncMem0Client({})

try:
created = await client.create()
assert created is fake_memory
assert client.memory is fake_memory
finally:
await client.aclose()
16 changes: 15 additions & 1 deletion utils/mem0_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,20 @@
_mem0_init_lock = threading.Lock()


async def _resolve_async_memory_from_config(config: dict[str, Any]) -> AsyncMemory:
"""Support both old and new mem0 AsyncMemory.from_config semantics.

Older mem0 releases exposed ``AsyncMemory.from_config`` as an async
classmethod, while newer releases return an ``AsyncMemory`` instance
directly. Dify's async validation path always calls ``create()``, so we
normalize both forms here and keep the rest of the client code unchanged.
"""
memory_or_awaitable = AsyncMemory.from_config(config)
if asyncio.iscoroutine(memory_or_awaitable):
return await memory_or_awaitable
return memory_or_awaitable


def _patch_llm_compat(llm: Any) -> None:
"""Patch LLM instances that lack _parse_response (e.g., structured providers)."""
if llm is None or hasattr(llm, "_parse_response"):
Expand Down Expand Up @@ -834,7 +848,7 @@ async def create(self) -> AsyncMemory:
loop = asyncio.get_event_loop()
await loop.run_in_executor(None, _mem0_init_lock.acquire)
try:
self.memory = await AsyncMemory.from_config(self.config)
self.memory = await _resolve_async_memory_from_config(self.config)
finally:
_mem0_init_lock.release()
_patch_llm_compat(getattr(self.memory, "llm", None))
Expand Down
4 changes: 2 additions & 2 deletions uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading