diff --git a/README.md b/README.md index 06f052b..971355a 100644 --- a/README.md +++ b/README.md @@ -123,6 +123,7 @@ semantra [OPTIONS] [FILENAME(S)]... - `--model [openai|minilm|mpnet|sgpt|sgpt-1.3B]`: Preset model to use for embedding. See [the models guide](docs/guide_models.md) for more info (default: mpnet) - `--transformer-model TEXT`: Custom Huggingface transformers model name to use for embedding (only one of `--model` and `--transformer-model` should be specified). See [the models guide](docs/guide_models.md) for more info +- `--cpu`: Run local transformers models on CPU even if CUDA is available - `--windows TEXT`: Embedding windows to extract. A comma-separated list of the format "size[\_offset=0][_rewind=0]. A window with size 128, offset 0, and rewind of 16 (128_0_16) will embed the document in chunks of 128 tokens which partially overlap by 16. Only the first window is used for search. See the [windows concept doc](docs/concept_windows.md) for more information (default: 128_0_16) - `--encoding`: Encoding to use for reading text files [default: utf-8] - `--no-server`: Do not start the UI server (only process) diff --git a/docs/guide_models.md b/docs/guide_models.md index 581e536..9c50d69 100644 --- a/docs/guide_models.md +++ b/docs/guide_models.md @@ -2,6 +2,12 @@ Semantra comes with a few preset models along with the ability to run almost any custom [Hugging Face](https://huggingface.co/) [transformers](https://huggingface.co/docs/transformers/index) model. If your computer has a compatible GPU (graphics processing unit, often found in video cards), Semantra will leverage it via [PyTorch](https://pytorch.org/) to dramatically speed up computation. +If PyTorch detects a CUDA device that is incompatible with the installed PyTorch build, you can force local transformers models to run on CPU with `--cpu`: + +```sh +semantra --cpu +``` + ## Using preset models The models Semantra comes with out-of-the-box include: diff --git a/src/semantra/models.py b/src/semantra/models.py index 5ff77b1..d0a0925 100644 --- a/src/semantra/models.py +++ b/src/semantra/models.py @@ -317,7 +317,7 @@ def embed(self, tokens, offsets, is_query=False) -> "list[list[float]]": "cost_per_token": 0.0004 / 1000, "pool_size": 50000, "pool_count": 2000, - "get_model": lambda: OpenAIModel( + "get_model": lambda cuda=None: OpenAIModel( model_name="text-embedding-ada-002", num_dimensions=1536, tokenizer_name="cl100k_base", @@ -326,35 +326,43 @@ def embed(self, tokens, offsets, is_query=False) -> "list[list[float]]": "minilm": { "cost_per_token": None, "pool_size": 50000, - "get_model": lambda: TransformerModel(model_name=minilm_model_name), + "get_model": lambda cuda=None: TransformerModel( + model_name=minilm_model_name, + cuda=cuda, + ), }, "mpnet": { "cost_per_token": None, "pool_size": 15000, - "get_model": lambda: TransformerModel(model_name=mpnet_model_name), + "get_model": lambda cuda=None: TransformerModel( + model_name=mpnet_model_name, + cuda=cuda, + ), }, "sgpt": { "cost_per_token": None, "pool_size": 10000, - "get_model": lambda: TransformerModel( + "get_model": lambda cuda=None: TransformerModel( model_name=sgpt_model_name, query_token_pre="[", query_token_post="]", doc_token_pre="{", doc_token_post="}", asymmetric=True, + cuda=cuda, ), }, "sgpt-1.3B": { "cost_per_token": None, "pool_size": 1000, - "get_model": lambda: TransformerModel( + "get_model": lambda cuda=None: TransformerModel( model_name=sgpt_1_3B_model_name, query_token_pre="[", query_token_post="]", doc_token_pre="{", doc_token_post="}", asymmetric=True, + cuda=cuda, ), }, } diff --git a/src/semantra/semantra.py b/src/semantra/semantra.py index 9636117..cade2de 100644 --- a/src/semantra/semantra.py +++ b/src/semantra/semantra.py @@ -367,6 +367,12 @@ def process_windows(windows: str) -> "list[tuple[int, int, int]]": type=str, help="Custom Huggingface transformers model name to use for embedding", ) +@click.option( + "--cpu", + is_flag=True, + default=False, + help="Run local transformers models on CPU even if CUDA is available", +) @click.option( "--windows", type=str, @@ -538,6 +544,7 @@ def main( doc_token_post=None, query_token_pre=None, query_token_post=None, + cpu=False, model="mpnet", transformer_model=None, encoding=DEFAULT_ENCODING, @@ -581,6 +588,7 @@ def main( raise click.UsageError("Must provide a filename to process/query") processed_windows = list(process_windows(windows)) + cuda = False if cpu else None if transformer_model is not None: # Handle custom transformers model @@ -594,6 +602,7 @@ def main( doc_token_post=doc_token_post, query_token_pre=query_token_pre, query_token_post=query_token_post, + cuda=cuda, ) else: # Pull preset model @@ -603,7 +612,7 @@ def main( pool_size = model_config["pool_size"] if pool_count is None: pool_count = model_config.get("pool_count", None) - model: BaseModel = model_config["get_model"]() + model: BaseModel = model_config["get_model"](cuda=cuda) # Check if model is compatible if svm and model.is_asymmetric():