Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commit08fcddf

Browse files
chohk88laikhtewari
authored andcommitted
feat: support aten._cdist_forward converter (#2726)
1 parent08fbf8c commit08fcddf

File tree

3 files changed

+318
-0
lines changed

3 files changed

+318
-0
lines changed

‎py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py‎

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2186,6 +2186,26 @@ def aten_ops_linear(
21862186
)
21872187

21882188

2189+
@dynamo_tensorrt_converter(torch.ops.aten._cdist_forward.default)
2190+
defaten_ops_cdist_forward(
2191+
ctx:ConversionContext,
2192+
target:Target,
2193+
args:Tuple[Argument, ...],
2194+
kwargs:Dict[str,Argument],
2195+
name:str,
2196+
)->Union[TRTTensor,Sequence[TRTTensor]]:
2197+
returnimpl.normalization.cdist_forward(
2198+
ctx,
2199+
target,
2200+
SourceIR.ATEN,
2201+
name,
2202+
x1=args[0],
2203+
x2=args[1],
2204+
p=args[2],
2205+
compute_mode=args_bounds_check(args,3,None),
2206+
)
2207+
2208+
21892209
defavg_pool_param_validator(pool_node:Node)->bool:
21902210
ceil_mode=args_bounds_check(pool_node.args,4,False)
21912211
divisor_override=args_bounds_check(pool_node.args,6)

‎py/torch_tensorrt/dynamo/conversion/impl/normalization/ops.py‎

Lines changed: 201 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
importlogging
12
fromtypingimportAny,List,Optional,Sequence,Tuple,Union,cast
23

34
importnumpyasnp
@@ -21,6 +22,8 @@
2122
fromtorch_tensorrt.fx.typesimportTRTTensor
2223
fromtorch_tensorrt.fx.utilsimportget_dynamic_dims
2324

25+
_LOGGER:logging.Logger=logging.getLogger(__name__)
26+
2427

2528
defbatch_norm(
2629
ctx:ConversionContext,
@@ -446,3 +449,201 @@ def pdist(
446449
)
447450
indices=np.triu_indices(shape[0],k=1)
448451
returnimpl.select.index(ctx,target,source_ir,f"{name}_index",norm,indices)
452+
453+
454+
defcdist_forward(
455+
ctx:ConversionContext,
456+
target:Target,
457+
source_ir:Optional[SourceIR],
458+
name:str,
459+
x1:TRTTensor,
460+
x2:TRTTensor,
461+
p:float,
462+
compute_mode:Optional[int],
463+
)->Union[TRTTensor,Sequence[TRTTensor]]:
464+
"""
465+
Computes pairwise distances between sets of vectors in tensors x1 and x2 using the p-norm. The function treats the last dimension
466+
of x1 and x2 as feature dimensions, which must be identical for both inputs. The second-to-last dimensions can differ, reflecting
467+
the number of vectors in each tensor. The dimensions preceding the last are considered as batch dimensions, and pairwise distances
468+
are computed for each matching set in these dimensions.
469+
470+
The output tensor's shape is derived by matching the batch dimensions of x1 and x2, where the mismatched batch dimensions are
471+
merged, and the resulting shape reflects the computed distances for each pair of vectors. It's crucial that the batch dimensions
472+
(except for the size of sets of vectors to compare) of x1 and x2 either match or one of them is 1 (broadcasting).
473+
474+
Args:
475+
x1 (Tensor): input tensor of shape B x P x M.
476+
x2 (Tensor): input tensor of shape B x R x M.
477+
p (float): p value for the p-norm distance to calculate between each vector pair
478+
compute_mode (int): Controls the computation method based on the size of the input sets:
479+
- None ('use_mm_for_euclid_dist_if_necessary'): Default mode. Uses matrix multiplication to calculate
480+
Euclidean distance (p=2) if either the number of vectors in x1 or x2 exceeds 25 (P > 25 or R > 25).
481+
- 1 ('use_mm_for_euclid_dist'): Always use matrix multiplication approach to calculate
482+
euclidean distance (p = 2)
483+
- 2 ('donot_use_mm_for_euclid_dist'): Never use matrix multiplication approach to calculate
484+
euclidean distance (p = 2)
485+
486+
Example:
487+
- If x1.shape = [2, 3, 10, 5] and x2.shape = [2, 3, 20, 5], both having the same batch dimensions [2, 3], the output shape will be [2, 3, 10, 20].
488+
This represents computing distances in two batches of three groups, each comparing 10 vectors from x1 with 20 vectors from x2.
489+
- For x1.shape = [10, 5] (10 vectors, each of 5 features) and x2.shape = [20, 5] (20 vectors, each of 5 features),
490+
since there are no batch dimensions to match, the output shape is simply [10, 20], comparing all vectors from x1 against all vectors from x2.
491+
492+
Note: The `compute_mode` parameter is designed to optimize the performance of the Euclidean distance calculation,
493+
especially useful when working with large datasets. This parameter allows you to control how the distances are computed,
494+
with different modes available to leverage matrix multiplication for speed improvements.
495+
"""
496+
ifcompute_modeisNone:
497+
compute_mode=0
498+
499+
x1_expand_shape=list(x1.shape[:-1])+ [1,x1.shape[-1]]
500+
x2_expand_shape=list(x2.shape[:-2])+ [1]+list(x2.shape[-2:])
501+
502+
# Reshape x1 and x2 for broadcasting
503+
x1_expanded=impl.shuffle.reshape(
504+
ctx,target,source_ir,f"{name}_x1_expand",x1,x1_expand_shape
505+
)
506+
x2_expanded=impl.shuffle.reshape(
507+
ctx,target,source_ir,f"{name}_x2_expand",x2,x2_expand_shape
508+
)
509+
510+
diff=impl.elementwise.sub(
511+
ctx,target,source_ir,f"{name}_diff",x1_expanded,x2_expanded
512+
)
513+
514+
ifp==0:
515+
diff_non_zero=impl.elementwise.ne(
516+
ctx,target,source_ir,f"{name}_diff_non_zero",diff,0
517+
)
518+
diff_non_zero=cast_trt_tensor(
519+
ctx,diff_non_zero,torch.float32,f"{name}_cast",target,source_ir
520+
)
521+
dist=impl.reduce.sum(
522+
ctx,
523+
target,
524+
source_ir,
525+
f"{name}_sum",
526+
diff_non_zero,
527+
dim=-1,
528+
keepdim=False,
529+
)
530+
elifp==1:
531+
abs_val=impl.unary.abs(ctx,target,source_ir,f"{name}_abs_val",diff)
532+
dist=impl.reduce.sum(
533+
ctx,target,source_ir,f"{name}_sum",abs_val,dim=-1,keepdim=False
534+
)
535+
elifp==2:
536+
if (
537+
compute_mode==0and (x1.shape[-2]>25orx2.shape[-2]>25)
538+
)orcompute_mode==1:
539+
# Compute squared elements
540+
x1_squared=impl.elementwise.pow(
541+
ctx,target,source_ir,f"{name}_x1_squared",x1,2
542+
)
543+
x2_squared=impl.elementwise.pow(
544+
ctx,target,source_ir,f"{name}_x2_squared",x2,2
545+
)
546+
547+
# Sum squares along the last dimension
548+
x1_sum_squared=impl.reduce.sum(
549+
ctx,
550+
target,
551+
source_ir,
552+
f"{name}_x1_sum",
553+
x1_squared,
554+
dim=-1,
555+
keepdim=True,
556+
)
557+
x2_sum_squared=impl.reduce.sum(
558+
ctx,
559+
target,
560+
source_ir,
561+
f"{name}_x2_sum",
562+
x2_squared,
563+
dim=-1,
564+
keepdim=True,
565+
)
566+
567+
# Reshape sums for broadcasting
568+
rank=len(x2.shape)
569+
permute_shape=list(range(rank-2))+ [rank-1,rank-2]
570+
x1_sum_expanded=x1_sum_squared
571+
x2_sum_expanded=impl.permutation.permute(
572+
ctx,target,source_ir,f"{name}_permute",x2_sum_squared,permute_shape
573+
)
574+
575+
# Compute dot product of x1 and transposed x2
576+
x2_tr=impl.permutation.permute(
577+
ctx,target,source_ir,f"{name}_permute_mm",x2,permute_shape
578+
)
579+
dot_product=impl.matmul.matrix_multiply(
580+
ctx,
581+
target,
582+
source_ir,
583+
f"{name}_dot_product",
584+
x1,
585+
x2_tr,
586+
input_matrix_op=trt.MatrixOperation.NONE,
587+
other_matrix_op=trt.MatrixOperation.NONE,
588+
)
589+
590+
# Combine results to get squared distances
591+
dist_squared=impl.elementwise.add(
592+
ctx,
593+
target,
594+
source_ir,
595+
f"{name}_dist_squared_initial",
596+
x1_sum_expanded,
597+
x2_sum_expanded,
598+
)
599+
dist_squared=impl.elementwise.sub(
600+
ctx,
601+
target,
602+
source_ir,
603+
f"{name}_dist_squared",
604+
dist_squared,
605+
impl.elementwise.mul(
606+
ctx,target,source_ir,f"{name}_dot_product_scaled",dot_product,2
607+
),
608+
)
609+
610+
# Compute the Euclidean distances
611+
dist=impl.unary.sqrt(ctx,target,source_ir,f"{name}_dist",dist_squared)
612+
else:
613+
diff_squared=impl.elementwise.pow(
614+
ctx,target,source_ir,f"{name}_diff_squared",diff,2
615+
)
616+
dist_squared=impl.reduce.sum(
617+
ctx,
618+
target,
619+
source_ir,
620+
f"{name}_dist_sq_sum",
621+
diff_squared,
622+
dim=-1,
623+
keepdim=False,
624+
)
625+
dist=impl.unary.sqrt(ctx,target,source_ir,f"{name}_sqrt",dist_squared)
626+
elif0<p<1or1<p<2or2<p<float("inf"):
627+
abs_val=impl.unary.abs(ctx,target,source_ir,f"{name}_abs_val",diff)
628+
pow_val=impl.elementwise.pow(
629+
ctx,target,source_ir,f"{name}_pow_val_1",abs_val,p
630+
)
631+
sum_val=impl.reduce.sum(
632+
ctx,target,source_ir,f"{name}_sum",pow_val,dim=-1,keepdim=False
633+
)
634+
dist=impl.elementwise.pow(
635+
ctx,target,source_ir,f"{name}_pow_val_2",sum_val,1/p
636+
)
637+
elifp==float("inf"):
638+
abs_val=impl.unary.abs(ctx,target,source_ir,f"{name}_abs_val",diff)
639+
dist=impl.reduce.max(
640+
ctx,
641+
target,
642+
source_ir,
643+
f"{name}_max",
644+
abs_val,
645+
dim=-1,
646+
keepdim=False,
647+
return_indices=False,
648+
)
649+
returndist
Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
importtorch
2+
importtorch.nnasnn
3+
fromparameterizedimportparameterized
4+
fromtorch.testing._internal.common_utilsimportrun_tests
5+
6+
from .harnessimportDispatchTestCase
7+
8+
9+
classTestCdistConverter(DispatchTestCase):
10+
@parameterized.expand(
11+
[
12+
("p_0", (4,3,4),0,0),
13+
("p>0_p<1_1", (10,3,5,2,6),0.5,1),
14+
("p>0_p<1_2", (10,2,15,2,7,2),0.5,1),
15+
("p_1", (15,10,5),1,None),
16+
("p>1_p<2", (19,11,5),1.5,None),
17+
("small_p_2_mode_1", (6,6,5),2.0,1),
18+
("large_p_2_mode_0", (35,35,5),2.0,0),
19+
("p>2", (15,10,5),2.99,None),
20+
("p_inf", (5,15,5),float("inf"),0),
21+
]
22+
)
23+
deftest_cdist_float_same_shape(self,name,shape,p,compute_mode):
24+
classCdist(nn.Module):
25+
defforward(self,x1,x2):
26+
returntorch.ops.aten._cdist_forward.default(x1,x2,p,compute_mode)
27+
28+
inputs= [torch.randn(shape),torch.randn(shape)]
29+
self.run_test(
30+
Cdist(),
31+
inputs,
32+
)
33+
34+
@parameterized.expand(
35+
[
36+
("p_0", (1,5), (2,3,5),0,0),
37+
("p_1", (4,5), (2,3,5),1,None),
38+
("diff_shape_p_0", (2,5,4,5), (2,5,8,5),0,2),
39+
("diff_shape_p_1", (2,4,5), (2,3,5),1,1),
40+
("p>0_p<1", (2,2,4,5), (2,3,5),0.5,None),
41+
("p>1_p<2", (5,2,12,5), (2,3,5),1.5,1),
42+
("p_2", (2,2,14,5), (2,3,5),2,0),
43+
("p>2", (2,2,4,5), (2,10,5),2.99,2),
44+
("p_inf", (2,2,3,5), (2,8,5),float("inf"),None),
45+
]
46+
)
47+
deftest_cdist_float_broadcast_and_diff_shape(
48+
self,name,shape_1,shape_2,p,compute_mode
49+
):
50+
classCdist(nn.Module):
51+
defforward(self,x1,x2):
52+
returntorch.ops.aten._cdist_forward.default(x1,x2,p,compute_mode)
53+
54+
inputs= [torch.randn(shape_1),torch.randn(shape_2)]
55+
self.run_test(
56+
Cdist(),
57+
inputs,
58+
)
59+
60+
@parameterized.expand(
61+
[
62+
("compute_mode_0", (15,10,5), (15,35,5),2.0,0),
63+
("compute_mode_1", (35,35,5), (35,45,5),2.0,0),
64+
("compute_mode_2", (15,10,5), (15,35,5),2.0,1),
65+
("compute_mode_3", (35,35,5), (35,45,5),2.0,2),
66+
("p_2_mm_shape_1", (2,2,14,5), (3,5),2,1),
67+
("p_2_mm_shape_2", (2,2,14,5), (2,3,5),2,1),
68+
("p_2_mm_shape_3", (2,2,14,5), (2,2,3,5),2,1),
69+
]
70+
)
71+
deftest_cdist_p_2_compute_mode(self,name,shape_1,shape_2,p,compute_mode):
72+
classCdist(nn.Module):
73+
defforward(self,x1,x2):
74+
returntorch.ops.aten._cdist_forward.default(x1,x2,p,compute_mode)
75+
76+
inputs= [torch.randn(shape_1),torch.randn(shape_2)]
77+
self.run_test(Cdist(),inputs)
78+
79+
@parameterized.expand(
80+
[
81+
("p_2_matmul", (50,40,30,30), (50,40,35,30),2,1),
82+
("p_2_elementwise_pow", (50,40,30,50), (50,40,35,50),2,2),
83+
]
84+
)
85+
deftest_cdist_efficiency_p_2_compute_mode(
86+
self,name,shape_1,shape_2,p,compute_mode
87+
):
88+
classCdist(nn.Module):
89+
defforward(self,x1,x2):
90+
returntorch.ops.aten._cdist_forward.default(x1,x2,p,compute_mode)
91+
92+
inputs= [torch.randn(shape_1),torch.randn(shape_2)]
93+
self.run_test(Cdist(),inputs)
94+
95+
96+
if__name__=="__main__":
97+
run_tests()

0 commit comments

Comments
 (0)

[8]ページ先頭

©2009-2025 Movatter.jp