User-defined embedding functions
To use your own custom embedding function, you can follow these 2 simple steps:
- Create your embedding function by implementing the
EmbeddingFunction
interface - Register your embedding function in the global
EmbeddingFunctionRegistry
.
Let us see how this looks like in action.
EmbeddingFunction
andEmbeddingFunctionRegistry
handle low-level details for serializing schema and model information as metadata. To build a custom embedding function, you don't have to worry about the finer details - simply focus on setting up the model and leave the rest to LanceDB.
TextEmbeddingFunction
interface
There is another optional layer of abstraction available:TextEmbeddingFunction
. You can use this abstraction if your model isn't multi-modal in nature and only needs to operate on text. In such cases, both the source and vector fields will have the same work for vectorization, so you simply just need to setup the model and rest is handled byTextEmbeddingFunction
. You can read more about the class and its attributes in the class reference.
Let's implementSentenceTransformerEmbeddings
class. All you need to do is implement thegenerate_embeddings()
andndims
function to handle the input types you expect and register the class in the globalEmbeddingFunctionRegistry
fromlancedb.embeddingsimportregisterfromlancedb.utilimportattempt_import_or_raise@register("sentence-transformers")classSentenceTransformerEmbeddings(TextEmbeddingFunction):name:str="all-MiniLM-L6-v2"# set more default instance vars like device, etc.def__init__(self,**kwargs):super().__init__(**kwargs)self._ndims=Nonedefgenerate_embeddings(self,texts):returnself._embedding_model().encode(list(texts),...).tolist()defndims(self):ifself._ndimsisNone:self._ndims=len(self.generate_embeddings("foo")[0])returnself._ndims@cached(cache={})def_embedding_model(self):returnsentence_transformers.SentenceTransformer(name)
import*aslancedbfrom"@lancedb/lancedb";import{LanceSchema,TextEmbeddingFunction,getRegistry,register,}from"@lancedb/lancedb/embedding";@register("sentence-transformers")classSentenceTransformersEmbeddingsextendsTextEmbeddingFunction{name="Xenova/all-miniLM-L6-v2";#ndims!:number;extractor!:FeatureExtractionPipeline;asyncinit(){this.extractor=awaitpipeline("feature-extraction",this.name,{dtype:"fp32",});this.#ndims=awaitthis.generateEmbeddings(["hello"]).then((e)=>e[0].length,);}ndims(){returnthis.#ndims;}toJSON(){return{name:this.name,};}asyncgenerateEmbeddings(texts:string[]){constoutput=awaitthis.extractor(texts,{pooling:"mean",normalize:true,});returnoutput.tolist();}}
This is a stripped down version of our implementation ofSentenceTransformerEmbeddings
that removes certain optimizations and default settings.
Use sensitive keys to prevent leaking secrets
To prevent leaking secrets, such as API keys, you should add any sensitiveparameters of an embedding function to the output of thesensitive_keys() /getSensitiveKeys()method. This prevents users from accidentally instantiating the embeddingfunction with hard-coded secrets.
Now you can use this embedding function to create your table schema and that's it! you can then ingest data and run queries without manually vectorizing the inputs.
fromlancedb.pydanticimportLanceModel,Vectorregistry=EmbeddingFunctionRegistry.get_instance()stransformer=registry.get("sentence-transformers").create()classTextModelSchema(LanceModel):vector:Vector(stransformer.ndims)=stransformer.VectorField()text:str=stransformer.SourceField()tbl=db.create_table("table",schema=TextModelSchema)tbl.add(pd.DataFrame({"text":["halo","world"]}))result=tbl.search("world").limit(5)
constregistry=getRegistry();constsentenceTransformer=awaitregistry.get<SentenceTransformersEmbeddings>("sentence-transformers")!.create();constschema=LanceSchema({vector:sentenceTransformer.vectorField(),text:sentenceTransformer.sourceField(),});constdb=awaitlancedb.connect(databaseDir);consttable=awaitdb.createEmptyTable("table",schema,{mode:"overwrite",});awaittable.add([{text:"hello"},{text:"world"}]);constresults=awaittable.search("greeting").limit(1).toArray();
Note
You can always implement theEmbeddingFunction
interface directly if you want or need to,TextEmbeddingFunction
just makes it much simpler and faster for you to do so, by setting up the boiler plat for text-specific use case
Multi-modal embedding function example
You can also use theEmbeddingFunction
interface to implement more complex workflows such as multi-modal embedding function support.
LanceDB implementsOpenClipEmeddingFunction
class that suppports multi-modal seach. Here's the implementation that you can use as a reference to build your own multi-modal embedding functions.
@register("open-clip")classOpenClipEmbeddings(EmbeddingFunction):name:str="ViT-B-32"pretrained:str="laion2b_s34b_b79k"device:str="cpu"batch_size:int=64normalize:bool=True_model=PrivateAttr()_preprocess=PrivateAttr()_tokenizer=PrivateAttr()def__init__(self,*args,**kwargs):super().__init__(*args,**kwargs)open_clip=attempt_import_or_raise("open_clip","open-clip")# EmbeddingFunction util to import external libs and raise if not foundmodel,_,preprocess=open_clip.create_model_and_transforms(self.name,pretrained=self.pretrained)model.to(self.device)self._model,self._preprocess=model,preprocessself._tokenizer=open_clip.get_tokenizer(self.name)self._ndims=Nonedefndims(self):ifself._ndimsisNone:self._ndims=self.generate_text_embeddings("foo").shape[0]returnself._ndimsdefcompute_query_embeddings(self,query:Union[str,"PIL.Image.Image"],*args,**kwargs)->List[np.ndarray]:""" Compute the embeddings for a given user query Parameters ---------- query : Union[str, PIL.Image.Image] The query to embed. A query can be either text or an image. """ifisinstance(query,str):return[self.generate_text_embeddings(query)]else:PIL=attempt_import_or_raise("PIL","pillow")ifisinstance(query,PIL.Image.Image):return[self.generate_image_embedding(query)]else:raiseTypeError("OpenClip supports str or PIL Image as query")defgenerate_text_embeddings(self,text:str)->np.ndarray:torch=attempt_import_or_raise("torch")text=self.sanitize_input(text)text=self._tokenizer(text)text.to(self.device)withtorch.no_grad():text_features=self._model.encode_text(text.to(self.device))ifself.normalize:text_features/=text_features.norm(dim=-1,keepdim=True)returntext_features.cpu().numpy().squeeze()defsanitize_input(self,images:IMAGES)->Union[List[bytes],np.ndarray]:""" Sanitize the input to the embedding function. """ifisinstance(images,(str,bytes)):images=[images]elifisinstance(images,pa.Array):images=images.to_pylist()elifisinstance(images,pa.ChunkedArray):images=images.combine_chunks().to_pylist()returnimagesdefcompute_source_embeddings(self,images:IMAGES,*args,**kwargs)->List[np.array]:""" Get the embeddings for the given images """images=self.sanitize_input(images)embeddings=[]foriinrange(0,len(images),self.batch_size):j=min(i+self.batch_size,len(images))batch=images[i:j]embeddings.extend(self._parallel_get(batch))returnembeddingsdef_parallel_get(self,images:Union[List[str],List[bytes]])->List[np.ndarray]:""" Issue concurrent requests to retrieve the image data """withconcurrent.futures.ThreadPoolExecutor()asexecutor:futures=[executor.submit(self.generate_image_embedding,image)forimageinimages]return[future.result()forfutureinfutures]defgenerate_image_embedding(self,image:Union[str,bytes,"PIL.Image.Image"])->np.ndarray:""" Generate the embedding for a single image Parameters ---------- image : Union[str, bytes, PIL.Image.Image] The image to embed. If the image is a str, it is treated as a uri. If the image is bytes, it is treated as the raw image bytes. """torch=attempt_import_or_raise("torch")# TODO handle retry and errors for httpsimage=self._to_pil(image)image=self._preprocess(image).unsqueeze(0)withtorch.no_grad():returnself._encode_and_normalize_image(image)def_to_pil(self,image:Union[str,bytes]):PIL=attempt_import_or_raise("PIL","pillow")ifisinstance(image,bytes):returnPIL.Image.open(io.BytesIO(image))ifisinstance(image,PIL.Image.Image):returnimageelifisinstance(image,str):parsed=urlparse.urlparse(image)# TODO handle drive letter on windows.ifparsed.scheme=="file":returnPIL.Image.open(parsed.path)elifparsed.scheme=="":returnPIL.Image.open(imageifos.name=="nt"elseparsed.path)elifparsed.scheme.startswith("http"):returnPIL.Image.open(io.BytesIO(url_retrieve(image)))else:raiseNotImplementedError("Only local and http(s) urls are supported")def_encode_and_normalize_image(self,image_tensor:"torch.Tensor"):""" encode a single image tensor and optionally normalize the output """image_features=self._model.encode_image(image_tensor)ifself.normalize:image_features/=image_features.norm(dim=-1,keepdim=True)returnimage_features.cpu().numpy().squeeze()
Coming Soon! See thisissue to track the status!