From f7f7328d03f8aeb65559c1559f9fbee1ad98a60a Mon Sep 17 00:00:00 2001 From: arun1729 Date: Mon, 13 Apr 2026 11:39:01 -0400 Subject: [PATCH] Adding cloud client, some refactoring and adding Graph.ls() --- .github/workflows/python-tests.yml | 28 +- .gitignore | 2 + cog/cloud_client.py | 318 ++++++++++ cog/config.py | 4 + cog/core.py | 6 +- cog/embedding_providers.py | 1 - cog/embeddings.py | 458 ++++++++++++++ cog/export.py | 13 +- cog/search.py | 197 ++++++ cog/torque.py | 923 ++++++++++------------------- cog/view.py | 29 +- pytest.ini | 3 + scripts/local_wheel_server.py | 19 +- setup.py | 2 +- test/test_batch_mode.py | 5 +- test/test_cloud.py | 440 ++++++++++++++ test/test_cloud_parity.py | 622 +++++++++++++++++++ test/test_db_2.py | 9 +- test/test_loopback_parity.py | 534 +++++++++++++++++ test/test_ls_use.py | 200 +++++++ test/test_torque2.py | 3 +- 21 files changed, 3170 insertions(+), 646 deletions(-) create mode 100644 cog/cloud_client.py create mode 100644 cog/embeddings.py create mode 100644 cog/search.py create mode 100644 pytest.ini create mode 100644 test/test_cloud.py create mode 100644 test/test_cloud_parity.py create mode 100644 test/test_loopback_parity.py create mode 100644 test/test_ls_use.py diff --git a/.github/workflows/python-tests.yml b/.github/workflows/python-tests.yml index 073e065..c44f64f 100644 --- a/.github/workflows/python-tests.yml +++ b/.github/workflows/python-tests.yml @@ -30,7 +30,7 @@ jobs: - name: Run tests with coverage run: | - python -m pytest test/ -v --ignore=test/bench.py --ignore=test/benchmark.py --cov=cog --cov-report=xml --junitxml=junit.xml -o junit_family=legacy + python -m pytest test/ -v --ignore=test/bench.py --ignore=test/benchmark.py -m "not cloud" --cov=cog --cov-report=xml --junitxml=junit.xml -o junit_family=legacy - name: Upload coverage to Codecov if: matrix.python-version == '3.12' @@ -54,3 +54,29 @@ jobs: run: | python test/benchmark.py --quick --skip-individual continue-on-error: true + + cloud-parity: + runs-on: ubuntu-latest + if: github.event_name == 'push' # only on merge, not on every PR + needs: test # run after unit tests pass + + steps: + - uses: actions/checkout@v4 + + - name: Set up Python 3.12 + uses: actions/setup-python@v5 + with: + python-version: '3.12' + + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install -r requirements.txt + pip install pytest + + - name: Run cloud parity tests + if: env.COGDB_API_KEY != '' + env: + COGDB_API_KEY: ${{ secrets.COGDB_API_KEY }} + run: | + python -m pytest test/test_cloud_parity.py -v diff --git a/.gitignore b/.gitignore index 806798f..4a2293a 100644 --- a/.gitignore +++ b/.gitignore @@ -12,5 +12,7 @@ test/.coverage PRIVATE_NOTES.md FEATURE_ROADMAP.md RELEASE_v3.3.0.md +private/ +scripts/* docs/architecture.md .vscode/settings.json diff --git a/cog/cloud_client.py b/cog/cloud_client.py new file mode 100644 index 0000000..74f37f0 --- /dev/null +++ b/cog/cloud_client.py @@ -0,0 +1,318 @@ +""" +HTTP transport for CogDB Cloud. +""" + +import json +import urllib.request +import urllib.error +from . import config as cfg + + +class CloudClient: + """Authenticated HTTP client for a single CogDB Cloud graph.""" + + MAX_BATCH_SIZE = 500 # server-side limit per request + + def __init__(self, graph_name, api_key, flush_interval=1): + self._graph_name = graph_name + self._api_key = api_key + self._base_url = f"{cfg.CLOUD_URL}{cfg.CLOUD_API_PREFIX}/{graph_name}" + self._account_url = f"{cfg.CLOUD_URL}{cfg.CLOUD_API_PREFIX}/_cog_sys__" + self._flush_interval = flush_interval + self._pending = [] # buffered mutations awaiting flush + + def _request(self, method, path, body=None): + """Make an authenticated request to a graph-scoped endpoint.""" + return self._do_request(method, f"{self._base_url}{path}", body) + + def _account_request(self, method, path, body=None): + """Make an authenticated request to an account-scoped endpoint.""" + return self._do_request(method, f"{self._account_url}{path}", body) + + def _do_request(self, method, full_url, body=None): + """Shared HTTP logic for all authenticated requests.""" + data = json.dumps(body).encode("utf-8") if body else None + req = urllib.request.Request(full_url, data=data, method=method) + req.add_header("Authorization", self._api_key) + req.add_header("Content-Type", "application/json") + req.add_header("User-Agent", "cogdb-python") + + try: + with urllib.request.urlopen(req) as resp: + return json.loads(resp.read().decode("utf-8")) + except urllib.error.HTTPError as e: + if e.code in (401, 403): + raise PermissionError("Invalid API key") + try: + detail = json.loads(e.read().decode("utf-8")).get("detail", "") + except Exception: + detail = "" + if e.code in (400, 422): + raise ValueError(detail or f"Bad request ({e.code})") + raise RuntimeError( + f"CogDB Cloud error ({e.code})" + (f": {detail}" if detail else "") + ) + except urllib.error.URLError as e: + raise ConnectionError( + f"Cannot reach CogDB Cloud at {cfg.CLOUD_URL}: {e.reason}" + ) + + def _mutate_batch(self, mutations): + """Send mutations via the batch endpoint, chunking at MAX_BATCH_SIZE.""" + total_count = 0 + for i in range(0, len(mutations), self.MAX_BATCH_SIZE): + chunk = mutations[i:i + self.MAX_BATCH_SIZE] + result = self._request("POST", "/mutate_batch", { + "mutations": chunk, + }) + total_count += result.get("count", len(chunk)) + return {"ok": True, "count": total_count} + + def _mutate_one(self, mutation): + """Send a single mutation immediately (bypasses buffer).""" + return self._mutate_batch([mutation]) + + def _enqueue(self, mutation): + """Buffer a mutation; auto-flush when flush_interval threshold is reached.""" + self._pending.append(mutation) + if self._flush_interval > 0 and len(self._pending) >= self._flush_interval: + self.sync() + + def sync(self): + """Flush all pending mutations to cloud.""" + if not self._pending: + return + self._mutate_batch(list(self._pending)) + self._pending.clear() + + def mutate_put(self, subject, predicate, obj, update=False, create_new_edge=False): + self._enqueue({ + "op": "PUT", "s": str(subject), "p": str(predicate), "o": str(obj), + "update": update, "create_new_edge": create_new_edge, + }) + + def mutate_put_batch(self, triples): + """triples: list of {"s": ..., "p": ..., "o": ...} dicts.""" + self.sync() # flush pending before direct batch send + mutations = [ + {"op": "PUT", "s": t["s"], "p": t["p"], "o": t["o"]} + for t in triples + ] + return self._mutate_batch(mutations) + + def mutate_delete(self, subject, predicate, obj): + self._enqueue({ + "op": "DELETE", "s": str(subject), "p": str(predicate), "o": str(obj), + }) + + def mutate_drop(self): + self.sync() # flush pending before destructive operation + return self._mutate_one({"op": "DROP"}) + + def mutate_truncate(self): + self.sync() # flush pending before destructive operation + return self._mutate_one({"op": "TRUNCATE"}) + + def mutate_put_embedding(self, word, embedding): + return self._mutate_one({ + "op": "PUT_EMBEDDING", "word": word, "embedding": embedding, + }) + + def mutate_delete_embedding(self, word): + return self._mutate_one({ + "op": "DELETE_EMBEDDING", "word": word, + }) + + def mutate_put_embeddings_batch(self, embeddings): + """embeddings: list of {"word": ..., "embedding": ...} dicts.""" + mutations = [ + {"op": "PUT_EMBEDDING", "word": e["word"], "embedding": e["embedding"]} + for e in embeddings + ] + return self._mutate_batch(mutations) + + def mutate_vectorize(self, words, provider, batch_size): + return self._mutate_one({ + "op": "VECTORIZE", "words": words, "provider": provider, + "batch_size": batch_size, + }) + + @staticmethod + def _quote(value): + """Quote a string value for the query string, escaping internal quotes and backslashes.""" + escaped = str(value).replace('\\', '\\\\').replace('"', '\\"') + return f'"{escaped}"' + + @classmethod + def _chain_to_query_string(cls, chain): + """Convert a list of chain steps into a query string. + + Each step is a dict with 'method' and optional 'args'. + Example chain: + [{"method": "v", "args": {"vertex": "alice"}}, + {"method": "out", "args": {"predicates": ["knows"]}}, + {"method": "all"}] + Result: v("alice").out("knows").all() + """ + parts = [] + for step in chain: + method = step["method"] + args = step.get("args", {}) + param_str = cls._serialize_step(method, args) + parts.append(f"{method}({param_str})") + return ".".join(parts) + + @classmethod + def _serialize_step(cls, method, args): + """Serialize a step's args into its parameter string.""" + if not args: + return "" + + if method == "v": + vertex = args.get("vertex") + if vertex is None: + return "" + if isinstance(vertex, list): + items = ", ".join(cls._quote(v) for v in vertex) + return f"[{items}]" + return cls._quote(vertex) + + if method in ("out", "inc", "both"): + predicates = args.get("predicates") + if not predicates: + return "" + if len(predicates) == 1: + return cls._quote(predicates[0]) + items = ", ".join(cls._quote(p) for p in predicates) + return f"[{items}]" + + if method in ("has", "hasr"): + predicates = args.get("predicates", []) + vertex = args.get("vertex", "") + if predicates and len(predicates) == 1: + return f'{cls._quote(predicates[0])}, {cls._quote(vertex)}' + if predicates: + items = ", ".join(cls._quote(p) for p in predicates) + return f'[{items}], {cls._quote(vertex)}' + return cls._quote(vertex) + + if method == "is_": + nodes = args.get("nodes", []) + items = ", ".join(cls._quote(n) for n in nodes) + return items + + if method == "tag": + names = args.get("tag_names", []) + return ", ".join(cls._quote(n) for n in names) + + if method == "back": + return cls._quote(args.get("tag", "")) + + if method in ("limit", "skip"): + return str(args.get("n", "")) + + if method == "order": + return cls._quote(args.get("direction", "asc")) + + if method == "scan": + parts = [] + if "limit" in args: + parts.append(str(args["limit"])) + if "scan_type" in args: + parts.append(cls._quote(args["scan_type"])) + return ", ".join(parts) + + if method == "all": + options = args.get("options") + if options: + return cls._quote(options) + return "" + + if method in ("bfs", "dfs"): + parts = [] + predicates = args.get("predicates") + if predicates: + if len(predicates) == 1: + parts.append(cls._quote(predicates[0])) + else: + items = ", ".join(cls._quote(p) for p in predicates) + parts.append(f"[{items}]") + if args.get("max_depth") is not None: + parts.append(str(args["max_depth"])) + min_depth = args.get("min_depth") + direction = args.get("direction") + unique = args.get("unique") + # Emit min_depth whenever a later positional arg is non-default + has_later = ((direction is not None and direction != "out") + or (unique is not None and unique is not True)) + if (min_depth is not None and min_depth != 0) or has_later: + parts.append(str(min_depth or 0)) + if direction is not None and direction != "out": + parts.append(cls._quote(direction)) + if unique is not None and unique is not True: + parts.append(str(unique).lower()) + return ", ".join(parts) + + if method == "sim": + parts = [cls._quote(args.get("word", ""))] + if args.get("operator"): + parts.append(cls._quote(args["operator"])) + if args.get("threshold") is not None: + parts.append(str(args["threshold"])) + if args.get("strict"): + parts.append("true") + return ", ".join(parts) + + if method == "k_nearest": + parts = [cls._quote(args.get("word", ""))] + if args.get("k") is not None: + parts.append(str(args["k"])) + return ", ".join(parts) + + # Fallback: serialize any remaining simple args + return ", ".join( + cls._quote(v) if isinstance(v, str) else str(v) + for v in args.values() + ) + + def query_chain(self, chain): + self.sync() # flush pending for read-your-writes + q = self._chain_to_query_string(chain) + return self._request("POST", "/query", {"q": q}) + + def query_scan(self, limit, scan_type): + self.sync() + q = f'scan({limit}, {self._quote(scan_type)})' + return self._request("POST", "/query", {"q": q}) + + def query_triples(self): + self.sync() + return self._request("POST", "/query", {"q": "triples()"}) + + def query_get_embedding(self, word): + self.sync() + return self._request("POST", "/query", { + "q": f'get_embedding({self._quote(word)})', + }) + + def query_scan_embeddings(self, limit): + self.sync() + return self._request("POST", "/query", { + "q": f'scan_embeddings({limit})', + }) + + def query_embedding_stats(self): + self.sync() + return self._request("POST", "/query", {"q": "embedding_stats()"}) + + + def list_graphs(self): + """List all graphs accessible by this API key. + + Returns: + list[str]: Sorted list of graph names. + """ + data = self._account_request("POST", "", {"fn": "ls"}) + graphs = data.get("graphs", data) + return sorted(graphs) if isinstance(graphs, list) else graphs + diff --git a/cog/config.py b/cog/config.py index 2c0c639..68bbc95 100644 --- a/cog/config.py +++ b/cog/config.py @@ -104,6 +104,10 @@ def cog_store(self, db_name, table_name, instance_id): ''' VECTORIZE ''' COGDB_EMBED_URL = "https://vectors.cogdb.io/embed" +''' CLOUD ''' +CLOUD_URL = "https://api.cogdb.io" +CLOUD_API_PREFIX = "/api/v1" + def cog_db_path(): if CUSTOM_COG_DB_PATH: diff --git a/cog/core.py b/cog/core.py index c163152..3e7648d 100644 --- a/cog/core.py +++ b/cog/core.py @@ -189,13 +189,13 @@ def __init__(self, table_meta, config, logger, index_id=0): f = open(self.name, 'wb+') i = 0 e_blocks = [] - while i < config.INDEX_CAPACITY: + while i < self.config.INDEX_CAPACITY: e_blocks.append(self.empty_block) i += 1 f.write(b''.join(e_blocks)) self.file_limit = f.tell() f.close() - self.logger.info("new index with capacity" + str(config.INDEX_CAPACITY) + "created: " + self.name) + self.logger.info("new index with capacity" + str(self.config.INDEX_CAPACITY) + "created: " + self.name) else: self.logger.info("Index: "+self.name+" already exists.") @@ -695,7 +695,7 @@ def __init__(self, tablemeta, config, logger): self.load_indexes() # if no index currenlty exist, create new live index. if len(self.index_list) == 0: - self.index_list.append(Index(tablemeta, config, logger, self.index_id)) + self.index_list.append(Index(tablemeta, self.config, logger, self.index_id)) self.live_index = self.index_list[self.index_id] def close(self): diff --git a/cog/embedding_providers.py b/cog/embedding_providers.py index e21cda0..ad138fe 100644 --- a/cog/embedding_providers.py +++ b/cog/embedding_providers.py @@ -2,7 +2,6 @@ import json import ssl import urllib.request -import urllib.error import logging from . import config as cfg diff --git a/cog/embeddings.py b/cog/embeddings.py new file mode 100644 index 0000000..bdc3a83 --- /dev/null +++ b/cog/embeddings.py @@ -0,0 +1,458 @@ +import array +import math +import heapq +import logging +from math import isclose + +from cog.core import Record +from cog.embedding_providers import EMBEDDING_PROVIDERS, _chunked + +# Optional simsimd for SIMD-optimized similarity +try: + import simsimd + _HAS_SIMSIMD = True +except ImportError: + _HAS_SIMSIMD = False + +logger = logging.getLogger(__name__) + + +class EmbeddingMixin: + """Mixin providing embedding/vector methods for Graph.""" + + def put_embedding(self, word, embedding): + """ + Saves a word embedding. + """ + assert isinstance(word, str), "word must be a string" + if self._cloud: + self._cloud_client.mutate_put_embedding(word, embedding) + return + self.cog.use_namespace(self.graph_name).use_table(self.config.EMBEDDING_SET_TABLE_NAME).put(Record( + word, embedding)) + + def get_embedding(self, word): + """ + Returns a word embedding. + """ + assert isinstance(word, str), "word must be a string" + if self._cloud: + result = self._cloud_client.query_get_embedding(word) + return result.get("embedding") + record = self.cog.use_namespace(self.graph_name).use_table(self.config.EMBEDDING_SET_TABLE_NAME).get( + word) + if record is None: + return None + return record.value + + def delete_embedding(self, word): + """ + Deletes a word embedding. + """ + assert isinstance(word, str), "word must be a string" + if self._cloud: + self._cloud_client.mutate_delete_embedding(word) + return + self.cog.use_namespace(self.graph_name).use_table(self.config.EMBEDDING_SET_TABLE_NAME).delete( + word) + + def put_embeddings_batch(self, word_embedding_pairs): + """ + Bulk insert multiple embeddings efficiently. + + :param word_embedding_pairs: List of (word, embedding) tuples + :return: self for method chaining + + Example: + g.put_embeddings_batch([ + ("apple", [0.1, 0.2, ...]), + ("orange", [0.3, 0.4, ...]), + ]) + """ + if self._cloud: + batch = [{"word": w, "embedding": e} for w, e in word_embedding_pairs] + self._cloud_client.mutate_put_embeddings_batch(batch) + return self + self.cog.use_namespace(self.graph_name) + self.cog.begin_batch() + try: + for word, embedding in word_embedding_pairs: + if not isinstance(word, str): + raise TypeError("word must be a string") + self.cog.use_table(self.config.EMBEDDING_SET_TABLE_NAME).put(Record( + word, embedding)) + finally: + self.cog.end_batch() + return self + + def scan_embeddings(self, limit=100): + """ + Scan and return a list of words that have embeddings stored. + + :param limit: Maximum number of embeddings to return + :return: Dictionary with 'result' containing list of words with embeddings + + Note: This scans the graph vertices and checks which have embeddings. + """ + if self._cloud: + result = self._cloud_client.query_scan_embeddings(limit) + result.pop("ok", None) + return result + result = [] + self.cog.use_namespace(self.graph_name).use_table(self.config.GRAPH_NODE_SET_TABLE_NAME) + count = 0 + for r in self.cog.scanner(): + if count >= limit: + break + word = r.key + if self.get_embedding(word) is not None: + result.append({"id": word}) + count += 1 + return {"result": result} + + def embedding_stats(self): + """ + Return statistics about stored embeddings. + + :return: Dictionary with count and dimensions (if available) + """ + if self._cloud: + result = self._cloud_client.query_embedding_stats() + result.pop("ok", None) + return result + count = 0 + dimensions = None + # Scan the embedding table directly + self.cog.use_namespace(self.graph_name).use_table(self.config.EMBEDDING_SET_TABLE_NAME) + for r in self.cog.scanner(): + count += 1 + if dimensions is None and r.value is not None: + dimensions = len(r.value) + return {"count": count, "dimensions": dimensions} + + def k_nearest(self, word, k=10): + """ + Find the k vertices most similar to the given word based on embeddings. + + :param word: The word to find similar vertices for + :param k: Number of nearest neighbors to return (default 10) + :return: self for method chaining + + Example: + g.v().k_nearest("machine_learning", k=5).all() + """ + if self._cloud: + return self._cloud_append("k_nearest", word=word, k=k) + # Auto-embed query word if missing + self._auto_embed(word) + + target_embedding = self.get_embedding(word) + if target_embedding is None: + self.last_visited_vertices = [] + return self + + # simsimd/fallback requires buffer protocol (e.g. numpy array or python array) + target_vec = array.array('f', target_embedding) + similarities = [] + + # None = no prior traversal, scan entire embedding table + # [] = prior traversal returned empty, preserve empty semantics + # [...] = search within visited vertices + if self.last_visited_vertices is None: + # Scan embedding table directly for all embeddings + self.cog.use_namespace(self.graph_name).use_table(self.config.EMBEDDING_SET_TABLE_NAME) + for r in self.cog.scanner(): + if r.value is not None: + v_vec = array.array('f', r.value) + distance = self._cosine_distance(target_vec, v_vec) + similarity = 1.0 - float(distance) + from cog.torque import Vertex + similarities.append((similarity, Vertex(r.key))) + elif self.last_visited_vertices: + # Search within visited vertices + for v in self.last_visited_vertices: + v_embedding = self.get_embedding(v.id) + if v_embedding is not None: + v_vec = array.array('f', v_embedding) + distance = self._cosine_distance(target_vec, v_vec) + similarity = 1.0 - float(distance) + similarities.append((similarity, v)) + # else: empty list, similarities stays empty + + # Get top k using heap for efficiency + top_k = heapq.nlargest(k, similarities, key=lambda x: x[0]) + self.last_visited_vertices = [v for _, v in top_k] + return self + + def sim(self, word, operator, threshold, strict=False): + """ + Applies cosine similarity filter to the vertices and removes any vertices that do not pass the filter. + + Parameters: + ----------- + word: str + The word to compare to the other vertices. + operator: str + The comparison operator to use. One of "==", ">", "<", ">=", "<=", or "in". + threshold: float or list of 2 floats + The threshold value(s) to use for the comparison. If operator is "==", ">", "<", ">=", or "<=", threshold should be a float. If operator is "in", threshold should be a list of 2 floats. + strict: bool, optional + If True, raises an exception if a word embedding is not found for either word. If False, assigns a similarity of 0.0 to any word embedding that is not found. + + Returns: + -------- + self: GraphTraversal + Returns self to allow for method chaining. + + Raises: + ------- + ValueError: + If operator is not a valid comparison operator or if threshold is not a valid threshold value for the given operator. + If strict is True and a word embedding is not found for either word. + """ + if self._cloud: + return self._cloud_append("sim", word=word, operator=operator, + threshold=threshold, strict=strict) + if not isinstance(threshold, (float, int, list)): + raise ValueError("Invalid threshold value: {}".format(threshold)) + + if operator == 'in': + if not isinstance(threshold, list) or len(threshold) != 2: + raise ValueError("Invalid threshold value: {}".format(threshold)) + if not all(isinstance(t, (float, int)) for t in threshold): + raise ValueError("Invalid threshold value: {}".format(threshold)) + + # Auto-embed query word if missing + self._auto_embed(word) + + filtered_vertices = [] + for v in self.last_visited_vertices: + similarity = self._cosine_similarity(word, v.id) + if not similarity: + # similarity is None if a word embedding is not found for either word. + if strict: + raise ValueError("Missing word embedding for either '{}' or '{}'".format(word, v.id)) + else: + # Treat vertices without word embeddings as if they have no similarity to any other vertex. + similarity = 0.0 + if operator == '=': + if isclose(similarity, threshold): + filtered_vertices.append(v) + elif operator == '>': + if similarity > threshold: + filtered_vertices.append(v) + elif operator == '<': + if similarity < threshold: + filtered_vertices.append(v) + elif operator == '>=': + if similarity >= threshold: + filtered_vertices.append(v) + elif operator == '<=': + if similarity <= threshold: + filtered_vertices.append(v) + elif operator == 'in': + if not threshold[0] <= similarity <= threshold[1]: + continue + filtered_vertices.append(v) + else: + raise ValueError("Invalid operator: {}".format(operator)) + self.last_visited_vertices = filtered_vertices + return self + + def _cosine_distance(self, x, y): + """Compute cosine distance (1 - similarity) with simsimd or pure Python fallback.""" + if _HAS_SIMSIMD: + return simsimd.cosine(x, y) + else: + # Pure Python fallback for Pyodide/environments without simsimd + dot = sum(a * b for a, b in zip(x, y)) + norm_x = math.sqrt(sum(a * a for a in x)) + norm_y = math.sqrt(sum(b * b for b in y)) + if norm_x == 0 or norm_y == 0: + return 1.0 # Max distance if either vector is zero + return 1.0 - (dot / (norm_x * norm_y)) + + def _cosine_similarity(self, word1, word2): + """Compute cosine similarity using SIMD-optimized simsimd library or pure Python fallback.""" + x_list = self.get_embedding(word1) + y_list = self.get_embedding(word2) + + if x_list is None or y_list is None: + return None + + # Use python array for buffer protocol (compatible with simsimd) + x = array.array('f', x_list) + y = array.array('f', y_list) + + # cosine distance = 1 - similarity, so we convert + distance = self._cosine_distance(x, y) + return 1.0 - float(distance) + + def load_glove(self, filepath, limit=None, batch_size=1000): + """ + Load GloVe embeddings from a text file. + + :param filepath: Path to GloVe file (e.g., 'glove.6B.100d.txt') + :param limit: Maximum number of embeddings to load (None for all) + :param batch_size: Number of embeddings to batch before writing (default 1000) + :return: Number of embeddings loaded + + Example: + count = g.load_glove("glove.6B.100d.txt", limit=50000) + """ + count = 0 + batch = [] + + with open(filepath, 'r', encoding='utf-8') as f: + for line in f: + if limit is not None and count >= limit: + break + parts = line.strip().split() + if len(parts) < 2: + continue + word = parts[0] + embedding = [float(x) for x in parts[1:]] + batch.append((word, embedding)) + count += 1 + + if len(batch) >= batch_size: + self.put_embeddings_batch(batch) + batch = [] + + # Load remaining batch + if batch: + self.put_embeddings_batch(batch) + + return count + + def load_gensim(self, model, limit=None, batch_size=1000): + """ + Load embeddings from a Gensim Word2Vec or FastText model. + + :param model: A Gensim model with a 'wv' attribute (Word2Vec, FastText) + :param limit: Maximum number of embeddings to load (None for all) + :param batch_size: Number of embeddings to batch before writing (default 1000) + :return: Number of embeddings loaded + + Example: + from gensim.models import Word2Vec + model = Word2Vec(sentences) + count = g.load_gensim(model) + """ + count = 0 + batch = [] + + # Get word vectors from model + if hasattr(model, 'wv'): + wv = model.wv + else: + wv = model # Already a KeyedVectors object + + for word in wv.index_to_key: + if limit is not None and count >= limit: + break + embedding = wv[word].tolist() + batch.append((word, embedding)) + count += 1 + + if len(batch) >= batch_size: + self.put_embeddings_batch(batch) + batch = [] + + if batch: + self.put_embeddings_batch(batch) + + return count + + def _auto_embed(self, word): + """Auto-fetch and store embedding for a word if missing. + Only active after vectorize() has been explicitly called.""" + if not self._vectorize_configured: + return + if self.get_embedding(word) is not None: + return + try: + provider_fn = EMBEDDING_PROVIDERS[self._default_provider] + pairs = provider_fn([word], **self._default_provider_kwargs) + if pairs: + self.put_embeddings_batch(pairs) + except Exception as e: + self.logger.debug("auto-embed for '{}' failed: {}".format(word, e)) + + def vectorize(self, words=None, provider="cogdb", batch_size=100, **kwargs): + """ + Auto-generate vector embeddings using a provider. + + Can embed all graph nodes, a single word, or a list of words. + Skips words that already have embeddings. + + :param words: Optional — a string or list of strings to embed. + If None, embeds all nodes in the graph. + :param provider: Provider name — "cogdb" (default), "openai", or "custom". + :param batch_size: Number of words per provider request (default 100). + :param kwargs: Passed to the provider (e.g. url=, api_key=, model=). + :return: Summary dict {"vectorized": N, "skipped": M, "total": T} + + Example: + g.vectorize() # all nodes + g.vectorize("europa") # single word + g.vectorize(["europa", "ocean"]) # specific words + g.vectorize(provider="openai", api_key="sk-...") + """ + if self._cloud: + w = words + if isinstance(w, str): + w = [w] + return self._cloud_client.mutate_vectorize(w, provider, batch_size) + if not isinstance(batch_size, int) or batch_size < 1: + raise ValueError("batch_size must be a positive integer, got: {}".format(batch_size)) + + if provider not in EMBEDDING_PROVIDERS: + raise ValueError("Unknown provider '{}'. Choose from: {}".format( + provider, ", ".join(EMBEDDING_PROVIDERS.keys()))) + + # Store provider config for auto-embed in queries + self._default_provider = provider + self._default_provider_kwargs = kwargs + self._vectorize_configured = True + + provider_fn = EMBEDDING_PROVIDERS[provider] + + # Determine which words to embed + if words is not None: + # Explicit word(s) + if isinstance(words, str): + words = [words] + all_words = words + else: + # All graph nodes + all_words = [] + self.cog.use_namespace(self.graph_name).use_table(self.config.GRAPH_NODE_SET_TABLE_NAME) + for r in self.cog.scanner(): + all_words.append(r.key) + + total = len(all_words) + + # Skip words that already have embeddings + to_embed = [w for w in all_words if self.get_embedding(w) is None] + skipped = total - len(to_embed) + + if not to_embed: + return {"vectorized": 0, "skipped": skipped, "total": total} + + # Send to provider in batches and store results + vectorized = 0 + errors = [] + for chunk in _chunked(to_embed, batch_size): + try: + pairs = provider_fn(chunk, **kwargs) + self.put_embeddings_batch(pairs) + vectorized += len(pairs) + except Exception as e: + self.logger.error("vectorize batch failed: {}".format(e)) + errors.append(str(e)) + + result = {"vectorized": vectorized, "skipped": skipped, "total": total} + if errors: + result["errors"] = errors + return result diff --git a/cog/export.py b/cog/export.py index 2ab3bcc..ad1eadf 100644 --- a/cog/export.py +++ b/cog/export.py @@ -82,19 +82,21 @@ def get_triples(graph): yield (vertex, predicate_name, obj) -def export_triples(graph, filepath, fmt="nt", strict=False): +def export_triples(graph, filepath, fmt="nt", strict=False, triples_iter=None): """ Export all triples in the graph to a file. Writes one triple per line in the specified format. - :param graph: A Graph instance. + :param graph: A Graph instance (unused if triples_iter is provided). :param filepath: Path to the output file. :param fmt: Format string — "nt" (N-Triples, default), "csv", or "tsv". :param strict: If True and fmt is "nt", output W3C-compliant N-Triples where IRIs are wrapped in <>, blank nodes use _: prefix, and plain literals are quoted with "". See https://www.w3.org/TR/n-triples/ + :param triples_iter: Optional iterable of (s, p, o) tuples. If provided, + used instead of scanning the graph. :return: Number of triples written. Example: @@ -105,18 +107,19 @@ def export_triples(graph, filepath, fmt="nt", strict=False): """ fmt = fmt.lower() count = 0 + triples = triples_iter if triples_iter is not None else get_triples(graph) with open(filepath, 'w', newline='') as f: if fmt in ("csv", "tsv"): delimiter = '\t' if fmt == "tsv" else ',' writer = csv_module.writer(f, delimiter=delimiter) writer.writerow(["subject", "predicate", "object"]) - for s, p, o in get_triples(graph): + for s, p, o in triples: writer.writerow([s, p, o]) count += 1 elif fmt == "nt": if strict: - for s, p, o in get_triples(graph): + for s, p, o in triples: f.write('{} {} {} .\n'.format( _to_nt_term(s, "subject"), _to_nt_term(p, "predicate"), @@ -124,7 +127,7 @@ def export_triples(graph, filepath, fmt="nt", strict=False): )) count += 1 else: - for s, p, o in get_triples(graph): + for s, p, o in triples: f.write('{} {} {} .\n'.format(s, p, o)) count += 1 else: diff --git a/cog/search.py b/cog/search.py new file mode 100644 index 0000000..9bf0696 --- /dev/null +++ b/cog/search.py @@ -0,0 +1,197 @@ +from collections import deque +import logging + +from cog.database import hash_predicate + +logger = logging.getLogger(__name__) + + +class TraversalMixin: + """Mixin providing BFS/DFS traversal methods for Graph.""" + + def __get_adjacent(self, vertex, predicates, direction): + """Get adjacent vertices based on direction: 'out', 'inc', or 'both'.""" + adjacent = [] + if direction in ("out", "both"): + adjacent.extend(self._Graph__adjacent_vertices(vertex, predicates, 'out')) + if direction in ("inc", "both"): + adjacent.extend(self._Graph__adjacent_vertices(vertex, predicates, 'in')) + return adjacent + + def bfs(self, predicates=None, max_depth=None, min_depth=0, + direction="out", until=None, unique=True): + """ + Traverse the graph breadth-first from current vertices. + + BFS explores level-by-level, visiting all neighbors at the current depth + before moving deeper. Guarantees shortest path in unweighted graphs. + + :param predicates: Edge type(s) to follow: str, list, or None (all edges) + :param max_depth: Maximum traversal depth (None = unlimited) + :param min_depth: Minimum depth to include in results (default 0) + :param direction: Traversal direction: "out", "inc", or "both" + :param until: Stop condition lambda: func(vertex_id) -> bool + :param unique: If True, visit each vertex only once (prevents cycles) + :return: self for method chaining + + Example: + g.v("alice").bfs(predicates="follows", max_depth=2).all() + g.v("alice").bfs(max_depth=3, min_depth=2).all() # depths 2-3 only + g.v("alice").bfs(until=lambda v: v == "target").all() + """ + if self._cloud: + if until is not None: + raise RuntimeError("bfs() with an 'until' lambda is not supported in cloud mode.") + p = predicates + if p is not None and not isinstance(p, list): + p = [p] + return self._cloud_append("bfs", predicates=p, max_depth=max_depth, + min_depth=min_depth, direction=direction, unique=unique) + + # Normalize predicates + if predicates is not None: + if not isinstance(predicates, list): + predicates = [predicates] + predicates = list(map(hash_predicate, predicates)) + else: + predicates = self.all_predicates + + from cog.torque import Vertex + result_vertices = [] + visited = set() + queue = deque() # (vertex, depth) + + # Initialize with current vertices at depth 0 + for v in self.last_visited_vertices: + queue.append((v, 0)) + if unique: + visited.add(v.id) + + while queue: + current, depth = queue.popleft() + + if until and until(current.id): + if depth >= min_depth: + result_vertex = Vertex(current.id) + result_vertex.tags = current.tags.copy() + result_vertex.edges = current.edges.copy() + result_vertex._path = current._path + result_vertices.append(result_vertex) + continue + + if depth > 0 and depth >= min_depth: + if max_depth is None or depth <= max_depth: + result_vertex = Vertex(current.id) + result_vertex.tags = current.tags.copy() + result_vertex.edges = current.edges.copy() + result_vertex._path = current._path + result_vertices.append(result_vertex) + + # Stop exploring if at max depth + if max_depth is not None and depth >= max_depth: + continue + + adjacent = self.__get_adjacent(current, predicates, direction) + for adj in adjacent: + if unique: + if adj.id in visited: + continue + visited.add(adj.id) + adj.tags = current.tags.copy() + # Build path for neighbor from parent's path + parent_path = current._path or [{'vertex': current.id}] + edge_hash = next(iter(adj.edges)) if adj.edges else None + edge_name = self._predicate_reverse_lookup_cache.get(edge_hash, edge_hash) if edge_hash else None + adj._path = list(parent_path) + ([{'edge': edge_name}] if edge_name else []) + [{'vertex': adj.id}] + queue.append((adj, depth + 1)) + + self.last_visited_vertices = result_vertices + return self + + def dfs(self, predicates=None, max_depth=None, min_depth=0, + direction="out", until=None, unique=True): + """ + Traverse the graph depth-first from current vertices. + + DFS explores as deep as possible along each branch before backtracking. + More memory-efficient than BFS for deep graphs. + + :param predicates: Edge type(s) to follow: str, list, or None (all edges) + :param max_depth: Maximum traversal depth (None = unlimited) + :param min_depth: Minimum depth to include in results (default 0) + :param direction: Traversal direction: "out", "inc", or "both" + :param until: Stop condition lambda: func(vertex_id) -> bool + :param unique: If True, visit each vertex only once (prevents cycles) + :return: self for method chaining + + Example: + g.v("alice").dfs(predicates="follows", max_depth=3).all() + g.v("alice").dfs(direction="both", max_depth=2).all() + """ + if self._cloud: + if until is not None: + raise RuntimeError("dfs() with an 'until' lambda is not supported in cloud mode.") + p = predicates + if p is not None and not isinstance(p, list): + p = [p] + return self._cloud_append("dfs", predicates=p, max_depth=max_depth, + min_depth=min_depth, direction=direction, unique=unique) + # Normalize predicates + if predicates is not None: + if not isinstance(predicates, list): + predicates = [predicates] + predicates = list(map(hash_predicate, predicates)) + else: + predicates = self.all_predicates + + from cog.torque import Vertex + result_vertices = [] + visited = set() + stack = [] # (vertex, depth) + + # Initialize with current vertices at depth 0 + for v in self.last_visited_vertices: + stack.append((v, 0)) + if unique: + visited.add(v.id) + + while stack: + current, depth = stack.pop() # LIFO for DFS + + if until and until(current.id): + if depth >= min_depth: + result_vertex = Vertex(current.id) + result_vertex.tags = current.tags.copy() + result_vertex.edges = current.edges.copy() + result_vertex._path = current._path + result_vertices.append(result_vertex) + continue + + if depth > 0 and depth >= min_depth: + if max_depth is None or depth <= max_depth: + result_vertex = Vertex(current.id) + result_vertex.tags = current.tags.copy() + result_vertex.edges = current.edges.copy() + result_vertex._path = current._path + result_vertices.append(result_vertex) + + # Stop exploring if at max depth + if max_depth is not None and depth >= max_depth: + continue + + adjacent = self.__get_adjacent(current, predicates, direction) + for adj in adjacent: + if unique: + if adj.id in visited: + continue + visited.add(adj.id) + adj.tags = current.tags.copy() + # Build path for neighbor from parent's path + parent_path = current._path or [{'vertex': current.id}] + edge_hash = next(iter(adj.edges)) if adj.edges else None + edge_name = self._predicate_reverse_lookup_cache.get(edge_hash, edge_hash) if edge_hash else None + adj._path = list(parent_path) + ([{'edge': edge_name}] if edge_name else []) + [{'vertex': adj.id}] + stack.append((adj, depth + 1)) + + self.last_visited_vertices = result_vertices + return self diff --git a/cog/torque.py b/cog/torque.py index 60e80b5..4a025df 100644 --- a/cog/torque.py +++ b/cog/torque.py @@ -1,29 +1,19 @@ from cog.database import Cog from cog.database import in_nodes, out_nodes, hash_predicate, parse_tripple -from cog.core import cog_hash, Record import json import logging from . import config as cfg from .config import CogConfig -from cog.view import graph_template, script_part1, script_part2, graph_lib_src -from cog.embedding_providers import EMBEDDING_PROVIDERS, _chunked +from cog.view import graph_template, script_part1, script_part2, graph_lib_src, View +from cog.embeddings import EmbeddingMixin +from cog.search import TraversalMixin import os import shutil from os import listdir +from cog.cloud_client import CloudClient import time import random -from math import isclose import warnings -import heapq -import array -import math - -# Optional simsimd for SIMD-optimized similarity -try: - import simsimd - _HAS_SIMSIMD = True -except ImportError: - _HAS_SIMSIMD = False NOTAG = "NOTAG" @@ -72,34 +62,69 @@ def is_id(cls, label): return label.startswith("_:" + BlankNode.ID_PREFIX) -class Graph: +class Graph(EmbeddingMixin, TraversalMixin): """ Creates a graph object. Args: - graph_name: Name of the graph + graph_name: Name of the graph (default: "default") cog_home: Home directory name for the database cog_path_prefix: Root directory location for Cog db enable_caching: Enable in-memory caching for faster reads - flush_interval: Number of writes before auto-flush per store. + flush_interval: Number of writes before auto-flush (local and cloud). 1 = flush every write (safest, default) 0 = manual flush only (fastest, use sync()) N>1 = flush every N writes with async background threads config: Optional CogConfig instance. When provided, overrides all other config options (cog_home, cog_path_prefix) and prevents mutation of global config state. Each Graph gets its own isolated copy. + api_key: API key for CogDB Cloud. When provided (or set via + COGDB_API_KEY env var), the graph operates in cloud mode — + all operations go over HTTP and no local files are created. """ - def __init__(self, graph_name, cog_home="cog_home", cog_path_prefix=None, enable_caching=True, - flush_interval=1, config=None): + def __init__(self, graph_name="default", cog_home="cog_home", cog_path_prefix=None, enable_caching=True, + flush_interval=1, config=None, api_key=None): """ - :param graph_name: + :param graph_name: Name of the graph (default: "default") :param cog_home: Home directory name, for most use cases use default. :param cog_path_prefix: sets the root directory location for Cog db. Default: '/tmp' set in cog.Config. Change this to current directory when running in an IPython environment. :param flush_interval: Number of writes before auto-flush. 1 = every write (safest). :param config: Optional CogConfig instance. Overrides cog_home and cog_path_prefix when provided. + :param api_key: API key for CogDB Cloud mode. """ + + self.graph_name = graph_name + self.logger = logging.getLogger(__name__) + + # Resolve API key: explicit param > env var > None + resolved_key = api_key or os.environ.get("COGDB_API_KEY") + + if resolved_key: + # Cloud mode — all operations go over HTTP + self._cloud = True + self._api_key = resolved_key + self._flush_interval = flush_interval + self._cloud_client = CloudClient(graph_name, resolved_key, flush_interval=flush_interval) + self._cloud_chain = [] # accumulates traversal steps + self.config = cfg + self.last_visited_vertices = None + self._server_port = None + self.views_dir = None + self._predicate_reverse_lookup_cache = {} + self._default_provider = "cogdb" + self._default_provider_kwargs = {} + self._vectorize_configured = False + self.logger.debug(f"Torque cloud mode on graph: {graph_name}") + # No local storage initialized + return + + # Local mode (existing behavior, unchanged) + self._cloud = False + self._api_key = None + self._cloud_client = None + if config is not None: self.config = config else: @@ -107,15 +132,17 @@ def __init__(self, graph_name, cog_home="cog_home", cog_path_prefix=None, enable if cog_path_prefix: self.config.COG_PATH_PREFIX = cog_path_prefix - self.graph_name = graph_name + if config is None: + # Keep module-level globals in sync for backward compat (single-graph usage) + cfg.COG_HOME = cog_home + if cog_path_prefix: + cfg.COG_PATH_PREFIX = cog_path_prefix if enable_caching: self.cache = {} else: self.cache = None - self.logger = logging.getLogger(__name__) - self.logger.debug(f"Torque init on graph: {graph_name} (flush_interval={flush_interval})") self.cog = Cog(self.cache, flush_interval=flush_interval, config=self.config) @@ -144,6 +171,33 @@ def __init__(self, graph_name, cog_home="cog_home", cog_path_prefix=None, enable self._default_provider_kwargs = {} # Provider kwargs (e.g. api_key) self._vectorize_configured = False # True after explicit vectorize() call + # === Cloud Traversal Helpers === + + def _cloud_reset_chain(self): + """Reset the cloud traversal chain for a new query.""" + self._cloud_chain = [] + + def _cloud_append(self, method, **kwargs): + """Append a traversal step to the cloud chain.""" + step = {"method": method} + if kwargs: + step["args"] = {k: v for k, v in kwargs.items() if v is not None} + self._cloud_chain.append(step) + return self + + def _cloud_execute_chain(self, terminal_method, **kwargs): + """Send accumulated chain + terminal method to cloud and return results.""" + chain = list(self._cloud_chain) + step = {"method": terminal_method} + if kwargs: + step["args"] = {k: v for k, v in kwargs.items() if v is not None} + chain.append(step) + result = self._cloud_client.query_chain(chain) + self._cloud_reset_chain() + # Strip cloud envelope to match local response format + result.pop("ok", None) + return result + # === Network Methods === def serve(self, port=8080, host="0.0.0.0", blocking=False, writable=False, share=False): @@ -177,6 +231,11 @@ def serve(self, port=8080, host="0.0.0.0", blocking=False, writable=False, share # Share graph publicly g.serve(port=8080, share=True) """ + if self._cloud: + raise RuntimeError( + "g.serve() is not available in cloud mode. " + "The graph is already hosted on CogDB Cloud." + ) from cog.server import get_or_create_server if self._server_port is not None: @@ -267,17 +326,97 @@ def connect(cls, url, timeout=30): def sync(self): """ - Force flush all pending writes to disk. + Force flush all pending writes to disk (local) or cloud. Blocks until all flushes are complete. Use this when flush_interval > 1 or when you need to ensure data durability at a specific point. """ + if self._cloud: + self._cloud_client.sync() + return self.cog.sync() def refresh(self): + if self._cloud: + return # No-op in cloud mode self.cog.refresh_all() + def ls(self): + """ + List all graph names accessible from this connection. + + In cloud mode, queries the server for all graphs under this API key. + In local mode, scans the cog_home directory for graph subdirectories. + + Returns: + list[str]: Sorted list of graph names. + + Example: + g = Graph(api_key="sk-...") + print(g.ls()) # ['default', 'products', 'social'] + + g = Graph() + print(g.ls()) # ['default', 'my_graph'] + """ + if self._cloud: + return self._cloud_client.list_graphs() + + # Local mode: each graph is a subdirectory under cog_db_path + db_path = self.config.cog_db_path() + if not os.path.exists(db_path): + return [] + skip = {self.config.COG_SYS_DIR, self.config.VIEWS} + return sorted([ + d for d in os.listdir(db_path) + if os.path.isdir(os.path.join(db_path, d)) and d not in skip + ]) + + def use(self, graph_name): + """ + Switch this instance to a different graph. + + Flushes any pending writes before switching. The graph is created + if it does not already exist (same behavior as the constructor). + + Args: + graph_name: Name of the graph to switch to. + + Returns: + self for method chaining. + + Example: + g = Graph(api_key="sk-...") + g.ls() # ['default', 'social'] + g.use("social").v("alice").out("knows").all() + """ + if self._cloud: + self._cloud_client.sync() # flush pending mutations + self.graph_name = graph_name + self._cloud_client = CloudClient( + graph_name, self._api_key, flush_interval=self._flush_interval + ) + self._cloud_reset_chain() + return self + + # Local mode: switch namespace + self.graph_name = graph_name + self.cog.create_or_load_namespace(graph_name) + self.cog.use_namespace(graph_name) + self.all_predicates = self.cog.list_tables() + # Rebuild predicate reverse lookup cache for the new graph + self._predicate_reverse_lookup_cache = {} + try: + for pred_hash in self.all_predicates: + edge_record = self.cog.use_table( + self.config.GRAPH_EDGE_SET_TABLE_NAME + ).get(pred_hash) + if edge_record is not None: + self._predicate_reverse_lookup_cache[pred_hash] = edge_record.value + except Exception: + pass # Edge set table may not exist yet for new graphs + return self + def updatej(self, json_object): self.put_json(json_object, True) @@ -364,6 +503,20 @@ def load_triples(self, graph_data_path, graph_name=None): :param graph_name: :return: None """ + if self._cloud: + # Read triples from file and send to cloud in batches + batch = [] + batch_size = 1000 + with open(graph_data_path) as f: + for line in f: + subject, predicate, obj, _ = parse_tripple(line) + batch.append({"s": subject, "p": predicate, "o": obj}) + if len(batch) >= batch_size: + self._cloud_client.mutate_put_batch(batch) + batch = [] + if batch: + self._cloud_client.mutate_put_batch(batch) + return None graph_name = self.graph_name if graph_name is None else graph_name self.cog.load_triples(graph_data_path, graph_name) @@ -389,6 +542,23 @@ def load_csv(self, csv_path, id_column_name, graph_name=None): if id_column_name is None: raise Exception("id_column_name must not be None") + if self._cloud: + # Read CSV locally and send triples to cloud in batches + batch = [] + batch_size = 1000 + with open(csv_path) as csv_file: + reader = csv.DictReader(csv_file) + for row in reader: + subject = row[id_column_name] + for col, val in row.items(): + if col != id_column_name: + batch.append({"s": subject, "p": col, "o": val}) + if len(batch) >= batch_size: + self._cloud_client.mutate_put_batch(batch) + batch = [] + if batch: + self._cloud_client.mutate_put_batch(batch) + return None graph_name = self.graph_name if graph_name is None else graph_name self.cog.load_csv(csv_path, id_column_name, graph_name) self.all_predicates = self.cog.list_tables() @@ -401,10 +571,17 @@ def load_csv(self, csv_path, id_column_name, graph_name=None): self._predicate_reverse_lookup_cache[hash_predicate(col)] = col def close(self): + if self._cloud: + self._cloud_client.sync() # flush any pending mutations + return self.logger.info("closing graph: " + self.graph_name) self.cog.close() def put(self, vertex1, predicate, vertex2, update=False, create_new_edge=False): + if self._cloud: + self._cloud_client.mutate_put(vertex1, predicate, vertex2, + update=update, create_new_edge=create_new_edge) + return self self._predicate_reverse_lookup_cache[hash_predicate(predicate)] = predicate self.cog.use_namespace(self.graph_name) if update: @@ -432,6 +609,16 @@ def put_batch(self, triples): ("charlie", "follows", "alice") ]) """ + if self._cloud: + batch = [] + for v1, pred, v2 in triples: + batch.append({"s": str(v1), "p": str(pred), "o": str(v2)}) + if len(batch) >= 1000: + self._cloud_client.mutate_put_batch(batch) + batch = [] + if batch: + self._cloud_client.mutate_put_batch(batch) + return self self.cog.use_namespace(self.graph_name) self.cog.begin_batch() try: @@ -456,6 +643,9 @@ def delete(self, vertex1, predicate, vertex2): g.put("alice", "knows", "bob") g.delete("alice", "knows", "bob") """ + if self._cloud: + self._cloud_client.mutate_delete(vertex1, predicate, vertex2) + return self self.cog.delete_edge(vertex1, predicate, vertex2) return self @@ -476,6 +666,9 @@ def drop(self, *args): "drop(s, p, o) is deprecated. Use delete(s, p, o) for edges. " "Use drop() with no arguments to delete the entire graph." ) + if self._cloud: + self._cloud_client.mutate_drop() + return # Clear the cache if self.cache is not None: @@ -506,6 +699,9 @@ def truncate(self): g.truncate() # Graph is now empty but still usable g.put("new", "data", "here") # Works fine """ + if self._cloud: + self._cloud_client.mutate_truncate() + return self # Get the graph directory path using the same method Cog uses # This correctly handles CUSTOM_COG_DB_PATH if set graph_path = self.config.cog_data_dir(self.graph_name) @@ -544,6 +740,11 @@ def update(self, vertex1, predicate, vertex2): return self def v(self, vertex=None, func=None): + if self._cloud: + self._cloud_reset_chain() + if isinstance(vertex, list): + return self._cloud_append("v", vertex=vertex) + return self._cloud_append("v", vertex=vertex) if func: warnings.warn("The use of func is deprecated, please use filter instead.", DeprecationWarning) if vertex is not None: @@ -567,6 +768,9 @@ def out(self, predicates=None, func=None): :param predicates: A string or a List of strings. :return: self for method chaining. """ + if self._cloud: + p = predicates if isinstance(predicates, list) else ([predicates] if predicates else None) + return self._cloud_append("out", predicates=p) if func: warnings.warn("The use of func is deprecated, please use filter instead.", DeprecationWarning) @@ -590,6 +794,9 @@ def inc(self, predicates=None, func=None): :param predicates: List of predicates :return: self for method chaining. """ + if self._cloud: + p = predicates if isinstance(predicates, list) else ([predicates] if predicates else None) + return self._cloud_append("inc", predicates=p) if func: warnings.warn("The use of func is deprecated, please use filter instead.", DeprecationWarning) @@ -630,6 +837,9 @@ def has(self, predicates, vertex): :param vertex: Vertex ID :return: self for method chaining. """ + if self._cloud: + p = predicates if isinstance(predicates, list) else ([predicates] if predicates else None) + return self._cloud_append("has", predicates=p, vertex=vertex) if predicates is not None: if not isinstance(predicates, list): @@ -653,6 +863,9 @@ def hasr(self, predicates, vertex): :param vertex: Vertex ID :return: self for method chaining. """ + if self._cloud: + p = predicates if isinstance(predicates, list) else ([predicates] if predicates else None) + return self._cloud_append("hasr", predicates=p, vertex=vertex) if predicates is not None: if not isinstance(predicates, list): @@ -676,6 +889,10 @@ def scan(self, limit=10, scan_type='v'): :param scan_type: use 'v' to scan the vertex set or 'e' to scan the edge set :return: A dictionary containing a list of scanned item(vertex) IDs, e.g., `{'result': [{'id': '...'}]}`. """ + if self._cloud: + result = self._cloud_client.query_scan(limit, scan_type) + result.pop("ok", None) + return result assert type(scan_type) is str, "Scan type must be either 'v' for vertices or 'e' for edges." if scan_type == 'e': self.cog.use_namespace(self.graph_name).use_table(self.config.GRAPH_EDGE_SET_TABLE_NAME) @@ -739,6 +956,11 @@ def filter(self, func): """ Applies a filter function to the vertices and removes any vertices that do not pass the filter. """ + if self._cloud: + raise RuntimeError( + "filter() with a Python lambda is not supported in cloud mode. " + "Use has()/hasr()/is_() for server-side filtering." + ) self.last_visited_vertices = [v for v in self.last_visited_vertices if func(v.id)] return self @@ -748,6 +970,10 @@ def both(self, predicates=None): :param predicates: A string or list of predicate strings to follow. :return: self for method chaining. """ + if self._cloud: + p = predicates if isinstance(predicates, list) else ([predicates] if predicates else None) + return self._cloud_append("both", predicates=p) + if predicates is not None: if not isinstance(predicates, list): predicates = [predicates] @@ -817,6 +1043,9 @@ def is_(self, *nodes): :param nodes: One or more node IDs to filter to. :return: self for method chaining. """ + if self._cloud: + node_list = list(nodes[0]) if (len(nodes) == 1 and isinstance(nodes[0], list)) else list(nodes) + return self._cloud_append("is_", nodes=node_list) if len(nodes) == 1 and isinstance(nodes[0], list): node_set = set(nodes[0]) else: @@ -829,6 +1058,8 @@ def unique(self): Remove duplicate vertices from the result set. :return: self for method chaining. """ + if self._cloud: + return self._cloud_append("unique") seen = set() unique_vertices = [] for v in self.last_visited_vertices: @@ -844,6 +1075,8 @@ def limit(self, n): :param n: Maximum number of vertices to return. :return: self for method chaining. """ + if self._cloud: + return self._cloud_append("limit", n=n) self.last_visited_vertices = self.last_visited_vertices[:n] return self @@ -853,6 +1086,8 @@ def skip(self, n): :param n: Number of vertices to skip. :return: self for method chaining. """ + if self._cloud: + return self._cloud_append("skip", n=n) self.last_visited_vertices = self.last_visited_vertices[n:] return self @@ -866,6 +1101,8 @@ def order(self, direction="asc"): g.v("Person").out("created_at").order().all() # ascending (default) g.v("Person").out("created_at").order(desc).all() # descending ''' + if self._cloud: + return self._cloud_append("order", direction=direction) reverse = (direction == "desc") self.last_visited_vertices = sorted(self.last_visited_vertices, key=lambda v: v.id, reverse=reverse) return self @@ -876,6 +1113,8 @@ def back(self, tag): :param tag: A previous tag in the query to jump back to. :return: self for method chaining. """ + if self._cloud: + return self._cloud_append("back", tag=tag) vertices = [] for v in self.last_visited_vertices: if tag in v.tags: @@ -887,277 +1126,6 @@ def back(self, tag): self.last_visited_vertices = vertices return self - def __get_adjacent(self, vertex, predicates, direction): - """Get adjacent vertices based on direction: 'out', 'inc', or 'both'.""" - adjacent = [] - if direction in ("out", "both"): - adjacent.extend(self.__adjacent_vertices(vertex, predicates, 'out')) - if direction in ("inc", "both"): - adjacent.extend(self.__adjacent_vertices(vertex, predicates, 'in')) - return adjacent - - def bfs(self, predicates=None, max_depth=None, min_depth=0, - direction="out", until=None, unique=True): - """ - Traverse the graph breadth-first from current vertices. - - BFS explores level-by-level, visiting all neighbors at the current depth - before moving deeper. Guarantees shortest path in unweighted graphs. - - :param predicates: Edge type(s) to follow: str, list, or None (all edges) - :param max_depth: Maximum traversal depth (None = unlimited) - :param min_depth: Minimum depth to include in results (default 0) - :param direction: Traversal direction: "out", "inc", or "both" - :param until: Stop condition lambda: func(vertex_id) -> bool - :param unique: If True, visit each vertex only once (prevents cycles) - :return: self for method chaining - - Example: - g.v("alice").bfs(predicates="follows", max_depth=2).all() - g.v("alice").bfs(max_depth=3, min_depth=2).all() # depths 2-3 only - g.v("alice").bfs(until=lambda v: v == "target").all() - """ - from collections import deque - - # Normalize predicates - if predicates is not None: - if not isinstance(predicates, list): - predicates = [predicates] - predicates = list(map(hash_predicate, predicates)) - else: - predicates = self.all_predicates - - result_vertices = [] - visited = set() - queue = deque() # (vertex, depth) - - # Initialize with current vertices at depth 0 - for v in self.last_visited_vertices: - queue.append((v, 0)) - if unique: - visited.add(v.id) - - while queue: - current, depth = queue.popleft() - - if until and until(current.id): - if depth >= min_depth: - result_vertex = Vertex(current.id) - result_vertex.tags = current.tags.copy() - result_vertex.edges = current.edges.copy() - result_vertex._path = current._path - result_vertices.append(result_vertex) - continue - - if depth > 0 and depth >= min_depth: - if max_depth is None or depth <= max_depth: - result_vertex = Vertex(current.id) - result_vertex.tags = current.tags.copy() - result_vertex.edges = current.edges.copy() - result_vertex._path = current._path - result_vertices.append(result_vertex) - - # Stop exploring if at max depth - if max_depth is not None and depth >= max_depth: - continue - - adjacent = self.__get_adjacent(current, predicates, direction) - for adj in adjacent: - if unique: - if adj.id in visited: - continue - visited.add(adj.id) - adj.tags = current.tags.copy() - # Build path for neighbor from parent's path - parent_path = current._path or [{'vertex': current.id}] - edge_hash = next(iter(adj.edges)) if adj.edges else None - edge_name = self._predicate_reverse_lookup_cache.get(edge_hash, edge_hash) if edge_hash else None - adj._path = list(parent_path) + ([{'edge': edge_name}] if edge_name else []) + [{'vertex': adj.id}] - queue.append((adj, depth + 1)) - - self.last_visited_vertices = result_vertices - return self - - def dfs(self, predicates=None, max_depth=None, min_depth=0, - direction="out", until=None, unique=True): - """ - Traverse the graph depth-first from current vertices. - - DFS explores as deep as possible along each branch before backtracking. - More memory-efficient than BFS for deep graphs. - - :param predicates: Edge type(s) to follow: str, list, or None (all edges) - :param max_depth: Maximum traversal depth (None = unlimited) - :param min_depth: Minimum depth to include in results (default 0) - :param direction: Traversal direction: "out", "inc", or "both" - :param until: Stop condition lambda: func(vertex_id) -> bool - :param unique: If True, visit each vertex only once (prevents cycles) - :return: self for method chaining - - Example: - g.v("alice").dfs(predicates="follows", max_depth=3).all() - g.v("alice").dfs(direction="both", max_depth=2).all() - """ - # Normalize predicates - if predicates is not None: - if not isinstance(predicates, list): - predicates = [predicates] - predicates = list(map(hash_predicate, predicates)) - else: - predicates = self.all_predicates - - result_vertices = [] - visited = set() - stack = [] # (vertex, depth) - - # Initialize with current vertices at depth 0 - for v in self.last_visited_vertices: - stack.append((v, 0)) - if unique: - visited.add(v.id) - - while stack: - current, depth = stack.pop() # LIFO for DFS - - if until and until(current.id): - if depth >= min_depth: - result_vertex = Vertex(current.id) - result_vertex.tags = current.tags.copy() - result_vertex.edges = current.edges.copy() - result_vertex._path = current._path - result_vertices.append(result_vertex) - continue - - if depth > 0 and depth >= min_depth: - if max_depth is None or depth <= max_depth: - result_vertex = Vertex(current.id) - result_vertex.tags = current.tags.copy() - result_vertex.edges = current.edges.copy() - result_vertex._path = current._path - result_vertices.append(result_vertex) - - # Stop exploring if at max depth - if max_depth is not None and depth >= max_depth: - continue - - adjacent = self.__get_adjacent(current, predicates, direction) - for adj in adjacent: - if unique: - if adj.id in visited: - continue - visited.add(adj.id) - adj.tags = current.tags.copy() - # Build path for neighbor from parent's path - parent_path = current._path or [{'vertex': current.id}] - edge_hash = next(iter(adj.edges)) if adj.edges else None - edge_name = self._predicate_reverse_lookup_cache.get(edge_hash, edge_hash) if edge_hash else None - adj._path = list(parent_path) + ([{'edge': edge_name}] if edge_name else []) + [{'vertex': adj.id}] - stack.append((adj, depth + 1)) - - self.last_visited_vertices = result_vertices - return self - - def sim(self, word, operator, threshold, strict=False): - """ - Applies cosine similarity filter to the vertices and removes any vertices that do not pass the filter. - - Parameters: - ----------- - word: str - The word to compare to the other vertices. - operator: str - The comparison operator to use. One of "==", ">", "<", ">=", "<=", or "in". - threshold: float or list of 2 floats - The threshold value(s) to use for the comparison. If operator is "==", ">", "<", ">=", or "<=", threshold should be a float. If operator is "in", threshold should be a list of 2 floats. - strict: bool, optional - If True, raises an exception if a word embedding is not found for either word. If False, assigns a similarity of 0.0 to any word embedding that is not found. - - Returns: - -------- - self: GraphTraversal - Returns self to allow for method chaining. - - Raises: - ------- - ValueError: - If operator is not a valid comparison operator or if threshold is not a valid threshold value for the given operator. - If strict is True and a word embedding is not found for either word. - """ - if not isinstance(threshold, (float, int, list)): - raise ValueError("Invalid threshold value: {}".format(threshold)) - - if operator == 'in': - if not isinstance(threshold, list) or len(threshold) != 2: - raise ValueError("Invalid threshold value: {}".format(threshold)) - if not all(isinstance(t, (float, int)) for t in threshold): - raise ValueError("Invalid threshold value: {}".format(threshold)) - - # Auto-embed query word if missing - self._auto_embed(word) - - filtered_vertices = [] - for v in self.last_visited_vertices: - similarity = self.__cosine_similarity(word, v.id) - if not similarity: - # similarity is None if a word embedding is not found for either word. - if strict: - raise ValueError("Missing word embedding for either '{}' or '{}'".format(word, v.id)) - else: - # Treat vertices without word embeddings as if they have no similarity to any other vertex. - similarity = 0.0 - if operator == '=': - if isclose(similarity, threshold): - filtered_vertices.append(v) - elif operator == '>': - if similarity > threshold: - filtered_vertices.append(v) - elif operator == '<': - if similarity < threshold: - filtered_vertices.append(v) - elif operator == '>=': - if similarity >= threshold: - filtered_vertices.append(v) - elif operator == '<=': - if similarity <= threshold: - filtered_vertices.append(v) - elif operator == 'in': - if not threshold[0] <= similarity <= threshold[1]: - continue - filtered_vertices.append(v) - else: - raise ValueError("Invalid operator: {}".format(operator)) - self.last_visited_vertices = filtered_vertices - return self - - def __cosine_distance(self, x, y): - """Compute cosine distance (1 - similarity) with simsimd or pure Python fallback.""" - if _HAS_SIMSIMD: - return simsimd.cosine(x, y) - else: - # Pure Python fallback for Pyodide/environments without simsimd - dot = sum(a * b for a, b in zip(x, y)) - norm_x = math.sqrt(sum(a * a for a in x)) - norm_y = math.sqrt(sum(b * b for b in y)) - if norm_x == 0 or norm_y == 0: - return 1.0 # Max distance if either vector is zero - return 1.0 - (dot / (norm_x * norm_y)) - - def __cosine_similarity(self, word1, word2): - """Compute cosine similarity using SIMD-optimized simsimd library or pure Python fallback.""" - x_list = self.get_embedding(word1) - y_list = self.get_embedding(word2) - - if x_list is None or y_list is None: - return None - - # Use python array for buffer protocol (compatible with simsimd) - x = array.array('f', x_list) - y = array.array('f', y_list) - - # cosine distance = 1 - similarity, so we convert - distance = self.__cosine_distance(x, y) - return 1.0 - float(distance) - def tag(self, tag_names): """ Saves vertices with tag name(s). Used to capture vertices while traversing a graph. @@ -1167,6 +1135,9 @@ def tag(self, tag_names): :param tag_names: A string or list of strings. :return: self for method chaining. """ + if self._cloud: + names = tag_names if isinstance(tag_names, list) else [tag_names] + return self._cloud_append("tag", tag_names=names) if not isinstance(tag_names, list): tag_names = [tag_names] for v in self.last_visited_vertices: @@ -1178,6 +1149,9 @@ def tag(self, tag_names): return self def count(self): + if self._cloud: + result = self._cloud_execute_chain("count") + return result.get("result", result.get("count", 0)) return len(self.last_visited_vertices) def all(self, options=None): @@ -1186,6 +1160,8 @@ def all(self, options=None): https://github.com/cayleygraph/cayley/blob/master/docs/GizmoAPI.md :return: """ + if self._cloud: + return self._cloud_execute_chain("all", options=options) result = [] show_edge = True if options is not None and 'e' in options else False for v in self.last_visited_vertices: @@ -1213,6 +1189,8 @@ def graph(self): g.v("bob").out("follows").tag("from").out("works_at").tag("to").graph() # {'nodes': [{'id': 'bob'}, ...], 'links': [{'source': 'bob', 'target': 'fred', 'label': 'follows'}, ...]} """ + if self._cloud: + return self._cloud_execute_chain("graph") nodes = {} links = {} @@ -1251,6 +1229,18 @@ def triples(self): g.triples() # [("alice", "follows", "bob"), ("bob", "follows", "charlie")] """ + if self._cloud: + result = self._cloud_client.query_triples() + result.pop("ok", None) + # Support both {"triples": [...]} and standard {"result": [...]} envelopes + raw = result.get("triples", result.get("result", [])) + triples_list = [] + for item in raw: + if isinstance(item, (list, tuple)) and len(item) >= 3: + triples_list.append(tuple(item[:3])) + elif isinstance(item, dict) and "s" in item: + triples_list.append((item["s"], item["p"], item["o"])) + return triples_list from cog.export import get_triples return list(get_triples(self)) @@ -1274,6 +1264,10 @@ def export(self, filepath, fmt="nt", strict=False): g.export("graph.csv", fmt="csv") # CSV with header g.export("graph.tsv", fmt="tsv") # TSV with header """ + if self._cloud: + cloud_triples = self.triples() + from cog.export import export_triples + return export_triples(self, filepath, fmt=fmt, strict=strict, triples_iter=cloud_triples) from cog.export import export_triples return export_triples(self, filepath, fmt=fmt, strict=strict) @@ -1283,6 +1277,8 @@ def view(self, view_name, Returns html view of the resulting graph from a query. :return: """ + if self._cloud: + raise RuntimeError("view() is not supported in cloud mode") assert view_name is not None, "a view name is required to create a view, it can be any string." result = self.graph() # Escape HTML special characters to prevent XSS when embedding in script tag @@ -1296,6 +1292,8 @@ def view(self, view_name, return view def getv(self, view_name): + if self._cloud: + raise RuntimeError("getv() is not supported in cloud mode") view = self.views_dir + "/{view_name}.html".format(view_name=view_name) assert os.path.isfile(view), "view not found, create a view by calling .view()" with open(view, 'r') as f: @@ -1304,337 +1302,10 @@ def getv(self, view_name): return view def lsv(self): + if self._cloud: + raise RuntimeError("lsv() is not supported in cloud mode") return [f.split(".")[0] for f in listdir(self.views_dir)] - def put_embedding(self, word, embedding): - """ - Saves a word embedding. - """ - - assert isinstance(word, str), "word must be a string" - self.cog.use_namespace(self.graph_name).use_table(self.config.EMBEDDING_SET_TABLE_NAME).put(Record( - word, embedding)) - - def get_embedding(self, word): - """ - Returns a word embedding. - """ - assert isinstance(word, str), "word must be a string" - record = self.cog.use_namespace(self.graph_name).use_table(self.config.EMBEDDING_SET_TABLE_NAME).get( - word) - if record is None: - return None - return record.value - - def delete_embedding(self, word): - """ - Deletes a word embedding. - """ - assert isinstance(word, str), "word must be a string" - self.cog.use_namespace(self.graph_name).use_table(self.config.EMBEDDING_SET_TABLE_NAME).delete( - word) - - def put_embeddings_batch(self, word_embedding_pairs): - """ - Bulk insert multiple embeddings efficiently. - - :param word_embedding_pairs: List of (word, embedding) tuples - :return: self for method chaining - - Example: - g.put_embeddings_batch([ - ("apple", [0.1, 0.2, ...]), - ("orange", [0.3, 0.4, ...]), - ]) - """ - self.cog.use_namespace(self.graph_name) - self.cog.begin_batch() - try: - for word, embedding in word_embedding_pairs: - if not isinstance(word, str): - raise TypeError("word must be a string") - self.cog.use_table(self.config.EMBEDDING_SET_TABLE_NAME).put(Record( - word, embedding)) - finally: - self.cog.end_batch() - return self - - def scan_embeddings(self, limit=100): - """ - Scan and return a list of words that have embeddings stored. - - :param limit: Maximum number of embeddings to return - :return: Dictionary with 'result' containing list of words with embeddings - - Note: This scans the graph vertices and checks which have embeddings. - """ - result = [] - self.cog.use_namespace(self.graph_name).use_table(self.config.GRAPH_NODE_SET_TABLE_NAME) - count = 0 - for r in self.cog.scanner(): - if count >= limit: - break - word = r.key - if self.get_embedding(word) is not None: - result.append({"id": word}) - count += 1 - return {"result": result} - - def embedding_stats(self): - """ - Return statistics about stored embeddings. - - :return: Dictionary with count and dimensions (if available) - """ - count = 0 - dimensions = None - # Scan the embedding table directly - self.cog.use_namespace(self.graph_name).use_table(self.config.EMBEDDING_SET_TABLE_NAME) - for r in self.cog.scanner(): - count += 1 - if dimensions is None and r.value is not None: - dimensions = len(r.value) - return {"count": count, "dimensions": dimensions} - - def k_nearest(self, word, k=10): - """ - Find the k vertices most similar to the given word based on embeddings. - - :param word: The word to find similar vertices for - :param k: Number of nearest neighbors to return (default 10) - :return: self for method chaining - - Example: - g.v().k_nearest("machine_learning", k=5).all() - """ - # Auto-embed query word if missing - self._auto_embed(word) - - target_embedding = self.get_embedding(word) - if target_embedding is None: - self.last_visited_vertices = [] - return self - - # simsimd/fallback requires buffer protocol (e.g. numpy array or python array) - target_vec = array.array('f', target_embedding) - similarities = [] - - # None = no prior traversal, scan entire embedding table - # [] = prior traversal returned empty, preserve empty semantics - # [...] = search within visited vertices - if self.last_visited_vertices is None: - # Scan embedding table directly for all embeddings - self.cog.use_namespace(self.graph_name).use_table(self.config.EMBEDDING_SET_TABLE_NAME) - for r in self.cog.scanner(): - if r.value is not None: - v_vec = array.array('f', r.value) - distance = self.__cosine_distance(target_vec, v_vec) - similarity = 1.0 - float(distance) - similarities.append((similarity, Vertex(r.key))) - elif self.last_visited_vertices: - # Search within visited vertices - for v in self.last_visited_vertices: - v_embedding = self.get_embedding(v.id) - if v_embedding is not None: - v_vec = array.array('f', v_embedding) - distance = self.__cosine_distance(target_vec, v_vec) - similarity = 1.0 - float(distance) - similarities.append((similarity, v)) - # else: empty list, similarities stays empty - - # Get top k using heap for efficiency - top_k = heapq.nlargest(k, similarities, key=lambda x: x[0]) - self.last_visited_vertices = [v for _, v in top_k] - return self - - def load_glove(self, filepath, limit=None, batch_size=1000): - """ - Load GloVe embeddings from a text file. - - :param filepath: Path to GloVe file (e.g., 'glove.6B.100d.txt') - :param limit: Maximum number of embeddings to load (None for all) - :param batch_size: Number of embeddings to batch before writing (default 1000) - :return: Number of embeddings loaded - - Example: - count = g.load_glove("glove.6B.100d.txt", limit=50000) - """ - count = 0 - batch = [] - - with open(filepath, 'r', encoding='utf-8') as f: - for line in f: - if limit is not None and count >= limit: - break - parts = line.strip().split() - if len(parts) < 2: - continue - word = parts[0] - embedding = [float(x) for x in parts[1:]] - batch.append((word, embedding)) - count += 1 - - if len(batch) >= batch_size: - self.put_embeddings_batch(batch) - batch = [] - - # Load remaining batch - if batch: - self.put_embeddings_batch(batch) - - return count - - def load_gensim(self, model, limit=None, batch_size=1000): - """ - Load embeddings from a Gensim Word2Vec or FastText model. - - :param model: A Gensim model with a 'wv' attribute (Word2Vec, FastText) - :param limit: Maximum number of embeddings to load (None for all) - :param batch_size: Number of embeddings to batch before writing (default 1000) - :return: Number of embeddings loaded - - Example: - from gensim.models import Word2Vec - model = Word2Vec(sentences) - count = g.load_gensim(model) - """ - count = 0 - batch = [] - - # Get word vectors from model - if hasattr(model, 'wv'): - wv = model.wv - else: - wv = model # Already a KeyedVectors object - - for word in wv.index_to_key: - if limit is not None and count >= limit: - break - embedding = wv[word].tolist() - batch.append((word, embedding)) - count += 1 - - if len(batch) >= batch_size: - self.put_embeddings_batch(batch) - batch = [] - - if batch: - self.put_embeddings_batch(batch) - - return count - - def _auto_embed(self, word): - """Auto-fetch and store embedding for a word if missing. - Only active after vectorize() has been explicitly called.""" - if not self._vectorize_configured: - return - if self.get_embedding(word) is not None: - return - try: - provider_fn = EMBEDDING_PROVIDERS[self._default_provider] - pairs = provider_fn([word], **self._default_provider_kwargs) - if pairs: - self.put_embeddings_batch(pairs) - except Exception as e: - self.logger.debug("auto-embed for '{}' failed: {}".format(word, e)) - - def vectorize(self, words=None, provider="cogdb", batch_size=100, **kwargs): - """ - Auto-generate vector embeddings using a provider. - - Can embed all graph nodes, a single word, or a list of words. - Skips words that already have embeddings. - - :param words: Optional — a string or list of strings to embed. - If None, embeds all nodes in the graph. - :param provider: Provider name — "cogdb" (default), "openai", or "custom". - :param batch_size: Number of words per provider request (default 100). - :param kwargs: Passed to the provider (e.g. url=, api_key=, model=). - :return: Summary dict {"vectorized": N, "skipped": M, "total": T} - - Example: - g.vectorize() # all nodes - g.vectorize("europa") # single word - g.vectorize(["europa", "ocean"]) # specific words - g.vectorize(provider="openai", api_key="sk-...") - """ - if not isinstance(batch_size, int) or batch_size < 1: - raise ValueError("batch_size must be a positive integer, got: {}".format(batch_size)) - - if provider not in EMBEDDING_PROVIDERS: - raise ValueError("Unknown provider '{}'. Choose from: {}".format( - provider, ", ".join(EMBEDDING_PROVIDERS.keys()))) - - # Store provider config for auto-embed in queries - self._default_provider = provider - self._default_provider_kwargs = kwargs - self._vectorize_configured = True - - provider_fn = EMBEDDING_PROVIDERS[provider] - - # Determine which words to embed - if words is not None: - # Explicit word(s) - if isinstance(words, str): - words = [words] - all_words = words - else: - # All graph nodes - all_words = [] - self.cog.use_namespace(self.graph_name).use_table(self.config.GRAPH_NODE_SET_TABLE_NAME) - for r in self.cog.scanner(): - all_words.append(r.key) - - total = len(all_words) - - # Skip words that already have embeddings - to_embed = [w for w in all_words if self.get_embedding(w) is None] - skipped = total - len(to_embed) - - if not to_embed: - return {"vectorized": 0, "skipped": skipped, "total": total} - - # Send to provider in batches and store results - vectorized = 0 - errors = [] - for chunk in _chunked(to_embed, batch_size): - try: - pairs = provider_fn(chunk, **kwargs) - self.put_embeddings_batch(pairs) - vectorized += len(pairs) - except Exception as e: - self.logger.error("vectorize batch failed: {}".format(e)) - errors.append(str(e)) - - result = {"vectorized": vectorized, "skipped": skipped, "total": total} - if errors: - result["errors"] = errors - return result - - -class View(object): - - def __init__(self, url, html): - self.url = url - self.html = html - - def render(self, height=700, width=700): - """ - :param self: - :param height: - :param width: - :return: - """ - iframe_html = r""" """.format(self.html, width, - height) - from IPython.core.display import display, HTML - display(HTML(iframe_html)) - - def persist(self): - f = open(self.url, "w") - f.write(self.html) - f.close() - - def __str__(self): - return self.url + def get_new_graph_instance(self): + return Graph(self.graph_name, self.config.COG_HOME, self.config.COG_PATH_PREFIX) diff --git a/cog/view.py b/cog/view.py index af72821..12ad8d6 100644 --- a/cog/view.py +++ b/cog/view.py @@ -95,4 +95,31 @@ -""" \ No newline at end of file +""" + + +class View(object): + + def __init__(self, url, html): + self.url = url + self.html = html + + def render(self, height=700, width=700): + """ + :param self: + :param height: + :param width: + :return: + """ + iframe_html = r""" """.format(self.html, width, + height) + from IPython.core.display import display, HTML + display(HTML(iframe_html)) + + def persist(self): + f = open(self.url, "w") + f.write(self.html) + f.close() + + def __str__(self): + return self.url \ No newline at end of file diff --git a/pytest.ini b/pytest.ini new file mode 100644 index 0000000..f43dad2 --- /dev/null +++ b/pytest.ini @@ -0,0 +1,3 @@ +[pytest] +markers = + cloud: tests that require COGDB_API_KEY and hit the live cloud backend diff --git a/scripts/local_wheel_server.py b/scripts/local_wheel_server.py index d5c1fff..6112b1c 100644 --- a/scripts/local_wheel_server.py +++ b/scripts/local_wheel_server.py @@ -26,6 +26,20 @@ def do_OPTIONS(self): self.end_headers() +def clean_build_artifacts(project_root, dist_dir): + """Remove old build artifacts to ensure a fresh build.""" + import shutil + + for d in [dist_dir, project_root / "build"]: + if d.exists(): + print(f"Cleaning {d.relative_to(project_root)}/...") + shutil.rmtree(d) + + for egg_info in project_root.glob("*.egg-info"): + print(f"Cleaning {egg_info.name}/...") + shutil.rmtree(egg_info) + + def main(): # Find project root (where setup.py is) script_dir = Path(__file__).parent @@ -34,11 +48,14 @@ def main(): os.chdir(project_root) + # Clean old artifacts first + clean_build_artifacts(project_root, dist_dir) + # Build the wheel using pip wheel print("Building wheel...") dist_dir.mkdir(exist_ok=True) result = subprocess.run( - [sys.executable, "-m", "pip", "wheel", "--no-deps", "-w", "dist", "."], + [sys.executable, "-m", "pip", "wheel", "--no-deps", "--no-cache-dir", "-w", "dist", "."], capture_output=True, text=True ) diff --git a/setup.py b/setup.py index 8862b1e..44fa48a 100644 --- a/setup.py +++ b/setup.py @@ -2,7 +2,7 @@ setup( name='cogdb', - version='3.7.5', + version='3.8.0', description='Persistent Embedded Graph Database', url='http://github.com/arun1729/cog', author='Arun Mahendra', diff --git a/test/test_batch_mode.py b/test/test_batch_mode.py index 341fe9b..2b54541 100644 --- a/test/test_batch_mode.py +++ b/test/test_batch_mode.py @@ -144,8 +144,9 @@ def test_batch_mode_performance_improvement(self): @classmethod def tearDownClass(cls): - # Clean up after this test class - pass + if os.path.exists("/tmp/" + DIR_NAME): + shutil.rmtree("/tmp/" + DIR_NAME) + os.makedirs("/tmp/" + DIR_NAME) class TestCogBatchMode(unittest.TestCase): diff --git a/test/test_cloud.py b/test/test_cloud.py new file mode 100644 index 0000000..3ad7ac4 --- /dev/null +++ b/test/test_cloud.py @@ -0,0 +1,440 @@ +from cog.torque import Graph +from cog import config as cfg +import unittest +import os +import shutil +import json +from unittest.mock import patch, MagicMock +from io import BytesIO +import urllib.error + + +def _mock_response(body_dict, status=200): + """Create a mock HTTP response that works as a context manager.""" + resp = MagicMock() + resp.read.return_value = json.dumps(body_dict).encode("utf-8") + resp.status = status + resp.__enter__ = MagicMock(return_value=resp) + resp.__exit__ = MagicMock(return_value=False) + return resp + + +class TestCloudModeActivation(unittest.TestCase): + """Test cloud mode activation via api_key param and env var.""" + + def test_explicit_api_key_activates_cloud_mode(self): + g = Graph("test-graph", api_key="cog_test123") + self.assertTrue(g._cloud) + self.assertEqual(g._api_key, "cog_test123") + self.assertEqual(g.graph_name, "test-graph") + + def test_env_var_activates_cloud_mode(self): + with patch.dict(os.environ, {"COGDB_API_KEY": "cog_env_key"}): + g = Graph("test-graph") + self.assertTrue(g._cloud) + self.assertEqual(g._api_key, "cog_env_key") + + def test_explicit_key_overrides_env_var(self): + with patch.dict(os.environ, {"COGDB_API_KEY": "cog_env_key"}): + g = Graph("test-graph", api_key="cog_explicit") + self.assertTrue(g._cloud) + self.assertEqual(g._api_key, "cog_explicit") + + def test_no_key_stays_local(self): + with patch.dict(os.environ, {}, clear=True): + # Remove COGDB_API_KEY if present + os.environ.pop("COGDB_API_KEY", None) + g = Graph("test-graph", cog_path_prefix="/tmp") + self.assertFalse(g._cloud) + self.assertIsNone(g._api_key) + g.close() + shutil.rmtree("/tmp/cog_home", ignore_errors=True) + + def test_cloud_mode_does_not_create_local_files(self): + g = Graph("cloud-test-no-files", api_key="cog_test123") + self.assertFalse(hasattr(g, 'cog')) + + +class TestCloudServeBlocked(unittest.TestCase): + """Test that serve() raises in cloud mode.""" + + def test_serve_raises_runtime_error(self): + g = Graph("test-graph", api_key="cog_test123") + with self.assertRaises(RuntimeError) as ctx: + g.serve() + self.assertIn("cloud mode", str(ctx.exception)) + + +class TestCloudWriteMethods(unittest.TestCase): + """Test that write methods send correct HTTP requests.""" + + @patch("urllib.request.urlopen") + def test_put_sends_mutate_request(self, mock_urlopen): + mock_urlopen.return_value = _mock_response({"ok": True, "count": 1}) + + g = Graph("my-graph", api_key="cog_key123") + result = g.put("alice", "knows", "bob") + + # put() returns self for method chaining + self.assertIs(result, g) + + mock_urlopen.assert_called_once() + req = mock_urlopen.call_args[0][0] + self.assertEqual(req.method, "POST") + self.assertIn(f"{cfg.CLOUD_API_PREFIX}/my-graph/mutate_batch", req.full_url) + self.assertEqual(req.get_header("Authorization"), "cog_key123") + self.assertEqual(req.get_header("Content-type"), "application/json") + + body = json.loads(req.data.decode("utf-8")) + self.assertIn("mutations", body) + self.assertEqual(len(body["mutations"]), 1) + m = body["mutations"][0] + self.assertEqual(m["op"], "PUT") + self.assertEqual(m["s"], "alice") + self.assertEqual(m["p"], "knows") + self.assertEqual(m["o"], "bob") + + @patch("urllib.request.urlopen") + def test_delete_sends_mutate_request(self, mock_urlopen): + mock_urlopen.return_value = _mock_response({"ok": True, "count": 1}) + + g = Graph("my-graph", api_key="cog_key123") + result = g.delete("alice", "knows", "bob") + + body = json.loads(mock_urlopen.call_args[0][0].data.decode("utf-8")) + self.assertEqual(body["mutations"][0]["op"], "DELETE") + # Should return self for chaining + self.assertIs(result, g) + + @patch("urllib.request.urlopen") + def test_put_batch_sends_batch(self, mock_urlopen): + mock_urlopen.return_value = _mock_response({"ok": True, "count": 2}) + + g = Graph("my-graph", api_key="cog_key123") + g.put_batch([ + ("alice", "knows", "bob"), + ("bob", "knows", "charlie"), + ]) + + body = json.loads(mock_urlopen.call_args[0][0].data.decode("utf-8")) + self.assertIn("mutations", body) + self.assertEqual(len(body["mutations"]), 2) + self.assertTrue(all(m["op"] == "PUT" for m in body["mutations"])) + + +class TestCloudTraversalChain(unittest.TestCase): + """Test that traversal chain accumulates and sends at terminal method.""" + + @patch("urllib.request.urlopen") + def test_v_out_all_sends_chain(self, mock_urlopen): + mock_urlopen.return_value = _mock_response({ + "ok": True, "result": [{"id": "bob"}] + }) + + g = Graph("my-graph", api_key="cog_key123") + result = g.v("alice").out("knows").all() + + # Should have made exactly one HTTP call (at all()) + mock_urlopen.assert_called_once() + body = json.loads(mock_urlopen.call_args[0][0].data.decode("utf-8")) + + # Cloud client now sends a query string, not a raw chain + self.assertIn("q", body) + self.assertEqual(body["q"], 'v("alice").out("knows").all()') + + # Response should have 'ok' stripped + self.assertEqual(result, {"result": [{"id": "bob"}]}) + self.assertNotIn("ok", result) + + @patch("urllib.request.urlopen") + def test_v_out_tag_all_sends_chain(self, mock_urlopen): + mock_urlopen.return_value = _mock_response({ + "ok": True, "result": [{"id": "bob", "source": ":(bob)"}] + }) + + g = Graph("my-graph", api_key="cog_key123") + g.v("alice").out("knows").tag("source").all() + + body = json.loads(mock_urlopen.call_args[0][0].data.decode("utf-8")) + self.assertEqual(body["q"], 'v("alice").out("knows").tag("source").all()') + + @patch("urllib.request.urlopen") + def test_count_sends_chain(self, mock_urlopen): + mock_urlopen.return_value = _mock_response({"ok": True, "result": 42}) + + g = Graph("my-graph", api_key="cog_key123") + result = g.v("alice").out("knows").count() + + self.assertEqual(result, 42) + + @patch("urllib.request.urlopen") + def test_v_all_vertices(self, mock_urlopen): + mock_urlopen.return_value = _mock_response({ + "ok": True, "result": [{"id": "alice"}, {"id": "bob"}] + }) + + g = Graph("my-graph", api_key="cog_key123") + result = g.v().all() + + body = json.loads(mock_urlopen.call_args[0][0].data.decode("utf-8")) + self.assertEqual(body["q"], 'v().all()') + self.assertEqual(result["result"], [{"id": "alice"}, {"id": "bob"}]) + self.assertNotIn("ok", result) + + @patch("urllib.request.urlopen") + def test_chain_resets_between_queries(self, mock_urlopen): + mock_urlopen.return_value = _mock_response({"ok": True, "result": []}) + + g = Graph("my-graph", api_key="cog_key123") + g.v("alice").out("knows").all() # First query + + # Start second query - chain should reset + mock_urlopen.return_value = _mock_response({"ok": True, "result": [{"id": "charlie"}]}) + g.v("bob").all() + + body = json.loads(mock_urlopen.call_args[0][0].data.decode("utf-8")) + # Second query should only contain v("bob").all(), not the previous chain + self.assertEqual(body["q"], 'v("bob").all()') + + @patch("urllib.request.urlopen") + def test_bfs_in_chain(self, mock_urlopen): + mock_urlopen.return_value = _mock_response({"ok": True, "result": [{"id": "charlie"}]}) + + g = Graph("my-graph", api_key="cog_key123") + g.v("alice").bfs(predicates="follows", max_depth=2).all() + + body = json.loads(mock_urlopen.call_args[0][0].data.decode("utf-8")) + self.assertEqual(body["q"], 'v("alice").bfs("follows", 2).all()') + + @patch("urllib.request.urlopen") + def test_graph_terminal(self, mock_urlopen): + mock_urlopen.return_value = _mock_response({ + "ok": True, "nodes": [{"id": "alice"}], + "links": [] + }) + + g = Graph("my-graph", api_key="cog_key123") + result = g.v("alice").out("knows").tag("from").graph() + + body = json.loads(mock_urlopen.call_args[0][0].data.decode("utf-8")) + self.assertIn('graph()', body["q"]) + + +class TestCloudTriples(unittest.TestCase): + """Test triples() in cloud mode.""" + + @patch("urllib.request.urlopen") + def test_triples_returns_tuples(self, mock_urlopen): + mock_urlopen.return_value = _mock_response({ + "triples": [["alice", "knows", "bob"], ["bob", "knows", "charlie"]] + }) + + g = Graph("my-graph", api_key="cog_key123") + result = g.triples() + + self.assertEqual(result, [ + ("alice", "knows", "bob"), + ("bob", "knows", "charlie"), + ]) + + +class TestCloudErrorHandling(unittest.TestCase): + """Test HTTP error mapping.""" + + def _make_http_error(self, code, body=None): + if body is None: + body = {"detail": "test error"} + return urllib.error.HTTPError( + url="https://api.cogdb.io/test", + code=code, + msg="Error", + hdrs={}, + fp=BytesIO(json.dumps(body).encode("utf-8")) + ) + + @patch("urllib.request.urlopen") + def test_401_raises_permission_error(self, mock_urlopen): + mock_urlopen.side_effect = self._make_http_error(401) + g = Graph("my-graph", api_key="cog_bad_key") + with self.assertRaises(PermissionError) as ctx: + g.put("a", "b", "c") + self.assertIn("Invalid API key", str(ctx.exception)) + + @patch("urllib.request.urlopen") + def test_403_raises_permission_error(self, mock_urlopen): + mock_urlopen.side_effect = self._make_http_error(403) + g = Graph("my-graph", api_key="cog_bad_key") + with self.assertRaises(PermissionError): + g.put("a", "b", "c") + + @patch("urllib.request.urlopen") + def test_400_raises_value_error(self, mock_urlopen): + mock_urlopen.side_effect = self._make_http_error(400, {"detail": "missing field"}) + g = Graph("my-graph", api_key="cog_key") + with self.assertRaises(ValueError) as ctx: + g.put("a", "b", "c") + self.assertIn("missing field", str(ctx.exception)) + + @patch("urllib.request.urlopen") + def test_500_raises_runtime_error(self, mock_urlopen): + mock_urlopen.side_effect = self._make_http_error(500) + g = Graph("my-graph", api_key="cog_key") + with self.assertRaises(RuntimeError) as ctx: + g.put("a", "b", "c") + self.assertIn("500", str(ctx.exception)) + + @patch("urllib.request.urlopen") + def test_connection_error_raises(self, mock_urlopen): + mock_urlopen.side_effect = urllib.error.URLError("Connection refused") + g = Graph("my-graph", api_key="cog_key") + with self.assertRaises(ConnectionError) as ctx: + g.put("a", "b", "c") + self.assertIn("Cannot reach CogDB Cloud", str(ctx.exception)) + + +class TestCloudNoOps(unittest.TestCase): + """Test that lifecycle methods are safe no-ops in cloud mode.""" + + def test_close_is_noop(self): + g = Graph("my-graph", api_key="cog_key") + g.close() # Should not raise + + def test_sync_without_pending_is_safe(self): + g = Graph("my-graph", api_key="cog_key") + g.sync() # Should not raise with no pending mutations + + def test_refresh_is_noop(self): + g = Graph("my-graph", api_key="cog_key") + g.refresh() # Should not raise + + def test_filter_raises_clear_error(self): + g = Graph("my-graph", api_key="cog_key") + g.v("alice") # start a chain + with self.assertRaises(RuntimeError) as ctx: + g.filter(lambda x: True) + self.assertIn("cloud mode", str(ctx.exception)) + + +class TestCloudDropTruncate(unittest.TestCase): + """Test drop() and truncate() in cloud mode.""" + + @patch("urllib.request.urlopen") + def test_drop_sends_mutate(self, mock_urlopen): + mock_urlopen.return_value = _mock_response({"ok": True, "count": 1}) + g = Graph("my-graph", api_key="cog_key") + g.drop() + + body = json.loads(mock_urlopen.call_args[0][0].data.decode("utf-8")) + self.assertEqual(body["mutations"][0]["op"], "DROP") + + @patch("urllib.request.urlopen") + def test_truncate_sends_mutate(self, mock_urlopen): + mock_urlopen.return_value = _mock_response({"ok": True, "count": 1}) + g = Graph("my-graph", api_key="cog_key") + result = g.truncate() + + body = json.loads(mock_urlopen.call_args[0][0].data.decode("utf-8")) + self.assertEqual(body["mutations"][0]["op"], "TRUNCATE") + self.assertIs(result, g) # returns self for chaining + + +class TestCloudUrl(unittest.TestCase): + """Test that cloud URL is set from config.""" + + def test_cloud_url_from_config(self): + g = Graph("my-graph", api_key="cog_key") + self.assertIn("https://api.cogdb.io", g._cloud_client._base_url) + + +class TestCloudFlushInterval(unittest.TestCase): + """Test that flush_interval controls write batching in cloud mode.""" + + @patch("urllib.request.urlopen") + def test_default_flush_interval_sends_immediately(self, mock_urlopen): + """With flush_interval=1 (default), each put() sends immediately.""" + mock_urlopen.return_value = _mock_response({"ok": True, "count": 1}) + g = Graph("my-graph", api_key="cog_key") + g.put("alice", "knows", "bob") + mock_urlopen.assert_called_once() + + @patch("urllib.request.urlopen") + def test_high_flush_interval_buffers_writes(self, mock_urlopen): + """With flush_interval > count, writes are buffered until sync().""" + mock_urlopen.return_value = _mock_response({"ok": True, "count": 3}) + g = Graph("my-graph", api_key="cog_key", flush_interval=10) + g.put("alice", "knows", "bob") + g.put("bob", "knows", "charlie") + g.put("charlie", "knows", "alice") + mock_urlopen.assert_not_called() + # sync() flushes all pending + g.sync() + mock_urlopen.assert_called_once() + body = json.loads(mock_urlopen.call_args[0][0].data.decode("utf-8")) + self.assertEqual(len(body["mutations"]), 3) + + @patch("urllib.request.urlopen") + def test_auto_flush_on_threshold(self, mock_urlopen): + """Buffer auto-flushes when flush_interval threshold is reached.""" + mock_urlopen.return_value = _mock_response({"ok": True, "count": 2}) + g = Graph("my-graph", api_key="cog_key", flush_interval=2) + g.put("alice", "knows", "bob") + mock_urlopen.assert_not_called() + g.put("bob", "knows", "charlie") # triggers auto-flush + mock_urlopen.assert_called_once() + body = json.loads(mock_urlopen.call_args[0][0].data.decode("utf-8")) + self.assertEqual(len(body["mutations"]), 2) + + @patch("urllib.request.urlopen") + def test_close_flushes_pending(self, mock_urlopen): + """close() flushes any pending mutations.""" + mock_urlopen.return_value = _mock_response({"ok": True, "count": 1}) + g = Graph("my-graph", api_key="cog_key", flush_interval=10) + g.put("alice", "knows", "bob") + mock_urlopen.assert_not_called() + g.close() + mock_urlopen.assert_called_once() + + @patch("urllib.request.urlopen") + def test_query_flushes_pending(self, mock_urlopen): + """Queries flush pending mutations for read-your-writes consistency.""" + flush_resp = _mock_response({"ok": True, "count": 1}) + query_resp = _mock_response({"ok": True, "result": [{"id": "bob"}]}) + mock_urlopen.side_effect = [flush_resp, query_resp] + g = Graph("my-graph", api_key="cog_key", flush_interval=10) + g.put("alice", "knows", "bob") + mock_urlopen.assert_not_called() + g.v("alice").out("knows").all() + self.assertEqual(mock_urlopen.call_count, 2) + + @patch("urllib.request.urlopen") + def test_flush_interval_zero_manual_only(self, mock_urlopen): + """With flush_interval=0, writes only send on explicit sync().""" + mock_urlopen.return_value = _mock_response({"ok": True, "count": 3}) + g = Graph("my-graph", api_key="cog_key", flush_interval=0) + g.put("a", "b", "c") + g.put("d", "e", "f") + g.put("g", "h", "i") + mock_urlopen.assert_not_called() + g.sync() + mock_urlopen.assert_called_once() + body = json.loads(mock_urlopen.call_args[0][0].data.decode("utf-8")) + self.assertEqual(len(body["mutations"]), 3) + + @patch("urllib.request.urlopen") + def test_delete_uses_buffer(self, mock_urlopen): + """delete() also goes through the buffer.""" + mock_urlopen.return_value = _mock_response({"ok": True, "count": 2}) + g = Graph("my-graph", api_key="cog_key", flush_interval=10) + g.put("alice", "knows", "bob") + g.delete("alice", "knows", "bob") + mock_urlopen.assert_not_called() + g.sync() + mock_urlopen.assert_called_once() + body = json.loads(mock_urlopen.call_args[0][0].data.decode("utf-8")) + self.assertEqual(len(body["mutations"]), 2) + self.assertEqual(body["mutations"][0]["op"], "PUT") + self.assertEqual(body["mutations"][1]["op"], "DELETE") + + +if __name__ == '__main__': + unittest.main() diff --git a/test/test_cloud_parity.py b/test/test_cloud_parity.py new file mode 100644 index 0000000..59a5436 --- /dev/null +++ b/test/test_cloud_parity.py @@ -0,0 +1,622 @@ +""" +Cloud ↔ Local Parity Tests +=========================== + +Runs the **same graph operations** against both a local Graph and a cloud +Graph, then asserts return types, response shapes, and values match. + +This ensures the cloud response-normalisation layer in torque.py keeps the +two backends perfectly aligned. + +Usage +----- + # Locally (macOS may need SSL_CERT_FILE): + SSL_CERT_FILE=$(python3 -c "import certifi; print(certifi.where())") \\ + COGDB_API_KEY=cog_xxx python3 -m pytest test/test_cloud_parity.py -v + + # Without a key the whole suite is auto-skipped: + python3 -m pytest test/test_cloud_parity.py -v + +CI +-- + Add COGDB_API_KEY as a GitHub Actions secret, then the workflow step + passes it via env. See .github/workflows/python-tests.yml. +""" + +import os +import shutil +import time +import unittest +from unittest.mock import patch + +import pytest + +CLOUD_API_KEY = os.environ.get("COGDB_API_KEY") +HAS_CLOUD = bool(CLOUD_API_KEY) + +from cog.torque import Graph + +# --------------------------------------------------------------------------- # +# Helpers +# --------------------------------------------------------------------------- # + +DIR_NAME = "CloudParityTest" + + +def ordered(obj): + """Recursively sort dicts/lists for deterministic comparison.""" + if isinstance(obj, dict): + return sorted((k, ordered(v)) for k, v in obj.items()) + if isinstance(obj, list): + return sorted(ordered(x) for x in obj) + return obj + + +# --------------------------------------------------------------------------- # +# Test class +# --------------------------------------------------------------------------- # + +@pytest.mark.cloud +@unittest.skipUnless(HAS_CLOUD, "COGDB_API_KEY not set — skipping cloud parity tests") +class TestCloudLocalParity(unittest.TestCase): + """ + Seeds identical data into a local graph and a cloud graph, runs the + same operations on both, and asserts the responses are identical. + """ + + maxDiff = None + + # ── fixtures ───────────────────────────────────────────────────────────── # + + @classmethod + def setUpClass(cls): + # Clean local directory + local_path = "/tmp/" + DIR_NAME + if os.path.exists(local_path): + shutil.rmtree(local_path) + os.makedirs(local_path, exist_ok=True) + + # Unique graph name per run so cloud data doesn't collide + cls.graph_name = f"parity_{int(time.time())}" + + # Local graph — must NOT pick up COGDB_API_KEY from env + with patch.dict(os.environ, {}, clear=False): + os.environ.pop("COGDB_API_KEY", None) + cls.local = Graph(graph_name=cls.graph_name, cog_home=DIR_NAME) + assert not cls.local._cloud, "Local graph unexpectedly in cloud mode" + + # Cloud graph + cls.cloud = Graph(graph_name=cls.graph_name, api_key=CLOUD_API_KEY) + assert cls.cloud._cloud, "Cloud graph not in cloud mode" + + # Seed identical data on both + cls.triples = [ + ("alice", "knows", "bob"), + ("bob", "knows", "charlie"), + ("charlie", "knows", "alice"), + ("alice", "works_at", "acme"), + ("bob", "works_at", "globex"), + ("charlie", "works_at", "acme"), + ("alice", "age", "30"), + ("bob", "age", "25"), + ("charlie", "age", "35"), + ] + for s, p, o in cls.triples: + cls.local.put(s, p, o) + cls.cloud.put(s, p, o) + + # Flush and wait for cloud backend to index before tests run + cls.cloud.sync() + time.sleep(2) + + @classmethod + def tearDownClass(cls): + cls.local.close() + shutil.rmtree("/tmp/" + DIR_NAME, ignore_errors=True) + + # ── assertion helpers ──────────────────────────────────────────────────── # + + def assert_same_type(self, local_result, cloud_result, ctx=""): + self.assertEqual( + type(local_result), type(cloud_result), + f"Type mismatch ({ctx}): " + f"local={type(local_result).__name__}, " + f"cloud={type(cloud_result).__name__}" + ) + + def assert_same_keys(self, local_result, cloud_result, ctx=""): + if isinstance(local_result, dict) and isinstance(cloud_result, dict): + self.assertEqual( + set(local_result.keys()), set(cloud_result.keys()), + f"Key mismatch ({ctx}): " + f"local={set(local_result.keys())}, " + f"cloud={set(cloud_result.keys())}" + ) + + def assert_same_shape(self, local_result, cloud_result, ctx=""): + self.assert_same_type(local_result, cloud_result, ctx) + self.assert_same_keys(local_result, cloud_result, ctx) + + def assert_same_result_set(self, local_result, cloud_result, ctx=""): + self.assert_same_shape(local_result, cloud_result, ctx) + self.assertEqual( + ordered(local_result), ordered(cloud_result), + f"Value mismatch ({ctx})" + ) + + # ====================================================================== # + # Mutations — return types # + # ====================================================================== # + + def test_put_returns_self(self): + """put() returns the Graph object for method chaining.""" + lr = self.local.put("_tmp", "r", "v") + cr = self.cloud.put("_tmp", "r", "v") + self.assertIsInstance(lr, Graph) + self.assertIsInstance(cr, Graph) + self.assertIs(lr, self.local) + self.assertIs(cr, self.cloud) + self.local.delete("_tmp", "r", "v") + self.cloud.delete("_tmp", "r", "v") + + def test_delete_returns_self(self): + """delete() returns the Graph object for method chaining.""" + self.local.put("_d", "r", "v") + self.cloud.put("_d", "r", "v") + lr = self.local.delete("_d", "r", "v") + cr = self.cloud.delete("_d", "r", "v") + self.assertIsInstance(lr, Graph) + self.assertIsInstance(cr, Graph) + + def test_method_chaining_put(self): + """g.put(...).put(...) works on both backends.""" + lr = self.local.put("_c1", "r", "v1").put("_c2", "r", "v2") + cr = self.cloud.put("_c1", "r", "v1").put("_c2", "r", "v2") + self.assertIsInstance(lr, Graph) + self.assertIsInstance(cr, Graph) + for s, o in [("_c1", "v1"), ("_c2", "v2")]: + self.local.delete(s, "r", o) + self.cloud.delete(s, "r", o) + + # ====================================================================== # + # Traversals — v(), out(), inc(), has(), hasr() # + # ====================================================================== # + + def test_v_all(self): + """g.v().all() has {'result': [...]} with no extra keys.""" + lr = self.local.v().all() + cr = self.cloud.v().all() + self.assert_same_shape(lr, cr, "v().all()") + self.assertEqual(set(lr.keys()), {"result"}) + self.assertEqual(set(cr.keys()), {"result"}) + for item in cr["result"]: + self.assertIn("id", item) + + def test_v_vertex_out_all(self): + """g.v('alice').out('knows').all()""" + lr = self.local.v("alice").out("knows").all() + cr = self.cloud.v("alice").out("knows").all() + self.assert_same_result_set(lr, cr, "v('alice').out('knows').all()") + + def test_v_inc_all(self): + """g.v('bob').inc('knows').all()""" + lr = self.local.v("bob").inc("knows").all() + cr = self.cloud.v("bob").inc("knows").all() + self.assert_same_result_set(lr, cr, "v('bob').inc('knows').all()") + + def test_has_filter(self): + """g.v().has('works_at', 'acme').all()""" + lr = self.local.v().has("works_at", "acme").all() + cr = self.cloud.v().has("works_at", "acme").all() + self.assert_same_result_set(lr, cr, "has('works_at','acme')") + + def test_hasr(self): + """g.v().hasr('knows', 'alice').all()""" + lr = self.local.v().hasr("knows", "alice").all() + cr = self.cloud.v().hasr("knows", "alice").all() + self.assert_same_result_set(lr, cr, "hasr('knows','alice')") + + def test_chained_out(self): + """g.v('alice').out('knows').out('knows').all()""" + lr = self.local.v("alice").out("knows").out("knows").all() + cr = self.cloud.v("alice").out("knows").out("knows").all() + self.assert_same_result_set(lr, cr, "chained out().out()") + + def test_v_list(self): + """g.v(['alice', 'bob']).all()""" + lr = self.local.v(["alice", "bob"]).all() + cr = self.cloud.v(["alice", "bob"]).all() + self.assert_same_result_set(lr, cr, "v([list]).all()") + + # ====================================================================== # + # count() # + # ====================================================================== # + + def test_count(self): + """g.v('alice').out('knows').count() returns same int.""" + lr = self.local.v("alice").out("knows").count() + cr = self.cloud.v("alice").out("knows").count() + self.assertIsInstance(lr, int) + self.assertIsInstance(cr, int) + self.assertEqual(lr, cr) + + def test_v_count_all(self): + """g.v().count() returns same int.""" + lr = self.local.v().count() + cr = self.cloud.v().count() + self.assertIsInstance(lr, int) + self.assertIsInstance(cr, int) + self.assertEqual(lr, cr) + + def test_count_empty(self): + """count() on empty result returns 0 on both.""" + lr = self.local.v("nonexistent_xyz").out("knows").count() + cr = self.cloud.v("nonexistent_xyz").out("knows").count() + self.assertEqual(lr, 0) + self.assertEqual(cr, 0) + + # ====================================================================== # + # scan() # + # ====================================================================== # + + def test_scan_shape(self): + """g.scan() returns {'result': [...]} with no extra keys.""" + lr = self.local.scan(limit=5) + cr = self.cloud.scan(limit=5) + self.assert_same_shape(lr, cr, "scan()") + self.assertEqual(set(lr.keys()), {"result"}) + self.assertEqual(set(cr.keys()), {"result"}) + self.assertEqual(len(lr["result"]), len(cr["result"])) + + def test_scan_items_have_id(self): + """Each scan result item has an 'id' key.""" + for item in self.local.scan()["result"]: + self.assertIn("id", item) + for item in self.cloud.scan()["result"]: + self.assertIn("id", item) + + # ====================================================================== # + # Edge cases # + # ====================================================================== # + + def test_empty_result(self): + """Query with no matches returns same empty structure.""" + lr = self.local.v("nonexistent_xyz").out("knows").all() + cr = self.cloud.v("nonexistent_xyz").out("knows").all() + self.assert_same_shape(lr, cr, "empty result") + self.assertEqual(lr["result"], []) + self.assertEqual(cr["result"], []) + + # ====================================================================== # + # Lifecycle no-ops # + # ====================================================================== # + + def test_sync_noop(self): + """sync() does not raise on either backend.""" + self.local.sync() + self.cloud.sync() + + def test_refresh_noop(self): + """refresh() does not raise on either backend.""" + self.local.refresh() + self.cloud.refresh() + + def test_close_safe(self): + """close() does not raise on cloud.""" + tmp = Graph(graph_name="parity_close_test", api_key=CLOUD_API_KEY) + tmp.close() + + # ====================================================================== # + # Deep / complex tests # + # ====================================================================== # + + # ── tag() + back() round-trip ─────────────────────────────────────────── # + + def test_tag_appears_in_all_results(self): + """tag('x') labels propagate identically into all() dicts.""" + lr = self.local.v("alice").tag("origin").out("knows").all() + cr = self.cloud.v("alice").tag("origin").out("knows").all() + self.assert_same_shape(lr, cr, "tag in all()") + # Every result should carry the 'origin' tag + for item in lr["result"]: + self.assertIn("origin", item) + for item in cr["result"]: + self.assertIn("origin", item) + self.assert_same_result_set(lr, cr, "tag values") + + def test_tag_back_returns_to_origin(self): + """v().tag('start').out().back('start') returns the starting vertices.""" + lr = self.local.v("alice").tag("start").out("knows").back("start").all() + cr = self.cloud.v("alice").tag("start").out("knows").back("start").all() + self.assert_same_result_set(lr, cr, "tag/back round-trip") + # Should return alice (the tagged vertex), not the traversed neighbours + local_ids = {item["id"] for item in lr["result"]} + cloud_ids = {item["id"] for item in cr["result"]} + self.assertEqual(local_ids, cloud_ids) + self.assertIn("alice", local_ids) + + def test_multi_tag_back(self): + """Two tags at different depths, back() to first.""" + lr = (self.local.v("alice").tag("t1") + .out("knows").tag("t2") + .out("works_at") + .back("t1").all()) + cr = (self.cloud.v("alice").tag("t1") + .out("knows").tag("t2") + .out("works_at") + .back("t1").all()) + self.assert_same_result_set(lr, cr, "multi-tag back(t1)") + + # ── order(), limit(), skip() ──────────────────────────────────────────── # + + def test_order_asc(self): + """v().order('asc').all() returns vertices sorted ascending.""" + lr = self.local.v().order("asc").all() + cr = self.cloud.v().order("asc").all() + self.assert_same_shape(lr, cr, "order asc") + local_ids = [item["id"] for item in lr["result"]] + cloud_ids = [item["id"] for item in cr["result"]] + self.assertEqual(local_ids, sorted(local_ids)) + self.assertEqual(cloud_ids, sorted(cloud_ids)) + self.assertEqual(local_ids, cloud_ids) + + def test_order_desc(self): + """v().order('desc').all() returns vertices sorted descending.""" + lr = self.local.v().order("desc").all() + cr = self.cloud.v().order("desc").all() + local_ids = [item["id"] for item in lr["result"]] + cloud_ids = [item["id"] for item in cr["result"]] + self.assertEqual(local_ids, sorted(local_ids, reverse=True)) + self.assertEqual(local_ids, cloud_ids) + + def test_limit(self): + """v().order('asc').limit(2).all() returns exactly 2 items.""" + lr = self.local.v().order("asc").limit(2).all() + cr = self.cloud.v().order("asc").limit(2).all() + self.assertEqual(len(lr["result"]), 2) + self.assertEqual(len(cr["result"]), 2) + self.assert_same_result_set(lr, cr, "limit(2)") + + def test_skip(self): + """v().order('asc').skip(2).all() skips first 2 items.""" + full_lr = self.local.v().order("asc").all() + lr = self.local.v().order("asc").skip(2).all() + cr = self.cloud.v().order("asc").skip(2).all() + expected_count = len(full_lr["result"]) - 2 + self.assertEqual(len(lr["result"]), expected_count) + self.assertEqual(len(cr["result"]), expected_count) + self.assert_same_result_set(lr, cr, "skip(2)") + + def test_limit_skip_pagination(self): + """order + skip + limit simulates pagination identically.""" + lr_page1 = self.local.v().order("asc").limit(3).all() + cr_page1 = self.cloud.v().order("asc").limit(3).all() + lr_page2 = self.local.v().order("asc").skip(3).limit(3).all() + cr_page2 = self.cloud.v().order("asc").skip(3).limit(3).all() + self.assert_same_result_set(lr_page1, cr_page1, "page1") + self.assert_same_result_set(lr_page2, cr_page2, "page2") + # Pages must not overlap + ids_p1 = {item["id"] for item in lr_page1["result"]} + ids_p2 = {item["id"] for item in lr_page2["result"]} + self.assertTrue(ids_p1.isdisjoint(ids_p2), "Pages overlap!") + + # ── both() traversal ──────────────────────────────────────────────────── # + + def test_both(self): + """v('bob').both('knows').all() follows edges in both directions.""" + lr = self.local.v("bob").both("knows").all() + cr = self.cloud.v("bob").both("knows").all() + self.assert_same_result_set(lr, cr, "both('knows')") + # bob knows charlie, and alice knows bob → both should appear + ids = {item["id"] for item in lr["result"]} + self.assertIn("charlie", ids) + self.assertIn("alice", ids) + + # ── is_() filtering ───────────────────────────────────────────────────── # + + def test_is_filter(self): + """v('alice').out('knows').is_('bob').all() filters to bob only.""" + lr = self.local.v("alice").out("knows").is_("bob").all() + cr = self.cloud.v("alice").out("knows").is_("bob").all() + self.assert_same_result_set(lr, cr, "is_('bob')") + self.assertEqual(len(lr["result"]), 1) + self.assertEqual(lr["result"][0]["id"], "bob") + + def test_is_multiple(self): + """is_() with multiple args.""" + lr = self.local.v().is_("alice", "charlie").all() + cr = self.cloud.v().is_("alice", "charlie").all() + self.assert_same_result_set(lr, cr, "is_ multi") + ids = {item["id"] for item in lr["result"]} + self.assertEqual(ids, {"alice", "charlie"}) + + # ── unique() deduplication ────────────────────────────────────────────── # + + def test_unique(self): + """unique() removes duplicate vertices from multi-path results.""" + # alice and bob both work_at acme / globex; traversing inc may yield dupes + lr = self.local.v("acme").inc("works_at").unique().all() + cr = self.cloud.v("acme").inc("works_at").unique().all() + self.assert_same_result_set(lr, cr, "unique()") + local_ids = [item["id"] for item in lr["result"]] + self.assertEqual(len(local_ids), len(set(local_ids)), "Duplicates in local") + + # ── BFS / DFS traversals ──────────────────────────────────────────────── # + + def test_bfs_basic(self): + """BFS from alice over 'knows' edges with max_depth=2.""" + lr = self.local.v("alice").bfs(predicates="knows", max_depth=2).all() + cr = self.cloud.v("alice").bfs(predicates="knows", max_depth=2).all() + self.assert_same_result_set(lr, cr, "bfs depth=2") + + def test_bfs_min_depth(self): + """BFS with min_depth=2 skips depth-1 neighbours.""" + lr = self.local.v("alice").bfs(predicates="knows", max_depth=2, min_depth=2).all() + cr = self.cloud.v("alice").bfs(predicates="knows", max_depth=2, min_depth=2).all() + self.assert_same_result_set(lr, cr, "bfs min_depth=2") + # depth-1 neighbour (bob) should not appear + ids = {item["id"] for item in lr["result"]} + self.assertNotIn("bob", ids) + + def test_dfs_basic(self): + """DFS from alice over 'knows' edges with max_depth=2.""" + lr = self.local.v("alice").dfs(predicates="knows", max_depth=2).all() + cr = self.cloud.v("alice").dfs(predicates="knows", max_depth=2).all() + self.assert_same_result_set(lr, cr, "dfs depth=2") + + def test_bfs_both_direction(self): + """BFS with direction='both' follows edges in both directions.""" + lr = self.local.v("bob").bfs(predicates="knows", max_depth=1, direction="both").all() + cr = self.cloud.v("bob").bfs(predicates="knows", max_depth=1, direction="both").all() + self.assert_same_result_set(lr, cr, "bfs both") + + # ── graph() terminal ──────────────────────────────────────────────────── # + + def test_graph_structure(self): + """graph() returns {nodes, links} with matching sets.""" + lr = self.local.v("alice").out("knows").graph() + cr = self.cloud.v("alice").out("knows").graph() + self.assert_same_shape(lr, cr, "graph()") + self.assertIn("nodes", lr) + self.assertIn("links", lr) + self.assertIn("nodes", cr) + self.assertIn("links", cr) + # Compare node id sets + local_node_ids = {n["id"] for n in lr["nodes"]} + cloud_node_ids = {n["id"] for n in cr["nodes"]} + self.assertEqual(local_node_ids, cloud_node_ids) + + # ── triples() terminal ────────────────────────────────────────────────── # + + def test_triples(self): + """triples() returns the same set of (s, p, o) tuples.""" + lr = self.local.triples() + cr = self.cloud.triples() + self.assertIsInstance(lr, list) + self.assertIsInstance(cr, list) + self.assertEqual(sorted(lr), sorted(cr)) + + # ── put_batch + query verification ────────────────────────────────────── # + + def test_put_batch_and_query(self): + """put_batch() inserts are query-visible on both backends.""" + batch = [ + ("_pb_x", "rel", "_pb_y"), + ("_pb_y", "rel", "_pb_z"), + ("_pb_x", "rel", "_pb_z"), + ] + self.local.put_batch(batch) + self.cloud.put_batch(batch) + self.cloud.sync() + time.sleep(1) + + lr = self.local.v("_pb_x").out("rel").all() + cr = self.cloud.v("_pb_x").out("rel").all() + self.assert_same_result_set(lr, cr, "put_batch query") + self.assertEqual(len(lr["result"]), 2) + + # Clean up + for s, p, o in batch: + self.local.delete(s, p, o) + self.cloud.delete(s, p, o) + + # ── delete + verify ───────────────────────────────────────────────────── # + + def test_delete_removes_triple(self): + """Deleting a triple makes it invisible on both backends.""" + self.local.put("_del_a", "link", "_del_b") + self.cloud.put("_del_a", "link", "_del_b") + self.cloud.sync() + time.sleep(1) + + # Verify it exists + lr = self.local.v("_del_a").out("link").all() + cr = self.cloud.v("_del_a").out("link").all() + self.assertEqual(len(lr["result"]), 1) + self.assertEqual(len(cr["result"]), 1) + + # Delete and verify gone + self.local.delete("_del_a", "link", "_del_b") + self.cloud.delete("_del_a", "link", "_del_b") + self.cloud.sync() + time.sleep(1) + + lr = self.local.v("_del_a").out("link").all() + cr = self.cloud.v("_del_a").out("link").all() + self.assertEqual(lr["result"], []) + self.assertEqual(cr["result"], []) + + # ── Complex multi-step traversals ─────────────────────────────────────── # + + def test_out_then_has(self): + """v('alice').out('knows').has('works_at', 'acme') — traverse then filter.""" + lr = self.local.v("alice").out("knows").has("works_at", "acme").all() + cr = self.cloud.v("alice").out("knows").has("works_at", "acme").all() + self.assert_same_result_set(lr, cr, "out+has") + + def test_inc_then_out(self): + """Reverse then forward: inc('works_at').out('knows').""" + lr = self.local.v("acme").inc("works_at").out("knows").all() + cr = self.cloud.v("acme").inc("works_at").out("knows").all() + self.assert_same_result_set(lr, cr, "inc+out chain") + + def test_deep_chain_tag_is(self): + """Deep chain: v → tag → out → out → is_ → all with tag in results.""" + lr = (self.local.v("alice").tag("root") + .out("knows").out("works_at").is_("acme").all()) + cr = (self.cloud.v("alice").tag("root") + .out("knows").out("works_at").is_("acme").all()) + self.assert_same_result_set(lr, cr, "deep chain tag+is_") + for item in lr["result"]: + self.assertEqual(item.get("root"), "alice") + for item in cr["result"]: + self.assertEqual(item.get("root"), "alice") + + def test_v_list_out_order_limit(self): + """v([list]).out().order().limit() — combined pipeline.""" + lr = (self.local.v(["alice", "bob"]).out("knows") + .order("asc").limit(2).all()) + cr = (self.cloud.v(["alice", "bob"]).out("knows") + .order("asc").limit(2).all()) + self.assert_same_result_set(lr, cr, "v-list+out+order+limit") + self.assertEqual(len(lr["result"]), 2) + + def test_count_after_complex_traversal(self): + """count() at end of multi-step chain.""" + lr = self.local.v("alice").out("knows").out("works_at").count() + cr = self.cloud.v("alice").out("knows").out("works_at").count() + self.assertIsInstance(lr, int) + self.assertIsInstance(cr, int) + self.assertEqual(lr, cr) + + # ── ls() graph listing ────────────────────────────────────────────────── # + + def test_ls_contains_graph(self): + """ls() includes the current graph name on both backends.""" + lr = self.local.ls() + cr = self.cloud.ls() + self.assertIsInstance(lr, list) + self.assertIsInstance(cr, list) + self.assertIn(self.graph_name, cr) + + # ── scan with limit and type ──────────────────────────────────────────── # + + def test_scan_limit_respected(self): + """scan(limit=3) returns at most 3 items on both.""" + lr = self.local.scan(limit=3) + cr = self.cloud.scan(limit=3) + self.assertLessEqual(len(lr["result"]), 3) + self.assertLessEqual(len(cr["result"]), 3) + self.assert_same_shape(lr, cr, "scan limit=3") + + def test_scan_edges(self): + """scan(scan_type='e') returns edge results on both.""" + lr = self.local.scan(limit=5, scan_type="e") + cr = self.cloud.scan(limit=5, scan_type="e") + self.assert_same_shape(lr, cr, "scan edges") + self.assertEqual(set(lr.keys()), {"result"}) + self.assertEqual(set(cr.keys()), {"result"}) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/test_db_2.py b/test/test_db_2.py index 788e71b..fcc86d5 100644 --- a/test/test_db_2.py +++ b/test/test_db_2.py @@ -82,8 +82,9 @@ def test_put_same_value_multiple_times(self): cogdb.close() - def test_zzz_after_all_tests(self): - shutil.rmtree('/tmp/cogtestdb2') - shutil.rmtree('/tmp/cogtestdb3') - shutil.rmtree('/tmp/cogtestdb4') + @classmethod + def tearDownClass(cls): + shutil.rmtree('/tmp/cogtestdb2', ignore_errors=True) + shutil.rmtree('/tmp/cogtestdb3', ignore_errors=True) + shutil.rmtree('/tmp/cogtestdb4', ignore_errors=True) print("*** deleted test data.") diff --git a/test/test_loopback_parity.py b/test/test_loopback_parity.py new file mode 100644 index 0000000..d1c73bb --- /dev/null +++ b/test/test_loopback_parity.py @@ -0,0 +1,534 @@ +""" +Loopback Parity Tests + +Tests the full code path in torque.py (chain accumulation, +CloudClient serialization, response normalization) by replacing the HTTP +transport with a loopback that executes queries on a second local Graph. +""" + +import os +import re +import shutil +import unittest + +from cog.torque import Graph + +# Whitelist mirroring cog/server.py _execute_query +_ALLOWED_METHODS = { + 'v', 'out', 'inc', 'both', 'has', 'hasr', 'tag', 'back', + 'all', 'count', 'first', 'one', 'scan', 'filter', 'unique', 'limit', 'skip', + 'is_', 'bfs', 'dfs', 'sim', 'k_nearest', 'order', +} + +_METHOD_RE = re.compile(r'\.?([a-zA-Z_][a-zA-Z0-9_]*)\s*\(') + + +def _execute_query(graph, query_str): + """Safely eval a Torque query string — mirrors server._execute_query.""" + query_str = query_str.strip() + + allowed_starts = ('v(', 'scan(') + if not any(query_str.startswith(s) for s in allowed_starts): + raise ValueError(f"Query must start with one of: {list(allowed_starts)}") + + if '__' in query_str: + raise ValueError("Query contains forbidden pattern '__'") + + methods_used = set(_METHOD_RE.findall(query_str)) + invalid = methods_used - _ALLOWED_METHODS + if invalid: + raise ValueError(f"Disallowed methods: {invalid}") + + full_query = f"graph.{query_str}" + compile(full_query, '', 'eval') + result = eval(full_query, {"__builtins__": {}}, {"graph": graph}) # noqa: S307 + + if isinstance(result, dict): + return result + if isinstance(result, int): + return {"result": result} + return {"result": []} + + +class LoopbackTransport: + """ + Routes cloud HTTP calls to a second *local* Graph so that the full + serialization ↔ deserialization round-trip is exercised without + touching the network. + """ + + def __init__(self, backing_graph): + self.g = backing_graph + + def __call__(self, method, path, body=None): + if path == "/mutate_batch": + return self._mutate_batch(body) + if path == "/query": + return self._query(body) + raise ValueError(f"LoopbackTransport: unhandled path {path}") + + # ── mutations ──────────────────────────────────────────────────── + + def _mutate_batch(self, body): + mutations = body.get("mutations", []) + for m in mutations: + self._apply_one(m) + return {"ok": True, "count": len(mutations)} + + def _apply_one(self, m): + op = m.get("op", "") + if op == "PUT": + self.g.put( + m["s"], m["p"], m["o"], + update=m.get("update", False), + create_new_edge=m.get("create_new_edge", False), + ) + elif op == "DELETE": + self.g.delete(m["s"], m["p"], m["o"]) + elif op == "DROP": + self.g.drop() + elif op == "TRUNCATE": + self.g.truncate() + elif op == "PUT_EMBEDDING": + self.g.put_embedding(m["word"], m["embedding"]) + elif op == "DELETE_EMBEDDING": + self.g.delete_embedding(m["word"]) + else: + raise ValueError(f"LoopbackTransport: unknown mutation op={op}") + + # ── queries ────────────────────────────────────────────────────── + + def _query(self, body): + q = body["q"] + + # Embedding helpers — not standard traversals + if q == "embedding_stats()": + result = self.g.embedding_stats() + return {"ok": True, **result} + + m = re.match(r'^get_embedding\("(.+)"\)$', q) + if m: + emb = self.g.get_embedding(m.group(1)) + return {"ok": True, "embedding": emb} + + m = re.match(r'^scan_embeddings\((\d+)\)$', q) + if m: + result = self.g.scan_embeddings(limit=int(m.group(1))) + return {"ok": True, **result} + + # Standard traversal queries (v(...), scan(...)) + result = _execute_query(self.g, q) + # Mirror the server handler: {"ok": True, "result": } + return {"ok": True, "result": result.get("result", result)} + + +# ─────────────────────────────────────────────────────────────────────────── # +# Helpers # +# ─────────────────────────────────────────────────────────────────────────── # + +LOCAL_DIR = "/tmp/LoopbackParityLocal" +CLOUD_DIR = "/tmp/LoopbackParityCloud" + + +def ordered(obj): + """Recursively sort dicts/lists for deterministic comparison.""" + if isinstance(obj, dict): + return sorted((k, ordered(v)) for k, v in obj.items()) + if isinstance(obj, list): + return sorted(ordered(x) for x in obj) + return obj + + +def _make_cloud_graph(graph_name, backing_graph): + """Create a Graph in cloud mode whose CloudClient routes to *backing_graph*.""" + g = Graph(graph_name=graph_name, api_key="loopback-key") + # Patch the transport layer + g._cloud_client._request = LoopbackTransport(backing_graph) + return g + + +# ─────────────────────────────────────────────────────────────────────────── # +# Test class # +# ─────────────────────────────────────────────────────────────────────────── # + +class TestLoopbackParity(unittest.TestCase): + """ + Seeds identical data into a local graph and a loopback-cloud graph, + runs the same operations, and asserts the responses are identical. + """ + + maxDiff = None + + # ── fixtures ───────────────────────────────────────────────────── + + @classmethod + def setUpClass(cls): + for d in (LOCAL_DIR, CLOUD_DIR): + if os.path.exists(d): + shutil.rmtree(d) + os.makedirs(d, exist_ok=True) + + graph_name = "loopback_parity" + + # Reference local graph + cls.local = Graph(graph_name=graph_name, cog_home="LoopbackParityLocal") + assert not cls.local._cloud + + # Backing graph for the loopback transport (separate storage) + cls._backing = Graph(graph_name=graph_name, cog_home="LoopbackParityCloud") + assert not cls._backing._cloud + + # Cloud graph wired to the backing local graph + cls.cloud = _make_cloud_graph(graph_name, cls._backing) + assert cls.cloud._cloud + + # Seed identical data + cls.triples = [ + ("alice", "knows", "bob"), + ("bob", "knows", "charlie"), + ("charlie", "knows", "alice"), + ("alice", "works_at", "acme"), + ("bob", "works_at", "globex"), + ("charlie", "works_at", "acme"), + ("alice", "age", "30"), + ("bob", "age", "25"), + ("charlie", "age", "35"), + ] + for s, p, o in cls.triples: + cls.local.put(s, p, o) + cls.cloud.put(s, p, o) + + @classmethod + def tearDownClass(cls): + cls.local.close() + cls._backing.close() + for d in (LOCAL_DIR, CLOUD_DIR): + shutil.rmtree(d, ignore_errors=True) + + # ── assertion helpers ──────────────────────────────────────────── + + def assert_same_type(self, local_result, cloud_result, ctx=""): + self.assertEqual( + type(local_result), type(cloud_result), + f"Type mismatch ({ctx}): local={type(local_result).__name__}, " + f"cloud={type(cloud_result).__name__}", + ) + + def assert_same_shape(self, local_result, cloud_result, ctx=""): + self.assert_same_type(local_result, cloud_result, ctx) + if isinstance(local_result, dict) and isinstance(cloud_result, dict): + self.assertEqual(set(local_result.keys()), set(cloud_result.keys()), + f"Key mismatch ({ctx})") + + def assert_same_result_set(self, local_result, cloud_result, ctx=""): + self.assert_same_shape(local_result, cloud_result, ctx) + self.assertEqual(ordered(local_result), ordered(cloud_result), + f"Value mismatch ({ctx})") + + # ================================================================ # + # Mutations — return types # + # ================================================================ # + + def test_put_returns_self(self): + lr = self.local.put("_tmp", "r", "v") + cr = self.cloud.put("_tmp", "r", "v") + self.assertIsInstance(lr, Graph) + self.assertIsInstance(cr, Graph) + self.assertIs(lr, self.local) + self.assertIs(cr, self.cloud) + self.local.delete("_tmp", "r", "v") + self.cloud.delete("_tmp", "r", "v") + + def test_delete_returns_self(self): + self.local.put("_d", "r", "v") + self.cloud.put("_d", "r", "v") + lr = self.local.delete("_d", "r", "v") + cr = self.cloud.delete("_d", "r", "v") + self.assertIsInstance(lr, Graph) + self.assertIsInstance(cr, Graph) + + def test_method_chaining_put(self): + lr = self.local.put("_c1", "r", "v1").put("_c2", "r", "v2") + cr = self.cloud.put("_c1", "r", "v1").put("_c2", "r", "v2") + self.assertIsInstance(lr, Graph) + self.assertIsInstance(cr, Graph) + for s, o in [("_c1", "v1"), ("_c2", "v2")]: + self.local.delete(s, "r", o) + self.cloud.delete(s, "r", o) + + def test_put_batch(self): + batch = [("_b1", "r", "x"), ("_b2", "r", "y")] + lr = self.local.put_batch(batch) + cr = self.cloud.put_batch(batch) + self.assertIsInstance(lr, Graph) + self.assertIsInstance(cr, Graph) + # Verify data landed + lr_data = self.local.v("_b1").out("r").all() + cr_data = self.cloud.v("_b1").out("r").all() + self.assert_same_result_set(lr_data, cr_data, "put_batch verify") + for s, _, o in batch: + self.local.delete(s, "r", o) + self.cloud.delete(s, "r", o) + + # ================================================================ # + # Traversals # + # ================================================================ # + + def test_v_all(self): + lr = self.local.v().all() + cr = self.cloud.v().all() + self.assert_same_shape(lr, cr, "v().all()") + self.assertEqual(set(lr.keys()), {"result"}) + self.assertEqual(set(cr.keys()), {"result"}) + + def test_v_vertex_out_all(self): + lr = self.local.v("alice").out("knows").all() + cr = self.cloud.v("alice").out("knows").all() + self.assert_same_result_set(lr, cr, "v('alice').out('knows').all()") + + def test_v_inc_all(self): + lr = self.local.v("bob").inc("knows").all() + cr = self.cloud.v("bob").inc("knows").all() + self.assert_same_result_set(lr, cr, "v('bob').inc('knows').all()") + + def test_has_filter(self): + lr = self.local.v().has("works_at", "acme").all() + cr = self.cloud.v().has("works_at", "acme").all() + self.assert_same_result_set(lr, cr, "has('works_at','acme')") + + def test_hasr(self): + lr = self.local.v().hasr("knows", "alice").all() + cr = self.cloud.v().hasr("knows", "alice").all() + self.assert_same_result_set(lr, cr, "hasr('knows','alice')") + + def test_both(self): + lr = self.local.v("bob").both("knows").all() + cr = self.cloud.v("bob").both("knows").all() + self.assert_same_result_set(lr, cr, "both('knows')") + + def test_chained_out(self): + lr = self.local.v("alice").out("knows").out("knows").all() + cr = self.cloud.v("alice").out("knows").out("knows").all() + self.assert_same_result_set(lr, cr, "chained out()") + + def test_v_list(self): + lr = self.local.v(["alice", "bob"]).all() + cr = self.cloud.v(["alice", "bob"]).all() + self.assert_same_result_set(lr, cr, "v([list]).all()") + + def test_v_all_vertices(self): + """v() with no args returns all vertices.""" + lr = self.local.v().all() + cr = self.cloud.v().all() + self.assertEqual(len(lr["result"]), len(cr["result"])) + + # ================================================================ # + # Intermediate ops — unique, limit, skip, order, is_, tag, back # + # ================================================================ # + + def test_unique(self): + lr = self.local.v("alice").out("knows").out("knows").unique().all() + cr = self.cloud.v("alice").out("knows").out("knows").unique().all() + self.assert_same_result_set(lr, cr, "unique()") + + def test_limit(self): + lr = self.local.v().limit(2).all() + cr = self.cloud.v().limit(2).all() + self.assert_same_shape(lr, cr, "limit()") + self.assertEqual(len(lr["result"]), len(cr["result"])) + + def test_skip(self): + lr = self.local.v().skip(2).all() + cr = self.cloud.v().skip(2).all() + self.assert_same_shape(lr, cr, "skip()") + + def test_order_asc(self): + lr = self.local.v().order("asc").all() + cr = self.cloud.v().order("asc").all() + # Order should match exactly + self.assertEqual( + [item["id"] for item in lr["result"]], + [item["id"] for item in cr["result"]], + "order(asc) mismatch", + ) + + def test_order_desc(self): + lr = self.local.v().order("desc").all() + cr = self.cloud.v().order("desc").all() + self.assertEqual( + [item["id"] for item in lr["result"]], + [item["id"] for item in cr["result"]], + "order(desc) mismatch", + ) + + def test_is_(self): + lr = self.local.v().out("knows").is_("bob").all() + cr = self.cloud.v().out("knows").is_("bob").all() + self.assert_same_result_set(lr, cr, "is_('bob')") + + def test_tag_all(self): + lr = self.local.v("alice").tag("start").out("knows").all() + cr = self.cloud.v("alice").tag("start").out("knows").all() + self.assert_same_result_set(lr, cr, "tag('start').out().all()") + + def test_back(self): + lr = (self.local.v("alice").tag("origin") + .out("knows").out("works_at").back("origin").all()) + cr = (self.cloud.v("alice").tag("origin") + .out("knows").out("works_at").back("origin").all()) + self.assert_same_result_set(lr, cr, "back('origin')") + + # ================================================================ # + # count() # + # ================================================================ # + + def test_count(self): + lr = self.local.v("alice").out("knows").count() + cr = self.cloud.v("alice").out("knows").count() + self.assertIsInstance(lr, int) + self.assertIsInstance(cr, int) + self.assertEqual(lr, cr) + + def test_v_count(self): + lr = self.local.v().count() + cr = self.cloud.v().count() + self.assertIsInstance(lr, int) + self.assertIsInstance(cr, int) + self.assertEqual(lr, cr) + + def test_count_empty(self): + lr = self.local.v("nonexistent").out("knows").count() + cr = self.cloud.v("nonexistent").out("knows").count() + self.assertEqual(lr, 0) + self.assertEqual(cr, 0) + + # ================================================================ # + # scan() # + # ================================================================ # + + def test_scan_shape(self): + lr = self.local.scan(limit=5) + cr = self.cloud.scan(limit=5) + self.assert_same_shape(lr, cr, "scan()") + self.assertEqual(set(lr.keys()), {"result"}) + self.assertEqual(len(lr["result"]), len(cr["result"])) + + def test_scan_items_have_id(self): + for item in self.local.scan()["result"]: + self.assertIn("id", item) + for item in self.cloud.scan()["result"]: + self.assertIn("id", item) + + # ================================================================ # + # Edge cases # + # ================================================================ # + + def test_empty_result(self): + lr = self.local.v("nonexistent").out("knows").all() + cr = self.cloud.v("nonexistent").out("knows").all() + self.assert_same_shape(lr, cr, "empty result") + self.assertEqual(lr["result"], []) + self.assertEqual(cr["result"], []) + + def test_empty_v_all(self): + """v() on an empty graph would return result list.""" + lr = self.local.v("nobody_here").all() + cr = self.cloud.v("nobody_here").all() + self.assert_same_shape(lr, cr, "v(nonexistent).all()") + + # ================================================================ # + # BFS / DFS # + # ================================================================ # + + def test_bfs(self): + lr = self.local.v("alice").bfs("knows", max_depth=2).all() + cr = self.cloud.v("alice").bfs("knows", max_depth=2).all() + self.assert_same_result_set(lr, cr, "bfs(knows, 2)") + + def test_dfs(self): + lr = self.local.v("alice").dfs("knows", max_depth=2).all() + cr = self.cloud.v("alice").dfs("knows", max_depth=2).all() + self.assert_same_result_set(lr, cr, "dfs(knows, 2)") + + # ================================================================ # + # Embeddings # + # ================================================================ # + + def test_put_get_embedding(self): + emb = [0.1, 0.2, 0.3, 0.4] + self.local.put_embedding("emb_word", emb) + self.cloud.put_embedding("emb_word", emb) + + lr = self.local.get_embedding("emb_word") + cr = self.cloud.get_embedding("emb_word") + self.assertEqual(lr, cr) + + self.local.delete_embedding("emb_word") + self.cloud.delete_embedding("emb_word") + + def test_embedding_stats(self): + self.local.put_embedding("stat_w", [1.0, 2.0]) + self.cloud.put_embedding("stat_w", [1.0, 2.0]) + + lr = self.local.embedding_stats() + cr = self.cloud.embedding_stats() + self.assert_same_shape(lr, cr, "embedding_stats()") + self.assertEqual(lr["count"], cr["count"]) + + self.local.delete_embedding("stat_w") + self.cloud.delete_embedding("stat_w") + + # ================================================================ # + # Lifecycle no-ops # + # ================================================================ # + + def test_sync_noop(self): + self.local.sync() + self.cloud.sync() + + def test_refresh_noop(self): + self.local.refresh() + self.cloud.refresh() + + def test_close_safe(self): + """close() on cloud graph is a no-op and doesn't raise.""" + tmp = _make_cloud_graph("close_test", self._backing) + tmp.close() + + # ================================================================ # + # Truncate # + # ================================================================ # + + def test_truncate(self): + """truncate() empties the graph; both backends return 0 after.""" + # Separate graphs so we don't affect other tests + ld = "/tmp/LoopbackTruncLocal" + cd = "/tmp/LoopbackTruncCloud" + for d in (ld, cd): + if os.path.exists(d): + shutil.rmtree(d) + os.makedirs(d, exist_ok=True) + + lg = Graph(graph_name="trunc_test", cog_home="LoopbackTruncLocal") + bg = Graph(graph_name="trunc_test", cog_home="LoopbackTruncCloud") + cg = _make_cloud_graph("trunc_test", bg) + + lg.put("x", "r", "y") + cg.put("x", "r", "y") + + lr = lg.truncate() + cr = cg.truncate() + self.assertIsInstance(lr, Graph) + self.assertIsInstance(cr, Graph) + + self.assertEqual(lg.v().count(), 0) + self.assertEqual(cg.v().count(), 0) + + lg.close() + bg.close() + for d in (ld, cd): + shutil.rmtree(d, ignore_errors=True) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/test_ls_use.py b/test/test_ls_use.py new file mode 100644 index 0000000..f9e6d42 --- /dev/null +++ b/test/test_ls_use.py @@ -0,0 +1,200 @@ +from cog.torque import Graph +import unittest +import os +import shutil + +DIR_NAME = "TestLsUse" + + +class TestDefaultGraphName(unittest.TestCase): + """Test that graph_name defaults to 'default'.""" + + @classmethod + def setUpClass(cls): + if os.path.exists("/tmp/" + DIR_NAME): + shutil.rmtree("/tmp/" + DIR_NAME) + + def test_default_graph_name(self): + """Graph() with no name should use 'default'.""" + g = Graph(cog_home=DIR_NAME) + self.assertEqual(g.graph_name, "default") + g.put("a", "rel", "b") + self.assertEqual(g.v("a").out("rel").count(), 1) + g.close() + + def test_explicit_graph_name_still_works(self): + """Graph('myname') should still work as before.""" + g = Graph("explicit_test", cog_home=DIR_NAME) + self.assertEqual(g.graph_name, "explicit_test") + g.put("x", "rel", "y") + self.assertEqual(g.v("x").out("rel").count(), 1) + g.close() + + @classmethod + def tearDownClass(cls): + if os.path.exists("/tmp/" + DIR_NAME): + shutil.rmtree("/tmp/" + DIR_NAME) + + +class TestLs(unittest.TestCase): + """Test ls() lists all graphs.""" + + @classmethod + def setUpClass(cls): + if os.path.exists("/tmp/" + DIR_NAME): + shutil.rmtree("/tmp/" + DIR_NAME) + + def test_ls_single_graph(self): + """ls() should list the current graph.""" + g = Graph("graph_one", cog_home=DIR_NAME) + g.put("a", "rel", "b") + graphs = g.ls() + self.assertIn("graph_one", graphs) + g.close() + + def test_ls_multiple_graphs(self): + """ls() should list all graphs in the cog_home.""" + g1 = Graph("alpha", cog_home=DIR_NAME) + g1.put("a", "rel", "b") + + g2 = Graph("beta", cog_home=DIR_NAME) + g2.put("x", "rel", "y") + + g3 = Graph("gamma", cog_home=DIR_NAME) + g3.put("m", "rel", "n") + + graphs = g1.ls() + self.assertIn("alpha", graphs) + self.assertIn("beta", graphs) + self.assertIn("gamma", graphs) + # Should be sorted + self.assertEqual(graphs, sorted(graphs)) + + g1.close() + g2.close() + g3.close() + + def test_ls_excludes_sys_and_views(self): + """ls() should not include 'sys' or 'views' directories.""" + g = Graph("real_graph", cog_home=DIR_NAME) + g.put("a", "rel", "b") + graphs = g.ls() + self.assertNotIn("sys", graphs) + self.assertNotIn("views", graphs) + self.assertIn("real_graph", graphs) + g.close() + + def test_ls_returns_sorted(self): + """ls() should return graph names in sorted order.""" + g = Graph("zebra", cog_home=DIR_NAME) + g.put("a", "rel", "b") + Graph("aardvark", cog_home=DIR_NAME).put("x", "rel", "y") + + graphs = g.ls() + self.assertEqual(graphs, sorted(graphs)) + g.close() + + @classmethod + def tearDownClass(cls): + if os.path.exists("/tmp/" + DIR_NAME): + shutil.rmtree("/tmp/" + DIR_NAME) + + +class TestUse(unittest.TestCase): + """Test use() switches between graphs.""" + + @classmethod + def setUpClass(cls): + if os.path.exists("/tmp/" + DIR_NAME): + shutil.rmtree("/tmp/" + DIR_NAME) + + def test_use_switches_graph(self): + """use() should switch to a different graph.""" + g = Graph("first", cog_home=DIR_NAME) + g.put("alice", "knows", "bob") + self.assertEqual(g.graph_name, "first") + self.assertEqual(g.v("alice").out("knows").count(), 1) + + # Switch to a new graph + g.use("second") + self.assertEqual(g.graph_name, "second") + # New graph should be empty + self.assertEqual(g.v().count(), 0) + + # Add data to second graph + g.put("charlie", "knows", "dave") + self.assertEqual(g.v("charlie").out("knows").count(), 1) + + # Switch back — first graph should still have its data + g.use("first") + self.assertEqual(g.v("alice").out("knows").count(), 1) + # charlie should not be in first graph + self.assertEqual(g.v("charlie").out("knows").count(), 0) + g.close() + + def test_use_returns_self(self): + """use() should return self for method chaining.""" + g = Graph("chain_test", cog_home=DIR_NAME) + g.put("a", "rel", "b") + result = g.use("chain_test") + self.assertIs(result, g) + g.close() + + def test_use_chaining(self): + """use() should support method chaining.""" + g = Graph("chained", cog_home=DIR_NAME) + g.put("alice", "knows", "bob") + + result = g.use("chained").v("alice").out("knows").all() + self.assertEqual(result, {"result": [{"id": "bob"}]}) + g.close() + + def test_use_creates_new_graph(self): + """use() should create the graph if it doesn't exist.""" + g = Graph("starter", cog_home=DIR_NAME) + g.put("a", "rel", "b") + + g.use("brand_new") + self.assertEqual(g.graph_name, "brand_new") + # Should be usable immediately + g.put("x", "rel", "y") + self.assertEqual(g.v("x").out("rel").count(), 1) + + # brand_new should now appear in ls() + self.assertIn("brand_new", g.ls()) + g.close() + + def test_use_with_ls_workflow(self): + """Full workflow: default graph → ls → use.""" + g = Graph(cog_home=DIR_NAME) + self.assertEqual(g.graph_name, "default") + + # Create some named graphs via use + g.use("social") + g.put("alice", "follows", "bob") + g.use("products") + g.put("widget", "category", "tools") + + # List all graphs + graphs = g.ls() + self.assertIn("social", graphs) + self.assertIn("products", graphs) + + # Switch back and verify data isolation + g.use("social") + self.assertEqual(g.v("alice").out("follows").count(), 1) + self.assertEqual(g.v("widget").out("category").count(), 0) + + g.use("products") + self.assertEqual(g.v("widget").out("category").count(), 1) + self.assertEqual(g.v("alice").out("follows").count(), 0) + g.close() + + @classmethod + def tearDownClass(cls): + if os.path.exists("/tmp/" + DIR_NAME): + shutil.rmtree("/tmp/" + DIR_NAME) + + +if __name__ == '__main__': + unittest.main() diff --git a/test/test_torque2.py b/test/test_torque2.py index 44021be..aa704e6 100644 --- a/test/test_torque2.py +++ b/test/test_torque2.py @@ -52,7 +52,8 @@ def test_torque_load_csv(self): @classmethod def tearDownClass(cls): - shutil.rmtree("/tmp/"+DIR_NAME) + shutil.rmtree("/tmp/"+DIR_NAME, ignore_errors=True) + shutil.rmtree("/tmp/cog_home", ignore_errors=True) print("*** deleted test data.")