Skip to content

3 b training prep#452

Merged
le1nux merged 8 commits into
mainfrom
3B_training_prep
Jun 19, 2026
Merged

3 b training prep#452
le1nux merged 8 commits into
mainfrom
3B_training_prep

Conversation

@le1nux

@le1nux le1nux commented Jun 15, 2026

Copy link
Copy Markdown
Member

What does this PR do?

This PR prepares the 3B training path by tightening weight-tying behavior for parallel training and by making DCP checkpoint restores more flexible.

General Changes

  • Add selective AppState component loading so checkpoint restore can load only the model, optimizer, and/or LR scheduler as needed.
  • Thread components_to_load and allow_partial_load through the app-state factory and DCP checkpoint loading path.
  • Add has_tied_word_embeddings model capability checks and centralize tied-embedding validation helpers.
  • Reject tied word embeddings for Tensor Parallelism and Pipeline Parallelism configs.
  • Update the Llama3-like initializer so lm_head is only initialized separately when weight tying is disabled.
  • Expose has_tied_word_embeddings on GPT-2 models and add a default implementation on the base model class.
  • Add tests covering selective checkpoint component loading and tied-embedding validation behavior.

Breaking Changes

  • Tensor Parallelism and Pipeline Parallelism configs now fail validation when tied word embeddings are enabled.
  • DCP app-state loading config now includes allow_partial_load, which changes how partial checkpoint restores can be configured explicitly.

Checklist before submitting final PR

  • My PR is minimal and addresses one issue in isolation
  • I have merged the latest version of the target branch into this feature branch
  • I have reviewed my own code w.r.t. correct implementation, missing type hints, proper documentation, etc.
  • I have run a sample config for model training
  • I have checked that all tests run through (python tests/tests.py)
  • I have updated the internal changelog (CHANGELOG_DEV.md)

@le1nux le1nux marked this pull request as draft June 15, 2026 09:51
… weights separately from the input embedding weights, since they will be tied together and should share the same initialization. The lm head weights will be initialized as part of the input embedding weights initialization, so we can remove the separate initialization for the lm head weights when weight tying is enabled.
@le1nux le1nux marked this pull request as ready for review June 19, 2026 13:49
app_state=self,
state_dict=state_dict[StatefulComponents.OPTIMIZER.value],
)
if self._lr_scheduler is not None and StatefulComponents.LR_SCHEDULER in self._components_to_load:

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we raise an error if self._components_to_load contains something unexpected?

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added a check and also a test case for this.

Comment thread src/modalities/checkpointing/fsdp/fsdp_checkpoint_loading.py Outdated
Comment thread src/modalities/checkpointing/stateful/app_state_factory.py Outdated
Comment thread src/modalities/config/config.py Outdated
@@ -0,0 +1,137 @@
from unittest.mock import MagicMock

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe add a test for invalid combinations of allow_partial_load and components_to_load

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

@le1nux le1nux merged commit 8db7d24 into main Jun 19, 2026
3 checks passed
@le1nux le1nux deleted the 3B_training_prep branch June 19, 2026 15:17
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants