Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commit612af49

Browse files
committed
Add trace with JIT example
1 parent610a481 commit612af49

File tree

5 files changed

+1775
-0
lines changed

5 files changed

+1775
-0
lines changed
Lines changed: 219 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,219 @@
1+
# vector_scalar_mul/vector_scalar_mul_jit.py -*- Python -*-
2+
#
3+
# This file is licensed under the Apache License v2.0 with LLVM Exceptions.
4+
# See https://llvm.org/LICENSE.txt for license information.
5+
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
#
7+
# (c) Copyright 2024-2025 Advanced Micro Devices, Inc. or its affiliates
8+
9+
importargparse
10+
importsys
11+
importnumpyasnp
12+
importaie.ironasiron
13+
importos
14+
15+
fromaie.ironimportExternalFunction,ObjectFifo,Program,Runtime,Worker
16+
fromaie.iron.placersimportSequentialPlacer
17+
fromaie.iron.deviceimportNPU1Col1,NPU2Col1
18+
fromaie.iron.controlflowimportrange_
19+
fromaie.iron.dtypeimportstr_to_dtype
20+
importargparse
21+
importsys
22+
importnumpyasnp
23+
importaie.ironasiron
24+
25+
fromaie.ironimportObjectFifo,Program,Runtime,Worker
26+
fromaie.iron.placersimportSequentialPlacer
27+
fromaie.iron.deviceimportNPU1Col1,NPU2Col1
28+
fromaie.iron.controlflowimportrange_
29+
fromaie.ironimporttrace
30+
31+
32+
@iron.jit(is_placed=False)
33+
defvector_scalar_mul(input0,input1,output):
34+
ifinput0.shape!=output.shape:
35+
raiseValueError(
36+
f"Input and output shapes are not the same ({input0.shape} !={output.shape})."
37+
)
38+
iflen(np.shape(input0))!=1:
39+
raiseValueError("Function only supports vectors.")
40+
41+
num_elements=np.size(input0)
42+
43+
# Add size validation like in reference code
44+
# Assert that input1 (factor) is size 4 bytes (1 integer)
45+
ifnp.size(input1)!=1:
46+
raiseValueError("2nd input buffer must be size 1 (1 integer).")
47+
48+
# Assert output size matches input size
49+
ifoutput.numel()!=input0.numel():
50+
raiseValueError("Output buffer size must match input buffer size.")
51+
52+
num_sub_vectors=4
53+
tile_size=num_elements//num_sub_vectors
54+
55+
ifnum_elements%num_sub_vectors!=0:
56+
raiseValueError(
57+
f"Number of elements ({num_elements}) must be a multiple of{num_sub_vectors}."
58+
)
59+
60+
ifinput0.dtype!=output.dtype:
61+
raiseValueError(
62+
f"Input and output data types are not the same ({input0.dtype} !={output.dtype})."
63+
)
64+
dtype=input0.dtype
65+
66+
# Define tensor types - factor should be scalar_ty (np.int32), not tile_ty
67+
tensor_ty=np.ndarray[(num_elements,),np.dtype[dtype]]
68+
tile_ty=np.ndarray[(tile_size,),np.dtype[dtype]]
69+
scalar_ty=np.ndarray[(1,),np.dtype[np.int32]]
70+
71+
# Create a handle to an externally-defined kernel
72+
# Construct path to kernel source file
73+
current_dir=os.path.dirname(__file__)
74+
kernel_path=os.path.join(current_dir,"../../../aie_kernels/aie2","scale.cc")
75+
# Get the bit width directly from the dtype
76+
bit_width=np.dtype(input0.dtype).itemsize*8
77+
78+
# Use the same kernel function name as reference code
79+
scale=ExternalFunction(
80+
"vector_scalar_mul_vector",
81+
source_file=kernel_path,
82+
arg_types=[
83+
tile_ty,# input tensor
84+
tile_ty,# output tensor
85+
scalar_ty,# scalar factor
86+
np.int32,# N
87+
],
88+
compile_flags=[f"-DBIT_WIDTH={bit_width}"],
89+
include_dirs=[os.path.join(current_dir,"../../../aie_kernels/aie2")],
90+
)
91+
92+
# AIE-array data movement with object fifos
93+
# Factor should be scalar_ty, not tensor_ty
94+
of_in=ObjectFifo(tile_ty,name="in")
95+
of_factor=ObjectFifo(scalar_ty,name="infactor")
96+
of_out=ObjectFifo(tile_ty,name="out")
97+
98+
# Define a task that will run on a compute tile
99+
defcore_body(of_in,of_factor,of_out,scale_fn):
100+
# Acquire factor once outside the loop, like in reference code
101+
elem_factor=of_factor.acquire(1)
102+
103+
# Number of sub-vector "tile" iterations
104+
for_inrange_(num_sub_vectors):
105+
elem_in=of_in.acquire(1)
106+
elem_out=of_out.acquire(1)
107+
scale_fn(elem_in,elem_out,elem_factor,tile_size)
108+
of_in.release(1)
109+
of_out.release(1)
110+
# Release factor once after the loop
111+
of_factor.release(1)
112+
113+
# Create a worker to run the task on a compute tile
114+
# enable_trace = 1 if trace.get_trace_size() > 0 else 0
115+
worker=Worker(
116+
core_body,
117+
fn_args=[of_in.cons(),of_factor.cons(),of_out.prod(),scale],
118+
trace=1iftrace.get_trace_size()>0else0,
119+
)
120+
121+
# Runtime operations to move data to/from the AIE-array
122+
rt=Runtime()
123+
124+
withrt.sequence(tensor_ty,scalar_ty,tensor_ty)as (A,F,C):
125+
iftrace.get_trace_size()>0:
126+
rt.enable_trace(trace.get_trace_size())
127+
rt.start(worker)
128+
rt.fill(of_in.prod(),A)
129+
rt.fill(of_factor.prod(),F)
130+
rt.drain(of_out.cons(),C,wait=True)
131+
132+
# Place program components (assign them resources on the device) and generate an MLIR module
133+
returnProgram(iron.get_current_device(),rt).resolve_program(SequentialPlacer())
134+
135+
136+
defmain():
137+
parser=argparse.ArgumentParser()
138+
parser.add_argument(
139+
"-v","--verbose",action="store_true",help="Enable verbose output"
140+
)
141+
parser.add_argument(
142+
"-n",
143+
"--num-elements",
144+
type=int,
145+
default=1024,
146+
help="Number of elements (default: 1024, must be multiple of 128 and >= 1024)",
147+
)
148+
parser.add_argument(
149+
"-t",
150+
"--trace-size",
151+
type=int,
152+
default=1024,
153+
help="Trace buffer size (0 = no tracing, default: 0)",
154+
)
155+
parser.add_argument(
156+
"-z",
157+
"--data_type",
158+
choices=["i16","i32"],
159+
default="i16",
160+
help="Data type (default: i16)",
161+
)
162+
args=parser.parse_args()
163+
164+
# Buffer size validation like reference code
165+
ifargs.num_elements%128!=0orargs.num_elements<1024:
166+
print(
167+
"Number of elements must be a multiple of 128 (so len is multiple of 64) and greater than or equal to 1024 (so len >= 512)"
168+
)
169+
raiseValueError
170+
171+
# Construct input random tensors and an output zeroed tensor
172+
# The tensors are in memory accessible to the NPU
173+
datatype=str_to_dtype(args.data_type)
174+
input0=iron.randint(0,100, (args.num_elements,),dtype=datatype,device="npu")
175+
scalar=iron.randint(0,100, (1,),dtype=np.int32,device="npu")
176+
output=iron.zeros_like(input0)
177+
178+
# Enable tracing if requested
179+
ifargs.trace_size>0:
180+
trace.set_trace_size(args.trace_size)
181+
trace.start_trace()
182+
183+
# JIT-compile the kernel then launches the kernel with the given arguments
184+
vector_scalar_mul(input0,scalar,output)
185+
186+
# Stop tracing and save results if tracing was enabled
187+
ifargs.trace_size>0:
188+
trace_filename=f"trace_output_{args.num_elements}_{args.data_type}.json"
189+
trace.stop_trace(trace_filename)
190+
print(f"Tracing completed and saved to{trace_filename}")
191+
192+
# Check the correctness of the result - use scalar multiplication
193+
expected=input0.numpy()*scalar.numpy()[0]
194+
actual=output.numpy()
195+
e=np.equal(expected,actual)
196+
errors=np.size(e)-np.count_nonzero(e)
197+
198+
# Optionally, print the results
199+
ifargs.verbose:
200+
print(f"{'input0':>4} *{'factor':>4} ={'output':>4}")
201+
print("-"*34)
202+
count=input0.numel()
203+
factor=scalar.numpy()[0]
204+
foridx, (a,c)inenumerate(zip(input0[:count],output[:count])):
205+
print(f"{idx:2}:{a:4} *{factor:4} ={c:4}")
206+
207+
# If the result is correct, exit with a success code.
208+
# Otherwise, exit with a failure code
209+
ifnoterrors:
210+
print("\nPASS!\n")
211+
sys.exit(0)
212+
else:
213+
print("\nError count: ",errors)
214+
print("\nFailed.\n")
215+
sys.exit(-1)
216+
217+
218+
if__name__=="__main__":
219+
main()

‎python/iron/__init__.py‎

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,5 +24,7 @@
2424
arange,
2525
zeros_like,
2626
)
27+
28+
from .importtrace
2729
exceptImportError:
2830
pass# silently ignore if pyxrt or .jit can't be imported

‎python/iron/jit.py‎

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,13 @@
2121
from .compileimportcompile_mlir_module
2222
from .configimportget_current_device
2323
fromaie.dialects.aieimportAIEDevice
24+
from .tensorimportzeros
25+
from .traceimport (
26+
_get_trace_active,
27+
_get_trace_tensor,
28+
_get_dummy_tensor,
29+
set_mlir_module,
30+
)
2431

2532

2633
# The `iron.jit` decorator below caches compiled kenrels inside the `IRON_CACHE_DIR` directory.
@@ -108,6 +115,26 @@ def __call__(self, *args):
108115
)
109116
kernel_args.append(tensor.buffer_object())
110117

118+
if_get_trace_active():
119+
# We always put the trace tensor at the 5th argument to match backend tracing logic
120+
# So we only enable tracing if we have at most 4 user arguments
121+
trace_tensor=_get_trace_tensor()
122+
iftrace_tensorisNone:
123+
raiseRuntimeError("Tracing active but no trace tensor found")
124+
125+
iflen(kernel_args)>=5:
126+
raiseValueError(
127+
f"Tracing can only be done for kernels with 4 or fewer arguments. Got{len(kernel_args)} arguments."
128+
)
129+
130+
# Pad with dummy tensors if needed and add them to kernel_args
131+
whilelen(kernel_args)<4:
132+
dummy_tensor=_get_dummy_tensor()
133+
kernel_args.append(dummy_tensor.buffer_object())
134+
135+
# Add trace tensor as the 5th argument
136+
kernel_args.append(trace_tensor.buffer_object())
137+
111138
h=self.__kernel(opcode,self.__insts_buffer_bo,self.__n_insts,*kernel_args)
112139
r=h.wait()
113140
ifr!=xrt.ert_cmd_state.ERT_CMD_STATE_COMPLETED:
@@ -229,12 +256,16 @@ def decorator(*args, **kwargs):
229256
xclbin_path=xclbin_path,
230257
work_dir=kernel_dir,
231258
)
259+
232260
exceptExceptionase:
233261
# Clean up cache directory on any compilation failure to avoid any corrupted objects in the cache
234262
ifos.path.exists(kernel_dir):
235263
shutil.rmtree(kernel_dir)
236264
raisee
237265

266+
# Set the MLIR module globally for tracing to use
267+
set_mlir_module(str(mlir_module))
268+
238269
kernel_name="MLIR_AIE"
239270
try:
240271
kernel=NPUKernel(xclbin_path,inst_path,kernel_name=kernel_name)

0 commit comments

Comments
 (0)

[8]ページ先頭

©2009-2025 Movatter.jp