Skip to content

Exclude norm/bias/scale parameters from weight decay; make it configurable#545

Open
jlamypoirier wants to merge 1 commit into
mainfrom
jlp_fix_weight_decay
Open

Exclude norm/bias/scale parameters from weight decay; make it configurable#545
jlamypoirier wants to merge 1 commit into
mainfrom
jlp_fix_weight_decay

Conversation

@jlamypoirier

Copy link
Copy Markdown
Collaborator

Claude Opus 4.8 authored this PR.

Closes #542.

Problem

The normalization refactor dropped weight_decay=False when creating norm weights/biases, so they fell back to the weight_decay=True default and became subject to AdamW weight decay — contrary to the prior behavior and the usual convention of decaying only weight matrices, not 1-D biases/gains/state parameters. Auditing the other get_parameter call sites surfaced the same issue in newer code (linear biases, causal-conv bias, SSM dt_bias/A_log, and the init-1 scales).

Change

  1. Restore the convention as the structural default at each call site — pass weight_decay=False for: norm weight + bias, RMSNorm weight, all linear biases (AffineLinearConfig + CausalConv1d), dt_bias/A_log in gdn/kda (matching mamba's existing A_log/D treatment), and the init-1 scales output_scale / router_scale / router_per_expert_scale. Embeddings, linear/head weights, and conv weights stay decayed.
  2. Make it configurable instead of hardcoded, so the values can be experimented with — new per-parameter field ParameterConfig.weight_decay: float | bool | None:
    • None (default) → keep the parent layer's structural default
    • True / False → enable (global value) / disable
    • a float → use that exact coefficient

The convention follows standard practice (BERT/GPT-2/GPT-3, HF default no_decay = ['bias', 'LayerNorm.weight']); decaying a gain/bias initialized to 1/0 pulls it toward 0 and fights the normalization. The init-1 scales are treated like norm gains, consistent with recent work on learnable multipliers that excludes such scalars from decay.

Plumbing

ParameterMeta.param_weight_decay now carries float | bool. The param-group builder maps True→global / False0.0 / float→itself; distinct values form distinct optimizer groups for free (the group key already keyed on this value). _reorder_parameter_metas reduces the value to a bool inside its sort key, so the meta ordering — and the sequence-parallel buffer contiguity it guarantees — is byte-identical for the existing bool values.

Testing

Verified on CPU:

  • Config validation: bool stays bool, float stays float, int→float, no bool↔float ambiguity.
  • Built models (gpt_2, mixtral, gdn, kda, mamba): norm/bias/A_log/dt_bias/D → no decay; weights/embeddings/head → decay.
  • Per-parameter override (float + bool) flows through real config parsing and is correctly scoped.

The tensor-parallel sequence-parallel path and float-valued optimizer grouping require the GPU suite; the reorder change is behavior-identical for all non-override (bool) usage by construction.

🤖 Generated with Claude Code

…rable

The normalization refactor dropped `weight_decay=False` on norm weights/biases,
so they (and several other 1-D parameters) were silently weight-decayed against
the usual convention of decaying only weight matrices (#542).

Restore the convention as the structural default at each call site — norm
weight+bias, all linear biases (incl. causal conv), the SSM `dt_bias`/`A_log`
state params, and the init-1 scales (`output_scale`, router scales) — and expose
it as a tunable per-parameter knob `ParameterConfig.weight_decay` so the values
can be experimented with: `None` keeps the default, `True`/`False` enable/disable,
a float sets a specific coefficient. Embeddings and weight matrices stay decayed.

`ParameterMeta.param_weight_decay` now carries `float | bool`; the param-group
builder maps `True`->global / `False`->0.0 / float->itself, and the meta reorder
reduces it to a bool so sequence-parallel buffer contiguity is unchanged.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Norm weights/biases no longer excluded from weight decay (regression from normalization refactor)

1 participant