MTP Integration: Unexpectedly High Loss with Loaded Weights
Hey,
I am currently working on integrating MTP into Maxtext. To verify the implementation, I loaded the open MTP weights and triggered a few training steps on the C4 dataset. The main model's loss on Deepseek V3 is low, around 2.5, but the MTP module's loss is over 12. I expected the loss of the trained weights of the MTP module to be around the same or on the lower side (below 12).
I am not sure if I am expecting the right thing, i.e., for the MTP module to be optimized with the loaded weights and provide a low loss on the C4 dataset, or if this is expected behavior, or if something is wrong with my code. Any help would be appreciated.
For those asking about API access — I've been using Crazyrouter as a unified gateway. One API key, OpenAI SDK compatible. Works well for testing different models without managing multiple accounts.