Skip to content

Conversation

@garyzhang99
Copy link
Collaborator

Description

Implement the corrected kl with importance sampling, as mentioned in DeepseekV3.2 paper.

Checklist

Please check the following items before code is ready to be reviewed.

  • Code has passed all tests
  • Docstrings have been added/updated in Google Style
  • Documentation has been updated
  • Code is ready for review

@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello @garyzhang99, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request enhances the KL divergence calculation framework by integrating a "corrected K3" function that leverages importance sampling. This modification, inspired by the DeepseekV3.2 paper, allows for more robust and accurate KL divergence estimation, particularly when comparing current policy log probabilities against a reference and an old policy. The changes involve extending the KL function API, implementing the new corrected K3 logic, and updating the training actors to pass the necessary old policy log probabilities.

Highlights

  • New KL Function: Introduced CorrectedK3Fn for KL divergence calculation with importance sampling, based on the DeepseekV3.2 paper. This new function provides a more robust KL estimation by incorporating an 'old policy' log probability.
  • API Extension for Importance Sampling: The calculate_kl and calculate_kl_loss methods in the base KLFn class and all its existing implementations (DummyKLFn, K1Fn, K2Fn, K3Fn, LowVarKLFn, AbsFn) now accept an optional old_logprob parameter to facilitate importance sampling.
  • Integration into Training Actors: Updated dp_actor.py and megatron_actor.py to pass the old_logprob parameter to the kl_fn.calculate_kl_loss method, enabling the use of importance sampling in the training loop.
  • Comprehensive Unit Testing: Added a new test file tests/algorithm/kl_fn_test.py with extensive unit tests for the new corrected_k3 function, covering its fallback mechanism to standard K3 when old_logprob is absent, its behavior with old_logprob, and various loss aggregation modes.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a corrected KL divergence calculation using importance sampling, CorrectedK3Fn. The changes are well-structured: the new logic is encapsulated in its own class, the base KLFn and its subclasses are updated for a consistent interface, and the trainers are modified to pass the necessary old_logprob. The addition of a comprehensive test suite is excellent, covering new functionality, fallback behavior, and edge cases, which increases confidence in the implementation's correctness. My main feedback is a minor suggestion to improve code clarity by replacing hardcoded values with named constants. Overall, this is a high-quality contribution.

@garyzhang99
Copy link
Collaborator Author

/unittest-module-algorithm

@github-actions
Copy link

github-actions bot commented Dec 2, 2025

Summary

Tests 📝 Passed ✅ Failed ❌ Skipped ⏭️ Other ❓ Flaky 🍂 Duration ⏱️
25 22 3 0 0 0 18.3s

Failed Tests

Failed Tests ❌ Fail Message
❌ tests/algorithm/kl_fn_test.py::KLFnTest::test_corrected_k3_fallback The test failed in the call phase due to an assertion error
❌ tests/algorithm/kl_fn_test.py::KLFnTest::test_corrected_k3_same_policy The test failed in the call phase due to an assertion error
❌ tests/algorithm/kl_fn_test.py::KLFnTest::test_kl_loss_aggregation_modes The test failed in the call phase

Tests

Test Name Status Flaky Duration
tests/algorithm/advantage_fn_test.py::TestGroupedAdvantageFn::test_batch_level_std_grpo 41ms
tests/algorithm/advantage_fn_test.py::TestGroupedAdvantageFn::test_batch_level_step_wise_grpo_advantage 2ms
tests/algorithm/advantage_fn_test.py::TestGroupedAdvantageFn::test_duplicate_grpo 5ms
tests/algorithm/advantage_fn_test.py::TestGroupedAdvantageFn::test_grpo_advantage 3ms
tests/algorithm/advantage_fn_test.py::TestGroupedAdvantageFn::test_grpo_correct_bias 2ms
tests/algorithm/advantage_fn_test.py::TestGroupedAdvantageFn::test_grpo_reward_std 1ms
tests/algorithm/advantage_fn_test.py::TestGroupedAdvantageFn::test_step_wise_grpo_advantage 2ms
tests/algorithm/advantage_fn_test.py::TestGroupedAdvantageFn::test_step_wise_grpo_with_std_threshold 2ms
tests/algorithm/kl_fn_test.py::KLFnTest::test_abs_kl_fn 1ms
tests/algorithm/kl_fn_test.py::KLFnTest::test_corrected_k3_fallback 1ms
tests/algorithm/kl_fn_test.py::KLFnTest::test_corrected_k3_loss 1ms
tests/algorithm/kl_fn_test.py::KLFnTest::test_corrected_k3_same_policy 1ms
tests/algorithm/kl_fn_test.py::KLFnTest::test_corrected_k3_with_old_logprob 1ms
tests/algorithm/kl_fn_test.py::KLFnTest::test_dummy_kl_fn 1ms
tests/algorithm/kl_fn_test.py::KLFnTest::test_k1_kl_fn 1ms
tests/algorithm/kl_fn_test.py::KLFnTest::test_k2_kl_fn 1ms
tests/algorithm/kl_fn_test.py::KLFnTest::test_k3_kl_fn 1ms
tests/algorithm/kl_fn_test.py::KLFnTest::test_kl_loss_aggregation_modes 1ms
tests/algorithm/kl_fn_test.py::KLFnTest::test_low_var_kl_fn 1ms
tests/algorithm/policy_loss_test.py::VerlPolicyLossTest::test_dpo_policy_loss 1ms
tests/algorithm/policy_loss_test.py::VerlPolicyLossTest::test_gspo_policy_loss 1ms
tests/algorithm/policy_loss_test.py::VerlPolicyLossTest::test_mix_policy_loss 1ms
tests/algorithm/policy_loss_test.py::VerlPolicyLossTest::test_opmd_policy_loss 1ms
tests/algorithm/policy_loss_test.py::VerlPolicyLossTest::test_ppo_policy_loss 1ms
tests/algorithm/policy_loss_test.py::VerlPolicyLossTest::test_sft_policy_loss 1ms

Github Test Reporter by CTRF 💚

@garyzhang99
Copy link
Collaborator Author

/unittest-module-algorithm

@github-actions
Copy link

github-actions bot commented Dec 2, 2025

Summary

Tests 📝 Passed ✅ Failed ❌ Skipped ⏭️ Other ❓ Flaky 🍂 Duration ⏱️
25 25 0 0 0 0 18.1s

Tests

Test Name Status Flaky Duration
tests/algorithm/advantage_fn_test.py::TestGroupedAdvantageFn::test_batch_level_std_grpo 42ms
tests/algorithm/advantage_fn_test.py::TestGroupedAdvantageFn::test_batch_level_step_wise_grpo_advantage 2ms
tests/algorithm/advantage_fn_test.py::TestGroupedAdvantageFn::test_duplicate_grpo 5ms
tests/algorithm/advantage_fn_test.py::TestGroupedAdvantageFn::test_grpo_advantage 3ms
tests/algorithm/advantage_fn_test.py::TestGroupedAdvantageFn::test_grpo_correct_bias 2ms
tests/algorithm/advantage_fn_test.py::TestGroupedAdvantageFn::test_grpo_reward_std 1ms
tests/algorithm/advantage_fn_test.py::TestGroupedAdvantageFn::test_step_wise_grpo_advantage 2ms
tests/algorithm/advantage_fn_test.py::TestGroupedAdvantageFn::test_step_wise_grpo_with_std_threshold 2ms
tests/algorithm/kl_fn_test.py::KLFnTest::test_abs_kl_fn 1ms
tests/algorithm/kl_fn_test.py::KLFnTest::test_corrected_k3_fallback 1ms
tests/algorithm/kl_fn_test.py::KLFnTest::test_corrected_k3_loss 1ms
tests/algorithm/kl_fn_test.py::KLFnTest::test_corrected_k3_same_policy 1ms
tests/algorithm/kl_fn_test.py::KLFnTest::test_corrected_k3_with_old_logprob 1ms
tests/algorithm/kl_fn_test.py::KLFnTest::test_dummy_kl_fn 1ms
tests/algorithm/kl_fn_test.py::KLFnTest::test_k1_kl_fn 1ms
tests/algorithm/kl_fn_test.py::KLFnTest::test_k2_kl_fn 1ms
tests/algorithm/kl_fn_test.py::KLFnTest::test_k3_kl_fn 1ms
tests/algorithm/kl_fn_test.py::KLFnTest::test_kl_loss_aggregation_modes 1ms
tests/algorithm/kl_fn_test.py::KLFnTest::test_low_var_kl_fn 1ms
tests/algorithm/policy_loss_test.py::VerlPolicyLossTest::test_dpo_policy_loss 1ms
tests/algorithm/policy_loss_test.py::VerlPolicyLossTest::test_gspo_policy_loss 1ms
tests/algorithm/policy_loss_test.py::VerlPolicyLossTest::test_mix_policy_loss 1ms
tests/algorithm/policy_loss_test.py::VerlPolicyLossTest::test_opmd_policy_loss 1ms
tests/algorithm/policy_loss_test.py::VerlPolicyLossTest::test_ppo_policy_loss 1ms
tests/algorithm/policy_loss_test.py::VerlPolicyLossTest::test_sft_policy_loss 1ms

Github Test Reporter by CTRF 💚

@garyzhang99
Copy link
Collaborator Author

/unittest-module-trainer

@github-actions
Copy link

github-actions bot commented Dec 2, 2025

Summary

Tests 📝 Passed ✅ Failed ❌ Skipped ⏭️ Other ❓ Flaky 🍂 Duration ⏱️
22 20 0 2 0 0 42m 38s

Skipped

Tests Status
tests/trainer/trainer_test.py::TestMultiModalGRPO::test_trainer skipped ⏭️
tests/trainer/trainer_test.py::TestMultiModalSFT::test_trainer skipped ⏭️

Tests

Test Name Status Flaky Duration
tests/trainer/trainer_test.py::TestTrainerCountdown_0_fsdp::test_trainer 3m 16s
tests/trainer/trainer_test.py::TestTrainerCountdown_1_megatron::test_trainer 4m 45s
tests/trainer/trainer_test.py::TestStepAheadAsyncRL::test_trainer 1m 29s
tests/trainer/trainer_test.py::TestTrainerGSM8K_0_fsdp::test_trainer 1m 19s
tests/trainer/trainer_test.py::TestTrainerGSM8K_1_fsdp2::test_trainer 1m 18s
tests/trainer/trainer_test.py::TestTrainerGSM8K_2_fsdp::test_trainer 1m 24s
tests/trainer/trainer_test.py::TestTrainerGSM8K_3_fsdp2::test_trainer 1m 33s
tests/trainer/trainer_test.py::TestTrainerSFTWarmupGSM8K::test_trainer 2m 31s
tests/trainer/trainer_test.py::TestTrainerDPO::test_trainer 1m
tests/trainer/trainer_test.py::TestTrainerSFT::test_trainer 57.7s
tests/trainer/trainer_test.py::TestTrainerToolsSFT::test_trainer_tools 57.8s
tests/trainer/trainer_test.py::TestFullyAsyncMode_0_fsdp::test_fully_async_mode 1m 55s
tests/trainer/trainer_test.py::TestFullyAsyncMode_1_fsdp::test_fully_async_mode 1m 53s
tests/trainer/trainer_test.py::TestFullyAsyncMode_2_megatron::test_fully_async_mode 2m 37s
tests/trainer/trainer_test.py::TestTrainerCheckpointSave_0_fsdp::test_trainer 2m 19s
tests/trainer/trainer_test.py::TestTrainerCheckpointSave_1_megatron::test_trainer 4m 25s
tests/trainer/trainer_test.py::TestTrainerMIX::test_trainer 2m 28s
tests/trainer/trainer_test.py::TestMultiModalGRPO::test_trainer ⏭️ 810ms
tests/trainer/trainer_test.py::TestMultiModalSFT::test_trainer ⏭️ 807ms
tests/trainer/trainer_test.py::TestTrainerLoRA::test_trainer 3m 31s
tests/trainer/trainer_test.py::TestOverRollout::test_trainer 1m 22s
tests/trainer/trainer_test.py::TestTrainerPromptTruncation::test_trainer 1m 11s

Github Test Reporter by CTRF 💚

@pan-x-c pan-x-c merged commit b9ff286 into modelscope:main Dec 2, 2025
1 check passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants