Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

[NVPTX] Add PRMT constant folding and cleanup usage of PRMT node#148906

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to ourterms of service andprivacy statement. We’ll occasionally send you account related emails.

Already on GitHub?Sign in to your account

Merged

Conversation

AlexMaclean
Copy link
Member

No description provided.

@llvmbot
Copy link
Member

@llvm/pr-subscribers-backend-nvptx

Author: Alex MacLean (AlexMaclean)

Changes

Patch is 130.66 KiB, truncated to 20.00 KiB below, full version:https://github.com/llvm/llvm-project/pull/148906.diff

5 Files Affected:

  • (modified) llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp (+138-23)
  • (modified) llvm/lib/Target/NVPTX/NVPTXInstrInfo.td (+19-4)
  • (modified) llvm/lib/Target/NVPTX/NVPTXIntrinsics.td (-18)
  • (modified) llvm/test/CodeGen/NVPTX/i8x4-instructions.ll (+1758-872)
  • (added) llvm/test/CodeGen/NVPTX/prmt-const-folding.ll (+171)
diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cppindex 14f05250ad6b8..e8f3b322ed90e 100644--- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp+++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp@@ -1048,9 +1048,12 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,                       MVT::v32i32, MVT::v64i32, MVT::v128i32},                      Custom);-  setOperationAction(ISD::INTRINSIC_WO_CHAIN, MVT::Other, Custom);-  // Enable custom lowering for the i128 bit operand with clusterlaunchcontrol-  setOperationAction(ISD::INTRINSIC_WO_CHAIN, MVT::i128, Custom);+  // Enable custom lowering for the following:+  //   * MVT::i128 - clusterlaunchcontrol+  //   * MVT::i32 - prmt+  //   * MVT::Other - internal.addrspace.wrap+  setOperationAction(ISD::INTRINSIC_WO_CHAIN, {MVT::i32, MVT::i128, MVT::Other},+                     Custom); }  const char *NVPTXTargetLowering::getTargetNodeName(unsigned Opcode) const {@@ -2060,6 +2063,13 @@ NVPTXTargetLowering::LowerCONCAT_VECTORS(SDValue Op, SelectionDAG &DAG) const {   return DAG.getBuildVector(Node->getValueType(0), dl, Ops); }+static SDValue getPRMT(SDValue A, SDValue B, SDValue Selector, SDLoc DL,+                       SelectionDAG &DAG,+                       unsigned Mode = NVPTX::PTXPrmtMode::NONE) {+  return DAG.getNode(NVPTXISD::PRMT, DL, MVT::i32,+                     {A, B, Selector, DAG.getConstant(Mode, DL, MVT::i32)});+}+ SDValue NVPTXTargetLowering::LowerBITCAST(SDValue Op, SelectionDAG &DAG) const {   // Handle bitcasting from v2i8 without hitting the default promotion   // strategy which goes through stack memory.@@ -2111,15 +2121,13 @@ SDValue NVPTXTargetLowering::LowerBUILD_VECTOR(SDValue Op,         L = DAG.getAnyExtOrTrunc(L, DL, MVT::i32);         R = DAG.getAnyExtOrTrunc(R, DL, MVT::i32);       }-      return DAG.getNode(-          NVPTXISD::PRMT, DL, MVT::v4i8,-          {L, R, DAG.getConstant(SelectionValue, DL, MVT::i32),-           DAG.getConstant(NVPTX::PTXPrmtMode::NONE, DL, MVT::i32)});+      return getPRMT(L, R, DAG.getConstant(SelectionValue, DL, MVT::i32), DL,+                     DAG);     };     auto PRMT__10 = GetPRMT(Op->getOperand(0), Op->getOperand(1), true, 0x3340);     auto PRMT__32 = GetPRMT(Op->getOperand(2), Op->getOperand(3), true, 0x3340);     auto PRMT3210 = GetPRMT(PRMT__10, PRMT__32, false, 0x5410);-    return DAG.getNode(ISD::BITCAST, DL, VT, PRMT3210);+    return DAG.getBitcast(VT, PRMT3210);   }    // Get value or the Nth operand as an APInt(32). Undef values treated as 0.@@ -2176,11 +2184,14 @@ SDValue NVPTXTargetLowering::LowerEXTRACT_VECTOR_ELT(SDValue Op,     SDValue Selector = DAG.getNode(ISD::OR, DL, MVT::i32,                                    DAG.getZExtOrTrunc(Index, DL, MVT::i32),                                    DAG.getConstant(0x7770, DL, MVT::i32));-    SDValue PRMT = DAG.getNode(-        NVPTXISD::PRMT, DL, MVT::i32,-        {DAG.getBitcast(MVT::i32, Vector), DAG.getConstant(0, DL, MVT::i32),-         Selector, DAG.getConstant(NVPTX::PTXPrmtMode::NONE, DL, MVT::i32)});-    return DAG.getAnyExtOrTrunc(PRMT, DL, Op->getValueType(0));+    SDValue PRMT = getPRMT(DAG.getBitcast(MVT::i32, Vector),+                           DAG.getConstant(0, DL, MVT::i32), Selector, DL, DAG);+    SDValue Ext = DAG.getAnyExtOrTrunc(PRMT, DL, Op->getValueType(0));+    SDNodeFlags Flags;+    Flags.setNoSignedWrap(Ext.getScalarValueSizeInBits() > 8);+    Flags.setNoUnsignedWrap(Ext.getScalarValueSizeInBits() >= 8);+    Ext->setFlags(Flags);+    return Ext;   }    // Constant index will be matched by tablegen.@@ -2242,9 +2253,10 @@ SDValue NVPTXTargetLowering::LowerVECTOR_SHUFFLE(SDValue Op,   }    SDLoc DL(Op);-  return DAG.getNode(NVPTXISD::PRMT, DL, MVT::v4i8, V1, V2,-                     DAG.getConstant(Selector, DL, MVT::i32),-                     DAG.getConstant(NVPTX::PTXPrmtMode::NONE, DL, MVT::i32));+  SDValue PRMT =+      getPRMT(DAG.getBitcast(MVT::i32, V1), DAG.getBitcast(MVT::i32, V2),+              DAG.getConstant(Selector, DL, MVT::i32), DL, DAG);+  return DAG.getBitcast(Op.getValueType(), PRMT); } /// LowerShiftRightParts - Lower SRL_PARTS, SRA_PARTS, which /// 1) returns two i32 values and take a 2 x i32 value to shift plus a shift@@ -2729,10 +2741,46 @@ static SDValue LowerClusterLaunchControlQueryCancel(SDValue Op,                      {TryCancelResponse0, TryCancelResponse1}); }+static SDValue lowerPrmtIntrinsic(SDValue Op, SelectionDAG &DAG) {+  const unsigned Mode = [&]() {+    switch (Op->getConstantOperandVal(0)) {+    case Intrinsic::nvvm_prmt:+      return NVPTX::PTXPrmtMode::NONE;+    case Intrinsic::nvvm_prmt_b4e:+      return NVPTX::PTXPrmtMode::B4E;+    case Intrinsic::nvvm_prmt_ecl:+      return NVPTX::PTXPrmtMode::ECL;+    case Intrinsic::nvvm_prmt_ecr:+      return NVPTX::PTXPrmtMode::ECR;+    case Intrinsic::nvvm_prmt_f4e:+      return NVPTX::PTXPrmtMode::F4E;+    case Intrinsic::nvvm_prmt_rc16:+      return NVPTX::PTXPrmtMode::RC16;+    case Intrinsic::nvvm_prmt_rc8:+      return NVPTX::PTXPrmtMode::RC8;+    default:+      llvm_unreachable("unsupported/unhandled intrinsic");+    }+  }();+  SDLoc DL(Op);+  SDValue A = Op->getOperand(1);+  SDValue B = Op.getNumOperands() == 4 ? Op.getOperand(2)+                                       : DAG.getConstant(0, DL, MVT::i32);+  SDValue Selector = (Op->op_end() - 1)->get();+  return getPRMT(A, B, Selector, DL, DAG, Mode);+} static SDValue lowerIntrinsicWOChain(SDValue Op, SelectionDAG &DAG) {   switch (Op->getConstantOperandVal(0)) {   default:     return Op;+  case Intrinsic::nvvm_prmt:+  case Intrinsic::nvvm_prmt_b4e:+  case Intrinsic::nvvm_prmt_ecl:+  case Intrinsic::nvvm_prmt_ecr:+  case Intrinsic::nvvm_prmt_f4e:+  case Intrinsic::nvvm_prmt_rc16:+  case Intrinsic::nvvm_prmt_rc8:+    return lowerPrmtIntrinsic(Op, DAG);   case Intrinsic::nvvm_internal_addrspace_wrap:     return Op.getOperand(1);   case Intrinsic::nvvm_clusterlaunchcontrol_query_cancel_is_canceled:@@ -5775,11 +5823,10 @@ PerformBUILD_VECTORCombine(SDNode *N, TargetLowering::DAGCombinerInfo &DCI) {   SDLoc DL(N);   auto &DAG = DCI.DAG;-  auto PRMT = DAG.getNode(-      NVPTXISD::PRMT, DL, MVT::v4i8,-      {Op0, Op1, DAG.getConstant((Op1Bytes << 8) | Op0Bytes, DL, MVT::i32),-       DAG.getConstant(NVPTX::PTXPrmtMode::NONE, DL, MVT::i32)});-  return DAG.getNode(ISD::BITCAST, DL, VT, PRMT);+  auto PRMT = getPRMT(+      DAG.getBitcast(MVT::i32, Op0), DAG.getBitcast(MVT::i32, Op1),+      DAG.getConstant((Op1Bytes << 8) | Op0Bytes, DL, MVT::i32), DL, DAG);+  return DAG.getBitcast(VT, PRMT); }  static SDValue combineADDRSPACECAST(SDNode *N,@@ -5797,6 +5844,72 @@ static SDValue combineADDRSPACECAST(SDNode *N,   return SDValue(); }+static APInt getPRMTSelector(APInt Selector, unsigned Mode) {+  if (Mode == NVPTX::PTXPrmtMode::NONE)+    return Selector;++  unsigned V = Selector.trunc(2).getZExtValue();++  const auto GetSelector = [](unsigned S0, unsigned S1, unsigned S2,+                              unsigned S3) {+    return APInt(32, S0 | (S1 << 4) | (S2 << 8) | (S3 << 12));+  };++  switch (Mode) {+  case NVPTX::PTXPrmtMode::F4E:+    return GetSelector(V, V + 1, V + 2, V + 3);+  case NVPTX::PTXPrmtMode::B4E:+    return GetSelector(V, (V - 1) & 7, (V - 2) & 7, (V - 3) & 7);+  case NVPTX::PTXPrmtMode::RC8:+    return GetSelector(V, V, V, V);+  case NVPTX::PTXPrmtMode::ECL:+    return GetSelector(V, std::max(V, 1U), std::max(V, 2U), 3U);+  case NVPTX::PTXPrmtMode::ECR:+    return GetSelector(0, std::min(V, 1U), std::min(V, 2U), V);+  case NVPTX::PTXPrmtMode::RC16: {+    unsigned V1 = (V & 1) << 1;+    return GetSelector(V1, V1 + 1, V1, V1 + 1);+  }+  default:+    llvm_unreachable("Invalid PRMT mode");+  }+}++static APInt computePRMT(APInt A, APInt B, APInt Selector, unsigned Mode) {+  // {b, a} = {{b7, b6, b5, b4}, {b3, b2, b1, b0}}+  APInt BitField = B.concat(A);+  APInt SelectorVal = getPRMTSelector(Selector, Mode);+  APInt Result(32, 0);+  for (unsigned I : llvm::seq(4U)) {+    APInt Sel = SelectorVal.extractBits(4, I * 4);+    unsigned Idx = Sel.getLoBits(3).getZExtValue();+    unsigned Sign = Sel.getHiBits(1).getZExtValue();+    APInt Byte = BitField.extractBits(8, Idx * 8);+    if (Sign)+      Byte = Byte.ashr(8);+    Result.insertBits(Byte, I * 8);+  }+  return Result;+}++static SDValue combinePRMT(SDNode *N, TargetLowering::DAGCombinerInfo &DCI,+                           CodeGenOptLevel OptLevel) {+  if (OptLevel == CodeGenOptLevel::None)+    return SDValue();++  // Constant fold PRMT+  if (isa<ConstantSDNode>(N->getOperand(0)) &&+      isa<ConstantSDNode>(N->getOperand(1)) &&+      isa<ConstantSDNode>(N->getOperand(2)))+    return DCI.DAG.getConstant(computePRMT(N->getConstantOperandAPInt(0),+                                           N->getConstantOperandAPInt(1),+                                           N->getConstantOperandAPInt(2),+                                           N->getConstantOperandVal(3)),+                               SDLoc(N), N->getValueType(0));++  return SDValue();+}+ SDValue NVPTXTargetLowering::PerformDAGCombine(SDNode *N,                                                DAGCombinerInfo &DCI) const {   CodeGenOptLevel OptLevel = getTargetMachine().getOptLevel();@@ -5838,6 +5951,8 @@ SDValue NVPTXTargetLowering::PerformDAGCombine(SDNode *N,       return PerformBUILD_VECTORCombine(N, DCI);     case ISD::ADDRSPACECAST:       return combineADDRSPACECAST(N, DCI);+    case NVPTXISD::PRMT:+      return combinePRMT(N, DCI, OptLevel);   }   return SDValue(); }@@ -6385,7 +6500,7 @@ static void computeKnownBitsForPRMT(const SDValue Op, KnownBits &Known,   ConstantSDNode *Selector = dyn_cast<ConstantSDNode>(Op.getOperand(2));   unsigned Mode = Op.getConstantOperandVal(3);-  if (Mode != NVPTX::PTXPrmtMode::NONE || !Selector)+  if (!Selector)     return;    KnownBits AKnown = DAG.computeKnownBits(A, Depth);@@ -6394,7 +6509,7 @@ static void computeKnownBitsForPRMT(const SDValue Op, KnownBits &Known,   // {b, a} = {{b7, b6, b5, b4}, {b3, b2, b1, b0}}   KnownBits BitField = BKnown.concat(AKnown);-  APInt SelectorVal = Selector->getAPIntValue();+  APInt SelectorVal = getPRMTSelector(Selector->getAPIntValue(), Mode);   for (unsigned I : llvm::seq(std::min(4U, Known.getBitWidth() / 8))) {     APInt Sel = SelectorVal.extractBits(4, I * 4);     unsigned Idx = Sel.getLoBits(3).getZExtValue();diff --git a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.tdindex ecae03e77aa83..6741ccbb43abc 100644--- a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td+++ b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td@@ -1453,18 +1453,33 @@ let hasSideEffects = false in {                 (ins PrmtMode:$mode),                 "prmt.b32$mode",                 [(set i32:$d, (prmt i32:$a, i32:$b, imm:$c, imm:$mode))]>;+  def PRMT_B32rir+  : BasicFlagsNVPTXInst<(outs B32:$d),+              (ins B32:$a, i32imm:$b, B32:$c),+              (ins PrmtMode:$mode),+              "prmt.b32$mode",+              [(set i32:$d, (prmt i32:$a, imm:$b, i32:$c, imm:$mode))]>;   def PRMT_B32rii     : BasicFlagsNVPTXInst<(outs B32:$d),                 (ins B32:$a, i32imm:$b, Hexu32imm:$c),                 (ins PrmtMode:$mode),                 "prmt.b32$mode",                 [(set i32:$d, (prmt i32:$a, imm:$b, imm:$c, imm:$mode))]>;-  def PRMT_B32rir+  def PRMT_B32irr     : BasicFlagsNVPTXInst<(outs B32:$d),-                (ins B32:$a, i32imm:$b, B32:$c),-                (ins PrmtMode:$mode),+                (ins i32imm:$a, B32:$b, B32:$c), (ins PrmtMode:$mode),+                "prmt.b32$mode",+                [(set i32:$d, (prmt imm:$a, i32:$b, i32:$c, imm:$mode))]>;+  def PRMT_B32iri+    : BasicFlagsNVPTXInst<(outs B32:$d),+                (ins i32imm:$a, B32:$b, Hexu32imm:$c), (ins PrmtMode:$mode),+                "prmt.b32$mode",+                [(set i32:$d, (prmt imm:$a, i32:$b, imm:$c, imm:$mode))]>;+  def PRMT_B32iir+    : BasicFlagsNVPTXInst<(outs B32:$d),+                (ins i32imm:$a, i32imm:$b, B32:$c), (ins PrmtMode:$mode),                 "prmt.b32$mode",-                [(set i32:$d, (prmt i32:$a, imm:$b, i32:$c, imm:$mode))]>;+                [(set i32:$d, (prmt imm:$a, imm:$b, i32:$c, imm:$mode))]>;  }diff --git a/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td b/llvm/lib/Target/NVPTX/NVPTXIntrinsics.tdindex 93827be5c2811..bdddf3f56cb13 100644--- a/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td+++ b/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td@@ -1007,24 +1007,6 @@ class F_MATH_3<string OpcStr, NVPTXRegClass t_regclass, // MISC //-class PRMT3Pat<Intrinsic prmt_intrinsic, PatLeaf prmt_mode>-    : Pat<(prmt_intrinsic i32:$a, i32:$b, i32:$c),-          (PRMT_B32rrr $a, $b, $c, prmt_mode)>;--class PRMT2Pat<Intrinsic prmt_intrinsic, PatLeaf prmt_mode>-    : Pat<(prmt_intrinsic i32:$a, i32:$c),-          (PRMT_B32rir $a, (i32 0), $c, prmt_mode)>;--def : PRMT3Pat<int_nvvm_prmt,      PrmtNONE>;-def : PRMT3Pat<int_nvvm_prmt_f4e,  PrmtF4E>;-def : PRMT3Pat<int_nvvm_prmt_b4e,  PrmtB4E>;--def : PRMT2Pat<int_nvvm_prmt_rc8,  PrmtRC8>;-def : PRMT2Pat<int_nvvm_prmt_ecl,  PrmtECL>;-def : PRMT2Pat<int_nvvm_prmt_ecr,  PrmtECR>;-def : PRMT2Pat<int_nvvm_prmt_rc16, PrmtRC16>;-- def INT_NVVM_NANOSLEEP_I : BasicNVPTXInst<(outs), (ins i32imm:$i), "nanosleep.u32",                              [(int_nvvm_nanosleep imm:$i)]>,         Requires<[hasPTX<63>, hasSM<70>]>;diff --git a/llvm/test/CodeGen/NVPTX/i8x4-instructions.ll b/llvm/test/CodeGen/NVPTX/i8x4-instructions.llindex 410c0019c7222..cbc9f700b1f01 100644--- a/llvm/test/CodeGen/NVPTX/i8x4-instructions.ll+++ b/llvm/test/CodeGen/NVPTX/i8x4-instructions.ll@@ -1,14 +1,19 @@ ; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 3 ; ## Support i16x2 instructions-; RUN: llc < %s -mtriple=nvptx64-nvidia-cuda -mcpu=sm_90 -mattr=+ptx80 \-; RUN:          -O0 -disable-post-ra -frame-pointer=all -verify-machineinstrs \-; RUN: | FileCheck -allow-deprecated-dag-overlap %s-; RUN: %if ptxas %{                                                           \-; RUN:   llc < %s -mtriple=nvptx64-nvidia-cuda -mcpu=sm_90 \-; RUN:          -O0 -disable-post-ra -frame-pointer=all -verify-machineinstrs \-; RUN:   | %ptxas-verify -arch=sm_90                                          \+; RUN: llc < %s -mcpu=sm_90 -mattr=+ptx80 -O0 -disable-post-ra -frame-pointer=all -verify-machineinstrs \+; RUN: | FileCheck %s --check-prefixes=CHECK,O0+; RUN: llc < %s -mcpu=sm_90 -mattr=+ptx80 -verify-machineinstrs \+; RUN: | FileCheck %s --check-prefixes=CHECK,O3+; RUN: %if ptxas %{                                                            \+; RUN:   llc < %s -mcpu=sm_90 -mattr=+ptx80 -O0 -disable-post-ra -frame-pointer=all -verify-machineinstrs \+; RUN:   | %ptxas-verify -arch=sm_90                                           \+; RUN: %}+; RUN: %if ptxas %{                                                            \+; RUN:   llc < %s -mcpu=sm_90 -mattr=+ptx80 -verify-machineinstrs              \+; RUN:   | %ptxas-verify -arch=sm_90                                           \ ; RUN: %}+target triple = "nvptx64-nvidia-cuda" target datalayout = "e-m:o-i64:64-i128:128-n32:64-S128"  define <4 x i8> @test_ret_const() #0 {@@ -79,61 +84,111 @@ define i8 @test_extract_3(<4 x i8> %a) #0 { }  define i8 @test_extract_i(<4 x i8> %a, i64 %idx) #0 {-; CHECK-LABEL: test_extract_i(-; CHECK:       {-; CHECK-NEXT:    .reg .b32 %r<5>;-; CHECK-NEXT:    .reg .b64 %rd<2>;-; CHECK-EMPTY:-; CHECK-NEXT:  // %bb.0:-; CHECK-NEXT:    ld.param.b64 %rd1, [test_extract_i_param_1];-; CHECK-NEXT:    ld.param.b32 %r1, [test_extract_i_param_0];-; CHECK-NEXT:    cvt.u32.u64 %r2, %rd1;-; CHECK-NEXT:    or.b32 %r3, %r2, 30576;-; CHECK-NEXT:    prmt.b32 %r4, %r1, 0, %r3;-; CHECK-NEXT:    st.param.b32 [func_retval0], %r4;-; CHECK-NEXT:    ret;+; O0-LABEL: test_extract_i(+; O0:       {+; O0-NEXT:    .reg .b32 %r<5>;+; O0-NEXT:    .reg .b64 %rd<2>;+; O0-EMPTY:+; O0-NEXT:  // %bb.0:+; O0-NEXT:    ld.param.b64 %rd1, [test_extract_i_param_1];+; O0-NEXT:    ld.param.b32 %r1, [test_extract_i_param_0];+; O0-NEXT:    cvt.u32.u64 %r2, %rd1;+; O0-NEXT:    or.b32 %r3, %r2, 30576;+; O0-NEXT:    prmt.b32 %r4, %r1, 0, %r3;+; O0-NEXT:    st.param.b32 [func_retval0], %r4;+; O0-NEXT:    ret;+;+; O3-LABEL: test_extract_i(+; O3:       {+; O3-NEXT:    .reg .b32 %r<5>;+; O3-EMPTY:+; O3-NEXT:  // %bb.0:+; O3-NEXT:    ld.param.b32 %r1, [test_extract_i_param_0];+; O3-NEXT:    ld.param.b32 %r2, [test_extract_i_param_1];+; O3-NEXT:    or.b32 %r3, %r2, 30576;+; O3-NEXT:    prmt.b32 %r4, %r1, 0, %r3;+; O3-NEXT:    st.param.b32 [func_retval0], %r4;+; O3-NEXT:    ret;   %e = extractelement <4 x i8> %a, i64 %idx   ret i8 %e }  define <4 x i8> @test_add(<4 x i8> %a, <4 x i8> %b) #0 {-; CHECK-LABEL: test_add(-; CHECK:       {-; CHECK-NEXT:    .reg .b16 %rs<13>;-; CHECK-NEXT:    .reg .b32 %r<18>;-; CHECK-EMPTY:-; CHECK-NEXT:  // %bb.0:-; CHECK-NEXT:    ld.param.b32 %r2, [test_add_param_1];-; CHECK-NEXT:    ld.param.b32 %r1, [test_add_param_0];-; CHECK-NEXT:    prmt.b32 %r3, %r2, 0, 0x7773U;-; CHECK-NEXT:    cvt.u16.u32 %rs1, %r3;-; CHECK-NEXT:    prmt.b32 %r4, %r1, 0, 0x7773U;-; CHECK-NEXT:    cvt.u16.u32 %rs2, %r4;-; CHECK-NEXT:    add.s16 %rs3, %rs2, %rs1;-; CHECK-NEXT:    cvt.u32.u16 %r5, %rs3;-; CHECK-NEXT:    prmt.b32 %r6, %r2, 0, 0x7772U;-; CHECK-NEXT:    cvt.u16.u32 %rs4, %r6;-; CHECK-NEXT:    prmt.b32 %r7, %r1, 0, 0x7772U;-; CHECK-NEXT:    cvt.u16.u32 %rs5, %r7;-; CHECK-NEXT:    add.s16 %rs6, %rs5, %rs4;-; CHECK-NEXT:    cvt.u32.u16 %r8, %rs6;-; CHECK-NEXT:    prmt.b32 %r9, %r8, %r5, 0x3340U;-; CHECK-NEXT:    prmt.b32 %r10, %r2, 0, 0x7771U;-; CHECK-NEXT:    cvt.u16.u32 %rs7, %r10;-; CHECK-NEXT:    prmt.b32 %r11, %r1, 0, 0x7771U;-; CHECK-NEXT:    cvt.u16.u32 %rs8, %r11;-; CHECK-NEXT:    add.s16 %rs9, %rs8, %rs7;-; CHECK-NEXT:    cvt.u32.u16 %r12, %rs9;-; CHECK-NEXT:    prmt.b32 %r13, %r2, 0, 0x7770U;-; CHECK-NEXT:    cvt.u16.u32 %rs10, %r13;-; CHECK-NEXT:    prmt.b32 %r14, %r1, 0, 0x7770U;-; CHECK-NEXT:    cvt.u16.u32 %rs11, %r14;-; CHECK-NEXT:    add.s16 %rs12, %rs11, %rs10;-; CHECK-NEXT:    cvt.u32.u16 %r15, %rs12;-; CHECK-NEXT:    prmt.b32 %r16, %r15, %r12, 0x3340U;-; CHECK-NEXT:    prmt.b32 %r17, %r16, %r9, 0x5410U;-; CHECK-NEXT:    st.param.b32 [func_retval0], %r17;-; CHECK-NEXT:    ret;+; O0-LABEL: test_add(+; O0:       {+; O0-NEXT:    .reg .b16 %rs<13>;+; O0-NEXT:    .reg .b32 %r<18>;+; O0-EMPTY:+; O0-NEXT:  // %bb.0:+; O0-NEXT:    ld.param.b32 %r2, [test_add_param_1];+; O0-NEXT:    ld.param.b32 %r1, [test_add_param_0];+; O0-NEXT:    prmt.b32 %r3, %r2, 0, 0x7773U;+; O0-NEXT:    cvt.u16.u32 %rs1, %r3;+; O0-NEXT:    prmt.b32 %r4, %r1, 0, 0x7773U;+; O0-NEXT:    cvt.u16.u32 %rs2, %r4;+; O0-NEXT:    add.s16 %rs3, %rs2, %rs1;+; O0-NEXT:    cvt.u32.u16 %r5, %rs3;+; O0-NEXT:    prmt.b32 %r6, %r2, 0, 0x7772U;+; O0-NEXT:    cvt.u16.u32 %rs4, %r6;+; O0-NEXT:    prmt.b32 %r7, %r1, 0, 0x7772U;+; O0-NEXT:    cvt.u16.u32 %rs5, %r7;+; O0-NEXT:    add.s16 %rs6, %rs5, %rs4;+; O0-NEXT:    cvt.u32.u16 %r8, %rs6;+; O0-NEXT:    prmt.b32 %r9, %r8, %r5, 0x3340U;+; O0-NEXT:    prmt.b32 %r10, %r2, 0, 0x7771U;+; O0-NEXT:    cvt.u16.u32 %rs7, %r10;+; O0-NEXT:    prmt.b32 %r11, %r1, 0, 0x7771U;+; O0-NEXT:    cvt.u16.u32 %rs8, %r11;+; O0-NEXT:    add.s16 %rs9, %rs8, %rs7;+; O0-NEXT:    cvt.u32.u16 %r12, %rs9;+; O0-NEXT:    prmt.b32 %r13, %r2, 0, 0x7770U;+; O0-NEXT:    cvt.u16.u32 %rs10, %r13;+; O0-NEXT:    prmt.b32 %r14, %r1, 0, 0x7770U;+; O0-NEXT:    cvt.u16.u32 %rs11, %r14;+; O0-NEXT:    add.s16 %rs12, %rs11, %rs10;+; O0-NEXT:    cvt.u32.u16 %r15, %rs12;+; O0-NEXT:    prmt.b32 %r16, %r15, %r12, 0x3340U;+; O0-NEXT:    prmt.b32 %r17, %r16, %r9, 0x5410U;+; O0-NEXT:    st.param.b32 [func_retval0], %r17;+; O0-NEXT:    ret;+;+; O3-LABEL: test_add(+; O3:       {+; O3-NEXT:    .reg .b16 %rs<13>;+; O3-NEXT:    .reg .b32 %r<18>;+; O3-EMPTY:+; O3-NEXT:  // %bb.0:+; O3-NEXT:    ld.param.b32 %r1, [test_add_param_0];+; O3-NEXT:    ld.param.b32 %r2, [test_add_param_1];+; O3-NEXT:    prmt.b32 %r3, %r2, 0, 0x7773U;+; O3...[truncated]

@github-actionsGitHub Actions
Copy link

github-actionsbot commentedJul 15, 2025
edited
Loading

✅ With the latest revision this PR passed the C/C++ code formatter.

@AlexMacleanAlexMacleanforce-pushed thedev/amaclean/upstream/prmt-2 branch from931aecd to43ba2f9CompareJuly 15, 2025 17:44
@AlexMacleanAlexMaclean requested a review fromkalxrJuly 16, 2025 16:10
Copy link
Member

@Artem-BArtem-B left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others.Learn more.

LGTM with a few nits.

Comment on lines +2066 to +2068
static SDValue getPRMT(SDValue A, SDValue B, SDValue Selector, SDLoc DL,
SelectionDAG &DAG,
unsigned Mode = NVPTX::PTXPrmtMode::NONE) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others.Learn more.

I'd add another overload with Selector provided as an integer. That seems to be a common pattern that forces us to sprinkleDAG.getConstant(X, DL, MVT::i32) in numerous places.

Copy link
MemberAuthor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others.Learn more.

Added

@@ -5797,47 +5844,116 @@ static SDValue combineADDRSPACECAST(SDNode *N,
return SDValue();
}

static APInt getPRMTSelector(APInt Selector, unsigned Mode) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others.Learn more.

Nit: a pointer to the selector encoding docs would be useful.

Copy link
MemberAuthor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others.Learn more.

Added

Comment on lines 5955 to 5956
case NVPTXISD::PRMT:
return combinePRMT(N, DCI, OptLevel);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others.Learn more.

Nit: Perhaps it's a good opportunity to sort the case values.

Copy link
MemberAuthor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others.Learn more.

Done

@AlexMacleanAlexMacleanforce-pushed thedev/amaclean/upstream/prmt-2 branch from43ba2f9 to6ccb74cCompareJuly 17, 2025 16:50
@AlexMacleanAlexMacleanforce-pushed thedev/amaclean/upstream/prmt-2 branch from6ccb74c to81a2e32CompareJuly 17, 2025 16:58
@AlexMacleanAlexMaclean merged commitf480e1b intollvm:mainJul 17, 2025
9 checks passed
Sign up for freeto join this conversation on GitHub. Already have an account?Sign in to comment
Reviewers

@Artem-BArtem-BArtem-B approved these changes

@justinfargnolijustinfargnolijustinfargnoli approved these changes

@kalxrkalxrAwaiting requested review from kalxr

Assignees

@AlexMacleanAlexMaclean

Projects
None yet
Milestone
No milestone
Development

Successfully merging this pull request may close these issues.

4 participants
@AlexMaclean@llvmbot@Artem-B@justinfargnoli

[8]ページ先頭

©2009-2025 Movatter.jp