Conversation
Reviewer's guide (collapsed on small PRs)Reviewer's GuideThis PR enhances the scheduler’s memory leak check to account for distributed decode context parallelism (DCP), adds conditional DCP enable/disable logging in the parallel state initializer, and updates the attention kernel to use base-2 exponential and logarithm functions. Sequence diagram for DCP-aware memory leak check in schedulersequenceDiagram
participant Scheduler
participant DCP
Scheduler->>DCP: get_dcp_world_size()
alt DCP world size > 1
Scheduler->>Scheduler: real_available_size < real_total_num_tokens?
else DCP world size == 1
Scheduler->>Scheduler: real_available_size != real_total_num_tokens?
end
Scheduler->>Scheduler: If memory leak, log warning
Class diagram for updated attention kernel math functionsclassDiagram
class _correct_attn_cp_out_kernel {
-lse
-lse_max
-lse_exp (now uses tl.exp2)
-lse_acc
-lse (now uses tl.log2)
-factor (now uses tl.exp2)
+output
}
Class diagram for DCP logging in parallel state initializerclassDiagram
class ParallelState {
+initialize_model_parallel()
+get_dcp_size_from_env()
+get_tensor_model_parallel_rank()
+logger.info() // logs DCP enabled/disabled based on dcp_size
}
File-Level Changes
Tips and commandsInteracting with Sourcery
Customizing Your ExperienceAccess your dashboard to:
Getting Help
|
Summary of ChangesHello @Rythsman, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request focuses on refining the distributed computing aspects, particularly concerning Decode Context Parallel (DCP). It introduces clearer logging for DCP's operational status, adjusts core numerical computations within an attention mechanism, and updates memory leak detection to be more robust and accurate when DCP is in use. The overall aim is to improve the stability, observability, and correctness of the system under parallel execution configurations. Highlights
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Code Review
This pull request introduces changes related to Decode Context Parallelism (DCP), including logging for its status and adjustments to memory leak detection logic. It also updates attention kernels to use base-2 logarithm functions, likely for performance. My review focuses on improving code clarity and maintainability in the new DCP-related logic.
| if decode_context_model_parallel_size > 1: | ||
| if get_tensor_model_parallel_rank() == 0: | ||
| logger.info(f"DCP enabled, dcp_size={decode_context_model_parallel_size}, tp_size={tensor_model_parallel_size}") | ||
| else: | ||
| if get_tensor_model_parallel_rank() == 0: | ||
| logger.info(f"DCP disabled, dcp_size={decode_context_model_parallel_size}, tp_size={tensor_model_parallel_size}") |
There was a problem hiding this comment.
The logging logic for DCP status can be simplified to avoid code duplication and improve readability. The check for get_tensor_model_parallel_rank() == 0 can be performed once, and a conditional expression can be used to determine the status string.
| if decode_context_model_parallel_size > 1: | |
| if get_tensor_model_parallel_rank() == 0: | |
| logger.info(f"DCP enabled, dcp_size={decode_context_model_parallel_size}, tp_size={tensor_model_parallel_size}") | |
| else: | |
| if get_tensor_model_parallel_rank() == 0: | |
| logger.info(f"DCP disabled, dcp_size={decode_context_model_parallel_size}, tp_size={tensor_model_parallel_size}") | |
| if get_tensor_model_parallel_rank() == 0: | |
| status = "enabled" if decode_context_model_parallel_size > 1 else "disabled" | |
| logger.info(f"DCP {status}, dcp_size={decode_context_model_parallel_size}, tp_size={tensor_model_parallel_size}") |
| real_available_size = available_size + evictable_size | ||
| real_total_num_tokens = self.max_total_num_tokens - protected_size | ||
| dcp_world_size = get_dcp_world_size() | ||
| # TODO(wh): currently, enable_dcp with get more avalibale_size, check later | ||
| token_msg = f"{self.max_total_num_tokens=}, {available_size=}, {evictable_size=}, {protected_size=}\n" | ||
| if dcp_world_size > 1: | ||
| memory_leak = real_available_size < real_total_num_tokens | ||
| else: | ||
| memory_leak = real_available_size != real_total_num_tokens |
There was a problem hiding this comment.
The logic for memory leak detection can be refactored for better clarity. It's good practice to state the general rule first and then handle the exception. This makes the intention of the code more explicit.
Also, there's a typo in the TODO comment: avalibale_size should be available_size.
| real_available_size = available_size + evictable_size | |
| real_total_num_tokens = self.max_total_num_tokens - protected_size | |
| dcp_world_size = get_dcp_world_size() | |
| # TODO(wh): currently, enable_dcp with get more avalibale_size, check later | |
| token_msg = f"{self.max_total_num_tokens=}, {available_size=}, {evictable_size=}, {protected_size=}\n" | |
| if dcp_world_size > 1: | |
| memory_leak = real_available_size < real_total_num_tokens | |
| else: | |
| memory_leak = real_available_size != real_total_num_tokens | |
| real_available_size = available_size + evictable_size | |
| real_total_num_tokens = self.max_total_num_tokens - protected_size | |
| dcp_world_size = get_dcp_world_size() | |
| # TODO(wh): currently, enable_dcp gets more available_size, check later | |
| token_msg = f"{self.max_total_num_tokens=}, {available_size=}, {evictable_size=}, {protected_size=}\n" | |
| memory_leak = real_available_size != real_total_num_tokens | |
| if memory_leak and dcp_world_size > 1 and real_available_size > real_total_num_tokens: | |
| # This is a known issue with DCP, not a leak. | |
| memory_leak = False |
Motivation
Modifications
Accuracy Tests
Benchmarking and Profiling
Checklist
Summary by Sourcery
Introduce DCP-awareness by adjusting memory leak detection logic and adding DCP status logs in model parallel initialization, and update attention kernels to use base-2 exp/log operations
Enhancements: