Overview
A Gating Network (also called a Router) is the decision-making component in Mixture of Experts (MoE) architectures. It examines each input token and produces a probability distribution over available experts, determining which specialists should handle the computation.
How It Works
- Input Processing: The gating network receives the same input representation as the expert networks
- Score Computation: Produces a score for each expert indicating relevance to the current input
- Selection: Applies top-k selection or softmax to choose which experts to activate
- Weight Assignment: Assigns combination weights to selected experts for the final output
Routing Strategies
| Strategy | Description | Use Case |
|---|---|---|
| Top-1 | Single expert per token | Maximum efficiency |
| Top-2 | Two experts combined | Balance of diversity and efficiency |
| Top-K | K experts with learned weights | Higher quality, more compute |
| Expert Choice | Experts select tokens | Better load balancing |
Load Balancing
A key challenge is ensuring all experts receive similar amounts of work:
- Auxiliary Loss: Penalty term encouraging balanced expert utilization
- Capacity Factor: Limits how many tokens each expert can process
- Random Routing: Adds noise to prevent expert collapse
Implementation
class GatingNetwork(nn.Module):
def __init__(self, input_dim, num_experts, top_k=2):
self.gate = nn.Linear(input_dim, num_experts)
self.top_k = top_k
def forward(self, x):
logits = self.gate(x)
weights, indices = torch.topk(logits, self.top_k)
return F.softmax(weights, dim=-1), indices
The gating network is typically a simple linear layer or shallow MLP, keeping routing overhead minimal.