diff --git a/CLAUDE.md b/CLAUDE.md index b15fad77..7830c4d1 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -42,8 +42,7 @@ Tests use **Vitest** for the frontend (React/TypeScript with React Testing Libra **100% code coverage is mandatory.** Any new or modified code — frontend or backend — must maintain 100% coverage across lines, functions, branches, and statements. PRs that drop below 100% coverage will not be merged. -- **Frontend:** Run `bun run test:coverage` and verify all metrics are 100%. -- **Backend:** Run `bun run test:backend:coverage` to enforce 100% line coverage (identical to what CI runs). Functions excluded from coverage with `#[cfg_attr(coverage_nightly, coverage(off))]` must be thin wrappers (Tauri commands, filesystem I/O) whose logic is tested through the functions they delegate to. +**Always run `bun run test:all:coverage` (never the bare `bun run test` / `bun run test:all`).** This single command runs both Vitest with coverage and the cargo llvm-cov gate that CI enforces. If it does not exit cleanly, the task is not done. Functions excluded from coverage with `#[cfg_attr(coverage_nightly, coverage(off))]` must be thin wrappers (Tauri commands, filesystem I/O) whose logic is tested through the functions they delegate to. ## Architecture @@ -143,7 +142,7 @@ When extending the system, preserve this contract: **never panic on user input** After making any code changes and before ending your response, you must: -1. Run `bun run test` — all tests must pass +1. Run `bun run test:all:coverage` — frontend + backend tests must pass AND 100% coverage gate must hold 2. Run `bun run validate-build` — must complete with **zero warnings and zero errors** Do not consider the task done if either step produces any warnings or errors. Fix all issues first. @@ -152,6 +151,16 @@ Do not consider the task done if either step produces any warnings or errors. Fi Never commit files generated by superpowers skills (design specs, implementation plans, brainstorming docs). These live under `docs/superpowers/` which is gitignored. Do not stage or commit anything under that path. +## GStack Design Tooling Fallback + +When invoking GStack design skills (`/design-shotgun`, `/design-html`, `/design-review`, etc.) inside Claude Code on this project: if the design CLI fails because no OpenAI API key is configured (e.g. `setup` not run, `OPENAI_API_KEY` unset, `~/.gstack/openai.json` missing), do not block the user with a setup prompt. Automatically fall back to hand-crafted HTML wireframes that use the real Thuki design tokens read directly from the source files (`src/view/onboarding/PermissionsStep.tsx`, `src/view/onboarding/IntroStep.tsx`, `src/components/`). These wireframes are strictly more accurate to the final UI than image generation because they use the exact CSS values rather than a model's interpretation of them. + +Workflow: +1. Read the relevant source files to extract the actual design tokens (colors, spacing, fonts, border radii, gradients, shadows). +2. Write the wireframes as static HTML files in `~/.gstack/projects/quiet-node-thuki/designs/-/` so they live alongside any future image-based mockups. +3. Open the wireframes in the browser via `open file://...` for review. +4. Only mention the missing API key as a one-line aside, not as a blocker. The user can opt back into image generation later. + ## Key Design Constraints - **macOS only** — uses NSPanel, Core Graphics event taps, macOS Control key diff --git a/README.md b/README.md index b66dd707..4dddb116 100644 --- a/README.md +++ b/README.md @@ -79,7 +79,7 @@ Most AI tools require accounts, API keys, or subscriptions that bill you per tok ### Step 1: Set Up Your AI Engine -> **Default model:** Thuki ships with [`gemma4:e2b`](https://ollama.com/library/gemma4) by default, an effective 2B parameter edge model from Google. It runs comfortably on most modern Macs with 8 GB of RAM and delivers strong performance on reasoning, coding, and vision tasks. To use a different model, edit `~/Library/Application Support/com.quietnode.thuki/config.toml` and reorder the `[model] available` list so your preferred model is first. See [Configurations](docs/configurations.md) for the full schema. +> **Default model:** Thuki ships with [`gemma4:e2b`](https://ollama.com/library/gemma4) by default, an effective 2B parameter edge model from Google. It runs comfortably on most modern Macs with 8 GB of RAM and delivers strong performance on reasoning, coding, and vision tasks. The ask-bar model picker lists the models currently installed in your local Ollama and lets you switch the active model without leaving the overlay. To change the bootstrap default itself, edit `~/Library/Application Support/com.quietnode.thuki/config.toml` and reorder the `[model] available` list so your preferred model is first. See [Configurations](docs/configurations.md) for the full schema. Choose one of the two options below to set up your AI engine before installing Thuki. @@ -256,7 +256,6 @@ The big leap: from answering questions to taking action. More flexibility over the model powering Thuki. - **Native settings panel (⌘,):** a proper macOS preferences window to configure your model, Ollama endpoint, activation shortcut, slash commands, and system prompt. No config files needed. -- **In-app model switching:** swap between any Ollama model from the UI without restarting (the backend already supports multiple models via the `[model] available` list in `config.toml`; the picker UI is next) - **Multiple provider support:** opt in to OpenAI, Anthropic, or any OpenAI-compatible endpoint as an alternative to local Ollama - **Custom activation shortcut:** change the double-tap trigger to any key or combo you prefer diff --git a/docs/configurations.md b/docs/configurations.md index c8c2aa67..21f177a2 100644 --- a/docs/configurations.md +++ b/docs/configurations.md @@ -27,10 +27,9 @@ open ~/Library/Application\ Support/com.quietnode.thuki/config.toml ```toml [model] -# First entry is the ACTIVE model used for all inference. -# Reorder the list to switch models (requires app restart in this release). -# Run `ollama pull ` before adding a model you haven't used. -available = ["gemma4:e2b", "gemma4:e4b"] +# Where Thuki finds your local Ollama server. The active model itself is +# selected from the in-app picker (which lists whatever is installed in +# Ollama via /api/tags) and is stored in Thuki's local database, not here. ollama_url = "http://127.0.0.1:11434" [prompt] @@ -81,16 +80,24 @@ Every domain below is shown as a single table that lists **all** constants Thuki ## Reference -### `[LLM models]` +### `[model]` -Which AI model Thuki uses and where to find your local Ollama server. +Where to find your local Ollama server. The active model itself is **not** a TOML setting: Thuki discovers installed models live from Ollama's `/api/tags` endpoint, lets you pick one from the in-app model picker, and stores that selection in its local SQLite database (`app_config` table). Storing the active slug in TOML would duplicate ground truth from Ollama and break the moment you remove a model with `ollama rm`, so it lives next to the conversation history instead. -| Constant | Default | Tunable? | Why not tunable | Bounds | Description | -| :----------- | :------------------------- | :------- | :-------------- | :------------- | :---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | -| `available` | `["gemma4:e2b"]` | Yes | — | non-empty list | The list of Ollama models Thuki knows about. **The first model in the list is the one Thuki actually uses.** To switch models, reorder the list. Make sure to run `ollama pull ` before adding a new entry here. | -| `ollama_url` | `"http://127.0.0.1:11434"` | Yes | — | non-empty URL | The web address where Thuki finds your local Ollama server. The default works if you run Ollama on this machine with its standard port. Change this only if you moved Ollama to a different port or another machine. | +| Constant | Default | Tunable? | Why not tunable | Bounds | Description | +| :----------- | :------------------------- | :------- | :-------------- | :------------ | :------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | +| `ollama_url` | `"http://127.0.0.1:11434"` | Yes | — | non-empty URL | The web address where Thuki finds your local Ollama server. The default works if you run Ollama on this machine with its standard port. Change this only if you moved Ollama to a different port or another machine. | -If the active model has not been pulled, the next request surfaces a "Model not found" error with the exact `ollama pull ` command to run. +If the active model has been removed from Ollama between launches, Thuki silently falls back to the first installed model the next time you open the picker. If no models are installed at all, the next request surfaces a "Model not found" error with the exact `ollama pull ` command to run. + +The table below also lists the baked-in safety limits that govern Thuki's communication with the Ollama HTTP API. None are tunable. + +| Constant | Default | Tunable? | Why not tunable | Bounds | Description | +| :------------------------------------------ | :------- | :------- | :---------------------------------------------------------------------------------------------------------------------------------------------------------------------- | :----- | :----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | +| `DEFAULT_OLLAMA_TAGS_REQUEST_TIMEOUT_SECS` | `5 s` | No | Protocol cap on a hung daemon to keep the UI responsive. A longer timeout would wedge the model picker; a shorter one would false-trigger on a momentarily slow daemon. | — | How long Thuki waits for Ollama's `/api/tags` endpoint to respond before giving up. If Ollama accepts the connection but never replies, this prevents the picker from stalling. | +| `DEFAULT_OLLAMA_SHOW_REQUEST_TIMEOUT_SECS` | `5 s` | No | Protocol cap on a hung daemon to keep the UI responsive. Same rationale as the tags timeout above. | — | How long Thuki waits for Ollama's `/api/show` endpoint to respond before giving up. Used when fetching capability flags (vision, thinking) for each installed model. | +| `MAX_OLLAMA_TAGS_BODY_BYTES` | `4 MiB` | No | Defense-in-depth bound on attacker-controlled response body. A misbehaving or compromised Ollama could otherwise stream an unbounded payload and exhaust memory. | — | The largest `/api/tags` response body Thuki will accept. 4 MiB fits thousands of model entries; anything larger is rejected immediately and the request returns an error. | +| `MAX_OLLAMA_SHOW_BODY_BYTES` | `4 MiB` | No | Defense-in-depth bound on attacker-controlled response body. Same rationale as `MAX_OLLAMA_TAGS_BODY_BYTES`. | — | The largest `/api/show` response body Thuki will accept. Full Modelfiles and parameters can be sizable, but 4 MiB is well above any real model; larger responses are rejected. | ### `[prompt]` diff --git a/src-tauri/src/commands.rs b/src-tauri/src/commands.rs index b5eb213f..17702ab6 100644 --- a/src-tauri/src/commands.rs +++ b/src-tauri/src/commands.rs @@ -30,19 +30,42 @@ pub struct OllamaError { pub message: String, } -/// Maps an HTTP status code to a user-friendly `OllamaError`. The `model_name` -/// is woven into the `ModelNotFound` hint so the user sees the exact command -/// to run, whatever their active model happens to be. -pub fn classify_http_error(status: u16, model_name: &str) -> OllamaError { +/// Pulls the human-readable reason out of an Ollama error payload. Ollama +/// returns `{"error":"..."}` on every non-2xx status from `/api/chat`; when +/// the body is empty, malformed, or missing the `error` key we return +/// `None` so the caller can fall back to the bare status code. +pub fn extract_ollama_error_message(body: &str) -> Option { + let trimmed = body.trim(); + if trimmed.is_empty() { + return None; + } + serde_json::from_str::(trimmed) + .ok() + .and_then(|v| v.get("error").and_then(|e| e.as_str()).map(str::to_string)) + .map(|s| s.trim().to_string()) + .filter(|s| !s.is_empty()) +} + +/// Maps an HTTP status code (plus the response body for non-404 paths) to a +/// user-friendly `OllamaError`. The `model_name` is woven into the +/// `ModelNotFound` hint so the user sees the exact command to run; for every +/// other status we surface the concrete reason Ollama returned (e.g. "this +/// model only supports one image while more than one image requested") so +/// the user can act on it instead of staring at a bare HTTP code. +pub fn classify_http_error(status: u16, model_name: &str, body: &str) -> OllamaError { match status { 404 => OllamaError { kind: OllamaErrorKind::ModelNotFound, message: format!("Model not found\nRun: ollama pull {model_name} in a terminal."), }, - _ => OllamaError { - kind: OllamaErrorKind::Other, - message: format!("Something went wrong\nHTTP {status}"), - }, + _ => { + let detail = + extract_ollama_error_message(body).unwrap_or_else(|| format!("HTTP {status}")); + OllamaError { + kind: OllamaErrorKind::Other, + message: format!("Something went wrong\n{detail}"), + } + } } } @@ -226,7 +249,15 @@ pub async fn stream_ollama_chat( Ok(response) => { if !response.status().is_success() { let status = response.status().as_u16(); - on_chunk(StreamChunk::Error(classify_http_error(status, model))); + // Drain the body so the user sees Ollama's own reason + // (e.g. "this model only supports one image while more + // than one image requested") instead of a bare HTTP code. + // A failed read collapses to an empty string and the + // classifier falls back to the status code. + let body = response.text().await.unwrap_or_default(); + on_chunk(StreamChunk::Error(classify_http_error( + status, model, &body, + ))); return accumulated; } @@ -318,8 +349,14 @@ pub async fn ask_ollama( generation: State<'_, GenerationState>, history: State<'_, ConversationHistory>, config: State<'_, AppConfig>, + active_model: State<'_, crate::models::ActiveModelState>, ) -> Result<(), String> { let endpoint = format!("{}/api/chat", config.model.ollama_url.trim_end_matches('/')); + // Snapshot the active model slug; drop the guard before any `.await`. + let model_name = { + let guard = active_model.0.lock().map_err(|e| e.to_string())?; + guard.clone() + }; let cancel_token = CancellationToken::new(); generation.set_token(cancel_token.clone()); @@ -366,7 +403,7 @@ pub async fn ask_ollama( let accumulated = stream_ollama_chat( &endpoint, - config.model.active(), + &model_name, messages, think, &client, @@ -1153,28 +1190,78 @@ mod tests { #[test] fn classify_http_404_returns_model_not_found() { - let err = classify_http_error(404, "gemma4:e2b"); + let err = classify_http_error(404, "gemma4:e2b", ""); assert_eq!(err.kind, OllamaErrorKind::ModelNotFound); assert!(err.message.contains("gemma4:e2b")); } #[test] fn classify_http_404_includes_requested_model_name_in_hint() { - let err = classify_http_error(404, "custom:model"); + let err = classify_http_error(404, "custom:model", ""); assert_eq!(err.kind, OllamaErrorKind::ModelNotFound); assert!(err.message.contains("custom:model")); } #[test] - fn classify_http_500_returns_other_with_status() { - let err = classify_http_error(500, "gemma4:e2b"); + fn classify_http_500_with_empty_body_falls_back_to_status_code() { + let err = classify_http_error(500, "gemma4:e2b", ""); + assert_eq!(err.kind, OllamaErrorKind::Other); + assert!(err.message.contains("500")); + } + + #[test] + fn classify_http_500_surfaces_ollama_error_text_when_present() { + let body = + r#"{"error":"this model only supports one image while more than one image requested"}"#; + let err = classify_http_error(500, "llama3.2-vision:11b", body); + assert_eq!(err.kind, OllamaErrorKind::Other); + assert!(err + .message + .contains("only supports one image while more than one image requested")); + assert!(!err.message.contains("HTTP 500")); + } + + #[test] + fn classify_http_500_falls_back_to_status_when_body_is_not_json() { + let err = classify_http_error(500, "any", "oops"); + assert_eq!(err.kind, OllamaErrorKind::Other); + assert!(err.message.contains("500")); + } + + #[test] + fn classify_http_500_falls_back_to_status_when_error_field_is_missing() { + let err = classify_http_error(500, "any", r#"{"detail":"nope"}"#); + assert_eq!(err.kind, OllamaErrorKind::Other); + assert!(err.message.contains("500")); + } + + #[test] + fn classify_http_500_falls_back_to_status_when_error_field_is_blank() { + let err = classify_http_error(500, "any", r#"{"error":" "}"#); assert_eq!(err.kind, OllamaErrorKind::Other); assert!(err.message.contains("500")); } + #[test] + fn extract_ollama_error_message_handles_known_shapes() { + assert_eq!(extract_ollama_error_message(""), None); + assert_eq!(extract_ollama_error_message(" "), None); + assert_eq!(extract_ollama_error_message("not json"), None); + assert_eq!(extract_ollama_error_message(r#"{}"#), None); + assert_eq!( + extract_ollama_error_message(r#"{"error":""}"#), + None, + "blank error string should be treated as missing", + ); + assert_eq!( + extract_ollama_error_message(r#"{"error":"boom"}"#).as_deref(), + Some("boom"), + ); + } + #[test] fn classify_http_401_returns_other_with_status() { - let err = classify_http_error(401, "gemma4:e2b"); + let err = classify_http_error(401, "gemma4:e2b", ""); assert_eq!(err.kind, OllamaErrorKind::Other); assert!(err.message.contains("401")); } @@ -1327,6 +1414,46 @@ mod tests { ); } + #[tokio::test] + async fn http_500_surfaces_ollama_error_body_through_stream() { + let mut server = mockito::Server::new_async().await; + let body = + r#"{"error":"this model only supports one image while more than one image requested"}"#; + let mock = server + .mock("POST", "/api/chat") + .with_status(500) + .with_header("content-type", "application/json") + .with_body(body) + .create_async() + .await; + + let client = reqwest::Client::new(); + let token = CancellationToken::new(); + let (chunks, callback) = collect_chunks(); + + stream_ollama_chat( + &format!("{}/api/chat", server.url()), + "llama3.2-vision:11b", + vec![], + false, + &client, + token, + callback, + ) + .await; + + mock.assert_async().await; + let chunks = chunks.lock().unwrap(); + assert_eq!(chunks.len(), 1); + assert!(matches!( + &chunks[0], + StreamChunk::Error(e) + if e.kind == OllamaErrorKind::Other + && e.message.contains("only supports one image") + && !e.message.contains("HTTP 500") + )); + } + /// Helper: builds a `/api/chat` response line with both thinking and content fields. fn chat_line_with_thinking(thinking: &str, content: &str, done: bool) -> String { format!( diff --git a/src-tauri/src/config/defaults.rs b/src-tauri/src/config/defaults.rs index c8c02bb8..be3af79e 100644 --- a/src-tauri/src/config/defaults.rs +++ b/src-tauri/src/config/defaults.rs @@ -103,3 +103,32 @@ pub const BOUNDS_SEARXNG_MAX_RESULTS: (u32, u32) = (1, 20); /// ceiling: a timeout longer than that indicates a misconfiguration, not a /// slow service. pub const BOUNDS_TIMEOUT_S: (u64, u64) = (1, 300); + +// Ollama API baked-in limits: not exposed in config.toml because they bound +// attacker-controlled data (response bodies from the local Ollama daemon) and +// keep the UI responsive when the daemon is hung. Changing either timeout +// value would require re-tuning the UX; changing the byte caps would require +// re-evaluating the memory budget. + +/// Per-request timeout (in seconds) for the Ollama `/api/tags` GET. Guards +/// the IPC boundary: if the daemon accepts the TCP connection but never +/// responds, `get_model_picker_state` would otherwise block indefinitely and +/// wedge the UI. 5 seconds is generous for a localhost call. +pub const DEFAULT_OLLAMA_TAGS_REQUEST_TIMEOUT_SECS: u64 = 5; + +/// Per-request timeout (in seconds) for the Ollama `/api/show` POST. Same +/// rationale as `DEFAULT_OLLAMA_TAGS_REQUEST_TIMEOUT_SECS`: local-loopback +/// HTTP is normally instant, but capping prevents a wedged daemon from +/// blocking picker rendering. +pub const DEFAULT_OLLAMA_SHOW_REQUEST_TIMEOUT_SECS: u64 = 5; + +/// Maximum accepted body size for the Ollama `/api/tags` response. Guards +/// against a misbehaving or compromised localhost Ollama streaming an +/// unbounded response that would exhaust memory. 4 MiB comfortably fits +/// thousands of model entries. +pub const MAX_OLLAMA_TAGS_BODY_BYTES: usize = 4 * 1024 * 1024; + +/// Maximum accepted body size for the Ollama `/api/show` response. The full +/// Modelfile and parameters can be sizable, but 4 MiB is comfortably above +/// any real model and bounds attacker-controlled inputs. +pub const MAX_OLLAMA_SHOW_BODY_BYTES: usize = 4 * 1024 * 1024; diff --git a/src-tauri/src/config/loader.rs b/src-tauri/src/config/loader.rs index 954220a7..8149ead2 100644 --- a/src-tauri/src/config/loader.rs +++ b/src-tauri/src/config/loader.rs @@ -26,13 +26,12 @@ use super::defaults::{ BOUNDS_MAX_ITERATIONS, BOUNDS_OVERLAY_WIDTH, BOUNDS_QUOTE_MAX_CONTEXT_LENGTH, BOUNDS_QUOTE_MAX_DISPLAY_CHARS, BOUNDS_QUOTE_MAX_DISPLAY_LINES, BOUNDS_SEARXNG_MAX_RESULTS, BOUNDS_TIMEOUT_S, BOUNDS_TOP_K_URLS, DEFAULT_COLLAPSED_HEIGHT, DEFAULT_HIDE_COMMIT_DELAY_MS, - DEFAULT_JUDGE_TIMEOUT_S, DEFAULT_MAX_CHAT_HEIGHT, DEFAULT_MAX_ITERATIONS, DEFAULT_MODEL_NAME, - DEFAULT_OLLAMA_URL, DEFAULT_OVERLAY_WIDTH, DEFAULT_QUOTE_MAX_CONTEXT_LENGTH, - DEFAULT_QUOTE_MAX_DISPLAY_CHARS, DEFAULT_QUOTE_MAX_DISPLAY_LINES, - DEFAULT_READER_BATCH_TIMEOUT_S, DEFAULT_READER_PER_URL_TIMEOUT_S, DEFAULT_READER_URL, - DEFAULT_ROUTER_TIMEOUT_S, DEFAULT_SEARCH_TIMEOUT_S, DEFAULT_SEARXNG_MAX_RESULTS, - DEFAULT_SEARXNG_URL, DEFAULT_SYSTEM_PROMPT_BASE, DEFAULT_TOP_K_URLS, - SLASH_COMMAND_PROMPT_APPENDIX, + DEFAULT_JUDGE_TIMEOUT_S, DEFAULT_MAX_CHAT_HEIGHT, DEFAULT_MAX_ITERATIONS, DEFAULT_OLLAMA_URL, + DEFAULT_OVERLAY_WIDTH, DEFAULT_QUOTE_MAX_CONTEXT_LENGTH, DEFAULT_QUOTE_MAX_DISPLAY_CHARS, + DEFAULT_QUOTE_MAX_DISPLAY_LINES, DEFAULT_READER_BATCH_TIMEOUT_S, + DEFAULT_READER_PER_URL_TIMEOUT_S, DEFAULT_READER_URL, DEFAULT_ROUTER_TIMEOUT_S, + DEFAULT_SEARCH_TIMEOUT_S, DEFAULT_SEARXNG_MAX_RESULTS, DEFAULT_SEARXNG_URL, + DEFAULT_SYSTEM_PROMPT_BASE, DEFAULT_TOP_K_URLS, SLASH_COMMAND_PROMPT_APPENDIX, }; use super::error::ConfigError; use super::schema::AppConfig; @@ -110,19 +109,9 @@ fn rename_corrupt(path: &Path) { /// and composes the system prompt appendix into `prompt.resolved_system`. /// After this runs, every `AppConfig` field holds a usable value. pub(crate) fn resolve(config: &mut AppConfig) { - // Model section: empty available list or empty/whitespace entries -> default. - let cleaned: Vec = config - .model - .available - .iter() - .map(|m| m.trim().to_string()) - .filter(|m| !m.is_empty()) - .collect(); - config.model.available = if cleaned.is_empty() { - vec![DEFAULT_MODEL_NAME.to_string()] - } else { - cleaned - }; + // Model section: only the Ollama endpoint is configurable here. The + // active model is runtime UI state owned by SQLite app_config, see + // crate::models::ActiveModelState. if config.model.ollama_url.trim().is_empty() { config.model.ollama_url = DEFAULT_OLLAMA_URL.to_string(); } diff --git a/src-tauri/src/config/schema.rs b/src-tauri/src/config/schema.rs index 125374e9..2b8e2735 100644 --- a/src-tauri/src/config/schema.rs +++ b/src-tauri/src/config/schema.rs @@ -15,22 +15,26 @@ use serde::{Deserialize, Serialize}; use super::defaults::{ DEFAULT_COLLAPSED_HEIGHT, DEFAULT_HIDE_COMMIT_DELAY_MS, DEFAULT_JUDGE_TIMEOUT_S, - DEFAULT_MAX_CHAT_HEIGHT, DEFAULT_MAX_ITERATIONS, DEFAULT_MODEL_NAME, DEFAULT_OLLAMA_URL, - DEFAULT_OVERLAY_WIDTH, DEFAULT_QUOTE_MAX_CONTEXT_LENGTH, DEFAULT_QUOTE_MAX_DISPLAY_CHARS, + DEFAULT_MAX_CHAT_HEIGHT, DEFAULT_MAX_ITERATIONS, DEFAULT_OLLAMA_URL, DEFAULT_OVERLAY_WIDTH, + DEFAULT_QUOTE_MAX_CONTEXT_LENGTH, DEFAULT_QUOTE_MAX_DISPLAY_CHARS, DEFAULT_QUOTE_MAX_DISPLAY_LINES, DEFAULT_READER_BATCH_TIMEOUT_S, DEFAULT_READER_PER_URL_TIMEOUT_S, DEFAULT_READER_URL, DEFAULT_ROUTER_TIMEOUT_S, DEFAULT_SEARCH_TIMEOUT_S, DEFAULT_SEARXNG_MAX_RESULTS, DEFAULT_SEARXNG_URL, DEFAULT_TOP_K_URLS, }; -/// Model configuration. The first entry of `available` is the active model -/// used for all inference. Reorder the list (or use the future settings panel) -/// to switch models. Keeping a single list instead of separate `active` and -/// `available` fields eliminates the mismatch failure mode entirely. +/// Static, user-tunable model configuration. +/// +/// The active model selection is NOT stored here. Active-model state is +/// runtime UI state owned by [`crate::models::ActiveModelState`] and +/// persisted in the SQLite `app_config` table under +/// [`crate::models::ACTIVE_MODEL_KEY`]. Storing a model slug in TOML would +/// duplicate ground truth from Ollama's `/api/tags` and create a staleness +/// trap: the file would happily reference a model the user has since +/// removed. This section keeps only the truly static knob, the Ollama +/// endpoint URL. #[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] #[serde(default)] pub struct ModelSection { - /// Ollama models Thuki knows about. First entry is active. - pub available: Vec, /// HTTP base URL of the local Ollama instance. pub ollama_url: String, } @@ -38,24 +42,11 @@ pub struct ModelSection { impl Default for ModelSection { fn default() -> Self { Self { - available: vec![DEFAULT_MODEL_NAME.to_string()], ollama_url: DEFAULT_OLLAMA_URL.to_string(), } } } -impl ModelSection { - /// Returns the active model (first entry). Falls back to the compiled - /// default if the list is somehow empty at call time; the loader also - /// guarantees this never happens by calling `resolve` during load. - pub fn active(&self) -> &str { - self.available - .first() - .map(String::as_str) - .unwrap_or(DEFAULT_MODEL_NAME) - } -} - /// Prompt configuration. `system` holds only the user-editable base text. /// The slash-command appendix is composed at load time into `resolved_system` /// and is never written back to the file. `resolved_system` is computed, not diff --git a/src-tauri/src/config/tests.rs b/src-tauri/src/config/tests.rs index d738ef90..2dd67a9b 100644 --- a/src-tauri/src/config/tests.rs +++ b/src-tauri/src/config/tests.rs @@ -14,8 +14,8 @@ use std::path::PathBuf; use super::defaults::{ DEFAULT_COLLAPSED_HEIGHT, DEFAULT_HIDE_COMMIT_DELAY_MS, DEFAULT_JUDGE_TIMEOUT_S, - DEFAULT_MAX_CHAT_HEIGHT, DEFAULT_MAX_ITERATIONS, DEFAULT_MODEL_NAME, DEFAULT_OLLAMA_URL, - DEFAULT_OVERLAY_WIDTH, DEFAULT_QUOTE_MAX_CONTEXT_LENGTH, DEFAULT_QUOTE_MAX_DISPLAY_CHARS, + DEFAULT_MAX_CHAT_HEIGHT, DEFAULT_MAX_ITERATIONS, DEFAULT_OLLAMA_URL, DEFAULT_OVERLAY_WIDTH, + DEFAULT_QUOTE_MAX_CONTEXT_LENGTH, DEFAULT_QUOTE_MAX_DISPLAY_CHARS, DEFAULT_QUOTE_MAX_DISPLAY_LINES, DEFAULT_READER_BATCH_TIMEOUT_S, DEFAULT_READER_PER_URL_TIMEOUT_S, DEFAULT_READER_URL, DEFAULT_ROUTER_TIMEOUT_S, DEFAULT_SEARCH_TIMEOUT_S, DEFAULT_SEARXNG_MAX_RESULTS, DEFAULT_SEARXNG_URL, @@ -47,7 +47,6 @@ fn defaults_const_values_match_schema_defaults() { // Guard rail: a change to a default in defaults.rs must flow through to // AppConfig::default(). If this test fails, someone changed one but not both. let c = AppConfig::default(); - assert_eq!(c.model.available, vec![DEFAULT_MODEL_NAME.to_string()]); assert_eq!(c.model.ollama_url, DEFAULT_OLLAMA_URL); assert_eq!(c.prompt.system, ""); assert_eq!(c.prompt.resolved_system, ""); @@ -87,8 +86,7 @@ fn defaults_prompt_base_is_nonempty() { #[test] fn section_defaults_are_sensible() { let m = ModelSection::default(); - assert_eq!(m.available, vec![DEFAULT_MODEL_NAME.to_string()]); - assert_eq!(m.active(), DEFAULT_MODEL_NAME); + assert_eq!(m.ollama_url, DEFAULT_OLLAMA_URL); let p = PromptSection::default(); assert!(p.system.is_empty()); @@ -100,26 +98,6 @@ fn section_defaults_are_sensible() { assert_eq!(q.max_display_lines, DEFAULT_QUOTE_MAX_DISPLAY_LINES); } -#[test] -fn model_section_active_falls_back_when_list_empty() { - // Guard: loader should prevent this, but active() has a defensive fallback - // so the struct can't explode if a caller bypasses the loader. - let m = ModelSection { - available: vec![], - ollama_url: DEFAULT_OLLAMA_URL.to_string(), - }; - assert_eq!(m.active(), DEFAULT_MODEL_NAME); -} - -#[test] -fn model_section_active_returns_first() { - let m = ModelSection { - available: vec!["custom:model".to_string(), "other:model".to_string()], - ollama_url: DEFAULT_OLLAMA_URL.to_string(), - }; - assert_eq!(m.active(), "custom:model"); -} - #[test] fn app_config_serde_round_trip_matches_defaults() { let original = AppConfig::default(); @@ -138,11 +116,10 @@ fn app_config_partial_file_fills_missing_fields_with_defaults() { // Only declare one field; serde(default) fills the rest. let partial = r#" [model] - available = ["custom:only"] + ollama_url = "http://localhost:9999" "#; let parsed: AppConfig = toml::from_str(partial).expect("partial file parses"); - assert_eq!(parsed.model.available, vec!["custom:only".to_string()]); - assert_eq!(parsed.model.ollama_url, DEFAULT_OLLAMA_URL); + assert_eq!(parsed.model.ollama_url, "http://localhost:9999"); assert_eq!(parsed.window.overlay_width, DEFAULT_OVERLAY_WIDTH); assert_eq!( parsed.quote.max_display_lines, @@ -187,7 +164,7 @@ fn load_missing_file_seeds_defaults_and_returns_them() { let config = load_from_path(&path).expect("seed on first run"); assert!(path.exists(), "file should be seeded"); - assert_eq!(config.model.active(), DEFAULT_MODEL_NAME); + assert_eq!(config.model.ollama_url, DEFAULT_OLLAMA_URL); // Resolved system prompt composed from default base plus appendix. assert!(config .prompt @@ -206,7 +183,7 @@ fn load_missing_file_in_missing_parent_dir_creates_dir() { let path = config_path_in(&nested); let config = load_from_path(&path).expect("creates parent dir and seeds"); assert!(path.exists()); - assert_eq!(config.model.active(), DEFAULT_MODEL_NAME); + assert_eq!(config.model.ollama_url, DEFAULT_OLLAMA_URL); } #[test] @@ -233,18 +210,12 @@ fn load_existing_valid_file_returns_resolved_config() { &path, r#" [model] - available = ["custom:a", "custom:b"] ollama_url = "http://localhost:99999" "#, ) .unwrap(); let config = load_from_path(&path).unwrap(); - assert_eq!( - config.model.available, - vec!["custom:a".to_string(), "custom:b".to_string()] - ); - assert_eq!(config.model.active(), "custom:a"); assert_eq!(config.model.ollama_url, "http://localhost:99999"); } @@ -277,7 +248,7 @@ fn load_corrupt_file_is_renamed_and_reseeded() { std::fs::write(&path, "this is = definitely not [ valid toml").unwrap(); let config = load_from_path(&path).expect("recover from corrupt file"); - assert_eq!(config.model.active(), DEFAULT_MODEL_NAME); + assert_eq!(config.model.ollama_url, DEFAULT_OLLAMA_URL); // Original file renamed with .corrupt- prefix. let renamed_exists = std::fs::read_dir(&dir) @@ -315,7 +286,7 @@ fn load_unreadable_file_returns_in_memory_defaults() { } let config = load_from_path(&path).expect("fallback to in-memory defaults"); - assert_eq!(config.model.active(), DEFAULT_MODEL_NAME); + assert_eq!(config.model.ollama_url, DEFAULT_OLLAMA_URL); // Restore so cleanup works. let _ = std::fs::set_permissions(&path, std::fs::Permissions::from_mode(0o644)); } @@ -323,55 +294,23 @@ fn load_unreadable_file_returns_in_memory_defaults() { // ── loader: resolve (empties and bounds) ──────────────────────────────────── #[test] -fn resolve_empty_available_list_falls_back_to_default_model() { - let dir = fresh_temp_dir(); - let path = config_path_in(&dir); - std::fs::write( - &path, - r#" - [model] - available = [] - "#, - ) - .unwrap(); - let config = load_from_path(&path).unwrap(); - assert_eq!(config.model.available, vec![DEFAULT_MODEL_NAME.to_string()]); - assert_eq!(config.model.active(), DEFAULT_MODEL_NAME); -} - -#[test] -fn resolve_whitespace_only_entries_are_filtered() { - let dir = fresh_temp_dir(); - let path = config_path_in(&dir); - std::fs::write( - &path, - r#" - [model] - available = [" ", "custom:x", " ", "custom:y"] - "#, - ) - .unwrap(); - let config = load_from_path(&path).unwrap(); - assert_eq!( - config.model.available, - vec!["custom:x".to_string(), "custom:y".to_string()] - ); -} - -#[test] -fn resolve_entry_whitespace_is_trimmed() { +fn resolve_unknown_model_field_is_ignored() { + // Older config files seeded a `[model] available = [...]` list. After + // removing that field from the schema, serde must silently drop it + // rather than refusing to parse the file. let dir = fresh_temp_dir(); let path = config_path_in(&dir); std::fs::write( &path, r#" [model] - available = [" spaced:model "] + available = ["legacy:model", "another:model"] + ollama_url = "http://localhost:11434" "#, ) .unwrap(); let config = load_from_path(&path).unwrap(); - assert_eq!(config.model.available, vec!["spaced:model".to_string()]); + assert_eq!(config.model.ollama_url, "http://localhost:11434"); } #[test] diff --git a/src-tauri/src/database.rs b/src-tauri/src/database.rs index d875891f..c6f1065f 100644 --- a/src-tauri/src/database.rs +++ b/src-tauri/src/database.rs @@ -14,7 +14,7 @@ use serde::Serialize; /// Tuple representing a message for batch insertion: /// (role, content, quoted_text, image_paths, thinking_content, search_sources, -/// search_warnings, search_metadata). +/// search_warnings, search_metadata, model_name). pub type MessageBatchRow = ( String, String, @@ -24,6 +24,7 @@ pub type MessageBatchRow = ( Option, Option, Option, + Option, ); /// Summary of a conversation for the history dropdown list. @@ -54,6 +55,9 @@ pub struct PersistedMessage { /// JSON-serialized `SearchMetadata` (iteration traces, timing) for this /// search turn. `None` for non-search messages and pre-Task-17 rows. pub search_metadata: Option, + /// Slug of the Ollama model that produced this assistant message. `None` + /// for user messages and rows written before the model_name migration. + pub model_name: Option, pub created_at: i64, } @@ -128,10 +132,48 @@ fn migrate_legacy_db(new_path: &std::path::Path) { } } +/// Returns true if `s` is a safe SQL identifier: non-empty, starts with an +/// ASCII letter or underscore, and contains only ASCII alphanumerics and +/// underscores thereafter. This subset covers every identifier Thuki uses and +/// excludes metacharacters that could turn a DDL statement into an injection. +fn is_safe_sql_ident(s: &str) -> bool { + !s.is_empty() + && s.chars().enumerate().all(|(i, c)| { + if i == 0 { + c.is_ascii_alphabetic() || c == '_' + } else { + c.is_ascii_alphanumeric() || c == '_' + } + }) +} + /// Idempotently adds a column to a SQLite table. A no-op when the column /// already exists. SQLite does not support `ALTER TABLE ... ADD COLUMN IF NOT /// EXISTS`, so we inspect `PRAGMA table_info` first. +/// +/// `col_type` may contain spaces (e.g. `"TEXT NOT NULL"`); each +/// whitespace-separated token is validated individually as a safe SQL +/// identifier. `table` and `column` must each be a single safe identifier. +/// Returns `Err` if any argument fails the allowlist check. fn ensure_column(conn: &Connection, table: &str, column: &str, col_type: &str) -> SqlResult<()> { + if !is_safe_sql_ident(table) { + return Err(rusqlite::Error::InvalidParameterName(format!( + "unsafe table name: {table:?}" + ))); + } + if !is_safe_sql_ident(column) { + return Err(rusqlite::Error::InvalidParameterName(format!( + "unsafe column name: {column:?}" + ))); + } + for token in col_type.split_whitespace() { + if !is_safe_sql_ident(token) { + return Err(rusqlite::Error::InvalidParameterName(format!( + "unsafe col_type token: {token:?}" + ))); + } + } + let exists: bool = conn .prepare(&format!("PRAGMA table_info({table})"))? .query_map([], |row| row.get::<_, String>(1))? @@ -176,6 +218,10 @@ fn run_migrations(conn: &Connection) -> SqlResult<()> { // JSON-encoded Vec and SearchMetadata (Task 17). ensure_column(conn, "messages", "search_warnings", "TEXT")?; ensure_column(conn, "messages", "search_metadata", "TEXT")?; + // Per-message model attribution (slug of the Ollama model that produced + // the assistant response). NULL for user messages and rows written before + // this migration. + ensure_column(conn, "messages", "model_name", "TEXT")?; Ok(()) } @@ -297,6 +343,7 @@ pub fn insert_message( search_sources: Option<&str>, search_warnings: Option<&str>, search_metadata: Option<&str>, + model_name: Option<&str>, ) -> SqlResult { let id = uuid::Uuid::new_v4().to_string(); let now = now_millis(); @@ -304,8 +351,8 @@ pub fn insert_message( "INSERT INTO messages \ (id, conversation_id, role, content, quoted_text, image_paths, \ thinking_content, search_sources, search_warnings, search_metadata, \ - created_at) \ - VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9, ?10, ?11)", + model_name, created_at) \ + VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9, ?10, ?11, ?12)", params![ id, conversation_id, @@ -317,6 +364,7 @@ pub fn insert_message( search_sources, search_warnings, search_metadata, + model_name, now ], )?; @@ -341,8 +389,8 @@ pub fn insert_messages_batch( "INSERT INTO messages \ (id, conversation_id, role, content, quoted_text, image_paths, \ thinking_content, search_sources, search_warnings, search_metadata, \ - created_at) \ - VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9, ?10, ?11)", + model_name, created_at) \ + VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9, ?10, ?11, ?12)", )?; for ( role, @@ -353,6 +401,7 @@ pub fn insert_messages_batch( search_sources, search_warnings, search_metadata, + model_name, ) in messages { let id = uuid::Uuid::new_v4().to_string(); @@ -367,6 +416,7 @@ pub fn insert_messages_batch( search_sources.as_deref(), search_warnings.as_deref(), search_metadata.as_deref(), + model_name.as_deref(), now ])?; } @@ -383,7 +433,7 @@ pub fn insert_messages_batch( pub fn load_messages(conn: &Connection, conversation_id: &str) -> SqlResult> { let mut stmt = conn.prepare( "SELECT id, role, content, quoted_text, image_paths, thinking_content, \ - search_sources, search_warnings, search_metadata, created_at + search_sources, search_warnings, search_metadata, model_name, created_at FROM messages WHERE conversation_id = ?1 ORDER BY created_at ASC", @@ -399,7 +449,8 @@ pub fn load_messages(conn: &Connection, conversation_id: &str) -> SqlResult = conn + .prepare("PRAGMA table_info(messages)") + .unwrap() + .query_map([], |row| row.get::<_, String>(1)) + .unwrap() + .filter_map(|r| r.ok()) + .collect(); + assert!(cols.contains(&"collated_col".to_string())); + } + + #[test] + fn ensure_column_rejects_empty_table_name() { + let conn = open_in_memory().unwrap(); + let result = ensure_column(&conn, "", "col", "TEXT"); + assert!(result.is_err(), "expected error for empty table name"); + } + + #[test] + fn ensure_column_rejects_empty_column_name() { + let conn = open_in_memory().unwrap(); + let result = ensure_column(&conn, "messages", "", "TEXT"); + assert!(result.is_err(), "expected error for empty column name"); + } + // ── search_warnings / search_metadata round-trip ───────────────────────── #[test] @@ -1095,6 +1257,7 @@ mod tests { None, Some(warnings_json), Some(metadata_json), + None, ) .unwrap(); @@ -1111,7 +1274,7 @@ mod tests { // No warnings or metadata (ordinary non-search message). insert_message( - &conn, &conv_id, "user", "hello", None, None, None, None, None, None, + &conn, &conv_id, "user", "hello", None, None, None, None, None, None, None, ) .unwrap(); @@ -1120,4 +1283,121 @@ mod tests { assert!(msgs[0].search_warnings.is_none()); assert!(msgs[0].search_metadata.is_none()); } + + // ── model_name column + round-trip ─────────────────────────────────────── + + #[test] + fn model_name_column_exists_after_migration() { + let conn = open_in_memory().unwrap(); + let cols: Vec = conn + .prepare("PRAGMA table_info(messages)") + .unwrap() + .query_map([], |row| row.get::<_, String>(1)) + .unwrap() + .filter_map(|r| r.ok()) + .collect(); + assert!(cols.contains(&"model_name".to_string())); + } + + #[test] + fn insert_message_with_model_name_round_trips() { + let conn = open_in_memory().unwrap(); + let id = create_conversation(&conn, None, "gemma4:e2b").unwrap(); + + insert_message( + &conn, + &id, + "assistant", + "Hello from gemma.", + None, + None, + None, + None, + None, + None, + Some("gemma4:e2b"), + ) + .unwrap(); + + let msgs = load_messages(&conn, &id).unwrap(); + assert_eq!(msgs.len(), 1); + assert_eq!(msgs[0].model_name.as_deref(), Some("gemma4:e2b")); + } + + #[test] + fn insert_message_with_null_model_name() { + let conn = open_in_memory().unwrap(); + let id = create_conversation(&conn, None, "gemma4:e2b").unwrap(); + + insert_message( + &conn, &id, "user", "hi there", None, None, None, None, None, None, None, + ) + .unwrap(); + + let msgs = load_messages(&conn, &id).unwrap(); + assert_eq!(msgs.len(), 1); + assert!(msgs[0].model_name.is_none()); + } + + #[test] + fn insert_messages_batch_includes_model_name() { + let conn = open_in_memory().unwrap(); + let id = create_conversation(&conn, None, "gemma4:e2b").unwrap(); + + let batch = vec![ + ( + "assistant".to_string(), + "answer from gemma".to_string(), + None, + None, + None, + None, + None, + None, + Some("gemma4:e2b".to_string()), + ), + ( + "assistant".to_string(), + "answer from qwen".to_string(), + None, + None, + None, + None, + None, + None, + Some("qwen2.5:7b".to_string()), + ), + ]; + insert_messages_batch(&conn, &id, &batch).unwrap(); + + let msgs = load_messages(&conn, &id).unwrap(); + assert_eq!(msgs.len(), 2); + assert_eq!(msgs[0].model_name.as_deref(), Some("gemma4:e2b")); + assert_eq!(msgs[1].model_name.as_deref(), Some("qwen2.5:7b")); + } + + #[test] + fn load_messages_handles_null_model_name_for_legacy_rows() { + let conn = open_in_memory().unwrap(); + let id = create_conversation(&conn, None, "gemma4:e2b").unwrap(); + + // Simulate a row written before the model_name migration by inserting + // with an explicit column list that omits model_name entirely. + conn.execute( + "INSERT INTO messages (id, conversation_id, role, content, created_at) \ + VALUES (?1, ?2, ?3, ?4, ?5)", + params![ + uuid::Uuid::new_v4().to_string(), + &id, + "assistant", + "legacy row", + now_millis(), + ], + ) + .unwrap(); + + let msgs = load_messages(&conn, &id).unwrap(); + assert_eq!(msgs.len(), 1); + assert!(msgs[0].model_name.is_none()); + } } diff --git a/src-tauri/src/history.rs b/src-tauri/src/history.rs index df16e845..864a109a 100644 --- a/src-tauri/src/history.rs +++ b/src-tauri/src/history.rs @@ -16,6 +16,7 @@ use tauri::State; use crate::commands::{ChatMessage, ConversationHistory}; use crate::config::AppConfig; use crate::database; +use crate::models::ActiveModelState; /// Thread-safe wrapper around the SQLite connection. pub struct Database(pub Mutex); @@ -47,6 +48,10 @@ pub struct SaveMessagePayload { /// Already-serialised `SearchMetadata` JSON string for search turns. /// Passed through verbatim to `messages.search_metadata`. pub search_metadata: Option, + /// Slug of the Ollama model that produced this response. Frontend stamps + /// assistant payloads with the active model at generation time; `None` + /// for user payloads. Accepted as missing via serde Option default. + pub model_name: Option, } /// Response returned when saving a conversation. @@ -63,9 +68,13 @@ pub struct SaveConversationResponse { pub fn save_conversation( messages: Vec, db: State<'_, Database>, - app_config: State<'_, AppConfig>, + active_model: State<'_, ActiveModelState>, ) -> Result { let conn = db.0.lock().map_err(|e| e.to_string())?; + let model_slug = { + let guard = active_model.0.lock().map_err(|e| e.to_string())?; + guard.clone() + }; // Use the first user message (truncated) as the initial title placeholder. let placeholder_title = messages.iter().find(|m| m.role == "user").map(|m| { @@ -85,12 +94,9 @@ pub fn save_conversation( } }); - let conversation_id = database::create_conversation( - &conn, - placeholder_title.as_deref(), - app_config.model.active(), - ) - .map_err(|e| e.to_string())?; + let conversation_id = + database::create_conversation(&conn, placeholder_title.as_deref(), &model_slug) + .map_err(|e| e.to_string())?; let batch: Vec = messages .into_iter() @@ -111,6 +117,7 @@ pub fn save_conversation( sources_json, m.search_warnings, m.search_metadata, + m.model_name, ) }) .collect(); @@ -134,6 +141,7 @@ pub fn persist_message( search_sources: Option>, search_warnings: Option, search_metadata: Option, + model_name: Option, db: State<'_, Database>, ) -> Result<(), String> { let conn = db.0.lock().map_err(|e| e.to_string())?; @@ -154,6 +162,7 @@ pub fn persist_message( sources_json.as_deref(), search_warnings.as_deref(), search_metadata.as_deref(), + model_name.as_deref(), ) .map_err(|e| e.to_string())?; Ok(()) @@ -244,6 +253,7 @@ pub fn delete_conversation( pub async fn generate_title( conversation_id: String, messages: Vec, + model: String, db: State<'_, Database>, client: State<'_, reqwest::Client>, app_config: State<'_, AppConfig>, @@ -289,7 +299,7 @@ pub async fn generate_title( let cancel_token = tokio_util::sync::CancellationToken::new(); let accumulated = crate::commands::stream_ollama_chat( &endpoint, - app_config.model.active(), + &model, title_messages, false, &client, @@ -342,6 +352,7 @@ mod tests { search_sources: None, search_warnings: None, search_metadata: None, + model_name: None, }, SaveMessagePayload { role: "assistant".to_string(), @@ -363,6 +374,7 @@ mod tests { search_metadata: Some( r#"{"iterations":[],"total_duration_ms":10,"retries_performed":0}"#.to_string(), ), + model_name: Some("gemma4:e2b".to_string()), }, ]; @@ -398,6 +410,7 @@ mod tests { sources_json, m.search_warnings, m.search_metadata, + m.model_name, ) }) .collect(); @@ -435,6 +448,8 @@ mod tests { .contains("total_duration_ms")); assert!(loaded[0].search_warnings.is_none()); assert!(loaded[0].search_metadata.is_none()); + assert!(loaded[0].model_name.is_none()); + assert_eq!(loaded[1].model_name.as_deref(), Some("gemma4:e2b")); } #[test] diff --git a/src-tauri/src/lib.rs b/src-tauri/src/lib.rs index 3a2a035e..ba55310a 100644 --- a/src-tauri/src/lib.rs +++ b/src-tauri/src/lib.rs @@ -20,6 +20,7 @@ pub mod config; pub mod database; pub mod history; pub mod images; +pub mod models; pub mod onboarding; pub mod screenshot; pub mod search; @@ -390,35 +391,46 @@ fn notify_frontend_ready(app_handle: tauri::AppHandle, db: tauri::State, + app_handle: tauri::AppHandle, +) -> Result<(), String> { + let conn = db.0.lock().map_err(|e| format!("db lock poisoned: {e}"))?; + onboarding::set_stage(&conn, &onboarding::OnboardingStage::Intro) + .map_err(|e| format!("db write failed: {e}"))?; + drop(conn); + + let _ = app_handle.emit( + ONBOARDING_EVENT, + OnboardingPayload { + stage: onboarding::OnboardingStage::Intro, + }, + ); + Ok(()) +} + // ─── Onboarding completion ─────────────────────────────────────────────────── /// Called when the user clicks "Get Started" on the intro screen. @@ -733,6 +776,27 @@ pub fn run() { .expect("failed to resolve app data directory"); let db_conn = database::open_database(&app_data_dir) .expect("failed to initialise SQLite database"); + + // ── Active-model state: seed from SQLite app_config table ── + // The installed list isn't queried here (no async runtime yet). + // get_model_picker_state reconciles against the live /api/tags + // inventory on first picker open and may replace this seed. + // The placeholder DEFAULT_MODEL_NAME bootstrap is a transient + // value used only until that first reconciliation, and is the + // last-resort fallback when both the persisted slug and the + // live installed list are absent. Phase 3 will gate the + // overlay on a real installed model so that placeholder is + // never streamed to Ollama. + let persisted_active = database::get_config(&db_conn, models::ACTIVE_MODEL_KEY) + .expect("failed to read active_model from app_config"); + let initial_active_model = models::resolve_seed_active_model( + persisted_active.as_deref(), + crate::config::defaults::DEFAULT_MODEL_NAME, + ); + app.manage(models::ActiveModelState(std::sync::Mutex::new( + initial_active_model, + ))); + app.manage(models::ModelCapabilitiesCache::default()); app.manage(history::Database(std::sync::Mutex::new(db_conn))); // ── Orphaned image cleanup (startup + periodic) ───────── @@ -755,6 +819,14 @@ pub fn run() { #[cfg(not(coverage))] commands::get_config, #[cfg(not(coverage))] + models::get_model_picker_state, + #[cfg(not(coverage))] + models::set_active_model, + #[cfg(not(coverage))] + models::check_model_setup, + #[cfg(not(coverage))] + models::get_model_capabilities, + #[cfg(not(coverage))] history::save_conversation, #[cfg(not(coverage))] history::persist_message, @@ -793,7 +865,8 @@ pub fn run() { permissions::check_screen_recording_tcc_granted, #[cfg(not(coverage))] permissions::quit_and_relaunch, - finish_onboarding + finish_onboarding, + advance_past_model_check ]) .build(tauri::generate_context!()) .expect("error while building tauri application") diff --git a/src-tauri/src/models.rs b/src-tauri/src/models.rs new file mode 100644 index 00000000..e02c62be --- /dev/null +++ b/src-tauri/src/models.rs @@ -0,0 +1,1895 @@ +/*! + * Active-model state module. + * + * Single source of truth for the locally-selected Ollama model. The "active" + * model is whichever slug the user last picked via the picker popup, + * persisted across launches in `app_config` under [`ACTIVE_MODEL_KEY`] and + * mirrored in [`ActiveModelState`] for fast reads from Tauri commands. + * + * The backend treats Ollama's `/api/tags` response as authoritative: a + * persisted model is only honored if it still appears in the live installed + * list. If not, we fall back to the first installed model, then to the + * bootstrap default from `THUKI_SUPPORTED_AI_MODELS`. + */ + +use std::collections::HashMap; +use std::sync::Mutex; + +use futures_util::StreamExt; +use serde::{Deserialize, Serialize}; + +use crate::config::defaults::{ + DEFAULT_OLLAMA_SHOW_REQUEST_TIMEOUT_SECS, DEFAULT_OLLAMA_TAGS_REQUEST_TIMEOUT_SECS, + MAX_OLLAMA_SHOW_BODY_BYTES, MAX_OLLAMA_TAGS_BODY_BYTES, +}; +use crate::config::AppConfig; +use crate::database::{get_config, set_config}; +use crate::history::Database; + +/// `app_config` key used to persist the user's selected model slug. +pub const ACTIVE_MODEL_KEY: &str = "active_model"; + +/// Maximum accepted byte length for a model slug passed to `set_active_model`. +/// Real Ollama slugs are a handful of characters; 256 is generous while still +/// capping adversarial inputs long before any network or database work. +pub const MAX_MODEL_SLUG_LEN: usize = 256; + +/// Shared error-message prefix used when a requested slug is not present in +/// the live Ollama inventory. Exported so the frontend and tests can match +/// against a stable constant instead of a prose string. +pub const MODEL_NOT_INSTALLED_ERR_PREFIX: &str = "Model is not installed in Ollama: "; + +/// In-memory cache of the currently active model slug. Written once at +/// startup (after `resolve_seed_active_model`) and updated every time the +/// user picks a new model via `set_active_model`. +#[derive(Default)] +pub struct ActiveModelState(pub Mutex); + +/// Top-level shape of the Ollama `/api/tags` response. Only the `models` +/// array is consumed; all other fields are ignored. +#[derive(Deserialize)] +struct TagsResponse { + models: Vec, +} + +/// A single entry in the `/api/tags` `models` array. Only the `name` slug +/// is needed; everything else (size, digest, modified_at, details) is +/// deliberately ignored to keep the schema surface small. +#[derive(Deserialize)] +struct TagsModel { + name: String, +} + +/// Chooses which model slug should be active given a persisted preference, +/// the live installed list from Ollama, and an env-derived bootstrap value. +/// +/// Resolution rules, in order: +/// 1. If `persisted` is `Some` and still appears in `installed`, use it. +/// 2. Otherwise use the first entry in `installed`. +/// 3. Otherwise fall back to `bootstrap` (the compiled-in / env default). +/// +/// This helper assumes `installed` reflects real Ollama ground truth. At +/// startup when no ground truth is available, use +/// [`resolve_seed_active_model`] instead so a valid persisted choice is +/// never overridden by the bootstrap default just because Ollama has not +/// been queried yet. +pub fn resolve_active_model( + persisted: Option<&str>, + installed: &[String], + bootstrap: &str, +) -> String { + if let Some(p) = persisted { + if installed.iter().any(|m| m == p) { + return p.to_string(); + } + } + if let Some(first) = installed.first() { + return first.clone(); + } + bootstrap.to_string() +} + +/// Startup-time resolver that never cross-checks against an installed list. +/// +/// At process start we cannot call Ollama (no async runtime yet), so the +/// safe behavior is to trust the persisted value when present and only fall +/// back to the bootstrap default when nothing was ever persisted. The first +/// `get_model_picker_state` call from the frontend reconciles against the +/// real installed list and may replace this seed. +pub fn resolve_seed_active_model(persisted: Option<&str>, bootstrap: &str) -> String { + match persisted { + Some(slug) if !slug.is_empty() => slug.to_string(), + _ => bootstrap.to_string(), + } +} + +/// Returns true when the resolved slug should be written back to persistent +/// storage. Only writes when Ollama actually reported some inventory AND the +/// resolved slug differs from the currently-persisted value. This prevents a +/// partially-up Ollama returning `models:[]` from clobbering a valid +/// persisted user preference with the bootstrap fallback. +pub fn should_persist_resolved( + installed: &[String], + persisted: Option<&str>, + resolved: &str, +) -> bool { + !installed.is_empty() && persisted != Some(resolved) +} + +/// Verifies that `model` is present in `installed`. Returns an `Err` with +/// a stable prefix (see [`MODEL_NOT_INSTALLED_ERR_PREFIX`]) so the frontend +/// can match against a constant rather than a verbatim prose string. +pub fn validate_model_installed(model: &str, installed: &[String]) -> Result<(), String> { + if installed.iter().any(|m| m == model) { + Ok(()) + } else { + Err(format!("{MODEL_NOT_INSTALLED_ERR_PREFIX}{model}")) + } +} + +/// Validates shape of a model slug coming across the IPC boundary before any +/// network work. Rejects empty, over-length, and out-of-charset inputs. +/// Accepted charset covers everything real Ollama slugs use: +/// `A-Z a-z 0-9 : . _ / -`. +pub fn validate_model_slug(model: &str) -> Result<(), String> { + if model.is_empty() { + return Err("Model name cannot be empty".to_string()); + } + if model.len() > MAX_MODEL_SLUG_LEN { + return Err(format!( + "Model name exceeds maximum length of {MAX_MODEL_SLUG_LEN} bytes" + )); + } + if !model + .chars() + .all(|c| c.is_ascii_alphanumeric() || matches!(c, ':' | '.' | '_' | '/' | '-')) + { + return Err("Model name contains invalid characters".to_string()); + } + Ok(()) +} + +/// GETs `{base_url}/api/tags` and returns the list of installed model slugs. +/// +/// Every failure mode (transport error, non-2xx status, oversized body, +/// JSON decode error) is translated to `Err(String)` so the Tauri command +/// layer can propagate it verbatim to the frontend without panicking. +pub async fn fetch_installed_model_names( + client: &reqwest::Client, + base_url: &str, +) -> Result, String> { + fetch_installed_model_names_with_timeout( + client, + base_url, + std::time::Duration::from_secs(DEFAULT_OLLAMA_TAGS_REQUEST_TIMEOUT_SECS), + ) + .await +} + +/// Internal variant of [`fetch_installed_model_names`] with a configurable +/// per-request timeout. Exists so tests can exercise the timeout branch +/// deterministically without waiting the production 5s. +async fn fetch_installed_model_names_with_timeout( + client: &reqwest::Client, + base_url: &str, + timeout: std::time::Duration, +) -> Result, String> { + fetch_installed_model_names_inner(client, base_url, timeout, MAX_OLLAMA_TAGS_BODY_BYTES).await +} + +/// Innermost implementation of the tags fetcher with both timeout and body +/// size cap configurable. Exists so the size-cap branches can be exercised +/// deterministically in tests without allocating production-scale buffers. +/// +/// The cap is enforced incrementally during the streaming read: each chunk +/// is checked before being appended, so the connection is aborted the moment +/// the running total would exceed `max_body_bytes` rather than after the full +/// body has been buffered. +async fn fetch_installed_model_names_inner( + client: &reqwest::Client, + base_url: &str, + timeout: std::time::Duration, + max_body_bytes: usize, +) -> Result, String> { + let url = format!("{}/api/tags", base_url.trim_end_matches('/')); + let response = client + .get(&url) + .timeout(timeout) + .send() + .await + .map_err(|e| format!("failed to reach Ollama: {e}"))?; + + if !response.status().is_success() { + return Err(format!( + "Ollama /api/tags returned HTTP {}", + response.status().as_u16() + )); + } + + if let Some(declared_len) = response.content_length() { + if declared_len as usize > max_body_bytes { + return Err(format!( + "/api/tags response exceeded {max_body_bytes} bytes" + )); + } + } + + let mut stream = response.bytes_stream(); + let mut buf: Vec = Vec::new(); + while let Some(chunk) = stream.next().await { + let chunk = chunk.map_err(|e| format!("failed to read /api/tags body: {e}"))?; + if buf.len() + chunk.len() > max_body_bytes { + return Err(format!( + "/api/tags response exceeded {max_body_bytes} bytes" + )); + } + buf.extend_from_slice(&chunk); + } + + let body: TagsResponse = serde_json::from_slice(&buf) + .map_err(|e| format!("failed to decode /api/tags response: {e}"))?; + + Ok(body.models.into_iter().map(|m| m.name).collect()) +} + +/// Returns the currently active model and the full list of installed models, +/// persisting the resolved active model so future launches see it. +/// +/// Shape: `{ "active": "", "all": ["", ...] }`. +/// +/// Coalesces the read + conditional write into a single database critical +/// section to avoid a TOCTOU window where a concurrent `set_active_model` +/// could be clobbered, and refuses to persist when Ollama reports an empty +/// inventory so a partially-up daemon cannot corrupt the persisted choice. +#[cfg_attr(coverage_nightly, coverage(off))] +#[cfg_attr(not(coverage), tauri::command)] +pub async fn get_model_picker_state( + client: tauri::State<'_, reqwest::Client>, + db: tauri::State<'_, Database>, + active_model: tauri::State<'_, ActiveModelState>, + config: tauri::State<'_, AppConfig>, +) -> Result { + let installed = fetch_installed_model_names(&client, &config.model.ollama_url).await?; + + let resolved = { + let conn = db.0.lock().map_err(|e| e.to_string())?; + let persisted = get_config(&conn, ACTIVE_MODEL_KEY).map_err(|e| e.to_string())?; + let resolved = resolve_active_model( + persisted.as_deref(), + &installed, + crate::config::defaults::DEFAULT_MODEL_NAME, + ); + if should_persist_resolved(&installed, persisted.as_deref(), &resolved) { + set_config(&conn, ACTIVE_MODEL_KEY, &resolved).map_err(|e| e.to_string())?; + } + resolved + }; + + { + let mut guard = active_model.0.lock().map_err(|e| e.to_string())?; + *guard = resolved.clone(); + } + + Ok(serde_json::json!({ "active": resolved, "all": installed })) +} + +/// Persists `model` as the active model after validating its shape and +/// confirming Ollama still reports it as installed. Rejects uninstalled +/// slugs with an error that starts with [`MODEL_NOT_INSTALLED_ERR_PREFIX`]. +#[cfg_attr(coverage_nightly, coverage(off))] +#[cfg_attr(not(coverage), tauri::command)] +pub async fn set_active_model( + model: String, + client: tauri::State<'_, reqwest::Client>, + db: tauri::State<'_, Database>, + active_model: tauri::State<'_, ActiveModelState>, + config: tauri::State<'_, AppConfig>, +) -> Result<(), String> { + validate_model_slug(&model)?; + + let installed = fetch_installed_model_names(&client, &config.model.ollama_url).await?; + validate_model_installed(&model, &installed)?; + + { + let conn = db.0.lock().map_err(|e| e.to_string())?; + set_config(&conn, ACTIVE_MODEL_KEY, &model).map_err(|e| e.to_string())?; + } + + { + let mut guard = active_model.0.lock().map_err(|e| e.to_string())?; + *guard = model; + } + + Ok(()) +} + +// ─── Model setup gate (Phase 3 onboarding) ────────────────────────────────── + +/// Result of probing the local Ollama daemon for setup readiness. +/// +/// Drives the Phase 3 onboarding gate that fires after the user grants +/// macOS permissions but before the chat overlay is allowed to open. +/// Variants are emitted to the frontend in `snake_case` with an +/// internally-tagged `state` discriminator so the React side can route +/// on a single string field without inspecting payload shape. +#[derive(Debug, Clone, PartialEq, serde::Serialize)] +#[serde(tag = "state", rename_all = "snake_case")] +pub enum ModelSetupState { + /// `/api/tags` could not be reached. Treat as "Ollama is not installed + /// or not running"; the UI must guide the user to install or start it. + OllamaUnreachable, + /// `/api/tags` responded successfully but the installed list is empty. + /// The UI must guide the user to `ollama pull `. + NoModelsInstalled, + /// Ollama is running with at least one installed model. `active_slug` + /// is the slug we resolved (persisted preference if still installed, + /// else first installed) and `installed` is the live list for the + /// frontend to render in the picker. + Ready { + active_slug: String, + installed: Vec, + }, +} + +/// Pure state-machine derivation: maps the result of probing `/api/tags` +/// plus the persisted active-slug preference into a [`ModelSetupState`]. +/// +/// Exists as a free function so the three branches can be unit-tested +/// without spinning up an HTTP server or a Tauri runtime. The fetch +/// result and persisted preference are the only inputs; no I/O happens +/// here. The Tauri command is a thin wrapper that calls the fetcher, +/// reads the persisted slug from SQLite, then delegates here. +/// +/// Resolution rules for the Ready arm match +/// [`resolve_active_model`]: prefer the persisted slug when it is still +/// installed; otherwise fall back to the first installed slug. The +/// `bootstrap` argument is the compile-time fallback used only when +/// both inputs are absent, which by definition cannot happen on the +/// Ready arm (it would have routed to NoModelsInstalled). +pub fn derive_model_setup_state( + installed_result: Result, String>, + persisted: Option<&str>, + bootstrap: &str, +) -> ModelSetupState { + match installed_result { + Err(_) => ModelSetupState::OllamaUnreachable, + Ok(installed) if installed.is_empty() => ModelSetupState::NoModelsInstalled, + Ok(installed) => { + let active_slug = resolve_active_model(persisted, &installed, bootstrap); + ModelSetupState::Ready { + active_slug, + installed, + } + } + } +} + +/// Probes Ollama for setup readiness and returns the typed +/// [`ModelSetupState`] for the frontend onboarding gate. +/// +/// Idempotent: safe to call on every overlay open. The Ready arm also +/// commits two side effects, both intentionally bounded: +/// +/// 1. If the resolved slug differs from the persisted slug AND the live +/// installed list is non-empty, persist the resolved slug. This heals +/// the case where a user removed their previously-selected model with +/// `ollama rm` between launches. +/// 2. Mirror the resolved slug into the in-memory [`ActiveModelState`] so +/// `ask_ollama` and `search_pipeline` see it on the next request +/// without an extra DB read. +/// +/// Both writes are gated through [`should_persist_resolved`] which +/// refuses to persist when Ollama reports an empty inventory (i.e. +/// daemon is up but mid-restart), so a transient empty response cannot +/// clobber a valid persisted choice. +#[cfg_attr(coverage_nightly, coverage(off))] +#[cfg_attr(not(coverage), tauri::command)] +pub async fn check_model_setup( + client: tauri::State<'_, reqwest::Client>, + db: tauri::State<'_, Database>, + active_model: tauri::State<'_, ActiveModelState>, + config: tauri::State<'_, AppConfig>, +) -> Result { + let installed_result = fetch_installed_model_names(&client, &config.model.ollama_url).await; + + let persisted = { + let conn = db.0.lock().map_err(|e| e.to_string())?; + get_config(&conn, ACTIVE_MODEL_KEY).map_err(|e| e.to_string())? + }; + + let state = derive_model_setup_state( + installed_result, + persisted.as_deref(), + crate::config::defaults::DEFAULT_MODEL_NAME, + ); + + if let ModelSetupState::Ready { + ref active_slug, + ref installed, + } = state + { + if should_persist_resolved(installed, persisted.as_deref(), active_slug) { + let conn = db.0.lock().map_err(|e| e.to_string())?; + set_config(&conn, ACTIVE_MODEL_KEY, active_slug).map_err(|e| e.to_string())?; + } + let mut guard = active_model.0.lock().map_err(|e| e.to_string())?; + *guard = active_slug.clone(); + } + + Ok(state) +} + +// ─── Model capabilities (vision, thinking) ────────────────────────────────── + +/// Per-model capability flags surfaced to the frontend so the picker can +/// label rows and the submit-time gate can refuse mismatched messages +/// (image attached + text-only model). Booleans are derived from Ollama's +/// `/api/show` `capabilities` array; unknown strings are ignored so future +/// Ollama additions cannot break the schema. +/// +/// Thuki surfaces exactly two capability flags. `completion` is implicit +/// (every chat model supports it; absence is rendered as the "text" tag +/// on the frontend). `tools`, embedding, and any future Ollama additions +/// are intentionally dropped so the picker stays focused on the +/// distinctions Thuki actually drives behavior off of. +#[derive(Debug, Clone, Default, PartialEq, Eq, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct Capabilities { + /// Model accepts image inputs alongside text prompts. Drives the + /// submit-time vision gate. + #[serde(default)] + pub vision: bool, + /// Model emits explicit reasoning tokens that Thuki renders in the + /// ThinkingBlock UI. + #[serde(default)] + pub thinking: bool, + /// Maximum number of images the model accepts in a single request, when + /// known. `None` means "unknown / unbounded by Thuki" and the gate lets + /// the request through. Today this is keyed off the model architecture + /// reported by `/api/show` (e.g. `mllama` → 1) because Ollama does not + /// surface a declarative max-image count anywhere in its metadata. + #[serde(default)] + pub max_images: Option, +} + +/// Architecture-keyed cap on the number of images accepted per request. +/// Ollama runners enforce these limits internally and answer with an HTTP +/// 500 when violated; mirroring them here lets the frontend gate refuse +/// the submit before the round-trip. +/// +/// Unknown architectures fall through to `None`, which the gate interprets +/// as "no Thuki-side cap", trusting Ollama's runner as the final authority. +/// New architectures only need to be added when we observe a hard, +/// model-specific limit (today: `mllama`, used by llama3.2-vision). +pub fn max_images_for_architecture(arch: &str) -> Option { + match arch { + "mllama" => Some(1), + _ => None, + } +} + +/// Subset of the `/api/show` response that Thuki consumes. All other fields +/// (modelfile, parameters, template, etc.) are ignored. +#[derive(Deserialize)] +struct ShowResponse { + #[serde(default)] + capabilities: Vec, + /// `details.family` (e.g. "mllama", "gemma4"). Older Ollama versions + /// omit this; the field stays optional so decoding never fails on a + /// model that pre-dates the field. + #[serde(default)] + details: Option, + /// Detailed `model_info` map. We only read `general.architecture` from + /// it. Stored as raw JSON so the rest of the (sometimes tens of fields, + /// arbitrary types) payload does not have to be modelled. + #[serde(default)] + model_info: Option>, +} + +/// Subset of `details` from `/api/show`. Only `family` is consumed today; +/// the rest of the object (parameter_size, quantization_level, etc.) is +/// ignored so unrelated changes upstream cannot break decoding. +#[derive(Deserialize)] +struct ShowDetails { + #[serde(default)] + family: Option, +} + +/// Reads the model architecture string from a parsed `/api/show` payload. +/// Prefers `model_info["general.architecture"]` (the canonical source); +/// falls back to `details.family` for older Ollama builds that did not +/// surface the structured `model_info` map. Returns `None` when neither +/// source is populated. +fn architecture_from_show(body: &ShowResponse) -> Option<&str> { + if let Some(mi) = &body.model_info { + if let Some(arch) = mi.get("general.architecture").and_then(|v| v.as_str()) { + if !arch.is_empty() { + return Some(arch); + } + } + } + body.details + .as_ref() + .and_then(|d| d.family.as_deref()) + .filter(|s| !s.is_empty()) +} + +/// Pure mapping from Ollama's capability strings into the typed +/// [`Capabilities`] struct. Unknown strings are silently dropped so a +/// future Ollama version that adds e.g. `"audio"` does not poison the +/// frontend payload. The `max_images` field is left at `None` here and +/// populated by the caller once the architecture is known. +pub fn capabilities_from_strings(items: &[String]) -> Capabilities { + let mut caps = Capabilities::default(); + for c in items { + match c.as_str() { + "vision" => caps.vision = true, + "thinking" => caps.thinking = true, + _ => {} + } + } + caps +} + +/// POSTs `{base_url}/api/show {"name": ""}` and returns the parsed +/// [`Capabilities`] for that model. +/// +/// Every failure mode (transport error, non-2xx status, oversized body, +/// JSON decode error) is translated to `Err(String)` so the Tauri command +/// layer can propagate it verbatim without panicking. +pub async fn fetch_model_capabilities( + client: &reqwest::Client, + base_url: &str, + name: &str, +) -> Result { + fetch_model_capabilities_with_timeout( + client, + base_url, + name, + std::time::Duration::from_secs(DEFAULT_OLLAMA_SHOW_REQUEST_TIMEOUT_SECS), + ) + .await +} + +/// Internal variant of [`fetch_model_capabilities`] with a configurable +/// per-request timeout. Exists so tests can exercise the timeout branch +/// deterministically without waiting the production 5s. +async fn fetch_model_capabilities_with_timeout( + client: &reqwest::Client, + base_url: &str, + name: &str, + timeout: std::time::Duration, +) -> Result { + fetch_model_capabilities_inner(client, base_url, name, timeout, MAX_OLLAMA_SHOW_BODY_BYTES) + .await +} + +/// Innermost implementation of the `/api/show` fetcher. Both timeout and +/// body size cap are configurable so the size-cap branches can be +/// exercised in tests without allocating production-scale buffers. +/// +/// The cap is enforced incrementally during the streaming read: each chunk +/// is checked before being appended, so the connection is aborted the moment +/// the running total would exceed `max_body_bytes` rather than after the full +/// body has been buffered. +async fn fetch_model_capabilities_inner( + client: &reqwest::Client, + base_url: &str, + name: &str, + timeout: std::time::Duration, + max_body_bytes: usize, +) -> Result { + let url = format!("{}/api/show", base_url.trim_end_matches('/')); + let response = client + .post(&url) + .json(&serde_json::json!({ "name": name })) + .timeout(timeout) + .send() + .await + .map_err(|e| format!("failed to reach Ollama: {e}"))?; + + if !response.status().is_success() { + return Err(format!( + "Ollama /api/show returned HTTP {}", + response.status().as_u16() + )); + } + + if let Some(declared_len) = response.content_length() { + if declared_len as usize > max_body_bytes { + return Err(format!( + "/api/show response exceeded {max_body_bytes} bytes" + )); + } + } + + let mut stream = response.bytes_stream(); + let mut buf: Vec = Vec::new(); + while let Some(chunk) = stream.next().await { + let chunk = chunk.map_err(|e| format!("failed to read /api/show body: {e}"))?; + if buf.len() + chunk.len() > max_body_bytes { + return Err(format!( + "/api/show response exceeded {max_body_bytes} bytes" + )); + } + buf.extend_from_slice(&chunk); + } + + let body: ShowResponse = serde_json::from_slice(&buf) + .map_err(|e| format!("failed to decode /api/show response: {e}"))?; + + let mut caps = capabilities_from_strings(&body.capabilities); + // Only attach max_images for vision models. There is no point capping a + // text-only model on an image count; the vision gate refuses those + // submits before the count check ever runs. + if caps.vision { + if let Some(arch) = architecture_from_show(&body) { + caps.max_images = max_images_for_architecture(arch); + } + } + Ok(caps) +} + +/// In-memory cache of capabilities keyed by model slug. Populated lazily +/// the first time a model is queried. Cleared on app restart, which is +/// the simplest valid invalidation strategy: re-pulling a model under the +/// same slug requires a process restart anyway because Tauri's reqwest +/// client is process-scoped, and capabilities for a given (slug, digest) +/// pair never change. +#[derive(Default)] +pub struct ModelCapabilitiesCache(pub Mutex>); + +/// Fetches `/api/tags` for the installed list, then returns a map of +/// `model name -> Capabilities` covering every installed model. Uses the +/// cache for hits and POSTs `/api/show` sequentially for misses, writing +/// results through to the cache. +/// +/// Sequential fetch is intentional: localhost Ollama responds in tens of +/// milliseconds, the typical user has fewer than ten models installed, +/// and sequential keeps lifetime / borrow plumbing simple. Per-model +/// fetch failures are skipped (the offending entry is just absent from +/// the result map) so a single bad model cannot blank out the whole +/// picker. +#[cfg_attr(coverage_nightly, coverage(off))] +#[cfg_attr(not(coverage), tauri::command)] +pub async fn get_model_capabilities( + client: tauri::State<'_, reqwest::Client>, + cache: tauri::State<'_, ModelCapabilitiesCache>, + config: tauri::State<'_, AppConfig>, +) -> Result, String> { + let base_url = &config.model.ollama_url; + let installed = fetch_installed_model_names(&client, base_url).await?; + Ok(reconcile_capabilities(&client, &cache, base_url, &installed).await) +} + +/// Pure-ish helper extracted so tests can drive the cache + fetch loop +/// against a `mockito` server without going through the Tauri command +/// boundary. Honors the cache for already-known slugs and fetches the +/// rest from `base_url`. +/// +/// Defense-in-depth: every miss is shape-checked via [`validate_model_slug`] +/// before being sent in the `/api/show` JSON body. Slugs that come from +/// `/api/tags` should already be well-formed, but a compromised or +/// misbehaving Ollama could return a slug containing control characters +/// or shell metacharacters; this guard keeps such inputs out of the +/// request entirely. Invalid slugs are silently dropped so they are +/// simply absent from the result map. +/// +/// Concurrency: the read snapshot, the per-miss fetch, and the +/// write-back each take their own short-lived `Mutex` guard. Two +/// concurrent calls for the same miss may both fetch and both write the +/// same value. This is benign because the operation is idempotent (the +/// same `(slug, /api/show)` always yields the same `Capabilities`); the +/// only cost is a duplicate POST. +async fn reconcile_capabilities( + client: &reqwest::Client, + cache: &ModelCapabilitiesCache, + base_url: &str, + installed: &[String], +) -> HashMap { + let mut hits: HashMap = HashMap::new(); + let mut misses: Vec = Vec::new(); + match cache.0.lock() { + Ok(guard) => { + for name in installed { + if let Some(c) = guard.get(name) { + hits.insert(name.clone(), c.clone()); + } else { + misses.push(name.clone()); + } + } + } + Err(_) => { + // Poisoned lock: treat every requested slug as a miss so the + // caller still gets a best-effort result. + misses.extend(installed.iter().cloned()); + } + } + for name in &misses { + if validate_model_slug(name).is_err() { + continue; + } + if let Ok(caps) = fetch_model_capabilities(client, base_url, name).await { + if let Ok(mut guard) = cache.0.lock() { + guard.insert(name.clone(), caps.clone()); + } + hits.insert(name.clone(), caps); + } + } + hits +} + +// ─── Tests ────────────────────────────────────────────────────────────────── + +#[cfg(test)] +mod tests { + use super::*; + + // ── resolve_active_model ───────────────────────────────────────────────── + + #[test] + fn resolve_prefers_persisted_when_still_installed() { + let installed = vec!["gemma4:e2b".to_string(), "gemma4:e4b".to_string()]; + let result = resolve_active_model(Some("gemma4:e4b"), &installed, "gemma4:e2b"); + assert_eq!(result, "gemma4:e4b"); + } + + #[test] + fn resolve_falls_back_to_first_installed_when_persisted_missing() { + let installed = vec!["gemma4:e2b".to_string(), "gemma4:e4b".to_string()]; + let result = resolve_active_model(Some("llama3:8b"), &installed, "bootstrap-model"); + assert_eq!(result, "gemma4:e2b"); + } + + #[test] + fn resolve_falls_back_to_bootstrap_when_nothing_installed() { + let installed: Vec = vec![]; + let result = resolve_active_model(None, &installed, "bootstrap-model"); + assert_eq!(result, "bootstrap-model"); + } + + #[test] + fn resolve_with_no_persisted_uses_first_installed() { + let installed = vec!["gemma4:e2b".to_string()]; + let result = resolve_active_model(None, &installed, "bootstrap-model"); + assert_eq!(result, "gemma4:e2b"); + } + + #[test] + fn resolve_with_empty_persisted_bootstrap_used_when_installed_empty() { + let installed: Vec = vec![]; + let result = resolve_active_model(Some("gemma4:e2b"), &installed, "fallback"); + assert_eq!(result, "fallback"); + } + + // ── resolve_seed_active_model ──────────────────────────────────────────── + + #[test] + fn seed_resolve_prefers_persisted() { + let result = resolve_seed_active_model(Some("llama3:8b"), "bootstrap-model"); + assert_eq!(result, "llama3:8b"); + } + + #[test] + fn seed_resolve_falls_back_to_bootstrap_when_none() { + let result = resolve_seed_active_model(None, "bootstrap-model"); + assert_eq!(result, "bootstrap-model"); + } + + #[test] + fn seed_resolve_falls_back_to_bootstrap_when_empty_persisted() { + let result = resolve_seed_active_model(Some(""), "bootstrap-model"); + assert_eq!(result, "bootstrap-model"); + } + + // ── should_persist_resolved ───────────────────────────────────────────── + + #[test] + fn should_persist_true_when_resolved_differs_and_inventory_present() { + let installed = vec!["gemma4:e2b".to_string()]; + assert!(should_persist_resolved( + &installed, + Some("llama3:8b"), + "gemma4:e2b" + )); + } + + #[test] + fn should_persist_false_when_resolved_matches_persisted() { + let installed = vec!["gemma4:e2b".to_string()]; + assert!(!should_persist_resolved( + &installed, + Some("gemma4:e2b"), + "gemma4:e2b" + )); + } + + #[test] + fn should_persist_false_when_inventory_empty() { + let installed: Vec = vec![]; + assert!(!should_persist_resolved(&installed, None, "bootstrap")); + } + + #[test] + fn should_persist_true_when_nothing_previously_persisted_but_resolved_available() { + let installed = vec!["gemma4:e2b".to_string()]; + assert!(should_persist_resolved(&installed, None, "gemma4:e2b")); + } + + // ── validate_model_installed ───────────────────────────────────────────── + + #[test] + fn validate_accepts_installed_model() { + let installed = vec!["gemma4:e2b".to_string(), "gemma4:e4b".to_string()]; + assert!(validate_model_installed("gemma4:e4b", &installed).is_ok()); + } + + #[test] + fn validate_rejects_uninstalled_model_with_stable_prefix() { + let installed = vec!["gemma4:e2b".to_string()]; + let err = validate_model_installed("llama3:8b", &installed).unwrap_err(); + assert!( + err.starts_with(MODEL_NOT_INSTALLED_ERR_PREFIX), + "expected stable prefix, got: {err}" + ); + assert!(err.ends_with("llama3:8b")); + } + + #[test] + fn validate_rejects_when_installed_list_empty() { + let installed: Vec = vec![]; + let err = validate_model_installed("gemma4:e2b", &installed).unwrap_err(); + assert_eq!(err, format!("{MODEL_NOT_INSTALLED_ERR_PREFIX}gemma4:e2b")); + } + + // ── validate_model_slug ────────────────────────────────────────────────── + + #[test] + fn validate_slug_accepts_valid_forms() { + assert!(validate_model_slug("gemma4:e2b").is_ok()); + assert!(validate_model_slug("llama3.1:8b").is_ok()); + assert!(validate_model_slug("registry.example.com/user/model:tag").is_ok()); + assert!(validate_model_slug("my_model-v2").is_ok()); + } + + #[test] + fn validate_slug_rejects_empty() { + let err = validate_model_slug("").unwrap_err(); + assert!(err.contains("empty")); + } + + #[test] + fn validate_slug_rejects_oversized() { + let oversized = "a".repeat(MAX_MODEL_SLUG_LEN + 1); + let err = validate_model_slug(&oversized).unwrap_err(); + assert!(err.contains("maximum length")); + } + + #[test] + fn validate_slug_accepts_max_length() { + let at_limit = "a".repeat(MAX_MODEL_SLUG_LEN); + assert!(validate_model_slug(&at_limit).is_ok()); + } + + #[test] + fn validate_slug_rejects_shell_metacharacters() { + assert!(validate_model_slug("bad; rm -rf /").is_err()); + assert!(validate_model_slug("../etc/passwd").is_ok()); // `.` `/` `-` allowed individually + assert!(validate_model_slug("bad name").is_err()); // whitespace rejected + assert!(validate_model_slug("bad\nname").is_err()); + assert!(validate_model_slug("bad$(whoami)").is_err()); + assert!(validate_model_slug("bad`whoami`").is_err()); + } + + #[test] + fn validate_slug_rejects_non_ascii() { + assert!(validate_model_slug("gëmma").is_err()); + } + + // ── fetch_installed_model_names ────────────────────────────────────────── + + #[tokio::test] + async fn fetch_parses_valid_tags_response() { + let mut server = mockito::Server::new_async().await; + let body = r#"{"models":[ + {"name":"gemma4:e2b"}, + {"name":"gemma4:e4b"} + ]}"#; + let mock = server + .mock("GET", "/api/tags") + .with_status(200) + .with_header("content-type", "application/json") + .with_body(body) + .create_async() + .await; + + let client = reqwest::Client::new(); + let result = fetch_installed_model_names(&client, &server.url()).await; + + mock.assert_async().await; + let names = result.unwrap(); + assert_eq!( + names, + vec!["gemma4:e2b".to_string(), "gemma4:e4b".to_string()] + ); + } + + #[tokio::test] + async fn fetch_returns_empty_when_no_models_installed() { + let mut server = mockito::Server::new_async().await; + let mock = server + .mock("GET", "/api/tags") + .with_status(200) + .with_header("content-type", "application/json") + .with_body(r#"{"models":[]}"#) + .create_async() + .await; + + let client = reqwest::Client::new(); + let result = fetch_installed_model_names(&client, &server.url()).await; + + mock.assert_async().await; + assert_eq!(result.unwrap(), Vec::::new()); + } + + #[tokio::test] + async fn fetch_maps_http_error_to_err_string() { + let mut server = mockito::Server::new_async().await; + let mock = server + .mock("GET", "/api/tags") + .with_status(500) + .with_body("server blew up") + .create_async() + .await; + + let client = reqwest::Client::new(); + let result = fetch_installed_model_names(&client, &server.url()).await; + + mock.assert_async().await; + let err = result.unwrap_err(); + assert!( + err.contains("500"), + "expected status code in error, got: {err}" + ); + } + + #[tokio::test] + async fn fetch_maps_invalid_json_to_err_string() { + let mut server = mockito::Server::new_async().await; + let mock = server + .mock("GET", "/api/tags") + .with_status(200) + .with_header("content-type", "application/json") + .with_body("not json at all") + .create_async() + .await; + + let client = reqwest::Client::new(); + let result = fetch_installed_model_names(&client, &server.url()).await; + + mock.assert_async().await; + let err = result.unwrap_err(); + assert!( + err.contains("failed to decode"), + "expected decode error, got: {err}" + ); + } + + #[tokio::test] + async fn fetch_maps_transport_error_to_err_string() { + // Port 1 is reserved and will refuse connections; tests the `send()` + // error branch without a live server. + let client = reqwest::Client::new(); + let result = fetch_installed_model_names(&client, "http://127.0.0.1:1").await; + + let err = result.unwrap_err(); + assert!( + err.contains("failed to reach Ollama"), + "expected transport error, got: {err}" + ); + } + + #[tokio::test] + async fn fetch_installed_model_names_times_out_when_ollama_hangs() { + let listener = std::net::TcpListener::bind("127.0.0.1:0").unwrap(); + let addr = listener.local_addr().unwrap(); + std::thread::spawn(move || { + let _held = listener.accept().ok(); + std::thread::sleep(std::time::Duration::from_secs(10)); + }); + + let client = reqwest::Client::new(); + let base = format!("http://{addr}"); + let result = fetch_installed_model_names_with_timeout( + &client, + &base, + std::time::Duration::from_millis(100), + ) + .await; + + let err = result.unwrap_err(); + assert!( + err.contains("failed to reach Ollama"), + "expected timeout to surface as transport error, got: {err}" + ); + } + + #[tokio::test] + async fn fetch_trims_trailing_slash_from_base_url() { + let mut server = mockito::Server::new_async().await; + let mock = server + .mock("GET", "/api/tags") + .with_status(200) + .with_header("content-type", "application/json") + .with_body(r#"{"models":[{"name":"x"}]}"#) + .create_async() + .await; + + let client = reqwest::Client::new(); + let url_with_slash = format!("{}/", server.url()); + let result = fetch_installed_model_names(&client, &url_with_slash).await; + + mock.assert_async().await; + assert_eq!(result.unwrap(), vec!["x".to_string()]); + } + + #[tokio::test] + async fn fetch_rejects_body_exceeding_size_cap_via_content_length() { + let mut server = mockito::Server::new_async().await; + // Tight cap (32 bytes) + a declared Content-Length that matches a + // 100-byte payload; the pre-read guard on `content_length` must + // reject before the bytes() call is issued. + let body = "x".repeat(100); + server + .mock("GET", "/api/tags") + .with_status(200) + .with_header("content-type", "application/json") + .with_body(body) + .create_async() + .await; + + let client = reqwest::Client::new(); + let result = fetch_installed_model_names_inner( + &client, + &server.url(), + std::time::Duration::from_secs(5), + 32, + ) + .await; + + let err = result.unwrap_err(); + assert!( + err.contains("exceeded"), + "expected size-cap error, got: {err}" + ); + } + + #[tokio::test] + async fn fetch_maps_body_read_error_to_err_string() { + // Headers advertise Content-Length but the server closes the socket + // before sending any body bytes. reqwest's bytes() surfaces this as + // a transport error; the helper must map it to the documented prose. + let listener = std::net::TcpListener::bind("127.0.0.1:0").unwrap(); + let addr = listener.local_addr().unwrap(); + std::thread::spawn(move || { + let (mut stream, _) = listener.accept().unwrap(); + use std::io::{Read, Write}; + let mut buf = [0u8; 1024]; + let _ = stream.read(&mut buf); + // Promise 100 body bytes, then immediately hang up. + let _ = stream.write_all( + b"HTTP/1.1 200 OK\r\nContent-Type: application/json\r\nContent-Length: 100\r\nConnection: close\r\n\r\n", + ); + }); + + let client = reqwest::Client::new(); + let base = format!("http://{addr}"); + let result = fetch_installed_model_names(&client, &base).await; + + let err = result.unwrap_err(); + assert!( + err.contains("failed to read /api/tags body"), + "expected body-read error, got: {err}" + ); + } + + #[tokio::test] + async fn fetch_rejects_body_exceeding_size_cap_when_no_content_length() { + // Chunked-encoding response (no Content-Length); the incremental stream + // cap must reject when the running total exceeds the limit. + let listener = std::net::TcpListener::bind("127.0.0.1:0").unwrap(); + let addr = listener.local_addr().unwrap(); + std::thread::spawn(move || { + let (mut stream, _) = listener.accept().unwrap(); + use std::io::{Read, Write}; + let mut buf = [0u8; 1024]; + let _ = stream.read(&mut buf); + let body = "x".repeat(200); + let response = format!( + "HTTP/1.1 200 OK\r\nContent-Type: application/json\r\nTransfer-Encoding: chunked\r\n\r\n{:x}\r\n{}\r\n0\r\n\r\n", + body.len(), + body + ); + let _ = stream.write_all(response.as_bytes()); + }); + + let client = reqwest::Client::new(); + let base = format!("http://{addr}"); + let result = fetch_installed_model_names_inner( + &client, + &base, + std::time::Duration::from_secs(5), + 32, + ) + .await; + + let err = result.unwrap_err(); + assert!( + err.contains("exceeded"), + "expected incremental stream cap error, got: {err}" + ); + } + + #[tokio::test] + async fn fetch_tags_chunked_early_abort_incremental() { + // Explicit test of the incremental streaming abort: the response has NO + // Content-Length header and sends chunks whose cumulative size exceeds + // the cap. The abort must fire during the streaming read, not after. + let listener = std::net::TcpListener::bind("127.0.0.1:0").unwrap(); + let addr = listener.local_addr().unwrap(); + std::thread::spawn(move || { + let (mut conn, _) = listener.accept().unwrap(); + use std::io::{Read, Write}; + let mut request_buf = [0u8; 1024]; + let _ = conn.read(&mut request_buf); + // Send two small chunks without Content-Length (chunked encoding). + // Each chunk alone is under the cap of 20 bytes, but together + // they exceed it, exercising the incremental buf.len() + chunk.len() + // check inside the stream loop. + let _ = conn.write_all( + b"HTTP/1.1 200 OK\r\nTransfer-Encoding: chunked\r\n\r\n\ + 0a\r\n0123456789\r\n\ + 0a\r\n0123456789\r\n\ + 0a\r\n0123456789\r\n\ + 0\r\n\r\n", + ); + }); + let client = reqwest::Client::new(); + let base = format!("http://{addr}"); + let err = fetch_installed_model_names_inner( + &client, + &base, + std::time::Duration::from_secs(5), + 20, + ) + .await + .unwrap_err(); + assert!( + err.contains("exceeded"), + "expected incremental abort error, got: {err}" + ); + } + + // ── ActiveModelState ───────────────────────────────────────────────────── + + #[test] + fn active_model_state_defaults_to_empty_string() { + let state = ActiveModelState::default(); + assert_eq!(*state.0.lock().unwrap(), ""); + } + + #[test] + fn active_model_state_round_trip_write_read() { + let state = ActiveModelState::default(); + { + let mut guard = state.0.lock().unwrap(); + *guard = "gemma4:e2b".to_string(); + } + assert_eq!(*state.0.lock().unwrap(), "gemma4:e2b"); + } + + // ── Persistence round-trip through app_config ─────────────────────────── + + #[test] + fn active_model_key_persists_via_set_and_get_config() { + let conn = crate::database::open_in_memory().unwrap(); + set_config(&conn, ACTIVE_MODEL_KEY, "gemma4:e4b").unwrap(); + let back = get_config(&conn, ACTIVE_MODEL_KEY).unwrap(); + assert_eq!(back.as_deref(), Some("gemma4:e4b")); + } + + #[test] + fn active_model_key_constant_matches_expected_value() { + assert_eq!(ACTIVE_MODEL_KEY, "active_model"); + } + + #[test] + fn model_not_installed_err_prefix_is_stable() { + assert_eq!( + MODEL_NOT_INSTALLED_ERR_PREFIX, + "Model is not installed in Ollama: " + ); + } + + // ── derive_model_setup_state (Phase 3 onboarding gate) ────────────────── + + #[test] + fn derive_setup_state_returns_unreachable_on_fetch_error() { + let state = + derive_model_setup_state(Err("connection refused".to_string()), None, "gemma4:e2b"); + assert_eq!(state, ModelSetupState::OllamaUnreachable); + } + + #[test] + fn derive_setup_state_returns_unreachable_even_when_persisted_choice_exists() { + // Past selection must NOT mask a current outage. The user needs to + // see the "Ollama not detected" screen even if SQLite remembers a slug. + let state = + derive_model_setup_state(Err("timeout".to_string()), Some("gemma4:e4b"), "gemma4:e2b"); + assert_eq!(state, ModelSetupState::OllamaUnreachable); + } + + #[test] + fn derive_setup_state_returns_no_models_when_inventory_empty() { + let state = derive_model_setup_state(Ok(vec![]), None, "gemma4:e2b"); + assert_eq!(state, ModelSetupState::NoModelsInstalled); + } + + #[test] + fn derive_setup_state_returns_no_models_even_with_stale_persisted_slug() { + // Daemon up but the user removed every model with `ollama rm`. The + // persisted slug is no longer valid; the gate must re-engage. + let state = derive_model_setup_state(Ok(vec![]), Some("removed-model:7b"), "gemma4:e2b"); + assert_eq!(state, ModelSetupState::NoModelsInstalled); + } + + #[test] + fn derive_setup_state_ready_keeps_persisted_when_still_installed() { + let state = derive_model_setup_state( + Ok(vec!["gemma4:e2b".to_string(), "llama3:8b".to_string()]), + Some("llama3:8b"), + "gemma4:e2b", + ); + assert_eq!( + state, + ModelSetupState::Ready { + active_slug: "llama3:8b".to_string(), + installed: vec!["gemma4:e2b".to_string(), "llama3:8b".to_string()], + } + ); + } + + #[test] + fn derive_setup_state_ready_falls_back_to_first_when_persisted_gone() { + let state = derive_model_setup_state( + Ok(vec!["gemma4:e4b".to_string(), "llama3:8b".to_string()]), + Some("removed-model:7b"), + "gemma4:e2b", + ); + assert_eq!( + state, + ModelSetupState::Ready { + active_slug: "gemma4:e4b".to_string(), + installed: vec!["gemma4:e4b".to_string(), "llama3:8b".to_string()], + } + ); + } + + #[test] + fn derive_setup_state_ready_uses_first_when_no_persisted_choice() { + // First-time user who somehow has models installed already (rare: + // they used Ollama for something else first). Pick the first. + let state = + derive_model_setup_state(Ok(vec!["qwen2.5:7b".to_string()]), None, "gemma4:e2b"); + assert_eq!( + state, + ModelSetupState::Ready { + active_slug: "qwen2.5:7b".to_string(), + installed: vec!["qwen2.5:7b".to_string()], + } + ); + } + + #[test] + fn model_setup_state_serializes_with_state_tag_for_frontend() { + // Wire format must be discriminated on a `state` field so the + // React side can route on a single string before pattern-matching + // payload shape. Drift here breaks the frontend dispatch. + let unreachable = serde_json::to_value(ModelSetupState::OllamaUnreachable).unwrap(); + assert_eq!( + unreachable, + serde_json::json!({"state": "ollama_unreachable"}) + ); + + let none = serde_json::to_value(ModelSetupState::NoModelsInstalled).unwrap(); + assert_eq!(none, serde_json::json!({"state": "no_models_installed"})); + + let ready = serde_json::to_value(ModelSetupState::Ready { + active_slug: "gemma4:e2b".to_string(), + installed: vec!["gemma4:e2b".to_string()], + }) + .unwrap(); + assert_eq!( + ready, + serde_json::json!({ + "state": "ready", + "active_slug": "gemma4:e2b", + "installed": ["gemma4:e2b"], + }) + ); + } + + // ── capabilities_from_strings ──────────────────────────────────────────── + + #[test] + fn capabilities_from_strings_recognises_all_known_flags() { + let caps = capabilities_from_strings(&["vision".to_string(), "thinking".to_string()]); + assert!(caps.vision); + assert!(caps.thinking); + } + + #[test] + fn capabilities_from_strings_defaults_to_all_false_on_empty() { + let caps = capabilities_from_strings(&[]); + assert!(!caps.vision); + assert!(!caps.thinking); + } + + #[test] + fn capabilities_from_strings_drops_unknown_flags_silently() { + let caps = capabilities_from_strings(&[ + "vision".to_string(), + "tools".to_string(), + "audio".to_string(), + "completion".to_string(), + "future-thing".to_string(), + ]); + assert!(caps.vision); + assert!(!caps.thinking); + } + + #[test] + fn capabilities_serialize_uses_camel_case_field_names() { + let caps = Capabilities { + vision: true, + thinking: false, + max_images: Some(1), + }; + let v = serde_json::to_value(&caps).unwrap(); + assert_eq!( + v, + serde_json::json!({ + "vision": true, + "thinking": false, + "maxImages": 1, + }) + ); + } + + #[test] + fn capabilities_serialize_emits_null_max_images_when_unknown() { + let caps = Capabilities { + vision: true, + thinking: false, + max_images: None, + }; + let v = serde_json::to_value(&caps).unwrap(); + assert_eq!(v["maxImages"], serde_json::Value::Null); + } + + #[test] + fn capabilities_deserialize_tolerates_missing_fields() { + let caps: Capabilities = serde_json::from_value(serde_json::json!({})).unwrap(); + assert_eq!(caps, Capabilities::default()); + } + + #[test] + fn capabilities_deserialize_round_trips_max_images() { + let caps: Capabilities = serde_json::from_value(serde_json::json!({ + "vision": true, + "thinking": false, + "maxImages": 3 + })) + .unwrap(); + assert!(caps.vision); + assert_eq!(caps.max_images, Some(3)); + } + + // ── max_images_for_architecture ───────────────────────────────────────── + + #[test] + fn max_images_caps_mllama_at_one() { + assert_eq!(max_images_for_architecture("mllama"), Some(1)); + } + + #[test] + fn max_images_returns_none_for_unknown_arch() { + assert_eq!(max_images_for_architecture("gemma4"), None); + assert_eq!(max_images_for_architecture(""), None); + assert_eq!(max_images_for_architecture("future-arch"), None); + } + + // ── architecture_from_show ────────────────────────────────────────────── + + #[test] + fn architecture_prefers_model_info_general_architecture() { + let body: ShowResponse = serde_json::from_value(serde_json::json!({ + "capabilities": ["completion","vision"], + "details": {"family": "fallback-family"}, + "model_info": {"general.architecture": "mllama"} + })) + .unwrap(); + assert_eq!(architecture_from_show(&body), Some("mllama")); + } + + #[test] + fn architecture_falls_back_to_details_family_when_model_info_absent() { + let body: ShowResponse = serde_json::from_value(serde_json::json!({ + "capabilities": ["completion","vision"], + "details": {"family": "mllama"} + })) + .unwrap(); + assert_eq!(architecture_from_show(&body), Some("mllama")); + } + + #[test] + fn architecture_falls_back_when_model_info_arch_is_blank() { + let body: ShowResponse = serde_json::from_value(serde_json::json!({ + "capabilities": [], + "details": {"family": "mllama"}, + "model_info": {"general.architecture": ""} + })) + .unwrap(); + assert_eq!(architecture_from_show(&body), Some("mllama")); + } + + #[test] + fn architecture_returns_none_when_neither_source_populated() { + let body: ShowResponse = serde_json::from_value(serde_json::json!({ + "capabilities": [] + })) + .unwrap(); + assert_eq!(architecture_from_show(&body), None); + } + + #[test] + fn architecture_returns_none_when_details_family_blank() { + let body: ShowResponse = serde_json::from_value(serde_json::json!({ + "capabilities": [], + "details": {"family": ""} + })) + .unwrap(); + assert_eq!(architecture_from_show(&body), None); + } + + #[test] + fn architecture_ignores_non_string_general_architecture() { + let body: ShowResponse = serde_json::from_value(serde_json::json!({ + "capabilities": [], + "details": {"family": "mllama"}, + "model_info": {"general.architecture": 7} + })) + .unwrap(); + // Non-string in model_info falls through; details.family wins. + assert_eq!(architecture_from_show(&body), Some("mllama")); + } + + // ── fetch_model_capabilities ───────────────────────────────────────────── + + #[tokio::test] + async fn fetch_capabilities_parses_full_response() { + let mut server = mockito::Server::new_async().await; + let body = r#"{"capabilities":["completion","vision","thinking"],"modelfile":"…"}"#; + let _m = server + .mock("POST", "/api/show") + .with_status(200) + .with_header("content-type", "application/json") + .with_body(body) + .create_async() + .await; + let client = reqwest::Client::new(); + let caps = fetch_model_capabilities(&client, &server.url(), "llama3.2-vision") + .await + .unwrap(); + assert!(caps.vision); + assert!(caps.thinking); + } + + #[tokio::test] + async fn fetch_capabilities_attaches_max_images_for_mllama_vision_models() { + let mut server = mockito::Server::new_async().await; + let body = r#"{ + "capabilities":["completion","vision"], + "details":{"family":"mllama"}, + "model_info":{"general.architecture":"mllama"} + }"#; + let _m = server + .mock("POST", "/api/show") + .with_status(200) + .with_header("content-type", "application/json") + .with_body(body) + .create_async() + .await; + let client = reqwest::Client::new(); + let caps = fetch_model_capabilities(&client, &server.url(), "llama3.2-vision:11b") + .await + .unwrap(); + assert!(caps.vision); + assert_eq!(caps.max_images, Some(1)); + } + + #[tokio::test] + async fn fetch_capabilities_leaves_max_images_unset_for_unknown_arch() { + let mut server = mockito::Server::new_async().await; + let body = r#"{ + "capabilities":["completion","vision","thinking"], + "details":{"family":"gemma4"}, + "model_info":{"general.architecture":"gemma4"} + }"#; + let _m = server + .mock("POST", "/api/show") + .with_status(200) + .with_header("content-type", "application/json") + .with_body(body) + .create_async() + .await; + let client = reqwest::Client::new(); + let caps = fetch_model_capabilities(&client, &server.url(), "gemma4:e2b") + .await + .unwrap(); + assert!(caps.vision); + assert!(caps.thinking); + assert_eq!(caps.max_images, None); + } + + #[tokio::test] + async fn fetch_capabilities_skips_max_images_for_text_only_models() { + // No point capping a text-only model on image count; vision gate + // will refuse the submit before max_images is consulted anyway. + let mut server = mockito::Server::new_async().await; + let body = r#"{ + "capabilities":["completion"], + "details":{"family":"mllama"}, + "model_info":{"general.architecture":"mllama"} + }"#; + let _m = server + .mock("POST", "/api/show") + .with_status(200) + .with_header("content-type", "application/json") + .with_body(body) + .create_async() + .await; + let client = reqwest::Client::new(); + let caps = fetch_model_capabilities(&client, &server.url(), "x") + .await + .unwrap(); + assert!(!caps.vision); + assert_eq!(caps.max_images, None); + } + + #[tokio::test] + async fn fetch_capabilities_handles_missing_array() { + let mut server = mockito::Server::new_async().await; + let _m = server + .mock("POST", "/api/show") + .with_status(200) + .with_body(r#"{"modelfile":"…"}"#) + .create_async() + .await; + let client = reqwest::Client::new(); + let caps = fetch_model_capabilities(&client, &server.url(), "x") + .await + .unwrap(); + assert_eq!(caps, Capabilities::default()); + } + + #[tokio::test] + async fn fetch_capabilities_returns_err_on_non_2xx() { + let mut server = mockito::Server::new_async().await; + let _m = server + .mock("POST", "/api/show") + .with_status(404) + .with_body("not found") + .create_async() + .await; + let client = reqwest::Client::new(); + let err = fetch_model_capabilities(&client, &server.url(), "missing") + .await + .unwrap_err(); + assert!(err.contains("404")); + } + + #[tokio::test] + async fn fetch_capabilities_returns_err_on_invalid_json() { + let mut server = mockito::Server::new_async().await; + let _m = server + .mock("POST", "/api/show") + .with_status(200) + .with_body("not json") + .create_async() + .await; + let client = reqwest::Client::new(); + let err = fetch_model_capabilities(&client, &server.url(), "x") + .await + .unwrap_err(); + assert!(err.contains("decode")); + } + + #[tokio::test] + async fn fetch_capabilities_returns_err_on_unreachable() { + let client = reqwest::Client::new(); + let err = fetch_model_capabilities(&client, "http://127.0.0.1:1", "x") + .await + .unwrap_err(); + assert!(err.contains("failed to reach Ollama")); + } + + #[tokio::test] + async fn fetch_capabilities_rejects_oversized_via_content_length() { + // Tight cap + 100-byte body; mockito sets Content-Length: 100, the + // pre-read guard on `content_length` must reject before bytes() is + // issued. + let mut server = mockito::Server::new_async().await; + let body = "x".repeat(100); + server + .mock("POST", "/api/show") + .with_status(200) + .with_header("content-type", "application/json") + .with_body(body) + .create_async() + .await; + let client = reqwest::Client::new(); + let err = fetch_model_capabilities_inner( + &client, + &server.url(), + "x", + std::time::Duration::from_secs(5), + 32, + ) + .await + .unwrap_err(); + assert!(err.contains("exceeded")); + } + + #[tokio::test] + async fn fetch_capabilities_rejects_oversized_when_no_content_length() { + // Chunked-encoding response (no Content-Length); the incremental stream + // cap must reject when the running total exceeds the limit. + let listener = std::net::TcpListener::bind("127.0.0.1:0").unwrap(); + let addr = listener.local_addr().unwrap(); + std::thread::spawn(move || { + let (mut stream, _) = listener.accept().unwrap(); + use std::io::{Read, Write}; + let mut buf = [0u8; 1024]; + let _ = stream.read(&mut buf); + let body = "x".repeat(200); + let response = format!( + "HTTP/1.1 200 OK\r\nContent-Type: application/json\r\nTransfer-Encoding: chunked\r\n\r\n{:x}\r\n{}\r\n0\r\n\r\n", + body.len(), + body + ); + let _ = stream.write_all(response.as_bytes()); + }); + let client = reqwest::Client::new(); + let base = format!("http://{addr}"); + let err = fetch_model_capabilities_inner( + &client, + &base, + "x", + std::time::Duration::from_secs(5), + 32, + ) + .await + .unwrap_err(); + assert!(err.contains("exceeded")); + } + + #[tokio::test] + async fn fetch_show_chunked_early_abort_incremental() { + // Explicit test of the incremental streaming abort for /api/show: the + // response has NO Content-Length header and sends chunks whose + // cumulative size exceeds the cap. The abort must fire during the + // streaming read, not after. + let listener = std::net::TcpListener::bind("127.0.0.1:0").unwrap(); + let addr = listener.local_addr().unwrap(); + std::thread::spawn(move || { + let (mut conn, _) = listener.accept().unwrap(); + use std::io::{Read, Write}; + let mut request_buf = [0u8; 1024]; + let _ = conn.read(&mut request_buf); + // Send three 10-byte chunks without Content-Length (chunked + // encoding). Each chunk alone is under the cap of 20 bytes, but + // together they exceed it, exercising the incremental check. + let _ = conn.write_all( + b"HTTP/1.1 200 OK\r\nTransfer-Encoding: chunked\r\n\r\n\ + 0a\r\n0123456789\r\n\ + 0a\r\n0123456789\r\n\ + 0a\r\n0123456789\r\n\ + 0\r\n\r\n", + ); + }); + let client = reqwest::Client::new(); + let base = format!("http://{addr}"); + let err = fetch_model_capabilities_inner( + &client, + &base, + "x", + std::time::Duration::from_secs(5), + 20, + ) + .await + .unwrap_err(); + assert!( + err.contains("exceeded"), + "expected incremental abort error, got: {err}" + ); + } + + #[tokio::test] + async fn fetch_capabilities_maps_body_read_error_to_err_string() { + // Headers promise body but the server hangs up. + let listener = std::net::TcpListener::bind("127.0.0.1:0").unwrap(); + let addr = listener.local_addr().unwrap(); + std::thread::spawn(move || { + let (mut stream, _) = listener.accept().unwrap(); + use std::io::{Read, Write}; + let mut buf = [0u8; 1024]; + let _ = stream.read(&mut buf); + let _ = stream.write_all( + b"HTTP/1.1 200 OK\r\nContent-Type: application/json\r\nContent-Length: 100\r\nConnection: close\r\n\r\n", + ); + }); + let client = reqwest::Client::new(); + let base = format!("http://{addr}"); + let err = fetch_model_capabilities(&client, &base, "x") + .await + .unwrap_err(); + assert!(err.contains("failed to read /api/show body")); + } + + #[tokio::test] + async fn fetch_capabilities_with_custom_timeout_branch_runs() { + let mut server = mockito::Server::new_async().await; + let _m = server + .mock("POST", "/api/show") + .with_status(200) + .with_body(r#"{"capabilities":["vision"]}"#) + .create_async() + .await; + let client = reqwest::Client::new(); + let caps = fetch_model_capabilities_with_timeout( + &client, + &server.url(), + "x", + std::time::Duration::from_millis(500), + ) + .await + .unwrap(); + assert!(caps.vision); + } + + // ── reconcile_capabilities ─────────────────────────────────────────────── + + /// `reconcile_capabilities` calls `DEFAULT_OLLAMA_URL` directly which + /// points at 127.0.0.1:11434. To keep the test deterministic without a + /// running Ollama we exercise the helper in cache-only mode: pre-seed + /// every requested name into the cache so no network call is issued. + #[tokio::test] + async fn reconcile_returns_cached_entries_without_network() { + let cache = ModelCapabilitiesCache::default(); + cache.0.lock().unwrap().insert( + "a".to_string(), + Capabilities { + vision: true, + ..Default::default() + }, + ); + cache.0.lock().unwrap().insert( + "b".to_string(), + Capabilities { + thinking: true, + ..Default::default() + }, + ); + let client = reqwest::Client::new(); + let installed = vec!["a".to_string(), "b".to_string()]; + let result = reconcile_capabilities(&client, &cache, "http://unused", &installed).await; + assert_eq!(result.len(), 2); + assert!(result["a"].vision); + assert!(result["b"].thinking); + } + + #[tokio::test] + async fn reconcile_with_empty_installed_returns_empty_map() { + let cache = ModelCapabilitiesCache::default(); + let client = reqwest::Client::new(); + let result = reconcile_capabilities(&client, &cache, "http://unused", &[]).await; + assert!(result.is_empty()); + } + + #[tokio::test] + async fn reconcile_fetches_misses_and_writes_through_cache() { + let mut server = mockito::Server::new_async().await; + let _m = server + .mock("POST", "/api/show") + .with_status(200) + .with_body(r#"{"capabilities":["completion","vision"]}"#) + .expect_at_least(1) + .create_async() + .await; + let cache = ModelCapabilitiesCache::default(); + let client = reqwest::Client::new(); + let installed = vec!["fresh".to_string()]; + let result = reconcile_capabilities(&client, &cache, &server.url(), &installed).await; + assert!(result["fresh"].vision); + // Cache must now hold the fetched entry. + let guard = cache.0.lock().unwrap(); + assert!(guard.contains_key("fresh")); + assert!(guard["fresh"].vision); + } + + #[tokio::test] + async fn reconcile_drops_unreachable_misses_without_failing() { + let cache = ModelCapabilitiesCache::default(); + cache.0.lock().unwrap().insert( + "cached".to_string(), + Capabilities { + vision: true, + ..Default::default() + }, + ); + let client = reqwest::Client::new(); + let installed = vec!["cached".to_string(), "missing".to_string()]; + // Point base_url at a port nothing listens on so misses fail fast. + let result = + reconcile_capabilities(&client, &cache, "http://127.0.0.1:1", &installed).await; + assert!(result.contains_key("cached")); + assert!(!result.contains_key("missing")); + } + + #[tokio::test] + async fn reconcile_skips_misses_with_invalid_slugs() { + // Defense in depth: a compromised Ollama returning a slug with + // shell metacharacters or whitespace must be dropped before any + // network work, never make it into the `/api/show` request. + let mut server = mockito::Server::new_async().await; + let m = server + .mock("POST", "/api/show") + .with_status(200) + .with_body(r#"{"capabilities":["vision"]}"#) + .expect(0) + .create_async() + .await; + let cache = ModelCapabilitiesCache::default(); + let client = reqwest::Client::new(); + let installed = vec!["bad name".to_string(), "bad$(whoami)".to_string()]; + let result = reconcile_capabilities(&client, &cache, &server.url(), &installed).await; + assert!(result.is_empty()); + m.assert_async().await; + } + + #[tokio::test] + async fn reconcile_when_cache_poisoned_still_attempts_fetches() { + let mut server = mockito::Server::new_async().await; + let _m = server + .mock("POST", "/api/show") + .with_status(200) + .with_body(r#"{"capabilities":["vision"]}"#) + .create_async() + .await; + let cache = ModelCapabilitiesCache::default(); + // Poison the mutex so the read-snapshot branch falls back to + // treating every slug as a miss. + let cache_ref = std::panic::AssertUnwindSafe(&cache.0); + let _ = std::panic::catch_unwind(|| { + let _guard = cache_ref.0.lock().unwrap(); + panic!("poison"); + }); + let client = reqwest::Client::new(); + let installed = vec!["x".to_string()]; + let result = reconcile_capabilities(&client, &cache, &server.url(), &installed).await; + // Cache writes silently fail on the poisoned lock, but the + // result map still carries the freshly-fetched value. + assert!(result["x"].vision); + } +} diff --git a/src-tauri/src/onboarding.rs b/src-tauri/src/onboarding.rs index 18b1bc5a..497a47ad 100644 --- a/src-tauri/src/onboarding.rs +++ b/src-tauri/src/onboarding.rs @@ -5,10 +5,19 @@ * persisted value in the `app_config` table. * * Stages progress linearly: - * "permissions" -> "intro" -> "complete" + * "permissions" -> "model_check" -> "intro" -> "complete" * * "permissions" is the implicit default when no value has been written yet. - * Once "complete", onboarding is never shown again regardless of permissions. + * "model_check" gates the user on having Ollama running with at least one + * installed model. Both stages are skipped on every subsequent launch once + * advanced past. Once "complete", onboarding is never shown again regardless + * of permissions or installed models. + * + * Backward compatibility: existing installs with persisted stages of + * "permissions", "intro", or "complete" all parse correctly. The new + * "model_check" value is unknown to older installs but the file format is + * forward-compatible (unknown stages fall back to Permissions, the safe + * default that re-runs the full flow). */ use rusqlite::Connection; @@ -19,18 +28,26 @@ use crate::database::{get_config, set_config}; const STAGE_KEY: &str = "onboarding_stage"; /// Serializable stage value sent to the frontend via the onboarding event. +/// +/// Variants are emitted in `snake_case` for the frontend to match the +/// `OnboardingStage` TypeScript union exactly. The persisted SQLite value +/// uses the same string form, so the on-disk format is identical to the +/// wire format. #[derive(Debug, Clone, PartialEq, serde::Serialize)] -#[serde(rename_all = "lowercase")] +#[serde(rename_all = "snake_case")] pub enum OnboardingStage { Permissions, + ModelCheck, Intro, Complete, } /// Reads the persisted onboarding stage. Returns `Permissions` if no value -/// has been written yet (i.e. first-ever launch). +/// has been written yet (first-ever launch) or if the persisted value is +/// not recognised (forward-compatible with future stage names). pub fn get_stage(conn: &Connection) -> rusqlite::Result { match get_config(conn, STAGE_KEY)?.as_deref() { + Some("model_check") => Ok(OnboardingStage::ModelCheck), Some("intro") => Ok(OnboardingStage::Intro), Some("complete") => Ok(OnboardingStage::Complete), _ => Ok(OnboardingStage::Permissions), @@ -41,6 +58,7 @@ pub fn get_stage(conn: &Connection) -> rusqlite::Result { pub fn set_stage(conn: &Connection, stage: &OnboardingStage) -> rusqlite::Result<()> { let value = match stage { OnboardingStage::Permissions => "permissions", + OnboardingStage::ModelCheck => "model_check", OnboardingStage::Intro => "intro", OnboardingStage::Complete => "complete", }; @@ -98,6 +116,13 @@ mod tests { assert_eq!(get_stage(&conn).unwrap(), OnboardingStage::Intro); } + #[test] + fn set_and_get_stage_round_trips_model_check() { + let conn = open_in_memory().unwrap(); + set_stage(&conn, &OnboardingStage::ModelCheck).unwrap(); + assert_eq!(get_stage(&conn).unwrap(), OnboardingStage::ModelCheck); + } + #[test] fn set_and_get_stage_round_trips_complete() { let conn = open_in_memory().unwrap(); @@ -105,6 +130,48 @@ mod tests { assert_eq!(get_stage(&conn).unwrap(), OnboardingStage::Complete); } + #[test] + fn get_stage_falls_back_to_permissions_on_unknown_value() { + // Forward-compat guard: if a future build wrote an unrecognised + // stage and the user downgrades, we must safely re-run the flow + // rather than panic or pick an arbitrary stage. + let conn = open_in_memory().unwrap(); + crate::database::set_config(&conn, STAGE_KEY, "future_stage").unwrap(); + assert_eq!(get_stage(&conn).unwrap(), OnboardingStage::Permissions); + } + + #[test] + fn compute_startup_stage_shows_model_check_when_stage_is_model_check() { + let conn = open_in_memory().unwrap(); + set_stage(&conn, &OnboardingStage::ModelCheck).unwrap(); + assert_eq!( + compute_startup_stage(&conn).unwrap(), + Some(OnboardingStage::ModelCheck) + ); + } + + #[test] + fn stage_serializes_to_snake_case_for_frontend() { + // Wire format must match the TypeScript OnboardingStage union exactly. + // Frontend routes on these strings, so any drift breaks the dispatch. + assert_eq!( + serde_json::to_string(&OnboardingStage::Permissions).unwrap(), + "\"permissions\"" + ); + assert_eq!( + serde_json::to_string(&OnboardingStage::ModelCheck).unwrap(), + "\"model_check\"" + ); + assert_eq!( + serde_json::to_string(&OnboardingStage::Intro).unwrap(), + "\"intro\"" + ); + assert_eq!( + serde_json::to_string(&OnboardingStage::Complete).unwrap(), + "\"complete\"" + ); + } + #[test] fn set_stage_overwrites_previous_value() { let conn = open_in_memory().unwrap(); diff --git a/src-tauri/src/search/mod.rs b/src-tauri/src/search/mod.rs index 62ea789e..05c34e90 100644 --- a/src-tauri/src/search/mod.rs +++ b/src-tauri/src/search/mod.rs @@ -17,6 +17,7 @@ use tokio_util::sync::CancellationToken; use crate::commands::{ConversationHistory, GenerationState}; use crate::config::AppConfig; +use crate::models::ActiveModelState; pub mod chunker; pub mod config; @@ -59,6 +60,7 @@ pub async fn search_pipeline( generation: State<'_, GenerationState>, history: State<'_, ConversationHistory>, app_config: State<'_, AppConfig>, + active_model_state: State<'_, ActiveModelState>, ) -> Result<(), String> { // Resolve the runtime search view from the loaded TOML. The single // source of truth lives in `config::defaults`; the loader has already @@ -66,6 +68,14 @@ pub async fn search_pipeline( let runtime_config = config::SearchRuntimeConfig::from_app_config(&app_config); let searxng_endpoint = runtime_config.searxng_endpoint(); + // Snapshot the active model slug once from the picker-backed + // ActiveModelState; drop the guard before any `.await` so we never + // hold a `MutexGuard` across an await point. + let model_name = { + let guard = active_model_state.0.lock().map_err(|e| e.to_string())?; + guard.clone() + }; + // Pre-flight: verify both sandbox services are reachable before touching // the LLM or SearXNG. A 2-second probe prevents a long wait when the // containers are simply not running. @@ -84,7 +94,6 @@ pub async fn search_pipeline( "{}/api/chat", app_config.model.ollama_url.trim_end_matches('/') ); - let active_model = app_config.model.active().to_string(); let cancel_token = CancellationToken::new(); generation.set_token(cancel_token.clone()); @@ -92,7 +101,7 @@ pub async fn search_pipeline( let router = pipeline::DefaultRouterJudge::new( ollama_endpoint.clone(), - active_model.clone(), + model_name.clone(), (*client).clone(), cancel_token.clone(), today.clone(), @@ -100,7 +109,7 @@ pub async fn search_pipeline( ); let judge = pipeline::DefaultJudge::new( ollama_endpoint.clone(), - active_model.clone(), + model_name.clone(), (*client).clone(), cancel_token.clone(), runtime_config.judge_timeout_s, @@ -110,7 +119,7 @@ pub async fn search_pipeline( &ollama_endpoint, &searxng_endpoint, &runtime_config.reader_url, - &active_model, + &model_name, &client, cancel_token.clone(), &app_config.prompt.resolved_system, diff --git a/src/App.css b/src/App.css index 522bd923..e64b8e1c 100644 --- a/src/App.css +++ b/src/App.css @@ -170,6 +170,21 @@ body { scrollbar-gutter: stable; } +/* ─── AskBar Textarea ─── + * Hide the native scrollbar so the textarea's wrapping width matches the + * highlight mirror div behind it. Without this, once content exceeds the + * 144px cap and the textarea begins scrolling, the system scrollbar + * consumes a few pixels of content width, causing wrapped lines in the + * textarea and the mirror to diverge — the caret then floats above the + * visible text. The hook syncs scrollTop so caret-follow still works. + */ +.askbar-textarea { + scrollbar-width: none; +} +.askbar-textarea::-webkit-scrollbar { + display: none; +} + /* ─── Markdown Body: Prose Defaults ─── * Tailwind's preflight resets list-style to none globally. Streamdown adds no * list CSS of its own, so
    /
      inside .markdown-body lose their markers. diff --git a/src/App.tsx b/src/App.tsx index 3932450f..12b412e8 100644 --- a/src/App.tsx +++ b/src/App.tsx @@ -4,6 +4,7 @@ import { useState, useEffect, useCallback, + useMemo, useRef, useLayoutEffect, } from 'react'; @@ -14,11 +15,16 @@ import { LogicalSize } from '@tauri-apps/api/dpi'; import { useOllama } from './hooks/useOllama'; import type { Message } from './hooks/useOllama'; import { useConversationHistory } from './hooks/useConversationHistory'; +import { useModelSelection } from './hooks/useModelSelection'; +import { useModelCapabilities } from './hooks/useModelCapabilities'; +import { getCapabilityConflict } from './utils/capabilityConflicts'; +import { Toast } from './components/Toast'; import { ConversationView } from './view/ConversationView'; import { AskBarView, MAX_IMAGES } from './view/AskBarView'; import { OnboardingView } from './view/onboarding/index'; import type { OnboardingStage } from './view/onboarding/index'; import { HistoryPanel } from './components/HistoryPanel'; +import { ModelPickerPanel } from './components/ModelPickerPanel'; import { ImagePreviewModal } from './components/ImagePreviewModal'; import type { AttachedImage } from './types/image'; import { MAX_IMAGE_SIZE_BYTES } from './types/image'; @@ -30,6 +36,9 @@ import { } from './config/commands'; import './App.css'; +/** Fallback model name used before get_model_picker_state resolves at startup. */ +const DEFAULT_MODEL_FALLBACK = 'gemma4:e2b'; + const OVERLAY_VISIBILITY_EVENT = 'thuki://visibility'; const ONBOARDING_EVENT = 'thuki://onboarding'; @@ -109,6 +118,9 @@ function App() { * but rendered differently based on `isChatMode`). */ const [isHistoryOpen, setIsHistoryOpen] = useState(false); + + /** Whether the model picker panel is currently open. Mutually exclusive with `isHistoryOpen`. */ + const [isModelPickerOpen, setIsModelPickerOpen] = useState(false); /** * True when the user clicked + while an unsaved conversation is active. * Causes the history dropdown to show a SwitchConfirmation prompt instead @@ -123,6 +135,44 @@ function App() { */ const morphingContainerNodeRef = useRef(null); + const { activeModel, availableModels, refreshModels, setActiveModel } = + useModelSelection(); + + const { capabilities: modelCapabilities, refresh: refreshModelCapabilities } = + useModelCapabilities(); + + /** Capability flags for the currently active model, or undefined if not loaded yet. */ + const activeModelCapabilities = activeModel + ? modelCapabilities[activeModel] + : undefined; + + /** + * Toast text shown by the submit-time capability gate. Set to a non-null + * string when the user attempts to send a message whose attached content + * the active model cannot handle (e.g. images on a text-only model). + * Cleared by the toast's auto-dismiss or on next submit attempt. + */ + const [capabilityToast, setCapabilityToast] = useState(null); + + /** + * Pulses true to trigger the ask-bar shake animation when the + * submit-time gate refuses a message, then resets so the next blocked + * submit gets its own animation. Reset is set just over the 500 ms + * keyframe duration in `AskBarView` so the bar never snaps back + * mid-animation if React schedules the state flip on the exact frame + * Framer is finishing. + */ + const [shakeAskBar, setShakeAskBar] = useState(false); + useEffect(() => { + if (!shakeAskBar) return; + const timer = setTimeout(() => setShakeAskBar(false), 600); + return () => clearTimeout(timer); + }, [shakeAskBar]); + + const dismissCapabilityToast = useCallback(() => { + setCapabilityToast(null); + }, []); + const { conversationId, isSaved, @@ -158,7 +208,7 @@ function App() { searchStage, reset, loadMessages, - } = useOllama(handleTurnComplete); + } = useOllama(activeModel, handleTurnComplete); /** * Sticky flag: once the user invokes `/search`, subsequent submits in the @@ -362,6 +412,16 @@ function App() { } }, [isGenerating]); + /* eslint-disable @eslint-react/set-state-in-effect -- intentional: close + the picker when the user triggers generation so it can't stay open over + a streaming response. No secondary effects are triggered by this reset. */ + useEffect(() => { + if (isGenerating || isSubmitPending) { + setIsModelPickerOpen(false); + } + }, [isGenerating, isSubmitPending]); + /* eslint-enable @eslint-react/set-state-in-effect */ + /** * Replays the entrance sequence by transitioning the overlay to the visible state. * Clears conversation state for a fresh session each time the overlay appears. @@ -390,6 +450,7 @@ function App() { setQuery(''); setSelectedContext(context); setIsHistoryOpen(false); + setIsModelPickerOpen(false); setAttachedImages((prev) => { for (const img of prev) URL.revokeObjectURL(img.blobUrl); return []; @@ -402,11 +463,12 @@ function App() { setCaptureError(null); setSearchActive(false); + void refreshModels(); reset(); resetHistory(); setOverlayState('visible'); }, - [reset, resetHistory], + [reset, resetHistory, refreshModels], ); /** @@ -437,9 +499,39 @@ function App() { /** Ref attached to the chat-mode history dropdown for click-outside detection. */ const historyDropdownRef = useRef(null); - /** Toggles the history panel open/closed. */ + /** Ref attached to the chat-mode model picker dropdown for click-outside detection. */ + const modelPickerDropdownRef = useRef(null); + /** Ref attached to the ask-bar mode model picker drawer for click-outside detection. */ + const modelPickerAskBarRef = useRef(null); + + /** + * Close the model picker when the user clicks outside it, in either mode. + * Clicks on any pill trigger (data-model-picker-toggle) are excluded so the + * trigger's own onClick can manage the toggle without a double-close race. + */ + useEffect(() => { + if (!isModelPickerOpen) return; + + const handleMouseDown = (e: MouseEvent) => { + const target = e.target as Element; + if ( + modelPickerDropdownRef.current?.contains(target) || + modelPickerAskBarRef.current?.contains(target) || + target.closest?.('[data-model-picker-toggle]') + ) { + return; + } + setIsModelPickerOpen(false); + }; + + document.addEventListener('mousedown', handleMouseDown); + return () => document.removeEventListener('mousedown', handleMouseDown); + }, [isModelPickerOpen]); + + /** Toggles the history panel open/closed. Closes model picker (mutually exclusive). */ const handleHistoryToggle = useCallback(() => { setIsHistoryOpen((prev) => !prev); + setIsModelPickerOpen(false); }, []); /** @@ -575,12 +667,14 @@ function App() { if (isSaved) { await unsave(); } else { - await save(messages); + // activeModel is empty string until the model picker hook resolves on first + // load; fall back to the bootstrap default during that brief window. + await save(messages, activeModel || DEFAULT_MODEL_FALLBACK); } } catch { // State stays unchanged on failure; feedback is implicit in the icon. } - }, [isSaved, unsave, save, messages]); + }, [isSaved, unsave, save, messages, activeModel]); /** * Loads a conversation from history, replacing the current session. @@ -616,7 +710,7 @@ function App() { const handleSaveAndLoad = useCallback( async (id: string) => { try { - await save(messages); + await save(messages, activeModel || DEFAULT_MODEL_FALLBACK); } catch { // Save failed - abort to avoid leaving the current session unprotected. return; @@ -631,7 +725,7 @@ function App() { setIsHistoryOpen(false); } }, - [save, messages, loadConversation, loadMessages], + [save, messages, loadConversation, loadMessages, activeModel], ); /** @@ -690,12 +784,12 @@ function App() { /** Saves the current conversation then starts a fresh one. */ const handleSaveAndNew = useCallback(async () => { try { - await save(messages); + await save(messages, activeModel || DEFAULT_MODEL_FALLBACK); } catch { return; } resetForNewConversation(); - }, [save, messages, resetForNewConversation]); + }, [save, messages, resetForNewConversation, activeModel]); /** Discards the current conversation and starts a fresh one. */ const handleJustNew = useCallback(() => { @@ -984,6 +1078,23 @@ function App() { ], ); + /** + * Live capability conflict for the current compose state. Drives the + * inline `CapabilityMismatchStrip` so the user sees the mismatch as + * soon as incompatible content lands in compose, not only at submit + * time. The strip is purely informational: recovery happens through + * the model picker chip. + */ + const liveCapabilityConflictMessage = useMemo(() => { + const trimmed = query.trim(); + const { found } = parseCommands(trimmed); + return getCapabilityConflict(activeModel, activeModelCapabilities, { + hasScreenCommand: found.has('/screen'), + hasThinkCommand: found.has('/think'), + imageCount: attachedImages.length, + }); + }, [query, attachedImages, activeModel, activeModelCapabilities]); + const handleSubmit = useCallback(() => { if ( (query.trim().length === 0 && attachedImages.length === 0) || @@ -1001,6 +1112,18 @@ function App() { const hasThink = found.has('/think'); const hasSearch = found.has('/search'); + // Submit-time capability gate. Refuses messages whose attached content + // the active model cannot handle (images on a text-only model). The + // gate is the only gate: input affordances stay live so the user can + // compose freely and recover via the model picker chip. When refused + // the ask bar shakes and a toast surfaces the reason. Compose state is + // preserved so the user does not lose their typing. + if (liveCapabilityConflictMessage !== null) { + setShakeAskBar(true); + setCapabilityToast(liveCapabilityConflictMessage); + return; + } + // `/search` entry point AND sticky follow-ups. Once a search turn is in // flight, subsequent submits without an explicit slash command continue // to route through the backend search pipeline so the LLM can clarify, @@ -1179,6 +1302,7 @@ function App() { askSearch, searchActive, quote.maxContextLength, + liveCapabilityConflictMessage, ]); // When a pending submit exists and all images finish processing, fire it. @@ -1275,6 +1399,49 @@ function App() { requestAnimationFrame(() => inputRef.current?.focus()); }, [isSubmitPending, cancel, setSearchActive, setSelectedContext]); + /** + * Persists the user's model choice via the backend and closes the picker panel. + * On rejection (e.g. the chosen model was uninstalled between render and click), + * triggers a refresh so the picker list and the active chip resync with the + * actual backend state instead of silently drifting. + */ + const handleModelSelect = useCallback( + (model: string) => { + setIsModelPickerOpen(false); + void setActiveModel(model).catch(() => { + void refreshModels(); + }); + }, + [setActiveModel, refreshModels], + ); + + /** Closes the model picker panel. Wired to Escape key inside the panel. */ + const handleModelPickerClose = useCallback(() => { + setIsModelPickerOpen(false); + }, []); + + /** + * Toggles the model picker panel. Closes history panel (mutually exclusive). + * + * On open we re-pull both the installed-model list and the per-model + * capability map so newly-pulled models (e.g. user ran `ollama pull + * deepseek-r1:1.5b` while Thuki was running) appear with their full + * capability label without needing an app restart. Backend + * `reconcile_capabilities` honors its cache for already-known slugs and + * only fetches `/api/show` for genuinely new entries, so this is cheap. + */ + const handleModelPickerToggle = useCallback(() => { + setIsModelPickerOpen((prev) => { + const opening = !prev; + if (opening) { + void refreshModels(); + void refreshModelCapabilities(); + } + return opening; + }); + setIsHistoryOpen(false); + }, [refreshModels, refreshModelCapabilities]); + /** * Synchronizes the React animation state with Tauri-driven overlay visibility * requests emitted from the Rust backend. @@ -1468,7 +1635,7 @@ function App() { : 'rounded-2xl shadow-bar' }`} > - {/* Chat Messages Area - morphs in when in chat mode */} + {/* Chat Messages Area - morphs in when in chat mode. */} {isChatMode ? ( ) : null} + {/* Ask-bar mode model picker drawer - above the input bar. + In chat mode the trigger and drawer move to the header area above. */} + {!isChatMode && ( + + {isModelPickerOpen && + activeModel && + availableModels && + availableModels.length > 0 ? ( + + + + ) : null} + + )} + {/* Ask-bar mode history panel - inline below the input bar. The !isChatMode gate lives OUTSIDE AnimatePresence so that when a conversation is loaded (isChatMode → true) the panel unmounts @@ -1559,9 +1765,50 @@ function App() { onImagePreview={handleAskBarImagePreview} onScreenshot={handleScreenshot} isDragOver={isDragOver ?? undefined} + activeModel={activeModel} + availableModels={availableModels} + onModelPickerToggle={handleModelPickerToggle} + isModelPickerOpen={isModelPickerOpen} + capabilityConflictMessage={liveCapabilityConflictMessage} + shake={shakeAskBar} + /> + + {/* Chat-mode model picker dropdown - floating card identical in style + to the chat-history dropdown. Anchored absolute right-3 top-10 + so it appears just below the header pill trigger without pushing + the conversation content. Click-outside closes it. */} + + {isChatMode && + isModelPickerOpen && + activeModel && + availableModels && + availableModels.length > 0 ? ( + + + + ) : null} + + {/* Chat-mode history dropdown - sibling of the morphing container so it is never clipped by its overflow-hidden. Positioned absolutely within this relative wrapper (same coordinate space as the diff --git a/src/__tests__/App.test.tsx b/src/__tests__/App.test.tsx index 0a2f5c1a..a0c816df 100644 --- a/src/__tests__/App.test.tsx +++ b/src/__tests__/App.test.tsx @@ -28,6 +28,441 @@ describe('App', () => { enableChannelCapture(); }); + it('fetches model picker state on mount and refreshes it when the overlay shows', async () => { + invoke.mockReset(); + enableChannelCaptureWithResponses({ + get_model_picker_state: { + active: 'gemma4:e2b', + all: ['gemma4:e2b', 'qwen2.5:7b'], + }, + }); + + render(); + await act(async () => {}); + + expect(invoke).toHaveBeenCalledWith('get_model_picker_state'); + + invoke.mockClear(); + + await showOverlay(); + + expect(invoke).toHaveBeenCalledWith('get_model_picker_state'); + }); + + it('renders the model picker when the overlay is visible and models load', async () => { + enableChannelCaptureWithResponses({ + get_model_picker_state: { + active: 'gemma4:e2b', + all: ['gemma4:e2b', 'qwen2.5:7b'], + }, + }); + + render(); + await act(async () => {}); + await showOverlay(); + + expect( + screen.getByRole('button', { name: 'Choose model' }), + ).toBeInTheDocument(); + }); + + it('saves the conversation with the currently selected model', async () => { + enableChannelCaptureWithResponses({ + get_model_picker_state: { + active: 'gemma4:e2b', + all: ['gemma4:e2b', 'qwen2.5:7b'], + }, + save_conversation: { conversation_id: 'conv-1' }, + generate_title: undefined, + set_active_model: undefined, + }); + + render(); + await act(async () => {}); + await showOverlay(); + + fireEvent.click(screen.getByRole('button', { name: 'Choose model' })); + await act(async () => { + fireEvent.click(screen.getByRole('option', { name: 'qwen2.5:7b' })); + }); + + const textarea = screen.getByPlaceholderText('Ask Thuki anything...'); + fireEvent.change(textarea, { target: { value: 'hello there' } }); + fireEvent.keyDown(textarea, { key: 'Enter', shiftKey: false }); + await act(async () => {}); + + act(() => { + getLastChannel()?.simulateMessage({ type: 'Token', data: 'Hi there!' }); + getLastChannel()?.simulateMessage({ type: 'Done' }); + }); + + await act(async () => { + fireEvent.click(screen.getByLabelText('Save conversation')); + }); + + // The picker selection is threaded into `generate_title` (which uses the + // active slug as the title-generation model) and stamped onto the + // assistant message via `model_name`. `save_conversation` itself does + // not take a top-level `model` arg; the active model is sourced + // backend-side from the loaded TOML AppConfig. + expect(invoke).toHaveBeenCalledWith( + 'generate_title', + expect.objectContaining({ model: 'qwen2.5:7b' }), + ); + }); + + it('opens model picker panel when trigger is clicked', async () => { + enableChannelCaptureWithResponses({ + get_model_picker_state: { + active: 'gemma4:e2b', + all: ['gemma4:e2b', 'qwen2.5:7b'], + }, + }); + render(); + await act(async () => {}); + await showOverlay(); + + fireEvent.click(screen.getByRole('button', { name: 'Choose model' })); + await act(async () => {}); + + expect( + screen.getByRole('option', { name: 'gemma4:e2b' }), + ).toBeInTheDocument(); + expect( + screen.getByRole('option', { name: 'qwen2.5:7b' }), + ).toBeInTheDocument(); + }); + + it('closes model picker and opens history when history toggle clicked', async () => { + enableChannelCaptureWithResponses({ + get_model_picker_state: { + active: 'gemma4:e2b', + all: ['gemma4:e2b', 'qwen2.5:7b'], + }, + list_conversations: [], + }); + render(); + await act(async () => {}); + await showOverlay(); + + fireEvent.click(screen.getByRole('button', { name: 'Choose model' })); + await act(async () => {}); + expect( + screen.getByRole('option', { name: 'gemma4:e2b' }), + ).toBeInTheDocument(); + + fireEvent.click(screen.getByRole('button', { name: 'Open history' })); + await act(async () => {}); + expect(screen.queryByRole('option', { name: 'gemma4:e2b' })).toBeNull(); + expect( + screen.getByPlaceholderText(/search past chats/i), + ).toBeInTheDocument(); + }); + + it('closes history and opens model picker when model picker trigger clicked', async () => { + enableChannelCaptureWithResponses({ + get_model_picker_state: { + active: 'gemma4:e2b', + all: ['gemma4:e2b', 'qwen2.5:7b'], + }, + list_conversations: [], + }); + render(); + await act(async () => {}); + await showOverlay(); + + fireEvent.click(screen.getByRole('button', { name: 'Open history' })); + await act(async () => {}); + expect( + screen.getByPlaceholderText(/search past chats/i), + ).toBeInTheDocument(); + + fireEvent.click(screen.getByRole('button', { name: 'Choose model' })); + await act(async () => {}); + expect(screen.queryByPlaceholderText(/search past chats/i)).toBeNull(); + expect( + screen.getByRole('option', { name: 'gemma4:e2b' }), + ).toBeInTheDocument(); + }); + + it('closes model picker when a model is selected', async () => { + enableChannelCaptureWithResponses({ + get_model_picker_state: { + active: 'gemma4:e2b', + all: ['gemma4:e2b', 'qwen2.5:7b'], + }, + set_active_model: undefined, + }); + render(); + await act(async () => {}); + await showOverlay(); + + fireEvent.click(screen.getByRole('button', { name: 'Choose model' })); + await act(async () => {}); + expect( + screen.getByRole('option', { name: 'qwen2.5:7b' }), + ).toBeInTheDocument(); + + fireEvent.click(screen.getByRole('option', { name: 'qwen2.5:7b' })); + await act(async () => {}); + expect(screen.queryByRole('option', { name: 'qwen2.5:7b' })).toBeNull(); + }); + + it('closes model picker when the trigger is clicked while open', async () => { + enableChannelCaptureWithResponses({ + get_model_picker_state: { + active: 'gemma4:e2b', + all: ['gemma4:e2b', 'qwen2.5:7b'], + }, + }); + render(); + await act(async () => {}); + await showOverlay(); + + const trigger = screen.getByRole('button', { name: 'Choose model' }); + fireEvent.click(trigger); + await act(async () => {}); + expect( + screen.getByRole('option', { name: 'gemma4:e2b' }), + ).toBeInTheDocument(); + + // Second click on the trigger toggles the panel closed; this exercises + // the "opening = false" branch of handleModelPickerToggle. + fireEvent.click(trigger); + await act(async () => {}); + expect(screen.queryByRole('option', { name: 'gemma4:e2b' })).toBeNull(); + }); + + it('closes model picker when generation starts', async () => { + enableChannelCaptureWithResponses({ + get_model_picker_state: { + active: 'gemma4:e2b', + all: ['gemma4:e2b', 'qwen2.5:7b'], + }, + }); + render(); + await act(async () => {}); + await showOverlay(); + + fireEvent.click(screen.getByRole('button', { name: 'Choose model' })); + await act(async () => {}); + expect( + screen.getByRole('option', { name: 'gemma4:e2b' }), + ).toBeInTheDocument(); + + const textarea = screen.getByPlaceholderText('Ask Thuki anything...'); + fireEvent.change(textarea, { target: { value: 'hi' } }); + fireEvent.keyDown(textarea, { key: 'Enter', shiftKey: false }); + await act(async () => {}); + + expect(screen.queryByRole('option', { name: 'gemma4:e2b' })).toBeNull(); + }); + + it('shows active model pill in chat mode header and opens picker from there', async () => { + enableChannelCaptureWithResponses({ + get_model_picker_state: { + active: 'gemma4:e2b', + all: ['gemma4:e2b', 'qwen2.5:7b'], + }, + }); + render(); + await act(async () => {}); + await showOverlay(); + + // Transition to chat mode by submitting a message + const textarea = screen.getByPlaceholderText('Ask Thuki anything...'); + fireEvent.change(textarea, { target: { value: 'hi' } }); + fireEvent.keyDown(textarea, { key: 'Enter', shiftKey: false }); + await act(async () => {}); + + act(() => { + getLastChannel()?.simulateMessage({ type: 'Token', data: 'Hello!' }); + getLastChannel()?.simulateMessage({ type: 'Done' }); + }); + await act(async () => {}); + + // Pill button should now be in the header (WindowControls), showing the model name + const pill = screen.getByRole('button', { name: 'Choose model' }); + expect(pill).toBeInTheDocument(); + expect(pill.textContent).toContain('gemma4:e2b'); + + // Click pill → model picker panel opens ABOVE the conversation + fireEvent.click(pill); + await act(async () => {}); + expect( + screen.getByRole('option', { name: 'gemma4:e2b' }), + ).toBeInTheDocument(); + expect( + screen.getByRole('option', { name: 'qwen2.5:7b' }), + ).toBeInTheDocument(); + }); + + it('closes chat-mode model picker when clicking outside the dropdown', async () => { + enableChannelCaptureWithResponses({ + get_model_picker_state: { + active: 'gemma4:e2b', + all: ['gemma4:e2b', 'qwen2.5:7b'], + }, + }); + render(); + await act(async () => {}); + await showOverlay(); + + const textarea = screen.getByPlaceholderText('Ask Thuki anything...'); + fireEvent.change(textarea, { target: { value: 'hi' } }); + fireEvent.keyDown(textarea, { key: 'Enter', shiftKey: false }); + await act(async () => {}); + act(() => { + getLastChannel()?.simulateMessage({ type: 'Token', data: 'Hello!' }); + getLastChannel()?.simulateMessage({ type: 'Done' }); + }); + await act(async () => {}); + + fireEvent.click(screen.getByRole('button', { name: 'Choose model' })); + await act(async () => {}); + expect( + screen.getByRole('option', { name: 'gemma4:e2b' }), + ).toBeInTheDocument(); + + fireEvent.mouseDown(document.body); + await act(async () => {}); + expect(screen.queryByRole('option', { name: 'gemma4:e2b' })).toBeNull(); + }); + + it('chat-mode click-outside does NOT close when clicking inside the dropdown or on the pill', async () => { + enableChannelCaptureWithResponses({ + get_model_picker_state: { + active: 'gemma4:e2b', + all: ['gemma4:e2b', 'qwen2.5:7b'], + }, + }); + render(); + await act(async () => {}); + await showOverlay(); + + const textarea = screen.getByPlaceholderText('Ask Thuki anything...'); + fireEvent.change(textarea, { target: { value: 'hi' } }); + fireEvent.keyDown(textarea, { key: 'Enter', shiftKey: false }); + await act(async () => {}); + act(() => { + getLastChannel()?.simulateMessage({ type: 'Token', data: 'Hello!' }); + getLastChannel()?.simulateMessage({ type: 'Done' }); + }); + await act(async () => {}); + + const pill = screen.getByRole('button', { name: 'Choose model' }); + fireEvent.click(pill); + await act(async () => {}); + const option = screen.getByRole('option', { name: 'gemma4:e2b' }); + expect(option).toBeInTheDocument(); + + // mousedown inside the dropdown must not close the picker + fireEvent.mouseDown(option); + await act(async () => {}); + expect( + screen.getByRole('option', { name: 'gemma4:e2b' }), + ).toBeInTheDocument(); + + // mousedown on the pill trigger must not close the picker either + fireEvent.mouseDown(pill); + await act(async () => {}); + expect( + screen.getByRole('option', { name: 'gemma4:e2b' }), + ).toBeInTheDocument(); + }); + + it('ask-bar mode click-outside closes the model picker drawer', async () => { + enableChannelCaptureWithResponses({ + get_model_picker_state: { + active: 'gemma4:e2b', + all: ['gemma4:e2b', 'qwen2.5:7b'], + }, + }); + render(); + await act(async () => {}); + await showOverlay(); + + fireEvent.click(screen.getByRole('button', { name: 'Choose model' })); + await act(async () => {}); + expect( + screen.getByRole('option', { name: 'gemma4:e2b' }), + ).toBeInTheDocument(); + + // Clicking inside the drawer must NOT close it + fireEvent.mouseDown(screen.getByRole('option', { name: 'gemma4:e2b' })); + await act(async () => {}); + expect( + screen.getByRole('option', { name: 'gemma4:e2b' }), + ).toBeInTheDocument(); + + // Clicking outside closes the drawer + fireEvent.mouseDown(document.body); + await act(async () => {}); + expect(screen.queryByRole('option', { name: 'gemma4:e2b' })).toBeNull(); + }); + + it('refreshes model list when set_active_model rejects', async () => { + let rejectionSeen = false; + let refreshesAfterRejection = 0; + invoke.mockImplementation(async (cmd: string) => { + if (cmd === 'get_model_picker_state') { + if (rejectionSeen) { + refreshesAfterRejection += 1; + return { active: 'gemma4:e2b', all: ['gemma4:e2b'] }; + } + return { active: 'gemma4:e2b', all: ['gemma4:e2b', 'qwen2.5:7b'] }; + } + if (cmd === 'set_active_model') { + rejectionSeen = true; + throw new Error('Model is not installed in Ollama: qwen2.5:7b'); + } + return undefined; + }); + render(); + await act(async () => {}); + await showOverlay(); + + fireEvent.click(screen.getByRole('button', { name: 'Choose model' })); + await act(async () => {}); + fireEvent.click(screen.getByRole('option', { name: 'qwen2.5:7b' })); + await act(async () => {}); + + // The rejection handler must have triggered at least one refresh fetch. + expect(refreshesAfterRejection).toBeGreaterThanOrEqual(1); + + // Reopen to confirm the list is the post-refresh one (qwen was removed). + fireEvent.click(screen.getByRole('button', { name: 'Choose model' })); + await act(async () => {}); + expect( + screen.getByRole('option', { name: 'gemma4:e2b' }), + ).toBeInTheDocument(); + expect(screen.queryByRole('option', { name: 'qwen2.5:7b' })).toBeNull(); + }); + + it('closes the model picker drawer when Escape is pressed in the filter input', async () => { + enableChannelCaptureWithResponses({ + get_model_picker_state: { + active: 'gemma4:e2b', + all: ['gemma4:e2b', 'qwen2.5:7b'], + }, + }); + render(); + await act(async () => {}); + await showOverlay(); + + fireEvent.click(screen.getByRole('button', { name: 'Choose model' })); + await act(async () => {}); + expect( + screen.getByRole('option', { name: 'gemma4:e2b' }), + ).toBeInTheDocument(); + + fireEvent.keyDown(screen.getByPlaceholderText(/filter models/i), { + key: 'Escape', + }); + await act(async () => {}); + expect(screen.queryByRole('option', { name: 'gemma4:e2b' })).toBeNull(); + }); + it('grows upward when near bottom screen edge', async () => { const { container } = render(); await act(async () => {}); @@ -658,6 +1093,10 @@ describe('App', () => { it('closes history panel when a conversation is loaded', async () => { enableChannelCaptureWithResponses({ + get_model_picker_state: { + active: 'gemma4:e2b', + all: ['gemma4:e2b'], + }, list_conversations: [], }); @@ -2790,6 +3229,187 @@ describe('App', () => { }); }); + // ─── Capability gate (vision mismatch) ───────────────────────────────────── + + describe('capability gate', () => { + /** Helper: paste an image file into the textarea and wait for thumbnails. */ + async function pasteImage() { + const textarea = screen.getByPlaceholderText('Ask Thuki anything...'); + const file = new File(['fake-img-data'], 'photo.png', { + type: 'image/png', + }); + const clipboardData = { + items: [{ type: 'image/png', getAsFile: () => file }], + }; + await act(async () => { + fireEvent.paste(textarea, { clipboardData }); + }); + await vi.waitFor(() => { + expect( + screen.getByRole('list', { name: /attached images/i }), + ).toBeInTheDocument(); + }); + } + + it('shows the live mismatch strip when a text-only model has an image attached', async () => { + enableChannelCaptureWithResponses({ + get_model_picker_state: { + active: 'llama3', + all: ['llama3', 'llama3.2-vision'], + }, + get_model_capabilities: { + llama3: { + vision: false, + thinking: false, + }, + 'llama3.2-vision': { + vision: true, + thinking: false, + }, + }, + save_image_command: '/tmp/staged/img1.jpg', + }); + render(); + await act(async () => {}); + await showOverlay(); + await pasteImage(); + await vi.waitFor(() => { + expect( + screen.getByTestId('capability-mismatch-strip'), + ).toBeInTheDocument(); + }); + expect(screen.getByTestId('capability-mismatch-strip')).toHaveTextContent( + 'llama3 reads text only', + ); + }); + + it('refuses submit and surfaces a toast when a text-only model has an image attached', async () => { + enableChannelCaptureWithResponses({ + get_model_picker_state: { + active: 'llama3', + all: ['llama3'], + }, + get_model_capabilities: { + llama3: { + vision: false, + thinking: false, + }, + }, + save_image_command: '/tmp/staged/img1.jpg', + }); + render(); + await act(async () => {}); + await showOverlay(); + await pasteImage(); + await act(async () => { + await vi.waitFor(() => { + expect(invoke).toHaveBeenCalledWith( + 'save_image_command', + expect.anything(), + ); + }); + }); + + // Type and submit. + const textarea = screen.getByPlaceholderText('Ask Thuki anything...'); + act(() => { + fireEvent.change(textarea, { target: { value: 'summarise these' } }); + }); + invoke.mockClear(); + await act(async () => { + fireEvent.click(screen.getByRole('button', { name: /send message/i })); + }); + + // Toast surfaces the reason. + await vi.waitFor(() => { + expect(screen.getByTestId('toast')).toHaveTextContent( + 'llama3 reads text only', + ); + }); + // ask_ollama is NOT invoked. + const askInvocations = invoke.mock.calls.filter( + (call) => call[0] === 'ask_ollama', + ); + expect(askInvocations.length).toBe(0); + // Compose state survives. + expect(screen.getByPlaceholderText('Ask Thuki anything...')).toHaveValue( + 'summarise these', + ); + }); + + it('toast auto-dismisses after the default duration', async () => { + vi.useFakeTimers(); + try { + enableChannelCaptureWithResponses({ + get_model_picker_state: { active: 'llama3', all: ['llama3'] }, + get_model_capabilities: { + llama3: { + vision: false, + thinking: false, + }, + }, + save_image_command: '/tmp/staged/img1.jpg', + }); + render(); + await act(async () => {}); + await showOverlay(); + // Paste (real timers were running; pasteImage uses waitFor which + // works under fake timers if we advance them). + const textarea = screen.getByPlaceholderText('Ask Thuki anything...'); + const file = new File(['x'], 'p.png', { type: 'image/png' }); + await act(async () => { + fireEvent.paste(textarea, { + clipboardData: { + items: [{ type: 'image/png', getAsFile: () => file }], + }, + }); + }); + await act(async () => { + await vi.advanceTimersByTimeAsync(50); + }); + // Submit with no text but an image (canSubmit honors images alone). + await act(async () => { + fireEvent.click( + screen.getByRole('button', { name: /send message/i }), + ); + }); + await act(async () => { + await vi.advanceTimersByTimeAsync(10); + }); + expect(screen.queryByTestId('toast')).not.toBeNull(); + // Advance past the 3000ms default duration. + await act(async () => { + await vi.advanceTimersByTimeAsync(3100); + }); + expect(screen.queryByTestId('toast')).toBeNull(); + } finally { + vi.useRealTimers(); + } + }); + + it('does not gate submit when the active model has vision', async () => { + enableChannelCaptureWithResponses({ + get_model_picker_state: { + active: 'llama3.2-vision', + all: ['llama3.2-vision'], + }, + get_model_capabilities: { + 'llama3.2-vision': { + vision: true, + thinking: false, + }, + }, + save_image_command: '/tmp/staged/img1.jpg', + }); + render(); + await act(async () => {}); + await showOverlay(); + await pasteImage(); + // Strip must not appear. + expect(screen.queryByTestId('capability-mismatch-strip')).toBeNull(); + }); + }); + // ─── Screenshot integration ──────────────────────────────────────────────── describe('screenshot integration', () => { @@ -4479,13 +5099,13 @@ describe('App', () => { emitTauriEvent('thuki://onboarding', { stage: 'intro' }); }); - expect(screen.getByText('Before you dive in')).toBeInTheDocument(); + expect(screen.getByText("You're all set")).toBeInTheDocument(); await act(async () => { fireEvent.click(screen.getByRole('button', { name: /get started/i })); }); - expect(screen.queryByText('Before you dive in')).toBeNull(); + expect(screen.queryByText("You're all set")).toBeNull(); }); }); }); diff --git a/src/components/CapabilityMismatchStrip.tsx b/src/components/CapabilityMismatchStrip.tsx new file mode 100644 index 00000000..bb329813 --- /dev/null +++ b/src/components/CapabilityMismatchStrip.tsx @@ -0,0 +1,48 @@ +/** Props for the {@link CapabilityMismatchStrip} component. */ +export interface CapabilityMismatchStripProps { + /** + * Human-readable reason rendered as the strip body. The strip renders + * only when this is a non-empty string. + */ + message: string; +} + +/** + * Inline informational strip that surfaces a capability mismatch between + * the user's compose state (image attached, `/screen` queued) and the + * active model. Passive: the strip carries no action button. Recovery + * happens through the existing model picker chip in WindowControls so + * the picker remains the single source of truth for switching models. + * + * The host is responsible for rendering the strip only when there is a + * real conflict (use `getCapabilityConflict` to compute the message). + * The strip itself does not animate; the host can wrap it in + * AnimatePresence if a fade-in / fade-out is desired. + */ +export function CapabilityMismatchStrip({ + message, +}: CapabilityMismatchStripProps) { + return ( +
      +
      + ); +} diff --git a/src/components/ChatBubble.tsx b/src/components/ChatBubble.tsx index b7c7b3b3..b6e6aa73 100644 --- a/src/components/ChatBubble.tsx +++ b/src/components/ChatBubble.tsx @@ -76,6 +76,38 @@ function avatarColor(domain: string): string { /** Regex matching inline `[N]` citation markers in plain text. Captures the N. */ const CITATION_RE = /\[(\d+)\]/g; +/** + * Hoisted static SVG glyph for the model attribution chip. Mirrors the + * chip icon used by the model picker so the attribution visually couples + * to the picker UI. Rendered as a child of a color-controlled span. + */ +const ATTRIB_CHIP_ICON = ( + +); + /** * Walks the rendered answer DOM and replaces every plain-text `[N]` occurrence * with an anchor element that links to the matching source URL. Called inside @@ -228,6 +260,8 @@ interface ChatBubbleProps { /** Whether the search pipeline is currently running. When true, renders a * `SearchTraceBlock` in loading state even before any traces arrive. */ isSearching?: boolean; + /** When set on an assistant message, renders a chip-style attribution badge beside the CopyButton so the user sees which model produced this response. */ + modelName?: string; } /** @@ -277,6 +311,7 @@ export function ChatBubble({ sandboxUnavailable = false, searchTraces, isSearching = false, + modelName, }: ChatBubbleProps) { const isUser = role === 'user'; const [sourcesOpen, setSourcesOpen] = useState(false); @@ -534,6 +569,19 @@ export function ChatBubble({ )} + {/* Model attribution chip: visually couples the response to the + model-picker UI so users can see which model produced it. */} + {modelName && ( + + + {ATTRIB_CHIP_ICON} + + {modelName} + + )} )} diff --git a/src/components/ModelPicker.tsx b/src/components/ModelPicker.tsx new file mode 100644 index 00000000..2db24fa8 --- /dev/null +++ b/src/components/ModelPicker.tsx @@ -0,0 +1,60 @@ +const CHIP_ICON = ( + +); + +/** Props for the {@link ModelPicker} trigger button. */ +export interface ModelPickerProps { + /** Called when the user clicks the trigger to toggle the picker panel. */ + onClick: () => void; + /** When true, the button is inert (e.g. during generation). */ + disabled: boolean; + /** Reflects whether the picker panel is currently open (drives aria-expanded). */ + isOpen: boolean; +} + +/** + * Chip-style trigger button that opens/closes the model picker panel. + * + * The panel itself is rendered by App.tsx as an inline drawer (same + * grow/shrink animation as the history panel) so the ResizeObserver drives + * natural window growth without any portal or frame-manipulation logic. + */ +export function ModelPicker({ onClick, disabled, isOpen }: ModelPickerProps) { + return ( + + ); +} diff --git a/src/components/ModelPickerPanel.tsx b/src/components/ModelPickerPanel.tsx new file mode 100644 index 00000000..a0ba50f2 --- /dev/null +++ b/src/components/ModelPickerPanel.tsx @@ -0,0 +1,285 @@ +import { useEffect, useMemo, useRef, useState } from 'react'; +import { invoke } from '@tauri-apps/api/core'; +import type { ModelCapabilitiesMap } from '../types/model'; +import { Tooltip } from './Tooltip'; + +/** + * Public Ollama library URL opened by the "Browse Ollama" pill. Lives + * here as a module constant so tests can match it without importing the + * pill's render path. + */ +export const OLLAMA_LIBRARY_URL = 'https://ollama.com/library'; + +/** + * Tooltip body shown when the user hovers the pill. Multi-line so the + * Tooltip component renders it as the wider variant and the user gets + * the full sentence in one read. + */ +export const OLLAMA_PILL_TOOLTIP = + 'Browse and pull any model on Ollama. Thuki auto-detects it.'; + +const CHECK_ICON_PATH = ( + +); + +const LISTBOX_ID = 'thuki-model-picker-listbox'; + +/** + * Builds the capability caption rendered beneath each picker row's model + * name. Always leads with "text" (every chat-completion model handles + * text), then appends "vision" and/or "thinking" when the model supports + * them. Returns `null` only when capabilities for the model are unknown + * (not yet loaded), which lets the caller suppress the caption line + * entirely during cold start. + * + * Exported for direct unit testing. + */ +export function formatCapabilityLabel( + capabilities: ModelCapabilitiesMap | undefined, + model: string, +): string | null { + const caps = capabilities?.[model]; + if (!caps) return null; + const flags: string[] = ['text']; + if (caps.vision) flags.push('vision'); + if (caps.thinking) flags.push('thinking'); + return flags.join(' · '); +} + +/** Props for the {@link ModelPickerPanel} content panel. */ +export interface ModelPickerPanelProps { + /** Full list of available model slugs. */ + models: string[]; + /** Currently active model slug. */ + activeModel: string; + /** Called with the chosen slug when the user clicks or keyboard-selects a row. */ + onSelect: (model: string) => void; + /** + * Called when the user presses Escape inside the panel. The host is + * responsible for closing the drawer/dropdown in response. + */ + onClose?: () => void; + /** + * Per-model capability map keyed by slug. When provided, each row + * renders a small capability suffix ("· vision · thinking"). Omit or + * pass an empty map to render plain rows (legacy / loading states). + */ + capabilities?: ModelCapabilitiesMap; + /** + * When true, the picker is rendered inside the chat-mode chip drawer + * (narrower, ~224px wide) and the Browse Ollama pill drops the + * "Ollama" word so the row stays uncluttered. The tooltip still + * spells out the full meaning on hover. Defaults to false (overlay + * mode, full-width "Browse Ollama" label). + */ + compact?: boolean; +} + +/** + * Inline model picker panel rendered as a drawer above the ask bar or as a + * floating dropdown in chat mode. + * + * Combobox-style keyboard model: focus stays in the filter input, ArrowDown/ + * ArrowUp move the `aria-activedescendant` marker through the visible rows, + * Enter commits the highlighted row, and Escape asks the host to close. + */ +export function ModelPickerPanel({ + models, + activeModel, + onSelect, + onClose, + capabilities, + compact = false, +}: ModelPickerPanelProps) { + const [filter, setFilter] = useState(''); + const [highlightedIndex, setHighlightedIndex] = useState(0); + const listboxRef = useRef(null); + + const filtered = useMemo(() => { + const trimmed = filter.trim(); + if (trimmed === '') return models; + const needle = trimmed.toLowerCase(); + return models.filter((m) => m.toLowerCase().includes(needle)); + }, [filter, models]); + + // Inline clamp: derive the safe render index without a useEffect so + // aria-activedescendant is consistent on the same render that filtered shrinks. + const safeHighlightedIndex = + filtered.length === 0 ? 0 : Math.min(highlightedIndex, filtered.length - 1); + + const activeId = + filtered.length > 0 + ? `${LISTBOX_ID}-option-${safeHighlightedIndex}` + : undefined; + + // Keep the highlighted row visible when it scrolls off-view. scrollIntoView + // is absent in happy-dom/jsdom; the optional call becomes a no-op there. + useEffect(() => { + if (!activeId) return; + const el = listboxRef.current?.querySelector(`#${activeId}`); + /* v8 ignore next -- scrollIntoView is a host API not available in jsdom */ + el?.scrollIntoView?.({ block: 'nearest' }); + }, [activeId]); + + const commit = (index: number) => { + if (index < 0 || index >= filtered.length) return; + onSelect(filtered[index]); + }; + + return ( +
      +
      + setFilter(e.target.value)} + onKeyDown={(e) => { + if (e.key === 'ArrowDown') { + e.preventDefault(); + if (filtered.length === 0) return; + setHighlightedIndex((i) => (i + 1) % filtered.length); + return; + } + if (e.key === 'ArrowUp') { + e.preventDefault(); + if (filtered.length === 0) return; + setHighlightedIndex( + (i) => (i - 1 + filtered.length) % filtered.length, + ); + return; + } + if (e.key === 'Home') { + e.preventDefault(); + if (filtered.length > 0) setHighlightedIndex(0); + return; + } + if (e.key === 'End') { + e.preventDefault(); + if (filtered.length > 0) setHighlightedIndex(filtered.length - 1); + return; + } + if (e.key === 'Enter') { + e.preventDefault(); + commit(safeHighlightedIndex); + return; + } + if (e.key === 'Escape') { + e.preventDefault(); + onClose?.(); + return; + } + }} + placeholder="Filter models..." + autoFocus + className="flex-1 min-w-0 bg-transparent text-xs text-text-primary placeholder:text-text-secondary outline-none" + /> + + + +
      + +
      + {models.length === 0 ? ( +

      + No models available. +

      + ) : filtered.length === 0 ? ( +

      + No models found. +

      + ) : ( + filtered.map((model, index) => { + const active = model === activeModel; + const highlighted = index === safeHighlightedIndex; + const capLabel = formatCapabilityLabel(capabilities, model); + return ( + + ); + }) + )} +
      +
      + ); +} diff --git a/src/components/Toast.tsx b/src/components/Toast.tsx new file mode 100644 index 00000000..d2132125 --- /dev/null +++ b/src/components/Toast.tsx @@ -0,0 +1,61 @@ +import { useEffect } from 'react'; + +/** Props for the {@link Toast} component. */ +export interface ToastProps { + /** Body text shown in the toast. The toast renders only when truthy. */ + message: string | null; + /** Called when the auto-dismiss timer fires or the user closes the toast. */ + onDismiss: () => void; + /** Auto-dismiss delay in ms. Defaults to 3000. */ + durationMs?: number; +} + +const DEFAULT_DURATION_MS = 3000; + +/** + * Bottom-anchored transient toast used by the submit-time capability + * gate. Renders nothing when `message` is null. Schedules a single + * auto-dismiss timer per non-null `message` and clears it on unmount or + * before the next message replaces it. + * + * Positioning is `absolute` against the nearest positioned ancestor; the + * caller wraps it in a `relative` container to anchor. + */ +export function Toast({ + message, + onDismiss, + durationMs = DEFAULT_DURATION_MS, +}: ToastProps) { + useEffect(() => { + if (!message) return; + const timer = setTimeout(onDismiss, durationMs); + return () => clearTimeout(timer); + }, [message, onDismiss, durationMs]); + + if (!message) return null; + + return ( +
      +
      + ); +} diff --git a/src/components/Tooltip.tsx b/src/components/Tooltip.tsx index c5dc26e1..00e3b66f 100644 --- a/src/components/Tooltip.tsx +++ b/src/components/Tooltip.tsx @@ -50,9 +50,13 @@ export function Tooltip({ /* v8 ignore stop */ const rect = triggerRef.current.getBoundingClientRect(); const rawLeft = rect.left + rect.width / 2; - // Conservative half-width estimate. Single-line tooltips fit "Conversation - // history" worst-case; multiline tooltips may use the full 320px max width. - const tooltipHalfWidth = multiline ? 160 : 90; + // Half-width estimate matched to the rendered max-width of each + // variant. Single-line tooltips fit "Conversation history" + // worst-case (~180px wide). Multiline tooltips render at + // max-w-[220px], so a 110px halfWidth keeps the centered box + // directly under the trigger even when the trigger sits near the + // right edge of a typical Thuki overlay (600px wide). + const tooltipHalfWidth = multiline ? 110 : 90; const edgePadding = 8; const left = Math.max( tooltipHalfWidth + edgePadding, @@ -71,6 +75,12 @@ export function Tooltip({ setIsVisible(true); }; + /** + * Hides the tooltip. Fired on both `mouseleave` and `mousedown` so a click + * on a tooltipped trigger that opens a popup (e.g. the model picker) + * dismisses the tooltip instead of letting it overlap the popup. The + * tooltip reappears normally on the next fresh hover. + */ const handleMouseLeave = () => { setIsVisible(false); }; @@ -86,6 +96,7 @@ export function Tooltip({ ref={triggerRef} onMouseEnter={handleMouseEnter} onMouseLeave={handleMouseLeave} + onMouseDown={handleMouseLeave} className={`inline-flex${className ? ` ${className}` : ''}`} > {children} @@ -126,9 +137,10 @@ export function Tooltip({ className="absolute -top-1.5 h-3 w-3 -translate-x-1/2 rotate-45 border-l border-t border-surface-border bg-surface-base" />
      diff --git a/src/components/WindowControls.tsx b/src/components/WindowControls.tsx index a95f4aee..db706819 100644 --- a/src/components/WindowControls.tsx +++ b/src/components/WindowControls.tsx @@ -66,6 +66,24 @@ const NEW_CONVERSATION_ICON = ( ); +/** Hoisted chip icon for the active-model pill trigger. */ +const CHIP_ICON = ( + +); + /** Hoisted history (clock) icon. */ const HISTORY_ICON = ( void; + /** + * Currently active model slug displayed in the pill trigger. + * Requires `onModelPickerToggle` to be present; omit either to hide the pill. + */ + activeModel?: string; + /** + * Called when the user clicks the active-model pill to open/close the picker. + * Requires `activeModel` to be present; omit either to hide the pill. + */ + onModelPickerToggle?: () => void; + /** Drives `aria-expanded` on the pill button. */ + isModelPickerOpen?: boolean; } /** Decorative dot color for inactive buttons. */ @@ -124,6 +154,9 @@ export const WindowControls = memo(function WindowControls({ canSave = false, onHistoryOpen, onNewConversation, + activeModel, + onModelPickerToggle, + isModelPickerOpen = false, }: WindowControlsProps) { // Disabled only when there is nothing to save yet and the conversation hasn't // been saved. Once saved the button stays active so the user can unsave. @@ -173,8 +206,44 @@ export const WindowControls = memo(function WindowControls({ aria-hidden="true" /> - {/* Right-side header controls - save bookmark + history dropdown */} + {/* Right-side header controls */}
      + {/* Active model pill trigger: leftmost, before save */} + {activeModel !== undefined && onModelPickerToggle !== undefined && ( + + + + )} + {onSave !== undefined && ( {NEW_CONVERSATION_ICON} @@ -220,7 +289,7 @@ export const WindowControls = memo(function WindowControls({ onClick={onHistoryOpen} aria-label="Open history" data-history-toggle - className="w-7 h-7 flex items-center justify-center rounded-lg text-text-secondary hover:text-text-primary hover:bg-white/5 transition-colors duration-150 cursor-pointer" + className="w-7 h-7 flex items-center justify-center rounded-lg text-text-secondary hover:text-primary hover:bg-primary/8 transition-colors duration-150 cursor-pointer" > {HISTORY_ICON} diff --git a/src/components/__tests__/CapabilityMismatchStrip.test.tsx b/src/components/__tests__/CapabilityMismatchStrip.test.tsx new file mode 100644 index 00000000..35b4e5f2 --- /dev/null +++ b/src/components/__tests__/CapabilityMismatchStrip.test.tsx @@ -0,0 +1,19 @@ +import { render, screen } from '@testing-library/react'; +import { describe, it, expect } from 'vitest'; +import { CapabilityMismatchStrip } from '../CapabilityMismatchStrip'; + +describe('CapabilityMismatchStrip', () => { + it('renders the message verbatim', () => { + render(); + const strip = screen.getByTestId('capability-mismatch-strip'); + expect(strip).toHaveTextContent("llama3 can't see images."); + }); + + it('exposes role=status for assistive tech', () => { + render(); + expect(screen.getByTestId('capability-mismatch-strip')).toHaveAttribute( + 'role', + 'status', + ); + }); +}); diff --git a/src/components/__tests__/ChatBubble.test.tsx b/src/components/__tests__/ChatBubble.test.tsx index 7858e637..3b723e86 100644 --- a/src/components/__tests__/ChatBubble.test.tsx +++ b/src/components/__tests__/ChatBubble.test.tsx @@ -1111,4 +1111,77 @@ describe('ChatBubble', () => { ).toBeTruthy(); }); }); + + describe('model attribution', () => { + it('renders the attribution chip when modelName is provided on assistant messages', () => { + render( + , + ); + const chip = screen.getByTestId('model-attribution'); + expect(chip).toBeInTheDocument(); + expect(chip).toHaveTextContent('gemma4:e2b'); + }); + + it('does not render the attribution chip when modelName is absent', () => { + render(); + expect(screen.queryByTestId('model-attribution')).toBeNull(); + }); + + it('does not render the attribution chip on user messages even with modelName', () => { + render( + , + ); + expect(screen.queryByTestId('model-attribution')).toBeNull(); + }); + + it('does not render the attribution chip when the message is an error callout', () => { + render( + , + ); + expect(screen.queryByTestId('model-attribution')).toBeNull(); + }); + + it('does not render the attribution chip when sandbox is unavailable', () => { + render( + , + ); + expect(screen.queryByTestId('model-attribution')).toBeNull(); + }); + + it('does not render the attribution chip while the assistant is still streaming', () => { + render( + , + ); + // Footer row including the attribution is hidden during streaming. + expect(screen.queryByTestId('model-attribution')).toBeNull(); + }); + }); }); diff --git a/src/components/__tests__/ModelPicker.test.tsx b/src/components/__tests__/ModelPicker.test.tsx new file mode 100644 index 00000000..a3c661db --- /dev/null +++ b/src/components/__tests__/ModelPicker.test.tsx @@ -0,0 +1,53 @@ +import { render, screen, fireEvent } from '@testing-library/react'; +import { describe, it, expect, vi } from 'vitest'; +import { ModelPicker } from '../ModelPicker'; + +function renderTrigger( + overrides: Partial> = {}, +) { + const props: React.ComponentProps = { + onClick: vi.fn(), + disabled: false, + isOpen: false, + ...overrides, + }; + return { props, ...render() }; +} + +describe('ModelPicker', () => { + it('renders the Choose model trigger button with chip icon', () => { + const { container } = renderTrigger(); + const trigger = screen.getByRole('button', { name: 'Choose model' }); + expect(trigger).toBeInTheDocument(); + expect(container.querySelector('svg')).not.toBeNull(); + }); + + it('sets aria-expanded false when isOpen is false', () => { + renderTrigger({ isOpen: false }); + expect( + screen.getByRole('button', { name: 'Choose model' }), + ).toHaveAttribute('aria-expanded', 'false'); + }); + + it('sets aria-expanded true when isOpen is true', () => { + renderTrigger({ isOpen: true }); + expect( + screen.getByRole('button', { name: 'Choose model' }), + ).toHaveAttribute('aria-expanded', 'true'); + }); + + it('calls onClick when clicked', () => { + const onClick = vi.fn(); + renderTrigger({ onClick }); + fireEvent.click(screen.getByRole('button', { name: 'Choose model' })); + expect(onClick).toHaveBeenCalledTimes(1); + }); + + it('is disabled and does not call onClick when disabled', () => { + const onClick = vi.fn(); + renderTrigger({ disabled: true, onClick }); + fireEvent.click(screen.getByRole('button', { name: 'Choose model' })); + expect(onClick).not.toHaveBeenCalled(); + expect(screen.getByRole('button', { name: 'Choose model' })).toBeDisabled(); + }); +}); diff --git a/src/components/__tests__/ModelPickerPanel.test.tsx b/src/components/__tests__/ModelPickerPanel.test.tsx new file mode 100644 index 00000000..7eb3a394 --- /dev/null +++ b/src/components/__tests__/ModelPickerPanel.test.tsx @@ -0,0 +1,370 @@ +import { render, screen, fireEvent } from '@testing-library/react'; +import { describe, it, expect, vi } from 'vitest'; +import { + ModelPickerPanel, + formatCapabilityLabel, + OLLAMA_LIBRARY_URL, + OLLAMA_PILL_TOOLTIP, +} from '../ModelPickerPanel'; +import type { ModelCapabilitiesMap } from '../../types/model'; +import { invoke } from '@tauri-apps/api/core'; + +vi.mock('@tauri-apps/api/core', () => ({ + invoke: vi.fn(), +})); + +const MODELS = ['gemma4:e2b', 'qwen2.5:7b', 'llama3.2:3b']; + +function renderPanel( + overrides: Partial> = {}, +) { + const props: React.ComponentProps = { + models: MODELS, + activeModel: 'gemma4:e2b', + onSelect: vi.fn(), + ...overrides, + }; + return { props, ...render() }; +} + +describe('ModelPickerPanel', () => { + it('renders filter input', () => { + renderPanel(); + expect(screen.getByPlaceholderText(/filter models/i)).toBeInTheDocument(); + }); + + it('shows all models on first render', () => { + renderPanel(); + for (const model of MODELS) { + expect(screen.getByRole('option', { name: model })).toBeInTheDocument(); + } + }); + + it('marks active model with aria-selected true, others false', () => { + renderPanel({ activeModel: 'qwen2.5:7b' }); + expect(screen.getByRole('option', { name: 'qwen2.5:7b' })).toHaveAttribute( + 'aria-selected', + 'true', + ); + expect(screen.getByRole('option', { name: 'gemma4:e2b' })).toHaveAttribute( + 'aria-selected', + 'false', + ); + expect(screen.getByRole('option', { name: 'llama3.2:3b' })).toHaveAttribute( + 'aria-selected', + 'false', + ); + }); + + it('shows visible checkmark on active model, hidden on others', () => { + renderPanel({ activeModel: 'gemma4:e2b' }); + const activeItem = screen.getByRole('option', { name: 'gemma4:e2b' }); + const inactiveItem = screen.getByRole('option', { name: 'qwen2.5:7b' }); + const activeCheck = activeItem.querySelector('svg')!; + const inactiveCheck = inactiveItem.querySelector('svg')!; + expect((activeCheck as SVGElement).style.opacity).toBe('1'); + expect((inactiveCheck as SVGElement).style.opacity).toBe('0'); + }); + + it('calls onSelect with slug when row clicked', () => { + const onSelect = vi.fn(); + renderPanel({ onSelect }); + fireEvent.click(screen.getByRole('option', { name: 'qwen2.5:7b' })); + expect(onSelect).toHaveBeenCalledWith('qwen2.5:7b'); + expect(onSelect).toHaveBeenCalledTimes(1); + }); + + it('filters models as user types', () => { + renderPanel(); + fireEvent.change(screen.getByPlaceholderText(/filter models/i), { + target: { value: 'qwen' }, + }); + expect( + screen.getByRole('option', { name: 'qwen2.5:7b' }), + ).toBeInTheDocument(); + expect(screen.queryByRole('option', { name: 'gemma4:e2b' })).toBeNull(); + expect(screen.queryByRole('option', { name: 'llama3.2:3b' })).toBeNull(); + }); + + it('shows no-models-found message when filter matches nothing', () => { + renderPanel(); + fireEvent.change(screen.getByPlaceholderText(/filter models/i), { + target: { value: 'zzz' }, + }); + expect(screen.getByText(/no models found/i)).toBeInTheDocument(); + expect(screen.queryByRole('option')).toBeNull(); + }); + + it('restores full list when filter is cleared', () => { + renderPanel(); + const input = screen.getByPlaceholderText(/filter models/i); + fireEvent.change(input, { target: { value: 'qwen' } }); + fireEvent.change(input, { target: { value: '' } }); + for (const model of MODELS) { + expect(screen.getByRole('option', { name: model })).toBeInTheDocument(); + } + }); + + it('shows no-models-available message when models list is empty', () => { + renderPanel({ models: [] }); + expect(screen.getByText(/no models available/i)).toBeInTheDocument(); + expect(screen.queryByRole('option')).toBeNull(); + }); + + it('marks the filter input as an aria-activedescendant combobox', () => { + renderPanel(); + const input = screen.getByPlaceholderText(/filter models/i); + expect(input).toHaveAttribute('role', 'combobox'); + expect(input).toHaveAttribute( + 'aria-activedescendant', + expect.stringContaining('option-0'), + ); + }); + + it('ArrowDown advances the highlighted descendant', () => { + renderPanel(); + const input = screen.getByPlaceholderText(/filter models/i); + fireEvent.keyDown(input, { key: 'ArrowDown' }); + expect(input).toHaveAttribute( + 'aria-activedescendant', + expect.stringContaining('option-1'), + ); + }); + + it('ArrowUp wraps to the last row from the first', () => { + renderPanel(); + const input = screen.getByPlaceholderText(/filter models/i); + fireEvent.keyDown(input, { key: 'ArrowUp' }); + expect(input).toHaveAttribute( + 'aria-activedescendant', + expect.stringContaining(`option-${MODELS.length - 1}`), + ); + }); + + it('Home/End jump to the first and last rows', () => { + renderPanel(); + const input = screen.getByPlaceholderText(/filter models/i); + fireEvent.keyDown(input, { key: 'End' }); + expect(input).toHaveAttribute( + 'aria-activedescendant', + expect.stringContaining(`option-${MODELS.length - 1}`), + ); + fireEvent.keyDown(input, { key: 'Home' }); + expect(input).toHaveAttribute( + 'aria-activedescendant', + expect.stringContaining('option-0'), + ); + }); + + it('Enter commits the highlighted row via onSelect', () => { + const onSelect = vi.fn(); + renderPanel({ onSelect }); + const input = screen.getByPlaceholderText(/filter models/i); + fireEvent.keyDown(input, { key: 'ArrowDown' }); + fireEvent.keyDown(input, { key: 'Enter' }); + expect(onSelect).toHaveBeenCalledWith('qwen2.5:7b'); + }); + + it('Escape fires onClose when provided', () => { + const onClose = vi.fn(); + renderPanel({ onClose }); + const input = screen.getByPlaceholderText(/filter models/i); + fireEvent.keyDown(input, { key: 'Escape' }); + expect(onClose).toHaveBeenCalledTimes(1); + }); + + it('Escape without onClose is a safe no-op', () => { + const onSelect = vi.fn(); + renderPanel({ onSelect }); + const input = screen.getByPlaceholderText(/filter models/i); + fireEvent.keyDown(input, { key: 'Escape' }); + // Escape must never select a model. + expect(onSelect).not.toHaveBeenCalled(); + // Focus must remain on the filter input. + expect(document.activeElement).toBe(screen.getByRole('combobox')); + // Filter value must be unchanged (Escape does not clear input). + expect((document.activeElement as HTMLInputElement).value).toBe(''); + }); + + it('keyboard nav on empty filter result is a safe no-op', () => { + const onSelect = vi.fn(); + renderPanel({ onSelect }); + const input = screen.getByPlaceholderText(/filter models/i); + fireEvent.change(input, { target: { value: 'zzz' } }); + fireEvent.keyDown(input, { key: 'ArrowDown' }); + fireEvent.keyDown(input, { key: 'ArrowUp' }); + fireEvent.keyDown(input, { key: 'Home' }); + fireEvent.keyDown(input, { key: 'End' }); + fireEvent.keyDown(input, { key: 'Enter' }); + expect(onSelect).not.toHaveBeenCalled(); + expect(input).not.toHaveAttribute('aria-activedescendant'); + }); + + it('clamps highlighted index when the filtered list shrinks', () => { + renderPanel(); + const input = screen.getByPlaceholderText(/filter models/i); + fireEvent.keyDown(input, { key: 'End' }); + // Narrow the visible set to one row; the activedescendant must clamp to 0. + fireEvent.change(input, { target: { value: 'qwen' } }); + expect(input).toHaveAttribute( + 'aria-activedescendant', + expect.stringContaining('option-0'), + ); + }); + + it('mouse-over updates the highlighted descendant', () => { + renderPanel(); + fireEvent.mouseEnter(screen.getByRole('option', { name: 'llama3.2:3b' })); + const input = screen.getByPlaceholderText(/filter models/i); + expect(input).toHaveAttribute( + 'aria-activedescendant', + expect.stringContaining('option-2'), + ); + }); + + it('ignores other keys without preventing default or firing handlers', () => { + const onSelect = vi.fn(); + const onClose = vi.fn(); + renderPanel({ onSelect, onClose }); + const input = screen.getByPlaceholderText(/filter models/i); + fireEvent.keyDown(input, { key: 'a' }); + expect(onSelect).not.toHaveBeenCalled(); + expect(onClose).not.toHaveBeenCalled(); + }); + + it('renders capability labels per row when capabilities prop is provided', () => { + const capabilities: ModelCapabilitiesMap = { + 'gemma4:e2b': { + vision: true, + thinking: false, + }, + 'qwen2.5:7b': { + vision: false, + thinking: true, + }, + 'llama3.2:3b': { + vision: false, + thinking: false, + }, + }; + renderPanel({ capabilities }); + // Every row leads with "text" (every chat model handles text), then + // appends vision/thinking when supported. Plain models render just "text". + const labels = screen.getAllByTestId('model-capability-label'); + expect(labels.length).toBe(3); + expect(labels[0]).toHaveTextContent('text · vision'); + expect(labels[1]).toHaveTextContent('text · thinking'); + expect(labels[2]).toHaveTextContent('text'); + }); + + it('row aria-label includes capability summary when present', () => { + const capabilities: ModelCapabilitiesMap = { + 'gemma4:e2b': { + vision: true, + thinking: false, + }, + }; + renderPanel({ models: ['gemma4:e2b'], capabilities }); + const row = screen.getByRole('option', { + name: /gemma4:e2b, text, vision/i, + }); + expect(row).toBeInTheDocument(); + }); +}); + +describe('formatCapabilityLabel', () => { + it('returns null when capabilities map is undefined', () => { + expect(formatCapabilityLabel(undefined, 'x')).toBeNull(); + }); + + it('returns null when the model is not in the map', () => { + expect(formatCapabilityLabel({}, 'x')).toBeNull(); + }); + + it('returns "text" for plain models with no surface-worthy capabilities', () => { + const map: ModelCapabilitiesMap = { + x: { vision: false, thinking: false }, + }; + expect(formatCapabilityLabel(map, 'x')).toBe('text'); + }); + + it('leads with "text" and appends every supported flag, joined with " · "', () => { + const map: ModelCapabilitiesMap = { + x: { vision: true, thinking: true }, + }; + expect(formatCapabilityLabel(map, 'x')).toBe('text · vision · thinking'); + }); + + it('appends "vision" after the leading "text" when only vision is present', () => { + const map: ModelCapabilitiesMap = { + x: { vision: true, thinking: false }, + }; + expect(formatCapabilityLabel(map, 'x')).toBe('text · vision'); + }); + + it('appends "thinking" after the leading "text" when only thinking is present', () => { + const map: ModelCapabilitiesMap = { + x: { vision: false, thinking: true }, + }; + expect(formatCapabilityLabel(map, 'x')).toBe('text · thinking'); + }); +}); + +describe('ModelPickerPanel "Browse Ollama" pill', () => { + it('renders the Browse Ollama button next to the filter input', () => { + render( + , + ); + const pill = screen.getByTestId('model-picker-ollama-link'); + expect(pill).toBeInTheDocument(); + expect(pill).toHaveTextContent(/Browse Ollama/i); + expect(pill).toHaveAttribute('aria-label', 'Browse Ollama models'); + }); + + it('opens the Ollama library URL via open_url when clicked', () => { + render( + , + ); + fireEvent.click(screen.getByTestId('model-picker-ollama-link')); + expect(invoke).toHaveBeenCalledWith('open_url', { + url: OLLAMA_LIBRARY_URL, + }); + }); + + it('exports a stable Ollama library URL constant', () => { + expect(OLLAMA_LIBRARY_URL).toBe('https://ollama.com/library'); + }); + + it('exports a stable tooltip body constant', () => { + expect(OLLAMA_PILL_TOOLTIP).toMatch(/Browse and pull any model on Ollama/i); + expect(OLLAMA_PILL_TOOLTIP).toMatch(/Thuki auto-detects it/i); + }); + + it('uses no em dashes in the tooltip body', () => { + expect(OLLAMA_PILL_TOOLTIP).not.toContain('—'); + }); + + it('drops the "Ollama" word in compact mode so the chip drawer stays uncluttered', () => { + render( + , + ); + const pill = screen.getByTestId('model-picker-ollama-link'); + expect(pill).toHaveTextContent(/^Browse$/); + expect(pill).not.toHaveTextContent(/Ollama/); + // Aria-label still spells it out for assistive tech. + expect(pill).toHaveAttribute('aria-label', 'Browse Ollama models'); + }); +}); diff --git a/src/components/__tests__/Toast.test.tsx b/src/components/__tests__/Toast.test.tsx new file mode 100644 index 00000000..1821d575 --- /dev/null +++ b/src/components/__tests__/Toast.test.tsx @@ -0,0 +1,81 @@ +import { render, screen, act } from '@testing-library/react'; +import { describe, it, expect, vi, beforeEach, afterEach } from 'vitest'; +import { Toast } from '../Toast'; + +describe('Toast', () => { + beforeEach(() => { + vi.useFakeTimers(); + }); + afterEach(() => { + vi.useRealTimers(); + }); + + it('renders nothing when message is null', () => { + render( {}} />); + expect(screen.queryByTestId('toast')).toBeNull(); + }); + + it('renders the message when provided', () => { + render( {}} />); + expect(screen.getByTestId('toast')).toHaveTextContent('hello'); + }); + + it('auto-dismisses after the default 3000ms', () => { + const onDismiss = vi.fn(); + render(); + expect(onDismiss).not.toHaveBeenCalled(); + act(() => { + vi.advanceTimersByTime(3000); + }); + expect(onDismiss).toHaveBeenCalledTimes(1); + }); + + it('honors a custom durationMs', () => { + const onDismiss = vi.fn(); + render(); + act(() => { + vi.advanceTimersByTime(499); + }); + expect(onDismiss).not.toHaveBeenCalled(); + act(() => { + vi.advanceTimersByTime(1); + }); + expect(onDismiss).toHaveBeenCalledTimes(1); + }); + + it('clears the timer when message changes to null before timeout', () => { + const onDismiss = vi.fn(); + const { rerender } = render( + , + ); + act(() => { + vi.advanceTimersByTime(500); + }); + rerender(); + act(() => { + vi.advanceTimersByTime(2000); + }); + expect(onDismiss).not.toHaveBeenCalled(); + }); + + it('resets the timer when message changes to a new value', () => { + const onDismiss = vi.fn(); + const { rerender } = render( + , + ); + act(() => { + vi.advanceTimersByTime(900); + }); + rerender( + , + ); + act(() => { + vi.advanceTimersByTime(900); + }); + expect(onDismiss).not.toHaveBeenCalled(); + act(() => { + vi.advanceTimersByTime(200); + }); + expect(onDismiss).toHaveBeenCalledTimes(1); + }); +}); diff --git a/src/components/__tests__/Tooltip.test.tsx b/src/components/__tests__/Tooltip.test.tsx index f2c861bb..8a3c5a36 100644 --- a/src/components/__tests__/Tooltip.test.tsx +++ b/src/components/__tests__/Tooltip.test.tsx @@ -110,6 +110,21 @@ describe('Tooltip', () => { expect(wrapper?.classList.contains('inline-flex')).toBe(true); }); + it('hides on mouseDown so the tooltip does not overlap a popup the click opens', () => { + render( + + + , + ); + const wrapper = screen.getByRole('button', { + name: 'Trigger', + }).parentElement!; + fireEvent.mouseEnter(wrapper); + expect(screen.getByText('Choose model')).toBeInTheDocument(); + fireEvent.mouseDown(wrapper); + expect(screen.queryByText('Choose model')).not.toBeInTheDocument(); + }); + it('applies extra className to the wrapper div', () => { const { container } = render( @@ -121,4 +136,25 @@ describe('Tooltip', () => { expect(wrapper?.classList.contains('min-w-0')).toBe(true); expect(wrapper?.classList.contains('inline-flex')).toBe(true); }); + + it('renders multiline tooltips at a fixed 225px width so the box stays directly below the trigger near edges', () => { + render( + + + , + ); + const wrapper = screen.getByRole('button', { + name: 'Trigger', + }).parentElement!; + fireEvent.mouseEnter(wrapper); + const fixedBox = document.body.querySelector( + '[style*="position: fixed"]', + ) as HTMLElement | null; + expect(fixedBox).not.toBeNull(); + // The inner content div (under the fixed-positioned outer + motion + // wrapper) carries the explicit 225px width style. + const inner = fixedBox?.querySelector('div[style*="width"]'); + expect(inner).not.toBeNull(); + expect((inner as HTMLElement).style.width).toBe('225px'); + }); }); diff --git a/src/components/__tests__/WindowControls.test.tsx b/src/components/__tests__/WindowControls.test.tsx index 21e35f72..51e1a29d 100644 --- a/src/components/__tests__/WindowControls.test.tsx +++ b/src/components/__tests__/WindowControls.test.tsx @@ -72,4 +72,90 @@ describe('WindowControls', () => { ); expect(onSave).toHaveBeenCalledTimes(1); }); + + it('renders active model pill when activeModel and onModelPickerToggle provided', () => { + render( + , + ); + expect( + screen.getByRole('button', { name: 'Choose model' }), + ).toBeInTheDocument(); + expect(screen.getByText('gemma4:e2b')).toBeInTheDocument(); + }); + + it('hides model pill when activeModel is not provided', () => { + render(); + expect(screen.queryByRole('button', { name: 'Choose model' })).toBeNull(); + }); + + it('hides model pill when onModelPickerToggle is not provided', () => { + render(); + expect(screen.queryByRole('button', { name: 'Choose model' })).toBeNull(); + }); + + it('calls onModelPickerToggle when pill is clicked', () => { + const onModelPickerToggle = vi.fn(); + render( + , + ); + fireEvent.click(screen.getByRole('button', { name: 'Choose model' })); + expect(onModelPickerToggle).toHaveBeenCalledTimes(1); + }); + + it('sets aria-expanded false on pill when isModelPickerOpen is false', () => { + render( + , + ); + expect( + screen.getByRole('button', { name: 'Choose model' }), + ).toHaveAttribute('aria-expanded', 'false'); + }); + + it('sets aria-expanded true on pill when isModelPickerOpen is true', () => { + render( + , + ); + expect( + screen.getByRole('button', { name: 'Choose model' }), + ).toHaveAttribute('aria-expanded', 'true'); + }); + + it('pill renders before save button in DOM order', () => { + render( + , + ); + const pill = screen.getByRole('button', { name: 'Choose model' }); + const save = screen.getByRole('button', { name: 'Save conversation' }); + const relation = pill.compareDocumentPosition(save); + expect(relation & Node.DOCUMENT_POSITION_FOLLOWING).toBe( + Node.DOCUMENT_POSITION_FOLLOWING, + ); + }); }); diff --git a/src/hooks/__tests__/useConversationHistory.test.tsx b/src/hooks/__tests__/useConversationHistory.test.tsx index 0403cab9..c841f973 100644 --- a/src/hooks/__tests__/useConversationHistory.test.tsx +++ b/src/hooks/__tests__/useConversationHistory.test.tsx @@ -4,6 +4,8 @@ import { useConversationHistory } from '../useConversationHistory'; import { invoke } from '../../testUtils/mocks/tauri'; import type { Message } from '../useOllama'; +const MODEL = 'gemma4:e2b'; + const MESSAGES: Message[] = [ { id: 'u1', role: 'user', content: 'Hello', quotedText: undefined }, { id: 'a1', role: 'assistant', content: 'Hi there' }, @@ -32,7 +34,7 @@ describe('useConversationHistory', () => { const { result } = renderHook(() => useConversationHistory()); await act(async () => { - await result.current.save(MESSAGES); + await result.current.save(MESSAGES, MODEL); }); expect(invoke).toHaveBeenCalledWith('save_conversation', { @@ -46,6 +48,7 @@ describe('useConversationHistory', () => { search_sources: null, search_warnings: null, search_metadata: null, + model_name: null, }, { role: 'assistant', @@ -56,6 +59,7 @@ describe('useConversationHistory', () => { search_sources: null, search_warnings: null, search_metadata: null, + model_name: null, }, ], }); @@ -68,7 +72,7 @@ describe('useConversationHistory', () => { const { result } = renderHook(() => useConversationHistory()); await act(async () => { - await result.current.save(MESSAGES); + await result.current.save(MESSAGES, MODEL); }); expect(result.current.isSaved).toBe(true); @@ -82,7 +86,7 @@ describe('useConversationHistory', () => { const { result } = renderHook(() => useConversationHistory()); await act(async () => { - await result.current.save(MESSAGES); + await result.current.save(MESSAGES, MODEL); }); expect(invoke).toHaveBeenCalledWith('generate_title', { @@ -97,6 +101,7 @@ describe('useConversationHistory', () => { search_sources: null, search_warnings: null, search_metadata: null, + model_name: null, }, { role: 'assistant', @@ -107,8 +112,10 @@ describe('useConversationHistory', () => { search_sources: null, search_warnings: null, search_metadata: null, + model_name: null, }, ], + model: MODEL, }); }); @@ -119,13 +126,13 @@ describe('useConversationHistory', () => { const { result } = renderHook(() => useConversationHistory()); await act(async () => { - await result.current.save(MESSAGES); + await result.current.save(MESSAGES, MODEL); }); invoke.mockClear(); await act(async () => { - await result.current.save(MESSAGES); + await result.current.save(MESSAGES, MODEL); }); expect(invoke).not.toHaveBeenCalled(); @@ -148,7 +155,7 @@ describe('useConversationHistory', () => { const { result } = renderHook(() => useConversationHistory()); await act(async () => { - await result.current.save(MESSAGES); + await result.current.save(MESSAGES, MODEL); }); invoke.mockClear(); @@ -179,6 +186,7 @@ describe('useConversationHistory', () => { searchSources: null, searchWarnings: null, searchMetadata: null, + modelName: null, }); expect(invoke).toHaveBeenCalledWith('persist_message', { conversationId: 'conv-123', @@ -190,6 +198,7 @@ describe('useConversationHistory', () => { searchSources: null, searchWarnings: null, searchMetadata: null, + modelName: null, }); }); @@ -200,7 +209,7 @@ describe('useConversationHistory', () => { const { result } = renderHook(() => useConversationHistory()); await act(async () => { - await result.current.save(MESSAGES); + await result.current.save(MESSAGES, MODEL); }); invoke.mockClear(); @@ -280,12 +289,14 @@ describe('useConversationHistory', () => { role: 'user', content: 'Saved question', quotedText: undefined, + modelName: undefined, }, { id: 'm2', role: 'assistant', content: 'Saved answer', quotedText: 'ctx', + modelName: undefined, }, ]); }); @@ -385,7 +396,7 @@ describe('useConversationHistory', () => { const { result } = renderHook(() => useConversationHistory()); await act(async () => { - await result.current.save(MESSAGES); + await result.current.save(MESSAGES, MODEL); }); invoke.mockClear(); @@ -416,6 +427,7 @@ describe('useConversationHistory', () => { searchSources: null, searchWarnings: null, searchMetadata: null, + modelName: null, }); }); @@ -436,7 +448,7 @@ describe('useConversationHistory', () => { const { result } = renderHook(() => useConversationHistory()); await act(async () => { - await result.current.save(messagesWithWarnings); + await result.current.save(messagesWithWarnings, MODEL); }); expect(invoke).toHaveBeenCalledWith('save_conversation', { @@ -450,6 +462,7 @@ describe('useConversationHistory', () => { search_sources: null, search_warnings: null, search_metadata: null, + model_name: null, }, { role: 'assistant', @@ -460,6 +473,7 @@ describe('useConversationHistory', () => { search_sources: null, search_warnings: JSON.stringify(['reader_unavailable']), search_metadata: null, + model_name: null, }, ], }); @@ -482,7 +496,7 @@ describe('useConversationHistory', () => { const { result } = renderHook(() => useConversationHistory()); await act(async () => { - await result.current.save(messagesWithThinking); + await result.current.save(messagesWithThinking, MODEL); }); expect(invoke).toHaveBeenCalledWith('save_conversation', { @@ -496,6 +510,7 @@ describe('useConversationHistory', () => { search_sources: null, search_warnings: null, search_metadata: null, + model_name: null, }, { role: 'assistant', @@ -506,6 +521,7 @@ describe('useConversationHistory', () => { search_sources: null, search_warnings: null, search_metadata: null, + model_name: null, }, ], }); @@ -607,7 +623,7 @@ describe('useConversationHistory', () => { const { result } = renderHook(() => useConversationHistory()); await act(async () => { - await result.current.save(MESSAGES); + await result.current.save(MESSAGES, MODEL); }); invoke.mockClear(); @@ -643,7 +659,7 @@ describe('useConversationHistory', () => { const { result } = renderHook(() => useConversationHistory()); await act(async () => { - await result.current.save(MESSAGES); + await result.current.save(MESSAGES, MODEL); }); expect(result.current.isSaved).toBe(true); @@ -662,7 +678,7 @@ describe('useConversationHistory', () => { const { result } = renderHook(() => useConversationHistory()); await act(async () => { - await result.current.save(MESSAGES); + await result.current.save(MESSAGES, MODEL); }); expect(result.current.isSaved).toBe(true); @@ -686,7 +702,7 @@ describe('useConversationHistory', () => { const { result } = renderHook(() => useConversationHistory()); await act(async () => { - await result.current.save(MESSAGES); + await result.current.save(MESSAGES, MODEL); }); invoke.mockClear(); @@ -716,6 +732,7 @@ describe('useConversationHistory', () => { searchSources: null, searchWarnings: JSON.stringify(['reader_unavailable']), searchMetadata: null, + modelName: null, }); }); @@ -772,7 +789,7 @@ describe('useConversationHistory', () => { const { result } = renderHook(() => useConversationHistory()); await act(async () => { - await result.current.save(MESSAGES); + await result.current.save(MESSAGES, MODEL); }); invoke.mockClear(); @@ -819,7 +836,7 @@ describe('useConversationHistory', () => { const { result } = renderHook(() => useConversationHistory()); await act(async () => { - await result.current.save(messagesWithMeta); + await result.current.save(messagesWithMeta, MODEL); }); expect(invoke).toHaveBeenCalledWith( @@ -863,7 +880,7 @@ describe('useConversationHistory', () => { const { result } = renderHook(() => useConversationHistory()); await act(async () => { - await result.current.save(messagesWithTraces); + await result.current.save(messagesWithTraces, MODEL); }); expect(invoke).toHaveBeenCalledWith( @@ -886,7 +903,7 @@ describe('useConversationHistory', () => { const { result } = renderHook(() => useConversationHistory()); await act(async () => { - await result.current.save(MESSAGES); + await result.current.save(MESSAGES, MODEL); }); invoke.mockClear(); @@ -927,6 +944,7 @@ describe('useConversationHistory', () => { searchSources: null, searchWarnings: null, searchMetadata: JSON.stringify(metadata), + modelName: null, }); }); @@ -937,7 +955,7 @@ describe('useConversationHistory', () => { const { result } = renderHook(() => useConversationHistory()); await act(async () => { - await result.current.save(MESSAGES); + await result.current.save(MESSAGES, MODEL); }); invoke.mockClear(); @@ -973,6 +991,7 @@ describe('useConversationHistory', () => { searchSources: null, searchWarnings: null, searchMetadata: JSON.stringify(traces), + modelName: null, }); }); @@ -1320,4 +1339,120 @@ describe('useConversationHistory', () => { ]), ); }); + + // ─── model_name round trip ─────────────────────────────────────────────────── + + it('save() stamps model_name on payloads when Message has modelName', async () => { + invoke.mockResolvedValueOnce({ conversation_id: 'conv-model-save' }); + invoke.mockResolvedValue(undefined); + + const messagesWithModel: Message[] = [ + { id: 'u1', role: 'user', content: 'Hi' }, + { + id: 'a1', + role: 'assistant', + content: 'Hello', + modelName: 'gemma4:e2b', + }, + ]; + + const { result } = renderHook(() => useConversationHistory()); + + await act(async () => { + await result.current.save(messagesWithModel, MODEL); + }); + + expect(invoke).toHaveBeenCalledWith( + 'save_conversation', + expect.objectContaining({ + messages: [ + expect.objectContaining({ role: 'user', model_name: null }), + expect.objectContaining({ + role: 'assistant', + model_name: 'gemma4:e2b', + }), + ], + }), + ); + }); + + it('persistTurn() sends modelName for assistant, null for user', async () => { + invoke.mockResolvedValueOnce({ conversation_id: 'conv-model-persist' }); + invoke.mockResolvedValue(undefined); + + const { result } = renderHook(() => useConversationHistory()); + + await act(async () => { + await result.current.save(MESSAGES, MODEL); + }); + invoke.mockClear(); + + const userMsg: Message = { id: 'u-m', role: 'user', content: 'q' }; + const assistantMsg: Message = { + id: 'a-m', + role: 'assistant', + content: 'answer', + modelName: 'qwen2.5:7b', + }; + + await act(async () => { + await result.current.persistTurn(userMsg, assistantMsg); + }); + + expect(invoke).toHaveBeenCalledWith( + 'persist_message', + expect.objectContaining({ + role: 'user', + modelName: null, + }), + ); + expect(invoke).toHaveBeenCalledWith( + 'persist_message', + expect.objectContaining({ + role: 'assistant', + modelName: 'qwen2.5:7b', + }), + ); + }); + + it('loadConversation() maps model_name back to modelName on restore', async () => { + invoke.mockResolvedValueOnce([ + { + id: 'u1', + role: 'user', + content: 'Hi', + quoted_text: null, + image_paths: null, + thinking_content: null, + search_sources: null, + search_warnings: null, + search_metadata: null, + model_name: null, + created_at: 1, + }, + { + id: 'a1', + role: 'assistant', + content: 'Hello', + quoted_text: null, + image_paths: null, + thinking_content: null, + search_sources: null, + search_warnings: null, + search_metadata: null, + model_name: 'gemma4:e2b', + created_at: 2, + }, + ]); + + const { result } = renderHook(() => useConversationHistory()); + let loaded: Message[] = []; + + await act(async () => { + loaded = await result.current.loadConversation('conv-model-load'); + }); + + expect(loaded[0].modelName).toBeUndefined(); + expect(loaded[1].modelName).toBe('gemma4:e2b'); + }); }); diff --git a/src/hooks/__tests__/useModelCapabilities.test.tsx b/src/hooks/__tests__/useModelCapabilities.test.tsx new file mode 100644 index 00000000..1ed4da6a --- /dev/null +++ b/src/hooks/__tests__/useModelCapabilities.test.tsx @@ -0,0 +1,113 @@ +import { renderHook, act } from '@testing-library/react'; +import { describe, it, expect, beforeEach } from 'vitest'; +import { useModelCapabilities } from '../useModelCapabilities'; +import { invoke } from '../../testUtils/mocks/tauri'; + +const FULL = { + vision: true, + thinking: true, +}; + +const TEXT_ONLY = { + vision: false, + thinking: false, +}; + +describe('useModelCapabilities', () => { + beforeEach(() => { + invoke.mockReset(); + }); + + it('loads the capability map from the backend', async () => { + invoke.mockResolvedValueOnce({ + 'llama3.2-vision': FULL, + llama3: TEXT_ONLY, + }); + const { result } = renderHook(() => useModelCapabilities()); + await act(async () => {}); + expect(result.current.capabilities).toEqual({ + 'llama3.2-vision': FULL, + llama3: TEXT_ONLY, + }); + }); + + it('clears state on backend reject', async () => { + invoke.mockRejectedValueOnce(new Error('backend offline')); + const { result } = renderHook(() => useModelCapabilities()); + await act(async () => {}); + expect(result.current.capabilities).toEqual({}); + }); + + it('clears state when payload is not an object', async () => { + invoke.mockResolvedValueOnce('not-a-map'); + const { result } = renderHook(() => useModelCapabilities()); + await act(async () => {}); + expect(result.current.capabilities).toEqual({}); + }); + + it('clears state when payload is null', async () => { + invoke.mockResolvedValueOnce(null); + const { result } = renderHook(() => useModelCapabilities()); + await act(async () => {}); + expect(result.current.capabilities).toEqual({}); + }); + + it('clears state when an entry has the wrong shape', async () => { + invoke.mockResolvedValueOnce({ + llama3: { + vision: 'yes', + thinking: false, + }, + }); + const { result } = renderHook(() => useModelCapabilities()); + await act(async () => {}); + expect(result.current.capabilities).toEqual({}); + }); + + it('clears state when an entry is null', async () => { + invoke.mockResolvedValueOnce({ llama3: null }); + const { result } = renderHook(() => useModelCapabilities()); + await act(async () => {}); + expect(result.current.capabilities).toEqual({}); + }); + + it('refresh re-fetches the map', async () => { + invoke + .mockResolvedValueOnce({ a: TEXT_ONLY }) + .mockResolvedValueOnce({ a: TEXT_ONLY, b: FULL }); + const { result } = renderHook(() => useModelCapabilities()); + await act(async () => {}); + expect(result.current.capabilities).toEqual({ a: TEXT_ONLY }); + await act(async () => { + await result.current.refresh(); + }); + expect(result.current.capabilities).toEqual({ a: TEXT_ONLY, b: FULL }); + }); + + it('drops a stale rejection from a superseded fetch', async () => { + // First mount-call hangs and is later rejected. A second refresh in + // the meantime resolves successfully and bumps the token. The first + // call's late rejection must be ignored so the resolved state from + // the second call is preserved. + let rejectFirst: (err: Error) => void = () => {}; + const firstPromise = new Promise((_, reject) => { + rejectFirst = reject; + }); + invoke.mockReturnValueOnce(firstPromise).mockResolvedValueOnce({ b: FULL }); + + const { result } = renderHook(() => useModelCapabilities()); + await act(async () => {}); + // Kick off a second refresh that supersedes the first. + await act(async () => { + await result.current.refresh(); + }); + expect(result.current.capabilities).toEqual({ b: FULL }); + // Now reject the first hanging call. Its catch must short-circuit + // because the token is stale and so leave state untouched. + await act(async () => { + rejectFirst(new Error('late')); + await new Promise((r) => setTimeout(r, 0)); + }); + expect(result.current.capabilities).toEqual({ b: FULL }); + }); +}); diff --git a/src/hooks/__tests__/useModelSelection.test.tsx b/src/hooks/__tests__/useModelSelection.test.tsx new file mode 100644 index 00000000..aa83a65a --- /dev/null +++ b/src/hooks/__tests__/useModelSelection.test.tsx @@ -0,0 +1,272 @@ +import { renderHook, act } from '@testing-library/react'; +import { describe, it, expect, beforeEach } from 'vitest'; +import { useModelSelection } from '../useModelSelection'; +import { invoke } from '../../testUtils/mocks/tauri'; + +describe('useModelSelection', () => { + beforeEach(() => { + invoke.mockReset(); + }); + + it('loads active and installed models from the backend', async () => { + invoke.mockResolvedValueOnce({ + active: 'gemma4:e2b', + all: ['gemma4:e2b', 'qwen2.5:7b'], + }); + + const { result } = renderHook(() => useModelSelection()); + await act(async () => {}); + + expect(result.current.activeModel).toBe('gemma4:e2b'); + expect(result.current.availableModels).toEqual([ + 'gemma4:e2b', + 'qwen2.5:7b', + ]); + }); + + it('persists a new active model and updates local state', async () => { + invoke + .mockResolvedValueOnce({ + active: 'gemma4:e2b', + all: ['gemma4:e2b', 'qwen2.5:7b'], + }) + .mockResolvedValueOnce(undefined); + + const { result } = renderHook(() => useModelSelection()); + await act(async () => {}); + + await act(async () => { + await result.current.setActiveModel('qwen2.5:7b'); + }); + + expect(invoke).toHaveBeenCalledWith('set_active_model', { + model: 'qwen2.5:7b', + }); + expect(result.current.activeModel).toBe('qwen2.5:7b'); + }); + + it('clears available models when backend fetch fails', async () => { + invoke.mockRejectedValueOnce(new Error('backend offline')); + + const { result } = renderHook(() => useModelSelection()); + await act(async () => {}); + + expect(result.current.availableModels).toEqual([]); + expect(result.current.activeModel).toBe(''); + }); + + it('falls back to empty state when payload shape is invalid', async () => { + invoke.mockResolvedValueOnce({ active: 42, all: 'not-an-array' }); + + const { result } = renderHook(() => useModelSelection()); + await act(async () => {}); + + expect(result.current.availableModels).toEqual([]); + expect(result.current.activeModel).toBe(''); + }); + + it('re-fetches models when refreshModels is called', async () => { + invoke + .mockResolvedValueOnce({ active: 'gemma4:e2b', all: ['gemma4:e2b'] }) + .mockResolvedValueOnce({ + active: 'qwen2.5:7b', + all: ['gemma4:e2b', 'qwen2.5:7b'], + }); + + const { result } = renderHook(() => useModelSelection()); + await act(async () => {}); + + await act(async () => { + await result.current.refreshModels(); + }); + + expect(result.current.activeModel).toBe('qwen2.5:7b'); + expect(result.current.availableModels).toEqual([ + 'gemma4:e2b', + 'qwen2.5:7b', + ]); + }); + + it('rejects null payloads from the backend', async () => { + invoke.mockResolvedValueOnce(null); + + const { result } = renderHook(() => useModelSelection()); + await act(async () => {}); + + expect(result.current.availableModels).toEqual([]); + expect(result.current.activeModel).toBe(''); + }); + + it('rejects non-object payloads from the backend', async () => { + invoke.mockResolvedValueOnce('gemma4:e2b'); + + const { result } = renderHook(() => useModelSelection()); + await act(async () => {}); + + expect(result.current.availableModels).toEqual([]); + expect(result.current.activeModel).toBe(''); + }); + + it('rejects payloads whose `all` array contains non-string entries', async () => { + invoke.mockResolvedValueOnce({ active: 'gemma4:e2b', all: ['ok', 7] }); + + const { result } = renderHook(() => useModelSelection()); + await act(async () => {}); + + expect(result.current.availableModels).toEqual([]); + expect(result.current.activeModel).toBe(''); + }); + + it('surfaces backend errors and leaves active model unchanged on rejection', async () => { + invoke + .mockResolvedValueOnce({ + active: 'gemma4:e2b', + all: ['gemma4:e2b', 'qwen2.5:7b'], + }) + .mockRejectedValueOnce( + new Error('Model is not installed in Ollama: mystery'), + ); + + const { result } = renderHook(() => useModelSelection()); + await act(async () => {}); + + await expect( + act(async () => { + await result.current.setActiveModel('mystery'); + }), + ).rejects.toThrow('Model is not installed in Ollama: mystery'); + + expect(result.current.activeModel).toBe('gemma4:e2b'); + }); + + it('clears active model when a later refresh returns a malformed payload', async () => { + invoke + .mockResolvedValueOnce({ + active: 'gemma4:e2b', + all: ['gemma4:e2b', 'qwen2.5:7b'], + }) + .mockResolvedValueOnce({ active: 42, all: 'not-an-array' }); + + const { result } = renderHook(() => useModelSelection()); + await act(async () => {}); + expect(result.current.activeModel).toBe('gemma4:e2b'); + + await act(async () => { + await result.current.refreshModels(); + }); + + expect(result.current.activeModel).toBe(''); + expect(result.current.availableModels).toEqual([]); + }); + + it('drops a stale setActiveModel resolution when a newer call supersedes it', async () => { + invoke.mockResolvedValueOnce({ + active: 'A', + all: ['A', 'B', 'C'], + }); + + const { result } = renderHook(() => useModelSelection()); + await act(async () => {}); + + let resolveSlow!: () => void; + invoke + .mockImplementationOnce( + () => + new Promise((r) => { + resolveSlow = () => r(); + }), + ) + .mockResolvedValueOnce(undefined); + + let slowPromise: Promise; + await act(async () => { + slowPromise = result.current.setActiveModel('B'); + await result.current.setActiveModel('C'); + }); + + // "C" wins because it was the latest call; "B"'s pending promise must be + // a silent no-op when it finally resolves. + expect(result.current.activeModel).toBe('C'); + + await act(async () => { + resolveSlow(); + await slowPromise; + }); + + expect(result.current.activeModel).toBe('C'); + }); + + it('drops a stale setActiveModel rejection when a newer call supersedes it', async () => { + invoke.mockResolvedValueOnce({ + active: 'A', + all: ['A', 'B', 'C'], + }); + + const { result } = renderHook(() => useModelSelection()); + await act(async () => {}); + + let rejectSlow!: (err: unknown) => void; + invoke + .mockImplementationOnce( + () => + new Promise((_resolve, reject) => { + rejectSlow = reject; + }), + ) + .mockResolvedValueOnce(undefined); + + let slowPromise: Promise; + await act(async () => { + slowPromise = result.current.setActiveModel('B'); + await result.current.setActiveModel('C'); + }); + + expect(result.current.activeModel).toBe('C'); + + // The stale rejection must not bubble up to callers or revert state. + await act(async () => { + rejectSlow(new Error('stale')); + await slowPromise; + }); + + expect(result.current.activeModel).toBe('C'); + }); + + it('drops a late refresh resolution after unmount', async () => { + let resolveLate!: (value: unknown) => void; + invoke.mockImplementationOnce( + () => + new Promise((resolve) => { + resolveLate = resolve; + }), + ); + + const { unmount } = renderHook(() => useModelSelection()); + unmount(); + + // Resolving after unmount would setState on an unmounted component without + // the mounted guard, producing a React warning / test failure. + await act(async () => { + resolveLate({ active: 'A', all: ['A'] }); + }); + }); + + it('drops a late refresh rejection after unmount', async () => { + let rejectLate!: (err: unknown) => void; + invoke.mockImplementationOnce( + () => + new Promise((_resolve, reject) => { + rejectLate = reject; + }), + ); + + const { unmount } = renderHook(() => useModelSelection()); + unmount(); + + // Same shape as the late-resolve test but exercises the catch branch of + // refreshModels so the post-unmount guard is covered in both arms. + await act(async () => { + rejectLate(new Error('late')); + }); + }); +}); diff --git a/src/hooks/__tests__/useOllama.test.tsx b/src/hooks/__tests__/useOllama.test.tsx index 388de991..5dcb930a 100644 --- a/src/hooks/__tests__/useOllama.test.tsx +++ b/src/hooks/__tests__/useOllama.test.tsx @@ -25,7 +25,7 @@ describe('useOllama', () => { describe('ask()', () => { it('sends message via invoke with correct command name and args', async () => { - const { result } = renderHook(() => useOllama()); + const { result } = renderHook(() => useOllama('')); await act(async () => { await result.current.ask('hello world'); @@ -55,7 +55,7 @@ describe('useOllama', () => { }, ); - const { result } = renderHook(() => useOllama()); + const { result } = renderHook(() => useOllama('')); // Start ask but don't await so we can read state while in-flight act(() => { @@ -72,7 +72,7 @@ describe('useOllama', () => { }); it('adds user message and empty assistant placeholder immediately on ask', async () => { - const { result } = renderHook(() => useOllama()); + const { result } = renderHook(() => useOllama('')); await act(async () => { await result.current.ask('my question'); @@ -95,7 +95,7 @@ describe('useOllama', () => { }); it('stores quotedText on user message when provided', async () => { - const { result } = renderHook(() => useOllama()); + const { result } = renderHook(() => useOllama('')); await act(async () => { await result.current.ask('what is this?', 'code snippet'); @@ -111,7 +111,7 @@ describe('useOllama', () => { }); it('sends quotedText to invoke when provided', async () => { - const { result } = renderHook(() => useOllama()); + const { result } = renderHook(() => useOllama('')); await act(async () => { await result.current.ask('summarize', 'selected text'); @@ -127,7 +127,7 @@ describe('useOllama', () => { }); it('accumulates streaming tokens into the assistant message', async () => { - const { result } = renderHook(() => useOllama()); + const { result } = renderHook(() => useOllama('')); await act(async () => { await result.current.ask('hello'); @@ -148,7 +148,7 @@ describe('useOllama', () => { }); it('keeps assistant message in place on Done chunk', async () => { - const { result } = renderHook(() => useOllama()); + const { result } = renderHook(() => useOllama('')); await act(async () => { await result.current.ask('hello'); @@ -172,7 +172,7 @@ describe('useOllama', () => { }); it('does nothing for empty prompt', async () => { - const { result } = renderHook(() => useOllama()); + const { result } = renderHook(() => useOllama('')); await act(async () => { await result.current.ask(''); @@ -183,7 +183,7 @@ describe('useOllama', () => { }); it('does nothing for whitespace-only prompt', async () => { - const { result } = renderHook(() => useOllama()); + const { result } = renderHook(() => useOllama('')); await act(async () => { await result.current.ask(' '); @@ -205,7 +205,7 @@ describe('useOllama', () => { }, ); - const { result } = renderHook(() => useOllama()); + const { result } = renderHook(() => useOllama('')); // Start the first ask (stalls) act(() => { @@ -230,7 +230,7 @@ describe('useOllama', () => { }); it('sends promptOverride as message to backend when provided', async () => { - const { result } = renderHook(() => useOllama()); + const { result } = renderHook(() => useOllama('')); await act(async () => { await result.current.ask( @@ -255,7 +255,7 @@ describe('useOllama', () => { }); it('sends displayContent as message when no promptOverride provided', async () => { - const { result } = renderHook(() => useOllama()); + const { result } = renderHook(() => useOllama('')); await act(async () => { await result.current.ask('hello world'); @@ -270,7 +270,7 @@ describe('useOllama', () => { }); it('sends displayContent when promptOverride is undefined', async () => { - const { result } = renderHook(() => useOllama()); + const { result } = renderHook(() => useOllama('')); await act(async () => { await result.current.ask( @@ -295,7 +295,7 @@ describe('useOllama', () => { describe('imagePaths handling', () => { it('allows ask() with empty text but valid imagePaths', async () => { - const { result } = renderHook(() => useOllama()); + const { result } = renderHook(() => useOllama('')); await act(async () => { await result.current.ask('', undefined, ['/tmp/img1.jpg']); @@ -320,7 +320,7 @@ describe('useOllama', () => { }); it('returns early for empty text AND no imagePaths', async () => { - const { result } = renderHook(() => useOllama()); + const { result } = renderHook(() => useOllama('')); await act(async () => { await result.current.ask('', undefined, undefined); @@ -331,7 +331,7 @@ describe('useOllama', () => { }); it('returns early for empty text AND empty imagePaths array', async () => { - const { result } = renderHook(() => useOllama()); + const { result } = renderHook(() => useOllama('')); await act(async () => { await result.current.ask('', undefined, []); @@ -342,7 +342,7 @@ describe('useOllama', () => { }); it('includes imagePaths in message and invoke when text AND imagePaths are provided', async () => { - const { result } = renderHook(() => useOllama()); + const { result } = renderHook(() => useOllama('')); await act(async () => { await result.current.ask('describe this', undefined, [ @@ -368,7 +368,7 @@ describe('useOllama', () => { }); it('sets message.imagePaths to undefined and invoke imagePaths to null when no imagePaths', async () => { - const { result } = renderHook(() => useOllama()); + const { result } = renderHook(() => useOllama('')); await act(async () => { await result.current.ask('hello'); @@ -388,7 +388,7 @@ describe('useOllama', () => { describe('error handling', () => { it('Error chunk sets isGenerating to false', async () => { - const { result } = renderHook(() => useOllama()); + const { result } = renderHook(() => useOllama('')); await act(async () => { await result.current.ask('test'); @@ -413,7 +413,7 @@ describe('useOllama', () => { it('invoke rejection sets isGenerating to false', async () => { invoke.mockRejectedValueOnce(new Error('connection refused')); - const { result } = renderHook(() => useOllama()); + const { result } = renderHook(() => useOllama('')); await act(async () => { await result.current.ask('test'); @@ -423,7 +423,7 @@ describe('useOllama', () => { }); it('Error chunk updates assistant placeholder with errorKind', async () => { - const { result } = renderHook(() => useOllama()); + const { result } = renderHook(() => useOllama('')); await act(async () => { await result.current.ask('test'); @@ -452,7 +452,7 @@ describe('useOllama', () => { }); it('Error chunk with partial tokens replaces content with error', async () => { - const { result } = renderHook(() => useOllama()); + const { result } = renderHook(() => useOllama('')); await act(async () => { await result.current.ask('test'); @@ -479,7 +479,7 @@ describe('useOllama', () => { it('invoke rejection creates assistant message with Other errorKind', async () => { invoke.mockRejectedValueOnce(new Error('network error')); - const { result } = renderHook(() => useOllama()); + const { result } = renderHook(() => useOllama('')); await act(async () => { await result.current.ask('test'); @@ -497,7 +497,7 @@ describe('useOllama', () => { describe('streaming edge cases', () => { it('handles Token with empty string', async () => { - const { result } = renderHook(() => useOllama()); + const { result } = renderHook(() => useOllama('')); await act(async () => { await result.current.ask('hello'); @@ -518,7 +518,7 @@ describe('useOllama', () => { }); it('drops the placeholder when only an empty ThinkingToken arrives before cancellation', async () => { - const { result } = renderHook(() => useOllama()); + const { result } = renderHook(() => useOllama('')); await act(async () => { await result.current.ask('hello', undefined, undefined, true); @@ -567,7 +567,7 @@ describe('useOllama', () => { } }); - const { result } = renderHook(() => useOllama()); + const { result } = renderHook(() => useOllama('')); let secondAsk!: Promise; let thirdAsk!: Promise; @@ -621,7 +621,7 @@ describe('useOllama', () => { } }); - const { result } = renderHook(() => useOllama()); + const { result } = renderHook(() => useOllama('')); act(() => { void result.current.ask('late failure'); @@ -663,7 +663,7 @@ describe('useOllama', () => { }, ); - const { result } = renderHook(() => useOllama()); + const { result } = renderHook(() => useOllama('')); act(() => { void result.current.ask('hello'); @@ -708,7 +708,7 @@ describe('useOllama', () => { }, ); - const { result } = renderHook(() => useOllama()); + const { result } = renderHook(() => useOllama('')); act(() => { void result.current.askSearch('rust'); @@ -752,7 +752,7 @@ describe('useOllama', () => { }); it('does nothing when not generating', async () => { - const { result } = renderHook(() => useOllama()); + const { result } = renderHook(() => useOllama('')); await act(async () => { await result.current.cancel(); @@ -767,7 +767,7 @@ describe('useOllama', () => { describe('Cancelled chunk', () => { it('keeps partial content as assistant message on Cancelled', async () => { - const { result } = renderHook(() => useOllama()); + const { result } = renderHook(() => useOllama('')); await act(async () => { await result.current.ask('hello'); @@ -792,7 +792,7 @@ describe('useOllama', () => { }); it('removes assistant placeholder when cancelled with no tokens', async () => { - const { result } = renderHook(() => useOllama()); + const { result } = renderHook(() => useOllama('')); await act(async () => { await result.current.ask('hello'); @@ -816,7 +816,7 @@ describe('useOllama', () => { describe('reset()', () => { it('clears all state', async () => { - const { result } = renderHook(() => useOllama()); + const { result } = renderHook(() => useOllama('')); // Build up some state await act(async () => { @@ -847,7 +847,7 @@ describe('useOllama', () => { describe('onTurnComplete callback', () => { it('is called with user and assistant messages on Done', async () => { const onTurnComplete = vi.fn(); - const { result } = renderHook(() => useOllama(onTurnComplete)); + const { result } = renderHook(() => useOllama('', onTurnComplete)); await act(async () => { await result.current.ask('ping'); @@ -870,7 +870,7 @@ describe('useOllama', () => { it('is not called when Cancelled', async () => { const onTurnComplete = vi.fn(); - const { result } = renderHook(() => useOllama(onTurnComplete)); + const { result } = renderHook(() => useOllama('', onTurnComplete)); await act(async () => { await result.current.ask('ping'); @@ -887,7 +887,7 @@ describe('useOllama', () => { it('is not called when an Error chunk is received', async () => { const onTurnComplete = vi.fn(); - const { result } = renderHook(() => useOllama(onTurnComplete)); + const { result } = renderHook(() => useOllama('', onTurnComplete)); await act(async () => { await result.current.ask('ping'); @@ -905,11 +905,82 @@ describe('useOllama', () => { }); }); + // ─── modelName attribution ─────────────────────────────────────────────────── + + describe('modelName attribution', () => { + it('stamps the assistant message with activeModel on ask() completion', async () => { + const onTurnComplete = vi.fn(); + const { result } = renderHook(() => + useOllama('gemma4:e2b', onTurnComplete), + ); + + await act(async () => { + await result.current.ask('hi'); + }); + + const channel = getChannel(); + act(() => { + channel!.simulateMessage({ type: 'Token', data: 'hello' }); + channel!.simulateMessage({ type: 'Done' }); + }); + + const [, assistantMsg] = onTurnComplete.mock.calls[0]; + expect(assistantMsg.modelName).toBe('gemma4:e2b'); + expect(result.current.messages[1]).toMatchObject({ + role: 'assistant', + modelName: 'gemma4:e2b', + }); + }); + + it('leaves modelName undefined when activeModel is an empty string', async () => { + const onTurnComplete = vi.fn(); + const { result } = renderHook(() => useOllama('', onTurnComplete)); + + await act(async () => { + await result.current.ask('hi'); + }); + + const channel = getChannel(); + act(() => { + channel!.simulateMessage({ type: 'Token', data: 'hello' }); + channel!.simulateMessage({ type: 'Done' }); + }); + + const [, assistantMsg] = onTurnComplete.mock.calls[0]; + expect(assistantMsg.modelName).toBeUndefined(); + }); + + it('stamps the assistant message with activeModel on askSearch() turns', async () => { + const onTurnComplete = vi.fn(); + const { result } = renderHook(() => + useOllama('qwen2.5:7b', onTurnComplete), + ); + + let pending: Promise | undefined; + await act(async () => { + pending = result.current.askSearch('rust async'); + }); + + const channel = getChannel(); + act(() => { + channel!.simulateMessage({ type: 'Token', content: 'answer' }); + channel!.simulateMessage({ type: 'Done' }); + }); + + await act(async () => { + await pending; + }); + + const [, assistantMsg] = onTurnComplete.mock.calls[0]; + expect(assistantMsg.modelName).toBe('qwen2.5:7b'); + }); + }); + // ─── loadMessages() ────────────────────────────────────────────────────────── describe('loadMessages()', () => { it('replaces messages state with provided array', async () => { - const { result } = renderHook(() => useOllama()); + const { result } = renderHook(() => useOllama('')); await act(async () => { await result.current.ask('original question'); @@ -934,7 +1005,7 @@ describe('useOllama', () => { it('clears generating state when loading messages', async () => { invoke.mockRejectedValueOnce(new Error('boom')); - const { result } = renderHook(() => useOllama()); + const { result } = renderHook(() => useOllama('')); await act(async () => { await result.current.ask('fail'); @@ -953,7 +1024,7 @@ describe('useOllama', () => { describe('ThinkingToken handling', () => { it('marks the assistant placeholder as a /think turn when think is true', async () => { - const { result } = renderHook(() => useOllama()); + const { result } = renderHook(() => useOllama('')); await act(async () => { await result.current.ask('hello', undefined, undefined, true); @@ -966,7 +1037,7 @@ describe('useOllama', () => { }); it('accumulates ThinkingTokens into thinkingContent', async () => { - const { result } = renderHook(() => useOllama()); + const { result } = renderHook(() => useOllama('')); await act(async () => { await result.current.ask('hello', undefined, undefined, true); @@ -987,7 +1058,7 @@ describe('useOllama', () => { }); it('passes think parameter to invoke', async () => { - const { result } = renderHook(() => useOllama()); + const { result } = renderHook(() => useOllama('')); await act(async () => { await result.current.ask('hello', undefined, undefined, true); @@ -1002,7 +1073,7 @@ describe('useOllama', () => { }); it('passes think as false by default', async () => { - const { result } = renderHook(() => useOllama()); + const { result } = renderHook(() => useOllama('')); await act(async () => { await result.current.ask('hello'); @@ -1018,7 +1089,7 @@ describe('useOllama', () => { it('includes thinkingContent in onTurnComplete on Done', async () => { const onTurnComplete = vi.fn(); - const { result } = renderHook(() => useOllama(onTurnComplete)); + const { result } = renderHook(() => useOllama('', onTurnComplete)); await act(async () => { await result.current.ask('hello', undefined, undefined, true); @@ -1043,7 +1114,7 @@ describe('useOllama', () => { it('does not set thinkingContent when no thinking happened', async () => { const onTurnComplete = vi.fn(); - const { result } = renderHook(() => useOllama(onTurnComplete)); + const { result } = renderHook(() => useOllama('', onTurnComplete)); await act(async () => { await result.current.ask('hello'); @@ -1061,7 +1132,7 @@ describe('useOllama', () => { }); it('preserves thinking content when cancelled with thinking but no regular tokens', async () => { - const { result } = renderHook(() => useOllama()); + const { result } = renderHook(() => useOllama('')); await act(async () => { await result.current.ask('hello', undefined, undefined, true); @@ -1091,7 +1162,7 @@ describe('useOllama', () => { describe('history', () => { it('maintains message history across multiple sequential asks', async () => { - const { result } = renderHook(() => useOllama()); + const { result } = renderHook(() => useOllama('')); // First ask + response await act(async () => { @@ -1133,7 +1204,7 @@ describe('useOllama', () => { describe('askSearch()', () => { it('invokes search_pipeline with the trimmed query', async () => { - const { result } = renderHook(() => useOllama()); + const { result } = renderHook(() => useOllama('')); let pending!: Promise<{ final: boolean }>; await act(async () => { pending = result.current.askSearch(' rust async '); @@ -1153,7 +1224,7 @@ describe('useOllama', () => { }); it('stores quotedText on the /search user message when provided', async () => { - const { result } = renderHook(() => useOllama()); + const { result } = renderHook(() => useOllama('')); let pending!: Promise<{ final: boolean }>; await act(async () => { pending = result.current.askSearch( @@ -1180,7 +1251,7 @@ describe('useOllama', () => { }); it('resolves immediately with final=true on empty query', async () => { - const { result } = renderHook(() => useOllama()); + const { result } = renderHook(() => useOllama('')); let outcome: { final: boolean } | undefined; await act(async () => { outcome = await result.current.askSearch(' '); @@ -1190,7 +1261,7 @@ describe('useOllama', () => { }); it('resolves with final=true when a token is received followed by Done', async () => { - const { result } = renderHook(() => useOllama()); + const { result } = renderHook(() => useOllama('')); const metadata = { iterations: [ { @@ -1233,7 +1304,7 @@ describe('useOllama', () => { it('resolves with final=false when a clarify trace is followed by question tokens and Done', async () => { const onTurnComplete = vi.fn(); - const { result } = renderHook(() => useOllama(onTurnComplete)); + const { result } = renderHook(() => useOllama('', onTurnComplete)); let pending!: Promise<{ final: boolean }>; await act(async () => { pending = result.current.askSearch('who is him'); @@ -1270,7 +1341,7 @@ describe('useOllama', () => { }); it('updates searchStage through the pipeline phases', async () => { - const { result } = renderHook(() => useOllama()); + const { result } = renderHook(() => useOllama('')); let pending!: Promise<{ final: boolean }>; await act(async () => { pending = result.current.askSearch('q'); @@ -1321,7 +1392,7 @@ describe('useOllama', () => { }); it('handles FetchingUrl, finalizes traces on IterationComplete, and ignores empty tokens', async () => { - const { result } = renderHook(() => useOllama()); + const { result } = renderHook(() => useOllama('')); let pending!: Promise<{ final: boolean }>; await act(async () => { @@ -1383,7 +1454,7 @@ describe('useOllama', () => { }); it('ignores IterationComplete events when no trace steps have started', async () => { - const { result } = renderHook(() => useOllama()); + const { result } = renderHook(() => useOllama('')); let pending!: Promise<{ final: boolean }>; await act(async () => { @@ -1418,7 +1489,7 @@ describe('useOllama', () => { }); it('drops the empty placeholder on Cancelled with no content', async () => { - const { result } = renderHook(() => useOllama()); + const { result } = renderHook(() => useOllama('')); let pending!: Promise<{ final: boolean }>; await act(async () => { pending = result.current.askSearch('q'); @@ -1438,7 +1509,7 @@ describe('useOllama', () => { }); it('keeps partial content on Cancelled after tokens arrived', async () => { - const { result } = renderHook(() => useOllama()); + const { result } = renderHook(() => useOllama('')); let pending!: Promise<{ final: boolean }>; await act(async () => { pending = result.current.askSearch('q'); @@ -1459,7 +1530,7 @@ describe('useOllama', () => { it('renders an Error event as an error bubble', async () => { const onTurnComplete = vi.fn(); - const { result } = renderHook(() => useOllama(onTurnComplete)); + const { result } = renderHook(() => useOllama('', onTurnComplete)); let pending!: Promise<{ final: boolean }>; await act(async () => { pending = result.current.askSearch('q'); @@ -1483,7 +1554,7 @@ describe('useOllama', () => { }); it('guards against concurrent invocations', async () => { - const { result } = renderHook(() => useOllama()); + const { result } = renderHook(() => useOllama('')); let firstPending!: Promise<{ final: boolean }>; await act(async () => { firstPending = result.current.askSearch('first'); @@ -1532,7 +1603,7 @@ describe('useOllama', () => { } }); - const { result } = renderHook(() => useOllama()); + const { result } = renderHook(() => useOllama('')); let firstPending!: Promise<{ final: boolean }>; let secondPending!: Promise<{ final: boolean }>; @@ -1578,7 +1649,7 @@ describe('useOllama', () => { invoke.mockImplementationOnce(async () => { throw new Error('ipc failed'); }); - const { result } = renderHook(() => useOllama()); + const { result } = renderHook(() => useOllama('')); let outcome: { final: boolean } | undefined; await act(async () => { outcome = await result.current.askSearch('q'); @@ -1611,7 +1682,7 @@ describe('useOllama', () => { } }); - const { result } = renderHook(() => useOllama()); + const { result } = renderHook(() => useOllama('')); let pending!: Promise<{ final: boolean }>; act(() => { @@ -1642,7 +1713,7 @@ describe('useOllama', () => { it('does not persist an empty turn on Done', async () => { const onTurnComplete = vi.fn(); - const { result } = renderHook(() => useOllama(onTurnComplete)); + const { result } = renderHook(() => useOllama('', onTurnComplete)); let pending!: Promise<{ final: boolean }>; await act(async () => { pending = result.current.askSearch('q'); @@ -1660,7 +1731,7 @@ describe('useOllama', () => { it('persists searchSources to the assistant message on Sources + Token + Done', async () => { const onTurnComplete = vi.fn(); - const { result } = renderHook(() => useOllama(onTurnComplete)); + const { result } = renderHook(() => useOllama('', onTurnComplete)); const metadata = { iterations: [ { @@ -1706,7 +1777,7 @@ describe('useOllama', () => { }); it('Warning event accumulates into message.searchWarnings while streaming continues', async () => { - const { result } = renderHook(() => useOllama()); + const { result } = renderHook(() => useOllama('')); let pending!: Promise<{ final: boolean }>; await act(async () => { pending = result.current.askSearch('q'); @@ -1731,7 +1802,7 @@ describe('useOllama', () => { it('askSearch accumulates warnings from Warning events into the persisted turn', async () => { const onTurnComplete = vi.fn(); - const { result } = renderHook(() => useOllama(onTurnComplete)); + const { result } = renderHook(() => useOllama('', onTurnComplete)); let pending!: Promise<{ final: boolean }>; await act(async () => { pending = result.current.askSearch('q'); @@ -1763,7 +1834,7 @@ describe('useOllama', () => { it('askSearch passes multiple warnings through in order', async () => { const onTurnComplete = vi.fn(); - const { result } = renderHook(() => useOllama(onTurnComplete)); + const { result } = renderHook(() => useOllama('', onTurnComplete)); let pending!: Promise<{ final: boolean }>; await act(async () => { pending = result.current.askSearch('q'); @@ -1793,7 +1864,7 @@ describe('useOllama', () => { }); it('Trace events accumulate steps on the assistant message', async () => { - const { result } = renderHook(() => useOllama()); + const { result } = renderHook(() => useOllama('')); let pending!: Promise<{ final: boolean }>; await act(async () => { pending = result.current.askSearch('q'); @@ -1840,7 +1911,7 @@ describe('useOllama', () => { }); it('Trace updates replace earlier steps with the same id', async () => { - const { result } = renderHook(() => useOllama()); + const { result } = renderHook(() => useOllama('')); let pending!: Promise<{ final: boolean }>; await act(async () => { pending = result.current.askSearch('q'); @@ -1889,7 +1960,7 @@ describe('useOllama', () => { it('Trace events are passed to onTurnComplete', async () => { const onTurnComplete = vi.fn(); - const { result } = renderHook(() => useOllama(onTurnComplete)); + const { result } = renderHook(() => useOllama('', onTurnComplete)); let pending!: Promise<{ final: boolean }>; await act(async () => { pending = result.current.askSearch('q'); @@ -1924,7 +1995,7 @@ describe('useOllama', () => { it('preserves completed traces on Done when no running steps need finalization', async () => { const onTurnComplete = vi.fn(); - const { result } = renderHook(() => useOllama(onTurnComplete)); + const { result } = renderHook(() => useOllama('', onTurnComplete)); let pending!: Promise<{ final: boolean }>; await act(async () => { @@ -1966,7 +2037,7 @@ describe('useOllama', () => { it('searchTraces is undefined when no Trace event is received', async () => { const onTurnComplete = vi.fn(); - const { result } = renderHook(() => useOllama(onTurnComplete)); + const { result } = renderHook(() => useOllama('', onTurnComplete)); let pending!: Promise<{ final: boolean }>; await act(async () => { pending = result.current.askSearch('q'); @@ -1988,7 +2059,7 @@ describe('useOllama', () => { describe('search state cleanup', () => { it('reset clears the search stage indicator', async () => { - const { result } = renderHook(() => useOllama()); + const { result } = renderHook(() => useOllama('')); let pending!: Promise<{ final: boolean }>; await act(async () => { pending = result.current.askSearch('q'); @@ -2011,7 +2082,7 @@ describe('useOllama', () => { }); it('loadMessages clears the search stage indicator', async () => { - const { result } = renderHook(() => useOllama()); + const { result } = renderHook(() => useOllama('')); let pending!: Promise<{ final: boolean }>; await act(async () => { pending = result.current.askSearch('q'); @@ -2034,7 +2105,7 @@ describe('useOllama', () => { }); it('Searching after RefiningSearch sets gap:true stage', async () => { - const { result } = renderHook(() => useOllama()); + const { result } = renderHook(() => useOllama('')); let pending!: Promise<{ final: boolean }>; await act(async () => { pending = result.current.askSearch('q'); @@ -2061,7 +2132,7 @@ describe('useOllama', () => { }); it('ReadingSources after RefiningSearch sets gap:true stage', async () => { - const { result } = renderHook(() => useOllama()); + const { result } = renderHook(() => useOllama('')); let pending!: Promise<{ final: boolean }>; await act(async () => { pending = result.current.askSearch('q'); @@ -2089,7 +2160,7 @@ describe('useOllama', () => { it('SandboxUnavailable event sets sandboxUnavailable on assistant message', async () => { const onTurnComplete = vi.fn(); - const { result } = renderHook(() => useOllama(onTurnComplete)); + const { result } = renderHook(() => useOllama('', onTurnComplete)); let pending!: Promise<{ final: boolean }>; await act(async () => { pending = result.current.askSearch('q'); @@ -2110,7 +2181,7 @@ describe('useOllama', () => { }); it('SandboxUnavailable event does not set errorKind', async () => { - const { result } = renderHook(() => useOllama()); + const { result } = renderHook(() => useOllama('')); let pending!: Promise<{ final: boolean }>; await act(async () => { pending = result.current.askSearch('q'); diff --git a/src/hooks/useConversationHistory.ts b/src/hooks/useConversationHistory.ts index 0e1ac8a4..d04a294b 100644 --- a/src/hooks/useConversationHistory.ts +++ b/src/hooks/useConversationHistory.ts @@ -41,6 +41,7 @@ function toPayload(msg: Message): SaveMessagePayload { : msg.searchTraces && msg.searchTraces.length > 0 ? JSON.stringify(msg.searchTraces) : null, + model_name: msg.modelName ?? null, }; } @@ -241,6 +242,7 @@ function fromPersisted(msg: PersistedMessage): Message { searchMetadata, searchTraces: searchTraces && searchTraces.length > 0 ? searchTraces : undefined, + modelName: msg.model_name ?? undefined, }; } @@ -266,18 +268,17 @@ export function useConversationHistory() { * Subsequent calls while `isSaved` is true are no-ops - the bookmark * icon on the frontend enforces single-save semantics. * - * The active model name is sourced by the Rust `save_conversation` command - * from the managed `AppConfig` state; the frontend no longer tracks or - * forwards it. - * - * Fires `generate_title` as a fire-and-forget background task after saving; - * the frontend should schedule a `listConversations` refresh to pick up the - * AI-generated title once it arrives (~2-5 seconds). + * Fires `generate_title` as a fire-and-forget background task after saving, + * threading the active model slug through so the title is produced by the + * same model that produced the conversation. The frontend should schedule a + * `listConversations` refresh to pick up the AI-generated title once it + * arrives (~2-5 seconds). * * @param messages The complete message history to persist. + * @param model The active Ollama model slug used for title generation. */ const save = useCallback( - async (messages: Message[]): Promise => { + async (messages: Message[], model: string): Promise => { if (isSaved) return; const payloads = messages.map(toPayload); @@ -296,6 +297,7 @@ export function useConversationHistory() { void invoke('generate_title', { conversationId: response.conversation_id, messages: payloads, + model, }); }, [isSaved], @@ -324,6 +326,7 @@ export function useConversationHistory() { searchSources: null, searchWarnings: null, searchMetadata: null, + modelName: null, }), invoke('persist_message', { conversationId, @@ -345,6 +348,7 @@ export function useConversationHistory() { assistantMsg.searchTraces.length > 0 ? JSON.stringify(assistantMsg.searchTraces) : null, + modelName: assistantMsg.modelName ?? null, }), ]); }, diff --git a/src/hooks/useModelCapabilities.ts b/src/hooks/useModelCapabilities.ts new file mode 100644 index 00000000..3198cd91 --- /dev/null +++ b/src/hooks/useModelCapabilities.ts @@ -0,0 +1,88 @@ +import { useCallback, useEffect, useRef, useState } from 'react'; +import { invoke } from '@tauri-apps/api/core'; +import type { ModelCapabilities, ModelCapabilitiesMap } from '../types/model'; + +/** + * Runtime guard that the IPC payload is a `{ [name]: Capabilities }` map. + * Mirrors the defensive shape check in `useModelSelection` so the hook + * stays robust against a backend / mock that returns the wrong shape. + */ +function isCapabilities(value: unknown): value is ModelCapabilities { + if (typeof value !== 'object' || value === null) return false; + const candidate = value as Record; + return ( + typeof candidate.vision === 'boolean' && + typeof candidate.thinking === 'boolean' + ); +} + +function isCapabilitiesMap(value: unknown): value is ModelCapabilitiesMap { + if (typeof value !== 'object' || value === null) return false; + return Object.values(value).every(isCapabilities); +} + +/** Shape returned by {@link useModelCapabilities}. */ +export interface UseModelCapabilitiesResult { + /** + * Map of model slug to its capability flags. Empty until the first + * fetch resolves or if the backend rejects. + */ + capabilities: ModelCapabilitiesMap; + /** + * Re-fetches the capabilities map. Callers are the single trigger: + * the hook fetches once on mount and never auto-retries. + */ + refresh: () => Promise; +} + +/** + * React hook that pulls the per-model capability map from the Rust + * `get_model_capabilities` Tauri command. Used by the picker to render + * capability labels and by the submit gate to refuse messages whose + * attached content does not match the active model's capabilities. + * + * The same monotonic-token pattern as `useModelSelection` keeps rapid + * out-of-order responses from overwriting newer state and drops + * resolutions that fire after unmount. + */ +export function useModelCapabilities(): UseModelCapabilitiesResult { + const [capabilities, setCapabilities] = useState({}); + const mountedRef = useRef(true); + const latestTokenRef = useRef(0); + + useEffect(() => { + mountedRef.current = true; + return () => { + mountedRef.current = false; + }; + }, []); + + const isLatest = useCallback( + (token: number): boolean => + mountedRef.current && token === latestTokenRef.current, + [], + ); + + const refresh = useCallback(async (): Promise => { + latestTokenRef.current += 1; + const token = latestTokenRef.current; + try { + const payload = await invoke('get_model_capabilities'); + if (!isLatest(token)) return; + if (!isCapabilitiesMap(payload)) { + setCapabilities({}); + return; + } + setCapabilities(payload); + } catch { + if (!isLatest(token)) return; + setCapabilities({}); + } + }, [isLatest]); + + useEffect(() => { + void refresh(); + }, [refresh]); + + return { capabilities, refresh }; +} diff --git a/src/hooks/useModelSelection.ts b/src/hooks/useModelSelection.ts new file mode 100644 index 00000000..97df0bcb --- /dev/null +++ b/src/hooks/useModelSelection.ts @@ -0,0 +1,119 @@ +import { useCallback, useEffect, useRef, useState } from 'react'; +import { invoke } from '@tauri-apps/api/core'; +import type { ModelPickerState } from '../types/model'; + +/** + * Runtime guard for the IPC boundary. The Rust backend is trusted, but this + * keeps the hook robust against shape drift (schema changes, legacy builds, + * mocks) without pulling in a schema library. + */ +function isModelPickerState(value: unknown): value is ModelPickerState { + if (typeof value !== 'object' || value === null) return false; + const candidate = value as { active?: unknown; all?: unknown }; + return ( + typeof candidate.active === 'string' && + Array.isArray(candidate.all) && + candidate.all.every((entry) => typeof entry === 'string') + ); +} + +/** + * Shape returned by {@link useModelSelection}. + */ +export interface UseModelSelectionResult { + /** The currently active Ollama model name. Empty string until loaded. */ + activeModel: string; + /** All locally installed Ollama model names available for selection. */ + availableModels: string[]; + /** + * Re-fetch the model picker state from the backend. Clears both + * `activeModel` and `availableModels` when the backend returns a malformed + * payload or the call rejects. Callers are the single trigger: this hook + * does not auto-retry. + */ + refreshModels: () => Promise; + /** + * Persist a new active model through the backend and sync local state + * after the backend acknowledges the change. Rejects with the backend + * error string so callers can surface the failure and trigger a refresh + * to resync the UI. + */ + setActiveModel: (model: string) => Promise; +} + +/** + * React hook that manages the active Ollama model selection. Loads the + * current model + the installed model list from the Rust backend on mount, + * and exposes imperative helpers for refresh and selection. + * + * Request serialization: every refresh and selection increments a monotonic + * token. Resolutions that belong to a stale token are dropped so rapid + * out-of-order responses cannot overwrite newer state. Resolutions that fire + * after unmount are also dropped to avoid React 18 StrictMode warnings. + */ +export function useModelSelection(): UseModelSelectionResult { + // The state setter is intentionally renamed because `setActiveModel` is the + // public async callback returned by this hook. + // eslint-disable-next-line @eslint-react/use-state + const [activeModel, setActiveModelState] = useState(''); + const [availableModels, setAvailableModels] = useState([]); + + const mountedRef = useRef(true); + const latestTokenRef = useRef(0); + + useEffect(() => { + mountedRef.current = true; + return () => { + mountedRef.current = false; + }; + }, []); + + const isLatest = useCallback((token: number): boolean => { + return mountedRef.current && token === latestTokenRef.current; + }, []); + + const refreshModels = useCallback(async (): Promise => { + latestTokenRef.current += 1; + const token = latestTokenRef.current; + try { + const state = await invoke('get_model_picker_state'); + if (!isLatest(token)) return; + if (!isModelPickerState(state)) { + setActiveModelState(''); + setAvailableModels([]); + return; + } + setActiveModelState(state.active); + setAvailableModels(state.all); + } catch { + if (!isLatest(token)) return; + setActiveModelState(''); + setAvailableModels([]); + } + }, [isLatest]); + + useEffect(() => { + void refreshModels(); + }, [refreshModels]); + + const setActiveModel = useCallback( + async (model: string): Promise => { + latestTokenRef.current += 1; + const token = latestTokenRef.current; + try { + await invoke('set_active_model', { model }); + } catch (err) { + if (isLatest(token)) { + throw err; + } + return; + } + if (isLatest(token)) { + setActiveModelState(model); + } + }, + [isLatest], + ); + + return { activeModel, availableModels, refreshModels, setActiveModel }; +} diff --git a/src/hooks/useOllama.ts b/src/hooks/useOllama.ts index 3e86e0c2..f292864e 100644 --- a/src/hooks/useOllama.ts +++ b/src/hooks/useOllama.ts @@ -18,6 +18,10 @@ export interface Message { id: string; role: 'user' | 'assistant'; content: string; + /** Ollama model slug attributed to this assistant message at creation time. + * Remains stable even if the user switches models mid-stream. Undefined for + * user messages and for legacy conversations loaded from pre-migration rows. */ + modelName?: string; /** Selected text from the host app that was quoted with this message, if any. */ quotedText?: string; /** Absolute file paths of images attached to this message, if any. */ @@ -120,8 +124,17 @@ function finalizeSearchTraceSteps( * * Manages message history, streaming state, and the Tauri IPC channels used by * both the normal chat path and the `/search` pipeline. + * + * @param activeModel Ollama model slug that should be attributed to each + * assistant message produced by this hook. Passed as a hook parameter (not + * a per-call argument) so the latest App-level selection is captured via + * closure on every render. An empty string (briefly possible on startup, + * before the model list resolves) is coerced to `undefined` on the emitted + * `Message`, so no attribution chip is rendered rather than a blank one. + * @param onTurnComplete Optional callback invoked after each completed turn. */ export function useOllama( + activeModel: string, onTurnComplete?: (userMsg: Message, assistantMsg: Message) => void, ) { const [messages, setMessages] = useState([]); @@ -220,6 +233,7 @@ export function useOllama( role: 'assistant', content: '', fromThink: think ? true : undefined, + modelName: activeModel || undefined, }; setMessages((prev) => [...prev, userMsg, assistantMsg]); @@ -336,7 +350,7 @@ export function useOllama( setSearchStage(null); } }, - [onTurnComplete], + [onTurnComplete, activeModel], ); /** @@ -374,6 +388,7 @@ export function useOllama( role: 'assistant', content: '', fromSearch: true, + modelName: activeModel || undefined, }; setMessages((prev) => [...prev, userMsg, assistantMsg]); @@ -564,7 +579,7 @@ export function useOllama( }); }); }, - [onTurnComplete], + [onTurnComplete, activeModel], ); /** Cancels the currently active generation. */ diff --git a/src/types/history.ts b/src/types/history.ts index fd8360f2..182aa8c7 100644 --- a/src/types/history.ts +++ b/src/types/history.ts @@ -45,6 +45,9 @@ export interface PersistedMessage { * `SearchTraceStep[]` or legacy iteration traces. Null for non-search * messages. */ search_metadata: string | null; + /** Ollama model slug attributed to this message. Null for user messages + * and legacy messages written before the model_name migration. */ + model_name: string | null; /** Unix timestamp (seconds) the message was created. */ created_at: number; } @@ -71,4 +74,7 @@ export interface SaveMessagePayload { search_warnings: string | null; /** Pre-serialized JSON string of SearchMetadata or a legacy trace payload. */ search_metadata: string | null; + /** Ollama model slug that produced this response. Null for user messages + * and messages from pre-migration conversations. */ + model_name: string | null; } diff --git a/src/types/model.ts b/src/types/model.ts new file mode 100644 index 00000000..4c449e3d --- /dev/null +++ b/src/types/model.ts @@ -0,0 +1,44 @@ +/* v8 ignore file -- type-only declarations, no runtime code */ + +/** + * Snapshot of model picker state returned by the Rust + * `get_model_picker_state` Tauri command. + * + * - `active` is the currently selected Ollama model name. Never empty once + * the backend has completed startup seeding. + * - `all` is the full list of locally installed Ollama model names, in the + * order the backend chose to surface them (typically matches `ollama list`). + */ +export interface ModelPickerState { + /** The currently active Ollama model name. */ + active: string; + /** All locally installed Ollama model names available for selection. */ + all: string[]; +} + +/** + * Per-model capability flags returned by the Rust `get_model_capabilities` + * Tauri command. Mirrors the `Capabilities` struct in `src-tauri/src/models.rs`. + */ +export interface ModelCapabilities { + vision: boolean; + thinking: boolean; + /** + * Maximum number of images the model accepts in a single request, when + * known. `null` (or omitted) means Thuki has no architecture-specific + * cap and trusts Ollama's runner as the final authority. Today this is + * set to `1` for `mllama`-family models (e.g. llama3.2-vision) which + * reject multi-image requests with HTTP 500. + */ + maxImages?: number | null; +} + +/** + * Map of model slug to its capabilities. Built from the Rust command's + * `HashMap` payload. + * + * Modelled as `Partial>` so that lookups on unknown slugs + * yield `undefined` instead of being silently typed as `ModelCapabilities`. + * Every consumer is forced to handle the missing-metadata case. + */ +export type ModelCapabilitiesMap = Partial>; diff --git a/src/utils/__tests__/capabilityConflicts.test.ts b/src/utils/__tests__/capabilityConflicts.test.ts new file mode 100644 index 00000000..45e0eaef --- /dev/null +++ b/src/utils/__tests__/capabilityConflicts.test.ts @@ -0,0 +1,255 @@ +import { describe, it, expect } from 'vitest'; +import { getCapabilityConflict } from '../capabilityConflicts'; +import type { ModelCapabilities } from '../../types/model'; +import type { ComposeCapabilityState } from '../capabilityConflicts'; + +const VISION: ModelCapabilities = { + vision: true, + thinking: false, + maxImages: null, +}; +const VISION_SINGLE_IMAGE: ModelCapabilities = { + vision: true, + thinking: false, + maxImages: 1, +}; +const VISION_TWO_IMAGES: ModelCapabilities = { + vision: true, + thinking: false, + maxImages: 2, +}; +const TEXT_ONLY: ModelCapabilities = { + vision: false, + thinking: false, + maxImages: null, +}; +const THINKING_ONLY: ModelCapabilities = { + vision: false, + thinking: true, + maxImages: null, +}; +const VISION_AND_THINKING: ModelCapabilities = { + vision: true, + thinking: true, + maxImages: null, +}; + +const EMPTY: ComposeCapabilityState = { + hasScreenCommand: false, + hasThinkCommand: false, + imageCount: 0, +}; + +describe('getCapabilityConflict', () => { + it('returns null when nothing is queued', () => { + expect(getCapabilityConflict('llama3', TEXT_ONLY, EMPTY)).toBeNull(); + }); + + it('returns null when capabilities are unknown (defaults permissive)', () => { + const result = getCapabilityConflict('llama3', undefined, { + ...EMPTY, + imageCount: 1, + }); + expect(result).toBeNull(); + }); + + it('returns null when capabilities is null', () => { + const result = getCapabilityConflict('llama3', null, { + ...EMPTY, + imageCount: 1, + }); + expect(result).toBeNull(); + }); + + it('returns null when active model can see images and has no max-images cap', () => { + const result = getCapabilityConflict('llava', VISION, { + ...EMPTY, + hasScreenCommand: true, + imageCount: 3, + }); + expect(result).toBeNull(); + }); + + it('returns conflict when images attached and model is text-only', () => { + const result = getCapabilityConflict('llama3', TEXT_ONLY, { + ...EMPTY, + imageCount: 1, + }); + expect(result).toBe( + 'llama3 reads text only. Try a vision model for images.', + ); + }); + + it('returns conflict when /screen is queued and model is text-only', () => { + const result = getCapabilityConflict('llama3', TEXT_ONLY, { + ...EMPTY, + hasScreenCommand: true, + }); + expect(result).toContain('reads text only'); + }); + + it('falls back to a generic name when model name is empty', () => { + const result = getCapabilityConflict('', TEXT_ONLY, { + ...EMPTY, + imageCount: 1, + }); + expect(result).toBe( + 'this model reads text only. Try a vision model for images.', + ); + }); + + it('falls back to a generic name when model name is null', () => { + const result = getCapabilityConflict(null, TEXT_ONLY, { + ...EMPTY, + imageCount: 1, + }); + expect(result?.startsWith('this model')).toBe(true); + }); + + it('falls back to a generic name when model name is undefined', () => { + const result = getCapabilityConflict(undefined, TEXT_ONLY, { + ...EMPTY, + imageCount: 1, + }); + expect(result?.startsWith('this model')).toBe(true); + }); + + // ── max-images gate ─────────────────────────────────────────────────────── + + it('returns null when single-image vision model has exactly one image', () => { + const result = getCapabilityConflict( + 'llama3.2-vision', + VISION_SINGLE_IMAGE, + { ...EMPTY, imageCount: 1 }, + ); + expect(result).toBeNull(); + }); + + it('refuses two attached images on a single-image vision model', () => { + const result = getCapabilityConflict( + 'llama3.2-vision', + VISION_SINGLE_IMAGE, + { ...EMPTY, imageCount: 2 }, + ); + expect(result).toBe( + 'llama3.2-vision accepts one image at a time. Remove the extras to send.', + ); + }); + + it('counts /screen as one image toward the cap', () => { + // Single-image vision model + one attached image + /screen queued = + // effective count of 2, exceeds the cap of 1. + const result = getCapabilityConflict( + 'llama3.2-vision', + VISION_SINGLE_IMAGE, + { ...EMPTY, hasScreenCommand: true, imageCount: 1 }, + ); + expect(result).toBe( + 'llama3.2-vision accepts one image at a time. Remove the extras to send.', + ); + }); + + it('allows /screen alone on a single-image vision model', () => { + const result = getCapabilityConflict( + 'llama3.2-vision', + VISION_SINGLE_IMAGE, + { ...EMPTY, hasScreenCommand: true }, + ); + expect(result).toBeNull(); + }); + + it('pluralizes the noun for a multi-image cap', () => { + const result = getCapabilityConflict('multi-cap', VISION_TWO_IMAGES, { + ...EMPTY, + imageCount: 5, + }); + expect(result).toBe( + 'multi-cap accepts 2 images at a time. Remove the extras to send.', + ); + }); + + it('allows submits at the cap exactly', () => { + const result = getCapabilityConflict('multi-cap', VISION_TWO_IMAGES, { + ...EMPTY, + imageCount: 2, + }); + expect(result).toBeNull(); + }); + + it('ignores a max-images cap below 1 (defensive)', () => { + const odd: ModelCapabilities = { + vision: true, + thinking: false, + maxImages: 0, + }; + const result = getCapabilityConflict('odd', odd, { + ...EMPTY, + imageCount: 3, + }); + expect(result).toBeNull(); + }); + + // ── /think gate ─────────────────────────────────────────────────────────── + + it('refuses /think on a non-thinking model', () => { + const result = getCapabilityConflict('llama3', TEXT_ONLY, { + ...EMPTY, + hasThinkCommand: true, + }); + expect(result).toBe( + "llama3 doesn't show reasoning. Try a thinking model for /think.", + ); + }); + + it('allows /think on a thinking-capable model', () => { + const result = getCapabilityConflict('reasoner', THINKING_ONLY, { + ...EMPTY, + hasThinkCommand: true, + }); + expect(result).toBeNull(); + }); + + it('falls back to a generic name when /think mismatches and name is empty', () => { + const result = getCapabilityConflict('', TEXT_ONLY, { + ...EMPTY, + hasThinkCommand: true, + }); + expect(result).toBe( + "this model doesn't show reasoning. Try a thinking model for /think.", + ); + }); + + it('prefers the vision message when /think and images both mismatch', () => { + // Vision is the more fundamental constraint and recovery from it + // (switching to a vision model) is also more likely to satisfy the + // /think requirement than the other way around. + const result = getCapabilityConflict('llama3', TEXT_ONLY, { + ...EMPTY, + imageCount: 1, + hasThinkCommand: true, + }); + expect(result).toBe( + 'llama3 reads text only. Try a vision model for images.', + ); + }); + + it('still fires the /think gate when vision is satisfied but thinking is not', () => { + const result = getCapabilityConflict('llava', VISION, { + ...EMPTY, + imageCount: 1, + hasThinkCommand: true, + }); + expect(result).toBe( + "llava doesn't show reasoning. Try a thinking model for /think.", + ); + }); + + it('returns null when both vision and thinking are satisfied', () => { + const result = getCapabilityConflict('omnimodel', VISION_AND_THINKING, { + ...EMPTY, + imageCount: 1, + hasThinkCommand: true, + }); + expect(result).toBeNull(); + }); +}); diff --git a/src/utils/capabilityConflicts.ts b/src/utils/capabilityConflicts.ts new file mode 100644 index 00000000..2d03ba74 --- /dev/null +++ b/src/utils/capabilityConflicts.ts @@ -0,0 +1,82 @@ +import type { ModelCapabilities } from '../types/model'; + +/** + * Compose-state inputs the gate inspects. `imageCount` covers manually + * attached + pasted + dragged images. `hasScreenCommand` covers the + * `/screen` slash command (which produces an image after capture and so + * has the same vision-required constraint as a non-zero imageCount). + * `hasThinkCommand` covers the `/think` slash command, which requires a + * model that emits reasoning tokens for the ThinkingBlock UI to render + * anything meaningful. + */ +export interface ComposeCapabilityState { + /** True if the message contains the `/screen` slash command. */ + hasScreenCommand: boolean; + /** True if the message contains the `/think` slash command. */ + hasThinkCommand: boolean; + /** + * Number of images attached to the compose state. Used by the + * max-images gate to refuse multi-image submits to single-image + * vision models (e.g. llama3.2-vision). The `/screen` command adds + * exactly one image at capture time so callers should fold it into + * this count when both are true. + */ + imageCount: number; +} + +/** + * Returns a single human-readable reason why the active model cannot + * send the current compose state, or `null` if the message is sendable. + * + * The strip and the submit-time toast both render the returned string + * verbatim so the wording lives in exactly one place. + * + * Defaults to permissive: an unknown active model (capabilities not yet + * fetched, or fetch failed) returns `null` so the user is never blocked + * by missing metadata. The backend is the final authority and will + * surface a real error if the model truly cannot accept the payload. + */ +export function getCapabilityConflict( + modelName: string | undefined | null, + capabilities: ModelCapabilities | undefined | null, + state: ComposeCapabilityState, +): string | null { + const needsVision = state.imageCount > 0 || state.hasScreenCommand; + const needsThinking = state.hasThinkCommand; + if (!needsVision && !needsThinking) return null; + if (!capabilities) return null; + const name = modelName && modelName.length > 0 ? modelName : 'this model'; + + // Vision is checked first when both apply because it is the more + // fundamental constraint: a text-only model cannot consume the image + // payload at all, while /think on a non-thinking model just degrades + // to a normal answer. Picking the vision message keeps the user + // pointed at the action that unblocks the most. + if (needsVision) { + if (!capabilities.vision) { + return `${name} reads text only. Try a vision model for images.`; + } + // Vision model, but it may cap the number of images per request + // (today: mllama-family models such as llama3.2-vision are 1-image + // only). Fold the /screen command into the effective count so a + // queued capture counts toward the cap exactly like an attached + // image. + const max = capabilities.maxImages; + if (max != null && max >= 1) { + const effective = state.imageCount + (state.hasScreenCommand ? 1 : 0); + if (effective > max) { + const noun = max === 1 ? 'one image' : `${max} images`; + return `${name} accepts ${noun} at a time. Remove the extras to send.`; + } + } + } + + // /think requires a model that emits reasoning tokens; otherwise the + // command is silently ignored and the user gets a normal answer with + // no ThinkingBlock, which feels broken. Surface the mismatch instead. + if (needsThinking && !capabilities.thinking) { + return `${name} doesn't show reasoning. Try a thinking model for /think.`; + } + + return null; +} diff --git a/src/view/AskBarView.tsx b/src/view/AskBarView.tsx index e4bd6ba1..d7a8bab3 100644 --- a/src/view/AskBarView.tsx +++ b/src/view/AskBarView.tsx @@ -5,7 +5,9 @@ import { formatQuotedText } from '../utils/formatQuote'; import { useConfig } from '../contexts/ConfigContext'; import { ImageThumbnails } from '../components/ImageThumbnails'; import { CommandSuggestion } from '../components/CommandSuggestion'; +import { ModelPicker } from '../components/ModelPicker'; import { Tooltip } from '../components/Tooltip'; +import { CapabilityMismatchStrip } from '../components/CapabilityMismatchStrip'; import type { AttachedImage } from '../types/image'; import { MAX_IMAGE_SIZE_BYTES } from '../types/image'; import { COMMANDS } from '../config/commands'; @@ -95,8 +97,8 @@ const BORDER_TRACE_RING = ( /** Hoisted static history (clock) icon - prevents re-allocation on every render. */ const HISTORY_ICON = ( ); -/** - * Renders text with command triggers highlighted in violet for the mirror div. - * Only the first occurrence of each command is highlighted; duplicates render plain. - */ -export function renderHighlightedText(text: string): React.ReactNode { - const parts: React.ReactNode[] = []; - let remaining = text; - const highlighted = new Set(); - - while (remaining.length > 0) { - let earliest = -1; - let matchedTrigger = ''; - for (const cmd of COMMANDS) { - if (highlighted.has(cmd.trigger)) continue; - const idx = remaining.indexOf(cmd.trigger); - if (idx !== -1 && (earliest === -1 || idx < earliest)) { - const before = idx === 0 || remaining[idx - 1] === ' '; - const after = - idx + cmd.trigger.length >= remaining.length || - remaining[idx + cmd.trigger.length] === ' '; - if (before && after) { - earliest = idx; - matchedTrigger = cmd.trigger; - } - } - } - - if (earliest === -1) { - parts.push({remaining}); - break; - } - - if (earliest > 0) { - parts.push( - {remaining.slice(0, earliest)}, - ); - } - parts.push( - - {matchedTrigger} - , - ); - highlighted.add(matchedTrigger); - remaining = remaining.slice(earliest + matchedTrigger.length); - } - - return <>{parts}; -} - /** * Maximum number of manually attached images per message. The backend allows * one additional image from /screen capture, for a total of 4 per message @@ -236,6 +189,80 @@ interface AskBarViewProps { * "normal" = violet ring; "max" = red ring + label; undefined = no ring. */ isDragOver?: 'normal' | 'max'; + /** Currently active Ollama model slug. Enables the model picker when set. */ + activeModel?: string; + /** Full list of model slugs available for selection in the picker. */ + availableModels?: string[]; + /** + * Called when the user clicks the model picker trigger. App.tsx owns the + * open/close state and renders the ModelPickerPanel as an inline drawer. + */ + onModelPickerToggle?: () => void; + /** Whether the model picker panel is currently open (drives aria-expanded). */ + isModelPickerOpen?: boolean; + /** + * Capability mismatch message to render between the attachments row and + * the input. `null` (or undefined) renders nothing. The host computes + * this string via `getCapabilityConflict` and passes it down. + */ + capabilityConflictMessage?: string | null; + /** + * When true, the input row plays a brief horizontal shake animation. + * The host pulses this true / false to signal a refused submit. + */ + shake?: boolean; +} + +/** + * Renders text with command triggers highlighted in violet for the mirror div. + * Only the first occurrence of each command is highlighted; duplicates render + * as plain text. Word-boundary aware: `/searching` does not match `/search`. + * + * Exported for direct unit testing. + */ +export function renderHighlightedText(text: string): React.ReactNode { + const parts: React.ReactNode[] = []; + let remaining = text; + const highlighted = new Set(); + + while (remaining.length > 0) { + let earliest = -1; + let matchedTrigger = ''; + for (const cmd of COMMANDS) { + if (highlighted.has(cmd.trigger)) continue; + const idx = remaining.indexOf(cmd.trigger); + if (idx !== -1 && (earliest === -1 || idx < earliest)) { + const before = idx === 0 || remaining[idx - 1] === ' '; + const after = + idx + cmd.trigger.length >= remaining.length || + remaining[idx + cmd.trigger.length] === ' '; + if (before && after) { + earliest = idx; + matchedTrigger = cmd.trigger; + } + } + } + + if (earliest === -1) { + parts.push({remaining}); + break; + } + + if (earliest > 0) { + parts.push( + {remaining.slice(0, earliest)}, + ); + } + parts.push( + + {matchedTrigger} + , + ); + highlighted.add(matchedTrigger); + remaining = remaining.slice(earliest + matchedTrigger.length); + } + + return <>{parts}; } /** @@ -261,12 +288,28 @@ export function AskBarView({ onImagePreview, onScreenshot, isDragOver, + activeModel, + availableModels, + onModelPickerToggle, + isModelPickerOpen, + capabilityConflictMessage, + shake = false, }: AskBarViewProps) { + /** Quote display limits resolved from the managed AppConfig. */ + const quote = useConfig().quote; + /** Ref to the mirror div behind the textarea for command highlighting. */ const mirrorRef = useRef(null); - /** Quote display limits resolved from the managed AppConfig. */ - const quote = useConfig().quote; + /** Syncs the mirror div scroll position with the textarea so the colored + * spans stay aligned with the caret on long inputs. */ + const handleTextareaScroll = useCallback(() => { + /* v8 ignore start -- both refs are always set by React when this fires */ + if (!mirrorRef.current || !inputRef.current) return; + /* v8 ignore stop */ + mirrorRef.current.scrollTop = inputRef.current.scrollTop; + mirrorRef.current.scrollLeft = inputRef.current.scrollLeft; + }, [inputRef]); /** True when the UI should be locked - either generating or waiting for images. */ const isBusy = isGenerating || isSubmitPending; @@ -283,6 +326,20 @@ export function AskBarView({ return () => clearTimeout(timer); }, [pasteMaxError]); + // ─── Model picker availability gate ─────────────────────────────────────── + + /** + * Prerequisites for rendering the chip trigger in the input bar. + * Hidden in chat mode — the pill trigger moves to the WindowControls header. + */ + const modelPickerAvailable = Boolean( + !isChatMode && + activeModel && + availableModels && + availableModels.length > 0 && + onModelPickerToggle, + ); + // ─── Command suggestion state ───────────────────────────────────────────── /** @@ -473,14 +530,6 @@ export function AskBarView({ ], ); - /** Syncs the mirror div scroll position with the textarea. */ - const handleTextareaScroll = useCallback(() => { - /* v8 ignore start -- both refs are always set by React when this fires */ - if (!mirrorRef.current || !inputRef.current) return; - /* v8 ignore stop */ - mirrorRef.current.scrollTop = inputRef.current.scrollTop; - }, [inputRef]); - /** Handles clipboard paste - extracts image items from clipboardData. */ const handlePaste = useCallback( (e: React.ClipboardEvent) => { @@ -556,6 +605,9 @@ export function AskBarView({ />
      )} + {capabilityConflictMessage && ( + + )} {/* Command suggestion renders above the input row in the normal DOM flow. Being inside the morphing container means the ResizeObserver detects the added height and grows the native window upward to reveal @@ -582,7 +634,14 @@ export function AskBarView({ )} -
      +
      {HISTORY_ICON} )}
      - {/* Mirror div: renders the same text with highlighted commands. - Sits behind the transparent textarea so colored spans show through. */} + {/* Mirror div: renders the same text with highlighted slash + commands. Sits behind the transparent textarea so colored + spans show through. Metrics (font, size, padding, leading, + wrap) MUST mirror the textarea exactly so the caret never + drifts off the rendered glyphs. */} @@ -627,7 +690,7 @@ export function AskBarView({ autoFocus rows={1} placeholder={isChatMode ? 'Reply...' : 'Ask Thuki anything...'} - className="relative w-full bg-transparent border-none outline-none text-transparent text-sm placeholder:text-text-secondary py-2 px-1 disabled:opacity-50 resize-none leading-relaxed" + className="askbar-textarea relative w-full bg-transparent border-none outline-none text-transparent text-sm placeholder:text-text-secondary py-2 px-1 disabled:opacity-50 resize-none leading-5" style={{ caretColor: 'var(--color-text-primary)' }} />
      @@ -651,13 +714,23 @@ export function AskBarView({ onClick={onScreenshot} disabled={isBusy} aria-label="Take screenshot" - className="shrink-0 w-7 h-7 flex items-center justify-center rounded-lg text-text-secondary hover:text-text-primary hover:bg-white/8 transition-colors duration-150 disabled:opacity-40 disabled:cursor-default cursor-pointer" + className="shrink-0 w-7 h-7 flex items-center justify-center rounded-lg text-text-secondary hover:text-primary hover:bg-primary/10 transition-colors duration-150 disabled:opacity-40 disabled:cursor-default cursor-pointer" > {CAMERA_ICON} )} + {modelPickerAvailable && onModelPickerToggle && ( + + + + )} +
      -
      +
      ); } diff --git a/src/view/ConversationView.tsx b/src/view/ConversationView.tsx index 15bf40b2..f3c5badc 100644 --- a/src/view/ConversationView.tsx +++ b/src/view/ConversationView.tsx @@ -74,6 +74,12 @@ interface ConversationViewProps { * of the typing indicator. */ searchStage?: SearchStage; + /** Currently active model slug forwarded to the WindowControls pill trigger. */ + activeModel?: string; + /** Toggles the model picker panel; forwarded to WindowControls. */ + onModelPickerToggle?: () => void; + /** Whether the model picker panel is open; drives aria-expanded on the pill. */ + isModelPickerOpen?: boolean; } /** @@ -97,6 +103,9 @@ export function ConversationView({ onNewConversation, onImagePreview, searchStage = null, + activeModel, + onModelPickerToggle, + isModelPickerOpen, }: ConversationViewProps) { const scrollContainerRef = useRef(null); @@ -197,6 +206,9 @@ export function ConversationView({ canSave={canSave} onNewConversation={onNewConversation} onHistoryOpen={onHistoryOpen} + activeModel={activeModel} + onModelPickerToggle={onModelPickerToggle} + isModelPickerOpen={isModelPickerOpen} />
      { @@ -212,6 +212,131 @@ describe('AskBarView', () => { ).toBeInTheDocument(); }); + it('renders a model picker trigger in ask-bar mode when models are available', () => { + render( + , + ); + expect( + screen.getByRole('button', { name: 'Choose model' }), + ).toBeInTheDocument(); + }); + + it('hides model picker trigger in chat mode (trigger moves to WindowControls header)', () => { + render( + , + ); + expect(screen.queryByRole('button', { name: 'Choose model' })).toBeNull(); + }); + + it('calls onModelPickerToggle when the Choose model button is clicked', () => { + const onModelPickerToggle = vi.fn(); + render( + , + ); + fireEvent.click(screen.getByRole('button', { name: 'Choose model' })); + expect(onModelPickerToggle).toHaveBeenCalledTimes(1); + }); + + it('sets aria-expanded on model picker trigger from isModelPickerOpen prop', () => { + render( + , + ); + expect( + screen.getByRole('button', { name: 'Choose model' }), + ).toHaveAttribute('aria-expanded', 'true'); + }); + + it('renders the model picker inside a Choose model tooltip wrapper in ask-bar mode', () => { + render( + , + ); + const trigger = screen.getByRole('button', { name: 'Choose model' }); + fireEvent.mouseEnter(trigger.parentElement!); + expect(screen.getAllByText('Choose model').length).toBeGreaterThanOrEqual( + 1, + ); + }); + + it('hides the model picker trigger in ask-bar mode when no models are available', () => { + render( + , + ); + expect(screen.queryByRole('button', { name: 'Choose model' })).toBeNull(); + }); + it('displays selectedText when provided', () => { render( { />, ); const btn = screen.getByRole('button', { name: 'Take screenshot' }); - expect(btn.className).not.toContain('hover:text-text-primary'); - expect(btn.className).not.toContain('hover:bg-white/8'); + expect(btn.className).not.toContain('hover:text-primary'); + expect(btn.className).not.toContain('hover:bg-primary/10'); }); it('has hover classes when below max images', () => { @@ -1015,8 +1140,8 @@ describe('AskBarView', () => { />, ); const btn = screen.getByRole('button', { name: 'Take screenshot' }); - expect(btn.className).toContain('hover:text-text-primary'); - expect(btn.className).toContain('hover:bg-white/8'); + expect(btn.className).toContain('hover:text-primary'); + expect(btn.className).toContain('hover:bg-primary/10'); }); it('shows tooltip explaining limit when camera button is hovered at max images', () => { @@ -1453,105 +1578,84 @@ describe('AskBarView', () => { }); }); - describe('Command highlighting mirror div', () => { - it('renders a mirror div with aria-hidden behind the textarea', () => { - const { container } = render( + describe('capability gate UI', () => { + it('renders the capability mismatch strip when message provided', () => { + render( , ); - const mirror = container.querySelector('[aria-hidden="true"]'); - expect(mirror).not.toBeNull(); - expect(mirror!.classList.contains('pointer-events-none')).toBe(true); - }); - - it('highlights /screen command in violet in the mirror div', () => { - const { container } = render( - , + expect(screen.getByTestId('capability-mismatch-strip')).toHaveTextContent( + "llama3 can't see images.", ); - const mirror = container.querySelector('[aria-hidden="true"]'); - const highlighted = mirror!.querySelector('.text-violet-400'); - expect(highlighted).not.toBeNull(); - expect(highlighted!.textContent).toBe('/screen'); }); - it('highlights multiple commands in the mirror div', () => { - const { container } = render( + it('omits the strip when message is null', () => { + render( , ); - const mirror = container.querySelector('[aria-hidden="true"]'); - const highlighted = mirror!.querySelectorAll('.text-violet-400'); - expect(highlighted).toHaveLength(2); - expect(highlighted[0].textContent).toBe('/screen'); - expect(highlighted[1].textContent).toBe('/think'); + expect(screen.queryByTestId('capability-mismatch-strip')).toBeNull(); }); - it('does not highlight partial command matches like /screensaver', () => { - const { container } = render( + it('mounts the shake animation branch when shake is true', () => { + render( , ); - const mirror = container.querySelector('[aria-hidden="true"]'); - expect(mirror!.querySelector('.text-violet-400')).toBeNull(); - expect(mirror!.textContent).toBe('/screensaver is nice'); + expect(screen.getByTestId('ask-bar-row')).toBeInTheDocument(); }); - it('renders plain text in mirror div when no commands present', () => { - const { container } = render( + it('keeps the no-shake branch when shake is false', () => { + render( , ); - const mirror = container.querySelector('[aria-hidden="true"]'); - expect(mirror!.querySelector('.text-violet-400')).toBeNull(); - expect(mirror!.textContent).toBe('hello world'); + expect(screen.getByTestId('ask-bar-row')).toBeInTheDocument(); }); + }); - it('makes the textarea text transparent for the mirror overlay', () => { - const { container } = render( + describe('slash command highlighting', () => { + it('mirror div renders the query so colored spans show through the textarea', () => { + render( { inputRef={makeRef()} />, ); - const textarea = container.querySelector('textarea'); - expect(textarea!.classList.contains('text-transparent')).toBe(true); - expect(textarea!.style.caretColor).toBe('var(--color-text-primary)'); - }); - - it('handles scroll event gracefully when refs are not yet set', () => { - // Render with a ref that has current = null to exercise the null guard - const nullRef = { current: null }; - const { container } = render( - , + const mirror = screen.getByTestId('askbar-mirror'); + expect(mirror).toHaveTextContent('/search what is Rust?'); + // The trigger token sits in its own span with the violet utility class. + const tokenSpan = Array.from(mirror.querySelectorAll('span')).find( + (s) => s.textContent === '/search', ); - // The textarea is rendered by React, but the inputRef.current is null - // because we passed a ref with current=null that was not wired via callback. - // The scroll handler should not throw. - const textarea = container.querySelector('textarea')!; - expect(() => fireEvent.scroll(textarea)).not.toThrow(); + expect(tokenSpan).toBeDefined(); + expect(tokenSpan?.className).toContain('text-violet-400'); }); - it('syncs mirror div scroll with textarea scroll', () => { + it('syncs mirror scrollTop with the textarea so the highlight tracks the caret', () => { const ref = makeRef(); - const { container } = render( + render( { inputRef={ref} />, ); - const textarea = container.querySelector('textarea')!; - const mirror = container.querySelector('[aria-hidden="true"]')!; - - // Simulate scrolling - Object.defineProperty(textarea, 'scrollTop', { - value: 42, - writable: true, + const mirror = screen.getByTestId('askbar-mirror') as HTMLDivElement; + // Pretend the textarea has been scrolled. + Object.defineProperty(ref.current, 'scrollTop', { + configurable: true, + value: 24, }); - fireEvent.scroll(textarea); - expect(mirror.scrollTop).toBe(42); + Object.defineProperty(ref.current, 'scrollLeft', { + configurable: true, + value: 6, + }); + fireEvent.scroll(ref.current!); + expect(mirror.scrollTop).toBe(24); + expect(mirror.scrollLeft).toBe(6); + }); + }); + + describe('renderHighlightedText (pure)', () => { + it('returns a single span when no command trigger is present', () => { + const node = renderHighlightedText('plain text only'); + const { container } = render(<>{node}); + const violet = container.querySelector('.text-violet-400'); + expect(violet).toBeNull(); + expect(container).toHaveTextContent('plain text only'); + }); + + it('wraps the first valid trigger occurrence in the violet utility class', () => { + const node = renderHighlightedText('/search what is Rust?'); + const { container } = render(<>{node}); + const tokens = container.querySelectorAll('.text-violet-400'); + expect(tokens.length).toBe(1); + expect(tokens[0].textContent).toBe('/search'); + }); + + it('only highlights the first occurrence of any given trigger', () => { + const node = renderHighlightedText('/search foo /search bar'); + const { container } = render(<>{node}); + const tokens = container.querySelectorAll('.text-violet-400'); + expect(tokens.length).toBe(1); + }); + + it('does not match a trigger embedded inside a longer word', () => { + // /searching contains /search but is not a standalone trigger token. + const node = renderHighlightedText('/searching'); + const { container } = render(<>{node}); + expect(container.querySelector('.text-violet-400')).toBeNull(); + }); + + it('returns an empty fragment for an empty string without throwing', () => { + const node = renderHighlightedText(''); + const { container } = render(<>{node}); + expect(container.textContent).toBe(''); }); }); }); diff --git a/src/view/onboarding/IntroStep.tsx b/src/view/onboarding/IntroStep.tsx index ccaeb971..47673e2c 100644 --- a/src/view/onboarding/IntroStep.tsx +++ b/src/view/onboarding/IntroStep.tsx @@ -64,7 +64,7 @@ export function IntroStep({ onComplete }: Props) { margin: '0 0 6px', }} > - Before you dive in + {"You're all set"}

      - {"You'll get the hang of it quickly."} + {"Five quick tips and you're chatting in seconds."}

      diff --git a/src/view/onboarding/ModelCheckStep.tsx b/src/view/onboarding/ModelCheckStep.tsx new file mode 100644 index 00000000..510dca3e --- /dev/null +++ b/src/view/onboarding/ModelCheckStep.tsx @@ -0,0 +1,951 @@ +/** + * Onboarding step that gates the chat overlay on a working local Ollama + * setup with at least one installed model. + * + * Layout: + * - Vertical timeline rail with numbered nodes connected by a thin line. + * - Step 1 active shows a single title row, then a two-tab install hero + * (Install Ollama / Already Installed?) above a single code box that + * swaps its command per tab. A short sub-line below the box invites + * the user to paste the command or visit the Ollama docs. + * - Step 2 active hosts a compact list of starter models, all rendered + * equal — no badge, no hierarchy. The user picks whichever fits. + * + * Probes Ollama via the `check_model_setup` Tauri command on mount and on + * every Re-check click. Background polling is intentionally absent so + * idle CPU and IPC stay at zero between explicit user actions. + */ + +import { AnimatePresence, motion } from 'framer-motion'; +import type React from 'react'; +import { useState, useEffect, useRef, useCallback } from 'react'; +import { invoke } from '@tauri-apps/api/core'; +import thukiLogo from '../../../src-tauri/icons/128x128.png'; +import { useConfig } from '../../contexts/ConfigContext'; +import { Badge } from './_shared'; + +const OLLAMA_DOCS_URL = 'https://ollama.com/download'; +const OLLAMA_SEARCH_URL = 'https://ollama.com/search'; + +/** + * Extracts the `host:port` segment from an Ollama daemon URL for display. + * Falls back to the raw input when the URL cannot be parsed (e.g. user + * config holds a non-URL string), so the UI never shows a confusing + * empty value. + */ +function formatListenAddr(url: string): string { + try { + return new URL(url).host; + } catch { + return url; + } +} + +type ModelSetupState = + | { state: 'ollama_unreachable' } + | { state: 'no_models_installed' } + | { state: 'ready'; active_slug: string; installed: string[] }; + +interface InstallTab { + id: string; + label: string; + command: string; +} + +/** + * Install routes shown above the Step 1 code box. The first entry is the + * default selection. `command` is the exact string copied to the + * clipboard when the copy pill is clicked. + */ +const INSTALL_TABS: InstallTab[] = [ + { + id: 'install', + label: 'Install Ollama', + command: 'curl -fsSL https://ollama.com/install.sh | sh', + }, + { + id: 'already-installed', + label: 'Already Installed?', + command: 'open -a Ollama', + }, +]; + +/** + * Starter models offered in Step 2. All entries support text and image + * input (vision / multimodal). Sizes are pulled from the official Ollama + * library (ollama.com/library) and reflect the default tag at time of + * authoring. All entries are intentionally peers — no recommended + * badge — so the user picks whichever fits their hardware. + */ +const STARTER_MODELS: Array<{ + slug: string; + description: string; + size: string; +}> = [ + { slug: 'gemma4:e4b', description: 'Google · vision', size: '9.6 GB' }, + { + slug: 'llama3.2-vision:11b', + description: 'Meta · vision', + size: '7.8 GB', + }, + { slug: 'phi4:14b', description: 'Microsoft · text', size: '9.1 GB' }, +]; + +/** + * Builds the public Ollama library URL for a model slug. Drops the `:tag` + * suffix so the destination shows every available variant rather than + * pinning the user to one quantisation. Both `gemma4` and `gemma4:e4b` + * resolve, but the bare-name URL is the more useful landing. + */ +function buildOllamaLibraryUrl(slug: string): string { + const base = slug.split(':')[0]; + return `https://ollama.com/library/${base}`; +} + +function buildPullCommand(slug: string): string { + return `ollama pull ${slug}`; +} + +async function copyToClipboard(text: string): Promise { + try { + await navigator.clipboard.writeText(text); + return true; + } catch { + return false; + } +} + +export function ModelCheckStep() { + const [setupState, setSetupState] = useState(null); + const [isRechecking, setIsRechecking] = useState(false); + const mountedRef = useRef(true); + + const probe = useCallback(async () => { + try { + const next = await invoke('check_model_setup'); + if (!mountedRef.current) return; + if (next.state === 'ready') { + await invoke('advance_past_model_check'); + return; + } + setSetupState(next); + } catch { + if (!mountedRef.current) return; + setSetupState({ state: 'ollama_unreachable' }); + } + }, []); + + useEffect(() => { + mountedRef.current = true; + void probe(); + return () => { + mountedRef.current = false; + }; + }, [probe]); + + const handleRecheck = useCallback(async () => { + setIsRechecking(true); + try { + await probe(); + } finally { + if (mountedRef.current) { + setIsRechecking(false); + } + } + }, [probe]); + + const ollamaConnected = setupState?.state === 'no_models_installed'; + const isWaitingForOllama = setupState?.state === 'ollama_unreachable'; + const isProbing = setupState === null; + + const titleSub = isProbing + ? 'Checking your local Ollama setup…' + : ollamaConnected + ? "Almost there. Let's pick a brain for Thuki." + : 'Runs Ollama locally. Your chats stay on this machine.'; + + return ( +
      + + {/* Top edge highlight, identical to PermissionsStep / IntroStep. */} +
      + +
      + Thuki +
      + +

      + Set up your local AI +

      +

      + {titleSub} +

      + + {!isProbing ? ( + + ) : null} + + + +

      + Private by default · All inference runs on your machine +

      + +
      + ); +} + +// ─── Rail ──────────────────────────────────────────────────────────────────── + +interface RailProps { + stepOneActive: boolean; + stepOneDone: boolean; + stepTwoActive: boolean; +} + +/** + * Two-step vertical timeline. The connecting line is rendered once as an + * absolute element behind the node column so it spans the full rail + * regardless of how tall each row's content grows. + */ +function Rail({ stepOneActive, stepOneDone, stepTwoActive }: RailProps) { + return ( +
      + + ); +} + +type NodeVariant = 'active' | 'done' | 'wait'; + +interface RailNodeProps { + number: number; + variant: NodeVariant; + topGap?: number; +} + +function RailNode({ number, variant, topGap = 0 }: RailNodeProps) { + const palette: Record< + NodeVariant, + { bg: string; border: string; color: string } + > = { + active: { + bg: 'rgba(255,141,92,0.1)', + border: 'rgba(255,141,92,0.4)', + color: '#ff8d5c', + }, + done: { + bg: 'rgba(34,197,94,0.12)', + border: 'rgba(34,197,94,0.4)', + color: '#22c55e', + }, + wait: { + bg: 'rgba(255,255,255,0.03)', + border: 'rgba(255,255,255,0.1)', + color: 'rgba(255,255,255,0.4)', + }, + }; + const p = palette[variant]; + return ( +
      +
      + {variant === 'done' ? '✓' : number} +
      +
      + ); +} + +// ─── Row 1: install Ollama ─────────────────────────────────────────────────── + +interface RowOneProps { + active: boolean; + done: boolean; +} + +function RowOne({ active, done }: RowOneProps) { + const config = useConfig(); + const [selectedTabIdx, setSelectedTabIdx] = useState(0); + const tab = INSTALL_TABS[selectedTabIdx]; + + return ( +
      +
      +
      +

      + {done ? 'Ollama is running' : 'Install & start Ollama'} +

      + {done ? ( +

      + Listening on {formatListenAddr(config.model.ollamaUrl)} +

      + ) : null} +
      + {done ? live : null} +
      + + {active ? ( + <> +
      +
      + {INSTALL_TABS.map((t, i) => ( + setSelectedTabIdx(i)} + /> + ))} +
      +
      + + + $ + + {tab.command} + + +
      +
      +
      + + Paste this in Terminal or visit + + + Ollama docs ↗ + +
      + + ) : null} +
      + ); +} + +// ─── Row 2: pull a starter model ───────────────────────────────────────────── + +function RowTwo({ active }: { active: boolean }) { + return ( +
      +

      + Pull a starter model +

      + + {active ? ( + <> +

      + You can swap or add more later. +

      +
      + {STARTER_MODELS.map((m, i) => ( + + ))} +
      + +
      + + Paste the command in Terminal + + or + + Browse all models on ollama.com ↗ + +
      + + ) : null} +
      + ); +} + +interface ModelRowProps { + slug: string; + description: string; + size: string; + isLast: boolean; +} + +function ModelRow({ slug, description, size, isLast }: ModelRowProps) { + return ( +
      +
      + +

      + {description} · {size} +

      +
      + +
      + ); +} + +/** + * Renders the model slug as an inline button styled like text. Click + * opens the model's Ollama library page in the user's default browser + * via the `open_url` Tauri command. Hover lifts the slug to brand + * orange with a subtle underline so it reads as discoverable without + * shouting. + */ +function SlugLink({ slug }: { slug: string }) { + const [hover, setHover] = useState(false); + return ( + + ); +} + +// ─── Tab + copy + docs link ────────────────────────────────────────────────── + +interface DocsLinkProps { + ariaLabel: string; + url: string; + children: React.ReactNode; +} + +function DocsLink({ ariaLabel, url, children }: DocsLinkProps) { + const [hover, setHover] = useState(false); + return ( + + ); +} + +interface TabButtonProps { + label: string; + selected: boolean; + onClick: () => void; +} + +function TabButton({ label, selected, onClick }: TabButtonProps) { + const [hover, setHover] = useState(false); + const borderColor = selected + ? 'rgba(255, 141, 92, 0.28)' + : hover + ? 'rgba(255, 255, 255, 0.1)' + : 'transparent'; + const bg = selected + ? 'rgba(255, 141, 92, 0.1)' + : hover + ? 'rgba(255, 255, 255, 0.04)' + : 'rgba(255, 255, 255, 0.025)'; + const color = selected + ? '#ff8d5c' + : hover + ? 'rgba(255,255,255,0.85)' + : 'rgba(255,255,255,0.55)'; + + return ( + + ); +} + +const COPIED_RESET_MS = 1500; + +interface CopyButtonProps { + command: string; + ariaLabel: string; + label?: string; + iconOnly?: boolean; +} + +function CopyButton({ + command, + ariaLabel, + label = 'Copy', + iconOnly = false, +}: CopyButtonProps) { + const [hover, setHover] = useState(false); + const [copied, setCopied] = useState(false); + const timeoutRef = useRef(null); + + useEffect(() => { + return () => { + if (timeoutRef.current !== null) { + window.clearTimeout(timeoutRef.current); + } + }; + }, []); + + const handleClick = useCallback(async () => { + const ok = await copyToClipboard(command); + if (!ok) return; + setCopied(true); + if (timeoutRef.current !== null) { + window.clearTimeout(timeoutRef.current); + } + timeoutRef.current = window.setTimeout(() => { + setCopied(false); + timeoutRef.current = null; + }, COPIED_RESET_MS); + }, [command]); + + const borderColor = copied + ? 'rgba(34,197,94,0.55)' + : hover + ? 'rgba(255,141,92,0.55)' + : 'rgba(255,255,255,0.12)'; + const labelColor = + hover || copied ? 'rgba(255,255,255,0.95)' : 'rgba(255,255,255,0.7)'; + const glyphColor = copied + ? '#22c55e' + : hover + ? '#ff8d5c' + : 'rgba(255,255,255,0.7)'; + + return ( + + ); +} + +// ─── Glyphs ────────────────────────────────────────────────────────────────── + +function CopyGlyph() { + return ( + + + + + ); +} + +function CheckGlyph() { + return ( + + + + ); +} diff --git a/src/view/onboarding/PermissionsStep.tsx b/src/view/onboarding/PermissionsStep.tsx index e8cd5f40..b49c4408 100644 --- a/src/view/onboarding/PermissionsStep.tsx +++ b/src/view/onboarding/PermissionsStep.tsx @@ -3,6 +3,7 @@ import type React from 'react'; import { useState, useEffect, useRef, useCallback } from 'react'; import { invoke } from '@tauri-apps/api/core'; import thukiLogo from '../../../src-tauri/icons/128x128.png'; +import { StepCard, Badge } from './_shared'; /** How often to poll for permission grants after the user requests them. */ const POLL_INTERVAL_MS = 500; @@ -541,72 +542,3 @@ function CTAButton({ ); } - -interface StepCardProps { - active: boolean; - done: boolean; - children: React.ReactNode; -} - -function StepCard({ active, done, children }: StepCardProps) { - const borderColor = done - ? 'rgba(34,197,94,0.2)' - : active - ? 'rgba(255,141,92,0.4)' - : 'rgba(255,255,255,0.06)'; - - const background = done - ? 'rgba(34,197,94,0.05)' - : active - ? 'rgba(255,141,92,0.07)' - : 'rgba(255,255,255,0.03)'; - - return ( -
      - {children} -
      - ); -} - -interface BadgeProps { - color: 'green'; - children: React.ReactNode; -} - -function Badge({ color, children }: BadgeProps) { - const styles: Record = { - green: { - color: '#22c55e', - background: 'rgba(34,197,94,0.1)', - border: '1px solid rgba(34,197,94,0.2)', - }, - }; - - return ( - - {children} - - ); -} diff --git a/src/view/onboarding/__tests__/IntroStep.test.tsx b/src/view/onboarding/__tests__/IntroStep.test.tsx index 742dcc07..80b2a752 100644 --- a/src/view/onboarding/__tests__/IntroStep.test.tsx +++ b/src/view/onboarding/__tests__/IntroStep.test.tsx @@ -10,13 +10,13 @@ describe('IntroStep', () => { it('renders the title', () => { render(); - expect(screen.getByText('Before you dive in')).toBeInTheDocument(); + expect(screen.getByText("You're all set")).toBeInTheDocument(); }); it('renders the subtitle', () => { render(); expect( - screen.getByText("You'll get the hang of it quickly."), + screen.getByText("Five quick tips and you're chatting in seconds."), ).toBeInTheDocument(); }); diff --git a/src/view/onboarding/__tests__/ModelCheckStep.test.tsx b/src/view/onboarding/__tests__/ModelCheckStep.test.tsx new file mode 100644 index 00000000..a8dffc42 --- /dev/null +++ b/src/view/onboarding/__tests__/ModelCheckStep.test.tsx @@ -0,0 +1,692 @@ +import { + render, + screen, + fireEvent, + act, + waitFor, + cleanup, +} from '@testing-library/react'; +import { describe, it, expect, beforeEach, beforeAll, vi } from 'vitest'; +import { ModelCheckStep } from '../ModelCheckStep'; +import { + ConfigProviderForTest, + DEFAULT_CONFIG, +} from '../../../contexts/ConfigContext'; +import { + invoke, + enableChannelCaptureWithResponses, +} from '../../../testUtils/mocks/tauri'; + +const READY_RESPONSE = { + state: 'ready', + active_slug: 'gemma4:e4b', + installed: ['gemma4:e4b'], +}; + +const writeText = vi.fn().mockResolvedValue(undefined); + +beforeAll(() => { + if (!('clipboard' in navigator)) { + Object.defineProperty(navigator, 'clipboard', { + configurable: true, + writable: true, + value: { writeText }, + }); + } else { + Object.assign(navigator.clipboard, { writeText }); + } +}); + +describe('ModelCheckStep', () => { + beforeEach(() => { + invoke.mockClear(); + writeText.mockReset(); + writeText.mockResolvedValue(undefined); + }); + + it('shows Step 1 active and Step 2 waiting on Ollama unreachable', async () => { + enableChannelCaptureWithResponses({ + check_model_setup: { state: 'ollama_unreachable' }, + }); + + render(); + await act(async () => {}); + + expect(screen.getByText('Set up your local AI')).toBeInTheDocument(); + expect( + screen.getByText('Runs Ollama locally. Your chats stay on this machine.'), + ).toBeInTheDocument(); + expect(screen.getByText('Install & start Ollama')).toBeInTheDocument(); + expect( + screen.queryByText('STEP 1 · ACTION NEEDED'), + ).not.toBeInTheDocument(); + expect(screen.queryByText('STEP 2 · WAITING')).not.toBeInTheDocument(); + expect(screen.getByText('Pull a starter model')).toBeInTheDocument(); + expect( + screen.getByText('curl -fsSL https://ollama.com/install.sh | sh'), + ).toBeInTheDocument(); + }); + + it('shows Step 1 done and Step 2 active on no_models_installed', async () => { + enableChannelCaptureWithResponses({ + check_model_setup: { state: 'no_models_installed' }, + }); + + render(); + await act(async () => {}); + + expect(screen.getByText('Ollama is running')).toBeInTheDocument(); + expect( + screen.getByText('Listening on 127.0.0.1:11434'), + ).toBeInTheDocument(); + expect(screen.getByText('live')).toBeInTheDocument(); + expect(screen.queryByText('Connected')).not.toBeInTheDocument(); + expect(screen.queryByText('STEP 1 · DONE')).not.toBeInTheDocument(); + expect( + screen.queryByText('STEP 2 · ACTION NEEDED'), + ).not.toBeInTheDocument(); + expect( + screen.getByText("Almost there. Let's pick a brain for Thuki."), + ).toBeInTheDocument(); + expect( + screen.getByText('You can swap or add more later.'), + ).toBeInTheDocument(); + expect(screen.getByText('gemma4:e4b')).toBeInTheDocument(); + expect(screen.getByText('llama3.2-vision:11b')).toBeInTheDocument(); + expect(screen.getByText('phi4:14b')).toBeInTheDocument(); + expect(screen.queryByText('RECOMMENDED')).not.toBeInTheDocument(); + }); + + it('renders the configured Ollama URL host:port in the listening line', async () => { + enableChannelCaptureWithResponses({ + check_model_setup: { state: 'no_models_installed' }, + }); + + render( + + + , + ); + await act(async () => {}); + + expect(screen.getByText('Listening on 10.0.0.5:9000')).toBeInTheDocument(); + }); + + it('falls back to the raw Ollama URL string when it is not parseable', async () => { + enableChannelCaptureWithResponses({ + check_model_setup: { state: 'no_models_installed' }, + }); + + render( + + + , + ); + await act(async () => {}); + + expect(screen.getByText('Listening on not-a-url')).toBeInTheDocument(); + }); + + it('fires advance_past_model_check when Ready', async () => { + enableChannelCaptureWithResponses({ + check_model_setup: READY_RESPONSE, + advance_past_model_check: undefined, + }); + + render(); + await act(async () => {}); + + await waitFor(() => { + expect(invoke).toHaveBeenCalledWith('advance_past_model_check'); + }); + }); + + it('treats IPC failure as Ollama unreachable so the user sees a recovery path', async () => { + invoke.mockRejectedValueOnce(new Error('ipc broken')); + + render(); + await act(async () => {}); + + expect(screen.getByText('Install & start Ollama')).toBeInTheDocument(); + }); + + it('Re-check button re-runs the probe and updates state', async () => { + let calls = 0; + invoke.mockImplementation(async (name: string) => { + if (name === 'check_model_setup') { + calls += 1; + return calls === 1 + ? { state: 'ollama_unreachable' } + : { state: 'no_models_installed' }; + } + return undefined; + }); + + render(); + await act(async () => {}); + + expect(screen.getByText('Install & start Ollama')).toBeInTheDocument(); + + await act(async () => { + fireEvent.click(screen.getByLabelText('Verify setup')); + }); + + expect(screen.getByText('Ollama is running')).toBeInTheDocument(); + expect(screen.getByText('live')).toBeInTheDocument(); + }); + + it('Re-check button is no-op while a probe is in flight', async () => { + let probeCalls = 0; + let resolveSecond: (value: unknown) => void = () => {}; + invoke.mockImplementation(async (name: string) => { + if (name === 'check_model_setup') { + probeCalls += 1; + if (probeCalls === 1) return { state: 'ollama_unreachable' }; + return new Promise((resolve) => { + resolveSecond = resolve; + }); + } + return undefined; + }); + + render(); + await act(async () => {}); + + await act(async () => { + fireEvent.click(screen.getByLabelText('Verify setup')); + }); + expect(probeCalls).toBe(2); + + await act(async () => { + fireEvent.click(screen.getByLabelText('Verify setup')); + }); + expect(probeCalls).toBe(2); + + await act(async () => { + resolveSecond({ state: 'no_models_installed' }); + }); + }); + + it('copies the selected install command (Install Ollama default)', async () => { + enableChannelCaptureWithResponses({ + check_model_setup: { state: 'ollama_unreachable' }, + }); + + render(); + await act(async () => {}); + + await act(async () => { + fireEvent.click(screen.getByLabelText('Copy install ollama command')); + }); + expect(writeText).toHaveBeenCalledWith( + 'curl -fsSL https://ollama.com/install.sh | sh', + ); + }); + + it('switching tabs swaps the displayed install command', async () => { + enableChannelCaptureWithResponses({ + check_model_setup: { state: 'ollama_unreachable' }, + }); + + render(); + await act(async () => {}); + + expect( + screen.getByText('curl -fsSL https://ollama.com/install.sh | sh'), + ).toBeInTheDocument(); + + await act(async () => { + fireEvent.click( + screen.getByRole('button', { name: 'Already Installed?' }), + ); + }); + expect(screen.getByText('open -a Ollama')).toBeInTheDocument(); + + await act(async () => { + fireEvent.click(screen.getByRole('button', { name: 'Install Ollama' })); + }); + expect( + screen.getByText('curl -fsSL https://ollama.com/install.sh | sh'), + ).toBeInTheDocument(); + }); + + it('copies the open command after switching to the Already Installed? tab', async () => { + enableChannelCaptureWithResponses({ + check_model_setup: { state: 'ollama_unreachable' }, + }); + + render(); + await act(async () => {}); + + await act(async () => { + fireEvent.click( + screen.getByRole('button', { name: 'Already Installed?' }), + ); + }); + await act(async () => { + fireEvent.click(screen.getByLabelText('Copy already installed? command')); + }); + expect(writeText).toHaveBeenCalledWith('open -a Ollama'); + }); + + it('lights up the active tab with the brand orange', async () => { + enableChannelCaptureWithResponses({ + check_model_setup: { state: 'ollama_unreachable' }, + }); + + render(); + await act(async () => {}); + + const installTab = screen.getByRole('button', { name: 'Install Ollama' }); + expect(installTab.style.color).toContain('255, 141, 92'); + + const alreadyTab = screen.getByRole('button', { + name: 'Already Installed?', + }); + expect(alreadyTab.style.color).not.toContain('255, 141, 92'); + }); + + it('hovering an inactive tab brightens the label', async () => { + enableChannelCaptureWithResponses({ + check_model_setup: { state: 'ollama_unreachable' }, + }); + + render(); + await act(async () => {}); + + const alreadyTab = screen.getByRole('button', { + name: 'Already Installed?', + }); + const before = alreadyTab.style.color; + fireEvent.mouseEnter(alreadyTab); + expect(alreadyTab.style.color).not.toBe(before); + fireEvent.mouseLeave(alreadyTab); + expect(alreadyTab.style.color).toBe(before); + }); + + it('copies the pull command for a starter model', async () => { + enableChannelCaptureWithResponses({ + check_model_setup: { state: 'no_models_installed' }, + }); + + render(); + await act(async () => {}); + + await act(async () => { + fireEvent.click( + screen.getByLabelText('Copy install command for phi4:14b'), + ); + }); + expect(writeText).toHaveBeenCalledWith('ollama pull phi4:14b'); + }); + + it('renders each starter model with its description and size', async () => { + enableChannelCaptureWithResponses({ + check_model_setup: { state: 'no_models_installed' }, + }); + + render(); + await act(async () => {}); + + expect(screen.getByText('Google · vision · 9.6 GB')).toBeInTheDocument(); + expect(screen.getByText('Meta · vision · 7.8 GB')).toBeInTheDocument(); + expect(screen.getByText('Microsoft · text · 9.1 GB')).toBeInTheDocument(); + }); + + it('clicking a model slug opens its Ollama library page', async () => { + enableChannelCaptureWithResponses({ + check_model_setup: { state: 'no_models_installed' }, + open_url: undefined, + }); + + render(); + await act(async () => {}); + + await act(async () => { + fireEvent.click(screen.getByLabelText('Open gemma4:e4b on Ollama')); + }); + + expect(invoke).toHaveBeenCalledWith('open_url', { + url: 'https://ollama.com/library/gemma4', + }); + }); + + it('lights up the slug link on pointer hover', async () => { + enableChannelCaptureWithResponses({ + check_model_setup: { state: 'no_models_installed' }, + }); + + render(); + await act(async () => {}); + + const link = screen.getByLabelText('Open phi4:14b on Ollama'); + const initialColor = link.style.color; + fireEvent.mouseEnter(link); + expect(link.style.color).not.toBe(initialColor); + fireEvent.mouseLeave(link); + expect(link.style.color).toBe(initialColor); + }); + + it('swallows clipboard write errors silently', async () => { + writeText.mockReset(); + writeText.mockRejectedValue(new Error('denied')); + enableChannelCaptureWithResponses({ + check_model_setup: { state: 'ollama_unreachable' }, + }); + + render(); + await act(async () => {}); + + await expect( + act(async () => { + fireEvent.click(screen.getByLabelText('Copy install ollama command')); + }), + ).resolves.not.toThrow(); + }); + + it('renders the privacy footer', async () => { + enableChannelCaptureWithResponses({ + check_model_setup: { state: 'ollama_unreachable' }, + }); + + render(); + await act(async () => {}); + + expect( + screen.getByText( + 'Private by default · All inference runs on your machine', + ), + ).toBeInTheDocument(); + }); + + it('renders the Step 1 sub-line below the code box with the Ollama docs link', async () => { + enableChannelCaptureWithResponses({ + check_model_setup: { state: 'ollama_unreachable' }, + }); + + render(); + await act(async () => {}); + + expect( + screen.getByText('Paste this in Terminal or visit'), + ).toBeInTheDocument(); + expect( + screen.getByLabelText('Open Ollama documentation'), + ).toBeInTheDocument(); + }); + + it('opens the Ollama docs URL when its sub-line link is clicked', async () => { + enableChannelCaptureWithResponses({ + check_model_setup: { state: 'ollama_unreachable' }, + open_url: undefined, + }); + + render(); + await act(async () => {}); + + await act(async () => { + fireEvent.click(screen.getByLabelText('Open Ollama documentation')); + }); + + expect(invoke).toHaveBeenCalledWith('open_url', { + url: 'https://ollama.com/download', + }); + }); + + it('opens the Ollama library URL when the Browse link is clicked', async () => { + enableChannelCaptureWithResponses({ + check_model_setup: { state: 'no_models_installed' }, + open_url: undefined, + }); + + render(); + await act(async () => {}); + + await act(async () => { + fireEvent.click(screen.getByLabelText('Browse all models on Ollama')); + }); + + expect(invoke).toHaveBeenCalledWith('open_url', { + url: 'https://ollama.com/search', + }); + }); + + it('renders the Step 2 helper block under the model list', async () => { + enableChannelCaptureWithResponses({ + check_model_setup: { state: 'no_models_installed' }, + }); + + render(); + await act(async () => {}); + + expect( + screen.getByText('Paste the command in Terminal'), + ).toBeInTheDocument(); + expect(screen.getByText('or')).toBeInTheDocument(); + expect( + screen.getByText('Browse all models on ollama.com ↗'), + ).toBeInTheDocument(); + }); + + it('lights up sub-line doc links on pointer hover', async () => { + enableChannelCaptureWithResponses({ + check_model_setup: { state: 'ollama_unreachable' }, + }); + + render(); + await act(async () => {}); + + const link = screen.getByLabelText('Open Ollama documentation'); + const initialColor = link.style.color; + fireEvent.mouseEnter(link); + expect(link.style.color).not.toBe(initialColor); + fireEvent.mouseLeave(link); + expect(link.style.color).toBe(initialColor); + }); + + it('icon-only install copy button shows only the green check on success (no Copied text)', async () => { + vi.useFakeTimers(); + try { + enableChannelCaptureWithResponses({ + check_model_setup: { state: 'ollama_unreachable' }, + }); + + render(); + await act(async () => {}); + + await act(async () => { + fireEvent.click(screen.getByLabelText('Copy install ollama command')); + }); + + expect(screen.queryByText('Copied')).not.toBeInTheDocument(); + const button = screen.getByLabelText('Copy install ollama command'); + expect(button.style.borderColor).toContain('34, 197, 94'); + + await act(async () => { + vi.advanceTimersByTime(1500); + }); + + expect(button.style.borderColor).not.toContain('34, 197, 94'); + } finally { + vi.useRealTimers(); + } + }); + + it('model-row copy button swaps into a Copied confirmation after a successful copy', async () => { + vi.useFakeTimers(); + try { + enableChannelCaptureWithResponses({ + check_model_setup: { state: 'no_models_installed' }, + }); + + render(); + await act(async () => {}); + + await act(async () => { + fireEvent.click( + screen.getByLabelText('Copy install command for gemma4:e4b'), + ); + }); + + expect(screen.getByText('Copied')).toBeInTheDocument(); + + await act(async () => { + vi.advanceTimersByTime(1500); + }); + + expect(screen.queryByText('Copied')).not.toBeInTheDocument(); + expect(screen.getAllByText('Copy').length).toBeGreaterThan(0); + } finally { + vi.useRealTimers(); + } + }); + + it('clears the previous Copied timer when the model-row copy button is clicked twice quickly', async () => { + vi.useFakeTimers(); + try { + enableChannelCaptureWithResponses({ + check_model_setup: { state: 'no_models_installed' }, + }); + + render(); + await act(async () => {}); + + const button = screen.getByLabelText('Copy install command for phi4:14b'); + + await act(async () => { + fireEvent.click(button); + }); + expect(screen.getByText('Copied')).toBeInTheDocument(); + + await act(async () => { + vi.advanceTimersByTime(800); + }); + await act(async () => { + fireEvent.click(button); + }); + expect(screen.getByText('Copied')).toBeInTheDocument(); + + await act(async () => { + vi.advanceTimersByTime(800); + }); + expect(screen.getByText('Copied')).toBeInTheDocument(); + + await act(async () => { + vi.advanceTimersByTime(800); + }); + expect(screen.queryByText('Copied')).not.toBeInTheDocument(); + } finally { + vi.useRealTimers(); + } + }); + + it('lights up the copy button border on pointer hover', async () => { + enableChannelCaptureWithResponses({ + check_model_setup: { state: 'ollama_unreachable' }, + }); + + render(); + await act(async () => {}); + + const button = screen.getByLabelText('Copy install ollama command'); + fireEvent.mouseEnter(button); + expect(button.style.borderColor).toContain('255, 141, 92'); + fireEvent.mouseLeave(button); + expect(button.style.borderColor).toContain('255, 255, 255'); + }); + + it('drops the probe success when the component unmounts mid-flight', async () => { + let resolveProbe: (value: unknown) => void = () => {}; + invoke.mockImplementation(async (name: string) => { + if (name === 'check_model_setup') { + return new Promise((resolve) => { + resolveProbe = resolve; + }); + } + return undefined; + }); + + const { unmount } = render(); + unmount(); + + await act(async () => { + resolveProbe({ state: 'no_models_installed' }); + }); + + expect(invoke).not.toHaveBeenCalledWith('advance_past_model_check'); + }); + + it('drops the probe failure when the component unmounts mid-flight', async () => { + let rejectProbe: (reason: unknown) => void = () => {}; + invoke.mockImplementation(async (name: string) => { + if (name === 'check_model_setup') { + return new Promise((_resolve, reject) => { + rejectProbe = reject; + }); + } + return undefined; + }); + + const { unmount } = render(); + unmount(); + + await act(async () => { + rejectProbe(new Error('late failure')); + }); + }); + + it('skips re-render when the recheck probe finishes after unmount', async () => { + let calls = 0; + let resolveSecond: (value: unknown) => void = () => {}; + invoke.mockImplementation(async (name: string) => { + if (name === 'check_model_setup') { + calls += 1; + if (calls === 1) return { state: 'ollama_unreachable' }; + return new Promise((resolve) => { + resolveSecond = resolve; + }); + } + return undefined; + }); + + render(); + await act(async () => {}); + + await act(async () => { + fireEvent.click(screen.getByLabelText('Verify setup')); + }); + + cleanup(); + + await act(async () => { + resolveSecond({ state: 'no_models_installed' }); + }); + }); + + it('does not show the Copied confirmation when the clipboard write fails', async () => { + writeText.mockReset(); + writeText.mockRejectedValue(new Error('denied')); + enableChannelCaptureWithResponses({ + check_model_setup: { state: 'ollama_unreachable' }, + }); + + render(); + await act(async () => {}); + + await act(async () => { + fireEvent.click(screen.getByLabelText('Copy install ollama command')); + }); + + expect(screen.queryByText('Copied')).not.toBeInTheDocument(); + }); +}); diff --git a/src/view/onboarding/__tests__/index.test.tsx b/src/view/onboarding/__tests__/index.test.tsx index 751fe408..28efe58b 100644 --- a/src/view/onboarding/__tests__/index.test.tsx +++ b/src/view/onboarding/__tests__/index.test.tsx @@ -17,6 +17,12 @@ describe('OnboardingView (orchestrator)', () => { it('renders IntroStep when stage is intro', () => { render(); - expect(screen.getByText('Before you dive in')).toBeInTheDocument(); + expect(screen.getByText("You're all set")).toBeInTheDocument(); + }); + + it('renders ModelCheckStep when stage is model_check', async () => { + render(); + await act(async () => {}); + expect(screen.getByText('Set up your local AI')).toBeInTheDocument(); }); }); diff --git a/src/view/onboarding/_shared.tsx b/src/view/onboarding/_shared.tsx new file mode 100644 index 00000000..68b82ad7 --- /dev/null +++ b/src/view/onboarding/_shared.tsx @@ -0,0 +1,100 @@ +/** + * Shared building blocks for onboarding steps. + * + * Extracted from PermissionsStep so ModelCheckStep (and any future + * onboarding screen) can reuse the same active / done / waiting visual + * language. The token values here are the source of truth for the + * onboarding visual system; do not duplicate them inline in step + * components. + */ + +import type React from 'react'; + +export interface StepCardProps { + /** Orange-glow treatment indicating the user must act on this step now. */ + active: boolean; + /** Green-tinted "done" treatment with a thin success border. */ + done: boolean; + children: React.ReactNode; +} + +/** + * Container that applies the onboarding step visual treatment. + * + * Three mutually exclusive states: + * - done: green border + green-tint background, no glow. + * - active && !done: warm orange border + orange-tint background + + * soft outer glow + 1px inner top highlight. + * - !active && !done: subtle white border + faint white-tint + * background, no glow. Used for "waiting" steps that the user + * cannot act on yet. + */ +export function StepCard({ active, done, children }: StepCardProps) { + const borderColor = done + ? 'rgba(34,197,94,0.2)' + : active + ? 'rgba(255,141,92,0.4)' + : 'rgba(255,255,255,0.06)'; + + const background = done + ? 'rgba(34,197,94,0.05)' + : active + ? 'rgba(255,141,92,0.07)' + : 'rgba(255,255,255,0.03)'; + + return ( +
      + {children} +
      + ); +} + +export interface BadgeProps { + color: 'green'; + children: React.ReactNode; +} + +/** + * Inline status pill rendered to the right of a done step's title. + * + * Single-color today (`green` for the success / connected state). Add + * new colors as discrete variants rather than accepting arbitrary CSS, + * which keeps the badge palette under one rule. + */ +export function Badge({ color, children }: BadgeProps) { + const styles: Record = { + green: { + color: '#22c55e', + background: 'rgba(34,197,94,0.1)', + border: '1px solid rgba(34,197,94,0.2)', + }, + }; + + return ( + + {children} + + ); +} diff --git a/src/view/onboarding/index.tsx b/src/view/onboarding/index.tsx index 9b2987e5..c5a20042 100644 --- a/src/view/onboarding/index.tsx +++ b/src/view/onboarding/index.tsx @@ -1,7 +1,13 @@ import { IntroStep } from './IntroStep'; +import { ModelCheckStep } from './ModelCheckStep'; import { PermissionsStep } from './PermissionsStep'; -export type OnboardingStage = 'permissions' | 'intro'; +/** + * Stage values mirror the Rust `OnboardingStage` enum exactly. The + * backend emits these strings as the `stage` field on the + * `thuki://onboarding` event; any drift here breaks the dispatch. + */ +export type OnboardingStage = 'permissions' | 'model_check' | 'intro'; interface Props { stage: OnboardingStage; @@ -14,7 +20,7 @@ interface Props { * Renders the correct step based on the persisted onboarding stage emitted * by the backend at startup. The stage advances on the backend: * - * permissions -> (quit+reopen) -> intro -> complete (normal app) + * permissions -> (quit+reopen) -> model_check -> (advance) -> intro -> complete * * When stage is "complete" the backend never emits the onboarding event, * so this component is never rendered. @@ -23,5 +29,12 @@ export function OnboardingView({ stage, onComplete }: Props) { if (stage === 'intro') { return ; } + if (stage === 'model_check') { + // ModelCheckStep advances to `intro` via the backend + // `advance_past_model_check` command, which re-emits the onboarding + // event. No callback wiring needed here. + void onComplete; // referenced for parity; unused by ModelCheckStep + return ; + } return ; }