- Notifications
You must be signed in to change notification settings - Fork14.5k
[IR2Vec][NFC] Add helper methods for numeric ID mapping in Vocabulary#149212
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to ourterms of service andprivacy statement. We’ll occasionally send you account related emails.
Already on GitHub?Sign in to your account
Uh oh!
There was an error while loading.Please reload this page.
Conversation
svkeerthy commentedJul 16, 2025 • edited
Loading Uh oh!
There was an error while loading.Please reload this page.
edited
Uh oh!
There was an error while loading.Please reload this page.
llvmbot commentedJul 16, 2025 • edited
Loading Uh oh!
There was an error while loading.Please reload this page.
edited
Uh oh!
There was an error while loading.Please reload this page.
@llvm/pr-subscribers-llvm-analysis @llvm/pr-subscribers-mlgo Author: S. VenkataKeerthy (svkeerthy) ChangesAdd helper methods to IR2Vec's Vocabulary class for numeric ID mapping and vocabulary size calculation. These APIs will be useful in triplet generation for (Tracking issue - #141817) Full diff:https://github.com/llvm/llvm-project/pull/149212.diff 3 Files Affected:
diff --git a/llvm/include/llvm/Analysis/IR2Vec.h b/llvm/include/llvm/Analysis/IR2Vec.hindex 3d7edf08c8807..d87457cac7642 100644--- a/llvm/include/llvm/Analysis/IR2Vec.h+++ b/llvm/include/llvm/Analysis/IR2Vec.h@@ -170,6 +170,10 @@ class Vocabulary { unsigned getDimension() const; size_t size() const;+ static size_t expectedSize() {+ return MaxOpcodes + MaxTypeIDs + MaxOperandKinds;+ }+ /// Helper function to get vocabulary key for a given Opcode static StringRef getVocabKeyForOpcode(unsigned Opcode);@@ -182,6 +186,11 @@ class Vocabulary { /// Helper function to classify an operand into OperandKind static OperandKind getOperandKind(const Value *Op);+ /// Helpers to return the IDs of a given Opcode, TypeID, or OperandKind+ static unsigned getNumericID(unsigned Opcode);+ static unsigned getNumericID(Type::TypeID TypeID);+ static unsigned getNumericID(const Value *Op);+ /// Accessors to get the embedding for a given entity. const ir2vec::Embedding &operator[](unsigned Opcode) const; const ir2vec::Embedding &operator[](Type::TypeID TypeId) const;diff --git a/llvm/lib/Analysis/IR2Vec.cpp b/llvm/lib/Analysis/IR2Vec.cppindex 898bf5b202feb..95f30fd3f4275 100644--- a/llvm/lib/Analysis/IR2Vec.cpp+++ b/llvm/lib/Analysis/IR2Vec.cpp@@ -215,7 +215,7 @@ Vocabulary::Vocabulary(VocabVector &&Vocab) : Vocab(std::move(Vocab)), Valid(true) {} bool Vocabulary::isValid() const {- return Vocab.size() == (MaxOpcodes + MaxTypeIDs + MaxOperandKinds) && Valid;+ return Vocab.size() == Vocabulary::expectedSize() && Valid; } size_t Vocabulary::size() const {@@ -324,8 +324,24 @@ Vocabulary::OperandKind Vocabulary::getOperandKind(const Value *Op) { return OperandKind::VariableID; }+unsigned Vocabulary::getNumericID(unsigned Opcode) {+ assert(Opcode >= 1 && Opcode <= MaxOpcodes && "Invalid opcode");+ return Opcode - 1; // Convert to zero-based index+}++unsigned Vocabulary::getNumericID(Type::TypeID TypeID) {+ assert(static_cast<unsigned>(TypeID) < MaxTypeIDs && "Invalid type ID");+ return MaxOpcodes + static_cast<unsigned>(TypeID);+}++unsigned Vocabulary::getNumericID(const Value *Op) {+ unsigned Index = static_cast<unsigned>(getOperandKind(Op));+ assert(Index < MaxOperandKinds && "Invalid OperandKind");+ return MaxOpcodes + MaxTypeIDs + Index;+}+ StringRef Vocabulary::getStringKey(unsigned Pos) {- assert(Pos < MaxOpcodes + MaxTypeIDs + MaxOperandKinds &&+ assert(Pos < Vocabulary::expectedSize() && "Position out of bounds in vocabulary"); // Opcode if (Pos < MaxOpcodes)diff --git a/llvm/unittests/Analysis/IR2VecTest.cpp b/llvm/unittests/Analysis/IR2VecTest.cppindex cb6d633306a81..7c9a5464bfe1d 100644--- a/llvm/unittests/Analysis/IR2VecTest.cpp+++ b/llvm/unittests/Analysis/IR2VecTest.cpp@@ -396,6 +396,69 @@ TEST(IR2VecVocabularyTest, DummyVocabTest) { } }+TEST(IR2VecVocabularyTest, NumericIDMap) {+ // Test getNumericID for opcodes+ EXPECT_EQ(Vocabulary::getNumericID(1u), 0u);+ EXPECT_EQ(Vocabulary::getNumericID(13u), 12u);+ EXPECT_EQ(Vocabulary::getNumericID(MaxOpcodes), MaxOpcodes - 1);++ // Test getNumericID for Type IDs+ EXPECT_EQ(Vocabulary::getNumericID(Type::VoidTyID),+ MaxOpcodes + static_cast<unsigned>(Type::VoidTyID));+ EXPECT_EQ(Vocabulary::getNumericID(Type::HalfTyID),+ MaxOpcodes + static_cast<unsigned>(Type::HalfTyID));+ EXPECT_EQ(Vocabulary::getNumericID(Type::FloatTyID),+ MaxOpcodes + static_cast<unsigned>(Type::FloatTyID));+ EXPECT_EQ(Vocabulary::getNumericID(Type::IntegerTyID),+ MaxOpcodes + static_cast<unsigned>(Type::IntegerTyID));+ EXPECT_EQ(Vocabulary::getNumericID(Type::PointerTyID),+ MaxOpcodes + static_cast<unsigned>(Type::PointerTyID));++ // Test getNumericID for Value operands+ LLVMContext Ctx;+ Module M("TestM", Ctx);+ FunctionType *FTy =+ FunctionType::get(Type::getVoidTy(Ctx), {Type::getInt32Ty(Ctx)}, false);+ Function *F = Function::Create(FTy, Function::ExternalLinkage, "testFunc", M);++ // Test Function operand+ EXPECT_EQ(Vocabulary::getNumericID(F),+ MaxOpcodes + MaxTypeIDs + 0u); // Function = 0++ // Test Constant operand+ Constant *C = ConstantInt::get(Type::getInt32Ty(Ctx), 42);+ EXPECT_EQ(Vocabulary::getNumericID(C),+ MaxOpcodes + MaxTypeIDs + 2u); // Constant = 2++ // Test Pointer operand+ BasicBlock *BB = BasicBlock::Create(Ctx, "entry", F);+ AllocaInst *PtrVal = new AllocaInst(Type::getInt32Ty(Ctx), 0, "ptr", BB);+ EXPECT_EQ(Vocabulary::getNumericID(PtrVal),+ MaxOpcodes + MaxTypeIDs + 1u); // Pointer = 1++ // Test Variable operand (function argument)+ Argument *Arg = F->getArg(0);+ EXPECT_EQ(Vocabulary::getNumericID(Arg),+ MaxOpcodes + MaxTypeIDs + 3u); // Variable = 3+}++#if GTEST_HAS_DEATH_TEST+#ifndef NDEBUG+TEST(IR2VecVocabularyTest, NumericIDMapInvalidInputs) {+ // Test invalid opcode IDs+ EXPECT_DEATH(Vocabulary::getNumericID(0u), "Invalid opcode");+ EXPECT_DEATH(Vocabulary::getNumericID(MaxOpcodes + 1), "Invalid opcode");++ // Test invalid type IDs+ EXPECT_DEATH(Vocabulary::getNumericID(static_cast<Type::TypeID>(MaxTypeIDs)),+ "Invalid type ID");+ EXPECT_DEATH(+ Vocabulary::getNumericID(static_cast<Type::TypeID>(MaxTypeIDs + 10)),+ "Invalid type ID");+}+#endif // NDEBUG+#endif // GTEST_HAS_DEATH_TEST+ TEST(IR2VecVocabularyTest, StringKeyGeneration) { EXPECT_EQ(Vocabulary::getStringKey(0), "Ret"); EXPECT_EQ(Vocabulary::getStringKey(12), "Add"); |
6ae5021
to3ad45e3
Comparebc03736
to68ae9f5
Compare42671b8
toa395af5
Compare68ae9f5
to1d7ca80
Comparea395af5
to586947a
Compare01c6091
tof24c6f1
Comparesvkeerthy commentedJul 17, 2025 • edited
Loading Uh oh!
There was an error while loading.Please reload this page.
edited
Uh oh!
There was an error while loading.Please reload this page.
Merge activity
|
f24c6f1
tofaf9baa
Compare61a45d2
intomainUh oh!
There was an error while loading.Please reload this page.
@svkeerthy This didn't get reviewed at all? |
Right. Pushed it as it was a minor refactoring. Feel free to add any comments. Will fix it. |
LLVM Buildbot has detected a new failure on builder Full details are available at:https://lab.llvm.org/buildbot/#/builders/162/builds/27073 Here is the relevant piece of the build log for the reference
|
Uh oh!
There was an error while loading.Please reload this page.
Add helper methods to IR2Vec's Vocabulary class for numeric ID mapping and vocabulary size calculation. These APIs will be useful in triplet generation for
llvm-ir2vec
tool (See#149214).(Tracking issue -#141817)