Exclude norm/bias/scale parameters from weight decay; make it configurable#545
Open
jlamypoirier wants to merge 1 commit into
Open
Exclude norm/bias/scale parameters from weight decay; make it configurable#545jlamypoirier wants to merge 1 commit into
jlamypoirier wants to merge 1 commit into
Conversation
…rable The normalization refactor dropped `weight_decay=False` on norm weights/biases, so they (and several other 1-D parameters) were silently weight-decayed against the usual convention of decaying only weight matrices (#542). Restore the convention as the structural default at each call site — norm weight+bias, all linear biases (incl. causal conv), the SSM `dt_bias`/`A_log` state params, and the init-1 scales (`output_scale`, router scales) — and expose it as a tunable per-parameter knob `ParameterConfig.weight_decay` so the values can be experimented with: `None` keeps the default, `True`/`False` enable/disable, a float sets a specific coefficient. Embeddings and weight matrices stay decayed. `ParameterMeta.param_weight_decay` now carries `float | bool`; the param-group builder maps `True`->global / `False`->0.0 / float->itself, and the meta reorder reduces it to a bool so sequence-parallel buffer contiguity is unchanged. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Claude Opus 4.8 authored this PR.
Closes #542.
Problem
The normalization refactor dropped
weight_decay=Falsewhen creating norm weights/biases, so they fell back to theweight_decay=Truedefault and became subject to AdamW weight decay — contrary to the prior behavior and the usual convention of decaying only weight matrices, not 1-D biases/gains/state parameters. Auditing the otherget_parametercall sites surfaced the same issue in newer code (linear biases, causal-conv bias, SSMdt_bias/A_log, and the init-1 scales).Change
weight_decay=Falsefor: norm weight + bias, RMSNorm weight, all linear biases (AffineLinearConfig+CausalConv1d),dt_bias/A_logingdn/kda(matchingmamba's existingA_log/Dtreatment), and the init-1 scalesoutput_scale/router_scale/router_per_expert_scale. Embeddings, linear/head weights, and conv weights stay decayed.ParameterConfig.weight_decay: float | bool | None:None(default) → keep the parent layer's structural defaultTrue/False→ enable (global value) / disableThe convention follows standard practice (BERT/GPT-2/GPT-3, HF default
no_decay = ['bias', 'LayerNorm.weight']); decaying a gain/bias initialized to 1/0 pulls it toward 0 and fights the normalization. The init-1 scales are treated like norm gains, consistent with recent work on learnable multipliers that excludes such scalars from decay.Plumbing
ParameterMeta.param_weight_decaynow carriesfloat | bool. The param-group builder mapsTrue→global /False→0.0/float→itself; distinct values form distinct optimizer groups for free (the group key already keyed on this value)._reorder_parameter_metasreduces the value to a bool inside its sort key, so the meta ordering — and the sequence-parallel buffer contiguity it guarantees — is byte-identical for the existing bool values.Testing
Verified on CPU:
boolstays bool, float stays float,int→float, no bool↔float ambiguity.A_log/dt_bias/D→ no decay; weights/embeddings/head → decay.The tensor-parallel sequence-parallel path and float-valued optimizer grouping require the GPU suite; the reorder change is behavior-identical for all non-override (bool) usage by construction.
🤖 Generated with Claude Code