- Notifications
You must be signed in to change notification settings - Fork13
Description
First of all, thank you for your work.
The method is promising and your article is very interesting, so I tried to use it in two way:
- determining whether a detected object is a False Positive
- determining the absence of an object in an image
I'm using the .pt weights you kindly provided, and I tried to implement the ATD and the CTW methods.
However the results were really bad leading me to think I missed something, on my first usecase the prompt was only:
"A photo of a person with a {}" ("A photo of a person without a {}") with "hat", "cap", "helmet" as the class names.
Using ATD everything is considered as an OOD, using CTW almost everything is considered as an ID.
I have some question regarding your paper:
Do you have a reference or a paper explaining where Eq.4 comes from? So regarding the CTW method, Eq.4 should be over 0.5 for the classification to be OOD.
And also from where comes the Eq.8?
As for the Eq.6, to compute pij, this is a kind of softmax right? Just adding the temperature parameter?
In this case, wouldn't the ATD method be unusable when you only have one class and just want to discard the FP as pij is equal to 1?
The first thing that came to my mind was to find the index of maximum value in logits, and check logits[index] > logits_no[index] to check if it's an ID or an OOD, however I suppose it's mathematically incorrect as you didn't mention it in your paper, and the test I ran also led to bad results.
Here are the functions I wrote for ATD and CTW from what I understood from your paper, they are kind of raw as it's a wip. I used the code in "handcrafted" folder, from what I understood this is the one to use when dealing with custom prompts and not the learned ones.
Both of them takes the logits and logits_no computed this way:
logits = F.normalize(feat, dim=-1, p=2) @ fc_yes.T
logits_no = F.normalize(feat, dim=-1, p=2) @ fc_no.T
As well as a tau parameter, I set it to 1 for now.
def CTW(logits_yes, logits_no, tau): yes = logits_yes[0].detach().tolist() no = logits_no[0].detach().tolist() pij = [] denominator = 0 for i in range(len(yes)): denominator += math.exp(yes[i] / tau) for i in range(len(yes)): pij.append(math.exp(yes[i] / tau) / denominator) pijno = [] for i in range(len(no)): pijno.append(math.exp(no[i]/tau) / (math.exp(yes[i]/tau) + math.exp(no[i]/tau))) index = pij.index(max(pij)) bestood = pijno[index] return (index, 1 - bestood > bestood)
def ATD(logits_yes, logits_no, tau): ood = 1. yes = logits_yes[0].detach().tolist() no = logits_no[0].detach().tolist() pijno = [] for i in range(len(no)): pijno.append(math.exp(no[i]/tau)/(math.exp(yes[i]/tau) + math.exp(no[i]/tau))) pij = [] denominator = 0 for i in range(len(yes)): denominator += math.exp(yes[i]/tau) for i in range(len(yes)): pij.append(math.exp(yes[i]/tau)/denominator) index = pij.index(max(pij)) for i, pno in enumerate(pijno): ood -= (1 - pno)*pij[i] res = 0 for pyes in pij: if pyes > ood: res = 1 return (index, res)
The return value is 1 if it's an ID and 0 otherwise.
The model is in eval mode and I use process_test function returned by load_model() function to preprocess the images I load using Pil Image.open().
So I don't know if I did something wrong or if I "just" need to retrain the model.
Thank for your help!