Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings
/jaxPublic

Commit9138c20

Browse files
[Mosaic GPU] Remove legacy mosaic GPU FFI.
PiperOrigin-RevId: 837463751
1 parentb4dfa55 commit9138c20

File tree

4 files changed

+11
-50
lines changed

4 files changed

+11
-50
lines changed

‎jax/_src/internal_test_util/export_back_compat_test_data/pallas/mosaic_gpu_add_one.py‎

Lines changed: 9 additions & 9 deletions
Large diffs are not rendered by default.

‎jaxlib/mosaic/gpu/BUILD‎

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -329,8 +329,6 @@ cc_library(
329329
"@xla//xla:executable_run_options",
330330
"@xla//xla/ffi",
331331
"@xla//xla/ffi:ffi_api",
332-
"@xla//xla/service:custom_call_status",
333-
"@xla//xla/service:custom_call_target_registry",
334332
"@xla//xla/service/gpu/llvm_gpu_backend:nvptx_libdevice_path",
335333
"@xla//xla/service/llvm_ir:llvm_command_line_options",
336334
"@xla//xla/stream_executor/cuda:assemble_compilation_provider",

‎jaxlib/mosaic/gpu/custom_call.cc‎

Lines changed: 0 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@ limitations under the License.
2323
#include<cstdint>
2424
#include<cstdio>
2525
#include<cstdlib>
26-
#include<cstring>
2726
#include<functional>
2827
#include<memory>
2928
#include<optional>
@@ -113,8 +112,6 @@ limitations under the License.
113112
#include"xla/executable_run_options.h"
114113
#include"xla/ffi/ffi.h"
115114
#include"xla/ffi/ffi_api.h"
116-
#include"xla/service/custom_call_status.h"
117-
#include"xla/service/custom_call_target_registry.h"
118115
#include"xla/service/gpu/llvm_gpu_backend/nvptx_libdevice_path.h"
119116
#include"xla/service/llvm_ir/llvm_command_line_options.h"
120117
#include"xla/stream_executor/cuda/assemble_compilation_provider.h"
@@ -634,40 +631,6 @@ absl::StatusOr<CompiledKernel*> CachedCompileAndInit(CacheKey key,
634631
return &cache->at(key);
635632
}
636633

637-
voidMosaicGPUCustomCall(void* stream,void** buffers,char* opaque,
638-
size_t opaque_len, XlaCustomCallStatus* status) {
639-
// Forward-compatible version using the legacy FFI API
640-
if (reinterpret_cast<uintptr_t>(opaque) %alignof(KernelHash)) {
641-
fprintf(stderr,"Misaligned opaque pointer\n");
642-
abort();
643-
}
644-
auto hash = *reinterpret_cast<KernelHash*>(opaque);
645-
CUcontext ctx;
646-
if (cuCtxGetCurrent(&ctx) != CUDA_SUCCESS) {
647-
fprintf(stderr,"Failed to get current CUDA context\n");
648-
abort();
649-
}
650-
CacheKeykey(hash,reinterpret_cast<uintptr_t>(ctx));
651-
auto compiled_kernel =CachedCompileAndInit(key, opaque +sizeof(KernelHash));
652-
if (!compiled_kernel.ok()) {
653-
XlaCustomCallStatusSetFailure(status,
654-
compiled_kernel.status().message().data(),
655-
compiled_kernel.status().message().size());
656-
return;
657-
}
658-
auto ctx_kernel_comm = (*compiled_kernel)->GetHostLaunch();
659-
bool is_comm_used = std::get<2>(ctx_kernel_comm);
660-
void* args[4] = {&std::get<0>(ctx_kernel_comm), &stream, &buffers};
661-
if (is_comm_used) {
662-
mosaic::gpu::NvshmemApi::Default().barrier_all_on_stream(
663-
reinterpret_cast<cudaStream_t>(stream));
664-
}
665-
std::get<1>(ctx_kernel_comm)(args);
666-
}
667-
668-
XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("mosaic_gpu", &MosaicGPUCustomCall,
669-
"CUDA");
670-
671634
absl::StatusMosaicGpuExecute(gpuStream_t stream, ffi::RemainingArgs inputs,
672635
ffi::RemainingRets results,
673636
std::string_view kernel_hash,

‎tests/pallas/export_back_compat_pallas_test.py‎

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -81,8 +81,8 @@ def test_mosaic_gpu_add_one(self):
8181
defadd_one(x_ref,o_ref):
8282
o_ref[...]=x_ref[...]+1
8383

84-
data=self.load_testdata(mosaic_gpu_add_one.data_2025_04_22)
85-
self.run_one_test(add_one,data,expect_current_custom_calls=["mosaic_gpu_v2"])
84+
data=self.load_testdata(mosaic_gpu_add_one.data_2025_11_27)
85+
self.run_one_test(add_one,data)
8686

8787
deftest_mosaic_gpu_kernel_add_one(self):
8888
ifnotjtu.is_cuda_compute_capability_at_least("9.0"):

0 commit comments

Comments
 (0)

[8]ページ先頭

©2009-2025 Movatter.jp