@@ -527,8 +527,8 @@ def rank(transformer, query, documents, kwargs):
527
527
return rank_using (model ,query ,documents ,kwargs )
528
528
529
529
530
- def create_embedding (transformer ):
531
- return SentenceTransformer (transformer )
530
+ def create_embedding (transformer , kwargs ):
531
+ return SentenceTransformer (transformer , ** kwargs )
532
532
533
533
534
534
def embed_using (model ,transformer ,inputs ,kwargs ):
@@ -545,16 +545,32 @@ def embed_using(model, transformer, inputs, kwargs):
545
545
546
546
def embed (transformer ,inputs ,kwargs ):
547
547
kwargs = orjson .loads (kwargs )
548
-
549
548
ensure_device (kwargs )
550
549
550
+ init_kwarg_keys = [
551
+ "device" ,
552
+ "trust_remote_code" ,
553
+ "revision" ,
554
+ "model_kwargs" ,
555
+ "tokenizer_kwargs" ,
556
+ "config_kwargs" ,
557
+ "truncate_dim" ,
558
+ "token" ,
559
+ ]
560
+ init_kwargs = {
561
+ key :value for key ,value in kwargs .items ()if key in init_kwarg_keys
562
+ }
563
+ encode_kwargs = {
564
+ key :value for key ,value in kwargs .items ()if key not in init_kwarg_keys
565
+ }
566
+
551
567
if transformer not in __cache_sentence_transformer_by_name :
552
568
__cache_sentence_transformer_by_name [transformer ]= create_embedding (
553
- transformer
569
+ transformer , init_kwargs
554
570
)
555
571
model = __cache_sentence_transformer_by_name [transformer ]
556
572
557
- return embed_using (model ,transformer ,inputs ,kwargs )
573
+ return embed_using (model ,transformer ,inputs ,encode_kwargs )
558
574
559
575
560
576
def clear_gpu_cache (memory_usage :None ):