Vision Transformer (ViT)
Vision Transformer (ViT) is a transformer adapted for computer vision tasks. An image is split into smaller fixed-sized patches which are treated as a sequence of tokens, similar to words for NLP tasks. ViT requires less resources to pretrain compared to convolutional architectures and its performance on large datasets can be transferred to smaller downstream tasks.
You can find all the original ViT checkpoints under theGoogle organization.
Click on the ViT models in the right sidebar for more examples of how to apply ViT to different computer vision tasks.
The example below demonstrates how to classify an image withPipeline or theAutoModel class.
import torchfrom transformersimport pipelinepipeline = pipeline( task="image-classification", model="google/vit-base-patch16-224", torch_dtype=torch.float16, device=0)pipeline(images="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/pipeline-cat-chonk.jpeg")
Notes
- The best results are obtained with supervised pretraining, and during fine-tuning, it may be better to use images with a resolution higher than 224x224.
- UseViTImageProcessorFast to resize (or rescale) and normalize images to the expected size.
- The patch and image resolution are reflected in the checkpoint name. For example, google/vit-base-patch16-224, is thebase-sized architecture with a patch resolution of 16x16 and fine-tuning resolution of 224x224.
ViTConfig
classtransformers.ViTConfig
<source>(hidden_size = 768num_hidden_layers = 12num_attention_heads = 12intermediate_size = 3072hidden_act = 'gelu'hidden_dropout_prob = 0.0attention_probs_dropout_prob = 0.0initializer_range = 0.02layer_norm_eps = 1e-12image_size = 224patch_size = 16num_channels = 3qkv_bias = Trueencoder_stride = 16pooler_output_size = Nonepooler_act = 'tanh'**kwargs)
Parameters
- hidden_size (
int
,optional, defaults to 768) —Dimensionality of the encoder layers and the pooler layer. - num_hidden_layers (
int
,optional, defaults to 12) —Number of hidden layers in the Transformer encoder. - num_attention_heads (
int
,optional, defaults to 12) —Number of attention heads for each attention layer in the Transformer encoder. - intermediate_size (
int
,optional, defaults to 3072) —Dimensionality of the “intermediate” (i.e., feed-forward) layer in the Transformer encoder. - hidden_act (
str
orfunction
,optional, defaults to"gelu"
) —The non-linear activation function (function or string) in the encoder and pooler. If string,"gelu"
,"relu"
,"selu"
and"gelu_new"
are supported. - hidden_dropout_prob (
float
,optional, defaults to 0.0) —The dropout probability for all fully connected layers in the embeddings, encoder, and pooler. - attention_probs_dropout_prob (
float
,optional, defaults to 0.0) —The dropout ratio for the attention probabilities. - initializer_range (
float
,optional, defaults to 0.02) —The standard deviation of the truncated_normal_initializer for initializing all weight matrices. - layer_norm_eps (
float
,optional, defaults to 1e-12) —The epsilon used by the layer normalization layers. - image_size (
int
,optional, defaults to 224) —The size (resolution) of each image. - patch_size (
int
,optional, defaults to 16) —The size (resolution) of each patch. - num_channels (
int
,optional, defaults to 3) —The number of input channels. - qkv_bias (
bool
,optional, defaults toTrue
) —Whether to add a bias to the queries, keys and values. - encoder_stride (
int
,optional, defaults to 16) —Factor to increase the spatial resolution by in the decoder head for masked image modeling. - pooler_output_size (
int
,optional) —Dimensionality of the pooler layer. If None, defaults tohidden_size
. - pooler_act (
str
,optional, defaults to"tanh"
) —The activation function to be used by the pooler. Keys of ACT2FN are supported for Flax andPytorch, and elements ofhttps://www.tensorflow.org/api_docs/python/tf/keras/activations aresupported for Tensorflow.
This is the configuration class to store the configuration of aViTModel. It is used to instantiate an ViTmodel according to the specified arguments, defining the model architecture. Instantiating a configuration with thedefaults will yield a similar configuration to that of the ViTgoogle/vit-base-patch16-224 architecture.
Configuration objects inherit fromPretrainedConfig and can be used to control the model outputs. Read thedocumentation fromPretrainedConfig for more information.
Example:
>>>from transformersimport ViTConfig, ViTModel>>># Initializing a ViT vit-base-patch16-224 style configuration>>>configuration = ViTConfig()>>># Initializing a model (with random weights) from the vit-base-patch16-224 style configuration>>>model = ViTModel(configuration)>>># Accessing the model configuration>>>configuration = model.config
ViTFeatureExtractor
classtransformers.ViTFeatureExtractor
<source>(*args**kwargs)
__call__
<source>(images**kwargs)
Preprocess an image or a batch of images.
ViTImageProcessor
classtransformers.ViTImageProcessor
<source>(do_resize: bool = Truesize: typing.Optional[dict[str, int]] = Noneresample: Resampling = <Resampling.BILINEAR: 2>do_rescale: bool = Truerescale_factor: typing.Union[int, float] = 0.00392156862745098do_normalize: bool = Trueimage_mean: typing.Union[float, list[float], NoneType] = Noneimage_std: typing.Union[float, list[float], NoneType] = Nonedo_convert_rgb: typing.Optional[bool] = None**kwargs)
Parameters
- do_resize (
bool
,optional, defaults toTrue
) —Whether to resize the image’s (height, width) dimensions to the specified(size["height"], size["width"])
. Can be overridden by thedo_resize
parameter in thepreprocess
method. - size (
dict
,optional, defaults to{"height" -- 224, "width": 224}
):Size of the output image after resizing. Can be overridden by thesize
parameter in thepreprocess
method. - resample (
PILImageResampling
,optional, defaults toResampling.BILINEAR
) —Resampling filter to use if resizing the image. Can be overridden by theresample
parameter in thepreprocess
method. - do_rescale (
bool
,optional, defaults toTrue
) —Whether to rescale the image by the specified scalerescale_factor
. Can be overridden by thedo_rescale
parameter in thepreprocess
method. - rescale_factor (
int
orfloat
,optional, defaults to1/255
) —Scale factor to use if rescaling the image. Can be overridden by therescale_factor
parameter in thepreprocess
method. - do_normalize (
bool
,optional, defaults toTrue
) —Whether to normalize the image. Can be overridden by thedo_normalize
parameter in thepreprocess
method. - image_mean (
float
orlist[float]
,optional, defaults toIMAGENET_STANDARD_MEAN
) —Mean to use if normalizing the image. This is a float or list of floats the length of the number ofchannels in the image. Can be overridden by theimage_mean
parameter in thepreprocess
method. - image_std (
float
orlist[float]
,optional, defaults toIMAGENET_STANDARD_STD
) —Standard deviation to use if normalizing the image. This is a float or list of floats the length of thenumber of channels in the image. Can be overridden by theimage_std
parameter in thepreprocess
method. - do_convert_rgb (
bool
,optional) —Whether to convert the image to RGB.
Constructs a ViT image processor.
preprocess
<source>(images: typing.Union[ForwardRef('PIL.Image.Image'), numpy.ndarray, ForwardRef('torch.Tensor'), list['PIL.Image.Image'], list[numpy.ndarray], list['torch.Tensor']]do_resize: typing.Optional[bool] = Nonesize: typing.Optional[dict[str, int]] = Noneresample: Resampling = Nonedo_rescale: typing.Optional[bool] = Nonerescale_factor: typing.Optional[float] = Nonedo_normalize: typing.Optional[bool] = Noneimage_mean: typing.Union[float, list[float], NoneType] = Noneimage_std: typing.Union[float, list[float], NoneType] = Nonereturn_tensors: typing.Union[str, transformers.utils.generic.TensorType, NoneType] = Nonedata_format: typing.Union[str, transformers.image_utils.ChannelDimension] = <ChannelDimension.FIRST: 'channels_first'>input_data_format: typing.Union[str, transformers.image_utils.ChannelDimension, NoneType] = Nonedo_convert_rgb: typing.Optional[bool] = None)
Parameters
- images (
ImageInput
) —Image to preprocess. Expects a single or batch of images with pixel values ranging from 0 to 255. Ifpassing in images with pixel values between 0 and 1, setdo_rescale=False
. - do_resize (
bool
,optional, defaults toself.do_resize
) —Whether to resize the image. - size (
dict[str, int]
,optional, defaults toself.size
) —Dictionary in the format{"height": h, "width": w}
specifying the size of the output image afterresizing. - resample (
PILImageResampling
filter,optional, defaults toself.resample
) —PILImageResampling
filter to use if resizing the image e.g.PILImageResampling.BILINEAR
. Only hasan effect ifdo_resize
is set toTrue
. - do_rescale (
bool
,optional, defaults toself.do_rescale
) —Whether to rescale the image values between [0 - 1]. - rescale_factor (
float
,optional, defaults toself.rescale_factor
) —Rescale factor to rescale the image by ifdo_rescale
is set toTrue
. - do_normalize (
bool
,optional, defaults toself.do_normalize
) —Whether to normalize the image. - image_mean (
float
orlist[float]
,optional, defaults toself.image_mean
) —Image mean to use ifdo_normalize
is set toTrue
. - image_std (
float
orlist[float]
,optional, defaults toself.image_std
) —Image standard deviation to use ifdo_normalize
is set toTrue
. - return_tensors (
str
orTensorType
,optional) —The type of tensors to return. Can be one of:- Unset: Return a list of
np.ndarray
. TensorType.TENSORFLOW
or'tf'
: Return a batch of typetf.Tensor
.TensorType.PYTORCH
or'pt'
: Return a batch of typetorch.Tensor
.TensorType.NUMPY
or'np'
: Return a batch of typenp.ndarray
.TensorType.JAX
or'jax'
: Return a batch of typejax.numpy.ndarray
.
- Unset: Return a list of
- data_format (
ChannelDimension
orstr
,optional, defaults toChannelDimension.FIRST
) —The channel dimension format for the output image. Can be one of:"channels_first"
orChannelDimension.FIRST
: image in (num_channels, height, width) format."channels_last"
orChannelDimension.LAST
: image in (height, width, num_channels) format.- Unset: Use the channel dimension format of the input image.
- input_data_format (
ChannelDimension
orstr
,optional) —The channel dimension format for the input image. If unset, the channel dimension format is inferredfrom the input image. Can be one of:"channels_first"
orChannelDimension.FIRST
: image in (num_channels, height, width) format."channels_last"
orChannelDimension.LAST
: image in (height, width, num_channels) format."none"
orChannelDimension.NONE
: image in (height, width) format.
- do_convert_rgb (
bool
,optional, defaults toself.do_convert_rgb
) —Whether to convert the image to RGB.
Preprocess an image or batch of images.
ViTImageProcessorFast
classtransformers.ViTImageProcessorFast
<source>(**kwargs: typing_extensions.Unpack[transformers.image_processing_utils_fast.DefaultFastImageProcessorKwargs])
Constructs a fast Vit image processor.
preprocess
<source>(images: typing.Union[ForwardRef('PIL.Image.Image'), numpy.ndarray, ForwardRef('torch.Tensor'), list['PIL.Image.Image'], list[numpy.ndarray], list['torch.Tensor']]*args**kwargs: typing_extensions.Unpack[transformers.image_processing_utils_fast.DefaultFastImageProcessorKwargs])→<class 'transformers.image_processing_base.BatchFeature'>
Parameters
- images (
Union[PIL.Image.Image, numpy.ndarray, torch.Tensor, list['PIL.Image.Image'], list[numpy.ndarray], list['torch.Tensor']]
) —Image to preprocess. Expects a single or batch of images with pixel values ranging from 0 to 255. Ifpassing in images with pixel values between 0 and 1, setdo_rescale=False
. - do_resize (
bool
,optional) —Whether to resize the image. - size (
dict[str, int]
,optional) —Describes the maximum input dimensions to the model. - default_to_square (
bool
,optional) —Whether to default to a square image when resizing, if size is an int. - resample (
Union[PILImageResampling, F.InterpolationMode, NoneType]
) —Resampling filter to use if resizing the image. This can be one of the enumPILImageResampling
. Onlyhas an effect ifdo_resize
is set toTrue
. - do_center_crop (
bool
,optional) —Whether to center crop the image. - crop_size (
dict[str, int]
,optional) —Size of the output image after applyingcenter_crop
. - do_rescale (
bool
,optional) —Whether to rescale the image. - rescale_factor (
Union[int, float, NoneType]
) —Rescale factor to rescale the image by ifdo_rescale
is set toTrue
. - do_normalize (
bool
,optional) —Whether to normalize the image. - image_mean (
Union[float, list[float], NoneType]
) —Image mean to use for normalization. Only has an effect ifdo_normalize
is set toTrue
. - image_std (
Union[float, list[float], NoneType]
) —Image standard deviation to use for normalization. Only has an effect ifdo_normalize
is set toTrue
. - do_convert_rgb (
bool
,optional) —Whether to convert the image to RGB. - return_tensors (
Union[str, ~utils.generic.TensorType, NoneType]
) —Returns stacked tensors if set to `pt, otherwise returns a list of tensors. - data_format (
~image_utils.ChannelDimension
,optional) —OnlyChannelDimension.FIRST
is supported. Added for compatibility with slow processors. - input_data_format (
Union[str, ~image_utils.ChannelDimension, NoneType]
) —The channel dimension format for the input image. If unset, the channel dimension format is inferredfrom the input image. Can be one of:"channels_first"
orChannelDimension.FIRST
: image in (num_channels, height, width) format."channels_last"
orChannelDimension.LAST
: image in (height, width, num_channels) format."none"
orChannelDimension.NONE
: image in (height, width) format.
- device (
torch.device
,optional) —The device to process the images on. If unset, the device is inferred from the input images. - disable_grouping (
bool
,optional) —Whether to disable grouping of images by size to process them individually and not in batches.If None, will be set to True if the images are on CPU, and False otherwise. This choice is based onempirical observations, as detailed here:https://github.com/huggingface/transformers/pull/38157
Returns
<class 'transformers.image_processing_base.BatchFeature'>
- data (
dict
) — Dictionary of lists/arrays/tensors returned by thecall method (‘pixel_values’, etc.). - tensor_type (
Union[None, str, TensorType]
,optional) — You can give a tensor_type here to convert the lists of integers in PyTorch/TensorFlow/Numpy Tensors atinitialization.
ViTModel
classtransformers.ViTModel
<source>(config: ViTConfigadd_pooling_layer: bool = Trueuse_mask_token: bool = False)
Parameters
- config (ViTConfig) —Model configuration class with all the parameters of the model. Initializing with a config file does notload the weights associated with the model, only the configuration. Check out thefrom_pretrained() method to load the model weights.
- add_pooling_layer (
bool
,optional, defaults toTrue
) —Whether to add a pooling layer - use_mask_token (
bool
,optional, defaults toFalse
) —Whether to use a mask token for masked image modeling.
The bare Vit Model outputting raw hidden-states without any specific head on top.
This model inherits fromPreTrainedModel. Check the superclass documentation for the generic methods thelibrary implements for all its model (such as downloading or saving, resizing the input embeddings, pruning headsetc.)
This model is also a PyTorchtorch.nn.Module subclass.Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usageand behavior.
forward
<source>(pixel_values: typing.Optional[torch.Tensor] = Nonebool_masked_pos: typing.Optional[torch.BoolTensor] = Nonehead_mask: typing.Optional[torch.Tensor] = Noneoutput_attentions: typing.Optional[bool] = Noneoutput_hidden_states: typing.Optional[bool] = Noneinterpolate_pos_encoding: typing.Optional[bool] = Nonereturn_dict: typing.Optional[bool] = None)→transformers.modeling_outputs.BaseModelOutputWithPooling ortuple(torch.FloatTensor)
Parameters
- pixel_values (
torch.Tensor
of shape(batch_size, num_channels, image_size, image_size)
,optional) —The tensors corresponding to the input images. Pixel values can be obtained using{image_processor_class}
. See{image_processor_class}.__call__
for details ({processor_class}
uses{image_processor_class}
for processing images). - bool_masked_pos (
torch.BoolTensor
of shape(batch_size, num_patches)
,optional) —Boolean masked positions. Indicates which patches are masked (1) and which aren’t (0). - head_mask (
torch.Tensor
of shape(num_heads,)
or(num_layers, num_heads)
,optional) —Mask to nullify selected heads of the self-attention modules. Mask values selected in[0, 1]
:- 1 indicates the head isnot masked,
- 0 indicates the head ismasked.
- output_attentions (
bool
,optional) —Whether or not to return the attentions tensors of all attention layers. Seeattentions
under returnedtensors for more detail. - output_hidden_states (
bool
,optional) —Whether or not to return the hidden states of all layers. Seehidden_states
under returned tensors formore detail. - interpolate_pos_encoding (
bool
,optional) —Whether to interpolate the pre-trained position encodings. - return_dict (
bool
,optional) —Whether or not to return aModelOutput instead of a plain tuple.
Returns
transformers.modeling_outputs.BaseModelOutputWithPooling ortuple(torch.FloatTensor)
Atransformers.modeling_outputs.BaseModelOutputWithPooling or a tuple oftorch.FloatTensor
(ifreturn_dict=False
is passed or whenconfig.return_dict=False
) comprising variouselements depending on the configuration (ViTConfig) and inputs.
last_hidden_state (
torch.FloatTensor
of shape(batch_size, sequence_length, hidden_size)
) — Sequence of hidden-states at the output of the last layer of the model.pooler_output (
torch.FloatTensor
of shape(batch_size, hidden_size)
) — Last layer hidden-state of the first token of the sequence (classification token) after further processingthrough the layers used for the auxiliary pretraining task. E.g. for BERT-family of models, this returnsthe classification token after processing through a linear layer and a tanh activation function. The linearlayer weights are trained from the next sentence prediction (classification) objective during pretraining.hidden_states (
tuple(torch.FloatTensor)
,optional, returned whenoutput_hidden_states=True
is passed or whenconfig.output_hidden_states=True
) — Tuple oftorch.FloatTensor
(one for the output of the embeddings, if the model has an embedding layer, +one for the output of each layer) of shape(batch_size, sequence_length, hidden_size)
.Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
attentions (
tuple(torch.FloatTensor)
,optional, returned whenoutput_attentions=True
is passed or whenconfig.output_attentions=True
) — Tuple oftorch.FloatTensor
(one for each layer) of shape(batch_size, num_heads, sequence_length, sequence_length)
.Attentions weights after the attention softmax, used to compute the weighted average in the self-attentionheads.
TheViTModel forward method, overrides the__call__
special method.
Although the recipe for forward pass needs to be defined within this function, one should call theModule
instance afterwards instead of this since the former takes care of running the pre and post processing steps whilethe latter silently ignores them.
ViTForMaskedImageModeling
classtransformers.ViTForMaskedImageModeling
<source>(config: ViTConfig)
Parameters
- config (ViTConfig) —Model configuration class with all the parameters of the model. Initializing with a config file does notload the weights associated with the model, only the configuration. Check out thefrom_pretrained() method to load the model weights.
ViT Model with a decoder on top for masked image modeling, as proposed inSimMIM.
Note that we provide a script to pre-train this model on custom data in ourexamplesdirectory.
This model inherits fromPreTrainedModel. Check the superclass documentation for the generic methods thelibrary implements for all its model (such as downloading or saving, resizing the input embeddings, pruning headsetc.)
This model is also a PyTorchtorch.nn.Module subclass.Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usageand behavior.
forward
<source>(pixel_values: typing.Optional[torch.Tensor] = Nonebool_masked_pos: typing.Optional[torch.BoolTensor] = Nonehead_mask: typing.Optional[torch.Tensor] = Noneoutput_attentions: typing.Optional[bool] = Noneoutput_hidden_states: typing.Optional[bool] = Noneinterpolate_pos_encoding: typing.Optional[bool] = Nonereturn_dict: typing.Optional[bool] = None)→transformers.modeling_outputs.MaskedImageModelingOutput
ortuple(torch.FloatTensor)
Parameters
- pixel_values (
torch.Tensor
of shape(batch_size, num_channels, image_size, image_size)
,optional) —The tensors corresponding to the input images. Pixel values can be obtained using{image_processor_class}
. See{image_processor_class}.__call__
for details ({processor_class}
uses{image_processor_class}
for processing images). - bool_masked_pos (
torch.BoolTensor
of shape(batch_size, num_patches)
) —Boolean masked positions. Indicates which patches are masked (1) and which aren’t (0). - head_mask (
torch.Tensor
of shape(num_heads,)
or(num_layers, num_heads)
,optional) —Mask to nullify selected heads of the self-attention modules. Mask values selected in[0, 1]
:- 1 indicates the head isnot masked,
- 0 indicates the head ismasked.
- output_attentions (
bool
,optional) —Whether or not to return the attentions tensors of all attention layers. Seeattentions
under returnedtensors for more detail. - output_hidden_states (
bool
,optional) —Whether or not to return the hidden states of all layers. Seehidden_states
under returned tensors formore detail. - interpolate_pos_encoding (
bool
,optional) —Whether to interpolate the pre-trained position encodings. - return_dict (
bool
,optional) —Whether or not to return aModelOutput instead of a plain tuple.
Returns
transformers.modeling_outputs.MaskedImageModelingOutput
ortuple(torch.FloatTensor)
Atransformers.modeling_outputs.MaskedImageModelingOutput
or a tuple oftorch.FloatTensor
(ifreturn_dict=False
is passed or whenconfig.return_dict=False
) comprising variouselements depending on the configuration (ViTConfig) and inputs.
- loss (
torch.FloatTensor
of shape(1,)
,optional, returned whenbool_masked_pos
is provided) — Reconstruction loss. - reconstruction (
torch.FloatTensor
of shape(batch_size, num_channels, height, width)
) — Reconstructed / completed images. - hidden_states (
tuple(torch.FloatTensor)
,optional, returned whenoutput_hidden_states=True
is passed or - when
config.output_hidden_states=True
) — Tuple oftorch.FloatTensor
(one for the output of the embeddings, if the model has an embedding layer, +one for the output of each stage) of shape(batch_size, sequence_length, hidden_size)
. Hidden-states(also called feature maps) of the model at the output of each stage. - attentions (
tuple(torch.FloatTensor)
,optional, returned whenoutput_attentions=True
is passed or when config.output_attentions=True
):Tuple oftorch.FloatTensor
(one for each layer) of shape(batch_size, num_heads, patch_size, sequence_length)
. Attentions weights after the attention softmax, used to compute the weighted average inthe self-attention heads.
TheViTForMaskedImageModeling forward method, overrides the__call__
special method.
Although the recipe for forward pass needs to be defined within this function, one should call theModule
instance afterwards instead of this since the former takes care of running the pre and post processing steps whilethe latter silently ignores them.
Examples:
>>>from transformersimport AutoImageProcessor, ViTForMaskedImageModeling>>>import torch>>>from PILimport Image>>>import requests>>>url ="http://images.cocodataset.org/val2017/000000039769.jpg">>>image = Image.open(requests.get(url, stream=True).raw)>>>image_processor = AutoImageProcessor.from_pretrained("google/vit-base-patch16-224-in21k")>>>model = ViTForMaskedImageModeling.from_pretrained("google/vit-base-patch16-224-in21k")>>>num_patches = (model.config.image_size // model.config.patch_size) **2>>>pixel_values = image_processor(images=image, return_tensors="pt").pixel_values>>># create random boolean mask of shape (batch_size, num_patches)>>>bool_masked_pos = torch.randint(low=0, high=2, size=(1, num_patches)).bool()>>>outputs = model(pixel_values, bool_masked_pos=bool_masked_pos)>>>loss, reconstructed_pixel_values = outputs.loss, outputs.reconstruction>>>list(reconstructed_pixel_values.shape)[1,3,224,224]
ViTForImageClassification
classtransformers.ViTForImageClassification
<source>(config: ViTConfig)
Parameters
- config (ViTConfig) —Model configuration class with all the parameters of the model. Initializing with a config file does notload the weights associated with the model, only the configuration. Check out thefrom_pretrained() method to load the model weights.
ViT Model transformer with an image classification head on top (a linear layer on top of the final hidden state ofthe [CLS] token) e.g. for ImageNet.
Note that it’s possible to fine-tune ViT on higher resolution images than the ones it has been trained on, bysettinginterpolate_pos_encoding
toTrue
in the forward of the model. This will interpolate the pre-trainedposition embeddings to the higher resolution.
This model inherits fromPreTrainedModel. Check the superclass documentation for the generic methods thelibrary implements for all its model (such as downloading or saving, resizing the input embeddings, pruning headsetc.)
This model is also a PyTorchtorch.nn.Module subclass.Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usageand behavior.
forward
<source>(pixel_values: typing.Optional[torch.Tensor] = Nonehead_mask: typing.Optional[torch.Tensor] = Nonelabels: typing.Optional[torch.Tensor] = Noneoutput_attentions: typing.Optional[bool] = Noneoutput_hidden_states: typing.Optional[bool] = Noneinterpolate_pos_encoding: typing.Optional[bool] = Nonereturn_dict: typing.Optional[bool] = None)→transformers.modeling_outputs.ImageClassifierOutput ortuple(torch.FloatTensor)
Parameters
- pixel_values (
torch.Tensor
of shape(batch_size, num_channels, image_size, image_size)
,optional) —The tensors corresponding to the input images. Pixel values can be obtained using{image_processor_class}
. See{image_processor_class}.__call__
for details ({processor_class}
uses{image_processor_class}
for processing images). - head_mask (
torch.Tensor
of shape(num_heads,)
or(num_layers, num_heads)
,optional) —Mask to nullify selected heads of the self-attention modules. Mask values selected in[0, 1]
:- 1 indicates the head isnot masked,
- 0 indicates the head ismasked.
- labels (
torch.LongTensor
of shape(batch_size,)
,optional) —Labels for computing the image classification/regression loss. Indices should be in[0, ..., config.num_labels - 1]
. Ifconfig.num_labels == 1
a regression loss is computed (Mean-Square loss), Ifconfig.num_labels > 1
a classification loss is computed (Cross-Entropy). - output_attentions (
bool
,optional) —Whether or not to return the attentions tensors of all attention layers. Seeattentions
under returnedtensors for more detail. - output_hidden_states (
bool
,optional) —Whether or not to return the hidden states of all layers. Seehidden_states
under returned tensors formore detail. - interpolate_pos_encoding (
bool
,optional) —Whether to interpolate the pre-trained position encodings. - return_dict (
bool
,optional) —Whether or not to return aModelOutput instead of a plain tuple.
Returns
transformers.modeling_outputs.ImageClassifierOutput ortuple(torch.FloatTensor)
Atransformers.modeling_outputs.ImageClassifierOutput or a tuple oftorch.FloatTensor
(ifreturn_dict=False
is passed or whenconfig.return_dict=False
) comprising variouselements depending on the configuration (ViTConfig) and inputs.
loss (
torch.FloatTensor
of shape(1,)
,optional, returned whenlabels
is provided) — Classification (or regression if config.num_labels==1) loss.logits (
torch.FloatTensor
of shape(batch_size, config.num_labels)
) — Classification (or regression if config.num_labels==1) scores (before SoftMax).hidden_states (
tuple(torch.FloatTensor)
,optional, returned whenoutput_hidden_states=True
is passed or whenconfig.output_hidden_states=True
) — Tuple oftorch.FloatTensor
(one for the output of the embeddings, if the model has an embedding layer, +one for the output of each stage) of shape(batch_size, sequence_length, hidden_size)
. Hidden-states(also called feature maps) of the model at the output of each stage.attentions (
tuple(torch.FloatTensor)
,optional, returned whenoutput_attentions=True
is passed or whenconfig.output_attentions=True
) — Tuple oftorch.FloatTensor
(one for each layer) of shape(batch_size, num_heads, patch_size, sequence_length)
.Attentions weights after the attention softmax, used to compute the weighted average in the self-attentionheads.
TheViTForImageClassification forward method, overrides the__call__
special method.
Although the recipe for forward pass needs to be defined within this function, one should call theModule
instance afterwards instead of this since the former takes care of running the pre and post processing steps whilethe latter silently ignores them.
Example:
>>>from transformersimport AutoImageProcessor, ViTForImageClassification>>>import torch>>>from datasetsimport load_dataset>>>dataset = load_dataset("huggingface/cats-image")>>>image = dataset["test"]["image"][0]>>>image_processor = AutoImageProcessor.from_pretrained("google/vit-base-patch16-224")>>>model = ViTForImageClassification.from_pretrained("google/vit-base-patch16-224")>>>inputs = image_processor(image, return_tensors="pt")>>>with torch.no_grad():... logits = model(**inputs).logits>>># model predicts one of the 1000 ImageNet classes>>>predicted_label = logits.argmax(-1).item()>>>print(model.config.id2label[predicted_label])...
TFViTModel
classtransformers.TFViTModel
<source>(config: ViTConfig*inputsadd_pooling_layer = True**kwargs)
Parameters
- config (ViTConfig) — Model configuration class with all the parameters of the model.Initializing with a config file does not load the weights associated with the model, only theconfiguration. Check out thefrom_pretrained() method to load the model weights.
The bare ViT Model transformer outputting raw hidden-states without any specific head on top.
This model inherits fromTFPreTrainedModel. Check the superclass documentation for the generic methods thelibrary implements for all its model (such as downloading or saving, resizing the input embeddings, pruning headsetc.)
This model is also akeras.Model subclass. Use itas a regular TF 2.0 Keras Model and refer to the TF 2.0 documentation for all matter related to general usage andbehavior.
TensorFlow models and layers intransformers
accept two formats as input:
- having all inputs as keyword arguments (like PyTorch models), or
- having all inputs as a list, tuple or dict in the first positional argument.
The reason the second format is supported is that Keras methods prefer this format when passing inputs to modelsand layers. Because of this support, when using methods likemodel.fit()
things should “just work” for you - justpass your inputs and labels in any format thatmodel.fit()
supports! If, however, you want to use the secondformat outside of Keras methods likefit()
andpredict()
, such as when creating your own layers or models withthe KerasFunctional
API, there are three possibilities you can use to gather all the input Tensors in the firstpositional argument:
- a single Tensor with
pixel_values
only and nothing else:model(pixel_values)
- a list of varying length with one or several input Tensors IN THE ORDER given in the docstring:
model([pixel_values, attention_mask])
ormodel([pixel_values, attention_mask, token_type_ids])
- a dictionary with one or several input Tensors associated to the input names given in the docstring:
model({"pixel_values": pixel_values, "token_type_ids": token_type_ids})
Note that when creating models and layers withsubclassing then you don’t need to worryabout any of this, as you can just pass inputs like you would to any other Python function!
call
<source>(pixel_values: TFModelInputType | None = Nonehead_mask: np.ndarray | tf.Tensor | None = Noneoutput_attentions: Optional[bool] = Noneoutput_hidden_states: Optional[bool] = Noneinterpolate_pos_encoding: Optional[bool] = Nonereturn_dict: Optional[bool] = Nonetraining: bool = False)→transformers.modeling_tf_outputs.TFBaseModelOutputWithPooling ortuple(tf.Tensor)
Parameters
- pixel_values (
np.ndarray
,tf.Tensor
,list[tf.Tensor]
`dict[str, tf.Tensor]
ordict[str, np.ndarray]
and each example must have the shape(batch_size, num_channels, height, width)
) —Pixel values. Pixel values can be obtained usingAutoImageProcessor. SeeViTImageProcessor.call()for details. - head_mask (
np.ndarray
ortf.Tensor
of shape(num_heads,)
or(num_layers, num_heads)
,optional) —Mask to nullify selected heads of the self-attention modules. Mask values selected in[0, 1]
:- 1 indicates the head isnot masked,
- 0 indicates the head ismasked.
- output_attentions (
bool
,optional) —Whether or not to return the attentions tensors of all attention layers. Seeattentions
under returnedtensors for more detail. This argument can be used only in eager mode, in graph mode the value in theconfig will be used instead. - output_hidden_states (
bool
,optional) —Whether or not to return the hidden states of all layers. Seehidden_states
under returned tensors formore detail. This argument can be used only in eager mode, in graph mode the value in the config will beused instead. - interpolate_pos_encoding (
bool
,optional) —Whether to interpolate the pre-trained position encodings. - return_dict (
bool
,optional) —Whether or not to return aModelOutput instead of a plain tuple. This argument can be used ineager mode, in graph mode the value will always be set to True. - training (
bool
,optional, defaults to `False“) —Whether or not to use the model in training mode (some modules like dropout modules have differentbehaviors between training and evaluation).
Returns
transformers.modeling_tf_outputs.TFBaseModelOutputWithPooling ortuple(tf.Tensor)
Atransformers.modeling_tf_outputs.TFBaseModelOutputWithPooling or a tuple oftf.Tensor
(ifreturn_dict=False
is passed or whenconfig.return_dict=False
) comprising various elements depending on theconfiguration (ViTConfig) and inputs.
last_hidden_state (
tf.Tensor
of shape(batch_size, sequence_length, hidden_size)
) — Sequence of hidden-states at the output of the last layer of the model.pooler_output (
tf.Tensor
of shape(batch_size, hidden_size)
) — Last layer hidden-state of the first token of the sequence (classification token) further processed by aLinear layer and a Tanh activation function. The Linear layer weights are trained from the next sentenceprediction (classification) objective during pretraining.This output is usuallynot a good summary of the semantic content of the input, you’re often better withaveraging or pooling the sequence of hidden-states for the whole input sequence.
hidden_states (
tuple(tf.Tensor)
,optional, returned whenoutput_hidden_states=True
is passed or whenconfig.output_hidden_states=True
) — Tuple oftf.Tensor
(one for the output of the embeddings + one for the output of each layer) of shape(batch_size, sequence_length, hidden_size)
.Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (
tuple(tf.Tensor)
,optional, returned whenoutput_attentions=True
is passed or whenconfig.output_attentions=True
) — Tuple oftf.Tensor
(one for each layer) of shape(batch_size, num_heads, sequence_length, sequence_length)
.Attentions weights after the attention softmax, used to compute the weighted average in the self-attentionheads.
TheTFViTModel forward method, overrides the__call__
special method.
Although the recipe for forward pass needs to be defined within this function, one should call theModule
instance afterwards instead of this since the former takes care of running the pre and post processing steps whilethe latter silently ignores them.
Example:
>>>from transformersimport AutoImageProcessor, TFViTModel>>>from datasetsimport load_dataset>>>dataset = load_dataset("huggingface/cats-image")>>>image = dataset["test"]["image"][0]>>>image_processor = AutoImageProcessor.from_pretrained("google/vit-base-patch16-224-in21k")>>>model = TFViTModel.from_pretrained("google/vit-base-patch16-224-in21k")>>>inputs = image_processor(image, return_tensors="tf")>>>outputs = model(**inputs)>>>last_hidden_states = outputs.last_hidden_state>>>list(last_hidden_states.shape)[1,197,768]
TFViTForImageClassification
classtransformers.TFViTForImageClassification
<source>(config: ViTConfig*inputs**kwargs)
Parameters
- config (ViTConfig) — Model configuration class with all the parameters of the model.Initializing with a config file does not load the weights associated with the model, only theconfiguration. Check out thefrom_pretrained() method to load the model weights.
ViT Model transformer with an image classification head on top (a linear layer on top of the final hidden state ofthe [CLS] token) e.g. for ImageNet.
Note that it’s possible to fine-tune ViT on higher resolution images than the ones it has been trained on, bysettinginterpolate_pos_encoding
toTrue
in the forward of the model. This will interpolate the pre-trainedposition embeddings to the higher resolution.
This model inherits fromTFPreTrainedModel. Check the superclass documentation for the generic methods thelibrary implements for all its model (such as downloading or saving, resizing the input embeddings, pruning headsetc.)
This model is also akeras.Model subclass. Use itas a regular TF 2.0 Keras Model and refer to the TF 2.0 documentation for all matter related to general usage andbehavior.
TensorFlow models and layers intransformers
accept two formats as input:
- having all inputs as keyword arguments (like PyTorch models), or
- having all inputs as a list, tuple or dict in the first positional argument.
The reason the second format is supported is that Keras methods prefer this format when passing inputs to modelsand layers. Because of this support, when using methods likemodel.fit()
things should “just work” for you - justpass your inputs and labels in any format thatmodel.fit()
supports! If, however, you want to use the secondformat outside of Keras methods likefit()
andpredict()
, such as when creating your own layers or models withthe KerasFunctional
API, there are three possibilities you can use to gather all the input Tensors in the firstpositional argument:
- a single Tensor with
pixel_values
only and nothing else:model(pixel_values)
- a list of varying length with one or several input Tensors IN THE ORDER given in the docstring:
model([pixel_values, attention_mask])
ormodel([pixel_values, attention_mask, token_type_ids])
- a dictionary with one or several input Tensors associated to the input names given in the docstring:
model({"pixel_values": pixel_values, "token_type_ids": token_type_ids})
Note that when creating models and layers withsubclassing then you don’t need to worryabout any of this, as you can just pass inputs like you would to any other Python function!
call
<source>(pixel_values: TFModelInputType | None = Nonehead_mask: np.ndarray | tf.Tensor | None = Noneoutput_attentions: Optional[bool] = Noneoutput_hidden_states: Optional[bool] = Noneinterpolate_pos_encoding: Optional[bool] = Nonereturn_dict: Optional[bool] = Nonelabels: np.ndarray | tf.Tensor | None = Nonetraining: Optional[bool] = False)→transformers.modeling_tf_outputs.TFSequenceClassifierOutput ortuple(tf.Tensor)
Parameters
- pixel_values (
np.ndarray
,tf.Tensor
,list[tf.Tensor]
`dict[str, tf.Tensor]
ordict[str, np.ndarray]
and each example must have the shape(batch_size, num_channels, height, width)
) —Pixel values. Pixel values can be obtained usingAutoImageProcessor. SeeViTImageProcessor.call()for details. - head_mask (
np.ndarray
ortf.Tensor
of shape(num_heads,)
or(num_layers, num_heads)
,optional) —Mask to nullify selected heads of the self-attention modules. Mask values selected in[0, 1]
:- 1 indicates the head isnot masked,
- 0 indicates the head ismasked.
- output_attentions (
bool
,optional) —Whether or not to return the attentions tensors of all attention layers. Seeattentions
under returnedtensors for more detail. This argument can be used only in eager mode, in graph mode the value in theconfig will be used instead. - output_hidden_states (
bool
,optional) —Whether or not to return the hidden states of all layers. Seehidden_states
under returned tensors formore detail. This argument can be used only in eager mode, in graph mode the value in the config will beused instead. - interpolate_pos_encoding (
bool
,optional) —Whether to interpolate the pre-trained position encodings. - return_dict (
bool
,optional) —Whether or not to return aModelOutput instead of a plain tuple. This argument can be used ineager mode, in graph mode the value will always be set to True. - training (
bool
,optional, defaults to `False“) —Whether or not to use the model in training mode (some modules like dropout modules have differentbehaviors between training and evaluation). - labels (
tf.Tensor
ornp.ndarray
of shape(batch_size,)
,optional) —Labels for computing the image classification/regression loss. Indices should be in[0, ..., config.num_labels - 1]
. Ifconfig.num_labels == 1
a regression loss is computed (Mean-Square loss), Ifconfig.num_labels > 1
a classification loss is computed (Cross-Entropy).
Returns
transformers.modeling_tf_outputs.TFSequenceClassifierOutput ortuple(tf.Tensor)
Atransformers.modeling_tf_outputs.TFSequenceClassifierOutput or a tuple oftf.Tensor
(ifreturn_dict=False
is passed or whenconfig.return_dict=False
) comprising various elements depending on theconfiguration (ViTConfig) and inputs.
loss (
tf.Tensor
of shape(batch_size, )
,optional, returned whenlabels
is provided) — Classification (or regression if config.num_labels==1) loss.logits (
tf.Tensor
of shape(batch_size, config.num_labels)
) — Classification (or regression if config.num_labels==1) scores (before SoftMax).hidden_states (
tuple(tf.Tensor)
,optional, returned whenoutput_hidden_states=True
is passed or whenconfig.output_hidden_states=True
) — Tuple oftf.Tensor
(one for the output of the embeddings + one for the output of each layer) of shape(batch_size, sequence_length, hidden_size)
.Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (
tuple(tf.Tensor)
,optional, returned whenoutput_attentions=True
is passed or whenconfig.output_attentions=True
) — Tuple oftf.Tensor
(one for each layer) of shape(batch_size, num_heads, sequence_length, sequence_length)
.Attentions weights after the attention softmax, used to compute the weighted average in the self-attentionheads.
TheTFViTForImageClassification forward method, overrides the__call__
special method.
Although the recipe for forward pass needs to be defined within this function, one should call theModule
instance afterwards instead of this since the former takes care of running the pre and post processing steps whilethe latter silently ignores them.
Example:
>>>from transformersimport AutoImageProcessor, TFViTForImageClassification>>>import tensorflowas tf>>>from datasetsimport load_dataset>>>dataset = load_dataset("huggingface/cats-image"))>>>image = dataset["test"]["image"][0]>>>image_processor = AutoImageProcessor.from_pretrained("google/vit-base-patch16-224")>>>model = TFViTForImageClassification.from_pretrained("google/vit-base-patch16-224")>>>inputs = image_processor(image, return_tensors="tf")>>>logits = model(**inputs).logits>>># model predicts one of the 1000 ImageNet classes>>>predicted_label =int(tf.math.argmax(logits, axis=-1))>>>print(model.config.id2label[predicted_label])Egyptian cat
FlaxVitModel
classtransformers.FlaxViTModel
<source>(config: ViTConfiginput_shape = Noneseed: int = 0dtype: dtype = <class 'jax.numpy.float32'>_do_init: bool = True**kwargs)
Parameters
- config (ViTConfig) — Model configuration class with all the parameters of the model.Initializing with a config file does not load the weights associated with the model, only theconfiguration. Check out thefrom_pretrained() method to load the model weights.
- dtype (
jax.numpy.dtype
,optional, defaults tojax.numpy.float32
) —The data type of the computation. Can be one ofjax.numpy.float32
,jax.numpy.float16
(on GPUs) andjax.numpy.bfloat16
(on TPUs).This can be used to enable mixed-precision training or half-precision inference on GPUs or TPUs. Ifspecified all the computation will be performed with the given
dtype
.Note that this only specifies the dtype of the computation and does not influence the dtype of modelparameters.
If you wish to change the dtype of the model parameters, seeto_fp16() andto_bf16().
The bare ViT Model transformer outputting raw hidden-states without any specific head on top.
This model inherits fromFlaxPreTrainedModel. Check the superclass documentation for the generic methods thelibrary implements for all its model (such as downloading, saving and converting weights from PyTorch models)
This model is also aflax.linen.Module subclass. Use it asa regular Flax linen Module and refer to the Flax documentation for all matter related to general usage andbehavior.
Finally, this model supports inherent JAX features such as:
__call__
<source>(pixel_valuesparams: typing.Optional[dict] = Nonedropout_rng: <function PRNGKey at 0x7f81640ba9e0> = Nonetrain: bool = Falseoutput_attentions: typing.Optional[bool] = Noneoutput_hidden_states: typing.Optional[bool] = Nonereturn_dict: typing.Optional[bool] = None)→transformers.modeling_flax_outputs.FlaxBaseModelOutputWithPooling ortuple(torch.FloatTensor)
Returns
transformers.modeling_flax_outputs.FlaxBaseModelOutputWithPooling ortuple(torch.FloatTensor)
Atransformers.modeling_flax_outputs.FlaxBaseModelOutputWithPooling or a tuple oftorch.FloatTensor
(ifreturn_dict=False
is passed or whenconfig.return_dict=False
) comprising variouselements depending on the configuration (<class 'transformers.models.vit.configuration_vit.ViTConfig'>
) and inputs.
last_hidden_state (
jnp.ndarray
of shape(batch_size, sequence_length, hidden_size)
) — Sequence of hidden-states at the output of the last layer of the model.pooler_output (
jnp.ndarray
of shape(batch_size, hidden_size)
) — Last layer hidden-state of the first token of the sequence (classification token) further processed by aLinear layer and a Tanh activation function. The Linear layer weights are trained from the next sentenceprediction (classification) objective during pretraining.hidden_states (
tuple(jnp.ndarray)
,optional, returned whenoutput_hidden_states=True
is passed or whenconfig.output_hidden_states=True
) — Tuple ofjnp.ndarray
(one for the output of the embeddings + one for the output of each layer) of shape(batch_size, sequence_length, hidden_size)
.Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (
tuple(jnp.ndarray)
,optional, returned whenoutput_attentions=True
is passed or whenconfig.output_attentions=True
) — Tuple ofjnp.ndarray
(one for each layer) of shape(batch_size, num_heads, sequence_length, sequence_length)
.Attentions weights after the attention softmax, used to compute the weighted average in the self-attentionheads.
TheFlaxViTPreTrainedModel
forward method, overrides the__call__
special method.
Although the recipe for forward pass needs to be defined within this function, one should call theModule
instance afterwards instead of this since the former takes care of running the pre and post processing steps whilethe latter silently ignores them.
Examples:
>>>from transformersimport AutoImageProcessor, FlaxViTModel>>>from PILimport Image>>>import requests>>>url ="http://images.cocodataset.org/val2017/000000039769.jpg">>>image = Image.open(requests.get(url, stream=True).raw)>>>image_processor = AutoImageProcessor.from_pretrained("google/vit-base-patch16-224-in21k")>>>model = FlaxViTModel.from_pretrained("google/vit-base-patch16-224-in21k")>>>inputs = image_processor(images=image, return_tensors="np")>>>outputs = model(**inputs)>>>last_hidden_states = outputs.last_hidden_state
FlaxViTForImageClassification
classtransformers.FlaxViTForImageClassification
<source>(config: ViTConfiginput_shape = Noneseed: int = 0dtype: dtype = <class 'jax.numpy.float32'>_do_init: bool = True**kwargs)
Parameters
- config (ViTConfig) — Model configuration class with all the parameters of the model.Initializing with a config file does not load the weights associated with the model, only theconfiguration. Check out thefrom_pretrained() method to load the model weights.
- dtype (
jax.numpy.dtype
,optional, defaults tojax.numpy.float32
) —The data type of the computation. Can be one ofjax.numpy.float32
,jax.numpy.float16
(on GPUs) andjax.numpy.bfloat16
(on TPUs).This can be used to enable mixed-precision training or half-precision inference on GPUs or TPUs. Ifspecified all the computation will be performed with the given
dtype
.Note that this only specifies the dtype of the computation and does not influence the dtype of modelparameters.
If you wish to change the dtype of the model parameters, seeto_fp16() andto_bf16().
ViT Model transformer with an image classification head on top (a linear layer on top of the final hidden state ofthe [CLS] token) e.g. for ImageNet.
This model inherits fromFlaxPreTrainedModel. Check the superclass documentation for the generic methods thelibrary implements for all its model (such as downloading, saving and converting weights from PyTorch models)
This model is also aflax.linen.Module subclass. Use it asa regular Flax linen Module and refer to the Flax documentation for all matter related to general usage andbehavior.
Finally, this model supports inherent JAX features such as:
__call__
<source>(pixel_valuesparams: typing.Optional[dict] = Nonedropout_rng: <function PRNGKey at 0x7f81640ba9e0> = Nonetrain: bool = Falseoutput_attentions: typing.Optional[bool] = Noneoutput_hidden_states: typing.Optional[bool] = Nonereturn_dict: typing.Optional[bool] = None)→transformers.modeling_flax_outputs.FlaxSequenceClassifierOutput ortuple(torch.FloatTensor)
Returns
transformers.modeling_flax_outputs.FlaxSequenceClassifierOutput ortuple(torch.FloatTensor)
Atransformers.modeling_flax_outputs.FlaxSequenceClassifierOutput or a tuple oftorch.FloatTensor
(ifreturn_dict=False
is passed or whenconfig.return_dict=False
) comprising variouselements depending on the configuration (<class 'transformers.models.vit.configuration_vit.ViTConfig'>
) and inputs.
logits (
jnp.ndarray
of shape(batch_size, config.num_labels)
) — Classification (or regression if config.num_labels==1) scores (before SoftMax).hidden_states (
tuple(jnp.ndarray)
,optional, returned whenoutput_hidden_states=True
is passed or whenconfig.output_hidden_states=True
) — Tuple ofjnp.ndarray
(one for the output of the embeddings + one for the output of each layer) of shape(batch_size, sequence_length, hidden_size)
.Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (
tuple(jnp.ndarray)
,optional, returned whenoutput_attentions=True
is passed or whenconfig.output_attentions=True
) — Tuple ofjnp.ndarray
(one for each layer) of shape(batch_size, num_heads, sequence_length, sequence_length)
.Attentions weights after the attention softmax, used to compute the weighted average in the self-attentionheads.
TheFlaxViTPreTrainedModel
forward method, overrides the__call__
special method.
Although the recipe for forward pass needs to be defined within this function, one should call theModule
instance afterwards instead of this since the former takes care of running the pre and post processing steps whilethe latter silently ignores them.
Example:
>>>from transformersimport AutoImageProcessor, FlaxViTForImageClassification>>>from PILimport Image>>>import jax>>>import requests>>>url ="http://images.cocodataset.org/val2017/000000039769.jpg">>>image = Image.open(requests.get(url, stream=True).raw)>>>image_processor = AutoImageProcessor.from_pretrained("google/vit-base-patch16-224")>>>model = FlaxViTForImageClassification.from_pretrained("google/vit-base-patch16-224")>>>inputs = image_processor(images=image, return_tensors="np")>>>outputs = model(**inputs)>>>logits = outputs.logits>>># model predicts one of the 1000 ImageNet classes>>>predicted_class_idx = jax.numpy.argmax(logits, axis=-1)>>>print("Predicted class:", model.config.id2label[predicted_class_idx.item()])