- Notifications
You must be signed in to change notification settings - Fork6.6k
[RLlib] Optimize rnn_sequencing performance#46502
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to ourterms of service andprivacy statement. We’ll occasionally send you account related emails.
Already on GitHub?Sign in to your account
Conversation
This pull request has been automatically marked as stale because it has not had recent activity. It will be closed in 14 days if no further activity occurs. Thank you for your contributions.
|
This pull request has been automatically marked as stale because it has not had recent activity. It will be closed in 14 days if no further activity occurs. Thank you for your contributions.
|
This pull request has been automatically closed because there has been no more activity in the 14 days Please feel free to reopen or open a new pull request if you'd still like this to be addressed. Again, you can always ask for help on ourdiscussion forum orRay's public slack channel. Thanks again for your contribution! |
Why are these changes needed?
We found the performance of LSTMs in Rllib to be extremely slow compared to other methods, with a single training iteration of PPO taking 179 seconds (compared to ~9 seconds with a similarly-sized MLP network). This made RNNs/LSTMs, as well as some transformer implementations, completely unusable for our purposes.
However, when profiling, we found this was primarily due to a very slow copy operation:
Further investigation revealed that most of this runtime was spent copying the
infos
dict. We determined that the root cause was inconsistent handling of the dictionary inrnn_sequencing
. While the non-recurrent implementation stores the list of dictionaries as a NumPy array of objects,rnn_sequencing
instead stores it as a Python list:We applied a one-line fix to make this behavior consistent and store the list as NumPy array:
This causes the
copy
function to perform a shallow copy, drastically improving performance by ~6x to around 29 seconds:However, we found that the training loop was still spending a lot of time in rnn_sequencing. We traced this down to a slow element-wise copy into an array. We instead replaced this with a vectorized copy:
This further improved performance (in this sample, we also removed the 1-time summary logging, but did not remove this in this PR):
Combined, these changes reduced the run time of the training step from 179 seconds to 15 seconds, approximately a 12x speedup and competitive with training an MLP.
Related issue number
Checks
git commit -s
) in this PR.scripts/format.sh
to lint the changes in this PR.method in Tune, I've added it in
doc/source/tune/api/
under thecorresponding
.rst
file.