@@ -153,6 +153,86 @@ def generate_proposals_ref():
153153torch .testing .assert_allclose (rois ,a )
154154torch .testing .assert_allclose (rois_probs ,b )
155155
156+ @given (
157+ bsz = st .integers (1 ,5 ),
158+ seq_lens = st .integers (1 ,6 ),
159+ emb_lens = st .integers (5 ,10 ),
160+ hidden_size = st .integers (3 ,7 ),
161+ num_layers = st .integers (1 ,4 ),
162+ has_biases = st .booleans (),
163+ is_bidirectional = st .booleans (),
164+ batch_first = st .booleans (),
165+ )
166+ def test_inference_lstm (
167+ self ,
168+ bsz ,
169+ seq_lens ,
170+ emb_lens ,
171+ hidden_size ,
172+ num_layers ,
173+ has_biases ,
174+ is_bidirectional ,
175+ batch_first ,
176+ ):
177+ num_directions = 2 if is_bidirectional else 1
178+ hx = np .zeros ((num_layers * num_directions ,bsz ,hidden_size ),dtype = np .float32 )
179+
180+ if batch_first :
181+ inputs = np .random .randn (bsz ,seq_lens ,emb_lens ).astype (np .float32 )
182+ else :
183+ inputs = np .random .randn (seq_lens ,bsz ,emb_lens ).astype (np .float32 )
184+
185+ torch_lstm = torch .nn .LSTM (
186+ emb_lens ,
187+ hidden_size ,
188+ batch_first = batch_first ,
189+ bidirectional = is_bidirectional ,
190+ bias = has_biases ,
191+ num_layers = num_layers ,
192+ )
193+
194+ def inference_lstm_ref ():
195+ input_names = ["inputs" ,"hidden_0" ,"hidden_1" ]
196+ workspace .FeedBlob ("inputs" ,inputs )
197+ workspace .FeedBlob ("hidden_0" ,hx )
198+ workspace .FeedBlob ("hidden_1" ,hx )
199+ for i ,param in enumerate (torch_lstm ._flat_weights ):
200+ input_names .append ("param_{}" .format (i ))
201+ workspace .FeedBlob ("param_{}" .format (i ),param .detach ().numpy ())
202+
203+ ref_op = core .CreateOperator (
204+ "InferenceLSTM" ,
205+ input_names ,
206+ ["output" ,"hidden" ,"cell" ],
207+ num_layers = num_layers ,
208+ has_biases = has_biases ,
209+ batch_first = batch_first ,
210+ bidirectional = is_bidirectional ,
211+ )
212+ workspace .RunOperatorOnce (ref_op )
213+ return (
214+ workspace .FetchBlob ("output" ),
215+ workspace .FetchBlob ("hidden" ),
216+ workspace .FetchBlob ("cell" )
217+ )
218+
219+ output ,hidden ,cell = inference_lstm_ref ()
220+ output = torch .tensor (output )
221+ hidden = torch .tensor (hidden )
222+ cell = torch .tensor (cell )
223+ lstm_in = [
224+ torch .from_numpy (inputs ),
225+ torch .from_numpy (hx ),
226+ torch .from_numpy (hx ),
227+ ]+ [param .detach ()for param in torch_lstm ._flat_weights ]
228+
229+ a ,b ,c = torch .ops ._caffe2 .InferenceLSTM (
230+ lstm_in ,num_layers ,has_biases ,batch_first ,is_bidirectional
231+ )
232+ torch .testing .assert_allclose (output ,a )
233+ torch .testing .assert_allclose (hidden ,b )
234+ torch .testing .assert_allclose (cell ,c )
235+
156236# Test case is using workspace.has_cuda_support and not workspace.has_gpu_support
157237# to exclude it from HIP because tensor interop doesn't work for HIP tensors yet
158238@unittest .skipIf (not workspace .has_cuda_support ,"No cuda support" )