- Notifications
You must be signed in to change notification settings - Fork196
A flexible package for multimodal-deep-learning to combine tabular data with text and images using Wide and Deep models in Pytorch
License
Apache-2.0, MIT licenses found
Licenses found
jrzaurin/pytorch-widedeep
Folders and files
Name | Name | Last commit message | Last commit date | |
---|---|---|---|---|
Repository files navigation
A flexible package for multimodal-deep-learning to combine tabular data withtext and images using Wide and Deep models in Pytorch
Documentation:https://pytorch-widedeep.readthedocs.io
Companion posts and tutorials:infinitoml
Experiments and comparison withLightGBM
:TabularDL vs LightGBM
Slack: if you want to contribute or just want to chat with us, joinslack
The content of this document is organized as follows:
pytorch-widedeep
is based on Google'sWide and Deep Algorithm,adjusted for multi-modal datasets.
In general terms,pytorch-widedeep
is a package to use deep learning withtabular data. In particular, is intended to facilitate the combination oftext and images with corresponding tabular data using wide and deep models.With that in mind there are a number of architectures that can be implementedwith the library. The main components of those architectures are shown in theFigure below:
In math terms, and following the notation in thepaper, the expression for the architecturewithout adeephead
component can be formulated as:
Where σ is the sigmoid function,'W' are the weight matrices applied to the wide model and to the finalactivations of the deep models,'a' are these final activations,φ(x) are the cross product transformations of the original features'x', and, and'b' is the bias term.In case you are wondering what are"cross product transformations", here isa quote taken directly from the paper:"For binary features, a cross-producttransformation (e.g., “AND(gender=female, language=en)”) is 1 if and only ifthe constituent features (“gender=female” and “language=en”) are all 1, and 0otherwise".
It is perfectly possible to use custom models (and not necessarily those inthe library) as long as the the custom models have an property calledoutput_dim
with the size of the last layer of activations, so thatWideDeep
can be constructed. Examples on how to use custom components canbe found in the Examples folder and the section below.
Thepytorch-widedeep
library offers a number of different architectures. Inthis section we will show some of them in their simplest form (i.e. withdefault param values in most cases) with their corresponding code snippets.Note thatall the snippets below shoud run locally. For a more detailedexplanation of the different components and their parameters, please refer tothe documentation.
For the examples below we will be using a toy dataset generated as follows:
importosimportrandomimportnumpyasnpimportpandasaspdfromPILimportImagefromfakerimportFakerdefcreate_and_save_random_image(image_number,size=(32,32)):ifnotos.path.exists("images"):os.makedirs("images")array=np.random.randint(0,256, (size[0],size[1],3),dtype=np.uint8)image=Image.fromarray(array)image_name=f"image_{image_number}.png"image.save(os.path.join("images",image_name))returnimage_namefake=Faker()cities= ["New York","Los Angeles","Chicago","Houston"]names= ["Alice","Bob","Charlie","David","Eva"]data= {"city": [random.choice(cities)for_inrange(100)],"name": [random.choice(names)for_inrange(100)],"age": [random.uniform(18,70)for_inrange(100)],"height": [random.uniform(150,200)for_inrange(100)],"sentence": [fake.sentence()for_inrange(100)],"other_sentence": [fake.sentence()for_inrange(100)],"image_name": [create_and_save_random_image(i)foriinrange(100)],"target": [random.choice([0,1])for_inrange(100)],}df=pd.DataFrame(data)
This will create a 100 rows dataframe and a dir in your local folder, calledimages
with 100 random images (or images with just noise).
Perhaps the simplest architecture would be just one component,wide
,deeptabular
,deeptext
ordeepimage
on their own, which is alsopossible, but let's start the examples with a standard Wide and Deeparchitecture. From there, how to build a model comprised only of onecomponent will be straightforward.
Note that the examples shown below would be almost identical using any of themodels available in the library. For example,TabMlp
can be replaced byTabResnet
,TabNet
,TabTransformer
, etc. Similarly,BasicRNN
can bereplaced byAttentiveRNN
,StackedAttentiveRNN
, orHFModel
withtheir corresponding parameters and preprocessor in the case of the HuggingFace models.
1. Wide and Tabular component (aka deeptabular)
frompytorch_widedeep.preprocessingimportTabPreprocessor,WidePreprocessorfrompytorch_widedeep.modelsimportWide,TabMlp,WideDeepfrompytorch_widedeep.trainingimportTrainer# Widewide_cols= ["city"]crossed_cols= [("city","name")]wide_preprocessor=WidePreprocessor(wide_cols=wide_cols,crossed_cols=crossed_cols)X_wide=wide_preprocessor.fit_transform(df)wide=Wide(input_dim=np.unique(X_wide).shape[0])# Tabulartab_preprocessor=TabPreprocessor(embed_cols=["city","name"],continuous_cols=["age","height"])X_tab=tab_preprocessor.fit_transform(df)tab_mlp=TabMlp(column_idx=tab_preprocessor.column_idx,cat_embed_input=tab_preprocessor.cat_embed_input,continuous_cols=tab_preprocessor.continuous_cols,mlp_hidden_dims=[64,32],)# WideDeepmodel=WideDeep(wide=wide,deeptabular=tab_mlp)# Traintrainer=Trainer(model,objective="binary")trainer.fit(X_wide=X_wide,X_tab=X_tab,target=df["target"].values,n_epochs=1,batch_size=32,)
2. Tabular and Text data
frompytorch_widedeep.preprocessingimportTabPreprocessor,TextPreprocessorfrompytorch_widedeep.modelsimportTabMlp,BasicRNN,WideDeepfrompytorch_widedeep.trainingimportTrainer# Tabulartab_preprocessor=TabPreprocessor(embed_cols=["city","name"],continuous_cols=["age","height"])X_tab=tab_preprocessor.fit_transform(df)tab_mlp=TabMlp(column_idx=tab_preprocessor.column_idx,cat_embed_input=tab_preprocessor.cat_embed_input,continuous_cols=tab_preprocessor.continuous_cols,mlp_hidden_dims=[64,32],)# Texttext_preprocessor=TextPreprocessor(text_col="sentence",maxlen=20,max_vocab=100,n_cpus=1)X_text=text_preprocessor.fit_transform(df)rnn=BasicRNN(vocab_size=len(text_preprocessor.vocab.itos),embed_dim=16,hidden_dim=8,n_layers=1,)# WideDeepmodel=WideDeep(deeptabular=tab_mlp,deeptext=rnn)# Traintrainer=Trainer(model,objective="binary")trainer.fit(X_tab=X_tab,X_text=X_text,target=df["target"].values,n_epochs=1,batch_size=32,)
3. Tabular and text with a FC head on top via thehead_hidden_dims
paraminWideDeep
frompytorch_widedeep.preprocessingimportTabPreprocessor,TextPreprocessorfrompytorch_widedeep.modelsimportTabMlp,BasicRNN,WideDeepfrompytorch_widedeep.trainingimportTrainer# Tabulartab_preprocessor=TabPreprocessor(embed_cols=["city","name"],continuous_cols=["age","height"])X_tab=tab_preprocessor.fit_transform(df)tab_mlp=TabMlp(column_idx=tab_preprocessor.column_idx,cat_embed_input=tab_preprocessor.cat_embed_input,continuous_cols=tab_preprocessor.continuous_cols,mlp_hidden_dims=[64,32],)# Texttext_preprocessor=TextPreprocessor(text_col="sentence",maxlen=20,max_vocab=100,n_cpus=1)X_text=text_preprocessor.fit_transform(df)rnn=BasicRNN(vocab_size=len(text_preprocessor.vocab.itos),embed_dim=16,hidden_dim=8,n_layers=1,)# WideDeepmodel=WideDeep(deeptabular=tab_mlp,deeptext=rnn,head_hidden_dims=[32,16])# Traintrainer=Trainer(model,objective="binary")trainer.fit(X_tab=X_tab,X_text=X_text,target=df["target"].values,n_epochs=1,batch_size=32,)
4. Tabular and multiple text columns that are passed directly toWideDeep
frompytorch_widedeep.preprocessingimportTabPreprocessor,TextPreprocessorfrompytorch_widedeep.modelsimportTabMlp,BasicRNN,WideDeepfrompytorch_widedeep.trainingimportTrainer# Tabulartab_preprocessor=TabPreprocessor(embed_cols=["city","name"],continuous_cols=["age","height"])X_tab=tab_preprocessor.fit_transform(df)tab_mlp=TabMlp(column_idx=tab_preprocessor.column_idx,cat_embed_input=tab_preprocessor.cat_embed_input,continuous_cols=tab_preprocessor.continuous_cols,mlp_hidden_dims=[64,32],)# Texttext_preprocessor_1=TextPreprocessor(text_col="sentence",maxlen=20,max_vocab=100,n_cpus=1)X_text_1=text_preprocessor_1.fit_transform(df)text_preprocessor_2=TextPreprocessor(text_col="other_sentence",maxlen=20,max_vocab=100,n_cpus=1)X_text_2=text_preprocessor_2.fit_transform(df)rnn_1=BasicRNN(vocab_size=len(text_preprocessor_1.vocab.itos),embed_dim=16,hidden_dim=8,n_layers=1,)rnn_2=BasicRNN(vocab_size=len(text_preprocessor_2.vocab.itos),embed_dim=16,hidden_dim=8,n_layers=1,)# WideDeepmodel=WideDeep(deeptabular=tab_mlp,deeptext=[rnn_1,rnn_2])# Traintrainer=Trainer(model,objective="binary")trainer.fit(X_tab=X_tab,X_text=[X_text_1,X_text_2],target=df["target"].values,n_epochs=1,batch_size=32,)
5. Tabular data and multiple text columns that are fused via a the library'sModelFuser
class
frompytorch_widedeep.preprocessingimportTabPreprocessor,TextPreprocessorfrompytorch_widedeep.modelsimportTabMlp,BasicRNN,WideDeep,ModelFuserfrompytorch_widedeepimportTrainer# Tabulartab_preprocessor=TabPreprocessor(embed_cols=["city","name"],continuous_cols=["age","height"])X_tab=tab_preprocessor.fit_transform(df)tab_mlp=TabMlp(column_idx=tab_preprocessor.column_idx,cat_embed_input=tab_preprocessor.cat_embed_input,continuous_cols=tab_preprocessor.continuous_cols,mlp_hidden_dims=[64,32],)# Texttext_preprocessor_1=TextPreprocessor(text_col="sentence",maxlen=20,max_vocab=100,n_cpus=1)X_text_1=text_preprocessor_1.fit_transform(df)text_preprocessor_2=TextPreprocessor(text_col="other_sentence",maxlen=20,max_vocab=100,n_cpus=1)X_text_2=text_preprocessor_2.fit_transform(df)rnn_1=BasicRNN(vocab_size=len(text_preprocessor_1.vocab.itos),embed_dim=16,hidden_dim=8,n_layers=1,)rnn_2=BasicRNN(vocab_size=len(text_preprocessor_2.vocab.itos),embed_dim=16,hidden_dim=8,n_layers=1,)models_fuser=ModelFuser(models=[rnn_1,rnn_2],fusion_method="mult")# WideDeepmodel=WideDeep(deeptabular=tab_mlp,deeptext=models_fuser)# Traintrainer=Trainer(model,objective="binary")trainer.fit(X_tab=X_tab,X_text=[X_text_1,X_text_2],target=df["target"].values,n_epochs=1,batch_size=32,)
6. Tabular and multiple text columns, with an image column. The text columnsare fused via the library'sModelFuser
and then all fused via thedeephead paramenter inWideDeep
which is a customModelFuser
coded bythe user
This is perhaps the less elegant solution as it involves a custom component bythe user and slicing the 'incoming' tensor. In the future, we will include aTextAndImageModelFuser
to make this process more straightforward. Still, is notreally complicated and it is a good example of how to use custom components inpytorch-widedeep
.
Note that the only requirement for the custom component is that it has aproperty calledoutput_dim
that returns the size of the last layer ofactivations. In other words, it does not need to inherit fromBaseWDModelComponent
. This base class simply checks the existence of suchproperty and avoids some typing errors internally.
importtorchfrompytorch_widedeep.preprocessingimportTabPreprocessor,TextPreprocessor,ImagePreprocessorfrompytorch_widedeep.modelsimportTabMlp,BasicRNN,WideDeep,ModelFuser,Visionfrompytorch_widedeep.models._base_wd_model_componentimportBaseWDModelComponentfrompytorch_widedeepimportTrainer# Tabulartab_preprocessor=TabPreprocessor(embed_cols=["city","name"],continuous_cols=["age","height"])X_tab=tab_preprocessor.fit_transform(df)tab_mlp=TabMlp(column_idx=tab_preprocessor.column_idx,cat_embed_input=tab_preprocessor.cat_embed_input,continuous_cols=tab_preprocessor.continuous_cols,mlp_hidden_dims=[16,8],)# Texttext_preprocessor_1=TextPreprocessor(text_col="sentence",maxlen=20,max_vocab=100,n_cpus=1)X_text_1=text_preprocessor_1.fit_transform(df)text_preprocessor_2=TextPreprocessor(text_col="other_sentence",maxlen=20,max_vocab=100,n_cpus=1)X_text_2=text_preprocessor_2.fit_transform(df)rnn_1=BasicRNN(vocab_size=len(text_preprocessor_1.vocab.itos),embed_dim=16,hidden_dim=8,n_layers=1,)rnn_2=BasicRNN(vocab_size=len(text_preprocessor_2.vocab.itos),embed_dim=16,hidden_dim=8,n_layers=1,)models_fuser=ModelFuser(models=[rnn_1,rnn_2],fusion_method="mult",)# Imageimage_preprocessor=ImagePreprocessor(img_col="image_name",img_path="images")X_img=image_preprocessor.fit_transform(df)vision=Vision(pretrained_model_setup="resnet18",head_hidden_dims=[16,8])# deephead (custom model fuser)classMyModelFuser(BaseWDModelComponent):""" Simply a Linear + Relu sequence on top of the text + images followed by a Linear -> Relu -> Linear for the concatenation of tabular slice of the tensor and the output of the text and image sequential model """def__init__(self,tab_incoming_dim:int,text_incoming_dim:int,image_incoming_dim:int,output_units:int, ):super(MyModelFuser,self).__init__()self.tab_incoming_dim=tab_incoming_dimself.text_incoming_dim=text_incoming_dimself.image_incoming_dim=image_incoming_dimself.output_units=output_unitsself.text_and_image_fuser=torch.nn.Sequential(torch.nn.Linear(text_incoming_dim+image_incoming_dim,output_units),torch.nn.ReLU(), )self.out=torch.nn.Sequential(torch.nn.Linear(output_units+tab_incoming_dim,output_units*4),torch.nn.ReLU(),torch.nn.Linear(output_units*4,output_units), )defforward(self,X:torch.Tensor)->torch.Tensor:tab_slice=slice(0,self.tab_incoming_dim)text_slice=slice(self.tab_incoming_dim,self.tab_incoming_dim+self.text_incoming_dim )image_slice=slice(self.tab_incoming_dim+self.text_incoming_dim,self.tab_incoming_dim+self.text_incoming_dim+self.image_incoming_dim, )X_tab=X[:,tab_slice]X_text=X[:,text_slice]X_img=X[:,image_slice]X_text_and_image=self.text_and_image_fuser(torch.cat([X_text,X_img],dim=1))returnself.out(torch.cat([X_tab,X_text_and_image],dim=1))@propertydefoutput_dim(self):returnself.output_unitsdeephead=MyModelFuser(tab_incoming_dim=tab_mlp.output_dim,text_incoming_dim=models_fuser.output_dim,image_incoming_dim=vision.output_dim,output_units=8,)# WideDeepmodel=WideDeep(deeptabular=tab_mlp,deeptext=models_fuser,deepimage=vision,deephead=deephead,)# Traintrainer=Trainer(model,objective="binary")trainer.fit(X_tab=X_tab,X_text=[X_text_1,X_text_2],X_img=X_img,target=df["target"].values,n_epochs=1,batch_size=32,)
7. A two-tower model
This is a popular model in the context of recommendation systems. Let's say wehave a tabular dataset formed my triples (user features, item features,target). We can create a two-tower model where the user and item features arepassed through two separate models and then "fused" via a dot product.
importnumpyasnpimportpandasaspdfrompytorch_widedeepimportTrainerfrompytorch_widedeep.preprocessingimportTabPreprocessorfrompytorch_widedeep.modelsimportTabMlp,WideDeep,ModelFuser# Let's create the interaction dataset# user_features dataframenp.random.seed(42)user_ids=np.arange(1,101)ages=np.random.randint(18,60,size=100)genders=np.random.choice(["male","female"],size=100)locations=np.random.choice(["city_a","city_b","city_c","city_d"],size=100)user_features=pd.DataFrame( {"id":user_ids,"age":ages,"gender":genders,"location":locations})# item_features dataframeitem_ids=np.arange(1,101)prices=np.random.uniform(10,500,size=100).round(2)colors=np.random.choice(["red","blue","green","black"],size=100)categories=np.random.choice(["electronics","clothing","home","toys"],size=100)item_features=pd.DataFrame( {"id":item_ids,"price":prices,"color":colors,"category":categories})# Interactions dataframeinteraction_user_ids=np.random.choice(user_ids,size=1000)interaction_item_ids=np.random.choice(item_ids,size=1000)purchased=np.random.choice([0,1],size=1000,p=[0.7,0.3])interactions=pd.DataFrame( {"user_id":interaction_user_ids,"item_id":interaction_item_ids,"purchased":purchased, })user_item_purchased=interactions.merge(user_features,left_on="user_id",right_on="id").merge(item_features,left_on="item_id",right_on="id")# Userstab_preprocessor_user=TabPreprocessor(cat_embed_cols=["gender","location"],continuous_cols=["age"],)X_user=tab_preprocessor_user.fit_transform(user_item_purchased)tab_mlp_user=TabMlp(column_idx=tab_preprocessor_user.column_idx,cat_embed_input=tab_preprocessor_user.cat_embed_input,continuous_cols=["age"],mlp_hidden_dims=[16,8],mlp_dropout=[0.2,0.2],)# Itemstab_preprocessor_item=TabPreprocessor(cat_embed_cols=["color","category"],continuous_cols=["price"],)X_item=tab_preprocessor_item.fit_transform(user_item_purchased)tab_mlp_item=TabMlp(column_idx=tab_preprocessor_item.column_idx,cat_embed_input=tab_preprocessor_item.cat_embed_input,continuous_cols=["price"],mlp_hidden_dims=[16,8],mlp_dropout=[0.2,0.2],)two_tower_model=ModelFuser([tab_mlp_user,tab_mlp_item],fusion_method="dot")model=WideDeep(deeptabular=two_tower_model)trainer=Trainer(model,objective="binary")trainer.fit(X_tab=[X_user,X_item],target=interactions.purchased.values,n_epochs=1,batch_size=32,)
8. Tabular with a multi-target loss
This one is "a bonus" to illustrate the use of multi-target losses, more thanactually a different architecture.
frompytorch_widedeep.preprocessingimportTabPreprocessor,TextPreprocessor,ImagePreprocessorfrompytorch_widedeep.modelsimportTabMlp,BasicRNN,WideDeep,ModelFuser,Visionfrompytorch_widedeep.losses_multitargetimportMultiTargetClassificationLossfrompytorch_widedeep.models._base_wd_model_componentimportBaseWDModelComponentfrompytorch_widedeepimportTrainer# let's add a second target to the dataframedf["target2"]= [random.choice([0,1])for_inrange(100)]# Tabulartab_preprocessor=TabPreprocessor(embed_cols=["city","name"],continuous_cols=["age","height"])X_tab=tab_preprocessor.fit_transform(df)tab_mlp=TabMlp(column_idx=tab_preprocessor.column_idx,cat_embed_input=tab_preprocessor.cat_embed_input,continuous_cols=tab_preprocessor.continuous_cols,mlp_hidden_dims=[64,32],)# 'pred_dim=2' because we have two binary targets. For other types of targets,# please, see the documentationmodel=WideDeep(deeptabular=tab_mlp,pred_dim=2).loss=MultiTargetClassificationLoss(binary_config=[0,1],reduction="mean")# When a multi-target loss is used, 'custom_loss_function' must not be None.# See the docstrainer=Trainer(model,objective="multitarget",custom_loss_function=loss)trainer.fit(X_tab=X_tab,target=df[["target","target2"]].values,n_epochs=1,batch_size=32,)
It is important to emphasize again thateach individual component,wide
,deeptabular
,deeptext
anddeepimage
, can be used independently and inisolation. For example, one could use onlywide
, which is in simply alinear model. In fact, one of the most interesting functionalitiesinpytorch-widedeep
would be the use of thedeeptabular
component onits own, i.e. what one might normally refer as Deep Learning for TabularData. Currently,pytorch-widedeep
offers the following different modelsfor that component:
- Wide: a simple linear model where the nonlinearities are captured viacross-product transformations, as explained before.
- TabMlp: a simple MLP that receives embeddings representing thecategorical features, concatenated with the continuous features, which canalso be embedded.
- TabResnet: similar to the previous model but the embeddings arepassed through a series of ResNet blocks built with dense layers.
- TabNet: details on TabNet can be found inTabNet: Attentive Interpretable Tabular Learning
Two simpler attention based models that we call:
- ContextAttentionMLP: MLP with at attention mechanism "on top" that is based onHierarchical Attention Networks for Document Classification
- SelfAttentionMLP: MLP with an attention mechanism that is a simplifiedversion of a transformer block that we refer as "query-key self-attention".
TheTabformer
family, i.e. Transformers for Tabular data:
- TabTransformer: details on the TabTransformer can be found inTabTransformer: Tabular Data Modeling Using Contextual Embeddings.
- SAINT: Details on SAINT can be found inSAINT: Improved Neural Networks for Tabular Data via Row Attention and Contrastive Pre-Training.
- FT-Transformer: details on the FT-Transformer can be found inRevisiting Deep Learning Models for Tabular Data.
- TabFastFormer: adaptation of the FastFormer for tabular data. Detailson the Fasformer can be found inFastFormers: Highly Efficient Transformer Models for Natural Language Understanding
- TabPerceiver: adaptation of the Perceiver for tabular data. Details onthe Perceiver can be found inPerceiver: General Perception with Iterative Attention
And probabilistic DL models for tabular data based onWeight Uncertainty in Neural Networks:
- BayesianWide: Probabilistic adaptation of the
Wide
model. - BayesianTabMlp: Probabilistic adaptation of the
TabMlp
model
Note that while there are scientific publications for the TabTransformer,SAINT and FT-Transformer, the TabFasfFormer and TabPerceiver are our ownadaptation of those algorithms for tabular data.
In addition, Self-Supervised pre-training can be used for alldeeptabular
models, with the exception of theTabPerceiver
. Self-Supervisedpre-training can be used via two methods or routines which we refer as:encoder-decoder method and constrastive-denoising method. Please, see thedocumentation and the examples for details on this functionality, and allother options in the library.
This module was introduced as an extension to the existing components in thelibrary, addressing questions and issues related to recommendation systems.While still under active development, it currently includes a select numberof powerful recommendation models.
It's worth noting that this library already supported the implementation ofvarious recommendation algorithms using existing components. For example,models like Wide and Deep, Two-Tower, or Neural Collaborative Filtering couldbe constructed using the library's core functionalities.
The recommendation algorithms in therec
module are:
- AutoInt: Automatic Feature Interaction Learning via Self-Attentive Neural Networks
- DeepFM: A Factorization-Machine based Neural Network for CTR Prediction
- (Deep) Field Aware Factorization Machine (FFM): a Deep Learning version of the algorithm presented inField-aware Factorization Machines in a Real-world Online Advertising System
- xDeepFM: Combining Explicit and Implicit Feature Interactions for Recommender Systems
- Deep Interest Network for Click-Through Rate Prediction
- Deep and Cross Network for Ad Click Predictions
- DCN V2: Improved Deep & Cross Network and Practical Lessons for Web-scale Learning to Rank Systems
- Towards Deeper, Lighter and Interpretable Click-through Rate Prediction
- A basic Transformer-based model for recommendation where the problem is faced as a sequence.
See the examples for details on how to use these models.
For the text component,deeptext
, the library offers the following models:
- BasicRNN: a simple RNN 2.AttentiveRNN: a RNN with an attentionmechanism based on theHierarchical Attention Networks for DocumentClassification
- StackedAttentiveRNN: a stack of AttentiveRNNs
- HFModel: a wrapper around Hugging Face Transfomer-based models. At the momentonly models from the families BERT, RoBERTa, DistilBERT, ALBERT and ELECTRAare supported. This is because this library is designed to addressclassification and regression tasks and these are the most 'popular'encoder-only models, which have proved to be those that work best for thesetasks. If there is demand for other models, they will be included in thefuture.
For the image component,deepimage
, the library supports models from thefollowing families:'resnet', 'shufflenet', 'resnext', 'wide_resnet', 'regnet', 'densenet', 'mobilenetv3','mobilenetv2', 'mnasnet', 'efficientnet' and 'squeezenet'. These areoffered viatorchvision
and wrapped up in theVision
class.
Install using pip:
pip install pytorch-widedeep
Or install directly from github
pip install git+https://github.com/jrzaurin/pytorch-widedeep.git
# Clone the repositorygit clone https://github.com/jrzaurin/pytorch-widedeepcd pytorch-widedeep# Install in dev modepip install -e.
Here is an end-to-end example of a binary classification with the adultdatasetusingWide
andDeepDense
and defaults settings.
Building a wide (linear) and deep model withpytorch-widedeep
:
importnumpyasnpimporttorchfromsklearn.model_selectionimporttrain_test_splitfrompytorch_widedeepimportTrainerfrompytorch_widedeep.preprocessingimportWidePreprocessor,TabPreprocessorfrompytorch_widedeep.modelsimportWide,TabMlp,WideDeepfrompytorch_widedeep.metricsimportAccuracyfrompytorch_widedeep.datasetsimportload_adultdf=load_adult(as_frame=True)df["income_label"]= (df["income"].apply(lambdax:">50K"inx)).astype(int)df.drop("income",axis=1,inplace=True)df_train,df_test=train_test_split(df,test_size=0.2,stratify=df.income_label)# Define the 'column set up'wide_cols= ["education","relationship","workclass","occupation","native-country","gender",]crossed_cols= [("education","occupation"), ("native-country","occupation")]cat_embed_cols= ["workclass","education","marital-status","occupation","relationship","race","gender","capital-gain","capital-loss","native-country",]continuous_cols= ["age","hours-per-week"]target="income_label"target=df_train[target].values# prepare the datawide_preprocessor=WidePreprocessor(wide_cols=wide_cols,crossed_cols=crossed_cols)X_wide=wide_preprocessor.fit_transform(df_train)tab_preprocessor=TabPreprocessor(cat_embed_cols=cat_embed_cols,continuous_cols=continuous_cols# type: ignore[arg-type])X_tab=tab_preprocessor.fit_transform(df_train)# build the modelwide=Wide(input_dim=np.unique(X_wide).shape[0],pred_dim=1)tab_mlp=TabMlp(column_idx=tab_preprocessor.column_idx,cat_embed_input=tab_preprocessor.cat_embed_input,continuous_cols=continuous_cols,)model=WideDeep(wide=wide,deeptabular=tab_mlp)# train and validatetrainer=Trainer(model,objective="binary",metrics=[Accuracy])trainer.fit(X_wide=X_wide,X_tab=X_tab,target=target,n_epochs=5,batch_size=256,)# predict on testX_wide_te=wide_preprocessor.transform(df_test)X_tab_te=tab_preprocessor.transform(df_test)preds=trainer.predict(X_wide=X_wide_te,X_tab=X_tab_te)# Save and load# Option 1: this will also save training history and lr history if the# LRHistory callback is usedtrainer.save(path="model_weights",save_state_dict=True)# Option 2: save as any other torch modeltorch.save(model.state_dict(),"model_weights/wd_model.pt")# From here in advance, Option 1 or 2 are the same. I assume the user has# prepared the data and defined the new model components:# 1. Build the modelmodel_new=WideDeep(wide=wide,deeptabular=tab_mlp)model_new.load_state_dict(torch.load("model_weights/wd_model.pt"))# 2. Instantiate the trainertrainer_new=Trainer(model_new,objective="binary")# 3. Either start the fit or directly predictpreds=trainer_new.predict(X_wide=X_wide,X_tab=X_tab,batch_size=32)
Of course, one can domuch more. See the Examples folder, thedocumentation or the companion posts for a better understanding of the contentof the package and its functionalities.
pytest tests
CheckCONTRIBUTING page.
This library takes from a series of other libraries, so I think it is justfair to mention them here in the README (specific mentions are also includedin the code).
TheCallbacks
andInitializers
structure and code is inspired by thetorchsample
library, which initself partially inspired byKeras
.
TheTextProcessor
class in this library uses thefastai
'sTokenizer
andVocab
. The code atutils.fastai_transforms
is a minoradaptation of their code so it functions within this library. To my experiencetheirTokenizer
is the best in class.
TheImageProcessor
class in this library uses code from the fantasticDeepLearning for ComputerVision(DL4CV) book by Adrian Rosebrock.
This work is dual-licensed under Apache 2.0 and MIT (or any later version).You can choose between one of them if you use this work.
SPDX-License-Identifier: Apache-2.0 AND MIT
@article{Zaurin_pytorch-widedeep_A_flexible_2023,author = {Zaurin, Javier Rodriguez and Mulinka, Pavol},doi = {10.21105/joss.05027},journal = {Journal of Open Source Software},month = jun,number = {86},pages = {5027},title = {{pytorch-widedeep: A flexible package for multimodal deep learning}},url = {https://joss.theoj.org/papers/10.21105/joss.05027},volume = {8},year = {2023}}
Zaurin, J. R., & Mulinka, P. (2023). pytorch-widedeep: A flexible package formultimodal deep learning. Journal of Open Source Software, 8(86), 5027.https://doi.org/10.21105/joss.05027
About
A flexible package for multimodal-deep-learning to combine tabular data with text and images using Wide and Deep models in Pytorch
Topics
Resources
License
Apache-2.0, MIT licenses found
Licenses found
Contributing
Uh oh!
There was an error while loading.Please reload this page.
Stars
Watchers
Forks
Packages0
Uh oh!
There was an error while loading.Please reload this page.
Contributors10
Uh oh!
There was an error while loading.Please reload this page.