Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion muon/_core/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,6 +308,7 @@ def mofa(
use_float32: bool = False,
gpu_mode: bool = False,
gpu_device: Optional[bool] = None,
train_kwargs: Optional[Mapping[str, Any]] = None,
svi_mode: bool = False,
svi_batch_size: float = 0.5,
svi_learning_rate: float = 1.0,
Expand Down Expand Up @@ -370,8 +371,11 @@ def mofa(
use reduced precision (float32)
gpu_mode : optional
if to use GPU mode
gpu_mode : optional
gpu_device : optional
which GPU device to use
train_kwargs: optional
additional parameters for MOFA (startELBO, freqELBO, startSparsity, tolerance, startDrop, freqDrop,
dropR2, nostop, schedule, weight_views)
svi_mode : optional
if to use Stochastic Variational Inference (SVI)
svi_batch_size : optional
Expand Down Expand Up @@ -489,6 +493,8 @@ def mofa(
)
logging.info(f"[{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] Setting training options...")

if train_kwargs is None:
train_kwargs = {}
try:
ent.set_train_options(
iter=n_iterations,
Expand All @@ -500,6 +506,7 @@ def mofa(
quiet=quiet,
outfile=outfile,
save_interrupted=save_interrupted,
**train_kwargs,
)
except TypeError:
# mofapy2 <0.7 does not have a gpu_device argument
Expand All @@ -516,6 +523,7 @@ def mofa(
quiet=quiet,
outfile=outfile,
save_interrupted=save_interrupted,
**train_kwargs,
)

if svi_mode:
Expand Down