Skip to content

Conversation

@robosimon
Copy link

Fix #295

This PR adds automatic CPU fallback when loading CUDA-saved checkpoints on CPU-only machines.

Changes:

  • Add _resolve_checkpoint_device() helper to detect CUDA → CPU fallback scenario
  • Update _load_cebra_with_sklearn_backend() to use resolved device for all .to() calls
  • Handle both str and torch.device types (including "cuda:0" variants)
  • Update cebra_.device_ and cebra_.device attributes after resolution
  • Add comprehensive tests with monkeypatched CPU-only environment (10 test cases)

Test coverage:

  • 4 CUDA device variants: "cuda", "cuda:0", torch.device("cuda"), torch.device("cuda", 0)
  • 2 model architectures: "offset1-model", "parametrized-model-5"
  • All save/load tests pass: 78 passed, 1 skipped

Verification:

  • Tests without fix: 8 FAILED (confirm bug exists)
  • Tests with fix: 10 PASSED (all green)

@cla-bot
Copy link

cla-bot bot commented Feb 11, 2026

Thank you for your pull request and welcome to our community. We could not parse the GitHub identity of the following contributors: Simon Openclaw.
This is most likely caused by a git client misconfiguration; please make sure to:

  1. check if your git client is configured with an email to sign commits git config --list | grep email
  2. If not, set it up using git config --global user.email [email protected]
  3. Make sure that the git commit email is configured in your GitHub account settings, see https://github.com/settings/emails

Copy link

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR aims to fix #295 by making CEBRA.load(...) gracefully fall back to CPU when a checkpoint indicates it was saved on a CUDA device but CUDA is unavailable at load time.

Changes:

  • Add _resolve_checkpoint_device() to convert checkpoint device_ values (string or torch.device) into a runtime-valid device, with CUDA→CPU fallback.
  • Update _load_cebra_with_sklearn_backend() to use the resolved device for model/criterion/solver .to(...) calls and to update cebra_.device_ (and cebra_.device on fallback).
  • Add tests that monkeypatch CUDA availability and validate successful load + inference.

Reviewed changes

Copilot reviewed 2 out of 2 changed files in this pull request and generated 3 comments.

File Description
cebra/integrations/sklearn/cebra.py Introduces checkpoint-device resolution and applies it during sklearn-backend loading.
tests/test_sklearn.py Adds regression tests intended to cover CUDA-saved checkpoint loading on CPU-only environments.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines 387 to 390
# Resolve device: use CPU when checkpoint was saved on CUDA but CUDA is not available
saved_device = state["device_"]
load_device = _resolve_checkpoint_device(saved_device)

Copy link

Copilot AI Feb 11, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The new CPU-fallback logic only changes subsequent .to(load_device) calls, but loading a truly CUDA-saved checkpoint can still fail earlier in torch.load when the checkpoint contains CUDA tensors and CUDA isn’t available. Consider adding a retry/automatic fallback in CEBRA.load / _safe_torch_load that catches the CUDA deserialization RuntimeError and re-loads with map_location='cpu' (when the caller didn’t already pass map_location).

Copilot uses AI. Check for mistakes.
Comment on lines 1115 to 1130
# Train a model on CPU
cebra_model = cebra_sklearn_cebra.CEBRA(
model_architecture=model_architecture,
max_iterations=5,
device="cpu"
).fit(X)

with _windows_compatible_tempfile(mode="w+b") as tempname:
# Save the model
cebra_model.save(tempname)

# Modify the checkpoint to have a CUDA device
checkpoint = cebra_sklearn_cebra._safe_torch_load(tempname)
checkpoint["state"]["device_"] = saved_device
torch.save(checkpoint, tempname)

Copy link

Copilot AI Feb 11, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This test is described as loading a “CUDA-saved checkpoint”, but it trains/saves the model on CPU and only edits checkpoint['state']['device_']. That doesn’t exercise the common failure mode where the checkpoint’s state_dict tensors are actually on CUDA and torch.load fails unless map_location is used. Consider either generating a real CUDA checkpoint when available, or monkeypatching torch.load/_safe_torch_load to simulate the CUDA deserialization error and assert the loader retries/falls back correctly.

Copilot uses AI. Check for mistakes.
monkeypatch.setattr(torch.cuda, "is_available", lambda: False)

# Load with explicit map_location
loaded_model = cebra_sklearn_cebra.CEBRA.load(tempname)
Copy link

Copilot AI Feb 11, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The test name/doc/comment mention an explicit map_location='cpu' override, but CEBRA.load(tempname) is called without passing map_location (or any kwargs). Either pass map_location via CEBRA.load(tempname, map_location='cpu') (since **kwargs are forwarded to torch.load) or rename/update the test/docstring to reflect what’s actually being tested.

Suggested change
loaded_model = cebra_sklearn_cebra.CEBRA.load(tempname)
loaded_model = cebra_sklearn_cebra.CEBRA.load(tempname, map_location='cpu')

Copilot uses AI. Check for mistakes.
Add comprehensive tests for the CUDA-to-CPU fallback fix:
- test_load_cuda_checkpoint_falls_back_to_cpu: parametrized with 4 CUDA device variants and 2 model architectures
- test_load_cuda_checkpoint_with_device_override: tests map_location behavior

These tests verify:
- Models saved with device='cuda' can load on CPU-only machines
- Device attributes are correctly updated to 'cpu'
- Model parameters are on CPU
- Model can perform inference (transform) after loading

Expected to FAIL before the fix is applied.
@robosimon robosimon force-pushed the fix/cuda-load-on-cpu branch from b952d02 to 6fae503 Compare February 11, 2026 16:03
@cla-bot
Copy link

cla-bot bot commented Feb 11, 2026

Thank you for your contribution. We require contributors to sign our Contributor License Agreement (CLA). We do not have a signed CLA on file for you. In order for us to review and merge your code, please sign our CLA here. After you signed, you can comment on this PR with @cla-bot check to trigger another check.

@robosimon robosimon force-pushed the fix/cuda-load-on-cpu branch from 6fae503 to efe8b95 Compare February 11, 2026 16:07
@cla-bot cla-bot bot added the CLA signed label Feb 11, 2026
@robosimon
Copy link
Author

Thanks for the review! I've addressed all the comments:

  1. Added retry logic in _safe_torch_load (cebra.py): Now catches CUDA deserialization RuntimeError and automatically retries with map_location='cpu' when CUDA is unavailable.

  2. Added new test test_safe_torch_load_cuda_fallback (test_sklearn.py): This test monkeypatches torch.load to simulate the actual CUDA tensor deserialization error on first call, then verifies the retry logic works correctly.

  3. Fixed test docstring (test_sklearn.py): Updated test_load_cuda_checkpoint_with_device_override docstring from "Test that map_location='cpu' works..." to "Test that automatic CPU fallback works..." to accurately reflect what the test verifies.

  4. Updated commit authors to match the robosimon GitHub account.

All tests pass (including the new ones).

When a CEBRA checkpoint was saved on a CUDA device but is loaded on a
machine without CUDA available, it now gracefully falls back to CPU
instead of crashing with RuntimeError.

Changes:
- Add _resolve_checkpoint_device() helper to handle device resolution
- Update _load_cebra_with_sklearn_backend() to use resolved device
- Handle both string and torch.device types, including cuda:0 variants
- Update model device attributes after resolution

Fixes: Loading model saved with device='cuda' on CPU-only machine
@robosimon robosimon force-pushed the fix/cuda-load-on-cpu branch from efe8b95 to 97d5b90 Compare February 11, 2026 16:12
@robosimon
Copy link
Author

Added the real CUDA-saved checkpoint to tests/test_data/cuda_saved_checkpoint/ as a test fixture.

This checkpoint was saved with CUDA tensors and serves as:

  1. Documentation of the real-world scenario this fix addresses
  2. A test fixture for future integration testing (when PyTorch 2.6+ directory format is supported)
  3. Validation that the fix handles actual CUDA-saved model structures

The checkpoint is stored in PyTorch's newer directory format (version 3). While the current test environment can't directly load it due to format limitations, the test_load_real_cuda_checkpoint_on_cpu test verifies the checkpoint structure is valid and documents its presence.

The actual fix is still fully tested via the mock-based tests that simulate the CUDA deserialization error and verify the retry logic works correctly.

@stes
Copy link
Member

stes commented Feb 11, 2026

See comment in #295 (comment); I think we need to discuss if there is an issue in the first place: Using map_storage as in torch should do the trick.

So the implementation could be:

  • change the default in CEBRA and become "more sklearn", less torch
  • keep our current default, but improve the error message.

Copy link
Member

@stes stes left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

left some minor comments.

In addition:

  • code for generating the test data (on gpu) should be added (see the tests/ folder, this can be a separate utility
  • the logic should be robust to other devices (e.g. mps)
  • the binary files should not go to the repo -- but I can handle this once the PR is ready (and code for generating the checkpoint is available)

If we go for this change to the logic, I would recommend to change the high level signature of the CEBRA.load function; e.g. a default to map_storage could be added.

The value of the current logic is that when a user attempts to use a CEBRA model on a GPU machine and for some reasaon (e.g. failure to use docker with Nvidia, outdated driver, etc) the GPU is not available, we will see an error during load.

On the other hand, an automated mapping to cpu might be more user friendly and should be minimally documented. If we go with the auto-remapping, we should minimally add a warning when this happens though

Comment on lines 90 to 98
if legacy_mode:
checkpoint = torch.load(filename, weights_only=False,
**kwargs)
else:
with torch.serialization.safe_globals(
CEBRA_LOAD_SAFE_GLOBALS):
checkpoint = torch.load(filename,
weights_only=weights_only,
**kwargs)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

duplicate; I recommend to use cebra.load again but adapt the map storage parameter instead, and fail on second attempt

Comment on lines 99 to 102
else:
raise
else:
raise
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

raise a meaningful error here

…lLab#296

- Refactored _safe_torch_load() to use recursion instead of duplicate logic
- Added meaningful error messages when CPU fallback fails
- Added UserWarning when auto-remapping CUDA/MPS to CPU
- Extended _resolve_checkpoint_device() to handle MPS fallback
- Added test for MPS checkpoint fallback
- Added test for meaningful error on retry failure
- Added test for error with explicit map_location
- Created tests/generate_cuda_checkpoint.py utility for GPU test data
- Removed binary checkpoint files from repo
- Updated .gitignore to exclude test checkpoint binaries

All 53 tests pass (14 CUDA/MPS tests + 39 regression tests)
@robosimon robosimon force-pushed the fix/cuda-load-on-cpu branch from 1e6749f to 55f7589 Compare February 11, 2026 19:30
@robosimon
Copy link
Author

@stes Thanks for the review! All comments addressed:

  1. Line 98 refactor: _safe_torch_load() now uses recursion with _is_retry flag, no duplicate logic
  2. Line 102 errors: Added meaningful error messages for both failure scenarios
  3. Warnings: Added UserWarning when auto-remapping to CPU
  4. MPS support: Extended _resolve_checkpoint_device() for Apple Silicon fallback
  5. GPU utility: Created tests/generate_cuda_checkpoint.py for generating test data
  6. Binary files: Removed cuda_saved_checkpoint/ and updated .gitignore

Re: map_location='cpu' default: Demo branch robosimon#3 proves this alone fails (8/8 tests fail). The issue is state['device_'] still says 'cuda' after loading, and model.to('cuda') fails afterwards.

Test results: All 53 tests pass (14 CUDA/MPS + 39 regression). Let me know if any adjustments needed!

Comment on lines 404 to 408
Args:
device: The device from the checkpoint (str or torch.device).

Returns:
str: The resolved device string ('cpu' or validated device).
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

pls dont mention types in args/returns. type annotate instead

Comment on lines 410 to 411
if isinstance(device, torch.device):
device = str(device)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not robust. use torch.device type instead of string parsing
https://docs.pytorch.org/docs/stable/tensor_attributes.html#torch.device

f"got {type(device)}.")

fallback_to_cpu = False

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change

args, state, state_dict = cebra_info['args'], cebra_info[
'state'], cebra_info['state_dict']

# Resolve device: use CPU when checkpoint was saved on CUDA but CUDA is not available
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# Resolve device: use CPU when checkpoint was saved on CUDA but CUDA is not available

remove comments that are obvious from context

for key, value in state.items():
setattr(cebra_, key, value)

# Update device attributes to the resolved device for the current runtime
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

see above

Suggested change
# Update device attributes to the resolved device for the current runtime

Comment on lines 476 to 478
if isinstance(saved_device_str,
str) and saved_device_str.startswith("cuda") and load_device == "cpu":
cebra_.device = "cpu"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

see above; lets use torch.device instead of string operations. e.g. instead of startswith you can check the .type of the device

- Add type annotations to _resolve_checkpoint_device: Union[str, torch.device] -> str
- Remove type mentions from docstring Args/Returns
- Use torch.device.type instead of string startswith checks
- Remove obvious comments from _load_cebra_with_sklearn_backend
- Use torch.device for device type checking in load backend

All 10 related tests pass.
@robosimon
Copy link
Author

@stes All review comments addressed:

_resolve_checkpoint_device:

  • Added type annotations:
  • Removed type mentions from docstring Args/Returns
  • Now converts str → torch.device (reversed the logic)
  • Uses / instead of string startswith

_load_cebra_with_sklearn_backend:

  • Removed obvious comments
  • Uses with torch.device instead of string parsing

All 10 related tests pass. Let me know if there's anything else!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Loading a model saved on CUDA fails on CPU-only machines

2 participants