diff --git a/setup.py b/setup.py index 3e1a416..905acc3 100644 --- a/setup.py +++ b/setup.py @@ -15,7 +15,7 @@ setup( name="adaptive-classifier", - version="0.1.1", + version="0.1.2", author="codelion", author_email="codelion@okyasoft.com", description="A flexible, adaptive classification system for dynamic text classification", diff --git a/src/adaptive_classifier/__init__.py b/src/adaptive_classifier/__init__.py index 1fa7f18..1ba635d 100644 --- a/src/adaptive_classifier/__init__.py +++ b/src/adaptive_classifier/__init__.py @@ -4,7 +4,7 @@ from .multilabel import MultiLabelAdaptiveClassifier, MultiLabelAdaptiveHead from huggingface_hub import ModelHubMixin -__version__ = "0.1.1" +__version__ = "0.1.2" __all__ = [ "AdaptiveClassifier", diff --git a/src/adaptive_classifier/classifier.py b/src/adaptive_classifier/classifier.py index 35e4352..1c5ca9f 100644 --- a/src/adaptive_classifier/classifier.py +++ b/src/adaptive_classifier/classifier.py @@ -33,7 +33,8 @@ def __init__( device: Optional[str] = None, config: Optional[Dict[str, Any]] = None, seed: int = 42, # Add seed parameter - use_onnx: Optional[Union[bool, str]] = "auto" # "auto", True, False + use_onnx: Optional[Union[bool, str]] = "auto", # "auto", True, False + trust_remote_code: bool = False ): """Initialize the adaptive classifier. @@ -44,6 +45,7 @@ def __init__( seed: Random seed for initialization use_onnx: Whether to use ONNX Runtime ("auto", True, False). "auto" uses ONNX on CPU, PyTorch on GPU. + trust_remote_code: Whether to trust remote code when loading models (default: False) """ # Set seed for initialization torch.manual_seed(seed) @@ -60,7 +62,8 @@ def __init__( logger.info(f"Initializing ONNX model for {model_name}") self.model = ORTModelForFeatureExtraction.from_pretrained( model_name, - export=True # Auto-export to ONNX if not already in ONNX format + export=True, # Auto-export to ONNX if not already in ONNX format + trust_remote_code=trust_remote_code ) logger.info("Successfully loaded ONNX model") except ImportError: @@ -69,17 +72,17 @@ def __init__( "Install with: pip install optimum[onnxruntime]" ) self.use_onnx = False - self.model = AutoModel.from_pretrained(model_name).to(self.device) + self.model = AutoModel.from_pretrained(model_name, trust_remote_code=trust_remote_code).to(self.device) except Exception as e: logger.warning( f"Failed to load ONNX model: {e}. Falling back to PyTorch." ) self.use_onnx = False - self.model = AutoModel.from_pretrained(model_name).to(self.device) + self.model = AutoModel.from_pretrained(model_name, trust_remote_code=trust_remote_code).to(self.device) else: - self.model = AutoModel.from_pretrained(model_name).to(self.device) + self.model = AutoModel.from_pretrained(model_name, trust_remote_code=trust_remote_code).to(self.device) - self.tokenizer = AutoTokenizer.from_pretrained(model_name) + self.tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=trust_remote_code) # Initialize memory system self.embedding_dim = self.model.config.hidden_size @@ -637,6 +640,7 @@ def _from_pretrained( token: Optional[Union[str, bool]] = None, use_onnx: Optional[Union[bool, str]] = "auto", prefer_quantized: bool = True, + trust_remote_code: bool = False, **kwargs ) -> "AdaptiveClassifier": """Load a model from the HuggingFace Hub or local directory. @@ -653,6 +657,7 @@ def _from_pretrained( use_onnx: Whether to use ONNX Runtime ("auto", True, False) prefer_quantized: Use quantized ONNX model if available (default: True) Set to False to use unquantized model for maximum accuracy + trust_remote_code: Whether to trust remote code when loading models (default: False) **kwargs: Additional arguments passed to from_pretrained Returns: @@ -667,6 +672,9 @@ def _from_pretrained( >>> >>> # Force PyTorch (no ONNX) >>> classifier = AdaptiveClassifier.load("adaptive-classifier/llm-router", use_onnx=False) + >>> + >>> # Load model requiring custom code + >>> classifier = AdaptiveClassifier.load("model-with-custom-code", trust_remote_code=True) """ # Check if model_id is a local directory @@ -814,9 +822,10 @@ def _from_pretrained( classifier.model = ORTModelForFeatureExtraction.from_pretrained( onnx_path, - file_name=onnx_file + file_name=onnx_file, + trust_remote_code=trust_remote_code ) - classifier.tokenizer = AutoTokenizer.from_pretrained(config_dict['model_name']) + classifier.tokenizer = AutoTokenizer.from_pretrained(config_dict['model_name'], trust_remote_code=trust_remote_code) # Initialize memory and other components classifier.embedding_dim = classifier.model.config.hidden_size @@ -852,7 +861,8 @@ def _from_pretrained( config_dict['model_name'], device=device, config=config_dict.get('config', None), - use_onnx=final_use_onnx if isinstance(final_use_onnx, bool) else False + use_onnx=final_use_onnx if isinstance(final_use_onnx, bool) else False, + trust_remote_code=trust_remote_code ) # Restore label mappings @@ -1187,7 +1197,7 @@ def save(self, save_dir: str, include_onnx: bool = True, quantize_onnx: bool = T ) @classmethod - def load(cls, save_dir: str, device: Optional[str] = None, use_onnx: Optional[Union[bool, str]] = "auto", prefer_quantized: bool = True) -> 'AdaptiveClassifier': + def load(cls, save_dir: str, device: Optional[str] = None, use_onnx: Optional[Union[bool, str]] = "auto", prefer_quantized: bool = True, trust_remote_code: bool = False) -> 'AdaptiveClassifier': """Legacy load method for backwards compatibility. Args: @@ -1195,11 +1205,12 @@ def load(cls, save_dir: str, device: Optional[str] = None, use_onnx: Optional[Un device: Device to load model on use_onnx: Whether to use ONNX Runtime ("auto", True, False) prefer_quantized: Use quantized ONNX model if available (default: True) + trust_remote_code: Whether to trust remote code when loading models (default: False) """ kwargs = {} if device is not None: kwargs['device'] = device - return cls._from_pretrained(save_dir, use_onnx=use_onnx, prefer_quantized=prefer_quantized, **kwargs) + return cls._from_pretrained(save_dir, use_onnx=use_onnx, prefer_quantized=prefer_quantized, trust_remote_code=trust_remote_code, **kwargs) def to(self, device: str) -> 'AdaptiveClassifier': """Move the model to specified device.