Instantly share code, notes, and snippets.
Save karpathy/587454dc0146a6ae21fc to your computer and use it in GitHub Desktop.
""" | |
This is a batched LSTM forward and backward pass | |
""" | |
importnumpyasnp | |
importcode | |
classLSTM: | |
@staticmethod | |
definit(input_size,hidden_size,fancy_forget_bias_init=3): | |
""" | |
Initialize parameters of the LSTM (both weights and biases in one matrix) | |
One might way to have a positive fancy_forget_bias_init number (e.g. maybe even up to 5, in some papers) | |
""" | |
# +1 for the biases, which will be the first row of WLSTM | |
WLSTM=np.random.randn(input_size+hidden_size+1,4*hidden_size)/np.sqrt(input_size+hidden_size) | |
WLSTM[0,:]=0# initialize biases to zero | |
iffancy_forget_bias_init!=0: | |
# forget gates get little bit negative bias initially to encourage them to be turned off | |
# remember that due to Xavier initialization above, the raw output activations from gates before | |
# nonlinearity are zero mean and on order of standard deviation ~1 | |
WLSTM[0,hidden_size:2*hidden_size]=fancy_forget_bias_init | |
returnWLSTM | |
@staticmethod | |
defforward(X,WLSTM,c0=None,h0=None): | |
""" | |
X should be of shape (n,b,input_size), where n = length of sequence, b = batch size | |
""" | |
n,b,input_size=X.shape | |
d=WLSTM.shape[1]/4# hidden size | |
ifc0isNone:c0=np.zeros((b,d)) | |
ifh0isNone:h0=np.zeros((b,d)) | |
# Perform the LSTM forward pass with X as the input | |
xphpb=WLSTM.shape[0]# x plus h plus bias, lol | |
Hin=np.zeros((n,b,xphpb))# input [1, xt, ht-1] to each tick of the LSTM | |
Hout=np.zeros((n,b,d))# hidden representation of the LSTM (gated cell content) | |
IFOG=np.zeros((n,b,d*4))# input, forget, output, gate (IFOG) | |
IFOGf=np.zeros((n,b,d*4))# after nonlinearity | |
C=np.zeros((n,b,d))# cell content | |
Ct=np.zeros((n,b,d))# tanh of cell content | |
fortinxrange(n): | |
# concat [x,h] as input to the LSTM | |
prevh=Hout[t-1]ift>0elseh0 | |
Hin[t,:,0]=1# bias | |
Hin[t,:,1:input_size+1]=X[t] | |
Hin[t,:,input_size+1:]=prevh | |
# compute all gate activations. dots: (most work is this line) | |
IFOG[t]=Hin[t].dot(WLSTM) | |
# non-linearities | |
IFOGf[t,:,:3*d]=1.0/(1.0+np.exp(-IFOG[t,:,:3*d]))# sigmoids; these are the gates | |
IFOGf[t,:,3*d:]=np.tanh(IFOG[t,:,3*d:])# tanh | |
# compute the cell activation | |
prevc=C[t-1]ift>0elsec0 | |
C[t]=IFOGf[t,:,:d]*IFOGf[t,:,3*d:]+IFOGf[t,:,d:2*d]*prevc | |
Ct[t]=np.tanh(C[t]) | |
Hout[t]=IFOGf[t,:,2*d:3*d]*Ct[t] | |
cache= {} | |
cache['WLSTM']=WLSTM | |
cache['Hout']=Hout | |
cache['IFOGf']=IFOGf | |
cache['IFOG']=IFOG | |
cache['C']=C | |
cache['Ct']=Ct | |
cache['Hin']=Hin | |
cache['c0']=c0 | |
cache['h0']=h0 | |
# return C[t], as well so we can continue LSTM with prev state init if needed | |
returnHout,C[t],Hout[t],cache | |
@staticmethod | |
defbackward(dHout_in,cache,dcn=None,dhn=None): | |
WLSTM=cache['WLSTM'] | |
Hout=cache['Hout'] | |
IFOGf=cache['IFOGf'] | |
IFOG=cache['IFOG'] | |
C=cache['C'] | |
Ct=cache['Ct'] | |
Hin=cache['Hin'] | |
c0=cache['c0'] | |
h0=cache['h0'] | |
n,b,d=Hout.shape | |
input_size=WLSTM.shape[0]-d-1# -1 due to bias | |
# backprop the LSTM | |
dIFOG=np.zeros(IFOG.shape) | |
dIFOGf=np.zeros(IFOGf.shape) | |
dWLSTM=np.zeros(WLSTM.shape) | |
dHin=np.zeros(Hin.shape) | |
dC=np.zeros(C.shape) | |
dX=np.zeros((n,b,input_size)) | |
dh0=np.zeros((b,d)) | |
dc0=np.zeros((b,d)) | |
dHout=dHout_in.copy()# make a copy so we don't have any funny side effects | |
ifdcnisnotNone:dC[n-1]+=dcn.copy()# carry over gradients from later | |
ifdhnisnotNone:dHout[n-1]+=dhn.copy() | |
fortinreversed(xrange(n)): | |
tanhCt=Ct[t] | |
dIFOGf[t,:,2*d:3*d]=tanhCt*dHout[t] | |
# backprop tanh non-linearity first then continue backprop | |
dC[t]+= (1-tanhCt**2)* (IFOGf[t,:,2*d:3*d]*dHout[t]) | |
ift>0: | |
dIFOGf[t,:,d:2*d]=C[t-1]*dC[t] | |
dC[t-1]+=IFOGf[t,:,d:2*d]*dC[t] | |
else: | |
dIFOGf[t,:,d:2*d]=c0*dC[t] | |
dc0=IFOGf[t,:,d:2*d]*dC[t] | |
dIFOGf[t,:,:d]=IFOGf[t,:,3*d:]*dC[t] | |
dIFOGf[t,:,3*d:]=IFOGf[t,:,:d]*dC[t] | |
# backprop activation functions | |
dIFOG[t,:,3*d:]= (1-IFOGf[t,:,3*d:]**2)*dIFOGf[t,:,3*d:] | |
y=IFOGf[t,:,:3*d] | |
dIFOG[t,:,:3*d]= (y*(1.0-y))*dIFOGf[t,:,:3*d] | |
# backprop matrix multiply | |
dWLSTM+=np.dot(Hin[t].transpose(),dIFOG[t]) | |
dHin[t]=dIFOG[t].dot(WLSTM.transpose()) | |
# backprop the identity transforms into Hin | |
dX[t]=dHin[t,:,1:input_size+1] | |
ift>0: | |
dHout[t-1,:]+=dHin[t,:,input_size+1:] | |
else: | |
dh0+=dHin[t,:,input_size+1:] | |
returndX,dWLSTM,dc0,dh0 | |
# ------------------- | |
# TEST CASES | |
# ------------------- | |
defcheckSequentialMatchesBatch(): | |
""" check LSTM I/O forward/backward interactions """ | |
n,b,d= (5,3,4)# sequence length, batch size, hidden size | |
input_size=10 | |
WLSTM=LSTM.init(input_size,d)# input size, hidden size | |
X=np.random.randn(n,b,input_size) | |
h0=np.random.randn(b,d) | |
c0=np.random.randn(b,d) | |
# sequential forward | |
cprev=c0 | |
hprev=h0 | |
caches= [{}fortinxrange(n)] | |
Hcat=np.zeros((n,b,d)) | |
fortinxrange(n): | |
xt=X[t:t+1] | |
_,cprev,hprev,cache=LSTM.forward(xt,WLSTM,cprev,hprev) | |
caches[t]=cache | |
Hcat[t]=hprev | |
# sanity check: perform batch forward to check that we get the same thing | |
H,_,_,batch_cache=LSTM.forward(X,WLSTM,c0,h0) | |
assertnp.allclose(H,Hcat),'Sequential and Batch forward don''t match!' | |
# eval loss | |
wrand=np.random.randn(*Hcat.shape) | |
loss=np.sum(Hcat*wrand) | |
dH=wrand | |
# get the batched version gradients | |
BdX,BdWLSTM,Bdc0,Bdh0=LSTM.backward(dH,batch_cache) | |
# now perform sequential backward | |
dX=np.zeros_like(X) | |
dWLSTM=np.zeros_like(WLSTM) | |
dc0=np.zeros_like(c0) | |
dh0=np.zeros_like(h0) | |
dcnext=None | |
dhnext=None | |
fortinreversed(xrange(n)): | |
dht=dH[t].reshape(1,b,d) | |
dx,dWLSTMt,dcprev,dhprev=LSTM.backward(dht,caches[t],dcnext,dhnext) | |
dhnext=dhprev | |
dcnext=dcprev | |
dWLSTM+=dWLSTMt# accumulate LSTM gradient | |
dX[t]=dx[0] | |
ift==0: | |
dc0=dcprev | |
dh0=dhprev | |
# and make sure the gradients match | |
print'Making sure batched version agrees with sequential version: (should all be True)' | |
printnp.allclose(BdX,dX) | |
printnp.allclose(BdWLSTM,dWLSTM) | |
printnp.allclose(Bdc0,dc0) | |
printnp.allclose(Bdh0,dh0) | |
defcheckBatchGradient(): | |
""" check that the batch gradient is correct """ | |
# lets gradient check this beast | |
n,b,d= (5,3,4)# sequence length, batch size, hidden size | |
input_size=10 | |
WLSTM=LSTM.init(input_size,d)# input size, hidden size | |
X=np.random.randn(n,b,input_size) | |
h0=np.random.randn(b,d) | |
c0=np.random.randn(b,d) | |
# batch forward backward | |
H,Ct,Ht,cache=LSTM.forward(X,WLSTM,c0,h0) | |
wrand=np.random.randn(*H.shape) | |
loss=np.sum(H*wrand)# weighted sum is a nice hash to use I think | |
dH=wrand | |
dX,dWLSTM,dc0,dh0=LSTM.backward(dH,cache) | |
deffwd(): | |
h,_,_,_=LSTM.forward(X,WLSTM,c0,h0) | |
returnnp.sum(h*wrand) | |
# now gradient check all | |
delta=1e-5 | |
rel_error_thr_warning=1e-2 | |
rel_error_thr_error=1 | |
tocheck= [X,WLSTM,c0,h0] | |
grads_analytic= [dX,dWLSTM,dc0,dh0] | |
names= ['X','WLSTM','c0','h0'] | |
forjinxrange(len(tocheck)): | |
mat=tocheck[j] | |
dmat=grads_analytic[j] | |
name=names[j] | |
# gradcheck | |
foriinxrange(mat.size): | |
old_val=mat.flat[i] | |
mat.flat[i]=old_val+delta | |
loss0=fwd() | |
mat.flat[i]=old_val-delta | |
loss1=fwd() | |
mat.flat[i]=old_val | |
grad_analytic=dmat.flat[i] | |
grad_numerical= (loss0-loss1)/ (2*delta) | |
ifgrad_numerical==0andgrad_analytic==0: | |
rel_error=0# both are zero, OK. | |
status='OK' | |
elifabs(grad_numerical)<1e-7andabs(grad_analytic)<1e-7: | |
rel_error=0# not enough precision to check this | |
status='VAL SMALL WARNING' | |
else: | |
rel_error=abs(grad_analytic-grad_numerical)/abs(grad_numerical+grad_analytic) | |
status='OK' | |
ifrel_error>rel_error_thr_warning:status='WARNING' | |
ifrel_error>rel_error_thr_error:status='!!!!! NOTOK' | |
# print stats | |
print'%s checking param %s index %s (val = %+8f), analytic = %+8f, numerical = %+8f, relative error = %+8f' \ | |
% (status,name,`np.unravel_index(i, mat.shape)`,old_val,grad_analytic,grad_numerical,rel_error) | |
if__name__=="__main__": | |
checkSequentialMatchesBatch() | |
raw_input('check OK, press key to continue to gradient check') | |
checkBatchGradient() | |
print'every line should start with OK. Have a nice day!' |
pranv commentedSep 11, 2015
This is indeed very efficient. I sat down hoping to rewrite this faster. Early on I changed a lot of things. But later, I reverted most of it.
Have you used numba. It gave me quite a bit of speedup.
upul commentedNov 24, 2015
Thanks a lot, karpathy!!!!
Could you please suggest me a good reference to get familiar with SLTM equations? Presently, I'm using
http://arxiv.org/abs/1503.04069 andhttps://apaszke.github.io/lstm-explained.html
Thanks
ghost commentedFeb 4, 2016
This is fantastic and clearly written. Thanks for this
arita37 commentedMay 18, 2016
Hello,
Any try to convert it slightly in cython ? (it should give a boost of x3)
georgeblck commentedFeb 14, 2017
Thanks for this!
I rewrote the code in R, if anyone is interested.
amin07 commentedNov 2, 2018
What should be the shape of 'dHout_in' in backward pass if I want to consider only nth time's state in my classification/softmax layer??
Alessiobrini commentedJan 3, 2019
Can someone post some example of using this batched LSTM for training over a dataset? I'm new to the domain and I have learned a lot by reading this well written code, but when it is time to use it in a real learning situation I find problems.