Using Torch Autograd ctx to Optimize Memory Leaking Issue#202
Open
Mr-Philo wants to merge 1 commit intoAzure:mainfrom
Open
Using Torch Autograd ctx to Optimize Memory Leaking Issue#202Mr-Philo wants to merge 1 commit intoAzure:mainfrom
Mr-Philo wants to merge 1 commit intoAzure:mainfrom
Conversation
Contributor
Author
@microsoft-github-policy-service agree company="Microsoft" |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
See Issue #201
This PR create a potential solution to solve the memory leadking issue when using MS-AMP custom GeMM.
Currently the custom GeMM function use
ctxobject to save input tensor x and weight tensor W. In backward gradient computing, x and W are needed.ctx.input_fp8means directly saving this attribute. However,input_fp8is forclass ScalingTensor. In practice, this saving method does not fully leverage the advantage of FP8 tensors!Instead, I suggest using
ctx.save_for_backward(). This method is specially designed for better memory management. Change saved context fromScalingTensortotorch.Tensor+ScalingMeta. This is proved to be efficient in memory saving!Effect for deit-base (86M) model training, batch size 256:
Effect for deit 570M model training, batch size 256: