Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Implementation of CRF layer in Keras.

License

NotificationsYou must be signed in to change notification settings

cxf2015/keras-crf-layer

 
 

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

10 Commits
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

The Keras-CRF-Layer module implements a linear-chain CRF layer for learning to predict tag sequences.This variant of the CRF is factored into unary potentials for every element in the sequence and binary potentials for every transition between output tags.

Usage

Below is an example of the API, which learns a CRF for some random data.The linear layer in the example can be replaced by any neural network.

importnumpyasnpfromkeras.layersimportEmbedding,Inputfromkeras.modelsimportModelfromcrfimportCRFLayer# Hyperparameter settings.vocab_size=20n_classes=11batch_size=2maxlen=2# Random features.x=np.random.randint(1,vocab_size,size=(batch_size,maxlen))# Random tag indices representing the gold sequence.y=np.random.randint(n_classes,size=(batch_size,maxlen))y=np.eye(n_classes)[y]# All sequences in this example have the same length, but they can be variable in a real model.s=np.asarray([maxlen]*batch_size,dtype='int32')# Build an example model.word_ids=Input(batch_shape=(batch_size,maxlen),dtype='int32')sequence_lengths=Input(batch_shape=[batch_size,1],dtype='int32')word_embeddings=Embedding(vocab_size,n_classes)(word_ids)crf=CRFLayer()pred=crf(inputs=[word_embeddings,sequence_lengths])model=Model(inputs=[word_ids,sequence_lengths],outputs=[pred])model.compile(loss=crf.loss,optimizer='sgd')# Train first 1 batch.model.train_on_batch([x,s],y)# Save the modelmodel.save('model.h5')

Model loading

When you want to load a saved model that has a crf output, then loadingthe model with 'keras.models.load_model' won't work properly becausethe reference of the loss function to the transition parameters is lost. Tofix this, you need to use the parameter 'custom_objects' as follows:

fromkeras.modelsimportload_modelfromcrfimportcreate_custom_objectsmodel=load_model('model.h5',custom_objects=create_custom_objects())

About

Implementation of CRF layer in Keras.

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages

  • Python100.0%

[8]ページ先頭

©2009-2025 Movatter.jp