|
| 1 | +importlogging |
1 | 2 | fromtypingimportAny,List,Optional,Sequence,Tuple,Union,cast |
2 | 3 |
|
3 | 4 | importnumpyasnp |
|
21 | 22 | fromtorch_tensorrt.fx.typesimportTRTTensor |
22 | 23 | fromtorch_tensorrt.fx.utilsimportget_dynamic_dims |
23 | 24 |
|
| 25 | +_LOGGER:logging.Logger=logging.getLogger(__name__) |
| 26 | + |
24 | 27 |
|
25 | 28 | defbatch_norm( |
26 | 29 | ctx:ConversionContext, |
@@ -446,3 +449,201 @@ def pdist( |
446 | 449 | ) |
447 | 450 | indices=np.triu_indices(shape[0],k=1) |
448 | 451 | returnimpl.select.index(ctx,target,source_ir,f"{name}_index",norm,indices) |
| 452 | + |
| 453 | + |
| 454 | +defcdist_forward( |
| 455 | +ctx:ConversionContext, |
| 456 | +target:Target, |
| 457 | +source_ir:Optional[SourceIR], |
| 458 | +name:str, |
| 459 | +x1:TRTTensor, |
| 460 | +x2:TRTTensor, |
| 461 | +p:float, |
| 462 | +compute_mode:Optional[int], |
| 463 | +)->Union[TRTTensor,Sequence[TRTTensor]]: |
| 464 | +""" |
| 465 | + Computes pairwise distances between sets of vectors in tensors x1 and x2 using the p-norm. The function treats the last dimension |
| 466 | + of x1 and x2 as feature dimensions, which must be identical for both inputs. The second-to-last dimensions can differ, reflecting |
| 467 | + the number of vectors in each tensor. The dimensions preceding the last are considered as batch dimensions, and pairwise distances |
| 468 | + are computed for each matching set in these dimensions. |
| 469 | +
|
| 470 | + The output tensor's shape is derived by matching the batch dimensions of x1 and x2, where the mismatched batch dimensions are |
| 471 | + merged, and the resulting shape reflects the computed distances for each pair of vectors. It's crucial that the batch dimensions |
| 472 | + (except for the size of sets of vectors to compare) of x1 and x2 either match or one of them is 1 (broadcasting). |
| 473 | +
|
| 474 | + Args: |
| 475 | + x1 (Tensor): input tensor of shape B x P x M. |
| 476 | + x2 (Tensor): input tensor of shape B x R x M. |
| 477 | + p (float): p value for the p-norm distance to calculate between each vector pair |
| 478 | + compute_mode (int): Controls the computation method based on the size of the input sets: |
| 479 | + - None ('use_mm_for_euclid_dist_if_necessary'): Default mode. Uses matrix multiplication to calculate |
| 480 | + Euclidean distance (p=2) if either the number of vectors in x1 or x2 exceeds 25 (P > 25 or R > 25). |
| 481 | + - 1 ('use_mm_for_euclid_dist'): Always use matrix multiplication approach to calculate |
| 482 | + euclidean distance (p = 2) |
| 483 | + - 2 ('donot_use_mm_for_euclid_dist'): Never use matrix multiplication approach to calculate |
| 484 | + euclidean distance (p = 2) |
| 485 | +
|
| 486 | + Example: |
| 487 | + - If x1.shape = [2, 3, 10, 5] and x2.shape = [2, 3, 20, 5], both having the same batch dimensions [2, 3], the output shape will be [2, 3, 10, 20]. |
| 488 | + This represents computing distances in two batches of three groups, each comparing 10 vectors from x1 with 20 vectors from x2. |
| 489 | + - For x1.shape = [10, 5] (10 vectors, each of 5 features) and x2.shape = [20, 5] (20 vectors, each of 5 features), |
| 490 | + since there are no batch dimensions to match, the output shape is simply [10, 20], comparing all vectors from x1 against all vectors from x2. |
| 491 | +
|
| 492 | + Note: The `compute_mode` parameter is designed to optimize the performance of the Euclidean distance calculation, |
| 493 | + especially useful when working with large datasets. This parameter allows you to control how the distances are computed, |
| 494 | + with different modes available to leverage matrix multiplication for speed improvements. |
| 495 | + """ |
| 496 | +ifcompute_modeisNone: |
| 497 | +compute_mode=0 |
| 498 | + |
| 499 | +x1_expand_shape=list(x1.shape[:-1])+ [1,x1.shape[-1]] |
| 500 | +x2_expand_shape=list(x2.shape[:-2])+ [1]+list(x2.shape[-2:]) |
| 501 | + |
| 502 | +# Reshape x1 and x2 for broadcasting |
| 503 | +x1_expanded=impl.shuffle.reshape( |
| 504 | +ctx,target,source_ir,f"{name}_x1_expand",x1,x1_expand_shape |
| 505 | + ) |
| 506 | +x2_expanded=impl.shuffle.reshape( |
| 507 | +ctx,target,source_ir,f"{name}_x2_expand",x2,x2_expand_shape |
| 508 | + ) |
| 509 | + |
| 510 | +diff=impl.elementwise.sub( |
| 511 | +ctx,target,source_ir,f"{name}_diff",x1_expanded,x2_expanded |
| 512 | + ) |
| 513 | + |
| 514 | +ifp==0: |
| 515 | +diff_non_zero=impl.elementwise.ne( |
| 516 | +ctx,target,source_ir,f"{name}_diff_non_zero",diff,0 |
| 517 | + ) |
| 518 | +diff_non_zero=cast_trt_tensor( |
| 519 | +ctx,diff_non_zero,torch.float32,f"{name}_cast",target,source_ir |
| 520 | + ) |
| 521 | +dist=impl.reduce.sum( |
| 522 | +ctx, |
| 523 | +target, |
| 524 | +source_ir, |
| 525 | +f"{name}_sum", |
| 526 | +diff_non_zero, |
| 527 | +dim=-1, |
| 528 | +keepdim=False, |
| 529 | + ) |
| 530 | +elifp==1: |
| 531 | +abs_val=impl.unary.abs(ctx,target,source_ir,f"{name}_abs_val",diff) |
| 532 | +dist=impl.reduce.sum( |
| 533 | +ctx,target,source_ir,f"{name}_sum",abs_val,dim=-1,keepdim=False |
| 534 | + ) |
| 535 | +elifp==2: |
| 536 | +if ( |
| 537 | +compute_mode==0and (x1.shape[-2]>25orx2.shape[-2]>25) |
| 538 | + )orcompute_mode==1: |
| 539 | +# Compute squared elements |
| 540 | +x1_squared=impl.elementwise.pow( |
| 541 | +ctx,target,source_ir,f"{name}_x1_squared",x1,2 |
| 542 | + ) |
| 543 | +x2_squared=impl.elementwise.pow( |
| 544 | +ctx,target,source_ir,f"{name}_x2_squared",x2,2 |
| 545 | + ) |
| 546 | + |
| 547 | +# Sum squares along the last dimension |
| 548 | +x1_sum_squared=impl.reduce.sum( |
| 549 | +ctx, |
| 550 | +target, |
| 551 | +source_ir, |
| 552 | +f"{name}_x1_sum", |
| 553 | +x1_squared, |
| 554 | +dim=-1, |
| 555 | +keepdim=True, |
| 556 | + ) |
| 557 | +x2_sum_squared=impl.reduce.sum( |
| 558 | +ctx, |
| 559 | +target, |
| 560 | +source_ir, |
| 561 | +f"{name}_x2_sum", |
| 562 | +x2_squared, |
| 563 | +dim=-1, |
| 564 | +keepdim=True, |
| 565 | + ) |
| 566 | + |
| 567 | +# Reshape sums for broadcasting |
| 568 | +rank=len(x2.shape) |
| 569 | +permute_shape=list(range(rank-2))+ [rank-1,rank-2] |
| 570 | +x1_sum_expanded=x1_sum_squared |
| 571 | +x2_sum_expanded=impl.permutation.permute( |
| 572 | +ctx,target,source_ir,f"{name}_permute",x2_sum_squared,permute_shape |
| 573 | + ) |
| 574 | + |
| 575 | +# Compute dot product of x1 and transposed x2 |
| 576 | +x2_tr=impl.permutation.permute( |
| 577 | +ctx,target,source_ir,f"{name}_permute_mm",x2,permute_shape |
| 578 | + ) |
| 579 | +dot_product=impl.matmul.matrix_multiply( |
| 580 | +ctx, |
| 581 | +target, |
| 582 | +source_ir, |
| 583 | +f"{name}_dot_product", |
| 584 | +x1, |
| 585 | +x2_tr, |
| 586 | +input_matrix_op=trt.MatrixOperation.NONE, |
| 587 | +other_matrix_op=trt.MatrixOperation.NONE, |
| 588 | + ) |
| 589 | + |
| 590 | +# Combine results to get squared distances |
| 591 | +dist_squared=impl.elementwise.add( |
| 592 | +ctx, |
| 593 | +target, |
| 594 | +source_ir, |
| 595 | +f"{name}_dist_squared_initial", |
| 596 | +x1_sum_expanded, |
| 597 | +x2_sum_expanded, |
| 598 | + ) |
| 599 | +dist_squared=impl.elementwise.sub( |
| 600 | +ctx, |
| 601 | +target, |
| 602 | +source_ir, |
| 603 | +f"{name}_dist_squared", |
| 604 | +dist_squared, |
| 605 | +impl.elementwise.mul( |
| 606 | +ctx,target,source_ir,f"{name}_dot_product_scaled",dot_product,2 |
| 607 | + ), |
| 608 | + ) |
| 609 | + |
| 610 | +# Compute the Euclidean distances |
| 611 | +dist=impl.unary.sqrt(ctx,target,source_ir,f"{name}_dist",dist_squared) |
| 612 | +else: |
| 613 | +diff_squared=impl.elementwise.pow( |
| 614 | +ctx,target,source_ir,f"{name}_diff_squared",diff,2 |
| 615 | + ) |
| 616 | +dist_squared=impl.reduce.sum( |
| 617 | +ctx, |
| 618 | +target, |
| 619 | +source_ir, |
| 620 | +f"{name}_dist_sq_sum", |
| 621 | +diff_squared, |
| 622 | +dim=-1, |
| 623 | +keepdim=False, |
| 624 | + ) |
| 625 | +dist=impl.unary.sqrt(ctx,target,source_ir,f"{name}_sqrt",dist_squared) |
| 626 | +elif0<p<1or1<p<2or2<p<float("inf"): |
| 627 | +abs_val=impl.unary.abs(ctx,target,source_ir,f"{name}_abs_val",diff) |
| 628 | +pow_val=impl.elementwise.pow( |
| 629 | +ctx,target,source_ir,f"{name}_pow_val_1",abs_val,p |
| 630 | + ) |
| 631 | +sum_val=impl.reduce.sum( |
| 632 | +ctx,target,source_ir,f"{name}_sum",pow_val,dim=-1,keepdim=False |
| 633 | + ) |
| 634 | +dist=impl.elementwise.pow( |
| 635 | +ctx,target,source_ir,f"{name}_pow_val_2",sum_val,1/p |
| 636 | + ) |
| 637 | +elifp==float("inf"): |
| 638 | +abs_val=impl.unary.abs(ctx,target,source_ir,f"{name}_abs_val",diff) |
| 639 | +dist=impl.reduce.max( |
| 640 | +ctx, |
| 641 | +target, |
| 642 | +source_ir, |
| 643 | +f"{name}_max", |
| 644 | +abs_val, |
| 645 | +dim=-1, |
| 646 | +keepdim=False, |
| 647 | +return_indices=False, |
| 648 | + ) |
| 649 | +returndist |