Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

[CUDA] Fixes for backwards in memefficient attn for large tensors#154663

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to ourterms of service andprivacy statement. We’ll occasionally send you account related emails.

Already on GitHub?Sign in to your account

Closed
Show file tree
Hide file tree
Changes fromall commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 17 additions & 3 deletionsaten/src/ATen/native/transformers/cuda/attention.cu
View file
Open in desktop
Original file line numberDiff line numberDiff line change
Expand Up@@ -968,8 +968,8 @@ std::tuple<Tensor, Tensor, Tensor, Tensor> _scaled_dot_product_efficient_attenti
int64_t batch_size = query.size(0);

if (batch_size > MAX_BATCH_SIZE) {
TORCH_CHECK(!compute_log_sumexp && (dropout_p == 0.0),
"Efficient attention cannot produce valid seed, logsumexp and offset outputs when "
TORCH_CHECK(dropout_p == 0.0,
"Efficient attention cannot produce valid seed and offset outputs when "
"the batch size exceeds (", MAX_BATCH_SIZE, ").");
}
auto process_chunk = [&](const Tensor& q_chunk,
Expand DownExpand Up@@ -1030,6 +1030,17 @@ std::tuple<Tensor, Tensor, Tensor, Tensor> _scaled_dot_product_efficient_attenti
}
Tensor final_attention = at::empty_strided(sizes, attn.strides(), attn.options());
final_attention.slice(0, start, end).copy_(attn);
Tensor final_log_sumexp;
if (compute_log_sumexp && log_sumexp.numel() > 0) {
std::vector<int64_t> lse_sizes;
lse_sizes.reserve(log_sumexp.dim());
lse_sizes.push_back(batch_size);
for (int i = 1; i < log_sumexp.dim(); i++) {
lse_sizes.push_back(log_sumexp.size(i));
}
final_log_sumexp = at::empty(std::move(lse_sizes), log_sumexp.options());
final_log_sumexp.slice(0, start, end).copy_(log_sumexp);
}

for (start = end; start < batch_size; start += MAX_BATCH_SIZE) {
end = std::min(start + MAX_BATCH_SIZE, batch_size);
Expand All@@ -1045,10 +1056,13 @@ std::tuple<Tensor, Tensor, Tensor, Tensor> _scaled_dot_product_efficient_attenti
auto [chunk_attn, chunk_log_sumexp, chunk_seed, chunk_offset] =
process_chunk(query_chunk, key_chunk, value_chunk, bias_chunk);
final_attention.slice(0, start, end).copy_(chunk_attn);
if (compute_log_sumexp && chunk_log_sumexp.numel() > 0) {
final_log_sumexp.slice(0, start, end).copy_(chunk_log_sumexp);
}
}

return std::make_tuple(std::move(final_attention),
std::move(log_sumexp),
std::move(final_log_sumexp),
std::move(seed),
std::move(offset));
}
Expand Down
131 changes: 116 additions & 15 deletionsaten/src/ATen/native/transformers/cuda/attention_backward.cu
View file
Open in desktop
Original file line numberDiff line numberDiff line change
Expand Up@@ -24,6 +24,8 @@
#include <ATen/Functions.h>
#include <ATen/NativeFunctions.h>
#else
#include <ATen/ops/zeros_like.h>
#include <ATen/ops/empty_strided.h>
#include <ATen/ops/_flash_attention_backward.h>
#include <ATen/ops/_flash_attention_backward_native.h>
#include <ATen/ops/_efficient_attention_backward.h>
Expand DownExpand Up@@ -905,40 +907,56 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor> _scaled_dot_product_e
if (!grad_out_.defined()) {
return std::make_tuple(Tensor{}, Tensor{}, Tensor{}, Tensor{});
}
auto grad_out = grad_out_.transpose(1, 2);
constexpr int64_t MAX_BATCH_SIZE = (1LL << 16) - 1;
int64_t batch_size = query.size(0);

if (batch_size > MAX_BATCH_SIZE) {
TORCH_CHECK(dropout_p == 0.0,
"Efficient attention backward cannot handle dropout when "
"the batch size exceeds (", MAX_BATCH_SIZE, ").");
}
auto grad_out_t = grad_out_.transpose(1, 2);
auto query_t = query.transpose(1, 2);
auto key_t = key.transpose(1, 2);
auto value_t = value.transpose(1, 2);
auto out_t = out.transpose(1, 2);
auto q_t = query.transpose(1, 2);
auto k_t = key.transpose(1, 2);
auto v_t = value.transpose(1, 2);

auto process_chunk = [&](const Tensor& grad_out_chunk,
const Tensor& query_chunk,
const Tensor& key_chunk,
const Tensor& value_chunk,
const std::optional<Tensor>& attn_bias_chunk,
const Tensor& out_chunk,
const Tensor& logsumexp_chunk)
-> std::tuple<Tensor, Tensor, Tensor, Tensor> {
// This is needed because SaveVariable automatically converts
// std::optional to undefined tensor
std::optional<Tensor> kernel_bias;
if (attn_bias.defined()) {
kernel_bias =attn_bias;
if (attn_bias_chunk.has_value() && attn_bias_chunk.value().defined()) {
kernel_bias =attn_bias_chunk.value();
}
// Will add with signauter changes for dropout and bias
// We are only handling Dense inputs, but this should be passed
// from forward to backward
int64_t max_seqlen_q =q_t.size(1);
int64_t max_seqlen_k =k_t.size(1);
int64_t max_seqlen_q =query_chunk.size(2);
int64_t max_seqlen_k =key_chunk.size(2);

sdp::CustomMaskType custom_mask_type = causal
? sdp::CustomMaskType::CausalFromTopLeft
: sdp::CustomMaskType::NoCustomMask;
auto [grad_q, grad_k, grad_v, grad_bias] =
at::_efficient_attention_backward(
grad_out,
q_t,
k_t,
v_t,
grad_out_chunk,
query_chunk,
key_chunk,
value_chunk,
kernel_bias,
out_t,
out_chunk,
std::nullopt,
std::nullopt,
max_seqlen_q,
max_seqlen_k,
logsumexp,
logsumexp_chunk,
dropout_p,
philox_seed,
philox_offset,
Expand All@@ -947,7 +965,90 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor> _scaled_dot_product_e
scale,
std::nullopt); // num_split_keys
return std::make_tuple(
grad_q.transpose(1, 2), grad_k.transpose(1, 2), grad_v.transpose(1, 2), grad_bias);
grad_q.transpose(1, 2), grad_k.transpose(1, 2), grad_v.transpose(1, 2), std::move(grad_bias));
};

// process in chunks if batch size exceeds maximum
if (batch_size > MAX_BATCH_SIZE) {
Tensor final_grad_q, final_grad_k, final_grad_v, final_grad_bias;

auto create_strided_output = [batch_size](const Tensor& tensor) -> Tensor {
if (!tensor.defined()) {
return Tensor{};
}
int dim = tensor.dim();
std::vector<int64_t> sizes;
sizes.reserve(dim);
sizes.push_back(batch_size);
for (int i = 1; i < dim; i++) {
sizes.push_back(tensor.size(i));
}
return at::empty_strided(std::move(sizes), tensor.strides(), tensor.options());
};

if (grad_input_mask[0]) {
final_grad_q = create_strided_output(query);
}

if (grad_input_mask[1]) {
final_grad_k = create_strided_output(key);
}

if (grad_input_mask[2]) {
final_grad_v = create_strided_output(value);
}
if (grad_input_mask[3] && attn_bias.defined()) {
final_grad_bias = at::zeros_like(attn_bias);
}

for (int64_t start = 0; start < batch_size; start += MAX_BATCH_SIZE) {
int64_t end = std::min(start + MAX_BATCH_SIZE, batch_size);

Tensor grad_out_chunk = grad_out_t.slice(0, start, end);
Tensor query_chunk = query_t.slice(0, start, end);
Tensor key_chunk = key_t.slice(0, start, end);
Tensor value_chunk = value_t.slice(0, start, end);
Tensor attn_bias_chunk;
if (attn_bias.defined()) {
attn_bias_chunk = attn_bias.slice(0, start, end);
} else {
attn_bias_chunk.reset();
}
Tensor out_chunk = out_t.slice(0, start, end);
Tensor logsumexp_chunk = logsumexp.numel() > 0 ? logsumexp.slice(0, start, end) : logsumexp;

auto [chunk_grad_q, chunk_grad_k, chunk_grad_v, chunk_grad_bias] =
process_chunk(grad_out_chunk, query_chunk, key_chunk, value_chunk,
attn_bias_chunk, out_chunk, logsumexp_chunk);

if (grad_input_mask[0] && chunk_grad_q.defined()) {
final_grad_q.slice(0, start, end).copy_(chunk_grad_q);
}
if (grad_input_mask[1] && chunk_grad_k.defined()) {
final_grad_k.slice(0, start, end).copy_(chunk_grad_k);
}
if (grad_input_mask[2] && chunk_grad_v.defined()) {
final_grad_v.slice(0, start, end).copy_(chunk_grad_v);
}
if (grad_input_mask[3] && chunk_grad_bias.defined()) {
final_grad_bias.add_(chunk_grad_bias);
}
}

return std::make_tuple(
std::move(final_grad_q),
std::move(final_grad_k),
std::move(final_grad_v),
std::move(final_grad_bias));
}
// when batch size is within allowed size, no chunking needed
else {
std::optional<Tensor> attn_bias_opt;
if (attn_bias.defined()) {
attn_bias_opt = attn_bias;
}
return process_chunk(grad_out_t, query_t, key_t, value_t, attn_bias_opt, out_t, logsumexp);
}
}

} // namespace at::native
29 changes: 21 additions & 8 deletionstest/test_transformers.py
View file
Open in desktop
Original file line numberDiff line numberDiff line change
Expand Up@@ -1900,23 +1900,36 @@ def test_flash_attention_fail_with_non_square_causal_attention(self, device):

@onlyCUDA
def test_mem_eff_attention_fail_with_batch_size_geq_65536(self):
query = torch.rand([2**16, 2, 2, 8], device='cuda', dtype=torch.float16)
key = torch.rand([2**16, 2, 2, 8], device='cuda', dtype=torch.float16)
value = torch.rand([2**16, 2, 2, 8], device='cuda', dtype=torch.float16)
batch_size = 2**16
query = torch.rand([batch_size, 2, 2, 8], device='cuda', dtype=torch.float16, requires_grad=True)
key = torch.rand([batch_size, 2, 2, 8], device='cuda', dtype=torch.float16, requires_grad=True)
value = torch.rand([batch_size, 2, 2, 8], device='cuda', dtype=torch.float16, requires_grad=True)
q_cpu, k_cpu, v_cpu = (query.detach().cpu().requires_grad_(True),
key.detach().cpu().requires_grad_(True),
value.detach().cpu().requires_grad_(True))
with sdpa_kernel(backends=SDPBackend.EFFICIENT_ATTENTION):
out = F.scaled_dot_product_attention(query, key, value)
out_cpu = F.scaled_dot_product_attention(query.cpu(), key.cpu(), value.cpu())
self.assertEqual(out, out_cpu, atol=1e-3, rtol=1e-4)
out_cpu = F.scaled_dot_product_attention(q_cpu, k_cpu, v_cpu)
grad_out = torch.rand_like(out)
out.backward(grad_out)
out_cpu.backward(grad_out.cpu())

self.assertEqual(out, out_cpu, atol=2e-3, rtol=1e-4)
self.assertEqual(query.grad, q_cpu.grad, atol=2e-3, rtol=1e-4)
self.assertEqual(key.grad, k_cpu.grad, atol=2e-3, rtol=1e-4)
self.assertEqual(value.grad, v_cpu.grad, atol=2e-3, rtol=1e-4)

@onlyCUDA
def test_mem_eff_attention_fail_with_batch_size_geq_65536_error(self):
query = torch.rand([2**16, 2, 2, 8], device='cuda', dtype=torch.float16)
key = torch.rand([2**16, 2, 2, 8], device='cuda', dtype=torch.float16)
value = torch.rand([2**16, 2, 2, 8], device='cuda', dtype=torch.float16)
error_str = (r"Efficient attention cannot produce valid seed, "
r"logsumexp and offset outputs whenthe batch size exceeds \(65535\)\.")
error_str = (r"Efficient attention cannot produce valid seed and offset outputs when "
r"the batch size exceeds \(65535\)\.")
with self.assertRaisesRegex(RuntimeError, error_str):
torch._scaled_dot_product_efficient_attention(query, key, value, attn_bias=None, compute_log_sumexp=True)
torch._scaled_dot_product_efficient_attention(query, key, value,
attn_bias=None, compute_log_sumexp=True,
dropout_p=0.01)

def _get_block_size_n(device, head_dim, is_dropout, is_causal):
# This should match the block sizes in the CUDA kernel
Expand Down
Loading

[8]ページ先頭

©2009-2025 Movatter.jp