Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings
/jaxPublic

Commitfdedc80

Browse files
[Mosaic GPU] Improve error handling in Mosaic GPU custom call.
Replace `fprintf(stderr)` and `abort()` calls with `absl::Status` returns for various error conditions.It seems it was copied from the legacy custom call implementation which did not support error handling.PiperOrigin-RevId: 837475367
1 parent9138c20 commitfdedc80

File tree

1 file changed

+7
-7
lines changed

1 file changed

+7
-7
lines changed

‎jaxlib/mosaic/gpu/custom_call.cc‎

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -639,20 +639,20 @@ absl::Status MosaicGpuExecute(gpuStream_t stream, ffi::RemainingArgs inputs,
639639
// Updated version using the new FFI API supporting custom barrier
640640
// for distributed kernels
641641
if (use_custom_barrier) {
642-
fprintf(stderr,"Custom barrier is not supported on GPUs.\n");
643-
abort();
642+
returnabsl::UnimplementedError("Custom barrier is not supported on GPUs.");
644643
}
645644
if (reinterpret_cast<constuintptr_t>(kernel_hash.data()) %
646645
alignof(KernelHash) ||
647646
kernel_hash.size() !=sizeof(KernelHash)) {
648-
fprintf(stderr,"Misaligned opaque pointer\n");
649-
abort();
647+
returnabsl::InvalidArgumentError("Misaligned opaque pointer");
650648
}
651649
auto hash = *reinterpret_cast<const KernelHash*>(kernel_hash.data());
652650
CUcontext ctx;
653-
if (cuCtxGetCurrent(&ctx) != CUDA_SUCCESS) {
654-
fprintf(stderr,"Failed to get current CUDA context\n");
655-
abort();
651+
if (auto result =cuCtxGetCurrent(&ctx); result != CUDA_SUCCESS) {
652+
constchar* error;
653+
cuGetErrorString(result, &error);
654+
returnabsl::InternalError(
655+
absl::StrFormat("Failed to get current CUDA context: %s", error));
656656
}
657657
CacheKeykey(hash,reinterpret_cast<uintptr_t>(ctx));
658658
TF_ASSIGN_OR_RETURN(auto compiled_kernel,

0 commit comments

Comments
 (0)

[8]ページ先頭

©2009-2025 Movatter.jp