From aa05f5f083a2612a1d687a910057e5a303952cf4 Mon Sep 17 00:00:00 2001 From: Ilia Kats Date: Fri, 8 Nov 2024 13:28:11 +0100 Subject: [PATCH] tl.mofa: add train_kwargs argument --- muon/_core/tools.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/muon/_core/tools.py b/muon/_core/tools.py index d80c9fa..e281f41 100644 --- a/muon/_core/tools.py +++ b/muon/_core/tools.py @@ -308,6 +308,7 @@ def mofa( use_float32: bool = False, gpu_mode: bool = False, gpu_device: Optional[bool] = None, + train_kwargs: Optional[Mapping[str, Any]] = None, svi_mode: bool = False, svi_batch_size: float = 0.5, svi_learning_rate: float = 1.0, @@ -370,8 +371,11 @@ def mofa( use reduced precision (float32) gpu_mode : optional if to use GPU mode - gpu_mode : optional + gpu_device : optional which GPU device to use + train_kwargs: optional + additional parameters for MOFA (startELBO, freqELBO, startSparsity, tolerance, startDrop, freqDrop, + dropR2, nostop, schedule, weight_views) svi_mode : optional if to use Stochastic Variational Inference (SVI) svi_batch_size : optional @@ -489,6 +493,8 @@ def mofa( ) logging.info(f"[{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] Setting training options...") + if train_kwargs is None: + train_kwargs = {} try: ent.set_train_options( iter=n_iterations, @@ -500,6 +506,7 @@ def mofa( quiet=quiet, outfile=outfile, save_interrupted=save_interrupted, + **train_kwargs, ) except TypeError: # mofapy2 <0.7 does not have a gpu_device argument @@ -516,6 +523,7 @@ def mofa( quiet=quiet, outfile=outfile, save_interrupted=save_interrupted, + **train_kwargs, ) if svi_mode: