Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up

CLIP (Contrastive Language-Image Pretraining), Predict the most relevant text snippet given an image

License

NotificationsYou must be signed in to change notification settings

openai/CLIP

Repository files navigation

[Blog][Paper][Model Card][Colab]

CLIP (Contrastive Language-Image Pre-Training) is a neural network trained on a variety of (image, text) pairs. It can be instructed in natural language to predict the most relevant text snippet, given an image, without directly optimizing for the task, similarly to the zero-shot capabilities of GPT-2 and 3. We found CLIP matches the performance of the original ResNet50 on ImageNet “zero-shot” without using any of the original 1.28M labeled examples, overcoming several major challenges in computer vision.

Approach

CLIP

Usage

First,install PyTorch 1.7.1 (or later) and torchvision, as well as small additional dependencies, and then install this repo as a Python package. On a CUDA GPU machine, the following will do the trick:

$ conda install --yes -c pytorch pytorch=1.7.1 torchvision cudatoolkit=11.0$ pip install ftfy regex tqdm$ pip install git+https://github.com/openai/CLIP.git

Replacecudatoolkit=11.0 above with the appropriate CUDA version on your machine orcpuonly when installing on a machine without a GPU.

importtorchimportclipfromPILimportImagedevice="cuda"iftorch.cuda.is_available()else"cpu"model,preprocess=clip.load("ViT-B/32",device=device)image=preprocess(Image.open("CLIP.png")).unsqueeze(0).to(device)text=clip.tokenize(["a diagram","a dog","a cat"]).to(device)withtorch.no_grad():image_features=model.encode_image(image)text_features=model.encode_text(text)logits_per_image,logits_per_text=model(image,text)probs=logits_per_image.softmax(dim=-1).cpu().numpy()print("Label probs:",probs)# prints: [[0.9927937  0.00421068 0.00299572]]

API

The CLIP moduleclip provides the following methods:

clip.available_models()

Returns the names of the available CLIP models.

clip.load(name, device=..., jit=False)

Returns the model and the TorchVision transform needed by the model, specified by the model name returned byclip.available_models(). It will download the model as necessary. Thename argument can also be a path to a local checkpoint.

The device to run the model can be optionally specified, and the default is to use the first CUDA device if there is any, otherwise the CPU. Whenjit isFalse, a non-JIT version of the model will be loaded.

clip.tokenize(text: Union[str, List[str]], context_length=77)

Returns a LongTensor containing tokenized sequences of given text input(s). This can be used as the input to the model


The model returned byclip.load() supports the following methods:

model.encode_image(image: Tensor)

Given a batch of images, returns the image features encoded by the vision portion of the CLIP model.

model.encode_text(text: Tensor)

Given a batch of text tokens, returns the text features encoded by the language portion of the CLIP model.

model(image: Tensor, text: Tensor)

Given a batch of images and a batch of text tokens, returns two Tensors, containing the logit scores corresponding to each image and text input. The values are cosine similarities between the corresponding image and text features, times 100.

More Examples

Zero-Shot Prediction

The code below performs zero-shot prediction using CLIP, as shown in Appendix B in the paper. This example takes an image from theCIFAR-100 dataset, and predicts the most likely labels among the 100 textual labels from the dataset.

importosimportclipimporttorchfromtorchvision.datasetsimportCIFAR100# Load the modeldevice="cuda"iftorch.cuda.is_available()else"cpu"model,preprocess=clip.load('ViT-B/32',device)# Download the datasetcifar100=CIFAR100(root=os.path.expanduser("~/.cache"),download=True,train=False)# Prepare the inputsimage,class_id=cifar100[3637]image_input=preprocess(image).unsqueeze(0).to(device)text_inputs=torch.cat([clip.tokenize(f"a photo of a{c}")forcincifar100.classes]).to(device)# Calculate featureswithtorch.no_grad():image_features=model.encode_image(image_input)text_features=model.encode_text(text_inputs)# Pick the top 5 most similar labels for the imageimage_features/=image_features.norm(dim=-1,keepdim=True)text_features/=text_features.norm(dim=-1,keepdim=True)similarity= (100.0*image_features @text_features.T).softmax(dim=-1)values,indices=similarity[0].topk(5)# Print the resultprint("\nTop predictions:\n")forvalue,indexinzip(values,indices):print(f"{cifar100.classes[index]:>16s}:{100*value.item():.2f}%")

The output will look like the following (the exact numbers may be slightly different depending on the compute device):

Top predictions:           snake: 65.31%          turtle: 12.29%    sweet_pepper: 3.83%          lizard: 1.88%       crocodile: 1.75%

Note that this example uses theencode_image() andencode_text() methods that return the encoded features of given inputs.

Linear-probe evaluation

The example below usesscikit-learn to perform logistic regression on image features.

importosimportclipimporttorchimportnumpyasnpfromsklearn.linear_modelimportLogisticRegressionfromtorch.utils.dataimportDataLoaderfromtorchvision.datasetsimportCIFAR100fromtqdmimporttqdm# Load the modeldevice="cuda"iftorch.cuda.is_available()else"cpu"model,preprocess=clip.load('ViT-B/32',device)# Load the datasetroot=os.path.expanduser("~/.cache")train=CIFAR100(root,download=True,train=True,transform=preprocess)test=CIFAR100(root,download=True,train=False,transform=preprocess)defget_features(dataset):all_features= []all_labels= []withtorch.no_grad():forimages,labelsintqdm(DataLoader(dataset,batch_size=100)):features=model.encode_image(images.to(device))all_features.append(features)all_labels.append(labels)returntorch.cat(all_features).cpu().numpy(),torch.cat(all_labels).cpu().numpy()# Calculate the image featurestrain_features,train_labels=get_features(train)test_features,test_labels=get_features(test)# Perform logistic regressionclassifier=LogisticRegression(random_state=0,C=0.316,max_iter=1000,verbose=1)classifier.fit(train_features,train_labels)# Evaluate using the logistic regression classifierpredictions=classifier.predict(test_features)accuracy=np.mean((test_labels==predictions).astype(float))*100.print(f"Accuracy ={accuracy:.3f}")

Note that theC value should be determined via a hyperparameter sweep using a validation split.

See Also

About

CLIP (Contrastive Language-Image Pretraining), Predict the most relevant text snippet given an image

Topics

Resources

License

Stars

Watchers

Forks


[8]ページ先頭

©2009-2025 Movatter.jp