Movatterモバイル変換


[0]ホーム

URL:


Skip to content
Search Gists
Sign in Sign up

Instantly share code, notes, and snippets.

@karpathy
Last activeMay 14, 2025 00:08
    Save karpathy/587454dc0146a6ae21fc to your computer and use it in GitHub Desktop.
    An efficient, batched LSTM.
    """
    This is a batched LSTM forward and backward pass
    """
    importnumpyasnp
    importcode
    classLSTM:
    @staticmethod
    definit(input_size,hidden_size,fancy_forget_bias_init=3):
    """
    Initialize parameters of the LSTM (both weights and biases in one matrix)
    One might way to have a positive fancy_forget_bias_init number (e.g. maybe even up to 5, in some papers)
    """
    # +1 for the biases, which will be the first row of WLSTM
    WLSTM=np.random.randn(input_size+hidden_size+1,4*hidden_size)/np.sqrt(input_size+hidden_size)
    WLSTM[0,:]=0# initialize biases to zero
    iffancy_forget_bias_init!=0:
    # forget gates get little bit negative bias initially to encourage them to be turned off
    # remember that due to Xavier initialization above, the raw output activations from gates before
    # nonlinearity are zero mean and on order of standard deviation ~1
    WLSTM[0,hidden_size:2*hidden_size]=fancy_forget_bias_init
    returnWLSTM
    @staticmethod
    defforward(X,WLSTM,c0=None,h0=None):
    """
    X should be of shape (n,b,input_size), where n = length of sequence, b = batch size
    """
    n,b,input_size=X.shape
    d=WLSTM.shape[1]/4# hidden size
    ifc0isNone:c0=np.zeros((b,d))
    ifh0isNone:h0=np.zeros((b,d))
    # Perform the LSTM forward pass with X as the input
    xphpb=WLSTM.shape[0]# x plus h plus bias, lol
    Hin=np.zeros((n,b,xphpb))# input [1, xt, ht-1] to each tick of the LSTM
    Hout=np.zeros((n,b,d))# hidden representation of the LSTM (gated cell content)
    IFOG=np.zeros((n,b,d*4))# input, forget, output, gate (IFOG)
    IFOGf=np.zeros((n,b,d*4))# after nonlinearity
    C=np.zeros((n,b,d))# cell content
    Ct=np.zeros((n,b,d))# tanh of cell content
    fortinxrange(n):
    # concat [x,h] as input to the LSTM
    prevh=Hout[t-1]ift>0elseh0
    Hin[t,:,0]=1# bias
    Hin[t,:,1:input_size+1]=X[t]
    Hin[t,:,input_size+1:]=prevh
    # compute all gate activations. dots: (most work is this line)
    IFOG[t]=Hin[t].dot(WLSTM)
    # non-linearities
    IFOGf[t,:,:3*d]=1.0/(1.0+np.exp(-IFOG[t,:,:3*d]))# sigmoids; these are the gates
    IFOGf[t,:,3*d:]=np.tanh(IFOG[t,:,3*d:])# tanh
    # compute the cell activation
    prevc=C[t-1]ift>0elsec0
    C[t]=IFOGf[t,:,:d]*IFOGf[t,:,3*d:]+IFOGf[t,:,d:2*d]*prevc
    Ct[t]=np.tanh(C[t])
    Hout[t]=IFOGf[t,:,2*d:3*d]*Ct[t]
    cache= {}
    cache['WLSTM']=WLSTM
    cache['Hout']=Hout
    cache['IFOGf']=IFOGf
    cache['IFOG']=IFOG
    cache['C']=C
    cache['Ct']=Ct
    cache['Hin']=Hin
    cache['c0']=c0
    cache['h0']=h0
    # return C[t], as well so we can continue LSTM with prev state init if needed
    returnHout,C[t],Hout[t],cache
    @staticmethod
    defbackward(dHout_in,cache,dcn=None,dhn=None):
    WLSTM=cache['WLSTM']
    Hout=cache['Hout']
    IFOGf=cache['IFOGf']
    IFOG=cache['IFOG']
    C=cache['C']
    Ct=cache['Ct']
    Hin=cache['Hin']
    c0=cache['c0']
    h0=cache['h0']
    n,b,d=Hout.shape
    input_size=WLSTM.shape[0]-d-1# -1 due to bias
    # backprop the LSTM
    dIFOG=np.zeros(IFOG.shape)
    dIFOGf=np.zeros(IFOGf.shape)
    dWLSTM=np.zeros(WLSTM.shape)
    dHin=np.zeros(Hin.shape)
    dC=np.zeros(C.shape)
    dX=np.zeros((n,b,input_size))
    dh0=np.zeros((b,d))
    dc0=np.zeros((b,d))
    dHout=dHout_in.copy()# make a copy so we don't have any funny side effects
    ifdcnisnotNone:dC[n-1]+=dcn.copy()# carry over gradients from later
    ifdhnisnotNone:dHout[n-1]+=dhn.copy()
    fortinreversed(xrange(n)):
    tanhCt=Ct[t]
    dIFOGf[t,:,2*d:3*d]=tanhCt*dHout[t]
    # backprop tanh non-linearity first then continue backprop
    dC[t]+= (1-tanhCt**2)* (IFOGf[t,:,2*d:3*d]*dHout[t])
    ift>0:
    dIFOGf[t,:,d:2*d]=C[t-1]*dC[t]
    dC[t-1]+=IFOGf[t,:,d:2*d]*dC[t]
    else:
    dIFOGf[t,:,d:2*d]=c0*dC[t]
    dc0=IFOGf[t,:,d:2*d]*dC[t]
    dIFOGf[t,:,:d]=IFOGf[t,:,3*d:]*dC[t]
    dIFOGf[t,:,3*d:]=IFOGf[t,:,:d]*dC[t]
    # backprop activation functions
    dIFOG[t,:,3*d:]= (1-IFOGf[t,:,3*d:]**2)*dIFOGf[t,:,3*d:]
    y=IFOGf[t,:,:3*d]
    dIFOG[t,:,:3*d]= (y*(1.0-y))*dIFOGf[t,:,:3*d]
    # backprop matrix multiply
    dWLSTM+=np.dot(Hin[t].transpose(),dIFOG[t])
    dHin[t]=dIFOG[t].dot(WLSTM.transpose())
    # backprop the identity transforms into Hin
    dX[t]=dHin[t,:,1:input_size+1]
    ift>0:
    dHout[t-1,:]+=dHin[t,:,input_size+1:]
    else:
    dh0+=dHin[t,:,input_size+1:]
    returndX,dWLSTM,dc0,dh0
    # -------------------
    # TEST CASES
    # -------------------
    defcheckSequentialMatchesBatch():
    """ check LSTM I/O forward/backward interactions """
    n,b,d= (5,3,4)# sequence length, batch size, hidden size
    input_size=10
    WLSTM=LSTM.init(input_size,d)# input size, hidden size
    X=np.random.randn(n,b,input_size)
    h0=np.random.randn(b,d)
    c0=np.random.randn(b,d)
    # sequential forward
    cprev=c0
    hprev=h0
    caches= [{}fortinxrange(n)]
    Hcat=np.zeros((n,b,d))
    fortinxrange(n):
    xt=X[t:t+1]
    _,cprev,hprev,cache=LSTM.forward(xt,WLSTM,cprev,hprev)
    caches[t]=cache
    Hcat[t]=hprev
    # sanity check: perform batch forward to check that we get the same thing
    H,_,_,batch_cache=LSTM.forward(X,WLSTM,c0,h0)
    assertnp.allclose(H,Hcat),'Sequential and Batch forward don''t match!'
    # eval loss
    wrand=np.random.randn(*Hcat.shape)
    loss=np.sum(Hcat*wrand)
    dH=wrand
    # get the batched version gradients
    BdX,BdWLSTM,Bdc0,Bdh0=LSTM.backward(dH,batch_cache)
    # now perform sequential backward
    dX=np.zeros_like(X)
    dWLSTM=np.zeros_like(WLSTM)
    dc0=np.zeros_like(c0)
    dh0=np.zeros_like(h0)
    dcnext=None
    dhnext=None
    fortinreversed(xrange(n)):
    dht=dH[t].reshape(1,b,d)
    dx,dWLSTMt,dcprev,dhprev=LSTM.backward(dht,caches[t],dcnext,dhnext)
    dhnext=dhprev
    dcnext=dcprev
    dWLSTM+=dWLSTMt# accumulate LSTM gradient
    dX[t]=dx[0]
    ift==0:
    dc0=dcprev
    dh0=dhprev
    # and make sure the gradients match
    print'Making sure batched version agrees with sequential version: (should all be True)'
    printnp.allclose(BdX,dX)
    printnp.allclose(BdWLSTM,dWLSTM)
    printnp.allclose(Bdc0,dc0)
    printnp.allclose(Bdh0,dh0)
    defcheckBatchGradient():
    """ check that the batch gradient is correct """
    # lets gradient check this beast
    n,b,d= (5,3,4)# sequence length, batch size, hidden size
    input_size=10
    WLSTM=LSTM.init(input_size,d)# input size, hidden size
    X=np.random.randn(n,b,input_size)
    h0=np.random.randn(b,d)
    c0=np.random.randn(b,d)
    # batch forward backward
    H,Ct,Ht,cache=LSTM.forward(X,WLSTM,c0,h0)
    wrand=np.random.randn(*H.shape)
    loss=np.sum(H*wrand)# weighted sum is a nice hash to use I think
    dH=wrand
    dX,dWLSTM,dc0,dh0=LSTM.backward(dH,cache)
    deffwd():
    h,_,_,_=LSTM.forward(X,WLSTM,c0,h0)
    returnnp.sum(h*wrand)
    # now gradient check all
    delta=1e-5
    rel_error_thr_warning=1e-2
    rel_error_thr_error=1
    tocheck= [X,WLSTM,c0,h0]
    grads_analytic= [dX,dWLSTM,dc0,dh0]
    names= ['X','WLSTM','c0','h0']
    forjinxrange(len(tocheck)):
    mat=tocheck[j]
    dmat=grads_analytic[j]
    name=names[j]
    # gradcheck
    foriinxrange(mat.size):
    old_val=mat.flat[i]
    mat.flat[i]=old_val+delta
    loss0=fwd()
    mat.flat[i]=old_val-delta
    loss1=fwd()
    mat.flat[i]=old_val
    grad_analytic=dmat.flat[i]
    grad_numerical= (loss0-loss1)/ (2*delta)
    ifgrad_numerical==0andgrad_analytic==0:
    rel_error=0# both are zero, OK.
    status='OK'
    elifabs(grad_numerical)<1e-7andabs(grad_analytic)<1e-7:
    rel_error=0# not enough precision to check this
    status='VAL SMALL WARNING'
    else:
    rel_error=abs(grad_analytic-grad_numerical)/abs(grad_numerical+grad_analytic)
    status='OK'
    ifrel_error>rel_error_thr_warning:status='WARNING'
    ifrel_error>rel_error_thr_error:status='!!!!! NOTOK'
    # print stats
    print'%s checking param %s index %s (val = %+8f), analytic = %+8f, numerical = %+8f, relative error = %+8f' \
    % (status,name,`np.unravel_index(i, mat.shape)`,old_val,grad_analytic,grad_numerical,rel_error)
    if__name__=="__main__":
    checkSequentialMatchesBatch()
    raw_input('check OK, press key to continue to gradient check')
    checkBatchGradient()
    print'every line should start with OK. Have a nice day!'
    @pranv
    Copy link

    This is indeed very efficient. I sat down hoping to rewrite this faster. Early on I changed a lot of things. But later, I reverted most of it.

    Have you used numba. It gave me quite a bit of speedup.

    @upul
    Copy link

    Thanks a lot, karpathy!!!!

    Could you please suggest me a good reference to get familiar with SLTM equations? Presently, I'm using

    http://arxiv.org/abs/1503.04069 andhttps://apaszke.github.io/lstm-explained.html

    Thanks

    Copy link

    This is fantastic and clearly written. Thanks for this

    @arita37
    Copy link

    Hello,
    Any try to convert it slightly in cython ? (it should give a boost of x3)

    @georgeblck
    Copy link

    Thanks for this!
    I rewrote the code in R, if anyone is interested.

    @amin07
    Copy link

    What should be the shape of 'dHout_in' in backward pass if I want to consider only nth time's state in my classification/softmax layer??

    @Alessiobrini
    Copy link

    Can someone post some example of using this batched LSTM for training over a dataset? I'm new to the domain and I have learned a lot by reading this well written code, but when it is time to use it in a real learning situation I find problems.

    Sign up for freeto join this conversation on GitHub. Already have an account?Sign in to comment

    [8]ページ先頭

    ©2009-2025 Movatter.jp