EDIT: I didn't set things up for perfect reproducibility and those who have re-run have seen lower scores in practice. So take the 0.064 with a pinch of salt until I can get something consistent. Also, all leaderboard scores are curiously low and I see lots of discussion around this - another thing on the 'to-investigate' list :)
Tried a different approach and I'm pretty happy with the results so far.
Blog post: https://datasciencecastnet.home.blog/2022/02/18/turtle-recall-a-contrastive-learning-approach/
Notebook: https://colab.research.google.com/drive/1AZkjlJ3oPUL-nZ03PU4d_jQL2yiKNHJp?usp=sharing
I'd love feedback! Next week (when I've had more than a day to get used to jax etc) I want to go back and improve the blog post and also hopefully make a video, so any questions or comments in the next few days will result in a better final result.
Hope this is useful, good luck everyone,
J
Glad simclr is giving good performance, was hesitant on following with similar approaches. Thanks for sharing your results :)
Nobody has been able to correlate the local validation with the leaderboard (as far as I know). Basically because the leaderboard metric is wrongly calculated.
Do you know what local validation value you had?
Thanks for sharing!
Hehe I left local validation as an exercise for the reader! First day on a comp I tend to cheat and rely on the lb score. Perhaps they're just using the total number of test items in the calc but only counting hits from the much smaller public set?
Nice, have eyeballed your code, I suspect you have tiny mistake in your loss function (but not on your postit note, so maybe you aware) which may make a small difference.
In your denominator, you should also exclude pos_sample_indicators and not sum over all.
Unless ... are you excluding it by '... * neg_sample_indicators'?
Nice work J
I thought about it but I think the goal is to exclude the identity (i.e. Images with themselves) but include the pairs highlighted by pos_sample_indicators alongside all other non-identity values in the denominator.
At the end of the notebook i visualize the sample_indicator matrices and that's what I was going for. But I agree my gut feeling reading the formula and approach was that we'd surely want to exclude the pos ones from the denominator as well. I even asked some other smart cookies about it and got 'shrug that's how SimCLR did it' as the response haha. Probably works fine either way especially since this is generally done for large batch sizes.
Thanks for the feedback though I really appreciate some sharp eyes on the code! And I could well be wrong on this 😂
Thanks J!
My reading of the formula is definitely to exclude the pos images in denominator but it will not change contours of loss function, will only shift it from being log x to being log 1 + x, the latter if you include the images. Note however the latter has much lower slope so you'd be doing your poor optimiser a favour by exluding.
Also, just a sense (which I may play with given *your* coce) is that a nice thing to calibrate here is the loss function. You could toy with the similarity score, perhaps using some ordinal or binary measure, or, related, make some big assumptions and use a pareto or zipf type assumption. I think I'll try something like that, probably will break a few rules and get punished for it, but what the heck.
So .... how is JAX + Haiku working for you?
Good point.
Please play away! That's the goal :) Interesting ideas for sure. I wonder if the test set is *evenly balanced* or has roughly the same distribution as the train...
My favourite thing about JAX so far is seeing the estimated time left (via tqdm etc) drop dramatically after the first iteration as the JIT magic does its thing! More seriously, I like the 'everything becomes a transform' style with haiku. I want to read more code examples and get used to it, but it feels like it pushes me towards good code.
He he, reason I want to play with loss is because sample is so small, so maybe something e.g. binary that punishes heavily can help.
Nikita Churkin did some really amazing work on distro shift between test and train. One day I'm going to use that approach for some comp, especially since it involves differential evolution which I luvvv.
Fnoa probed the LB a bit and if I remember he got 2 new turtles.
Haiku pushing you - code is art, don't let it push you around, there's enough other things doing that.
resnet weights?
Hmmmmm - so those imagenet weights, did you get them from Haiku somewhere? Never used Haiku before, not sure if I should even worry. Those are easy to get with keras but Haiku? I've no idea?
Oh wait a bit, I see - those weights are provided by the host.
Did I miss the announcement? Are there any other interesting files there that we should be aware of or is the host assisting you with this example?
I also wasn't sure of the best place to get pretrained weights with Haiku but the tutorial notebook came to the rescue - the hosts presumably shared them to make that getting started notebook easier to run.
Ok, yes I looked everywhere except the starter!
Thanks for sharing Johnowhitaker.
I ran the model and obtained a much lesser score that is 0.026. I am concerned about the reproducibility of the score 0.064. I noticed you used a random key of 42 but not sure why the score is not reproducible. Is there some non-deterministic force at play here? Any strategies to ensure maximum model reproducibility?
Did you do the yoga?
The code is probably doing more yoga than myself at the moment 😂.
Hmm looking through I rely on numpy (not jax.numpy) for some random splits and such, which I didn't set the seed for. Sorry about that! Will try to remedy next week.
My new hobby - running the models of tomorrow on the computers of yesterday.
I have to make some changes, no GPU and old computer, if I follow your yoga advice I'll be doing vertical splits by Tuesday.
There is some relationship between number of embeddings and batch size that I've not quite figured out, but as is, the code fails when you change one without a like change in the other.
Since you reward similarity and punish dissimilarity, as mentioned earlier I hope to toy a bit with the way in which you calculate (or rank) 'similaritiness', but changing cosine might also break the optimiser.
Looking at the LB, it seems your code is gaining traction :-)