Transformer Roadmap
The problem this chapter solves is:
The repository name points toward Transformers, but the current code is a foundation course. This chapter explains exactly how the current objects and morphisms point toward a future attention-based model.
The current code is not a full Transformer.
It teaches the typed pieces you need first:
tokens
vectors
logits
probabilities
loss
training updates
composition
This distinction matters. A roadmap should not pretend the current crate is a production Transformer or a full sequence model. It should show how the current typed skeleton can grow without losing the discipline that made the small examples understandable.
Reader orientation: Read this chapter as an engineering migration plan, not as a promise that the current code already contains every Transformer component.
The source path for this roadmap is:
current Rust pipeline
-> original Transformer architecture
-> implementation-oriented attention tutorials
-> future typed Rust milestones
The original Transformer paper introduced an architecture based on attention instead of recurrence or convolution for sequence transduction. Dive into Deep Learning gives a practical learning path through queries, keys, values, multi-head attention, self-attention, positional encoding, and the full Transformer architecture. Implementation tutorials such as The Annotated Transformer and visual explainers such as The Illustrated Transformer are useful bridges from paper notation to code and diagrams.
Framework documentation such as PyTorch’s attention and Transformer layer APIs is useful for one narrower purpose: checking the public shapes that production tools expose. This chapter does not copy those APIs. It uses them as a sanity check while keeping the teaching path smaller and typed.
The Hugging Face course also gives a useful distinction for this roadmap:
architecture, checkpoint, and model are not the same idea. This repository is
working on architecture pieces: named states, typed boundaries, and update
rules. It is not loading a pretrained checkpoint, and it is not wrapping a
large framework model output. When this chapter uses words such as
HiddenSequence, AttentionWeights, or SequenceLogits, read them as tiny
Rust-owned teaching objects that make the same roles inspectable.
There is also advanced category-theory work that studies attention more directly. One recent source introduces a category-theoretic diagrammatic formalism for decomposing attention mechanisms into anatomical components and comparing attention variants. Another treats the linear query, key, and value maps through a parametric categorical lens and studies how layered self-attention structure can be composed. These sources are useful precision support, but they are not a license to call every part of a Transformer the same categorical object. The parametric-endofunctor paper itself separates its linear focus from nonlinear pieces such as softmax and layer normalization. This roadmap follows the same caution: name the linear maps, product-input boundaries, shape-preserving endomorphisms, and state updates separately.
A broader categorical deep-learning source makes the same warning at the architecture level: a model can be described by constraints it should satisfy and by the implementation that realizes those constraints. This roadmap uses that distinction as a practical rule. Do not treat a compiled Rust boundary as proof that the whole architecture satisfies a mathematical constraint. Do not treat an architecture diagram as a substitute for a concrete type, constructor, example, and test.
This chapter keeps those sources in view, but it does not import their full complexity all at once. The rule is: add one typed concept only when the tiny Rust version can explain its boundary.
Chapter Outcomes
By the end of this chapter, you should be able to:
- trace the attention example from query/key scoring through masking, softmax, value mixing, projection, residual addition, normalization, and feed-forward refinement,
- classify Transformer boundaries by counting inputs before naming morphisms, product-input morphisms, endomorphisms, or illegal attempted compositions,
- separate architecture constraints from implementation evidence in the tiny Rust roadmap.
What You Already Know
If you understand the current prediction path, you already know the skeleton a Transformer will extend. Tokens become vectors, vectors move through typed transformations, and probabilities feed a loss. The future work is to replace the one-token middle with sequence-aware structure.
Transformer Role Ownership Map
Before reading the implementation status table, separate the roles. A Transformer explanation becomes hard when query, key, value, score, weight, mask, and hidden-state roles all look like raw vectors or matrices. This roadmap assigns each role to a named Rust type or boundary.
| Transformer role | Rust owner | Boundary shape | Confusion prevented |
|---|---|---|---|
| hidden state sequence | HiddenSequence | model-width rows over sequence positions | treating one token vector as a full sequence |
| query role | QuerySequence and HiddenToQuery | HiddenSequence -> QuerySequence | passing values where queries are expected |
| key role | KeySequence and HiddenToKey | HiddenSequence -> KeySequence | comparing against value vectors instead of keys |
| value role | ValueSequence and HiddenToValue | HiddenSequence -> ValueSequence | mixing scores directly instead of value vectors |
| raw attention scores | AttentionScores | QuerySequence x KeySequence -> AttentionScores | treating unnormalized scores as probabilities |
| mask | AttentionMask | AttentionScores x AttentionMask -> AttentionScores | allowing illegal positions into softmax |
| normalized attention weights | AttentionWeights | AttentionScores -> AttentionWeights | forgetting that each query row is a distribution over source positions |
| value mixing | AttentionOutput | AttentionWeights x ValueSequence -> AttentionOutput | multiplying weights without saying what information is read |
| multiple heads | AttentionHeadOutputs and MultiHeadOutput | head outputs -> concatenated model-width rows | losing head count and head dimension |
| output projection | ProjectedAttentionOutput | MultiHeadOutput -> ProjectedAttentionOutput | leaving concatenated heads at the wrong width |
| residual boundary | ResidualConnection | HiddenSequence x ProjectedAttentionOutput -> HiddenSequence | adding tensors that cannot return to the block input shape |
| layer normalization | LayerNormalization | HiddenSequence -> HiddenSequence | changing values while accidentally changing the public object |
| feed-forward sublayer | PositionWiseFeedForward | HiddenSequence -> HiddenSequence | forgetting that the sublayer is position-wise and shape-preserving |
| block mask boundary | MaskedMultiHeadTransformerBlock | HiddenSequence x AttentionMask -> HiddenSequence | hiding the mask inside loose optional state |
| sequence readout | TransformerReadout and SequenceLogits | HiddenSequence -> SequenceLogits | confusing hidden states with vocabulary scores |
| training state | TransformerTrainingState | state plus learning rate plus step count | passing loose parameters without optimizer context |
This table is the chapter’s first debugging tool. If a later attention formula feels vague, point to the row that owns the role. The typed roadmap should make the question concrete:
Which object owns this role?
Which boundary produces it?
Which invalid connection should fail?
Category Naming Contract
Before this chapter calls an attention boundary an endomorphism, count its inputs. The original Transformer architecture, the query-key-value teaching path, and framework attention APIs all expose the same warning: attention is not one vague arrow from a sequence to itself. Some stages need a query side and a source side. Some stages need a mask. Some stages need the previous hidden stream and a sublayer output.
Use this contract while reading the roadmap:
| If the boundary has shape | Name it as | Example | Do not call it |
|---|---|---|---|
A -> B | ordinary morphism | AttentionScores -> AttentionWeights | an endomorphism |
A -> A | endomorphism | LayerNormalization : HiddenSequence -> HiddenSequence | a product boundary |
A x B -> C | product-input morphism | AttentionWeights x ValueSequence -> AttentionOutput | a unary transform |
A x B -> A | product-input morphism returning A | HiddenSequence x ProjectedAttentionOutput -> HiddenSequence | an endomorphism unless the whole input object and output object are identical |
| missing projection or wrong role | illegal attempted composition | HiddenSequence x MultiHeadOutput -> HiddenSequence | a clever shortcut |
One more context rule matters for learned layers. When this roadmap writes
LayerNormalization : HiddenSequence -> HiddenSequence or
PositionWiseFeedForward : HiddenSequence -> HiddenSequence, it means:
for this fixed layer instance, with its current parameters already stored
inside the Rust object
If the parameters themselves are allowed to vary, name the larger boundary
instead. For example, a parameter-learning story belongs to
TransformerTrainingState -> TransformerTrainingState, not to a hidden
sequence endomorphism that silently changes weights.
This rule keeps the category-theory vocabulary proportional to the code. The linear query, key, value, positional, and layered pieces can be compared with advanced categorical work on self-attention. Masking, softmax, residual addition, normalization, feed-forward refinement, and training updates still need their own typed boundaries in this teaching project.
Fixed-Value Endomorphism Ledger
Use this ledger whenever a roadmap boundary looks like
HiddenSequence -> HiddenSequence. The shape is not enough by itself; the
stored context must also be stable for the forward call.
| Boundary | Fixed value that makes the unary view valid | If that value changes |
|---|---|---|
PositionalEncoding : HiddenSequence -> HiddenSequence | one position table with fixed row count and model dimension | name the table update or rebuild path separately |
LayerNormalization : HiddenSequence -> HiddenSequence | one layer-normalization value with fixed scale, shift, and epsilon | move to TransformerTrainingState -> TransformerTrainingState |
PositionWiseFeedForward : HiddenSequence -> HiddenSequence | one feed-forward value with fixed weights, biases, and activation rule | move to TransformerTrainingState -> TransformerTrainingState |
MultiHeadTransformerBlock : HiddenSequence -> HiddenSequence | one block value with fixed heads, projections, residual path, normalization, and feed-forward layers | move to TransformerTrainingState -> TransformerTrainingState |
fixed-mask view of MaskedMultiHeadTransformerBlock | one named AttentionMask selected before the hidden-sequence call | return to HiddenSequence x AttentionMask -> HiddenSequence, or name a larger state carrying the changing mask |
This table is backed by the same source roles as the precision rules below: parameter-management references explain why model components own parameters, optimizer references explain why changing parameters belongs to the training loop, and Rust closure references explain the local analogy for fixing a mask before the remaining call.
Add-Norm Order Ledger
Residual addition and layer normalization are not only two names that happen in the same neighborhood. Their order is part of the block boundary.
The original Transformer uses residual addition around each sublayer followed
by layer normalization. Dive into Deep Learning teaches the same AddNorm
shape as residual addition followed by layer normalization. PyTorch exposes the
order as a configurable boundary: TransformerEncoderLayer has norm_first,
where layer normalization can happen before attention and feed-forward
operations instead of after them. Research on layer-normalization placement
also treats the difference between Post-LN and Pre-LN as an optimization
question, not a cosmetic rewrite.
This repository currently teaches the post-add normalization path:
attention sublayer output
-> ResidualConnection
-> attention_norm
feed-forward sublayer output
-> ResidualConnection
-> feed_forward_norm
Use this ledger when reading or extending the block:
| Order question | Source signal | Current Rust reading | Safe category statement |
|---|---|---|---|
| original post-norm shape | original Transformer and D2L AddNorm place normalization after residual addition | ResidualConnection runs before attention_norm and feed_forward_norm | fixed block is still HiddenSequence -> HiddenSequence |
| configurable framework shape | PyTorch norm_first can move normalization before attention and feed-forward operations | no pre-norm block is implemented here yet | a future pre-norm block needs a named constructor or type |
| optimization meaning | layer-normalization placement affects gradient behavior in Transformer training | current tests validate the local post-add path only | same source and target object does not imply same morphism |
| teaching boundary | order is visible in MultiHeadTransformerBlock::apply and MaskedMultiHeadTransformerBlock::apply_with_cache | residual output is normalized before feed-forward runs | do not erase order when explaining composition |
The important category-theory lesson is modest:
post-norm block : HiddenSequence -> HiddenSequence
pre-norm block : HiddenSequence -> HiddenSequence
Those two arrows can have the same source and target while being different morphisms. Shape compatibility permits composition. It does not say the two implementations are interchangeable.
Source-Backed Precision Rules
Use this table as a citation-to-claim guard while reading the rest of the roadmap. Each source supports a local teaching rule. None of them should be used as a shortcut around the typed Rust boundary.
| Source signal | Local rule in this roadmap | Rust evidence to inspect |
|---|---|---|
| Attention Is All You Need introduces the Transformer around attention instead of recurrence or convolution | treat attention as the architecture target, not as proof that the current crate is a full Transformer | examples/06_attention_scores.rs is a shape lab, not a production model |
Dive into Deep Learning: Scaled Dot Product Attention writes attention with n queries and m key-value pairs | keep query-side length and source-side length visible before naming the morphism | QuerySequence x KeySequence -> AttentionScores and AttentionWeights x ValueSequence -> AttentionOutput |
PyTorch MultiheadAttention exposes separate query, key, and value inputs with target length L and source length S | do not collapse self-attention and cross-attention into one vague HiddenSequence -> HiddenSequence arrow | TargetHiddenSequence -> QuerySequence and SourceHiddenSequence -> KeySequence, ValueSequence are the future cross-attention shape |
PyTorch scaled dot product attention says the attention mask must broadcast to the attention-weight shape, a boolean True means the element participates in attention, and a float mask is added to attention scores | a mask modifies the score table before probability normalization; it is not a token sequence and not attention weights | AttentionScores x AttentionMask -> AttentionScores runs before AttentionScores -> AttentionWeights |
PyTorch Transformer and PyTorch MultiheadAttention expose mask arguments where boolean True can mean “not allowed” or “ignore this key” | mask shape and mask polarity are separate ideas; translate polarity before comparing APIs with this Rust roadmap | AttentionMask::new(vec![vec![true, false, true], ...]) uses true for “this source position is allowed” |
TensorFlow Keras MultiHeadAttention uses query shape (B, T, dim), value/key shape (B, S, dim), mask shape (B, T, S), and a boolean attention mask where 1 means attention is allowed | treat target/query length, source/key-value length, and allow-mask polarity as framework-neutral shape evidence | AttentionMask answers which source positions each target position may read |
| PyTorch Transformer building blocks separates dense tensors, nested tensors, masks, scaled dot-product attention, and cross-attention concerns | production masking and variable-length behavior are framework boundary choices; the tiny Rust mask is deliberately stricter | AttentionMask::new rejects a row with no legal keys |
PyTorch TransformerEncoderLayer exposes norm_first and the original encoder-layer reference shape | residual-normalization order is a named architecture choice, not a detail to hide behind HiddenSequence -> HiddenSequence | MultiHeadTransformerBlock::apply uses post-add normalization today; a future pre-norm variant needs a named boundary |
| On Layer Normalization in the Transformer Architecture distinguishes Post-LN and Pre-LN Transformer variants and studies their training behavior | same source and target object can still mean different morphisms when the internal order changes | local tests validate the current post-add path, not every normalization-order variant |
| Dive into Deep Learning: Parameter Management treats parameters as named model components that can be accessed and updated | a forward sublayer may be an endomorphism only for a fixed layer instance; parameter-changing claims belong to the training-state boundary | LayerNormalization stores scale and shift parameters; TransformerTrainingState owns mutable training context |
| CS231n Neural Networks Part 3 and PyTorch gradcheck compare numerical finite differences with analytical gradients under tolerance and precision caveats | a finite-difference match is local evidence for one selected parameter path, not proof of every gradient, dataset, optimizer, or future training loop | transformer_block_train_step_matches_finite_difference_for_readout_weight, feed-forward, layer-normalization, output-projection, and attention-projection tests |
| Rust Book: Closures explains closures as callable values that can capture values from their surrounding environment | use closure capture as the Rust analogy for fixing a mask context before applying a unary view | `move |
| On the Anatomy of Attention studies attention by decomposing variants into components | decompose attention first, then compare variants | the roadmap names scores, masks, weights, values, heads, projection, residuals, normalization, and feed-forward separately |
| Self-Attention as a Parametric Endofunctor focuses on linear self-attention structure and explicitly separates nonlinear pieces | use “endofunctor” language only after naming the linear scope; do not carry it through softmax, masking, residuals, normalization, or training state without a new argument | HiddenSequence -> QuerySequence is a linear role-producing morphism; AttentionScores x AttentionMask -> AttentionScores is still product-input context |
Attention Mental Model Repair Table
Use this table before the first attention example. Each row repairs a tempting mental model with one source-backed rule and one local Rust checkpoint.
| Tempting mental model | Safer model | Rust checkpoint |
|---|---|---|
| query turns into key, then key turns into value | query, key, and value are roles; in self-attention they are parallel projections from the same hidden source, and in cross-attention the query side and key-value side may come from different sequence objects | HiddenSequence -> QuerySequence, HiddenSequence -> KeySequence, and HiddenSequence -> ValueSequence are siblings, not a pipeline |
| raw scores are already attention probabilities | scores become probabilities only after mask handling and row-wise softmax; value mixing happens after weights exist | AttentionScores x AttentionMask -> AttentionScores -> AttentionWeights comes before AttentionWeights x ValueSequence -> AttentionOutput |
| same output shape means endomorphism | count the whole source object first; A x B -> A returns A, but it is still a product-input boundary while B is open | HiddenSequence x ProjectedAttentionOutput -> HiddenSequence is not the same shape as HiddenSequence -> HiddenSequence |
| fixing a mask means the mask disappeared | a fixed-context view is a new named view after one mask has been chosen; the open boundary remains product-input | MaskedMultiHeadTransformerBlock[M] : HiddenSequence -> HiddenSequence is valid only after naming the fixed AttentionMask M |
The repair pattern is:
bad shortcut -> source-backed role or shape rule -> local Rust boundary
Do this before using a category-theory label. The label should describe the boundary the reader can inspect, not the shortcut the reader is trying to remember.
If a future chapter cites a stronger categorical result, it should add the same three pieces:
source claim
local typed boundary
validation command or test
That keeps the roadmap useful for both readers: the ML reader can see which shape is being implemented, and the category-theory reader can see which formal claim is being used and where it stops.
Worked Example Priority
The roadmap now has many typed attention boundaries. A reader does not need all of them expanded at the same depth on a first pass. Use this priority table to decide which sections deserve worked examples before more implementation is added.
| Priority | Boundary | Why this comes first | Evidence to ask from a reader |
|---|---|---|---|
| 1 | AttentionScores x AttentionMask -> AttentionScores -> AttentionWeights | readers often confuse raw scores, masked scores, and probabilities | Can the reader explain which positions were removed before softmax? |
| 2 | HiddenSequence -> QuerySequence, KeySequence, ValueSequence | query, key, and value are numerically similar but semantically different roles | Can the reader say which role asks, which role is compared, and which role is mixed? |
| 3 | AttentionWeights x ValueSequence -> AttentionOutput | attention becomes useful only when weights read values | Can the reader trace one output row as a weighted sum of value rows? |
| 4 | AttentionHeadOutputs -> MultiHeadOutput -> ProjectedAttentionOutput | multi-head attention adds shape arithmetic that can hide mistakes | Can the reader compute head_count * head_dimension and name the projection input width? |
| 5 | HiddenSequence x ProjectedAttentionOutput -> HiddenSequence | residual addition explains why many sublayers return to the same object | Can the reader explain why mismatched sequence length or model dimension must fail? |
| 6 | HiddenSequence -> HiddenSequence for normalization and feed-forward | these are shape-preserving sublayers, not new sequence objects | Can the reader name what changes and what stays invariant? |
| 7 | TransformerTrainingState -> TransformerTrainingState | training is important, but it should come after forward shape ownership is clear | Can the reader separate readout-only, local feed-forward, and composed block updates? |
This table is not a ranking of importance. It is a ranking of teaching risk. The first three rows protect the core attention story:
roles -> scores -> masked weights -> mixed values
If a reader cannot trace that path, the later block and training sections will feel like a list of names. If the reader can trace it, residuals, normalization, feed-forward layers, and training state have a stable place to attach.
Worked Example: Mask Before Softmax
The original Transformer formula and implementation-oriented attention references put softmax after query-key scoring. Practical implementations add the attention mask to the score table before softmax. The reason is simple: only legal positions should compete for probability mass.
The runnable attention example starts with two query positions and three key positions:
let queries = QuerySequence::new(vec![vec![1.0, 0.0], vec![0.0, 1.0]])?;
let keys = KeySequence::new(vec![vec![1.0, 0.0], vec![0.0, 1.0], vec![1.0, 1.0]])?;
let values = ValueSequence::new(vec![vec![1.0, 10.0], vec![2.0, 20.0], vec![3.0, 30.0]])?;
let mask = AttentionMask::new(vec![vec![true, false, true], vec![true, true, true]])?;
For the first query row, scaled dot product produces:
raw scores:
[0.7071, 0.0000, 0.7071]
mask:
[true, false, true]
masked scores:
[0.7071, very negative, 0.7071]
row-wise softmax:
[0.5, 0.0, 0.5]
The middle key position has a real raw score, but the mask says this query position is not allowed to read it. The mask must therefore act before softmax. After softmax, the illegal position would already have received probability mass.
The value-mixing step then reads only the allowed value rows:
0.5 * [1.0, 10.0]
+ 0.0 * [2.0, 20.0]
+ 0.5 * [3.0, 30.0]
= [2.0, 20.0]
That is why the typed path is:
AttentionScores x AttentionMask -> AttentionScores
AttentionScores -> AttentionWeights
AttentionWeights x ValueSequence -> AttentionOutput
The mask boundary is not a cosmetic option. It protects the meaning of the
probability row. AttentionWeights should answer:
among the positions this query may read, how much should each one contribute?
In this repository, masked-out scores become a very negative finite value instead of a non-finite value so that the pedagogical constructors can keep the “all scores are finite” invariant. The teaching meaning is the same as the standard attention implementation pattern: make disallowed positions effectively impossible before row-wise softmax.
Mask Polarity Ledger
Mask shape answers:
which query row and source column is this mask cell about?
Mask polarity answers:
does true mean allowed, or does true mean blocked?
Those are different questions. This repository chooses the smaller teaching polarity:
true -> this query may read this source position
false -> this query may not read this source position
That choice matches the boolean mask meaning used by PyTorch’s
scaled_dot_product_attention, where True means the element participates in
attention. It also matches the Keras MultiHeadAttention attention-mask rule
where 1 marks a query-key pair that may attend. It does not match every
PyTorch attention API. In MultiheadAttention padding masks, and in the
boolean masks described by torch.nn.Transformer, True can mean the
position is blocked or ignored.
So translate a framework mask in two steps:
| Question | Rust roadmap answer | Framework caution |
|---|---|---|
| What is the shape? | one cell per query-source score position | L x S, (B, T, S), and padding masks point at different axes |
| What is the polarity? | true means allowed | some APIs use true to mean blocked or padding |
| When is it applied? | before softmax, while values are still scores | after-softmax masking would change the meaning of the probability row |
The safe translation rule is:
first match the mask cells to score cells,
then translate boolean polarity,
then apply the mask before softmax
Do not carry a raw boolean mask from a framework into this Rust roadmap without stating its polarity. Two masks can have the same shape and opposite meaning.
Production Masking Caveat
The tiny AttentionMask in src/attention.rs is stricter than a production
framework boundary. For example, this constructor call is rejected:
AttentionMask::new(vec![vec![false, false]])
The reason is pedagogical. In this book, every attention-weight row should mean:
among at least one legal source position, how much should each one contribute?
If a row allows no source positions, there is no probability support for that row. The constructor therefore returns:
Err(CtError::EmptyInput("attention mask row allows no keys"))
That boundary is intentionally less general than production Transformer
libraries. PyTorch’s Transformer building-blocks tutorial discusses nested
tensors, variable sequence lengths, padding masks, and the production problem
of fully masked rows. It notes that softmax over an empty set is undefined and
that newer scaled_dot_product_attention behavior returns zero output for
fully masked rows.
The contrast is useful:
| Concern | Production framework boundary | Tiny teaching boundary |
|---|---|---|
| variable sequence lengths | ragged batches, padding, nested tensors, and mask ergonomics | each example uses one explicit rectangular mask |
| fully masked query row | framework must decide a stable output convention | constructor rejects the row before softmax |
| performance | fused kernels, compilation, and memory-aware representations | small values the reader can inspect by hand |
This is not a disagreement with framework behavior. It is a scope decision. It
preserves the invariant that AttentionWeights is a row-wise
distribution over at least one source position. A future production-oriented
chapter can relax that boundary only if it also names the new output convention
for rows with no legal source positions.
What Exists Now
The current model has this prediction path:
TokenId -> Vector -> Logits -> Distribution
The implementation status is:
| Concept | Current status | Reason |
|---|---|---|
| Token ids | implemented | TokenId and TokenSequence already exist |
| Vectors | implemented | Vector is the current hidden representation |
| Logits and probabilities | implemented | LinearToLogits and Softmax are executable |
| Loss | implemented | CrossEntropy evaluates prediction against target |
| Parameter update | implemented | TrainStep updates Parameters |
| Query-key score boundary | implemented as a tiny roadmap sketch | QuerySequence x KeySequence -> AttentionScores is executable |
| Attention mask boundary | implemented as a tiny roadmap sketch | AttentionScores x AttentionMask -> AttentionScores is executable |
| Attention score-to-weight boundary | implemented as a tiny roadmap sketch | AttentionScores -> AttentionWeights is executable |
| Value-mixing boundary | implemented as a tiny roadmap sketch | AttentionWeights x ValueSequence -> AttentionOutput is executable |
| Multi-head concatenation boundary | implemented as a tiny roadmap sketch | AttentionHeadOutputs -> MultiHeadOutput is executable |
| Attention output projection boundary | implemented as a tiny roadmap sketch | MultiHeadOutput -> ProjectedAttentionOutput is executable |
| Sequence hidden states | implemented as a tiny roadmap sketch | HiddenSequence is executable for residual addition |
| Residual addition boundary | implemented as a tiny roadmap sketch | HiddenSequence x ProjectedAttentionOutput -> HiddenSequence is executable |
| Layer normalization boundary | implemented as a tiny roadmap sketch | HiddenSequence -> HiddenSequence is executable through LayerNormalization |
| Position-wise feed-forward boundary | implemented as a tiny roadmap sketch | HiddenSequence -> HiddenSequence is executable through PositionWiseFeedForward |
| Hidden-to-query/key/value projections | implemented as a tiny roadmap sketch | HiddenSequence -> QuerySequence, HiddenSequence -> KeySequence, and HiddenSequence -> ValueSequence are executable |
| Single-head block boundary | implemented as a tiny roadmap sketch | SingleHeadTransformerBlock : HiddenSequence -> HiddenSequence composes the current boundaries |
| Multi-head block boundary | implemented as a tiny roadmap sketch | MultiHeadTransformerBlock : HiddenSequence -> HiddenSequence composes several SelfAttentionHead values |
| Positional encoding | implemented as a tiny roadmap sketch | PositionalEncoding : HiddenSequence -> HiddenSequence adds position rows while preserving shape |
| Masked block variants | implemented as a tiny roadmap sketch | MaskedMultiHeadTransformerBlock : HiddenSequence x AttentionMask -> HiddenSequence accepts a block-level mask |
| Sequence logits and readout | implemented as a tiny roadmap sketch | TransformerReadout : HiddenSequence -> SequenceLogits produces vocabulary scores at each sequence position |
| Structured Transformer parameter object | implemented as a tiny roadmap sketch | TinyTransformerParameters : HiddenSequence x AttentionMask -> SequenceLogits owns position, masked block, and readout pieces |
| Structured Transformer training state | implemented as a tiny roadmap sketch | TransformerTrainingState owns parameters, learning rate, and step count |
| Readout-only training step | implemented as a tiny roadmap sketch | TransformerReadoutTrainStep : TransformerTrainingState -> TransformerTrainingState updates only the sequence readout |
| Local feed-forward training step | implemented as a tiny roadmap sketch | TransformerFeedForwardTrainStep : TransformerTrainingState -> TransformerTrainingState updates only the position-wise feed-forward sublayer against hidden targets |
| Composed block training step | implemented as a tiny roadmap sketch | TransformerBlockTrainStep : TransformerTrainingState -> TransformerTrainingState updates readout, feed-forward, attention-output-projection, query/key/value, and layer-normalization parameters from sequence targets through residual, normalization, and attention paths |
This table is a guardrail. When extending the project, do not present planned items as implemented content. Add the type, example, test, chapter prose, and reference link together.
Rust Syntax
The path is implemented with:
Embedding
LinearToLogits
Softmax
Compose
The main domain objects are:
TokenId
Vector
Logits
Distribution
Parameters
The training update is:
TrainStep : Parameters -> Parameters
ML Concept
This is a tiny next-token model.
It predicts from one token at a time.
The main training example is still that small. The roadmap module now sketches attention blocks and structured Transformer state, but it does not yet train a production Transformer.
Still, it already teaches the core path:
discrete token
-> dense representation
-> vocabulary scores
-> next-token probabilities
Category Theory Concept
The current system teaches composition:
TokenId -> Vector -> Logits -> Distribution
and endomorphism:
Parameters -> Parameters
Those two shapes remain central in Transformers.
Step 1: Sequences As First-Class Objects
The future problem:
Attention does not operate on one token alone. It operates on a sequence of hidden states.
The current code already has TokenSequence, but that is a sequence of token
ids. Attention needs a sequence of hidden vectors, usually with position and
mask information attached. That is a different object with different
invariants.
Worked Example: Validating Sequence Length
The first-principles Rust move is the same one used throughout the book: do not let a meaningful value travel as a raw primitive once it crosses a conceptual boundary. The roadmap module now starts with a small validating type:
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct SequenceLength(usize);
impl SequenceLength {
pub fn new(value: usize) -> CtResult<Self> {
if value == 0 {
return Err(CtError::EmptyInput("sequence length"));
}
Ok(Self(value))
}
pub fn value(&self) -> usize {
self.0
}
}
Self-Check
Before reading the roadmap steps, explain why a future SequenceLength should
not be passed around as a bare usize.
Rust Syntax
A future extension should introduce types such as:
pub struct Position(usize);
pub struct SequenceLength(usize);
pub struct HiddenSequence(Vec<Vector>);
pub struct AttentionMask(/* validated mask representation */);
The important rule is the same as this course:
do not pass raw vectors across architectural boundaries
ML Concept
Attention needs a representation like:
[hidden_0, hidden_1, hidden_2, ...]
plus position and mask information.
Category Theory Concept
The object changes from:
Vector
to:
Sequence(Vector)
The next morphisms operate on structured sequences.
Design contract:
TokenSequence -> HiddenSequence
should not be represented as:
Vec<usize> -> Vec<Vec<f32>>
The second shape hides every domain distinction the course has worked to make visible.
Step 2: Query, Key, And Value Projections
The current problem:
Attention compares tokens by projecting hidden states into query, key, and value spaces.
The important design move is not only three matrices. It is three roles. A query vector, key vector, and value vector may share the same numeric representation, but they should not share the same Rust type once they cross a module boundary.
Rust Syntax
The current projection morphisms have shapes:
HiddenSequence -> QuerySequence
HiddenSequence -> KeySequence
HiddenSequence -> ValueSequence
Each output type should be distinct.
The current roadmap code models both the role objects and the hidden-state projection morphisms:
HiddenToQuery
HiddenToKey
HiddenToValue
QuerySequence
KeySequence
ValueSequence
Queries, keys, and values are all vectors underneath, but they have different roles.
Worked Example: Same Hidden Row, Three Roles
The query-key-value split is not about three mysterious kinds of vector. It is about three uses of a hidden state.
Start with two hidden rows:
hidden_0 = [1.0, 2.0]
hidden_1 = [3.0, 4.0]
A tiny set of projections can send the same hidden rows into three role-specific objects:
let hidden = HiddenSequence::new(vec![vec![1.0, 2.0], vec![3.0, 4.0]])?;
let to_query = HiddenToQuery::new(
vec![vec![1.0, 0.0], vec![0.0, 1.0]],
vec![0.0, 0.0],
)?;
let to_key = HiddenToKey::new(
vec![vec![0.0, 1.0], vec![1.0, 0.0]],
vec![0.0, 0.0],
)?;
let to_value = HiddenToValue::new(
vec![vec![10.0, 0.0], vec![0.0, 10.0]],
vec![0.0, 0.0],
)?;
For hidden_0, those projections produce:
query_0 = [1.0, 2.0]
key_0 = [2.0, 1.0]
value_0 = [10.0, 20.0]
The numbers are deliberately simple. The important lesson is the role separation:
| Role | Question it answers | Used for |
|---|---|---|
| query | what is this position looking for? | compared with keys |
| key | what can this source position be matched by? | compared with queries |
| value | what information can this source position contribute? | mixed after weights exist |
If all three values were passed around as Vec<Vec<f32>>, the compiler could
not help a reader notice a role mistake. ValueSequence could accidentally be
fed into query-key scoring. KeySequence could accidentally be mixed as if it
were content. The typed split makes that confusion harder to express.
This also explains why the attention path has two phases:
QuerySequence x KeySequence -> AttentionScores
AttentionWeights x ValueSequence -> AttentionOutput
Queries and keys decide where to look. Values provide what gets read.
Self-Attention And Cross-Attention Boundary
The current roadmap example is self-attention: query, key, and value roles all
come from the same HiddenSequence before they are projected into separate
role objects.
That is only one attention case.
Official framework documentation exposes a more general boundary. PyTorch’s
multi-head attention API accepts query, key, and value as separate
inputs. Its shape language distinguishes target sequence length L for
queries from source sequence length S for keys and values. Dive into Deep
Learning makes the same teaching distinction when it writes attention over
n queries and m key-value pairs.
TensorFlow/Keras exposes the same split with different letters: query has
target length T, value and key have source length S, and the attention mask
has shape (B, T, S). That cross-framework agreement is useful because it
keeps the rule from sounding like a PyTorch naming quirk. A target/query row
asks a question. A source/key-value column is something that can be read.
This matters for the book because it prevents a subtle category mistake. The attention scoring boundary is not automatically:
HiddenSequence -> HiddenSequence
The more honest shape is:
Target positions x Source positions -> attention weights
or, in the current Rust vocabulary:
QuerySequence x KeySequence -> AttentionScores
AttentionWeights x ValueSequence -> AttentionOutput
Self-attention is the special case where the target positions and source positions come from the same hidden sequence:
HiddenSequence -> QuerySequence
HiddenSequence -> KeySequence
HiddenSequence -> ValueSequence
These are parallel projections, not a pipeline where queries turn into keys and keys turn into values. The shared source is what makes the case “self-attention”; the role split is still real after projection.
Cross-attention is the case where the query side and the key-value side come from different sequence objects:
TargetHiddenSequence -> QuerySequence
SourceHiddenSequence -> KeySequence
SourceHiddenSequence -> ValueSequence
The tiny repository does not implement a full cross-attention module yet. But the naming rule should already be clear:
same source for Q, K, V -> self-attention case
separate query and key-value sources -> cross-attention case
Use this Q/K/V source diagnostic before reading a framework call:
| Question | Self-attention answer | Cross-attention answer |
|---|---|---|
| Which sequence owns the query side? | the same hidden sequence | the target hidden sequence |
| Which sequence owns the key side? | the same hidden sequence | the source hidden sequence |
| Which sequence owns the value side? | the same hidden sequence | the source hidden sequence |
| Which length counts score rows? | target/query length | target/query length |
| Which length counts score columns? | source/key-value length, equal to target length in the simple self-attention case | source/key-value length, possibly different from target length |
This table prevents a common framework-reading mistake. Passing the same hidden sequence into Q, K, and V means the source object is shared. It does not mean the projected query, key, and value roles have become the same role.
When you run the attention example, the first lines now anchor that diagnostic before any probabilities appear:
Q/K/V source diagnostic:
query rows own score rows; key/value rows own score columns
self-attention shares the hidden source before projection; projected roles stay distinct
mask polarity here: true = allowed, false = blocked
Use those four lines before interpreting attention shape: 2 query positions x 3 key positions. The terminal output gives the learner one inspectable signal
for the source-backed rule above: score rows come from the query side, score
columns come from the key-value side, and the local mask polarity must be
translated before comparing the Rust example with a framework API.
PyTorch and TensorFlow/Keras use different names but expose the same shape split:
| Framework cue | Query side | Key-value side | Mask cue |
|---|---|---|---|
| PyTorch | target length L | source length S | attention weights and masks use L x S |
| TensorFlow/Keras | target length T | source length S | mask shape is (B, T, S) |
| Rust roadmap | QuerySequence | KeySequence and ValueSequence | AttentionMask says which source positions each query may read |
Use the same ledger when reading the Rust types:
| Ledger item | Meaning in framework docs | Meaning in this roadmap | Category-shape consequence |
|---|---|---|---|
| target length | PyTorch L, Keras T | number of QuerySequence rows | score rows belong to the query-side object |
| source length | PyTorch/Keras S | number of KeySequence and ValueSequence rows | score columns belong to the key-value source object |
| attention mask | PyTorch L x S, Keras (B, T, S) | one permission table from query rows to source rows | the mask is context over a product boundary |
| attention output | target-side output rows | one AttentionOutput row for each query row | value mixing returns information to the query side |
The shape ledger gives a quick sanity check:
score table rows == query positions
score table columns == key-value positions
mask cells == query-position/source-position permissions
output rows == query positions after reading values
If those four statements are not true, the explanation has probably collapsed source ownership, role ownership, or mask context too early.
Mask Role Ledger: Permissions, Not Tokens
Framework APIs make the mask shape look like another tensor argument, but the teaching question is more specific:
Which query rows may read which source columns before softmax?
That is why the roadmap names the mask separately from the token sequence, score table, and attention weights. The mask is permission context over the score table.
| Mask misreading | Correct local boundary | What to inspect |
|---|---|---|
| the mask is a shorter token sequence | AttentionScores x AttentionMask -> AttentionScores | the score table keeps query rows and source columns |
| the mask directly produces probabilities | AttentionScores -> AttentionWeights still happens after masking | query 0 attends with [0.5, 0.0, 0.5] |
| the mask is hidden global state | MaskedMultiHeadTransformerBlock : HiddenSequence x AttentionMask -> HiddenSequence | the block boundary keeps the mask visible |
| a fixed mask means no mask exists | MaskedMultiHeadTransformerBlock[M] : HiddenSequence -> HiddenSequence | the chosen mask M is fixed context for that view |
The local rule is:
mask cells select legal score cells;
softmax turns remaining score rows into weights;
weights read value rows.
Do not say “the mask removes tokens” unless you also say which score cells were removed from probability competition. The source sequence still owns the value rows. The mask only says which of those rows each query is allowed to read.
The category-theory reading follows the input count. Self-attention can be wrapped inside a shape-preserving block after projection, masking, value mixing, output projection, residual addition, and normalization return to the hidden stream. The core scoring and mixing steps are still product-input morphisms. Cross-attention makes that product input impossible to ignore, because the target side and source side may have different sequence lengths.
When a framework call reports an attention mask of shape L x S, read it as a
typed reminder:
for each target position, which source positions may be read?
That is why this roadmap names QuerySequence, KeySequence, ValueSequence,
AttentionScores, AttentionMask, AttentionWeights, and AttentionOutput
separately. The names keep target-side questions, source-side comparison, and
source-side information from collapsing into one raw tensor.
ML Concept
Queries ask:
what am I looking for?
Keys answer:
what do I contain?
Values provide:
what information should be mixed?
Category Theory Concept
These are parallel morphisms out of the same object:
HiddenSequence -> QuerySequence
HiddenSequence -> KeySequence
HiddenSequence -> ValueSequence
The current attention example combines query and key roles to produce scores, then uses value roles to produce output vectors.
Design contract:
HiddenSequence -> QuerySequence
HiddenSequence -> KeySequence
HiddenSequence -> ValueSequence
should be three explicit morphisms. A single untyped vector list would make it too easy to pass values into the wrong part of the attention computation.
Step 3: Scaled Dot-Product Attention
The future problem:
Convert query-key similarity into a probability distribution over positions, then use it to mix values.
Rust Syntax
A typed shape could be:
QuerySequence x KeySequence -> AttentionScores
AttentionScores x AttentionMask -> AttentionScores
AttentionScores -> AttentionWeights
AttentionWeights x ValueSequence -> AttentionOutput
AttentionHeadOutputs -> MultiHeadOutput
MultiHeadOutput -> ProjectedAttentionOutput
HiddenSequence x ProjectedAttentionOutput -> HiddenSequence
HiddenSequence -> HiddenSequence
Read the current roadmap code through this shape trace:
flowchart LR
H["HiddenSequence"] --> Q["QuerySequence"]
H --> K["KeySequence"]
H --> V["ValueSequence"]
Q --> S["AttentionScores"]
K --> S
S --> M["Masked Scores"]
Mask["AttentionMask"] --> M
M --> W["AttentionWeights"]
W --> O["AttentionOutput"]
V --> O
O --> MH["MultiHeadOutput"]
MH --> P["ProjectedAttentionOutput"]
H --> R["Residual HiddenSequence"]
P --> R
R --> N["Normalized HiddenSequence"]
N --> FF["FeedForward HiddenSequence"]
The same attention core as a compact rendered math view:
[ \begin{array}{rcl} \mathrm{QuerySequence} \times \mathrm{KeySequence} & \to & \mathrm{AttentionScores} \ \mathrm{AttentionScores} \times \mathrm{AttentionMask} & \to & \mathrm{MaskedScores} \ \mathrm{MaskedScores} & \to & \mathrm{AttentionWeights} \ \mathrm{AttentionWeights} \times \mathrm{ValueSequence} & \to & \mathrm{AttentionOutput} \ \mathrm{AttentionOutput} & \to & \mathrm{ProjectedAttentionOutput} \ \mathrm{HiddenSequence} \times \mathrm{ProjectedAttentionOutput} & \to & \mathrm{HiddenSequence} \end{array} ]
How to read this diagram:
- every product input means two roles must stay visible,
- masking happens before row-wise softmax produces weights,
- value mixing is separate from score calculation,
- the residual step is the first row here that explicitly returns to
HiddenSequence.
What to notice:
Rust reading:
each box is a named type or a named typed boundary in src/attention.rs
ML reading:
scores choose positions, weights mix values, projection and residual return to
the hidden-state width
Category-theory reading:
the middle of attention is a composition with product inputs, and the enclosing
block keeps returning to HiddenSequence
AttentionWeights should be validated like Distribution, but over sequence
positions instead of vocabulary tokens.
The current roadmap code implements the query-key score boundary, the mask boundary, the score-to-weight boundary, the value-mixing boundary, the multi-head concatenation boundary, the output projection boundary, and the residual addition and normalization boundaries:
Source snapshot: src/attention.rs
//! Tiny typed attention boundary for the Transformer roadmap.
//!
//! This module does not implement a full Transformer. It makes the first small
//! attention-specific shapes explicit:
//!
//! - projected queries and keys become query-by-key scores,
//! - masks turn illegal score positions into negligible softmax inputs,
//! - query-by-key scores become row-wise attention probabilities,
//! - attention probabilities mix value vectors into output vectors,
//! - multiple head outputs concatenate into a multi-head output,
//! - the concatenated heads project back into a hidden sequence width,
//! - residual addition preserves the hidden sequence boundary,
//! - layer normalization preserves the hidden sequence boundary,
//! - a position-wise feed-forward map preserves the hidden sequence boundary,
//! - positional encoding adds position information while preserving shape,
//! - a single-head block sketch composes those boundaries end to end,
//! - a masked multi-head block accepts attention masks at the block boundary,
//! - a structured parameter object owns position, block, and readout pieces,
//! - a training-state object owns parameters, learning rate, and step count,
//! - a composed block train step updates readout, feed-forward,
//! normalization, and attention projection parameters from sequence targets.
use crate::category::{Morphism, StepCount};
use crate::domain::{
Distribution, LearningRate, Logits, Loss, ModelDimension, Product, TokenSequence, Vector,
VocabSize,
};
use crate::error::{CtError, CtResult};
use crate::ml::Softmax;
const MASKED_SCORE: f32 = -1_000_000.0;
/// Number of positions in a sequence.
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct SequenceLength(usize);
impl SequenceLength {
pub fn new(value: usize) -> CtResult<Self> {
if value == 0 {
return Err(CtError::EmptyInput("sequence length"));
}
Ok(Self(value))
}
pub fn value(&self) -> usize {
self.0
}
}
/// Number of parallel attention heads.
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct HeadCount(usize);
impl HeadCount {
pub fn new(value: usize) -> CtResult<Self> {
if value == 0 {
return Err(CtError::EmptyInput("head count"));
}
Ok(Self(value))
}
pub fn value(&self) -> usize {
self.0
}
}
/// Width of one attention head.
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct HeadDimension(usize);
impl HeadDimension {
pub fn new(value: usize) -> CtResult<Self> {
if value == 0 {
return Err(CtError::EmptyInput("head dimension"));
}
Ok(Self(value))
}
pub fn value(&self) -> usize {
self.0
}
}
/// Positive stabilizer used in layer normalization.
#[derive(Debug, Clone, Copy, PartialEq)]
pub struct NormalizationEpsilon(f32);
impl NormalizationEpsilon {
pub fn new(value: f32) -> CtResult<Self> {
if !value.is_finite() || value <= 0.0 {
return Err(CtError::ShapeMismatch {
op: "normalization epsilon",
expected: "positive finite epsilon".to_string(),
got: format!("epsilon {value}"),
});
}
Ok(Self(value))
}
pub fn value(&self) -> f32 {
self.0
}
}
/// Projected query vectors for one attention head.
#[derive(Debug, Clone, PartialEq)]
pub struct QuerySequence {
sequence_len: SequenceLength,
head_dimension: HeadDimension,
rows: Vec<Vector>,
}
impl QuerySequence {
pub fn new(rows: Vec<Vec<f32>>) -> CtResult<Self> {
let matrix = AttentionVectorRows::new("query sequence", rows)?;
Ok(Self {
sequence_len: matrix.sequence_len,
head_dimension: matrix.head_dimension,
rows: matrix.rows,
})
}
pub fn sequence_len(&self) -> SequenceLength {
self.sequence_len
}
pub fn head_dimension(&self) -> HeadDimension {
self.head_dimension
}
pub fn rows(&self) -> &[Vector] {
&self.rows
}
}
/// Projected key vectors for one attention head.
#[derive(Debug, Clone, PartialEq)]
pub struct KeySequence {
sequence_len: SequenceLength,
head_dimension: HeadDimension,
rows: Vec<Vector>,
}
impl KeySequence {
pub fn new(rows: Vec<Vec<f32>>) -> CtResult<Self> {
let matrix = AttentionVectorRows::new("key sequence", rows)?;
Ok(Self {
sequence_len: matrix.sequence_len,
head_dimension: matrix.head_dimension,
rows: matrix.rows,
})
}
pub fn sequence_len(&self) -> SequenceLength {
self.sequence_len
}
pub fn head_dimension(&self) -> HeadDimension {
self.head_dimension
}
pub fn rows(&self) -> &[Vector] {
&self.rows
}
}
/// Projected value vectors for one attention head.
#[derive(Debug, Clone, PartialEq)]
pub struct ValueSequence {
sequence_len: SequenceLength,
head_dimension: HeadDimension,
rows: Vec<Vector>,
}
impl ValueSequence {
pub fn new(rows: Vec<Vec<f32>>) -> CtResult<Self> {
let matrix = AttentionVectorRows::new("value sequence", rows)?;
Ok(Self {
sequence_len: matrix.sequence_len,
head_dimension: matrix.head_dimension,
rows: matrix.rows,
})
}
pub fn sequence_len(&self) -> SequenceLength {
self.sequence_len
}
pub fn head_dimension(&self) -> HeadDimension {
self.head_dimension
}
pub fn rows(&self) -> &[Vector] {
&self.rows
}
}
/// Hidden vectors over sequence positions.
#[derive(Debug, Clone, PartialEq)]
pub struct HiddenSequence {
sequence_len: SequenceLength,
model_dimension: ModelDimension,
rows: Vec<Vector>,
}
impl HiddenSequence {
pub fn new(rows: Vec<Vec<f32>>) -> CtResult<Self> {
let matrix = AttentionVectorRows::new("hidden sequence", rows)?;
Ok(Self {
sequence_len: matrix.sequence_len,
model_dimension: ModelDimension::new(matrix.head_dimension.value())?,
rows: matrix.rows,
})
}
pub fn sequence_len(&self) -> SequenceLength {
self.sequence_len
}
pub fn model_dimension(&self) -> ModelDimension {
self.model_dimension
}
pub fn rows(&self) -> &[Vector] {
&self.rows
}
}
/// A finite table of position vectors added to hidden states.
#[derive(Debug, Clone, PartialEq)]
pub struct PositionalEncoding {
max_sequence_len: SequenceLength,
model_dimension: ModelDimension,
rows: Vec<Vector>,
}
impl PositionalEncoding {
pub fn new(rows: Vec<Vec<f32>>) -> CtResult<Self> {
let matrix = AttentionVectorRows::new("positional encoding", rows)?;
Ok(Self {
max_sequence_len: matrix.sequence_len,
model_dimension: ModelDimension::new(matrix.head_dimension.value())?,
rows: matrix.rows,
})
}
pub fn max_sequence_len(&self) -> SequenceLength {
self.max_sequence_len
}
pub fn model_dimension(&self) -> ModelDimension {
self.model_dimension
}
}
#[derive(Debug, Clone, PartialEq)]
struct AttentionVectorRows {
sequence_len: SequenceLength,
head_dimension: HeadDimension,
rows: Vec<Vector>,
}
impl AttentionVectorRows {
fn new(kind: &'static str, rows: Vec<Vec<f32>>) -> CtResult<Self> {
if rows.is_empty() {
return Err(CtError::EmptyInput(kind));
}
let head_dimension = rows[0].len();
if head_dimension == 0 {
return Err(CtError::EmptyInput("attention vector row"));
}
for row in &rows {
if row.len() != head_dimension {
return Err(CtError::ShapeMismatch {
op: kind,
expected: format!("all rows have {head_dimension} columns"),
got: format!("row with {} columns", row.len()),
});
}
if row.iter().any(|value| !value.is_finite()) {
return Err(CtError::ShapeMismatch {
op: kind,
expected: "all vector values are finite".to_string(),
got: "non-finite vector value".to_string(),
});
}
}
Ok(Self {
sequence_len: SequenceLength::new(rows.len())?,
head_dimension: HeadDimension::new(head_dimension)?,
rows: rows.into_iter().map(Vector::new).collect(),
})
}
}
/// Query-by-key scores before row-wise softmax.
#[derive(Debug, Clone, PartialEq)]
pub struct AttentionScores {
query_len: SequenceLength,
key_len: SequenceLength,
rows: Vec<Logits>,
}
impl AttentionScores {
pub fn new(rows: Vec<Vec<f32>>) -> CtResult<Self> {
if rows.is_empty() {
return Err(CtError::EmptyInput("attention scores"));
}
let key_len = rows[0].len();
if key_len == 0 {
return Err(CtError::EmptyInput("attention score row"));
}
for row in &rows {
if row.len() != key_len {
return Err(CtError::ShapeMismatch {
op: "attention scores",
expected: format!("all rows have {key_len} columns"),
got: format!("row with {} columns", row.len()),
});
}
if row.iter().any(|value| !value.is_finite()) {
return Err(CtError::ShapeMismatch {
op: "attention scores",
expected: "all score values are finite".to_string(),
got: "non-finite score value".to_string(),
});
}
}
Ok(Self {
query_len: SequenceLength::new(rows.len())?,
key_len: SequenceLength::new(key_len)?,
rows: rows.into_iter().map(Logits::new).collect(),
})
}
pub fn query_len(&self) -> SequenceLength {
self.query_len
}
pub fn key_len(&self) -> SequenceLength {
self.key_len
}
pub fn rows(&self) -> &[Logits] {
&self.rows
}
}
/// Allowed query-by-key positions before attention softmax.
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct AttentionMask {
query_len: SequenceLength,
key_len: SequenceLength,
rows: Vec<Vec<bool>>,
}
impl AttentionMask {
pub fn new(rows: Vec<Vec<bool>>) -> CtResult<Self> {
if rows.is_empty() {
return Err(CtError::EmptyInput("attention mask"));
}
let key_len = rows[0].len();
if key_len == 0 {
return Err(CtError::EmptyInput("attention mask row"));
}
for row in &rows {
if row.len() != key_len {
return Err(CtError::ShapeMismatch {
op: "attention mask",
expected: format!("all rows have {key_len} columns"),
got: format!("row with {} columns", row.len()),
});
}
if !row.iter().any(|allowed| *allowed) {
return Err(CtError::EmptyInput("attention mask row allows no keys"));
}
}
Ok(Self {
query_len: SequenceLength::new(rows.len())?,
key_len: SequenceLength::new(key_len)?,
rows,
})
}
pub fn query_len(&self) -> SequenceLength {
self.query_len
}
pub fn key_len(&self) -> SequenceLength {
self.key_len
}
pub fn rows(&self) -> &[Vec<bool>] {
&self.rows
}
}
/// Computes scaled query-key dot-product scores.
#[derive(Debug, Clone)]
pub struct ScaledDotProductScores;
impl Morphism<Product<QuerySequence, KeySequence>, AttentionScores> for ScaledDotProductScores {
fn name(&self) -> &'static str {
"scaled_dot_product_scores"
}
fn apply(&self, input: Product<QuerySequence, KeySequence>) -> CtResult<AttentionScores> {
let (queries, keys) = input.into_parts();
let query_dimension = queries.head_dimension();
let key_dimension = keys.head_dimension();
if query_dimension != key_dimension {
return Err(CtError::ShapeMismatch {
op: "scaled dot-product attention scores",
expected: format!("query head dimension {}", query_dimension.value()),
got: format!("key head dimension {}", key_dimension.value()),
});
}
let scale = (query_dimension.value() as f32).sqrt();
let rows = queries
.rows()
.iter()
.map(|query| {
keys.rows()
.iter()
.map(|key| dot_product(query.as_slice(), key.as_slice()) / scale)
.collect::<Vec<_>>()
})
.collect::<Vec<_>>();
AttentionScores::new(rows)
}
}
fn dot_product(left: &[f32], right: &[f32]) -> f32 {
left.iter()
.zip(right.iter())
.map(|(left, right)| left * right)
.sum()
}
/// Applies a boolean attention mask to score rows before softmax.
#[derive(Debug, Clone)]
pub struct MaskedAttentionScores;
impl Morphism<Product<AttentionScores, AttentionMask>, AttentionScores> for MaskedAttentionScores {
fn name(&self) -> &'static str {
"masked_attention_scores"
}
fn apply(&self, input: Product<AttentionScores, AttentionMask>) -> CtResult<AttentionScores> {
let (scores, mask) = input.into_parts();
if scores.query_len() != mask.query_len() || scores.key_len() != mask.key_len() {
return Err(CtError::ShapeMismatch {
op: "masked attention scores",
expected: format!(
"{} query rows x {} key columns",
scores.query_len().value(),
scores.key_len().value()
),
got: format!(
"{} query rows x {} key columns",
mask.query_len().value(),
mask.key_len().value()
),
});
}
let rows = scores
.rows()
.iter()
.zip(mask.rows())
.map(|(score_row, mask_row)| {
score_row
.as_slice()
.iter()
.zip(mask_row)
.map(|(score, allowed)| if *allowed { *score } else { MASKED_SCORE })
.collect::<Vec<_>>()
})
.collect::<Vec<_>>();
AttentionScores::new(rows)
}
}
/// Row-wise attention probabilities over key positions.
#[derive(Debug, Clone, PartialEq)]
pub struct AttentionWeights {
query_len: SequenceLength,
key_len: SequenceLength,
rows: Vec<Distribution>,
}
impl AttentionWeights {
pub fn new(rows: Vec<Distribution>) -> CtResult<Self> {
if rows.is_empty() {
return Err(CtError::EmptyInput("attention weights"));
}
let key_len = rows[0].as_slice().len();
if key_len == 0 {
return Err(CtError::EmptyInput("attention weight row"));
}
for row in &rows {
if row.as_slice().len() != key_len {
return Err(CtError::ShapeMismatch {
op: "attention weights",
expected: format!("all rows have {key_len} columns"),
got: format!("row with {} columns", row.as_slice().len()),
});
}
}
Ok(Self {
query_len: SequenceLength::new(rows.len())?,
key_len: SequenceLength::new(key_len)?,
rows,
})
}
pub fn query_len(&self) -> SequenceLength {
self.query_len
}
pub fn key_len(&self) -> SequenceLength {
self.key_len
}
pub fn rows(&self) -> &[Distribution] {
&self.rows
}
}
/// Weighted value vectors, one output row per query position.
#[derive(Debug, Clone, PartialEq)]
pub struct AttentionOutput {
sequence_len: SequenceLength,
head_dimension: HeadDimension,
rows: Vec<Vector>,
}
impl AttentionOutput {
pub fn new(rows: Vec<Vec<f32>>) -> CtResult<Self> {
let matrix = AttentionVectorRows::new("attention output", rows)?;
Ok(Self {
sequence_len: matrix.sequence_len,
head_dimension: matrix.head_dimension,
rows: matrix.rows,
})
}
pub fn sequence_len(&self) -> SequenceLength {
self.sequence_len
}
pub fn head_dimension(&self) -> HeadDimension {
self.head_dimension
}
pub fn rows(&self) -> &[Vector] {
&self.rows
}
}
/// Validated collection of single-head attention outputs.
#[derive(Debug, Clone, PartialEq)]
pub struct AttentionHeadOutputs {
head_count: HeadCount,
sequence_len: SequenceLength,
head_dimension: HeadDimension,
heads: Vec<AttentionOutput>,
}
impl AttentionHeadOutputs {
pub fn new(heads: Vec<AttentionOutput>) -> CtResult<Self> {
if heads.is_empty() {
return Err(CtError::EmptyInput("attention head outputs"));
}
let sequence_len = heads[0].sequence_len();
let head_dimension = heads[0].head_dimension();
for head in &heads {
if head.sequence_len() != sequence_len {
return Err(CtError::ShapeMismatch {
op: "attention head outputs",
expected: format!("all heads have {} sequence rows", sequence_len.value()),
got: format!("head with {} sequence rows", head.sequence_len().value()),
});
}
if head.head_dimension() != head_dimension {
return Err(CtError::ShapeMismatch {
op: "attention head outputs",
expected: format!("all heads have dimension {}", head_dimension.value()),
got: format!("head dimension {}", head.head_dimension().value()),
});
}
}
Ok(Self {
head_count: HeadCount::new(heads.len())?,
sequence_len,
head_dimension,
heads,
})
}
pub fn head_count(&self) -> HeadCount {
self.head_count
}
pub fn sequence_len(&self) -> SequenceLength {
self.sequence_len
}
pub fn head_dimension(&self) -> HeadDimension {
self.head_dimension
}
pub fn heads(&self) -> &[AttentionOutput] {
&self.heads
}
}
/// Concatenated output of several attention heads.
#[derive(Debug, Clone, PartialEq)]
pub struct MultiHeadOutput {
sequence_len: SequenceLength,
head_count: HeadCount,
head_dimension: HeadDimension,
model_dimension: ModelDimension,
rows: Vec<Vector>,
}
impl MultiHeadOutput {
fn new(
rows: Vec<Vec<f32>>,
head_count: HeadCount,
head_dimension: HeadDimension,
) -> CtResult<Self> {
let matrix = AttentionVectorRows::new("multi-head output", rows)?;
let expected_dimension = head_count.value() * head_dimension.value();
if matrix.head_dimension.value() != expected_dimension {
return Err(CtError::ShapeMismatch {
op: "multi-head output",
expected: format!("row dimension {expected_dimension}"),
got: format!("row dimension {}", matrix.head_dimension.value()),
});
}
Ok(Self {
sequence_len: matrix.sequence_len,
head_count,
head_dimension,
model_dimension: ModelDimension::new(expected_dimension)?,
rows: matrix.rows,
})
}
pub fn sequence_len(&self) -> SequenceLength {
self.sequence_len
}
pub fn head_count(&self) -> HeadCount {
self.head_count
}
pub fn head_dimension(&self) -> HeadDimension {
self.head_dimension
}
pub fn model_dimension(&self) -> ModelDimension {
self.model_dimension
}
pub fn rows(&self) -> &[Vector] {
&self.rows
}
}
/// Output sequence after the multi-head output projection.
#[derive(Debug, Clone, PartialEq)]
pub struct ProjectedAttentionOutput {
sequence_len: SequenceLength,
model_dimension: ModelDimension,
rows: Vec<Vector>,
}
impl ProjectedAttentionOutput {
pub fn new(rows: Vec<Vec<f32>>) -> CtResult<Self> {
let matrix = AttentionVectorRows::new("projected attention output", rows)?;
Ok(Self {
sequence_len: matrix.sequence_len,
model_dimension: ModelDimension::new(matrix.head_dimension.value())?,
rows: matrix.rows,
})
}
pub fn sequence_len(&self) -> SequenceLength {
self.sequence_len
}
pub fn model_dimension(&self) -> ModelDimension {
self.model_dimension
}
pub fn rows(&self) -> &[Vector] {
&self.rows
}
}
/// Vocabulary logits for every position in a hidden sequence.
#[derive(Debug, Clone, PartialEq)]
pub struct SequenceLogits {
sequence_len: SequenceLength,
vocab_size: VocabSize,
rows: Vec<Logits>,
}
impl SequenceLogits {
pub fn new(rows: Vec<Vec<f32>>) -> CtResult<Self> {
if rows.is_empty() {
return Err(CtError::EmptyInput("sequence logits"));
}
let vocab_size = rows[0].len();
if vocab_size == 0 {
return Err(CtError::EmptyInput("sequence logits row"));
}
for row in &rows {
if row.len() != vocab_size {
return Err(CtError::ShapeMismatch {
op: "sequence logits",
expected: format!("all rows have {vocab_size} columns"),
got: format!("row with {} columns", row.len()),
});
}
if row.iter().any(|value| !value.is_finite()) {
return Err(CtError::ShapeMismatch {
op: "sequence logits",
expected: "all logit values are finite".to_string(),
got: "non-finite logit value".to_string(),
});
}
}
Ok(Self {
sequence_len: SequenceLength::new(rows.len())?,
vocab_size: VocabSize::new(vocab_size)?,
rows: rows.into_iter().map(Logits::new).collect(),
})
}
pub fn sequence_len(&self) -> SequenceLength {
self.sequence_len
}
pub fn vocab_size(&self) -> VocabSize {
self.vocab_size
}
pub fn rows(&self) -> &[Logits] {
&self.rows
}
}
/// Learned language-model readout applied to each hidden position.
#[derive(Debug, Clone, PartialEq)]
pub struct TransformerReadout {
input_dimension: ModelDimension,
vocab_size: VocabSize,
weight: Vec<Vec<f32>>,
bias: Vec<f32>,
}
impl TransformerReadout {
pub fn new(weight: Vec<Vec<f32>>, bias: Vec<f32>) -> CtResult<Self> {
let (input_dimension, output_dimension) =
validate_linear_parts("transformer readout", &weight, &bias)?;
Ok(Self {
input_dimension,
vocab_size: VocabSize::new(output_dimension.value())?,
weight,
bias,
})
}
pub fn input_dimension(&self) -> ModelDimension {
self.input_dimension
}
pub fn vocab_size(&self) -> VocabSize {
self.vocab_size
}
pub fn weight(&self) -> &[Vec<f32>] {
&self.weight
}
pub fn bias(&self) -> &[f32] {
&self.bias
}
}
/// Learned output projection after head concatenation.
#[derive(Debug, Clone, PartialEq)]
pub struct AttentionOutputProjection {
input_dimension: ModelDimension,
output_dimension: ModelDimension,
weight: Vec<Vec<f32>>,
bias: Vec<f32>,
}
impl AttentionOutputProjection {
pub fn new(weight: Vec<Vec<f32>>, bias: Vec<f32>) -> CtResult<Self> {
if weight.is_empty() {
return Err(CtError::EmptyInput("attention output projection weight"));
}
if bias.is_empty() {
return Err(CtError::EmptyInput("attention output projection bias"));
}
let output_dimension = bias.len();
if bias.iter().any(|value| !value.is_finite()) {
return Err(CtError::ShapeMismatch {
op: "attention output projection",
expected: "finite bias values".to_string(),
got: "non-finite bias value".to_string(),
});
}
for row in &weight {
if row.len() != output_dimension {
return Err(CtError::ShapeMismatch {
op: "attention output projection",
expected: format!("weight rows have {output_dimension} columns"),
got: format!("weight row with {} columns", row.len()),
});
}
if row.iter().any(|value| !value.is_finite()) {
return Err(CtError::ShapeMismatch {
op: "attention output projection",
expected: "finite weight values".to_string(),
got: "non-finite weight value".to_string(),
});
}
}
Ok(Self {
input_dimension: ModelDimension::new(weight.len())?,
output_dimension: ModelDimension::new(output_dimension)?,
weight,
bias,
})
}
pub fn input_dimension(&self) -> ModelDimension {
self.input_dimension
}
pub fn output_dimension(&self) -> ModelDimension {
self.output_dimension
}
pub fn weight(&self) -> &[Vec<f32>] {
&self.weight
}
pub fn bias(&self) -> &[f32] {
&self.bias
}
}
/// Scale, shift, and epsilon parameters for layer normalization.
#[derive(Debug, Clone, PartialEq)]
pub struct LayerNormParameters {
model_dimension: ModelDimension,
scale: Vec<f32>,
shift: Vec<f32>,
epsilon: NormalizationEpsilon,
}
impl LayerNormParameters {
pub fn new(scale: Vec<f32>, shift: Vec<f32>, epsilon: NormalizationEpsilon) -> CtResult<Self> {
if scale.is_empty() {
return Err(CtError::EmptyInput("layer norm scale"));
}
if shift.is_empty() {
return Err(CtError::EmptyInput("layer norm shift"));
}
if scale.len() != shift.len() {
return Err(CtError::ShapeMismatch {
op: "layer norm parameters",
expected: format!("scale and shift length {}", scale.len()),
got: format!("shift length {}", shift.len()),
});
}
if scale.iter().any(|value| !value.is_finite()) {
return Err(CtError::ShapeMismatch {
op: "layer norm parameters",
expected: "finite scale values".to_string(),
got: "non-finite scale value".to_string(),
});
}
if shift.iter().any(|value| !value.is_finite()) {
return Err(CtError::ShapeMismatch {
op: "layer norm parameters",
expected: "finite shift values".to_string(),
got: "non-finite shift value".to_string(),
});
}
Ok(Self {
model_dimension: ModelDimension::new(scale.len())?,
scale,
shift,
epsilon,
})
}
pub fn identity(model_dimension: ModelDimension) -> Self {
Self {
model_dimension,
scale: vec![1.0; model_dimension.value()],
shift: vec![0.0; model_dimension.value()],
epsilon: NormalizationEpsilon(1e-5),
}
}
pub fn model_dimension(&self) -> ModelDimension {
self.model_dimension
}
pub fn scale(&self) -> &[f32] {
&self.scale
}
pub fn shift(&self) -> &[f32] {
&self.shift
}
pub fn epsilon(&self) -> NormalizationEpsilon {
self.epsilon
}
}
/// Layer normalization over each hidden vector independently.
#[derive(Debug, Clone, PartialEq)]
pub struct LayerNormalization {
parameters: LayerNormParameters,
}
impl LayerNormalization {
pub fn new(parameters: LayerNormParameters) -> Self {
Self { parameters }
}
pub fn model_dimension(&self) -> ModelDimension {
self.parameters.model_dimension()
}
pub fn parameters(&self) -> &LayerNormParameters {
&self.parameters
}
}
#[derive(Debug, Clone, PartialEq)]
struct FeedForwardRowCache {
input: Vec<f32>,
pre_activation: Vec<f32>,
activation: Vec<f32>,
output: Vec<f32>,
}
#[derive(Debug, Clone, PartialEq)]
struct AttentionHeadTrainingCache {
queries: QuerySequence,
keys: KeySequence,
values: ValueSequence,
weights: AttentionWeights,
output: AttentionOutput,
}
#[derive(Debug, Clone, PartialEq)]
struct MaskedBlockTrainingCache {
output: HiddenSequence,
with_feed_forward: HiddenSequence,
with_attention: HiddenSequence,
multi_head_output: MultiHeadOutput,
attention_heads: Vec<AttentionHeadTrainingCache>,
feed_forward_rows: Vec<FeedForwardRowCache>,
}
/// Position-wise two-layer feed-forward sublayer.
#[derive(Debug, Clone, PartialEq)]
pub struct PositionWiseFeedForward {
input_dimension: ModelDimension,
hidden_dimension: ModelDimension,
output_dimension: ModelDimension,
first_weight: Vec<Vec<f32>>,
first_bias: Vec<f32>,
second_weight: Vec<Vec<f32>>,
second_bias: Vec<f32>,
}
impl PositionWiseFeedForward {
pub fn new(
first_weight: Vec<Vec<f32>>,
first_bias: Vec<f32>,
second_weight: Vec<Vec<f32>>,
second_bias: Vec<f32>,
) -> CtResult<Self> {
let (input_dimension, hidden_dimension) = validate_linear_parts(
"position-wise feed-forward first layer",
&first_weight,
&first_bias,
)?;
let (second_input_dimension, output_dimension) = validate_linear_parts(
"position-wise feed-forward second layer",
&second_weight,
&second_bias,
)?;
if second_input_dimension != hidden_dimension {
return Err(CtError::ShapeMismatch {
op: "position-wise feed-forward",
expected: format!("second input dimension {}", hidden_dimension.value()),
got: format!("second input dimension {}", second_input_dimension.value()),
});
}
if output_dimension != input_dimension {
return Err(CtError::ShapeMismatch {
op: "position-wise feed-forward",
expected: format!("output dimension {}", input_dimension.value()),
got: format!("output dimension {}", output_dimension.value()),
});
}
Ok(Self {
input_dimension,
hidden_dimension,
output_dimension,
first_weight,
first_bias,
second_weight,
second_bias,
})
}
pub fn input_dimension(&self) -> ModelDimension {
self.input_dimension
}
pub fn hidden_dimension(&self) -> ModelDimension {
self.hidden_dimension
}
pub fn output_dimension(&self) -> ModelDimension {
self.output_dimension
}
pub fn first_weight(&self) -> &[Vec<f32>] {
&self.first_weight
}
pub fn first_bias(&self) -> &[f32] {
&self.first_bias
}
pub fn second_weight(&self) -> &[Vec<f32>] {
&self.second_weight
}
pub fn second_bias(&self) -> &[f32] {
&self.second_bias
}
}
#[derive(Debug, Clone, PartialEq)]
struct HiddenProjection {
op: &'static str,
input_dimension: ModelDimension,
head_dimension: HeadDimension,
weight: Vec<Vec<f32>>,
bias: Vec<f32>,
}
impl HiddenProjection {
fn new(op: &'static str, weight: Vec<Vec<f32>>, bias: Vec<f32>) -> CtResult<Self> {
let (input_dimension, output_dimension) = validate_linear_parts(op, &weight, &bias)?;
Ok(Self {
op,
input_dimension,
head_dimension: HeadDimension::new(output_dimension.value())?,
weight,
bias,
})
}
fn project(&self, input: &HiddenSequence) -> CtResult<Vec<Vec<f32>>> {
if input.model_dimension() != self.input_dimension {
return Err(CtError::ShapeMismatch {
op: self.op,
expected: format!("input dimension {}", self.input_dimension.value()),
got: format!("input dimension {}", input.model_dimension().value()),
});
}
Ok(input
.rows()
.iter()
.map(|row| project_row(row.as_slice(), &self.weight, &self.bias))
.collect::<Vec<_>>())
}
fn input_dimension(&self) -> ModelDimension {
self.input_dimension
}
fn head_dimension(&self) -> HeadDimension {
self.head_dimension
}
fn weight(&self) -> &[Vec<f32>] {
&self.weight
}
fn bias(&self) -> &[f32] {
&self.bias
}
}
/// Learned projection from hidden states to query vectors.
#[derive(Debug, Clone, PartialEq)]
pub struct HiddenToQuery {
projection: HiddenProjection,
}
impl HiddenToQuery {
pub fn new(weight: Vec<Vec<f32>>, bias: Vec<f32>) -> CtResult<Self> {
Ok(Self {
projection: HiddenProjection::new("hidden-to-query projection", weight, bias)?,
})
}
pub fn input_dimension(&self) -> ModelDimension {
self.projection.input_dimension()
}
pub fn head_dimension(&self) -> HeadDimension {
self.projection.head_dimension()
}
pub fn weight(&self) -> &[Vec<f32>] {
self.projection.weight()
}
pub fn bias(&self) -> &[f32] {
self.projection.bias()
}
}
/// Learned projection from hidden states to key vectors.
#[derive(Debug, Clone, PartialEq)]
pub struct HiddenToKey {
projection: HiddenProjection,
}
impl HiddenToKey {
pub fn new(weight: Vec<Vec<f32>>, bias: Vec<f32>) -> CtResult<Self> {
Ok(Self {
projection: HiddenProjection::new("hidden-to-key projection", weight, bias)?,
})
}
pub fn input_dimension(&self) -> ModelDimension {
self.projection.input_dimension()
}
pub fn head_dimension(&self) -> HeadDimension {
self.projection.head_dimension()
}
pub fn weight(&self) -> &[Vec<f32>] {
self.projection.weight()
}
pub fn bias(&self) -> &[f32] {
self.projection.bias()
}
}
/// Learned projection from hidden states to value vectors.
#[derive(Debug, Clone, PartialEq)]
pub struct HiddenToValue {
projection: HiddenProjection,
}
impl HiddenToValue {
pub fn new(weight: Vec<Vec<f32>>, bias: Vec<f32>) -> CtResult<Self> {
Ok(Self {
projection: HiddenProjection::new("hidden-to-value projection", weight, bias)?,
})
}
pub fn input_dimension(&self) -> ModelDimension {
self.projection.input_dimension()
}
pub fn head_dimension(&self) -> HeadDimension {
self.projection.head_dimension()
}
pub fn weight(&self) -> &[Vec<f32>] {
self.projection.weight()
}
pub fn bias(&self) -> &[f32] {
self.projection.bias()
}
}
/// A tiny single-head Transformer block sketch.
#[derive(Debug, Clone, PartialEq)]
pub struct SingleHeadTransformerBlock {
model_dimension: ModelDimension,
query_projection: HiddenToQuery,
key_projection: HiddenToKey,
value_projection: HiddenToValue,
output_projection: AttentionOutputProjection,
attention_norm: LayerNormalization,
feed_forward: PositionWiseFeedForward,
feed_forward_norm: LayerNormalization,
}
impl SingleHeadTransformerBlock {
pub fn new(
query_projection: HiddenToQuery,
key_projection: HiddenToKey,
value_projection: HiddenToValue,
output_projection: AttentionOutputProjection,
attention_norm: LayerNormalization,
feed_forward: PositionWiseFeedForward,
feed_forward_norm: LayerNormalization,
) -> CtResult<Self> {
let model_dimension = query_projection.input_dimension();
validate_projection_input(
"single-head block key projection",
model_dimension,
key_projection.input_dimension(),
)?;
validate_projection_input(
"single-head block value projection",
model_dimension,
value_projection.input_dimension(),
)?;
if query_projection.head_dimension() != key_projection.head_dimension() {
return Err(CtError::ShapeMismatch {
op: "single-head block",
expected: format!(
"query/key head dimension {}",
query_projection.head_dimension().value()
),
got: format!(
"key head dimension {}",
key_projection.head_dimension().value()
),
});
}
if output_projection.input_dimension().value() != value_projection.head_dimension().value()
{
return Err(CtError::ShapeMismatch {
op: "single-head block",
expected: format!(
"output projection input dimension {}",
value_projection.head_dimension().value()
),
got: format!(
"output projection input dimension {}",
output_projection.input_dimension().value()
),
});
}
validate_projection_input(
"single-head block output projection",
model_dimension,
output_projection.output_dimension(),
)?;
validate_projection_input(
"single-head block attention normalization",
model_dimension,
attention_norm.model_dimension(),
)?;
validate_projection_input(
"single-head block feed-forward",
model_dimension,
feed_forward.input_dimension(),
)?;
validate_projection_input(
"single-head block feed-forward normalization",
model_dimension,
feed_forward_norm.model_dimension(),
)?;
Ok(Self {
model_dimension,
query_projection,
key_projection,
value_projection,
output_projection,
attention_norm,
feed_forward,
feed_forward_norm,
})
}
pub fn model_dimension(&self) -> ModelDimension {
self.model_dimension
}
}
/// Learned query, key, and value projections for one self-attention head.
#[derive(Debug, Clone, PartialEq)]
pub struct SelfAttentionHead {
query_projection: HiddenToQuery,
key_projection: HiddenToKey,
value_projection: HiddenToValue,
}
impl SelfAttentionHead {
pub fn new(
query_projection: HiddenToQuery,
key_projection: HiddenToKey,
value_projection: HiddenToValue,
) -> CtResult<Self> {
let input_dimension = query_projection.input_dimension();
validate_projection_input(
"self-attention head key projection",
input_dimension,
key_projection.input_dimension(),
)?;
validate_projection_input(
"self-attention head value projection",
input_dimension,
value_projection.input_dimension(),
)?;
if query_projection.head_dimension() != key_projection.head_dimension() {
return Err(CtError::ShapeMismatch {
op: "self-attention head",
expected: format!(
"query/key head dimension {}",
query_projection.head_dimension().value()
),
got: format!(
"key head dimension {}",
key_projection.head_dimension().value()
),
});
}
Ok(Self {
query_projection,
key_projection,
value_projection,
})
}
pub fn input_dimension(&self) -> ModelDimension {
self.query_projection.input_dimension()
}
pub fn query_key_dimension(&self) -> HeadDimension {
self.query_projection.head_dimension()
}
pub fn value_dimension(&self) -> HeadDimension {
self.value_projection.head_dimension()
}
pub fn query_projection(&self) -> &HiddenToQuery {
&self.query_projection
}
pub fn key_projection(&self) -> &HiddenToKey {
&self.key_projection
}
pub fn value_projection(&self) -> &HiddenToValue {
&self.value_projection
}
}
/// A tiny multi-head Transformer block sketch.
#[derive(Debug, Clone, PartialEq)]
pub struct MultiHeadTransformerBlock {
model_dimension: ModelDimension,
head_count: HeadCount,
value_dimension: HeadDimension,
heads: Vec<SelfAttentionHead>,
output_projection: AttentionOutputProjection,
attention_norm: LayerNormalization,
feed_forward: PositionWiseFeedForward,
feed_forward_norm: LayerNormalization,
}
impl MultiHeadTransformerBlock {
pub fn new(
heads: Vec<SelfAttentionHead>,
output_projection: AttentionOutputProjection,
attention_norm: LayerNormalization,
feed_forward: PositionWiseFeedForward,
feed_forward_norm: LayerNormalization,
) -> CtResult<Self> {
if heads.is_empty() {
return Err(CtError::EmptyInput("multi-head block heads"));
}
let model_dimension = heads[0].input_dimension();
let value_dimension = heads[0].value_dimension();
for head in &heads {
validate_projection_input(
"multi-head block head projection",
model_dimension,
head.input_dimension(),
)?;
if head.value_dimension() != value_dimension {
return Err(CtError::ShapeMismatch {
op: "multi-head block",
expected: format!("value head dimension {}", value_dimension.value()),
got: format!("value head dimension {}", head.value_dimension().value()),
});
}
}
let head_count = HeadCount::new(heads.len())?;
let concatenated_dimension =
ModelDimension::new(head_count.value() * value_dimension.value())?;
validate_projection_input(
"multi-head block output projection input",
concatenated_dimension,
output_projection.input_dimension(),
)?;
validate_projection_input(
"multi-head block output projection",
model_dimension,
output_projection.output_dimension(),
)?;
validate_projection_input(
"multi-head block attention normalization",
model_dimension,
attention_norm.model_dimension(),
)?;
validate_projection_input(
"multi-head block feed-forward",
model_dimension,
feed_forward.input_dimension(),
)?;
validate_projection_input(
"multi-head block feed-forward normalization",
model_dimension,
feed_forward_norm.model_dimension(),
)?;
Ok(Self {
model_dimension,
head_count,
value_dimension,
heads,
output_projection,
attention_norm,
feed_forward,
feed_forward_norm,
})
}
pub fn model_dimension(&self) -> ModelDimension {
self.model_dimension
}
pub fn head_count(&self) -> HeadCount {
self.head_count
}
pub fn value_dimension(&self) -> HeadDimension {
self.value_dimension
}
pub fn heads(&self) -> &[SelfAttentionHead] {
&self.heads
}
fn with_heads(self, heads: Vec<SelfAttentionHead>) -> CtResult<Self> {
Self::new(
heads,
self.output_projection,
self.attention_norm,
self.feed_forward,
self.feed_forward_norm,
)
}
fn with_feed_forward(self, feed_forward: PositionWiseFeedForward) -> CtResult<Self> {
Self::new(
self.heads,
self.output_projection,
self.attention_norm,
feed_forward,
self.feed_forward_norm,
)
}
fn with_output_projection(
self,
output_projection: AttentionOutputProjection,
) -> CtResult<Self> {
Self::new(
self.heads,
output_projection,
self.attention_norm,
self.feed_forward,
self.feed_forward_norm,
)
}
fn with_layer_norms(
self,
attention_norm: LayerNormalization,
feed_forward_norm: LayerNormalization,
) -> CtResult<Self> {
Self::new(
self.heads,
self.output_projection,
attention_norm,
self.feed_forward,
feed_forward_norm,
)
}
}
/// A tiny masked multi-head Transformer block sketch.
#[derive(Debug, Clone, PartialEq)]
pub struct MaskedMultiHeadTransformerBlock {
block: MultiHeadTransformerBlock,
}
impl MaskedMultiHeadTransformerBlock {
pub fn new(
heads: Vec<SelfAttentionHead>,
output_projection: AttentionOutputProjection,
attention_norm: LayerNormalization,
feed_forward: PositionWiseFeedForward,
feed_forward_norm: LayerNormalization,
) -> CtResult<Self> {
Ok(Self {
block: MultiHeadTransformerBlock::new(
heads,
output_projection,
attention_norm,
feed_forward,
feed_forward_norm,
)?,
})
}
pub fn model_dimension(&self) -> ModelDimension {
self.block.model_dimension()
}
pub fn head_count(&self) -> HeadCount {
self.block.head_count()
}
pub fn value_dimension(&self) -> HeadDimension {
self.block.value_dimension()
}
pub fn heads(&self) -> &[SelfAttentionHead] {
self.block.heads()
}
pub fn feed_forward(&self) -> &PositionWiseFeedForward {
&self.block.feed_forward
}
fn with_feed_forward(self, feed_forward: PositionWiseFeedForward) -> CtResult<Self> {
Ok(Self {
block: self.block.with_feed_forward(feed_forward)?,
})
}
fn with_heads(self, heads: Vec<SelfAttentionHead>) -> CtResult<Self> {
Ok(Self {
block: self.block.with_heads(heads)?,
})
}
fn with_output_projection(
self,
output_projection: AttentionOutputProjection,
) -> CtResult<Self> {
Ok(Self {
block: self.block.with_output_projection(output_projection)?,
})
}
fn with_layer_norms(
self,
attention_norm: LayerNormalization,
feed_forward_norm: LayerNormalization,
) -> CtResult<Self> {
Ok(Self {
block: self
.block
.with_layer_norms(attention_norm, feed_forward_norm)?,
})
}
fn apply_with_training_cache(
&self,
hidden: HiddenSequence,
mask: AttentionMask,
) -> CtResult<MaskedBlockTrainingCache> {
if hidden.model_dimension() != self.block.model_dimension {
return Err(CtError::ShapeMismatch {
op: "masked multi-head block",
expected: format!("model dimension {}", self.block.model_dimension.value()),
got: format!("model dimension {}", hidden.model_dimension().value()),
});
}
let head_caches = self
.block
.heads
.iter()
.map(|head| apply_self_attention_head_with_mask_cache(&hidden, head, Some(&mask)))
.collect::<CtResult<Vec<_>>>()?;
let attention_outputs = head_caches
.iter()
.map(|cache| cache.output.clone())
.collect::<Vec<_>>();
let head_outputs = AttentionHeadOutputs::new(attention_outputs)?;
let multi_head_output = ConcatenateHeads.apply(head_outputs)?;
let projected_attention = self
.block
.output_projection
.apply(multi_head_output.clone())?;
let with_attention = ResidualConnection.apply(Product::new(hidden, projected_attention))?;
let normalized_attention = self.block.attention_norm.apply(with_attention.clone())?;
let (feed_forward_output, feed_forward_rows) =
feed_forward_with_cache(&self.block.feed_forward, &normalized_attention)?;
let with_feed_forward =
ResidualConnection.apply(Product::new(normalized_attention, feed_forward_output))?;
let output = self
.block
.feed_forward_norm
.apply(with_feed_forward.clone())?;
Ok(MaskedBlockTrainingCache {
output,
with_feed_forward,
with_attention,
multi_head_output,
attention_heads: head_caches,
feed_forward_rows,
})
}
}
/// Tiny structured Transformer parameter object for the roadmap.
#[derive(Debug, Clone, PartialEq)]
pub struct TinyTransformerParameters {
positional_encoding: PositionalEncoding,
block: MaskedMultiHeadTransformerBlock,
readout: TransformerReadout,
}
impl TinyTransformerParameters {
pub fn new(
positional_encoding: PositionalEncoding,
block: MaskedMultiHeadTransformerBlock,
readout: TransformerReadout,
) -> CtResult<Self> {
let model_dimension = positional_encoding.model_dimension();
validate_projection_input(
"tiny transformer parameters block",
model_dimension,
block.model_dimension(),
)?;
validate_projection_input(
"tiny transformer parameters readout",
model_dimension,
readout.input_dimension(),
)?;
Ok(Self {
positional_encoding,
block,
readout,
})
}
pub fn model_dimension(&self) -> ModelDimension {
self.positional_encoding.model_dimension()
}
pub fn max_sequence_len(&self) -> SequenceLength {
self.positional_encoding.max_sequence_len()
}
pub fn vocab_size(&self) -> VocabSize {
self.readout.vocab_size()
}
pub fn encode(&self, hidden: HiddenSequence, mask: AttentionMask) -> CtResult<HiddenSequence> {
let positioned = self.positional_encoding.apply(hidden)?;
self.block.apply(Product::new(positioned, mask))
}
pub fn readout(&self) -> &TransformerReadout {
&self.readout
}
pub fn feed_forward(&self) -> &PositionWiseFeedForward {
self.block.feed_forward()
}
pub fn output_projection(&self) -> &AttentionOutputProjection {
&self.block.block.output_projection
}
pub fn attention_heads(&self) -> &[SelfAttentionHead] {
self.block.heads()
}
pub fn attention_norm(&self) -> &LayerNormalization {
&self.block.block.attention_norm
}
pub fn feed_forward_norm(&self) -> &LayerNormalization {
&self.block.block.feed_forward_norm
}
fn with_readout(self, readout: TransformerReadout) -> CtResult<Self> {
Self::new(self.positional_encoding, self.block, readout)
}
fn with_feed_forward(self, feed_forward: PositionWiseFeedForward) -> CtResult<Self> {
Self::new(
self.positional_encoding,
self.block.with_feed_forward(feed_forward)?,
self.readout,
)
}
fn with_attention_heads(self, heads: Vec<SelfAttentionHead>) -> CtResult<Self> {
Self::new(
self.positional_encoding,
self.block.with_heads(heads)?,
self.readout,
)
}
fn with_output_projection(
self,
output_projection: AttentionOutputProjection,
) -> CtResult<Self> {
Self::new(
self.positional_encoding,
self.block.with_output_projection(output_projection)?,
self.readout,
)
}
fn with_layer_norms(
self,
attention_norm: LayerNormalization,
feed_forward_norm: LayerNormalization,
) -> CtResult<Self> {
Self::new(
self.positional_encoding,
self.block
.with_layer_norms(attention_norm, feed_forward_norm)?,
self.readout,
)
}
}
/// Structured state owned by a future Transformer training loop.
#[derive(Debug, Clone, PartialEq)]
pub struct TransformerTrainingState {
parameters: TinyTransformerParameters,
learning_rate: LearningRate,
step_count: StepCount,
}
impl TransformerTrainingState {
pub fn new(parameters: TinyTransformerParameters, learning_rate: LearningRate) -> Self {
Self::from_parts(parameters, learning_rate, StepCount::new(0))
}
pub fn from_parts(
parameters: TinyTransformerParameters,
learning_rate: LearningRate,
step_count: StepCount,
) -> Self {
Self {
parameters,
learning_rate,
step_count,
}
}
pub fn parameters(&self) -> &TinyTransformerParameters {
&self.parameters
}
pub fn learning_rate(&self) -> LearningRate {
self.learning_rate
}
pub fn step_count(&self) -> StepCount {
self.step_count
}
pub fn record_updated_parameters(self, parameters: TinyTransformerParameters) -> Self {
Self {
parameters,
learning_rate: self.learning_rate,
step_count: StepCount::new(self.step_count.value() + 1),
}
}
}
/// One supervised sequence example for a readout-only Transformer update.
#[derive(Debug, Clone, PartialEq)]
pub struct TransformerReadoutTrainingExample {
hidden: HiddenSequence,
mask: AttentionMask,
targets: TokenSequence,
}
impl TransformerReadoutTrainingExample {
pub fn new(
hidden: HiddenSequence,
mask: AttentionMask,
targets: TokenSequence,
) -> CtResult<Self> {
let sequence_len = hidden.sequence_len();
if targets.as_slice().len() != sequence_len.value() {
return Err(CtError::ShapeMismatch {
op: "transformer readout training targets",
expected: format!("{} target tokens", sequence_len.value()),
got: format!("{} target tokens", targets.as_slice().len()),
});
}
if mask.query_len() != sequence_len || mask.key_len() != sequence_len {
return Err(CtError::ShapeMismatch {
op: "transformer readout training mask",
expected: format!(
"{} query rows x {} key columns",
sequence_len.value(),
sequence_len.value()
),
got: format!(
"{} query rows x {} key columns",
mask.query_len().value(),
mask.key_len().value()
),
});
}
Ok(Self {
hidden,
mask,
targets,
})
}
pub fn hidden(&self) -> &HiddenSequence {
&self.hidden
}
pub fn mask(&self) -> &AttentionMask {
&self.mask
}
pub fn targets(&self) -> &TokenSequence {
&self.targets
}
}
/// Non-empty set of supervised readout examples.
#[derive(Debug, Clone, PartialEq)]
pub struct TransformerReadoutTrainingSet(Vec<TransformerReadoutTrainingExample>);
impl TransformerReadoutTrainingSet {
pub fn new(
examples: impl IntoIterator<Item = TransformerReadoutTrainingExample>,
) -> CtResult<Self> {
let examples = examples.into_iter().collect::<Vec<_>>();
if examples.is_empty() {
return Err(CtError::EmptyInput("transformer readout training set"));
}
Ok(Self(examples))
}
pub fn examples(&self) -> &[TransformerReadoutTrainingExample] {
&self.0
}
}
/// One full-batch update of the sequence readout parameters.
#[derive(Debug, Clone, PartialEq)]
pub struct TransformerReadoutTrainStep {
dataset: TransformerReadoutTrainingSet,
}
impl TransformerReadoutTrainStep {
pub fn new(dataset: TransformerReadoutTrainingSet) -> Self {
Self { dataset }
}
}
impl Morphism<TransformerTrainingState, TransformerTrainingState> for TransformerReadoutTrainStep {
fn name(&self) -> &'static str {
"transformer_readout_train_step"
}
fn apply(&self, state: TransformerTrainingState) -> CtResult<TransformerTrainingState> {
let input_dimension = state.parameters.readout().input_dimension().value();
let vocab_size = state.parameters.vocab_size().value();
let mut grad_weight = vec![vec![0.0; vocab_size]; input_dimension];
let mut grad_bias = vec![0.0; vocab_size];
let mut position_count = 0usize;
for example in self.dataset.examples() {
let encoded = state
.parameters
.encode(example.hidden().clone(), example.mask().clone())?;
let logits = state.parameters.readout().apply(encoded.clone())?;
for ((hidden_row, logit_row), target) in encoded
.rows()
.iter()
.zip(logits.rows())
.zip(example.targets().as_slice())
{
let target_index = target.index();
if target_index >= vocab_size {
return Err(CtError::OutOfRange {
kind: "sequence target",
index: target_index,
limit: vocab_size,
});
}
let probabilities = Softmax.apply(logit_row.clone())?;
let mut dlogits = probabilities.as_slice().to_vec();
dlogits[target_index] -= 1.0;
for (vocab_id, dlogit) in dlogits.iter().copied().enumerate() {
grad_bias[vocab_id] += dlogit;
for (feature, hidden_value) in hidden_row.as_slice().iter().copied().enumerate()
{
grad_weight[feature][vocab_id] += hidden_value * dlogit;
}
}
position_count += 1;
}
}
let scale = state.learning_rate().value() / position_count as f32;
let mut updated_weight = state.parameters.readout().weight().to_vec();
let mut updated_bias = state.parameters.readout().bias().to_vec();
for (row, grad_row) in updated_weight.iter_mut().zip(&grad_weight) {
for (weight, grad) in row.iter_mut().zip(grad_row) {
*weight -= scale * grad;
}
}
for (bias, grad) in updated_bias.iter_mut().zip(&grad_bias) {
*bias -= scale * grad;
}
let updated_readout = TransformerReadout::new(updated_weight, updated_bias)?;
let updated_parameters = state.parameters.clone().with_readout(updated_readout)?;
Ok(state.record_updated_parameters(updated_parameters))
}
}
/// One supervised sequence example for local feed-forward sublayer training.
#[derive(Debug, Clone, PartialEq)]
pub struct TransformerFeedForwardTrainingExample {
input: HiddenSequence,
target: HiddenSequence,
}
impl TransformerFeedForwardTrainingExample {
pub fn new(input: HiddenSequence, target: HiddenSequence) -> CtResult<Self> {
if input.sequence_len() != target.sequence_len() {
return Err(CtError::ShapeMismatch {
op: "transformer feed-forward training sequence",
expected: format!("{} target rows", input.sequence_len().value()),
got: format!("{} target rows", target.sequence_len().value()),
});
}
if input.model_dimension() != target.model_dimension() {
return Err(CtError::ShapeMismatch {
op: "transformer feed-forward training dimension",
expected: format!("target dimension {}", input.model_dimension().value()),
got: format!("target dimension {}", target.model_dimension().value()),
});
}
Ok(Self { input, target })
}
pub fn input(&self) -> &HiddenSequence {
&self.input
}
pub fn target(&self) -> &HiddenSequence {
&self.target
}
}
/// Non-empty set of local feed-forward training examples.
#[derive(Debug, Clone, PartialEq)]
pub struct TransformerFeedForwardTrainingSet(Vec<TransformerFeedForwardTrainingExample>);
impl TransformerFeedForwardTrainingSet {
pub fn new(
examples: impl IntoIterator<Item = TransformerFeedForwardTrainingExample>,
) -> CtResult<Self> {
let examples = examples.into_iter().collect::<Vec<_>>();
if examples.is_empty() {
return Err(CtError::EmptyInput("transformer feed-forward training set"));
}
Ok(Self(examples))
}
pub fn examples(&self) -> &[TransformerFeedForwardTrainingExample] {
&self.0
}
}
/// One full-batch update of the position-wise feed-forward sublayer.
#[derive(Debug, Clone, PartialEq)]
pub struct TransformerFeedForwardTrainStep {
dataset: TransformerFeedForwardTrainingSet,
}
impl TransformerFeedForwardTrainStep {
pub fn new(dataset: TransformerFeedForwardTrainingSet) -> Self {
Self { dataset }
}
}
impl Morphism<TransformerTrainingState, TransformerTrainingState>
for TransformerFeedForwardTrainStep
{
fn name(&self) -> &'static str {
"transformer_feed_forward_train_step"
}
fn apply(&self, state: TransformerTrainingState) -> CtResult<TransformerTrainingState> {
let feed_forward = state.parameters.feed_forward();
let mut gradients = FeedForwardGradients::new(feed_forward);
let mut row_count = 0usize;
for example in self.dataset.examples() {
if example.input().model_dimension() != feed_forward.input_dimension() {
return Err(CtError::ShapeMismatch {
op: "transformer feed-forward train step",
expected: format!("input dimension {}", feed_forward.input_dimension().value()),
got: format!(
"input dimension {}",
example.input().model_dimension().value()
),
});
}
let (_output, cache_rows) = feed_forward_with_cache(feed_forward, example.input())?;
for (cache, target_row) in cache_rows.iter().zip(example.target().rows()) {
let d_output = cache
.output
.iter()
.zip(target_row.as_slice())
.map(|(output_value, target_value)| output_value - target_value)
.collect::<Vec<_>>();
gradients.accumulate(feed_forward, cache, &d_output);
row_count += 1;
}
}
let updated_feed_forward =
gradients.apply_to(feed_forward, state.learning_rate(), row_count)?;
let updated_parameters = state
.parameters
.clone()
.with_feed_forward(updated_feed_forward)?;
Ok(state.record_updated_parameters(updated_parameters))
}
}
/// One supervised sequence example for a composed block-level update.
#[derive(Debug, Clone, PartialEq)]
pub struct TransformerBlockTrainingExample {
hidden: HiddenSequence,
mask: AttentionMask,
targets: TokenSequence,
}
impl TransformerBlockTrainingExample {
pub fn new(
hidden: HiddenSequence,
mask: AttentionMask,
targets: TokenSequence,
) -> CtResult<Self> {
let sequence_len = hidden.sequence_len();
if targets.as_slice().len() != sequence_len.value() {
return Err(CtError::ShapeMismatch {
op: "transformer block training targets",
expected: format!("{} target tokens", sequence_len.value()),
got: format!("{} target tokens", targets.as_slice().len()),
});
}
if mask.query_len() != sequence_len || mask.key_len() != sequence_len {
return Err(CtError::ShapeMismatch {
op: "transformer block training mask",
expected: format!(
"{} query rows x {} key columns",
sequence_len.value(),
sequence_len.value()
),
got: format!(
"{} query rows x {} key columns",
mask.query_len().value(),
mask.key_len().value()
),
});
}
Ok(Self {
hidden,
mask,
targets,
})
}
pub fn hidden(&self) -> &HiddenSequence {
&self.hidden
}
pub fn mask(&self) -> &AttentionMask {
&self.mask
}
pub fn targets(&self) -> &TokenSequence {
&self.targets
}
}
/// Non-empty set of supervised block-level training examples.
#[derive(Debug, Clone, PartialEq)]
pub struct TransformerBlockTrainingSet(Vec<TransformerBlockTrainingExample>);
impl TransformerBlockTrainingSet {
pub fn new(
examples: impl IntoIterator<Item = TransformerBlockTrainingExample>,
) -> CtResult<Self> {
let examples = examples.into_iter().collect::<Vec<_>>();
if examples.is_empty() {
return Err(CtError::EmptyInput("transformer block training set"));
}
Ok(Self(examples))
}
pub fn examples(&self) -> &[TransformerBlockTrainingExample] {
&self.0
}
}
/// One full-batch update through the readout, block sublayers, and attention heads.
#[derive(Debug, Clone, PartialEq)]
pub struct TransformerBlockTrainStep {
dataset: TransformerBlockTrainingSet,
}
impl TransformerBlockTrainStep {
pub fn new(dataset: TransformerBlockTrainingSet) -> Self {
Self { dataset }
}
}
impl Morphism<TransformerTrainingState, TransformerTrainingState> for TransformerBlockTrainStep {
fn name(&self) -> &'static str {
"transformer_block_train_step"
}
fn apply(&self, state: TransformerTrainingState) -> CtResult<TransformerTrainingState> {
let readout = state.parameters.readout();
let feed_forward = state.parameters.feed_forward();
let output_projection = state.parameters.output_projection();
let attention_norm = state.parameters.attention_norm();
let feed_forward_norm = state.parameters.feed_forward_norm();
let mut readout_gradients = ReadoutGradients::new(readout);
let mut feed_forward_gradients = FeedForwardGradients::new(feed_forward);
let mut output_projection_gradients =
AttentionOutputProjectionGradients::new(output_projection);
let mut attention_norm_gradients = LayerNormGradients::new(attention_norm.parameters());
let mut feed_forward_norm_gradients =
LayerNormGradients::new(feed_forward_norm.parameters());
let mut attention_head_gradients = state
.parameters
.attention_heads()
.iter()
.map(AttentionHeadGradients::new)
.collect::<Vec<_>>();
let mut position_count = 0usize;
for example in self.dataset.examples() {
let positioned = state
.parameters
.positional_encoding
.apply(example.hidden().clone())?;
let block_cache = state
.parameters
.block
.apply_with_training_cache(positioned.clone(), example.mask().clone())?;
let logits = readout.apply(block_cache.output.clone())?;
let vocab_size = logits.vocab_size().value();
let mut d_multi_head_rows = vec![
vec![0.0; output_projection.input_dimension().value()];
block_cache.multi_head_output.sequence_len().value()
];
for (
position,
(
(((encoded_row, logit_row), with_feed_forward_row), with_attention_row),
((feed_forward_cache, multi_head_row), target),
),
) in block_cache
.output
.rows()
.iter()
.zip(logits.rows())
.zip(block_cache.with_feed_forward.rows())
.zip(block_cache.with_attention.rows())
.zip(
block_cache
.feed_forward_rows
.iter()
.zip(block_cache.multi_head_output.rows())
.zip(example.targets().as_slice()),
)
.enumerate()
{
let dlogits =
softmax_cross_entropy_logits_gradient(logit_row, target.index(), vocab_size)?;
let d_encoded =
readout_gradients.accumulate(readout, encoded_row.as_slice(), &dlogits);
let d_with_feed_forward = feed_forward_norm_gradients.accumulate(
&d_encoded,
with_feed_forward_row.as_slice(),
feed_forward_norm.parameters(),
);
let d_feed_forward_input = feed_forward_gradients.accumulate(
feed_forward,
feed_forward_cache,
&d_with_feed_forward,
);
let d_normalized_attention = add_rows(&d_with_feed_forward, &d_feed_forward_input);
let d_with_attention = attention_norm_gradients.accumulate(
&d_normalized_attention,
with_attention_row.as_slice(),
attention_norm.parameters(),
);
d_multi_head_rows[position] = output_projection_gradients.accumulate(
output_projection,
multi_head_row.as_slice(),
&d_with_attention,
);
position_count += 1;
}
let value_dimension = block_cache.multi_head_output.head_dimension().value();
for (head_index, ((head_gradient, head), head_cache)) in attention_head_gradients
.iter_mut()
.zip(state.parameters.attention_heads())
.zip(&block_cache.attention_heads)
.enumerate()
{
let start = head_index * value_dimension;
let end = start + value_dimension;
let d_head_output_rows = d_multi_head_rows
.iter()
.map(|row| row[start..end].to_vec())
.collect::<Vec<_>>();
head_gradient.accumulate(
head,
&positioned,
head_cache,
example.mask(),
&d_head_output_rows,
)?;
}
}
let updated_readout =
readout_gradients.apply_to(readout, state.learning_rate(), position_count)?;
let updated_feed_forward =
feed_forward_gradients.apply_to(feed_forward, state.learning_rate(), position_count)?;
let updated_output_projection = output_projection_gradients.apply_to(
output_projection,
state.learning_rate(),
position_count,
)?;
let updated_attention_norm = LayerNormalization::new(attention_norm_gradients.apply_to(
attention_norm.parameters(),
state.learning_rate(),
position_count,
)?);
let updated_feed_forward_norm =
LayerNormalization::new(feed_forward_norm_gradients.apply_to(
feed_forward_norm.parameters(),
state.learning_rate(),
position_count,
)?);
let updated_heads = state
.parameters
.attention_heads()
.iter()
.zip(attention_head_gradients)
.map(|(head, gradients)| {
gradients.apply_to(head, state.learning_rate(), position_count)
})
.collect::<CtResult<Vec<_>>>()?;
let updated_parameters = state
.parameters
.clone()
.with_readout(updated_readout)?
.with_feed_forward(updated_feed_forward)?
.with_output_projection(updated_output_projection)?
.with_layer_norms(updated_attention_norm, updated_feed_forward_norm)?
.with_attention_heads(updated_heads)?;
Ok(state.record_updated_parameters(updated_parameters))
}
}
/// Average sequence cross-entropy for the structured Transformer readout.
pub fn transformer_readout_average_loss(
state: &TransformerTrainingState,
dataset: &TransformerReadoutTrainingSet,
) -> CtResult<Loss> {
let mut total = 0.0;
let mut position_count = 0usize;
for example in dataset.examples() {
let logits = state.apply(Product::new(
example.hidden().clone(),
example.mask().clone(),
))?;
let vocab_size = logits.vocab_size().value();
for (logit_row, target) in logits.rows().iter().zip(example.targets().as_slice()) {
let target_index = target.index();
if target_index >= vocab_size {
return Err(CtError::OutOfRange {
kind: "sequence target",
index: target_index,
limit: vocab_size,
});
}
let probabilities = Softmax.apply(logit_row.clone())?;
let probability = probabilities.as_slice()[target_index].max(1e-9);
total += -probability.ln();
position_count += 1;
}
}
Loss::new(total / position_count as f32)
}
/// Average squared error for local feed-forward sublayer training.
pub fn transformer_feed_forward_average_loss(
state: &TransformerTrainingState,
dataset: &TransformerFeedForwardTrainingSet,
) -> CtResult<Loss> {
let feed_forward = state.parameters.feed_forward();
let mut total = 0.0;
let mut value_count = 0usize;
for example in dataset.examples() {
let output = feed_forward.apply(example.input().clone())?;
for (output_row, target_row) in output.rows().iter().zip(example.target().rows()) {
for (output_value, target_value) in
output_row.as_slice().iter().zip(target_row.as_slice())
{
let error = output_value - target_value;
total += 0.5 * error * error;
value_count += 1;
}
}
}
Loss::new(total / value_count as f32)
}
/// Average sequence cross-entropy for the composed block-level training set.
pub fn transformer_block_average_loss(
state: &TransformerTrainingState,
dataset: &TransformerBlockTrainingSet,
) -> CtResult<Loss> {
let mut total = 0.0;
let mut position_count = 0usize;
for example in dataset.examples() {
let logits = state.apply(Product::new(
example.hidden().clone(),
example.mask().clone(),
))?;
let vocab_size = logits.vocab_size().value();
for (logit_row, target) in logits.rows().iter().zip(example.targets().as_slice()) {
let target_index = target.index();
if target_index >= vocab_size {
return Err(CtError::OutOfRange {
kind: "sequence target",
index: target_index,
limit: vocab_size,
});
}
let probabilities = Softmax.apply(logit_row.clone())?;
let probability = probabilities.as_slice()[target_index].max(1e-9);
total += -probability.ln();
position_count += 1;
}
}
Loss::new(total / position_count as f32)
}
/// Applies softmax independently to each query row.
#[derive(Debug, Clone)]
pub struct AttentionSoftmax;
impl Morphism<AttentionScores, AttentionWeights> for AttentionSoftmax {
fn name(&self) -> &'static str {
"attention_softmax"
}
fn apply(&self, scores: AttentionScores) -> CtResult<AttentionWeights> {
let rows = scores
.rows
.into_iter()
.map(|row| Softmax.apply(row))
.collect::<CtResult<Vec<_>>>()?;
AttentionWeights::new(rows)
}
}
/// Mixes value vectors with row-wise attention weights.
#[derive(Debug, Clone)]
pub struct WeightedValueMixing;
impl Morphism<Product<AttentionWeights, ValueSequence>, AttentionOutput> for WeightedValueMixing {
fn name(&self) -> &'static str {
"weighted_value_mixing"
}
fn apply(&self, input: Product<AttentionWeights, ValueSequence>) -> CtResult<AttentionOutput> {
let (weights, values) = input.into_parts();
if weights.key_len() != values.sequence_len() {
return Err(CtError::ShapeMismatch {
op: "weighted value mixing",
expected: format!("{} value rows", weights.key_len().value()),
got: format!("{} value rows", values.sequence_len().value()),
});
}
let value_width = values.head_dimension().value();
let rows = weights
.rows()
.iter()
.map(|weight_row| weighted_sum(weight_row.as_slice(), values.rows(), value_width))
.collect::<Vec<_>>();
AttentionOutput::new(rows)
}
}
fn weighted_sum(weights: &[f32], values: &[Vector], value_width: usize) -> Vec<f32> {
let mut output = vec![0.0; value_width];
for (weight, value) in weights.iter().zip(values.iter()) {
for (output_value, value_component) in output.iter_mut().zip(value.as_slice()) {
*output_value += weight * value_component;
}
}
output
}
/// Concatenates single-head outputs along the feature dimension.
#[derive(Debug, Clone)]
pub struct ConcatenateHeads;
impl Morphism<AttentionHeadOutputs, MultiHeadOutput> for ConcatenateHeads {
fn name(&self) -> &'static str {
"concatenate_heads"
}
fn apply(&self, heads: AttentionHeadOutputs) -> CtResult<MultiHeadOutput> {
let rows = (0..heads.sequence_len().value())
.map(|position| {
heads
.heads()
.iter()
.flat_map(|head| head.rows()[position].as_slice().iter().copied())
.collect::<Vec<_>>()
})
.collect::<Vec<_>>();
MultiHeadOutput::new(rows, heads.head_count(), heads.head_dimension())
}
}
impl Morphism<MultiHeadOutput, ProjectedAttentionOutput> for AttentionOutputProjection {
fn name(&self) -> &'static str {
"attention_output_projection"
}
fn apply(&self, input: MultiHeadOutput) -> CtResult<ProjectedAttentionOutput> {
if input.model_dimension() != self.input_dimension {
return Err(CtError::ShapeMismatch {
op: "attention output projection",
expected: format!("input dimension {}", self.input_dimension.value()),
got: format!("input dimension {}", input.model_dimension().value()),
});
}
let rows = input
.rows()
.iter()
.map(|row| project_row(row.as_slice(), &self.weight, &self.bias))
.collect::<Vec<_>>();
ProjectedAttentionOutput::new(rows)
}
}
/// Adds a same-shaped sublayer output back to the hidden sequence.
#[derive(Debug, Clone)]
pub struct ResidualConnection;
impl Morphism<Product<HiddenSequence, ProjectedAttentionOutput>, HiddenSequence>
for ResidualConnection
{
fn name(&self) -> &'static str {
"residual_connection"
}
fn apply(
&self,
input: Product<HiddenSequence, ProjectedAttentionOutput>,
) -> CtResult<HiddenSequence> {
let (hidden, sublayer_output) = input.into_parts();
if hidden.sequence_len() != sublayer_output.sequence_len() {
return Err(CtError::ShapeMismatch {
op: "residual connection",
expected: format!("{} sequence rows", hidden.sequence_len().value()),
got: format!("{} sequence rows", sublayer_output.sequence_len().value()),
});
}
if hidden.model_dimension() != sublayer_output.model_dimension() {
return Err(CtError::ShapeMismatch {
op: "residual connection",
expected: format!("model dimension {}", hidden.model_dimension().value()),
got: format!(
"model dimension {}",
sublayer_output.model_dimension().value()
),
});
}
let rows = hidden
.rows()
.iter()
.zip(sublayer_output.rows())
.map(|(left, right)| add_rows(left.as_slice(), right.as_slice()))
.collect::<Vec<_>>();
HiddenSequence::new(rows)
}
}
impl Morphism<HiddenSequence, HiddenSequence> for LayerNormalization {
fn name(&self) -> &'static str {
"layer_normalization"
}
fn apply(&self, input: HiddenSequence) -> CtResult<HiddenSequence> {
if input.model_dimension() != self.parameters.model_dimension() {
return Err(CtError::ShapeMismatch {
op: "layer normalization",
expected: format!(
"model dimension {}",
self.parameters.model_dimension().value()
),
got: format!("model dimension {}", input.model_dimension().value()),
});
}
let rows = input
.rows()
.iter()
.map(|row| normalize_row(row.as_slice(), &self.parameters))
.collect::<Vec<_>>();
HiddenSequence::new(rows)
}
}
impl Morphism<HiddenSequence, HiddenSequence> for PositionWiseFeedForward {
fn name(&self) -> &'static str {
"position_wise_feed_forward"
}
fn apply(&self, input: HiddenSequence) -> CtResult<HiddenSequence> {
let (output, _cache) = feed_forward_with_cache(self, &input)?;
Ok(output)
}
}
impl Morphism<HiddenSequence, HiddenSequence> for PositionalEncoding {
fn name(&self) -> &'static str {
"positional_encoding"
}
fn apply(&self, input: HiddenSequence) -> CtResult<HiddenSequence> {
if input.sequence_len().value() > self.max_sequence_len.value() {
return Err(CtError::ShapeMismatch {
op: "positional encoding",
expected: format!("at most {} sequence rows", self.max_sequence_len.value()),
got: format!("{} sequence rows", input.sequence_len().value()),
});
}
if input.model_dimension() != self.model_dimension {
return Err(CtError::ShapeMismatch {
op: "positional encoding",
expected: format!("model dimension {}", self.model_dimension.value()),
got: format!("model dimension {}", input.model_dimension().value()),
});
}
let rows = input
.rows()
.iter()
.zip(&self.rows)
.map(|(hidden_row, position_row)| {
add_rows(hidden_row.as_slice(), position_row.as_slice())
})
.collect::<Vec<_>>();
HiddenSequence::new(rows)
}
}
impl Morphism<HiddenSequence, SequenceLogits> for TransformerReadout {
fn name(&self) -> &'static str {
"transformer_readout"
}
fn apply(&self, input: HiddenSequence) -> CtResult<SequenceLogits> {
if input.model_dimension() != self.input_dimension {
return Err(CtError::ShapeMismatch {
op: "transformer readout",
expected: format!("input dimension {}", self.input_dimension.value()),
got: format!("input dimension {}", input.model_dimension().value()),
});
}
let rows = input
.rows()
.iter()
.map(|row| project_row(row.as_slice(), &self.weight, &self.bias))
.collect::<Vec<_>>();
SequenceLogits::new(rows)
}
}
impl Morphism<HiddenSequence, QuerySequence> for HiddenToQuery {
fn name(&self) -> &'static str {
"hidden_to_query"
}
fn apply(&self, input: HiddenSequence) -> CtResult<QuerySequence> {
QuerySequence::new(self.projection.project(&input)?)
}
}
impl Morphism<HiddenSequence, KeySequence> for HiddenToKey {
fn name(&self) -> &'static str {
"hidden_to_key"
}
fn apply(&self, input: HiddenSequence) -> CtResult<KeySequence> {
KeySequence::new(self.projection.project(&input)?)
}
}
impl Morphism<HiddenSequence, ValueSequence> for HiddenToValue {
fn name(&self) -> &'static str {
"hidden_to_value"
}
fn apply(&self, input: HiddenSequence) -> CtResult<ValueSequence> {
ValueSequence::new(self.projection.project(&input)?)
}
}
fn apply_self_attention_head(
input: &HiddenSequence,
head: &SelfAttentionHead,
) -> CtResult<AttentionOutput> {
apply_self_attention_head_with_mask(input, head, None)
}
fn apply_self_attention_head_with_mask(
input: &HiddenSequence,
head: &SelfAttentionHead,
mask: Option<&AttentionMask>,
) -> CtResult<AttentionOutput> {
Ok(apply_self_attention_head_with_mask_cache(input, head, mask)?.output)
}
fn apply_self_attention_head_with_mask_cache(
input: &HiddenSequence,
head: &SelfAttentionHead,
mask: Option<&AttentionMask>,
) -> CtResult<AttentionHeadTrainingCache> {
let queries = head.query_projection.apply(input.clone())?;
let keys = head.key_projection.apply(input.clone())?;
let values = head.value_projection.apply(input.clone())?;
let scores = ScaledDotProductScores.apply(Product::new(queries.clone(), keys.clone()))?;
let scores = if let Some(mask) = mask {
MaskedAttentionScores.apply(Product::new(scores, mask.clone()))?
} else {
scores
};
let weights = AttentionSoftmax.apply(scores)?;
let output = WeightedValueMixing.apply(Product::new(weights.clone(), values.clone()))?;
Ok(AttentionHeadTrainingCache {
queries,
keys,
values,
weights,
output,
})
}
impl Morphism<Product<HiddenSequence, HiddenSequence>, HiddenSequence> for ResidualConnection {
fn name(&self) -> &'static str {
"hidden_residual_connection"
}
fn apply(&self, input: Product<HiddenSequence, HiddenSequence>) -> CtResult<HiddenSequence> {
let (left, right) = input.into_parts();
if left.sequence_len() != right.sequence_len() {
return Err(CtError::ShapeMismatch {
op: "hidden residual connection",
expected: format!("{} sequence rows", left.sequence_len().value()),
got: format!("{} sequence rows", right.sequence_len().value()),
});
}
if left.model_dimension() != right.model_dimension() {
return Err(CtError::ShapeMismatch {
op: "hidden residual connection",
expected: format!("model dimension {}", left.model_dimension().value()),
got: format!("model dimension {}", right.model_dimension().value()),
});
}
let rows = left
.rows()
.iter()
.zip(right.rows())
.map(|(left, right)| add_rows(left.as_slice(), right.as_slice()))
.collect::<Vec<_>>();
HiddenSequence::new(rows)
}
}
impl Morphism<HiddenSequence, HiddenSequence> for SingleHeadTransformerBlock {
fn name(&self) -> &'static str {
"single_head_transformer_block"
}
fn apply(&self, input: HiddenSequence) -> CtResult<HiddenSequence> {
if input.model_dimension() != self.model_dimension {
return Err(CtError::ShapeMismatch {
op: "single-head block",
expected: format!("model dimension {}", self.model_dimension.value()),
got: format!("model dimension {}", input.model_dimension().value()),
});
}
let head = SelfAttentionHead::new(
self.query_projection.clone(),
self.key_projection.clone(),
self.value_projection.clone(),
)?;
let attention_output = apply_self_attention_head(&input, &head)?;
let head_outputs = AttentionHeadOutputs::new(vec![attention_output])?;
let multi_head_output = ConcatenateHeads.apply(head_outputs)?;
let projected_attention = self.output_projection.apply(multi_head_output)?;
let with_attention = ResidualConnection.apply(Product::new(input, projected_attention))?;
let normalized_attention = self.attention_norm.apply(with_attention)?;
let feed_forward_output = self.feed_forward.apply(normalized_attention.clone())?;
let with_feed_forward =
ResidualConnection.apply(Product::new(normalized_attention, feed_forward_output))?;
self.feed_forward_norm.apply(with_feed_forward)
}
}
impl Morphism<HiddenSequence, HiddenSequence> for MultiHeadTransformerBlock {
fn name(&self) -> &'static str {
"multi_head_transformer_block"
}
fn apply(&self, input: HiddenSequence) -> CtResult<HiddenSequence> {
if input.model_dimension() != self.model_dimension {
return Err(CtError::ShapeMismatch {
op: "multi-head block",
expected: format!("model dimension {}", self.model_dimension.value()),
got: format!("model dimension {}", input.model_dimension().value()),
});
}
let attention_outputs = self
.heads
.iter()
.map(|head| apply_self_attention_head(&input, head))
.collect::<CtResult<Vec<_>>>()?;
let head_outputs = AttentionHeadOutputs::new(attention_outputs)?;
let multi_head_output = ConcatenateHeads.apply(head_outputs)?;
let projected_attention = self.output_projection.apply(multi_head_output)?;
let with_attention = ResidualConnection.apply(Product::new(input, projected_attention))?;
let normalized_attention = self.attention_norm.apply(with_attention)?;
let feed_forward_output = self.feed_forward.apply(normalized_attention.clone())?;
let with_feed_forward =
ResidualConnection.apply(Product::new(normalized_attention, feed_forward_output))?;
self.feed_forward_norm.apply(with_feed_forward)
}
}
impl Morphism<Product<HiddenSequence, AttentionMask>, HiddenSequence>
for MaskedMultiHeadTransformerBlock
{
fn name(&self) -> &'static str {
"masked_multi_head_transformer_block"
}
fn apply(&self, input: Product<HiddenSequence, AttentionMask>) -> CtResult<HiddenSequence> {
let (hidden, mask) = input.into_parts();
Ok(self.apply_with_training_cache(hidden, mask)?.output)
}
}
impl Morphism<Product<HiddenSequence, AttentionMask>, SequenceLogits>
for TinyTransformerParameters
{
fn name(&self) -> &'static str {
"tiny_transformer_parameters"
}
fn apply(&self, input: Product<HiddenSequence, AttentionMask>) -> CtResult<SequenceLogits> {
let (hidden, mask) = input.into_parts();
let positioned = self.positional_encoding.apply(hidden)?;
let encoded = self.block.apply(Product::new(positioned, mask))?;
self.readout.apply(encoded)
}
}
impl Morphism<Product<HiddenSequence, AttentionMask>, SequenceLogits> for TransformerTrainingState {
fn name(&self) -> &'static str {
"transformer_training_state_forward"
}
fn apply(&self, input: Product<HiddenSequence, AttentionMask>) -> CtResult<SequenceLogits> {
self.parameters.apply(input)
}
}
#[derive(Debug, Clone, PartialEq)]
struct ReadoutGradients {
weight: Vec<Vec<f32>>,
bias: Vec<f32>,
}
impl ReadoutGradients {
fn new(readout: &TransformerReadout) -> Self {
Self {
weight: vec![
vec![0.0; readout.vocab_size().value()];
readout.input_dimension().value()
],
bias: vec![0.0; readout.vocab_size().value()],
}
}
fn accumulate(
&mut self,
readout: &TransformerReadout,
hidden_row: &[f32],
dlogits: &[f32],
) -> Vec<f32> {
let mut d_hidden = vec![0.0; readout.input_dimension().value()];
for (vocab_id, dlogit) in dlogits.iter().copied().enumerate() {
self.bias[vocab_id] += dlogit;
for (feature, hidden_value) in hidden_row.iter().copied().enumerate() {
self.weight[feature][vocab_id] += hidden_value * dlogit;
d_hidden[feature] += readout.weight()[feature][vocab_id] * dlogit;
}
}
d_hidden
}
fn apply_to(
self,
readout: &TransformerReadout,
learning_rate: LearningRate,
position_count: usize,
) -> CtResult<TransformerReadout> {
if position_count == 0 {
return Err(CtError::EmptyInput("readout gradient positions"));
}
let scale = learning_rate.value() / position_count as f32;
let mut updated_weight = readout.weight().to_vec();
let mut updated_bias = readout.bias().to_vec();
for (row, grad_row) in updated_weight.iter_mut().zip(&self.weight) {
for (weight, grad) in row.iter_mut().zip(grad_row) {
*weight -= scale * grad;
}
}
for (bias, grad) in updated_bias.iter_mut().zip(&self.bias) {
*bias -= scale * grad;
}
TransformerReadout::new(updated_weight, updated_bias)
}
}
#[derive(Debug, Clone, PartialEq)]
struct FeedForwardGradients {
first_weight: Vec<Vec<f32>>,
first_bias: Vec<f32>,
second_weight: Vec<Vec<f32>>,
second_bias: Vec<f32>,
}
impl FeedForwardGradients {
fn new(feed_forward: &PositionWiseFeedForward) -> Self {
Self {
first_weight: vec![
vec![0.0; feed_forward.hidden_dimension().value()];
feed_forward.input_dimension().value()
],
first_bias: vec![0.0; feed_forward.hidden_dimension().value()],
second_weight: vec![
vec![0.0; feed_forward.output_dimension().value()];
feed_forward.hidden_dimension().value()
],
second_bias: vec![0.0; feed_forward.output_dimension().value()],
}
}
fn accumulate(
&mut self,
feed_forward: &PositionWiseFeedForward,
cache: &FeedForwardRowCache,
d_output: &[f32],
) -> Vec<f32> {
let mut d_activation = vec![0.0; feed_forward.hidden_dimension().value()];
for (output_id, d_output_value) in d_output.iter().copied().enumerate() {
self.second_bias[output_id] += d_output_value;
for (hidden_id, activation_value) in cache.activation.iter().copied().enumerate() {
self.second_weight[hidden_id][output_id] += activation_value * d_output_value;
d_activation[hidden_id] +=
feed_forward.second_weight()[hidden_id][output_id] * d_output_value;
}
}
let mut d_input = vec![0.0; feed_forward.input_dimension().value()];
for (hidden_id, pre_activation_value) in cache.pre_activation.iter().copied().enumerate() {
let d_pre_activation = if pre_activation_value > 0.0 {
d_activation[hidden_id]
} else {
0.0
};
self.first_bias[hidden_id] += d_pre_activation;
for (input_id, input_value) in cache.input.iter().copied().enumerate() {
self.first_weight[input_id][hidden_id] += input_value * d_pre_activation;
d_input[input_id] +=
feed_forward.first_weight()[input_id][hidden_id] * d_pre_activation;
}
}
d_input
}
fn apply_to(
self,
feed_forward: &PositionWiseFeedForward,
learning_rate: LearningRate,
row_count: usize,
) -> CtResult<PositionWiseFeedForward> {
if row_count == 0 {
return Err(CtError::EmptyInput("feed-forward gradient rows"));
}
let scale = learning_rate.value() / row_count as f32;
let mut updated_first_weight = feed_forward.first_weight().to_vec();
let mut updated_first_bias = feed_forward.first_bias().to_vec();
let mut updated_second_weight = feed_forward.second_weight().to_vec();
let mut updated_second_bias = feed_forward.second_bias().to_vec();
for (row, grad_row) in updated_first_weight.iter_mut().zip(&self.first_weight) {
for (weight, grad) in row.iter_mut().zip(grad_row) {
*weight -= scale * grad;
}
}
for (bias, grad) in updated_first_bias.iter_mut().zip(&self.first_bias) {
*bias -= scale * grad;
}
for (row, grad_row) in updated_second_weight.iter_mut().zip(&self.second_weight) {
for (weight, grad) in row.iter_mut().zip(grad_row) {
*weight -= scale * grad;
}
}
for (bias, grad) in updated_second_bias.iter_mut().zip(&self.second_bias) {
*bias -= scale * grad;
}
PositionWiseFeedForward::new(
updated_first_weight,
updated_first_bias,
updated_second_weight,
updated_second_bias,
)
}
}
#[derive(Debug, Clone, PartialEq)]
struct AttentionOutputProjectionGradients {
weight: Vec<Vec<f32>>,
bias: Vec<f32>,
}
impl AttentionOutputProjectionGradients {
fn new(output_projection: &AttentionOutputProjection) -> Self {
Self {
weight: vec![
vec![0.0; output_projection.output_dimension().value()];
output_projection.input_dimension().value()
],
bias: vec![0.0; output_projection.output_dimension().value()],
}
}
fn accumulate(
&mut self,
output_projection: &AttentionOutputProjection,
multi_head_row: &[f32],
d_projected_attention: &[f32],
) -> Vec<f32> {
let mut d_multi_head = vec![0.0; output_projection.input_dimension().value()];
for (output_id, d_value) in d_projected_attention.iter().copied().enumerate() {
self.bias[output_id] += d_value;
for (input_id, input_value) in multi_head_row.iter().copied().enumerate() {
self.weight[input_id][output_id] += input_value * d_value;
d_multi_head[input_id] += output_projection.weight()[input_id][output_id] * d_value;
}
}
d_multi_head
}
fn apply_to(
self,
output_projection: &AttentionOutputProjection,
learning_rate: LearningRate,
row_count: usize,
) -> CtResult<AttentionOutputProjection> {
if row_count == 0 {
return Err(CtError::EmptyInput(
"attention output projection gradient rows",
));
}
let scale = learning_rate.value() / row_count as f32;
let mut updated_weight = output_projection.weight().to_vec();
let mut updated_bias = output_projection.bias().to_vec();
for (row, grad_row) in updated_weight.iter_mut().zip(&self.weight) {
for (weight, grad) in row.iter_mut().zip(grad_row) {
*weight -= scale * grad;
}
}
for (bias, grad) in updated_bias.iter_mut().zip(&self.bias) {
*bias -= scale * grad;
}
AttentionOutputProjection::new(updated_weight, updated_bias)
}
}
#[derive(Debug, Clone, PartialEq)]
struct HiddenProjectionGradients {
weight: Vec<Vec<f32>>,
bias: Vec<f32>,
}
impl HiddenProjectionGradients {
fn new(projection: &HiddenProjection) -> Self {
Self {
weight: vec![
vec![0.0; projection.head_dimension().value()];
projection.input_dimension().value()
],
bias: vec![0.0; projection.head_dimension().value()],
}
}
fn accumulate(
&mut self,
projection: &HiddenProjection,
hidden_row: &[f32],
d_output: &[f32],
) -> CtResult<()> {
if hidden_row.len() != projection.input_dimension().value() {
return Err(CtError::ShapeMismatch {
op: projection.op,
expected: format!("input dimension {}", projection.input_dimension().value()),
got: format!("input dimension {}", hidden_row.len()),
});
}
if d_output.len() != projection.head_dimension().value() {
return Err(CtError::ShapeMismatch {
op: projection.op,
expected: format!("output dimension {}", projection.head_dimension().value()),
got: format!("output dimension {}", d_output.len()),
});
}
for (output_id, d_output_value) in d_output.iter().copied().enumerate() {
self.bias[output_id] += d_output_value;
for (input_id, input_value) in hidden_row.iter().copied().enumerate() {
self.weight[input_id][output_id] += input_value * d_output_value;
}
}
Ok(())
}
fn updated_parts(
self,
projection: &HiddenProjection,
learning_rate: LearningRate,
row_count: usize,
) -> CtResult<(Vec<Vec<f32>>, Vec<f32>)> {
if row_count == 0 {
return Err(CtError::EmptyInput("hidden projection gradient rows"));
}
let scale = learning_rate.value() / row_count as f32;
let mut updated_weight = projection.weight().to_vec();
let mut updated_bias = projection.bias().to_vec();
for (row, grad_row) in updated_weight.iter_mut().zip(&self.weight) {
for (weight, grad) in row.iter_mut().zip(grad_row) {
*weight -= scale * grad;
}
}
for (bias, grad) in updated_bias.iter_mut().zip(&self.bias) {
*bias -= scale * grad;
}
Ok((updated_weight, updated_bias))
}
}
#[derive(Debug, Clone, PartialEq)]
struct AttentionHeadGradients {
query: HiddenProjectionGradients,
key: HiddenProjectionGradients,
value: HiddenProjectionGradients,
}
impl AttentionHeadGradients {
fn new(head: &SelfAttentionHead) -> Self {
Self {
query: HiddenProjectionGradients::new(&head.query_projection.projection),
key: HiddenProjectionGradients::new(&head.key_projection.projection),
value: HiddenProjectionGradients::new(&head.value_projection.projection),
}
}
fn accumulate(
&mut self,
head: &SelfAttentionHead,
input: &HiddenSequence,
cache: &AttentionHeadTrainingCache,
mask: &AttentionMask,
d_output_rows: &[Vec<f32>],
) -> CtResult<()> {
let sequence_len = input.sequence_len().value();
let value_dimension = head.value_dimension().value();
let query_key_dimension = head.query_key_dimension().value();
if d_output_rows.len() != sequence_len {
return Err(CtError::ShapeMismatch {
op: "attention head gradients",
expected: format!("{sequence_len} output rows"),
got: format!("{} output rows", d_output_rows.len()),
});
}
if mask.query_len().value() != sequence_len || mask.key_len().value() != sequence_len {
return Err(CtError::ShapeMismatch {
op: "attention head gradient mask",
expected: format!("{sequence_len} query rows x {sequence_len} key columns"),
got: format!(
"{} query rows x {} key columns",
mask.query_len().value(),
mask.key_len().value()
),
});
}
let mut d_weights = vec![vec![0.0; sequence_len]; sequence_len];
let mut d_values = vec![vec![0.0; value_dimension]; sequence_len];
for (query_id, d_output) in d_output_rows.iter().enumerate() {
if d_output.len() != value_dimension {
return Err(CtError::ShapeMismatch {
op: "attention head output gradient",
expected: format!("value dimension {value_dimension}"),
got: format!("value dimension {}", d_output.len()),
});
}
for key_id in 0..sequence_len {
let value_row = cache.values.rows()[key_id].as_slice();
let weight = cache.weights.rows()[query_id].as_slice()[key_id];
for value_id in 0..value_dimension {
d_weights[query_id][key_id] += d_output[value_id] * value_row[value_id];
d_values[key_id][value_id] += weight * d_output[value_id];
}
}
}
let mut d_scores = vec![vec![0.0; sequence_len]; sequence_len];
for query_id in 0..sequence_len {
let weight_row = cache.weights.rows()[query_id].as_slice();
let row_dot = d_weights[query_id]
.iter()
.zip(weight_row)
.map(|(grad, weight)| grad * weight)
.sum::<f32>();
for key_id in 0..sequence_len {
if mask.rows()[query_id][key_id] {
d_scores[query_id][key_id] =
weight_row[key_id] * (d_weights[query_id][key_id] - row_dot);
}
}
}
let score_scale = (query_key_dimension as f32).sqrt();
let mut d_queries = vec![vec![0.0; query_key_dimension]; sequence_len];
let mut d_keys = vec![vec![0.0; query_key_dimension]; sequence_len];
for query_id in 0..sequence_len {
let query_row = cache.queries.rows()[query_id].as_slice();
for key_id in 0..sequence_len {
let score_gradient = d_scores[query_id][key_id] / score_scale;
let key_row = cache.keys.rows()[key_id].as_slice();
for feature in 0..query_key_dimension {
d_queries[query_id][feature] += score_gradient * key_row[feature];
d_keys[key_id][feature] += score_gradient * query_row[feature];
}
}
}
for position in 0..sequence_len {
let hidden_row = input.rows()[position].as_slice();
self.query.accumulate(
&head.query_projection.projection,
hidden_row,
&d_queries[position],
)?;
self.key.accumulate(
&head.key_projection.projection,
hidden_row,
&d_keys[position],
)?;
self.value.accumulate(
&head.value_projection.projection,
hidden_row,
&d_values[position],
)?;
}
Ok(())
}
fn apply_to(
self,
head: &SelfAttentionHead,
learning_rate: LearningRate,
row_count: usize,
) -> CtResult<SelfAttentionHead> {
let (query_weight, query_bias) = self.query.updated_parts(
&head.query_projection.projection,
learning_rate,
row_count,
)?;
let (key_weight, key_bias) =
self.key
.updated_parts(&head.key_projection.projection, learning_rate, row_count)?;
let (value_weight, value_bias) = self.value.updated_parts(
&head.value_projection.projection,
learning_rate,
row_count,
)?;
SelfAttentionHead::new(
HiddenToQuery::new(query_weight, query_bias)?,
HiddenToKey::new(key_weight, key_bias)?,
HiddenToValue::new(value_weight, value_bias)?,
)
}
}
#[derive(Debug, Clone, PartialEq)]
struct LayerNormGradients {
scale: Vec<f32>,
shift: Vec<f32>,
}
impl LayerNormGradients {
fn new(parameters: &LayerNormParameters) -> Self {
Self {
scale: vec![0.0; parameters.model_dimension().value()],
shift: vec![0.0; parameters.model_dimension().value()],
}
}
fn accumulate(
&mut self,
d_output: &[f32],
input: &[f32],
parameters: &LayerNormParameters,
) -> Vec<f32> {
let stats = layer_norm_stats(input, parameters);
for (feature, (grad, normalized_value)) in
d_output.iter().zip(&stats.normalized).enumerate()
{
self.scale[feature] += grad * normalized_value;
self.shift[feature] += grad;
}
layer_norm_backward_from_stats(d_output, &stats, parameters)
}
fn apply_to(
self,
parameters: &LayerNormParameters,
learning_rate: LearningRate,
row_count: usize,
) -> CtResult<LayerNormParameters> {
if row_count == 0 {
return Err(CtError::EmptyInput("layer norm gradient rows"));
}
let scale = learning_rate.value() / row_count as f32;
let mut updated_scale = parameters.scale().to_vec();
let mut updated_shift = parameters.shift().to_vec();
for (value, grad) in updated_scale.iter_mut().zip(&self.scale) {
*value -= scale * grad;
}
for (value, grad) in updated_shift.iter_mut().zip(&self.shift) {
*value -= scale * grad;
}
LayerNormParameters::new(updated_scale, updated_shift, parameters.epsilon())
}
}
fn feed_forward_with_cache(
feed_forward: &PositionWiseFeedForward,
input: &HiddenSequence,
) -> CtResult<(HiddenSequence, Vec<FeedForwardRowCache>)> {
if input.model_dimension() != feed_forward.input_dimension() {
return Err(CtError::ShapeMismatch {
op: "position-wise feed-forward",
expected: format!("input dimension {}", feed_forward.input_dimension().value()),
got: format!("input dimension {}", input.model_dimension().value()),
});
}
let mut output_rows = Vec::with_capacity(input.rows().len());
let mut cache_rows = Vec::with_capacity(input.rows().len());
for row in input.rows() {
let input_row = row.as_slice().to_vec();
let pre_activation = project_row(
&input_row,
feed_forward.first_weight(),
feed_forward.first_bias(),
);
let activation = pre_activation
.iter()
.map(|value| value.max(0.0))
.collect::<Vec<_>>();
let output = project_row(
&activation,
feed_forward.second_weight(),
feed_forward.second_bias(),
);
output_rows.push(output.clone());
cache_rows.push(FeedForwardRowCache {
input: input_row,
pre_activation,
activation,
output,
});
}
Ok((HiddenSequence::new(output_rows)?, cache_rows))
}
fn softmax_cross_entropy_logits_gradient(
logits: &Logits,
target_index: usize,
vocab_size: usize,
) -> CtResult<Vec<f32>> {
if target_index >= vocab_size {
return Err(CtError::OutOfRange {
kind: "sequence target",
index: target_index,
limit: vocab_size,
});
}
let probabilities = Softmax.apply(logits.clone())?;
let mut dlogits = probabilities.as_slice().to_vec();
dlogits[target_index] -= 1.0;
Ok(dlogits)
}
#[derive(Debug, Clone, PartialEq)]
struct LayerNormStats {
dimension: f32,
inverse_std: f32,
normalized: Vec<f32>,
}
fn layer_norm_stats(input: &[f32], parameters: &LayerNormParameters) -> LayerNormStats {
let dimension = input.len() as f32;
let mean = input.iter().sum::<f32>() / dimension;
let variance = input
.iter()
.map(|value| {
let centered = value - mean;
centered * centered
})
.sum::<f32>()
/ dimension;
let inverse_std = 1.0 / (variance + parameters.epsilon().value()).sqrt();
let normalized = input
.iter()
.map(|value| (value - mean) * inverse_std)
.collect::<Vec<_>>();
LayerNormStats {
dimension,
inverse_std,
normalized,
}
}
fn layer_norm_backward_from_stats(
d_output: &[f32],
stats: &LayerNormStats,
parameters: &LayerNormParameters,
) -> Vec<f32> {
let d_normalized = d_output
.iter()
.zip(parameters.scale())
.map(|(grad, scale)| grad * scale)
.collect::<Vec<_>>();
let sum_d_normalized = d_normalized.iter().sum::<f32>();
let sum_d_normalized_times_normalized = d_normalized
.iter()
.zip(&stats.normalized)
.map(|(grad, normalized_value)| grad * normalized_value)
.sum::<f32>();
d_normalized
.iter()
.zip(&stats.normalized)
.map(|(grad, normalized_value)| {
(stats.dimension * grad
- sum_d_normalized
- normalized_value * sum_d_normalized_times_normalized)
* stats.inverse_std
/ stats.dimension
})
.collect()
}
fn validate_projection_input(
op: &'static str,
expected: ModelDimension,
got: ModelDimension,
) -> CtResult<()> {
if expected != got {
return Err(CtError::ShapeMismatch {
op,
expected: format!("model dimension {}", expected.value()),
got: format!("model dimension {}", got.value()),
});
}
Ok(())
}
fn add_rows(left: &[f32], right: &[f32]) -> Vec<f32> {
left.iter()
.zip(right.iter())
.map(|(left, right)| left + right)
.collect()
}
fn validate_linear_parts(
op: &'static str,
weight: &[Vec<f32>],
bias: &[f32],
) -> CtResult<(ModelDimension, ModelDimension)> {
if weight.is_empty() {
return Err(CtError::EmptyInput("linear weight"));
}
if bias.is_empty() {
return Err(CtError::EmptyInput("linear bias"));
}
let output_dimension = bias.len();
if bias.iter().any(|value| !value.is_finite()) {
return Err(CtError::ShapeMismatch {
op,
expected: "finite bias values".to_string(),
got: "non-finite bias value".to_string(),
});
}
for row in weight {
if row.len() != output_dimension {
return Err(CtError::ShapeMismatch {
op,
expected: format!("weight rows have {output_dimension} columns"),
got: format!("weight row with {} columns", row.len()),
});
}
if row.iter().any(|value| !value.is_finite()) {
return Err(CtError::ShapeMismatch {
op,
expected: "finite weight values".to_string(),
got: "non-finite weight value".to_string(),
});
}
}
Ok((
ModelDimension::new(weight.len())?,
ModelDimension::new(output_dimension)?,
))
}
fn normalize_row(input: &[f32], parameters: &LayerNormParameters) -> Vec<f32> {
let mean = input.iter().sum::<f32>() / input.len() as f32;
let variance = input
.iter()
.map(|value| {
let centered = value - mean;
centered * centered
})
.sum::<f32>()
/ input.len() as f32;
let denominator = (variance + parameters.epsilon.value()).sqrt();
input
.iter()
.zip(parameters.scale.iter().zip(¶meters.shift))
.map(|(value, (scale, shift))| ((value - mean) / denominator) * scale + shift)
.collect()
}
fn project_row(input: &[f32], weight: &[Vec<f32>], bias: &[f32]) -> Vec<f32> {
let mut output = bias.to_vec();
for (feature, input_value) in input.iter().enumerate() {
for (output_value, weight_value) in output.iter_mut().zip(&weight[feature]) {
*output_value += input_value * weight_value;
}
}
output
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn scaled_dot_product_scores_build_query_by_key_rows() -> CtResult<()> {
let queries = QuerySequence::new(vec![vec![1.0, 0.0], vec![0.0, 1.0]])?;
let keys = KeySequence::new(vec![vec![1.0, 0.0], vec![0.0, 1.0], vec![1.0, 1.0]])?;
let scores = ScaledDotProductScores.apply(Product::new(queries, keys))?;
assert_eq!(scores.query_len().value(), 2);
assert_eq!(scores.key_len().value(), 3);
assert!(crate::domain::approx_eq(
scores.rows()[0].as_slice()[0],
std::f32::consts::FRAC_1_SQRT_2,
1e-4
));
assert!(crate::domain::approx_eq(
scores.rows()[0].as_slice()[1],
0.0,
1e-4
));
assert!(crate::domain::approx_eq(
scores.rows()[1].as_slice()[2],
std::f32::consts::FRAC_1_SQRT_2,
1e-4
));
Ok(())
}
#[test]
fn weighted_value_mixing_builds_one_output_per_query() -> CtResult<()> {
let queries = QuerySequence::new(vec![vec![1.0, 0.0], vec![0.0, 1.0]])?;
let keys = KeySequence::new(vec![vec![1.0, 0.0], vec![0.0, 1.0], vec![1.0, 1.0]])?;
let values = ValueSequence::new(vec![vec![1.0, 10.0], vec![2.0, 20.0], vec![3.0, 30.0]])?;
let scores = ScaledDotProductScores.apply(Product::new(queries, keys))?;
let weights = AttentionSoftmax.apply(scores)?;
let output = WeightedValueMixing.apply(Product::new(weights, values))?;
assert_eq!(output.sequence_len().value(), 2);
assert_eq!(output.head_dimension().value(), 2);
assert!(crate::domain::approx_eq(
output.rows()[0].as_slice()[0],
2.0,
1e-4
));
assert!(crate::domain::approx_eq(
output.rows()[0].as_slice()[1],
20.0,
1e-4
));
Ok(())
}
#[test]
fn concatenate_heads_preserves_sequence_and_concatenates_features() -> CtResult<()> {
let head_a = AttentionOutput::new(vec![vec![1.0, 10.0], vec![2.0, 20.0]])?;
let head_b = AttentionOutput::new(vec![vec![3.0, 30.0], vec![4.0, 40.0]])?;
let heads = AttentionHeadOutputs::new(vec![head_a, head_b])?;
let output = ConcatenateHeads.apply(heads)?;
assert_eq!(output.sequence_len().value(), 2);
assert_eq!(output.head_count().value(), 2);
assert_eq!(output.head_dimension().value(), 2);
assert_eq!(output.model_dimension().value(), 4);
assert_eq!(output.rows()[0].as_slice(), &[1.0, 10.0, 3.0, 30.0]);
assert_eq!(output.rows()[1].as_slice(), &[2.0, 20.0, 4.0, 40.0]);
Ok(())
}
#[test]
fn attention_output_projection_maps_multi_head_rows() -> CtResult<()> {
let head_a = AttentionOutput::new(vec![vec![1.0, 10.0], vec![2.0, 20.0]])?;
let head_b = AttentionOutput::new(vec![vec![3.0, 30.0], vec![4.0, 40.0]])?;
let multi_head =
ConcatenateHeads.apply(AttentionHeadOutputs::new(vec![head_a, head_b])?)?;
let projection = AttentionOutputProjection::new(
vec![
vec![1.0, 0.0],
vec![0.0, 0.1],
vec![0.5, 0.0],
vec![0.0, 0.01],
],
vec![0.0, 1.0],
)?;
let output = projection.apply(multi_head)?;
assert_eq!(output.sequence_len().value(), 2);
assert_eq!(output.model_dimension().value(), 2);
assert!(crate::domain::approx_eq(
output.rows()[0].as_slice()[0],
2.5,
1e-4
));
assert!(crate::domain::approx_eq(
output.rows()[0].as_slice()[1],
2.3,
1e-4
));
assert!(crate::domain::approx_eq(
output.rows()[1].as_slice()[0],
4.0,
1e-4
));
assert!(crate::domain::approx_eq(
output.rows()[1].as_slice()[1],
3.4,
1e-4
));
Ok(())
}
#[test]
fn residual_connection_adds_matching_sequences() -> CtResult<()> {
let hidden = HiddenSequence::new(vec![vec![1.0, 2.0], vec![3.0, 4.0]])?;
let sublayer_output = ProjectedAttentionOutput::new(vec![vec![0.5, 1.5], vec![2.5, 3.5]])?;
let output = ResidualConnection.apply(Product::new(hidden, sublayer_output))?;
assert_eq!(output.sequence_len().value(), 2);
assert_eq!(output.model_dimension().value(), 2);
assert_eq!(output.rows()[0].as_slice(), &[1.5, 3.5]);
assert_eq!(output.rows()[1].as_slice(), &[5.5, 7.5]);
Ok(())
}
#[test]
fn layer_normalization_preserves_shape_and_normalizes_each_row() -> CtResult<()> {
let hidden = HiddenSequence::new(vec![vec![1.0, 3.0], vec![2.0, 4.0]])?;
let norm = LayerNormalization::new(LayerNormParameters::identity(ModelDimension::new(2)?));
let output = norm.apply(hidden)?;
assert_eq!(output.sequence_len().value(), 2);
assert_eq!(output.model_dimension().value(), 2);
assert!(crate::domain::approx_eq(
output.rows()[0].as_slice()[0],
-0.999995,
1e-4
));
assert!(crate::domain::approx_eq(
output.rows()[0].as_slice()[1],
0.999995,
1e-4
));
assert!(crate::domain::approx_eq(
output.rows()[1].as_slice()[0],
-0.999995,
1e-4
));
assert!(crate::domain::approx_eq(
output.rows()[1].as_slice()[1],
0.999995,
1e-4
));
Ok(())
}
#[test]
fn layer_normalization_applies_scale_and_shift() -> CtResult<()> {
let hidden = HiddenSequence::new(vec![vec![1.0, 3.0]])?;
let params = LayerNormParameters::new(
vec![2.0, 0.5],
vec![1.0, -1.0],
NormalizationEpsilon::new(1e-5)?,
)?;
let norm = LayerNormalization::new(params);
let output = norm.apply(hidden)?;
assert!(crate::domain::approx_eq(
output.rows()[0].as_slice()[0],
-0.99999,
1e-4
));
assert!(crate::domain::approx_eq(
output.rows()[0].as_slice()[1],
-0.5000025,
1e-4
));
Ok(())
}
#[test]
fn position_wise_feed_forward_maps_each_row_and_preserves_shape() -> CtResult<()> {
let hidden = HiddenSequence::new(vec![vec![1.0, 2.0], vec![3.0, 4.0]])?;
let feed_forward = PositionWiseFeedForward::new(
vec![vec![1.0, -1.0, 0.5], vec![0.0, 1.0, 0.5]],
vec![0.0, 0.0, 0.0],
vec![vec![1.0, 0.0], vec![0.0, 1.0], vec![0.5, 0.5]],
vec![0.0, 0.0],
)?;
let output = feed_forward.apply(hidden)?;
assert_eq!(output.sequence_len().value(), 2);
assert_eq!(output.model_dimension().value(), 2);
assert!(crate::domain::approx_eq(
output.rows()[0].as_slice()[0],
1.75,
1e-4
));
assert!(crate::domain::approx_eq(
output.rows()[0].as_slice()[1],
1.75,
1e-4
));
assert!(crate::domain::approx_eq(
output.rows()[1].as_slice()[0],
4.75,
1e-4
));
assert!(crate::domain::approx_eq(
output.rows()[1].as_slice()[1],
2.75,
1e-4
));
Ok(())
}
#[test]
fn attention_mask_removes_disallowed_positions_before_softmax() -> CtResult<()> {
let scores = AttentionScores::new(vec![vec![2.0, 1.0, 2.0]])?;
let mask = AttentionMask::new(vec![vec![true, false, true]])?;
let masked_scores = MaskedAttentionScores.apply(Product::new(scores, mask))?;
let weights = AttentionSoftmax.apply(masked_scores)?;
assert!(crate::domain::approx_eq(
weights.rows()[0].as_slice()[0],
0.5,
1e-4
));
assert!(crate::domain::approx_eq(
weights.rows()[0].as_slice()[1],
0.0,
1e-4
));
assert!(crate::domain::approx_eq(
weights.rows()[0].as_slice()[2],
0.5,
1e-4
));
Ok(())
}
#[test]
fn attention_softmax_normalizes_each_query_row() -> CtResult<()> {
let scores = AttentionScores::new(vec![vec![2.0, 1.0], vec![0.0, 3.0]])?;
let weights = AttentionSoftmax.apply(scores)?;
assert_eq!(weights.query_len().value(), 2);
assert_eq!(weights.key_len().value(), 2);
for row in weights.rows() {
let sum: f32 = row.as_slice().iter().sum();
assert!(crate::domain::approx_eq(sum, 1.0, 1e-4));
}
Ok(())
}
#[test]
fn attention_scores_reject_non_finite_values() {
assert!(matches!(
AttentionScores::new(vec![vec![1.0, f32::NAN]]),
Err(CtError::ShapeMismatch {
op: "attention scores",
..
})
));
}
#[test]
fn attention_scores_reject_ragged_rows() {
assert!(matches!(
AttentionScores::new(vec![vec![1.0, 2.0], vec![3.0]]),
Err(CtError::ShapeMismatch {
op: "attention scores",
..
})
));
}
#[test]
fn attention_mask_rejects_rows_with_no_allowed_keys() {
assert!(matches!(
AttentionMask::new(vec![vec![false, false]]),
Err(CtError::EmptyInput("attention mask row allows no keys"))
));
}
#[test]
fn masked_attention_scores_reject_shape_mismatch() -> CtResult<()> {
let scores = AttentionScores::new(vec![vec![1.0, 2.0]])?;
let mask = AttentionMask::new(vec![vec![true, true], vec![true, true]])?;
assert!(matches!(
MaskedAttentionScores.apply(Product::new(scores, mask)),
Err(CtError::ShapeMismatch {
op: "masked attention scores",
..
})
));
Ok(())
}
#[test]
fn query_sequence_rejects_ragged_rows() {
assert!(matches!(
QuerySequence::new(vec![vec![1.0, 2.0], vec![3.0]]),
Err(CtError::ShapeMismatch {
op: "query sequence",
..
})
));
}
#[test]
fn value_sequence_rejects_empty_rows() {
assert!(matches!(
ValueSequence::new(vec![Vec::new()]),
Err(CtError::EmptyInput("attention vector row"))
));
}
#[test]
fn key_sequence_rejects_non_finite_values() {
assert!(matches!(
KeySequence::new(vec![vec![1.0, f32::NAN]]),
Err(CtError::ShapeMismatch {
op: "key sequence",
..
})
));
}
#[test]
fn scaled_dot_product_rejects_mismatched_head_dimensions() -> CtResult<()> {
let queries = QuerySequence::new(vec![vec![1.0, 0.0]])?;
let keys = KeySequence::new(vec![vec![1.0, 0.0, 1.0]])?;
assert!(matches!(
ScaledDotProductScores.apply(Product::new(queries, keys)),
Err(CtError::ShapeMismatch {
op: "scaled dot-product attention scores",
..
})
));
Ok(())
}
#[test]
fn weighted_value_mixing_rejects_value_length_mismatch() -> CtResult<()> {
let weights = AttentionWeights::new(vec![Distribution::new(vec![0.5, 0.5])?])?;
let values = ValueSequence::new(vec![vec![1.0, 10.0]])?;
assert!(matches!(
WeightedValueMixing.apply(Product::new(weights, values)),
Err(CtError::ShapeMismatch {
op: "weighted value mixing",
..
})
));
Ok(())
}
#[test]
fn sequence_and_head_dimensions_reject_zero() {
assert!(matches!(
SequenceLength::new(0),
Err(CtError::EmptyInput("sequence length"))
));
assert!(matches!(
HeadDimension::new(0),
Err(CtError::EmptyInput("head dimension"))
));
assert!(matches!(
HeadCount::new(0),
Err(CtError::EmptyInput("head count"))
));
}
#[test]
fn attention_head_outputs_reject_sequence_mismatch() -> CtResult<()> {
let head_a = AttentionOutput::new(vec![vec![1.0, 10.0]])?;
let head_b = AttentionOutput::new(vec![vec![2.0, 20.0], vec![3.0, 30.0]])?;
assert!(matches!(
AttentionHeadOutputs::new(vec![head_a, head_b]),
Err(CtError::ShapeMismatch {
op: "attention head outputs",
..
})
));
Ok(())
}
#[test]
fn attention_head_outputs_reject_head_dimension_mismatch() -> CtResult<()> {
let head_a = AttentionOutput::new(vec![vec![1.0, 10.0]])?;
let head_b = AttentionOutput::new(vec![vec![2.0, 20.0, 200.0]])?;
assert!(matches!(
AttentionHeadOutputs::new(vec![head_a, head_b]),
Err(CtError::ShapeMismatch {
op: "attention head outputs",
..
})
));
Ok(())
}
#[test]
fn attention_output_projection_rejects_input_dimension_mismatch() -> CtResult<()> {
let head_a = AttentionOutput::new(vec![vec![1.0, 10.0]])?;
let head_b = AttentionOutput::new(vec![vec![2.0, 20.0]])?;
let multi_head =
ConcatenateHeads.apply(AttentionHeadOutputs::new(vec![head_a, head_b])?)?;
let projection =
AttentionOutputProjection::new(vec![vec![1.0], vec![1.0], vec![1.0]], vec![0.0])?;
assert!(matches!(
projection.apply(multi_head),
Err(CtError::ShapeMismatch {
op: "attention output projection",
..
})
));
Ok(())
}
#[test]
fn attention_output_projection_rejects_bad_weight_shapes() {
assert!(matches!(
AttentionOutputProjection::new(vec![vec![1.0, 0.0], vec![1.0]], vec![0.0, 0.0]),
Err(CtError::ShapeMismatch {
op: "attention output projection",
..
})
));
}
#[test]
fn attention_output_projection_rejects_non_finite_values() {
assert!(matches!(
AttentionOutputProjection::new(vec![vec![1.0]], vec![f32::NAN]),
Err(CtError::ShapeMismatch {
op: "attention output projection",
..
})
));
assert!(matches!(
AttentionOutputProjection::new(vec![vec![f32::INFINITY]], vec![0.0]),
Err(CtError::ShapeMismatch {
op: "attention output projection",
..
})
));
}
#[test]
fn residual_connection_rejects_sequence_mismatch() -> CtResult<()> {
let hidden = HiddenSequence::new(vec![vec![1.0, 2.0]])?;
let sublayer_output = ProjectedAttentionOutput::new(vec![vec![0.5, 1.5], vec![2.5, 3.5]])?;
assert!(matches!(
ResidualConnection.apply(Product::new(hidden, sublayer_output)),
Err(CtError::ShapeMismatch {
op: "residual connection",
..
})
));
Ok(())
}
#[test]
fn residual_connection_rejects_model_dimension_mismatch() -> CtResult<()> {
let hidden = HiddenSequence::new(vec![vec![1.0, 2.0]])?;
let sublayer_output = ProjectedAttentionOutput::new(vec![vec![0.5, 1.5, 2.5]])?;
assert!(matches!(
ResidualConnection.apply(Product::new(hidden, sublayer_output)),
Err(CtError::ShapeMismatch {
op: "residual connection",
..
})
));
Ok(())
}
#[test]
fn layer_normalization_rejects_model_dimension_mismatch() -> CtResult<()> {
let hidden = HiddenSequence::new(vec![vec![1.0, 2.0, 3.0]])?;
let norm = LayerNormalization::new(LayerNormParameters::identity(ModelDimension::new(2)?));
assert!(matches!(
norm.apply(hidden),
Err(CtError::ShapeMismatch {
op: "layer normalization",
..
})
));
Ok(())
}
#[test]
fn layer_norm_parameters_reject_bad_shapes_and_values() {
assert!(matches!(
LayerNormParameters::new(vec![1.0, 1.0], vec![0.0], NormalizationEpsilon(1e-5)),
Err(CtError::ShapeMismatch {
op: "layer norm parameters",
..
})
));
assert!(matches!(
LayerNormParameters::new(vec![f32::NAN], vec![0.0], NormalizationEpsilon(1e-5)),
Err(CtError::ShapeMismatch {
op: "layer norm parameters",
..
})
));
assert!(matches!(
LayerNormParameters::new(vec![1.0], vec![f32::INFINITY], NormalizationEpsilon(1e-5)),
Err(CtError::ShapeMismatch {
op: "layer norm parameters",
..
})
));
assert!(matches!(
NormalizationEpsilon::new(0.0),
Err(CtError::ShapeMismatch {
op: "normalization epsilon",
..
})
));
}
#[test]
fn position_wise_feed_forward_rejects_input_dimension_mismatch() -> CtResult<()> {
let hidden = HiddenSequence::new(vec![vec![1.0, 2.0, 3.0]])?;
let feed_forward = PositionWiseFeedForward::new(
vec![vec![1.0, 0.0], vec![0.0, 1.0]],
vec![0.0, 0.0],
vec![vec![1.0, 0.0], vec![0.0, 1.0]],
vec![0.0, 0.0],
)?;
assert!(matches!(
feed_forward.apply(hidden),
Err(CtError::ShapeMismatch {
op: "position-wise feed-forward",
..
})
));
Ok(())
}
#[test]
fn position_wise_feed_forward_rejects_incompatible_layer_shapes() {
assert!(matches!(
PositionWiseFeedForward::new(
vec![vec![1.0, 0.0]],
vec![0.0, 0.0],
vec![vec![1.0], vec![1.0], vec![1.0]],
vec![0.0],
),
Err(CtError::ShapeMismatch {
op: "position-wise feed-forward",
..
})
));
assert!(matches!(
PositionWiseFeedForward::new(
vec![vec![1.0, 0.0]],
vec![0.0, 0.0],
vec![vec![1.0, 0.0], vec![1.0, 0.0]],
vec![0.0, 0.0],
),
Err(CtError::ShapeMismatch {
op: "position-wise feed-forward",
..
})
));
}
#[test]
fn position_wise_feed_forward_rejects_non_finite_values() {
assert!(matches!(
PositionWiseFeedForward::new(
vec![vec![f32::NAN]],
vec![0.0],
vec![vec![1.0]],
vec![0.0],
),
Err(CtError::ShapeMismatch {
op: "position-wise feed-forward first layer",
..
})
));
assert!(matches!(
PositionWiseFeedForward::new(
vec![vec![1.0]],
vec![0.0],
vec![vec![1.0]],
vec![f32::INFINITY],
),
Err(CtError::ShapeMismatch {
op: "position-wise feed-forward second layer",
..
})
));
}
#[test]
fn positional_encoding_adds_position_rows_and_preserves_shape() -> CtResult<()> {
let hidden = HiddenSequence::new(vec![vec![1.0, 2.0], vec![3.0, 4.0]])?;
let positions = PositionalEncoding::new(vec![vec![0.1, 0.2], vec![0.3, 0.4]])?;
let output = positions.apply(hidden)?;
assert_eq!(output.sequence_len().value(), 2);
assert_eq!(output.model_dimension().value(), 2);
assert!(crate::domain::approx_eq(
output.rows()[0].as_slice()[0],
1.1,
1e-4
));
assert!(crate::domain::approx_eq(
output.rows()[1].as_slice()[1],
4.4,
1e-4
));
Ok(())
}
#[test]
fn positional_encoding_rejects_model_dimension_mismatch() -> CtResult<()> {
let hidden = HiddenSequence::new(vec![vec![1.0, 2.0]])?;
let positions = PositionalEncoding::new(vec![vec![0.1, 0.2, 0.3]])?;
assert!(matches!(
positions.apply(hidden),
Err(CtError::ShapeMismatch {
op: "positional encoding",
..
})
));
Ok(())
}
#[test]
fn positional_encoding_rejects_sequence_too_long() -> CtResult<()> {
let hidden = HiddenSequence::new(vec![vec![1.0], vec![2.0]])?;
let positions = PositionalEncoding::new(vec![vec![0.1]])?;
assert!(matches!(
positions.apply(hidden),
Err(CtError::ShapeMismatch {
op: "positional encoding",
..
})
));
Ok(())
}
#[test]
fn hidden_to_query_projects_hidden_rows() -> CtResult<()> {
let hidden = HiddenSequence::new(vec![vec![1.0, 2.0], vec![3.0, 4.0]])?;
let projection = HiddenToQuery::new(vec![vec![1.0, 0.0], vec![0.0, 1.0]], vec![0.5, -0.5])?;
let queries = projection.apply(hidden)?;
assert_eq!(queries.sequence_len().value(), 2);
assert_eq!(queries.head_dimension().value(), 2);
assert!(crate::domain::approx_eq(
queries.rows()[0].as_slice()[0],
1.5,
1e-4
));
assert!(crate::domain::approx_eq(
queries.rows()[1].as_slice()[1],
3.5,
1e-4
));
Ok(())
}
#[test]
fn hidden_projection_rejects_input_dimension_mismatch() -> CtResult<()> {
let hidden = HiddenSequence::new(vec![vec![1.0, 2.0, 3.0]])?;
let projection = HiddenToValue::new(vec![vec![1.0, 0.0], vec![0.0, 1.0]], vec![0.0, 0.0])?;
assert!(matches!(
projection.apply(hidden),
Err(CtError::ShapeMismatch {
op: "hidden-to-value projection",
..
})
));
Ok(())
}
#[test]
fn residual_connection_adds_hidden_sequences() -> CtResult<()> {
let left = HiddenSequence::new(vec![vec![1.0, 2.0]])?;
let right = HiddenSequence::new(vec![vec![3.0, 4.0]])?;
let output = ResidualConnection.apply(Product::new(left, right))?;
assert_eq!(output.rows()[0].as_slice(), &[4.0, 6.0]);
Ok(())
}
#[test]
fn single_head_transformer_block_preserves_hidden_sequence_shape() -> CtResult<()> {
let block = tiny_single_head_block()?;
let hidden = HiddenSequence::new(vec![vec![1.0, 0.0], vec![0.0, 1.0]])?;
let output = block.apply(hidden)?;
assert_eq!(output.sequence_len().value(), 2);
assert_eq!(output.model_dimension().value(), 2);
assert!(
output
.rows()
.iter()
.flat_map(|row| row.as_slice())
.all(|value| value.is_finite())
);
Ok(())
}
#[test]
fn single_head_transformer_block_rejects_constructor_dimension_mismatch() -> CtResult<()> {
let query = HiddenToQuery::new(vec![vec![1.0, 0.0], vec![0.0, 1.0]], vec![0.0, 0.0])?;
let key = HiddenToKey::new(
vec![vec![1.0, 0.0], vec![0.0, 1.0], vec![0.0, 0.0]],
vec![0.0, 0.0],
)?;
let value = HiddenToValue::new(vec![vec![1.0, 0.0], vec![0.0, 1.0]], vec![0.0, 0.0])?;
let output_projection =
AttentionOutputProjection::new(vec![vec![1.0, 0.0], vec![0.0, 1.0]], vec![0.0, 0.0])?;
let model_dimension = ModelDimension::new(2)?;
let attention_norm =
LayerNormalization::new(LayerNormParameters::identity(model_dimension));
let feed_forward = PositionWiseFeedForward::new(
vec![vec![1.0, 0.0], vec![0.0, 1.0]],
vec![0.0, 0.0],
vec![vec![1.0, 0.0], vec![0.0, 1.0]],
vec![0.0, 0.0],
)?;
let feed_forward_norm =
LayerNormalization::new(LayerNormParameters::identity(model_dimension));
assert!(matches!(
SingleHeadTransformerBlock::new(
query,
key,
value,
output_projection,
attention_norm,
feed_forward,
feed_forward_norm,
),
Err(CtError::ShapeMismatch {
op: "single-head block key projection",
..
})
));
Ok(())
}
#[test]
fn single_head_transformer_block_rejects_apply_dimension_mismatch() -> CtResult<()> {
let block = tiny_single_head_block()?;
let hidden = HiddenSequence::new(vec![vec![1.0, 0.0, 0.0]])?;
assert!(matches!(
block.apply(hidden),
Err(CtError::ShapeMismatch {
op: "single-head block",
..
})
));
Ok(())
}
#[test]
fn self_attention_head_rejects_query_key_head_mismatch() -> CtResult<()> {
let query = HiddenToQuery::new(vec![vec![1.0], vec![0.0]], vec![0.0])?;
let key = HiddenToKey::new(vec![vec![1.0, 0.0], vec![0.0, 1.0]], vec![0.0, 0.0])?;
let value = HiddenToValue::new(vec![vec![1.0], vec![0.0]], vec![0.0])?;
assert!(matches!(
SelfAttentionHead::new(query, key, value),
Err(CtError::ShapeMismatch {
op: "self-attention head",
..
})
));
Ok(())
}
#[test]
fn multi_head_transformer_block_preserves_hidden_sequence_shape() -> CtResult<()> {
let block = tiny_multi_head_block()?;
let hidden = HiddenSequence::new(vec![vec![1.0, 0.0], vec![0.0, 1.0]])?;
let output = block.apply(hidden)?;
assert_eq!(block.head_count().value(), 2);
assert_eq!(block.value_dimension().value(), 1);
assert_eq!(output.sequence_len().value(), 2);
assert_eq!(output.model_dimension().value(), 2);
assert!(
output
.rows()
.iter()
.flat_map(|row| row.as_slice())
.all(|value| value.is_finite())
);
Ok(())
}
#[test]
fn multi_head_transformer_block_rejects_value_dimension_mismatch() -> CtResult<()> {
let head_a = tiny_self_attention_head_first_feature()?;
let head_b = SelfAttentionHead::new(
HiddenToQuery::new(vec![vec![0.0], vec![1.0]], vec![0.0])?,
HiddenToKey::new(vec![vec![0.0], vec![1.0]], vec![0.0])?,
HiddenToValue::new(vec![vec![0.0, 1.0], vec![1.0, 0.0]], vec![0.0, 0.0])?,
)?;
let model_dimension = ModelDimension::new(2)?;
assert!(matches!(
MultiHeadTransformerBlock::new(
vec![head_a, head_b],
AttentionOutputProjection::new(
vec![vec![1.0, 0.0], vec![0.0, 1.0]],
vec![0.0, 0.0],
)?,
LayerNormalization::new(LayerNormParameters::identity(model_dimension)),
identity_feed_forward()?,
LayerNormalization::new(LayerNormParameters::identity(model_dimension)),
),
Err(CtError::ShapeMismatch {
op: "multi-head block",
..
})
));
Ok(())
}
#[test]
fn multi_head_transformer_block_rejects_output_projection_input_mismatch() -> CtResult<()> {
let model_dimension = ModelDimension::new(2)?;
assert!(matches!(
MultiHeadTransformerBlock::new(
vec![
tiny_self_attention_head_first_feature()?,
tiny_self_attention_head_second_feature()?,
],
AttentionOutputProjection::new(
vec![vec![1.0, 0.0], vec![0.0, 1.0], vec![0.0, 0.0]],
vec![0.0, 0.0],
)?,
LayerNormalization::new(LayerNormParameters::identity(model_dimension)),
identity_feed_forward()?,
LayerNormalization::new(LayerNormParameters::identity(model_dimension)),
),
Err(CtError::ShapeMismatch {
op: "multi-head block output projection input",
..
})
));
Ok(())
}
#[test]
fn multi_head_transformer_block_rejects_apply_dimension_mismatch() -> CtResult<()> {
let block = tiny_multi_head_block()?;
let hidden = HiddenSequence::new(vec![vec![1.0, 0.0, 0.0]])?;
assert!(matches!(
block.apply(hidden),
Err(CtError::ShapeMismatch {
op: "multi-head block",
..
})
));
Ok(())
}
#[test]
fn masked_multi_head_transformer_block_preserves_hidden_sequence_shape() -> CtResult<()> {
let block = tiny_masked_multi_head_block()?;
let hidden = HiddenSequence::new(vec![vec![1.0, 0.0], vec![0.0, 1.0]])?;
let mask = AttentionMask::new(vec![vec![true, false], vec![true, true]])?;
let output = block.apply(Product::new(hidden, mask))?;
assert_eq!(block.head_count().value(), 2);
assert_eq!(output.sequence_len().value(), 2);
assert_eq!(output.model_dimension().value(), 2);
assert!(
output
.rows()
.iter()
.flat_map(|row| row.as_slice())
.all(|value| value.is_finite())
);
Ok(())
}
#[test]
fn masked_multi_head_transformer_block_rejects_mask_shape_mismatch() -> CtResult<()> {
let block = tiny_masked_multi_head_block()?;
let hidden = HiddenSequence::new(vec![vec![1.0, 0.0], vec![0.0, 1.0]])?;
let mask = AttentionMask::new(vec![vec![true, true, true], vec![true, true, true]])?;
assert!(matches!(
block.apply(Product::new(hidden, mask)),
Err(CtError::ShapeMismatch {
op: "masked attention scores",
..
})
));
Ok(())
}
#[test]
fn transformer_readout_maps_each_hidden_position_to_logits() -> CtResult<()> {
let readout = TransformerReadout::new(
vec![vec![1.0, 0.0, 0.5], vec![0.0, 1.0, -0.5]],
vec![0.0, 0.1, -0.1],
)?;
let hidden = HiddenSequence::new(vec![vec![2.0, 3.0], vec![4.0, 5.0]])?;
let logits = readout.apply(hidden)?;
assert_eq!(logits.sequence_len().value(), 2);
assert_eq!(logits.vocab_size().value(), 3);
assert_eq!(logits.rows()[0].as_slice(), &[2.0, 3.1, -0.6]);
assert_eq!(logits.rows()[1].as_slice(), &[4.0, 5.1, -0.6]);
Ok(())
}
#[test]
fn tiny_transformer_parameters_forward_maps_hidden_and_mask_to_sequence_logits() -> CtResult<()>
{
let parameters = tiny_transformer_parameters()?;
let hidden = HiddenSequence::new(vec![vec![1.0, 0.0], vec![0.0, 1.0]])?;
let mask = AttentionMask::new(vec![vec![true, false], vec![true, true]])?;
let logits = parameters.apply(Product::new(hidden, mask))?;
assert_eq!(parameters.model_dimension().value(), 2);
assert_eq!(parameters.max_sequence_len().value(), 2);
assert_eq!(logits.sequence_len().value(), 2);
assert_eq!(logits.vocab_size().value(), 3);
assert!(
logits
.rows()
.iter()
.flat_map(|row| row.as_slice())
.all(|value| value.is_finite())
);
Ok(())
}
#[test]
fn tiny_transformer_parameters_rejects_readout_dimension_mismatch() -> CtResult<()> {
let positional_encoding = PositionalEncoding::new(vec![vec![0.1, 0.0], vec![0.0, 0.1]])?;
let block = tiny_masked_multi_head_block()?;
let readout = TransformerReadout::new(vec![vec![1.0], vec![0.0], vec![0.5]], vec![0.0])?;
assert!(matches!(
TinyTransformerParameters::new(positional_encoding, block, readout),
Err(CtError::ShapeMismatch {
op: "tiny transformer parameters readout",
..
})
));
Ok(())
}
#[test]
fn transformer_training_state_records_updated_parameters_and_step_count() -> CtResult<()> {
let initial_parameters = tiny_transformer_parameters()?;
let updated_parameters = tiny_transformer_parameters()?;
let state = TransformerTrainingState::new(initial_parameters, LearningRate::new(0.25)?);
let next_state = state.record_updated_parameters(updated_parameters.clone());
assert_eq!(next_state.parameters(), &updated_parameters);
assert_eq!(next_state.learning_rate().value(), 0.25);
assert_eq!(next_state.step_count().value(), 1);
Ok(())
}
#[test]
fn transformer_training_state_forward_uses_structured_parameters() -> CtResult<()> {
let state =
TransformerTrainingState::new(tiny_transformer_parameters()?, LearningRate::new(0.1)?);
let hidden = HiddenSequence::new(vec![vec![1.0, 0.0], vec![0.0, 1.0]])?;
let mask = AttentionMask::new(vec![vec![true, false], vec![true, true]])?;
let logits = state.apply(Product::new(hidden, mask))?;
assert_eq!(logits.sequence_len().value(), 2);
assert_eq!(logits.vocab_size().value(), 3);
assert_eq!(state.step_count().value(), 0);
Ok(())
}
#[test]
fn transformer_readout_training_example_rejects_target_length_mismatch() -> CtResult<()> {
let hidden = HiddenSequence::new(vec![vec![1.0, 0.0], vec![0.0, 1.0]])?;
let mask = AttentionMask::new(vec![vec![true, false], vec![true, true]])?;
let targets = TokenSequence::from_indices([0])?;
assert!(matches!(
TransformerReadoutTrainingExample::new(hidden, mask, targets),
Err(CtError::ShapeMismatch {
op: "transformer readout training targets",
..
})
));
Ok(())
}
#[test]
fn transformer_readout_training_example_rejects_mask_shape_mismatch() -> CtResult<()> {
let hidden = HiddenSequence::new(vec![vec![1.0, 0.0], vec![0.0, 1.0]])?;
let mask = AttentionMask::new(vec![vec![true, true, true], vec![true, true, true]])?;
let targets = TokenSequence::from_indices([0, 1])?;
assert!(matches!(
TransformerReadoutTrainingExample::new(hidden, mask, targets),
Err(CtError::ShapeMismatch {
op: "transformer readout training mask",
..
})
));
Ok(())
}
#[test]
fn transformer_readout_train_step_reduces_sequence_loss() -> CtResult<()> {
let dataset = tiny_transformer_training_set()?;
let state =
TransformerTrainingState::new(tiny_transformer_parameters()?, LearningRate::new(0.5)?);
let before = transformer_readout_average_loss(&state, &dataset)?;
let train_step = TransformerReadoutTrainStep::new(dataset.clone());
let trained =
crate::category::apply_endomorphism_n_times(&train_step, state, StepCount::new(40))?;
let after = transformer_readout_average_loss(&trained, &dataset)?;
assert!(after.value() < before.value());
assert_eq!(trained.step_count().value(), 40);
Ok(())
}
#[test]
fn transformer_readout_train_step_rejects_target_outside_vocabulary() -> CtResult<()> {
let example = TransformerReadoutTrainingExample::new(
HiddenSequence::new(vec![vec![1.0, 0.0], vec![0.0, 1.0]])?,
AttentionMask::new(vec![vec![true, false], vec![true, true]])?,
TokenSequence::from_indices([0, 9])?,
)?;
let dataset = TransformerReadoutTrainingSet::new([example])?;
let state =
TransformerTrainingState::new(tiny_transformer_parameters()?, LearningRate::new(0.1)?);
let train_step = TransformerReadoutTrainStep::new(dataset);
assert!(matches!(
train_step.apply(state),
Err(CtError::OutOfRange {
kind: "sequence target",
index: 9,
limit: 3,
})
));
Ok(())
}
#[test]
fn transformer_feed_forward_training_example_rejects_target_shape_mismatch() -> CtResult<()> {
let input = HiddenSequence::new(vec![vec![1.0, 0.0], vec![0.0, 1.0]])?;
let target = HiddenSequence::new(vec![vec![1.0, 0.0, 0.0], vec![0.0, 1.0, 0.0]])?;
assert!(matches!(
TransformerFeedForwardTrainingExample::new(input, target),
Err(CtError::ShapeMismatch {
op: "transformer feed-forward training dimension",
..
})
));
Ok(())
}
#[test]
fn transformer_feed_forward_training_set_rejects_empty_input() {
assert!(matches!(
TransformerFeedForwardTrainingSet::new([]),
Err(CtError::EmptyInput("transformer feed-forward training set"))
));
}
#[test]
fn transformer_feed_forward_train_step_reduces_local_hidden_loss() -> CtResult<()> {
let dataset = tiny_feed_forward_training_set()?;
let state =
TransformerTrainingState::new(tiny_transformer_parameters()?, LearningRate::new(0.2)?);
let before = transformer_feed_forward_average_loss(&state, &dataset)?;
let train_step = TransformerFeedForwardTrainStep::new(dataset.clone());
let trained =
crate::category::apply_endomorphism_n_times(&train_step, state, StepCount::new(60))?;
let after = transformer_feed_forward_average_loss(&trained, &dataset)?;
assert!(after.value() < before.value());
assert_eq!(trained.step_count().value(), 60);
Ok(())
}
#[test]
fn transformer_block_training_example_rejects_mask_shape_mismatch() -> CtResult<()> {
let hidden = HiddenSequence::new(vec![vec![1.0, 0.0], vec![0.0, 1.0]])?;
let mask = AttentionMask::new(vec![vec![true, true, true], vec![true, true, true]])?;
let targets = TokenSequence::from_indices([0, 1])?;
assert!(matches!(
TransformerBlockTrainingExample::new(hidden, mask, targets),
Err(CtError::ShapeMismatch {
op: "transformer block training mask",
..
})
));
Ok(())
}
#[test]
fn transformer_block_training_set_rejects_empty_input() {
assert!(matches!(
TransformerBlockTrainingSet::new([]),
Err(CtError::EmptyInput("transformer block training set"))
));
}
#[test]
fn transformer_block_train_step_rejects_target_outside_vocabulary() -> CtResult<()> {
let example = TransformerBlockTrainingExample::new(
HiddenSequence::new(vec![vec![1.0, 0.0], vec![0.0, 1.0]])?,
AttentionMask::new(vec![vec![true, false], vec![true, true]])?,
TokenSequence::from_indices([0, 9])?,
)?;
let dataset = TransformerBlockTrainingSet::new([example])?;
let state =
TransformerTrainingState::new(tiny_transformer_parameters()?, LearningRate::new(0.1)?);
let train_step = TransformerBlockTrainStep::new(dataset);
assert!(matches!(
train_step.apply(state),
Err(CtError::OutOfRange {
kind: "sequence target",
index: 9,
limit: 3,
})
));
Ok(())
}
#[test]
fn transformer_block_train_step_reduces_sequence_loss() -> CtResult<()> {
let dataset = tiny_transformer_block_training_set()?;
let state =
TransformerTrainingState::new(tiny_transformer_parameters()?, LearningRate::new(0.2)?);
let before = transformer_block_average_loss(&state, &dataset)?;
let train_step = TransformerBlockTrainStep::new(dataset.clone());
let trained =
crate::category::apply_endomorphism_n_times(&train_step, state, StepCount::new(40))?;
let after = transformer_block_average_loss(&trained, &dataset)?;
assert!(after.value() < before.value());
assert_eq!(trained.step_count().value(), 40);
Ok(())
}
#[test]
fn transformer_block_train_step_updates_attention_output_projection() -> CtResult<()> {
let dataset = tiny_transformer_block_training_set()?;
let state =
TransformerTrainingState::new(tiny_transformer_parameters()?, LearningRate::new(0.2)?);
let before = state.parameters().output_projection().weight().to_vec();
let train_step = TransformerBlockTrainStep::new(dataset);
let trained = train_step.apply(state)?;
let after = trained.parameters().output_projection().weight().to_vec();
assert_ne!(before, after);
assert_eq!(trained.step_count().value(), 1);
Ok(())
}
#[test]
fn transformer_block_train_step_updates_layer_norm_parameters() -> CtResult<()> {
let dataset = tiny_transformer_block_training_set()?;
let state =
TransformerTrainingState::new(tiny_transformer_parameters()?, LearningRate::new(0.2)?);
let before_attention_scale = state
.parameters()
.attention_norm()
.parameters()
.scale()
.to_vec();
let before_feed_forward_shift = state
.parameters()
.feed_forward_norm()
.parameters()
.shift()
.to_vec();
let train_step = TransformerBlockTrainStep::new(dataset);
let trained = train_step.apply(state)?;
let after_attention_scale = trained
.parameters()
.attention_norm()
.parameters()
.scale()
.to_vec();
let after_feed_forward_shift = trained
.parameters()
.feed_forward_norm()
.parameters()
.shift()
.to_vec();
assert_ne!(before_attention_scale, after_attention_scale);
assert_ne!(before_feed_forward_shift, after_feed_forward_shift);
assert_eq!(trained.step_count().value(), 1);
Ok(())
}
#[test]
fn transformer_block_train_step_updates_query_key_value_projections() -> CtResult<()> {
let dataset = tiny_transformer_block_training_set()?;
let state =
TransformerTrainingState::new(tiny_transformer_parameters()?, LearningRate::new(0.2)?);
let before_query = state
.parameters()
.attention_heads()
.iter()
.map(|head| head.query_projection().weight().to_vec())
.collect::<Vec<_>>();
let before_key = state
.parameters()
.attention_heads()
.iter()
.map(|head| head.key_projection().weight().to_vec())
.collect::<Vec<_>>();
let before_value = state
.parameters()
.attention_heads()
.iter()
.map(|head| head.value_projection().weight().to_vec())
.collect::<Vec<_>>();
let train_step = TransformerBlockTrainStep::new(dataset);
let trained = train_step.apply(state)?;
let after_query = trained
.parameters()
.attention_heads()
.iter()
.map(|head| head.query_projection().weight().to_vec())
.collect::<Vec<_>>();
let after_key = trained
.parameters()
.attention_heads()
.iter()
.map(|head| head.key_projection().weight().to_vec())
.collect::<Vec<_>>();
let after_value = trained
.parameters()
.attention_heads()
.iter()
.map(|head| head.value_projection().weight().to_vec())
.collect::<Vec<_>>();
assert_ne!(before_query, after_query);
assert_ne!(before_key, after_key);
assert_ne!(before_value, after_value);
assert_eq!(trained.step_count().value(), 1);
Ok(())
}
#[test]
fn transformer_block_train_step_matches_finite_difference_for_attention_projection()
-> CtResult<()> {
let dataset = tiny_transformer_block_training_set()?;
let state =
TransformerTrainingState::new(tiny_transformer_parameters()?, LearningRate::new(0.2)?);
let train_step = TransformerBlockTrainStep::new(dataset.clone());
let trained = train_step.apply(state.clone())?;
let selection = largest_changed_attention_projection(&state, &trained)?;
let before_value = attention_projection_weight(&state, selection)?;
let after_value = attention_projection_weight(&trained, selection)?;
let inferred_gradient = (before_value - after_value) / state.learning_rate().value();
let epsilon = 1e-3;
let loss_plus = transformer_block_average_loss(
&state_with_attention_projection_weight(&state, selection, before_value + epsilon)?,
&dataset,
)?
.value();
let loss_minus = transformer_block_average_loss(
&state_with_attention_projection_weight(&state, selection, before_value - epsilon)?,
&dataset,
)?
.value();
let finite_difference = (loss_plus - loss_minus) / (2.0 * epsilon);
assert!(
(inferred_gradient - finite_difference).abs() < 1e-2,
"inferred gradient {inferred_gradient} should match finite difference {finite_difference}"
);
Ok(())
}
#[test]
fn transformer_block_train_step_matches_finite_difference_for_readout_weight() -> CtResult<()> {
let dataset = tiny_transformer_block_training_set()?;
let state =
TransformerTrainingState::new(tiny_transformer_parameters()?, LearningRate::new(0.2)?);
let train_step = TransformerBlockTrainStep::new(dataset.clone());
let trained = train_step.apply(state.clone())?;
let selection = largest_changed_readout_weight(&state, &trained)?;
let before_value = readout_weight(&state, selection);
let after_value = readout_weight(&trained, selection);
assert_block_gradient_matches_finite_difference(
&state,
&dataset,
before_value,
after_value,
|value| state_with_readout_weight(&state, selection, value),
)
}
#[test]
fn transformer_block_train_step_matches_finite_difference_for_feed_forward_weight()
-> CtResult<()> {
let dataset = tiny_transformer_block_training_set()?;
let state =
TransformerTrainingState::new(tiny_transformer_parameters()?, LearningRate::new(0.2)?);
let train_step = TransformerBlockTrainStep::new(dataset.clone());
let trained = train_step.apply(state.clone())?;
let selection = largest_changed_feed_forward_weight(&state, &trained)?;
let before_value = feed_forward_weight(&state, selection);
let after_value = feed_forward_weight(&trained, selection);
assert_block_gradient_matches_finite_difference(
&state,
&dataset,
before_value,
after_value,
|value| state_with_feed_forward_weight(&state, selection, value),
)
}
#[test]
fn transformer_block_train_step_matches_finite_difference_for_layer_norm_parameter()
-> CtResult<()> {
let dataset = tiny_transformer_block_training_set()?;
let state =
TransformerTrainingState::new(tiny_transformer_parameters()?, LearningRate::new(0.2)?);
let train_step = TransformerBlockTrainStep::new(dataset.clone());
let trained = train_step.apply(state.clone())?;
let selection = largest_changed_layer_norm_parameter(&state, &trained)?;
let before_value = layer_norm_parameter_value(&state, selection);
let after_value = layer_norm_parameter_value(&trained, selection);
assert_block_gradient_matches_finite_difference(
&state,
&dataset,
before_value,
after_value,
|value| state_with_layer_norm_parameter(&state, selection, value),
)
}
#[test]
fn transformer_block_train_step_matches_finite_difference_for_readout_bias() -> CtResult<()> {
let dataset = tiny_transformer_block_training_set()?;
let state =
TransformerTrainingState::new(tiny_transformer_parameters()?, LearningRate::new(0.2)?);
let train_step = TransformerBlockTrainStep::new(dataset.clone());
let trained = train_step.apply(state.clone())?;
let selection = largest_changed_vector_index(
state.parameters().readout().bias(),
trained.parameters().readout().bias(),
"changed readout bias",
)?;
let before_value = state.parameters().readout().bias()[selection];
let after_value = trained.parameters().readout().bias()[selection];
assert_block_gradient_matches_finite_difference(
&state,
&dataset,
before_value,
after_value,
|value| state_with_readout_bias(&state, selection, value),
)
}
#[test]
fn transformer_block_train_step_matches_finite_difference_for_feed_forward_bias() -> CtResult<()>
{
let dataset = tiny_transformer_block_training_set()?;
let state =
TransformerTrainingState::new(tiny_transformer_parameters()?, LearningRate::new(0.2)?);
let train_step = TransformerBlockTrainStep::new(dataset.clone());
let trained = train_step.apply(state.clone())?;
let selection = largest_changed_feed_forward_bias(&state, &trained)?;
let before_value = feed_forward_bias(&state, selection);
let after_value = feed_forward_bias(&trained, selection);
assert_block_gradient_matches_finite_difference(
&state,
&dataset,
before_value,
after_value,
|value| state_with_feed_forward_bias(&state, selection, value),
)
}
#[test]
fn transformer_block_train_step_matches_finite_difference_for_output_projection_bias()
-> CtResult<()> {
let dataset = tiny_transformer_block_training_set()?;
let state =
TransformerTrainingState::new(tiny_transformer_parameters()?, LearningRate::new(0.2)?);
let train_step = TransformerBlockTrainStep::new(dataset.clone());
let trained = train_step.apply(state.clone())?;
let selection = largest_changed_vector_index(
state.parameters().output_projection().bias(),
trained.parameters().output_projection().bias(),
"changed attention output projection bias",
)?;
let before_value = state.parameters().output_projection().bias()[selection];
let after_value = trained.parameters().output_projection().bias()[selection];
assert_block_gradient_matches_finite_difference(
&state,
&dataset,
before_value,
after_value,
|value| state_with_output_projection_bias(&state, selection, value),
)
}
#[test]
fn transformer_block_train_step_matches_finite_difference_for_attention_projection_bias()
-> CtResult<()> {
let dataset = tiny_transformer_block_training_set()?;
let state =
TransformerTrainingState::new(tiny_transformer_parameters()?, LearningRate::new(0.2)?);
let train_step = TransformerBlockTrainStep::new(dataset.clone());
let trained = train_step.apply(state.clone())?;
let selection = largest_changed_attention_projection_bias(&state, &trained)?;
let before_value = attention_projection_bias(&state, selection)?;
let after_value = attention_projection_bias(&trained, selection)?;
assert_block_gradient_matches_finite_difference(
&state,
&dataset,
before_value,
after_value,
|value| state_with_attention_projection_bias(&state, selection, value),
)
}
fn assert_block_gradient_matches_finite_difference(
state: &TransformerTrainingState,
dataset: &TransformerBlockTrainingSet,
before_value: f32,
after_value: f32,
mut state_with_value: impl FnMut(f32) -> CtResult<TransformerTrainingState>,
) -> CtResult<()> {
let inferred_gradient = (before_value - after_value) / state.learning_rate().value();
let epsilon = 1e-3;
let loss_plus =
transformer_block_average_loss(&state_with_value(before_value + epsilon)?, dataset)?
.value();
let loss_minus =
transformer_block_average_loss(&state_with_value(before_value - epsilon)?, dataset)?
.value();
let finite_difference = (loss_plus - loss_minus) / (2.0 * epsilon);
assert!(
(inferred_gradient - finite_difference).abs() < 1e-2,
"inferred gradient {inferred_gradient} should match finite difference {finite_difference}"
);
Ok(())
}
fn largest_changed_vector_index(
before: &[f32],
after: &[f32],
label: &'static str,
) -> CtResult<usize> {
let mut selected = None;
let mut largest_delta = 0.0;
for (index, (before_value, after_value)) in before.iter().zip(after).enumerate() {
let delta = (before_value - after_value).abs();
if delta > largest_delta {
largest_delta = delta;
selected = Some(index);
}
}
selected.ok_or(CtError::EmptyInput(label))
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
struct MatrixSelection {
input_index: usize,
output_index: usize,
}
fn largest_changed_matrix_weight(
before: &[Vec<f32>],
after: &[Vec<f32>],
label: &'static str,
) -> CtResult<MatrixSelection> {
let mut selected = None;
let mut largest_delta = 0.0;
for (input_index, (before_row, after_row)) in before.iter().zip(after).enumerate() {
for (output_index, (before_value, after_value)) in
before_row.iter().zip(after_row).enumerate()
{
let delta = (before_value - after_value).abs();
if delta > largest_delta {
largest_delta = delta;
selected = Some(MatrixSelection {
input_index,
output_index,
});
}
}
}
selected.ok_or(CtError::EmptyInput(label))
}
fn largest_changed_readout_weight(
before: &TransformerTrainingState,
after: &TransformerTrainingState,
) -> CtResult<MatrixSelection> {
largest_changed_matrix_weight(
before.parameters().readout().weight(),
after.parameters().readout().weight(),
"changed readout weight",
)
}
fn readout_weight(state: &TransformerTrainingState, selection: MatrixSelection) -> f32 {
state.parameters().readout().weight()[selection.input_index][selection.output_index]
}
fn state_with_readout_weight(
state: &TransformerTrainingState,
selection: MatrixSelection,
value: f32,
) -> CtResult<TransformerTrainingState> {
let readout = state.parameters().readout();
let mut weight = readout.weight().to_vec();
weight[selection.input_index][selection.output_index] = value;
let readout = TransformerReadout::new(weight, readout.bias().to_vec())?;
let parameters = state.parameters().clone().with_readout(readout)?;
Ok(TransformerTrainingState::from_parts(
parameters,
state.learning_rate(),
state.step_count(),
))
}
fn state_with_readout_bias(
state: &TransformerTrainingState,
selection: usize,
value: f32,
) -> CtResult<TransformerTrainingState> {
let readout = state.parameters().readout();
let mut bias = readout.bias().to_vec();
bias[selection] = value;
let readout = TransformerReadout::new(readout.weight().to_vec(), bias)?;
let parameters = state.parameters().clone().with_readout(readout)?;
Ok(TransformerTrainingState::from_parts(
parameters,
state.learning_rate(),
state.step_count(),
))
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
enum FeedForwardWeightKind {
First,
Second,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
struct FeedForwardWeightSelection {
kind: FeedForwardWeightKind,
matrix: MatrixSelection,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
struct FeedForwardBiasSelection {
kind: FeedForwardWeightKind,
index: usize,
}
fn largest_changed_feed_forward_weight(
before: &TransformerTrainingState,
after: &TransformerTrainingState,
) -> CtResult<FeedForwardWeightSelection> {
let first = largest_changed_matrix_weight(
before.parameters().feed_forward().first_weight(),
after.parameters().feed_forward().first_weight(),
"changed first feed-forward weight",
);
let second = largest_changed_matrix_weight(
before.parameters().feed_forward().second_weight(),
after.parameters().feed_forward().second_weight(),
"changed second feed-forward weight",
);
match (first, second) {
(Ok(first), Ok(second)) => {
let first_delta = feed_forward_weight_delta(
before,
after,
FeedForwardWeightSelection {
kind: FeedForwardWeightKind::First,
matrix: first,
},
);
let second_delta = feed_forward_weight_delta(
before,
after,
FeedForwardWeightSelection {
kind: FeedForwardWeightKind::Second,
matrix: second,
},
);
if first_delta >= second_delta {
Ok(FeedForwardWeightSelection {
kind: FeedForwardWeightKind::First,
matrix: first,
})
} else {
Ok(FeedForwardWeightSelection {
kind: FeedForwardWeightKind::Second,
matrix: second,
})
}
}
(Ok(first), Err(_)) => Ok(FeedForwardWeightSelection {
kind: FeedForwardWeightKind::First,
matrix: first,
}),
(Err(_), Ok(second)) => Ok(FeedForwardWeightSelection {
kind: FeedForwardWeightKind::Second,
matrix: second,
}),
(Err(_), Err(_)) => Err(CtError::EmptyInput("changed feed-forward weight")),
}
}
fn feed_forward_weight_delta(
before: &TransformerTrainingState,
after: &TransformerTrainingState,
selection: FeedForwardWeightSelection,
) -> f32 {
(feed_forward_weight(before, selection) - feed_forward_weight(after, selection)).abs()
}
fn feed_forward_weight(
state: &TransformerTrainingState,
selection: FeedForwardWeightSelection,
) -> f32 {
let feed_forward = state.parameters().feed_forward();
match selection.kind {
FeedForwardWeightKind::First => feed_forward.first_weight()
[selection.matrix.input_index][selection.matrix.output_index],
FeedForwardWeightKind::Second => feed_forward.second_weight()
[selection.matrix.input_index][selection.matrix.output_index],
}
}
fn state_with_feed_forward_weight(
state: &TransformerTrainingState,
selection: FeedForwardWeightSelection,
value: f32,
) -> CtResult<TransformerTrainingState> {
let feed_forward = state.parameters().feed_forward();
let mut first_weight = feed_forward.first_weight().to_vec();
let mut second_weight = feed_forward.second_weight().to_vec();
match selection.kind {
FeedForwardWeightKind::First => {
first_weight[selection.matrix.input_index][selection.matrix.output_index] = value;
}
FeedForwardWeightKind::Second => {
second_weight[selection.matrix.input_index][selection.matrix.output_index] = value;
}
}
let feed_forward = PositionWiseFeedForward::new(
first_weight,
feed_forward.first_bias().to_vec(),
second_weight,
feed_forward.second_bias().to_vec(),
)?;
let parameters = state.parameters().clone().with_feed_forward(feed_forward)?;
Ok(TransformerTrainingState::from_parts(
parameters,
state.learning_rate(),
state.step_count(),
))
}
fn largest_changed_feed_forward_bias(
before: &TransformerTrainingState,
after: &TransformerTrainingState,
) -> CtResult<FeedForwardBiasSelection> {
let first = largest_changed_vector_index(
before.parameters().feed_forward().first_bias(),
after.parameters().feed_forward().first_bias(),
"changed first feed-forward bias",
);
let second = largest_changed_vector_index(
before.parameters().feed_forward().second_bias(),
after.parameters().feed_forward().second_bias(),
"changed second feed-forward bias",
);
match (first, second) {
(Ok(first), Ok(second)) => {
let first_delta = feed_forward_bias_delta(
before,
after,
FeedForwardBiasSelection {
kind: FeedForwardWeightKind::First,
index: first,
},
);
let second_delta = feed_forward_bias_delta(
before,
after,
FeedForwardBiasSelection {
kind: FeedForwardWeightKind::Second,
index: second,
},
);
if first_delta >= second_delta {
Ok(FeedForwardBiasSelection {
kind: FeedForwardWeightKind::First,
index: first,
})
} else {
Ok(FeedForwardBiasSelection {
kind: FeedForwardWeightKind::Second,
index: second,
})
}
}
(Ok(first), Err(_)) => Ok(FeedForwardBiasSelection {
kind: FeedForwardWeightKind::First,
index: first,
}),
(Err(_), Ok(second)) => Ok(FeedForwardBiasSelection {
kind: FeedForwardWeightKind::Second,
index: second,
}),
(Err(_), Err(_)) => Err(CtError::EmptyInput("changed feed-forward bias")),
}
}
fn feed_forward_bias_delta(
before: &TransformerTrainingState,
after: &TransformerTrainingState,
selection: FeedForwardBiasSelection,
) -> f32 {
(feed_forward_bias(before, selection) - feed_forward_bias(after, selection)).abs()
}
fn feed_forward_bias(
state: &TransformerTrainingState,
selection: FeedForwardBiasSelection,
) -> f32 {
let feed_forward = state.parameters().feed_forward();
match selection.kind {
FeedForwardWeightKind::First => feed_forward.first_bias()[selection.index],
FeedForwardWeightKind::Second => feed_forward.second_bias()[selection.index],
}
}
fn state_with_feed_forward_bias(
state: &TransformerTrainingState,
selection: FeedForwardBiasSelection,
value: f32,
) -> CtResult<TransformerTrainingState> {
let feed_forward = state.parameters().feed_forward();
let mut first_bias = feed_forward.first_bias().to_vec();
let mut second_bias = feed_forward.second_bias().to_vec();
match selection.kind {
FeedForwardWeightKind::First => {
first_bias[selection.index] = value;
}
FeedForwardWeightKind::Second => {
second_bias[selection.index] = value;
}
}
let feed_forward = PositionWiseFeedForward::new(
feed_forward.first_weight().to_vec(),
first_bias,
feed_forward.second_weight().to_vec(),
second_bias,
)?;
let parameters = state.parameters().clone().with_feed_forward(feed_forward)?;
Ok(TransformerTrainingState::from_parts(
parameters,
state.learning_rate(),
state.step_count(),
))
}
fn state_with_output_projection_bias(
state: &TransformerTrainingState,
selection: usize,
value: f32,
) -> CtResult<TransformerTrainingState> {
let output_projection = state.parameters().output_projection();
let mut bias = output_projection.bias().to_vec();
bias[selection] = value;
let output_projection =
AttentionOutputProjection::new(output_projection.weight().to_vec(), bias)?;
let parameters = state
.parameters()
.clone()
.with_output_projection(output_projection)?;
Ok(TransformerTrainingState::from_parts(
parameters,
state.learning_rate(),
state.step_count(),
))
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
enum LayerNormParameterKind {
AttentionScale,
AttentionShift,
FeedForwardScale,
FeedForwardShift,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
struct LayerNormParameterSelection {
kind: LayerNormParameterKind,
feature_index: usize,
}
fn largest_changed_layer_norm_parameter(
before: &TransformerTrainingState,
after: &TransformerTrainingState,
) -> CtResult<LayerNormParameterSelection> {
let mut selected = None;
let mut largest_delta = 0.0;
for kind in [
LayerNormParameterKind::AttentionScale,
LayerNormParameterKind::AttentionShift,
LayerNormParameterKind::FeedForwardScale,
LayerNormParameterKind::FeedForwardShift,
] {
let before_values = layer_norm_parameter_values(before, kind);
let after_values = layer_norm_parameter_values(after, kind);
for (feature_index, (before_value, after_value)) in
before_values.iter().zip(after_values).enumerate()
{
let delta = (before_value - after_value).abs();
if delta > largest_delta {
largest_delta = delta;
selected = Some(LayerNormParameterSelection {
kind,
feature_index,
});
}
}
}
selected.ok_or(CtError::EmptyInput("changed layer norm parameter"))
}
fn layer_norm_parameter_values(
state: &TransformerTrainingState,
kind: LayerNormParameterKind,
) -> &[f32] {
match kind {
LayerNormParameterKind::AttentionScale => {
state.parameters().attention_norm().parameters().scale()
}
LayerNormParameterKind::AttentionShift => {
state.parameters().attention_norm().parameters().shift()
}
LayerNormParameterKind::FeedForwardScale => {
state.parameters().feed_forward_norm().parameters().scale()
}
LayerNormParameterKind::FeedForwardShift => {
state.parameters().feed_forward_norm().parameters().shift()
}
}
}
fn layer_norm_parameter_value(
state: &TransformerTrainingState,
selection: LayerNormParameterSelection,
) -> f32 {
layer_norm_parameter_values(state, selection.kind)[selection.feature_index]
}
fn state_with_layer_norm_parameter(
state: &TransformerTrainingState,
selection: LayerNormParameterSelection,
value: f32,
) -> CtResult<TransformerTrainingState> {
let attention_parameters = state.parameters().attention_norm().parameters();
let feed_forward_parameters = state.parameters().feed_forward_norm().parameters();
let mut attention_scale = attention_parameters.scale().to_vec();
let mut attention_shift = attention_parameters.shift().to_vec();
let mut feed_forward_scale = feed_forward_parameters.scale().to_vec();
let mut feed_forward_shift = feed_forward_parameters.shift().to_vec();
match selection.kind {
LayerNormParameterKind::AttentionScale => {
attention_scale[selection.feature_index] = value;
}
LayerNormParameterKind::AttentionShift => {
attention_shift[selection.feature_index] = value;
}
LayerNormParameterKind::FeedForwardScale => {
feed_forward_scale[selection.feature_index] = value;
}
LayerNormParameterKind::FeedForwardShift => {
feed_forward_shift[selection.feature_index] = value;
}
}
let attention_norm = LayerNormalization::new(LayerNormParameters::new(
attention_scale,
attention_shift,
attention_parameters.epsilon(),
)?);
let feed_forward_norm = LayerNormalization::new(LayerNormParameters::new(
feed_forward_scale,
feed_forward_shift,
feed_forward_parameters.epsilon(),
)?);
let parameters = state
.parameters()
.clone()
.with_layer_norms(attention_norm, feed_forward_norm)?;
Ok(TransformerTrainingState::from_parts(
parameters,
state.learning_rate(),
state.step_count(),
))
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
enum AttentionProjectionKind {
Query,
Key,
Value,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
struct AttentionProjectionSelection {
head_index: usize,
kind: AttentionProjectionKind,
input_index: usize,
output_index: usize,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
struct AttentionProjectionBiasSelection {
head_index: usize,
kind: AttentionProjectionKind,
output_index: usize,
}
fn largest_changed_attention_projection(
before: &TransformerTrainingState,
after: &TransformerTrainingState,
) -> CtResult<AttentionProjectionSelection> {
let head_count = before.parameters().attention_heads().len();
let mut selected = None;
let mut largest_delta = 0.0;
for head_index in 0..head_count {
for kind in [
AttentionProjectionKind::Query,
AttentionProjectionKind::Key,
AttentionProjectionKind::Value,
] {
let before_weight = attention_projection_weight_matrix(before, head_index, kind)?;
let after_weight = attention_projection_weight_matrix(after, head_index, kind)?;
for (input_index, (before_row, after_row)) in
before_weight.iter().zip(after_weight).enumerate()
{
for (output_index, (before_value, after_value)) in
before_row.iter().zip(after_row).enumerate()
{
let delta = (before_value - after_value).abs();
if delta > largest_delta {
largest_delta = delta;
selected = Some(AttentionProjectionSelection {
head_index,
kind,
input_index,
output_index,
});
}
}
}
}
}
selected.ok_or(CtError::EmptyInput("changed attention projection"))
}
fn attention_projection_weight(
state: &TransformerTrainingState,
selection: AttentionProjectionSelection,
) -> CtResult<f32> {
let weight =
attention_projection_weight_matrix(state, selection.head_index, selection.kind)?;
Ok(weight[selection.input_index][selection.output_index])
}
fn attention_projection_weight_matrix(
state: &TransformerTrainingState,
head_index: usize,
kind: AttentionProjectionKind,
) -> CtResult<&[Vec<f32>]> {
let head =
state
.parameters()
.attention_heads()
.get(head_index)
.ok_or(CtError::OutOfRange {
kind: "attention head",
index: head_index,
limit: state.parameters().attention_heads().len(),
})?;
Ok(match kind {
AttentionProjectionKind::Query => head.query_projection().weight(),
AttentionProjectionKind::Key => head.key_projection().weight(),
AttentionProjectionKind::Value => head.value_projection().weight(),
})
}
fn largest_changed_attention_projection_bias(
before: &TransformerTrainingState,
after: &TransformerTrainingState,
) -> CtResult<AttentionProjectionBiasSelection> {
let head_count = before.parameters().attention_heads().len();
let mut selected = None;
let mut largest_delta = 0.0;
for head_index in 0..head_count {
for kind in [
AttentionProjectionKind::Query,
AttentionProjectionKind::Key,
AttentionProjectionKind::Value,
] {
let before_bias = attention_projection_bias_values(before, head_index, kind)?;
let after_bias = attention_projection_bias_values(after, head_index, kind)?;
for (output_index, (before_value, after_value)) in
before_bias.iter().zip(after_bias).enumerate()
{
let delta = (before_value - after_value).abs();
if delta > largest_delta {
largest_delta = delta;
selected = Some(AttentionProjectionBiasSelection {
head_index,
kind,
output_index,
});
}
}
}
}
selected.ok_or(CtError::EmptyInput("changed attention projection bias"))
}
fn attention_projection_bias(
state: &TransformerTrainingState,
selection: AttentionProjectionBiasSelection,
) -> CtResult<f32> {
let bias = attention_projection_bias_values(state, selection.head_index, selection.kind)?;
Ok(bias[selection.output_index])
}
fn attention_projection_bias_values(
state: &TransformerTrainingState,
head_index: usize,
kind: AttentionProjectionKind,
) -> CtResult<&[f32]> {
let head =
state
.parameters()
.attention_heads()
.get(head_index)
.ok_or(CtError::OutOfRange {
kind: "attention head",
index: head_index,
limit: state.parameters().attention_heads().len(),
})?;
Ok(match kind {
AttentionProjectionKind::Query => head.query_projection().bias(),
AttentionProjectionKind::Key => head.key_projection().bias(),
AttentionProjectionKind::Value => head.value_projection().bias(),
})
}
fn state_with_attention_projection_weight(
state: &TransformerTrainingState,
selection: AttentionProjectionSelection,
value: f32,
) -> CtResult<TransformerTrainingState> {
let mut heads = state.parameters().attention_heads().to_vec();
let head = heads.get(selection.head_index).ok_or(CtError::OutOfRange {
kind: "attention head",
index: selection.head_index,
limit: heads.len(),
})?;
let mut query_weight = head.query_projection().weight().to_vec();
let mut key_weight = head.key_projection().weight().to_vec();
let mut value_weight = head.value_projection().weight().to_vec();
match selection.kind {
AttentionProjectionKind::Query => {
query_weight[selection.input_index][selection.output_index] = value;
}
AttentionProjectionKind::Key => {
key_weight[selection.input_index][selection.output_index] = value;
}
AttentionProjectionKind::Value => {
value_weight[selection.input_index][selection.output_index] = value;
}
}
heads[selection.head_index] = SelfAttentionHead::new(
HiddenToQuery::new(query_weight, head.query_projection().bias().to_vec())?,
HiddenToKey::new(key_weight, head.key_projection().bias().to_vec())?,
HiddenToValue::new(value_weight, head.value_projection().bias().to_vec())?,
)?;
let parameters = state.parameters().clone().with_attention_heads(heads)?;
Ok(TransformerTrainingState::from_parts(
parameters,
state.learning_rate(),
state.step_count(),
))
}
fn state_with_attention_projection_bias(
state: &TransformerTrainingState,
selection: AttentionProjectionBiasSelection,
value: f32,
) -> CtResult<TransformerTrainingState> {
let mut heads = state.parameters().attention_heads().to_vec();
let head = heads.get(selection.head_index).ok_or(CtError::OutOfRange {
kind: "attention head",
index: selection.head_index,
limit: heads.len(),
})?;
let mut query_bias = head.query_projection().bias().to_vec();
let mut key_bias = head.key_projection().bias().to_vec();
let mut value_bias = head.value_projection().bias().to_vec();
match selection.kind {
AttentionProjectionKind::Query => {
query_bias[selection.output_index] = value;
}
AttentionProjectionKind::Key => {
key_bias[selection.output_index] = value;
}
AttentionProjectionKind::Value => {
value_bias[selection.output_index] = value;
}
}
heads[selection.head_index] = SelfAttentionHead::new(
HiddenToQuery::new(head.query_projection().weight().to_vec(), query_bias)?,
HiddenToKey::new(head.key_projection().weight().to_vec(), key_bias)?,
HiddenToValue::new(head.value_projection().weight().to_vec(), value_bias)?,
)?;
let parameters = state.parameters().clone().with_attention_heads(heads)?;
Ok(TransformerTrainingState::from_parts(
parameters,
state.learning_rate(),
state.step_count(),
))
}
fn tiny_single_head_block() -> CtResult<SingleHeadTransformerBlock> {
let query = HiddenToQuery::new(vec![vec![1.0, 0.0], vec![0.0, 1.0]], vec![0.0, 0.0])?;
let key = HiddenToKey::new(vec![vec![1.0, 0.0], vec![0.0, 1.0]], vec![0.0, 0.0])?;
let value = HiddenToValue::new(vec![vec![1.0, 0.0], vec![0.0, 1.0]], vec![0.0, 0.0])?;
let output_projection =
AttentionOutputProjection::new(vec![vec![1.0, 0.0], vec![0.0, 1.0]], vec![0.0, 0.0])?;
let model_dimension = ModelDimension::new(2)?;
let attention_norm =
LayerNormalization::new(LayerNormParameters::identity(model_dimension));
let feed_forward = PositionWiseFeedForward::new(
vec![vec![1.0, 0.0], vec![0.0, 1.0]],
vec![0.0, 0.0],
vec![vec![1.0, 0.0], vec![0.0, 1.0]],
vec![0.0, 0.0],
)?;
let feed_forward_norm =
LayerNormalization::new(LayerNormParameters::identity(model_dimension));
SingleHeadTransformerBlock::new(
query,
key,
value,
output_projection,
attention_norm,
feed_forward,
feed_forward_norm,
)
}
fn tiny_multi_head_block() -> CtResult<MultiHeadTransformerBlock> {
let model_dimension = ModelDimension::new(2)?;
MultiHeadTransformerBlock::new(
vec![
tiny_self_attention_head_first_feature()?,
tiny_self_attention_head_second_feature()?,
],
AttentionOutputProjection::new(vec![vec![1.0, 0.0], vec![0.0, 1.0]], vec![0.0, 0.0])?,
LayerNormalization::new(LayerNormParameters::identity(model_dimension)),
identity_feed_forward()?,
LayerNormalization::new(LayerNormParameters::identity(model_dimension)),
)
}
fn tiny_masked_multi_head_block() -> CtResult<MaskedMultiHeadTransformerBlock> {
let model_dimension = ModelDimension::new(2)?;
MaskedMultiHeadTransformerBlock::new(
vec![
tiny_self_attention_head_first_feature()?,
tiny_self_attention_head_second_feature()?,
],
AttentionOutputProjection::new(vec![vec![1.0, 0.0], vec![0.0, 1.0]], vec![0.0, 0.0])?,
LayerNormalization::new(LayerNormParameters::identity(model_dimension)),
identity_feed_forward()?,
LayerNormalization::new(LayerNormParameters::identity(model_dimension)),
)
}
fn tiny_transformer_parameters() -> CtResult<TinyTransformerParameters> {
TinyTransformerParameters::new(
PositionalEncoding::new(vec![vec![0.1, 0.0], vec![0.0, 0.1]])?,
tiny_masked_multi_head_block()?,
TransformerReadout::new(
vec![vec![1.0, 0.0, 0.5], vec![0.0, 1.0, -0.5]],
vec![0.0, 0.0, 0.0],
)?,
)
}
fn tiny_transformer_training_set() -> CtResult<TransformerReadoutTrainingSet> {
let example = TransformerReadoutTrainingExample::new(
HiddenSequence::new(vec![vec![1.0, 0.0], vec![0.0, 1.0]])?,
AttentionMask::new(vec![vec![true, false], vec![true, true]])?,
TokenSequence::from_indices([0, 1])?,
)?;
TransformerReadoutTrainingSet::new([example])
}
fn tiny_feed_forward_training_set() -> CtResult<TransformerFeedForwardTrainingSet> {
let example = TransformerFeedForwardTrainingExample::new(
HiddenSequence::new(vec![vec![1.0, 0.0], vec![0.0, 1.0]])?,
HiddenSequence::new(vec![vec![2.0, 0.0], vec![0.0, 2.0]])?,
)?;
TransformerFeedForwardTrainingSet::new([example])
}
fn tiny_transformer_block_training_set() -> CtResult<TransformerBlockTrainingSet> {
let example = TransformerBlockTrainingExample::new(
HiddenSequence::new(vec![vec![1.0, 0.0], vec![0.0, 1.0]])?,
AttentionMask::new(vec![vec![true, false], vec![true, true]])?,
TokenSequence::from_indices([0, 1])?,
)?;
TransformerBlockTrainingSet::new([example])
}
fn tiny_self_attention_head_first_feature() -> CtResult<SelfAttentionHead> {
SelfAttentionHead::new(
HiddenToQuery::new(vec![vec![1.0], vec![0.0]], vec![0.0])?,
HiddenToKey::new(vec![vec![1.0], vec![0.0]], vec![0.0])?,
HiddenToValue::new(vec![vec![1.0], vec![0.0]], vec![0.0])?,
)
}
fn tiny_self_attention_head_second_feature() -> CtResult<SelfAttentionHead> {
SelfAttentionHead::new(
HiddenToQuery::new(vec![vec![0.0], vec![1.0]], vec![0.0])?,
HiddenToKey::new(vec![vec![0.0], vec![1.0]], vec![0.0])?,
HiddenToValue::new(vec![vec![0.0], vec![1.0]], vec![0.0])?,
)
}
fn identity_feed_forward() -> CtResult<PositionWiseFeedForward> {
PositionWiseFeedForward::new(
vec![vec![1.0, 0.0], vec![0.0, 1.0]],
vec![0.0, 0.0],
vec![vec![1.0, 0.0], vec![0.0, 1.0]],
vec![0.0, 0.0],
)
}
}
The full runnable companion is:
Source snapshot: examples/06_attention_scores.rs
use category_theory_transformer_rs::{
AttentionHeadOutputs, AttentionMask, AttentionOutput, AttentionOutputProjection,
AttentionSoftmax, ConcatenateHeads, CtResult, HiddenSequence, HiddenToKey, HiddenToQuery,
HiddenToValue, KeySequence, LayerNormParameters, LayerNormalization, LearningRate,
MaskedAttentionScores, MaskedMultiHeadTransformerBlock, Morphism, MultiHeadTransformerBlock,
PositionWiseFeedForward, PositionalEncoding, Product, QuerySequence, ResidualConnection,
ScaledDotProductScores, SelfAttentionHead, SingleHeadTransformerBlock,
TinyTransformerParameters, TokenSequence, TransformerBlockTrainStep,
TransformerBlockTrainingExample, TransformerBlockTrainingSet, TransformerFeedForwardTrainStep,
TransformerFeedForwardTrainingExample, TransformerFeedForwardTrainingSet, TransformerReadout,
TransformerReadoutTrainStep, TransformerReadoutTrainingExample, TransformerReadoutTrainingSet,
TransformerTrainingState, ValueSequence, WeightedValueMixing, transformer_block_average_loss,
transformer_feed_forward_average_loss, transformer_readout_average_loss,
};
fn main() -> CtResult<()> {
let queries = QuerySequence::new(vec![vec![1.0, 0.0], vec![0.0, 1.0]])?;
let keys = KeySequence::new(vec![vec![1.0, 0.0], vec![0.0, 1.0], vec![1.0, 1.0]])?;
let values = ValueSequence::new(vec![vec![1.0, 10.0], vec![2.0, 20.0], vec![3.0, 30.0]])?;
let mask = AttentionMask::new(vec![vec![true, false, true], vec![true, true, true]])?;
println!("Q/K/V source diagnostic:");
println!("query rows own score rows; key/value rows own score columns");
println!(
"self-attention shares the hidden source before projection; projected roles stay distinct"
);
println!("mask polarity here: true = allowed, false = blocked\n");
let scores = ScaledDotProductScores.apply(Product::new(queries, keys))?;
let masked_scores = MaskedAttentionScores.apply(Product::new(scores, mask))?;
let weights = AttentionSoftmax.apply(masked_scores)?;
println!(
"attention shape: {} query positions x {} key positions",
weights.query_len().value(),
weights.key_len().value()
);
for (query_position, row) in weights.rows().iter().enumerate() {
println!("query {query_position} attends with {:?}", row.as_slice());
}
let output = WeightedValueMixing.apply(Product::new(weights, values))?;
for (query_position, row) in output.rows().iter().enumerate() {
println!("query {query_position} output vector {:?}", row.as_slice());
}
let second_head = AttentionOutput::new(vec![vec![10.0, 1.0], vec![20.0, 2.0]])?;
let head_outputs = AttentionHeadOutputs::new(vec![output, second_head])?;
let multi_head = ConcatenateHeads.apply(head_outputs)?;
println!(
"multi-head shape: {} heads x {} features -> model dimension {}",
multi_head.head_count().value(),
multi_head.head_dimension().value(),
multi_head.model_dimension().value()
);
for (query_position, row) in multi_head.rows().iter().enumerate() {
println!("query {query_position} multi-head row {:?}", row.as_slice());
}
let output_projection = AttentionOutputProjection::new(
vec![
vec![1.0, 0.0],
vec![0.0, 0.1],
vec![0.5, 0.0],
vec![0.0, 1.0],
],
vec![0.0, 0.0],
)?;
let projected = output_projection.apply(multi_head)?;
println!(
"projected attention shape: {} positions x model dimension {}",
projected.sequence_len().value(),
projected.model_dimension().value()
);
for (query_position, row) in projected.rows().iter().enumerate() {
println!(
"query {query_position} projected attention row {:?}",
row.as_slice()
);
}
let hidden_input = HiddenSequence::new(vec![vec![0.5, 0.5], vec![1.0, 1.0]])?;
let residual = ResidualConnection.apply(Product::new(hidden_input, projected))?;
println!(
"residual shape: {} positions x model dimension {}",
residual.sequence_len().value(),
residual.model_dimension().value()
);
for (query_position, row) in residual.rows().iter().enumerate() {
println!("query {query_position} residual row {:?}", row.as_slice());
}
let layer_norm =
LayerNormalization::new(LayerNormParameters::identity(residual.model_dimension()));
let normalized = layer_norm.apply(residual)?;
println!(
"normalized shape: {} positions x model dimension {}",
normalized.sequence_len().value(),
normalized.model_dimension().value()
);
for (query_position, row) in normalized.rows().iter().enumerate() {
println!("query {query_position} normalized row {:?}", row.as_slice());
}
let feed_forward = PositionWiseFeedForward::new(
vec![vec![1.0, -1.0, 0.5], vec![0.0, 1.0, 0.5]],
vec![0.0, 0.0, 0.0],
vec![vec![1.0, 0.0], vec![0.0, 1.0], vec![0.5, 0.5]],
vec![0.0, 0.0],
)?;
let fed_forward = feed_forward.apply(normalized)?;
println!(
"feed-forward shape: {} positions x model dimension {}",
fed_forward.sequence_len().value(),
fed_forward.model_dimension().value()
);
for (query_position, row) in fed_forward.rows().iter().enumerate() {
println!(
"query {query_position} feed-forward row {:?}",
row.as_slice()
);
}
let positional_encoding = PositionalEncoding::new(vec![vec![0.1, 0.0], vec![0.0, 0.1]])?;
let positioned_hidden =
positional_encoding.apply(HiddenSequence::new(vec![vec![1.0, 0.0], vec![0.0, 1.0]])?)?;
println!(
"positioned hidden shape: {} positions x model dimension {}",
positioned_hidden.sequence_len().value(),
positioned_hidden.model_dimension().value()
);
let block = SingleHeadTransformerBlock::new(
HiddenToQuery::new(vec![vec![1.0, 0.0], vec![0.0, 1.0]], vec![0.0, 0.0])?,
HiddenToKey::new(vec![vec![1.0, 0.0], vec![0.0, 1.0]], vec![0.0, 0.0])?,
HiddenToValue::new(vec![vec![1.0, 0.0], vec![0.0, 1.0]], vec![0.0, 0.0])?,
AttentionOutputProjection::new(vec![vec![1.0, 0.0], vec![0.0, 1.0]], vec![0.0, 0.0])?,
LayerNormalization::new(LayerNormParameters::identity(fed_forward.model_dimension())),
PositionWiseFeedForward::new(
vec![vec![1.0, 0.0], vec![0.0, 1.0]],
vec![0.0, 0.0],
vec![vec![1.0, 0.0], vec![0.0, 1.0]],
vec![0.0, 0.0],
)?,
LayerNormalization::new(LayerNormParameters::identity(fed_forward.model_dimension())),
)?;
let block_output = block.apply(positioned_hidden.clone())?;
println!(
"single-head block shape: {} positions x model dimension {}",
block_output.sequence_len().value(),
block_output.model_dimension().value()
);
let multi_head_block = MultiHeadTransformerBlock::new(
vec![
SelfAttentionHead::new(
HiddenToQuery::new(vec![vec![1.0], vec![0.0]], vec![0.0])?,
HiddenToKey::new(vec![vec![1.0], vec![0.0]], vec![0.0])?,
HiddenToValue::new(vec![vec![1.0], vec![0.0]], vec![0.0])?,
)?,
SelfAttentionHead::new(
HiddenToQuery::new(vec![vec![0.0], vec![1.0]], vec![0.0])?,
HiddenToKey::new(vec![vec![0.0], vec![1.0]], vec![0.0])?,
HiddenToValue::new(vec![vec![0.0], vec![1.0]], vec![0.0])?,
)?,
],
AttentionOutputProjection::new(vec![vec![1.0, 0.0], vec![0.0, 1.0]], vec![0.0, 0.0])?,
LayerNormalization::new(LayerNormParameters::identity(
block_output.model_dimension(),
)),
PositionWiseFeedForward::new(
vec![vec![1.0, 0.0], vec![0.0, 1.0]],
vec![0.0, 0.0],
vec![vec![1.0, 0.0], vec![0.0, 1.0]],
vec![0.0, 0.0],
)?,
LayerNormalization::new(LayerNormParameters::identity(
block_output.model_dimension(),
)),
)?;
let multi_head_output = multi_head_block.apply(positioned_hidden)?;
println!(
"multi-head block shape: {} positions x {} heads x value dimension {} -> model dimension {}",
multi_head_output.sequence_len().value(),
multi_head_block.head_count().value(),
multi_head_block.value_dimension().value(),
multi_head_output.model_dimension().value()
);
let masked_multi_head_block = MaskedMultiHeadTransformerBlock::new(
vec![
SelfAttentionHead::new(
HiddenToQuery::new(vec![vec![1.0], vec![0.0]], vec![0.0])?,
HiddenToKey::new(vec![vec![1.0], vec![0.0]], vec![0.0])?,
HiddenToValue::new(vec![vec![1.0], vec![0.0]], vec![0.0])?,
)?,
SelfAttentionHead::new(
HiddenToQuery::new(vec![vec![0.0], vec![1.0]], vec![0.0])?,
HiddenToKey::new(vec![vec![0.0], vec![1.0]], vec![0.0])?,
HiddenToValue::new(vec![vec![0.0], vec![1.0]], vec![0.0])?,
)?,
],
AttentionOutputProjection::new(vec![vec![1.0, 0.0], vec![0.0, 1.0]], vec![0.0, 0.0])?,
LayerNormalization::new(LayerNormParameters::identity(
multi_head_output.model_dimension(),
)),
PositionWiseFeedForward::new(
vec![vec![1.0, 0.0], vec![0.0, 1.0]],
vec![0.0, 0.0],
vec![vec![1.0, 0.0], vec![0.0, 1.0]],
vec![0.0, 0.0],
)?,
LayerNormalization::new(LayerNormParameters::identity(
multi_head_output.model_dimension(),
)),
)?;
let masked_block_output = masked_multi_head_block.apply(Product::new(
HiddenSequence::new(vec![vec![1.0, 0.0], vec![0.0, 1.0]])?,
AttentionMask::new(vec![vec![true, false], vec![true, true]])?,
))?;
println!(
"masked multi-head block shape: {} positions x model dimension {}",
masked_block_output.sequence_len().value(),
masked_block_output.model_dimension().value()
);
let transformer_parameters = TinyTransformerParameters::new(
PositionalEncoding::new(vec![vec![0.1, 0.0], vec![0.0, 0.1]])?,
masked_multi_head_block,
TransformerReadout::new(
vec![vec![1.0, 0.0, 0.5], vec![0.0, 1.0, -0.5]],
vec![0.0, 0.0, 0.0],
)?,
)?;
let transformer_state =
TransformerTrainingState::new(transformer_parameters, LearningRate::new(0.1)?);
let training_set =
TransformerReadoutTrainingSet::new([TransformerReadoutTrainingExample::new(
HiddenSequence::new(vec![vec![1.0, 0.0], vec![0.0, 1.0]])?,
AttentionMask::new(vec![vec![true, false], vec![true, true]])?,
TokenSequence::from_indices([0, 1])?,
)?])?;
let sequence_logits = transformer_state.apply(Product::new(
HiddenSequence::new(vec![vec![1.0, 0.0], vec![0.0, 1.0]])?,
AttentionMask::new(vec![vec![true, false], vec![true, true]])?,
))?;
let loss_before = transformer_readout_average_loss(&transformer_state, &training_set)?;
let train_step = TransformerReadoutTrainStep::new(training_set.clone());
let next_state = train_step.apply(transformer_state.clone())?;
let loss_after = transformer_readout_average_loss(&next_state, &training_set)?;
println!(
"structured transformer logits shape: {} positions x vocabulary size {}",
sequence_logits.sequence_len().value(),
sequence_logits.vocab_size().value()
);
println!(
"training state step: {} -> {}",
transformer_state.step_count().value(),
next_state.step_count().value()
);
println!(
"readout loss after one update: {:.6} -> {:.6}",
loss_before.value(),
loss_after.value()
);
let feed_forward_training_set =
TransformerFeedForwardTrainingSet::new([TransformerFeedForwardTrainingExample::new(
HiddenSequence::new(vec![vec![1.0, 0.0], vec![0.0, 1.0]])?,
HiddenSequence::new(vec![vec![2.0, 0.0], vec![0.0, 2.0]])?,
)?])?;
let feed_forward_loss_before =
transformer_feed_forward_average_loss(&next_state, &feed_forward_training_set)?;
let feed_forward_train_step =
TransformerFeedForwardTrainStep::new(feed_forward_training_set.clone());
let feed_forward_state = feed_forward_train_step.apply(next_state.clone())?;
let feed_forward_loss_after =
transformer_feed_forward_average_loss(&feed_forward_state, &feed_forward_training_set)?;
println!(
"feed-forward loss after one local update: {:.6} -> {:.6}",
feed_forward_loss_before.value(),
feed_forward_loss_after.value()
);
let block_training_set =
TransformerBlockTrainingSet::new([TransformerBlockTrainingExample::new(
HiddenSequence::new(vec![vec![1.0, 0.0], vec![0.0, 1.0]])?,
AttentionMask::new(vec![vec![true, false], vec![true, true]])?,
TokenSequence::from_indices([0, 1])?,
)?])?;
let block_loss_before =
transformer_block_average_loss(&feed_forward_state, &block_training_set)?;
let block_train_step = TransformerBlockTrainStep::new(block_training_set.clone());
let block_trained_state = block_train_step.apply(feed_forward_state)?;
let block_loss_after =
transformer_block_average_loss(&block_trained_state, &block_training_set)?;
println!(
"block loss after one composed update: {:.6} -> {:.6}",
block_loss_before.value(),
block_loss_after.value()
);
println!();
println!("Typed transformation:");
println!("HiddenSequence -> QuerySequence");
println!("HiddenSequence -> KeySequence");
println!("HiddenSequence -> ValueSequence");
println!("QuerySequence x KeySequence -> AttentionScores");
println!("AttentionScores x AttentionMask -> AttentionScores");
println!("AttentionScores -> AttentionWeights");
println!("AttentionWeights x ValueSequence -> AttentionOutput");
println!("AttentionHeadOutputs -> MultiHeadOutput");
println!("MultiHeadOutput -> ProjectedAttentionOutput");
println!("HiddenSequence x ProjectedAttentionOutput -> HiddenSequence");
println!("LayerNormalization : HiddenSequence -> HiddenSequence");
println!("PositionWiseFeedForward : HiddenSequence -> HiddenSequence");
println!("PositionalEncoding : HiddenSequence -> HiddenSequence");
println!("SingleHeadTransformerBlock : HiddenSequence -> HiddenSequence");
println!("MultiHeadTransformerBlock : HiddenSequence -> HiddenSequence");
println!("MaskedMultiHeadTransformerBlock : HiddenSequence x AttentionMask -> HiddenSequence");
println!("TinyTransformerParameters : HiddenSequence x AttentionMask -> SequenceLogits");
println!("TransformerTrainingState owns parameters, learning rate, and step count");
println!("TransformerReadoutTrainStep : TransformerTrainingState -> TransformerTrainingState");
println!(
"TransformerFeedForwardTrainStep : TransformerTrainingState -> TransformerTrainingState"
);
println!("TransformerBlockTrainStep : TransformerTrainingState -> TransformerTrainingState");
Ok(())
}
The smaller state-only companion is:
Source snapshot: examples/07_transformer_training_state.rs
use category_theory_transformer_rs::{
AttentionMask, AttentionOutputProjection, CtResult, HiddenSequence, HiddenToKey, HiddenToQuery,
HiddenToValue, LayerNormParameters, LayerNormalization, LearningRate,
MaskedMultiHeadTransformerBlock, ModelDimension, Morphism, PositionWiseFeedForward,
PositionalEncoding, Product, SelfAttentionHead, TinyTransformerParameters, TokenSequence,
TransformerBlockTrainStep, TransformerBlockTrainingExample, TransformerBlockTrainingSet,
TransformerFeedForwardTrainStep, TransformerFeedForwardTrainingExample,
TransformerFeedForwardTrainingSet, TransformerReadout, TransformerReadoutTrainStep,
TransformerReadoutTrainingExample, TransformerReadoutTrainingSet, TransformerTrainingState,
transformer_block_average_loss, transformer_feed_forward_average_loss,
transformer_readout_average_loss,
};
fn main() -> CtResult<()> {
let hidden = HiddenSequence::new(vec![vec![1.0, 0.0], vec![0.0, 1.0]])?;
let mask = AttentionMask::new(vec![vec![true, false], vec![true, true]])?;
let targets = TokenSequence::from_indices([0, 1])?;
let initial_state = tiny_training_state()?;
let logits = initial_state.apply(Product::new(hidden.clone(), mask.clone()))?;
println!(
"initial state: step={}, learning_rate={:.3}, model_dimension={}, vocab_size={}",
initial_state.step_count().value(),
initial_state.learning_rate().value(),
initial_state.parameters().model_dimension().value(),
initial_state.parameters().vocab_size().value()
);
println!(
"forward shape: {} positions x vocabulary size {}",
logits.sequence_len().value(),
logits.vocab_size().value()
);
let readout_set =
TransformerReadoutTrainingSet::new([TransformerReadoutTrainingExample::new(
hidden.clone(),
mask.clone(),
targets.clone(),
)?])?;
let readout_loss_before = transformer_readout_average_loss(&initial_state, &readout_set)?;
let readout_state =
TransformerReadoutTrainStep::new(readout_set.clone()).apply(initial_state)?;
let readout_loss_after = transformer_readout_average_loss(&readout_state, &readout_set)?;
print_update(
"readout update",
0,
&readout_state,
readout_loss_before.value(),
readout_loss_after.value(),
);
let feed_forward_set =
TransformerFeedForwardTrainingSet::new([TransformerFeedForwardTrainingExample::new(
hidden.clone(),
HiddenSequence::new(vec![vec![2.0, 0.0], vec![0.0, 2.0]])?,
)?])?;
let feed_forward_loss_before =
transformer_feed_forward_average_loss(&readout_state, &feed_forward_set)?;
let feed_forward_state =
TransformerFeedForwardTrainStep::new(feed_forward_set.clone()).apply(readout_state)?;
let feed_forward_loss_after =
transformer_feed_forward_average_loss(&feed_forward_state, &feed_forward_set)?;
print_update(
"feed-forward update",
1,
&feed_forward_state,
feed_forward_loss_before.value(),
feed_forward_loss_after.value(),
);
let block_set = TransformerBlockTrainingSet::new([TransformerBlockTrainingExample::new(
hidden, mask, targets,
)?])?;
let block_loss_before = transformer_block_average_loss(&feed_forward_state, &block_set)?;
let block_state =
TransformerBlockTrainStep::new(block_set.clone()).apply(feed_forward_state)?;
let block_loss_after = transformer_block_average_loss(&block_state, &block_set)?;
print_update(
"composed block update",
2,
&block_state,
block_loss_before.value(),
block_loss_after.value(),
);
println!();
println!("Typed transformation:");
println!("TinyTransformerParameters : HiddenSequence x AttentionMask -> SequenceLogits");
println!("TransformerTrainingState owns parameters, learning rate, and step count");
println!("TransformerReadoutTrainStep : TransformerTrainingState -> TransformerTrainingState");
println!(
"TransformerFeedForwardTrainStep : TransformerTrainingState -> TransformerTrainingState"
);
println!("TransformerBlockTrainStep : TransformerTrainingState -> TransformerTrainingState");
println!("Every update returns a full training state, not loose changed weights.");
Ok(())
}
fn print_update(
label: &str,
previous_step: usize,
state: &TransformerTrainingState,
loss_before: f32,
loss_after: f32,
) {
println!(
"{label}: step {} -> {}, loss {:.6} -> {:.6}",
previous_step,
state.step_count().value(),
loss_before,
loss_after
);
}
fn tiny_training_state() -> CtResult<TransformerTrainingState> {
let model_dimension = ModelDimension::new(2)?;
let block = MaskedMultiHeadTransformerBlock::new(
vec![
SelfAttentionHead::new(
HiddenToQuery::new(vec![vec![1.0], vec![0.0]], vec![0.0])?,
HiddenToKey::new(vec![vec![1.0], vec![0.0]], vec![0.0])?,
HiddenToValue::new(vec![vec![1.0], vec![0.0]], vec![0.0])?,
)?,
SelfAttentionHead::new(
HiddenToQuery::new(vec![vec![0.0], vec![1.0]], vec![0.0])?,
HiddenToKey::new(vec![vec![0.0], vec![1.0]], vec![0.0])?,
HiddenToValue::new(vec![vec![0.0], vec![1.0]], vec![0.0])?,
)?,
],
AttentionOutputProjection::new(vec![vec![1.0, 0.0], vec![0.0, 1.0]], vec![0.0, 0.0])?,
LayerNormalization::new(LayerNormParameters::identity(model_dimension)),
PositionWiseFeedForward::new(
vec![vec![1.0, 0.0], vec![0.0, 1.0]],
vec![0.0, 0.0],
vec![vec![1.0, 0.0], vec![0.0, 1.0]],
vec![0.0, 0.0],
)?,
LayerNormalization::new(LayerNormParameters::identity(model_dimension)),
)?;
let parameters = TinyTransformerParameters::new(
PositionalEncoding::new(vec![vec![0.1, 0.0], vec![0.0, 0.1]])?,
block,
TransformerReadout::new(
vec![vec![1.0, 0.0, 0.5], vec![0.0, 1.0, -0.5]],
vec![0.0, 0.0, 0.0],
)?,
)?;
Ok(TransformerTrainingState::new(
parameters,
LearningRate::new(0.1)?,
))
}
Run it with:
cargo run --example 06_attention_scores
If you only want to inspect the training-state update shape, run the smaller companion example:
cargo run --example 07_transformer_training_state
You should see two query positions and three key positions. Query and key vectors first produce score rows. The mask removes one illegal score position. Then each row is normalized independently, and the weights mix value vectors:
QuerySequence x KeySequence -> AttentionScores
AttentionScores x AttentionMask -> AttentionScores
AttentionScores -> AttentionWeights
AttentionWeights x ValueSequence -> AttentionOutput
AttentionHeadOutputs -> MultiHeadOutput
MultiHeadOutput -> ProjectedAttentionOutput
HiddenSequence x ProjectedAttentionOutput -> HiddenSequence
HiddenSequence -> HiddenSequence
That is the bridge from the current softmax chapter. The current Distribution
answers:
which next token is likely?
Attention weights answer:
which source positions should this position read from?
Both are probability-like objects. They differ in what their support means.
ML Concept
Attention computes:
scores = QK^T / sqrt(d)
weights = softmax(scores)
output = weights V
This is softmax again, but applied to token-to-token interaction scores.
Category Theory Concept
The attention block is a composition of typed maps with a product input:
(Q, K, V) -> scores -> weights -> mixed values
Design contract:
Attention should have a positive test showing that valid shapes compose and a negative test showing that mismatched head dimensions, mask shapes, or value lengths are rejected at construction or composition time.
Step 4: Multi-Head Concatenation
The current problem:
One attention head sees one interaction pattern. Multiple heads let the model carry several patterns in parallel. Their outputs must be recombined without losing shape information.
Rust Syntax
The recombination boundary is:
AttentionHeadOutputs -> MultiHeadOutput
HeadCount rejects zero. AttentionHeadOutputs rejects an empty collection,
sequence-length mismatches, and head-dimension mismatches. MultiHeadOutput
records:
sequence length
head count
head dimension
model dimension
This boundary is not the whole block by itself. It is the place where separate head outputs become one wider object before the output projection.
Worked Example: Concatenate Heads, Then Project
Multi-head attention adds one shape calculation that readers should be able to do without a framework:
model_dimension = head_count * head_dimension
In the runnable attention example, the first head has two output features per query position:
head_0 query_0 = [2.0, 20.0]
head_0 query_1 = [2.2033, 22.0334]
The example then adds a second head with the same sequence length and head dimension:
head_1 query_0 = [10.0, 1.0]
head_1 query_1 = [20.0, 2.0]
Concatenation does not average the heads. It places their feature rows side by side:
query_0 multi-head row = [2.0, 20.0, 10.0, 1.0]
query_1 multi-head row = [2.2033, 22.0334, 20.0, 2.0]
The recorded shape is:
2 heads x 2 features = model dimension 4
The next boundary is the learned output projection:
MultiHeadOutput -> ProjectedAttentionOutput
For the first query row, the tiny projection in the example uses:
input row:
[2.0, 20.0, 10.0, 1.0]
projection rows:
[1.0, 0.0]
[0.0, 0.1]
[0.5, 0.0]
[0.0, 1.0]
The projected row is:
first output feature = 2.0 * 1.0 + 20.0 * 0.0 + 10.0 * 0.5 + 1.0 * 0.0 = 7.0
second output feature = 2.0 * 0.0 + 20.0 * 0.1 + 10.0 * 0.0 + 1.0 * 1.0 = 3.0
projected query_0 = [7.0, 3.0]
This is the reason the repository keeps two separate boundaries:
AttentionHeadOutputs -> MultiHeadOutput
MultiHeadOutput -> ProjectedAttentionOutput
The first boundary proves that head outputs can be concatenated. The second boundary proves that the concatenated width matches the projection input width. If either relationship fails, the model should fail at the boundary, before a later residual connection receives the wrong shape.
ML Concept
Each head performs attention separately.
The outputs are concatenated, then an output projection maps the combined row back into the hidden-state width. This repository now implements that projection as a separate typed boundary.
Category Theory Concept
This is parallel composition followed by recombination:
head_1 x head_2 x ... x head_n -> MultiHeadOutput
Design contract:
HeadCount, HeadDimension, and ModelDimension are not bare usize values.
The arithmetic relationship between them is part of the architecture: if there
are two heads of width two, the concatenated model dimension is four.
Step 5: Output Projection
The current problem:
Concatenated heads are wider than a single head. A later Transformer block expects a coherent hidden width again.
Rust Syntax
The current code models the projection as:
MultiHeadOutput -> ProjectedAttentionOutput
AttentionOutputProjection validates:
non-empty weight rows
non-empty bias
finite weight and bias values
weight rows matching output dimension
input dimension matching MultiHeadOutput model dimension
That shape follows the multi-head attention reference path: heads are concatenated, then another learned linear projection produces the output sequence.
ML Concept
The projection is a learned linear map after concatenation. It lets the model mix features across heads and return to the width expected by the surrounding block.
Category Theory Concept
This is another typed morphism:
MultiHeadOutput -> ProjectedAttentionOutput
Design contract:
The projection should fail before multiplication if the concatenated head width does not match the projection’s input dimension. That is a boundary invariant, not an indexing accident.
Step 6: Residual Addition
The current problem:
Transformer sublayers need to add their output back to the hidden sequence they received.
Rust Syntax
The current code models residual addition as:
HiddenSequence x ProjectedAttentionOutput -> HiddenSequence
ResidualConnection rejects sequence-length mismatches and model-dimension
mismatches before adding rows. This follows the Transformer requirement that a
sublayer output must have the same dimension as its input for residual addition
to be feasible.
Worked Example: Residual Addition Needs The Same Shape
The previous worked example ended with a projected attention row:
projected query_0 = [7.0, 3.0]
Residual addition can only happen because that projected row has the same model dimension as the hidden row it will be added to:
hidden query_0 = [0.5, 0.5]
projected query_0 = [7.0, 3.0]
residual query_0 = [7.5, 3.5]
The second row follows the same rule:
hidden query_1 = [1.0, 1.0]
projected query_1 = [12.2033, 4.2033]
residual query_1 = [13.2033, 5.2033]
The shape is preserved:
HiddenSequence:
2 positions x model dimension 2
ProjectedAttentionOutput:
2 positions x model dimension 2
Residual HiddenSequence:
2 positions x model dimension 2
This is why the output projection matters. Multi-head concatenation produced a four-feature row. Residual addition needs a two-feature row because the hidden sequence has model dimension two in this tiny example. The projection is the bridge that makes the residual boundary legal.
The invalid shortcut would be:
HiddenSequence x MultiHeadOutput -> HiddenSequence
For the example above, that would try to add a two-feature hidden row to a four-feature concatenated row. The repository avoids that by making the legal path explicit:
MultiHeadOutput -> ProjectedAttentionOutput
HiddenSequence x ProjectedAttentionOutput -> HiddenSequence
ML Concept
Residual addition preserves the hidden sequence shape while allowing a sublayer to contribute a learned change:
hidden + sublayer_output
The repository currently implements the addition boundary, the layer normalization boundary, the position-wise feed-forward boundary, and compact single-head and multi-head block boundaries.
Category Theory Concept
The residual boundary consumes a product object and returns the same hidden sequence object:
HiddenSequence x ProjectedAttentionOutput -> HiddenSequence
For a fixed block value, the larger unmasked block shape is still an endomorphism:
HiddenSequence -> HiddenSequence
Design contract:
Residual addition should fail before addition if either the sequence length or model dimension differs. Without that check, a later Transformer block would be silently mixing incompatible objects.
Step 7: Layer Normalization
The current problem:
After residual addition, each sequence position should be normalized across its feature dimension while preserving the hidden sequence shape.
Rust Syntax
The current code models layer normalization as:
HiddenSequence -> HiddenSequence
LayerNormParameters validates non-empty scale and shift vectors, matching
parameter lengths, finite parameter values, and positive finite epsilon.
LayerNormalization rejects hidden sequences whose model dimension does not
match its parameter dimension.
ML Concept
Layer normalization recenters and rescales each hidden vector across its feature dimension. It is batch-size independent and preserves the sequence shape:
same positions
same model dimension
new normalized values
The Layer Normalization paper is useful here because it frames the operation
around statistics inside a single training case rather than statistics
collected across a batch. In this roadmap’s tiny Rust object, that means each
row can be normalized while the public HiddenSequence boundary remains the
same.
Category Theory Concept
The normalization boundary is an endomorphism:
HiddenSequence -> HiddenSequence
Read that as a forward call for one fixed LayerNormalization value. Its
scale, shift, and epsilon are already stored in the layer. If those values are
being learned, the changing object is the larger training state, not the
hidden sequence alone.
Design contract:
Normalization should not change the object type. If a later block expects a hidden sequence, the normalized result should still be a hidden sequence.
Step 8: Position-Wise Feed-Forward
The current problem:
After attention and normalization, each sequence position needs a learned non-linear transformation that preserves the hidden sequence shape.
Rust Syntax
The current code models this as:
HiddenSequence -> HiddenSequence
PositionWiseFeedForward validates two linear layers:
model dimension -> feed-forward hidden dimension -> model dimension
It also checks finite weights, finite biases, and compatible intermediate dimensions before any row is projected.
ML Concept
A position-wise feed-forward network applies the same two-layer non-linear map to each hidden vector independently:
hidden row -> expanded row -> activated row -> hidden row
It changes feature values, not the sequence length or public model dimension.
That is why the public boundary stays:
HiddenSequence -> HiddenSequence
Category Theory Concept
The feed-forward sublayer is another endomorphism:
HiddenSequence -> HiddenSequence
Read that the same way: for this fixed PositionWiseFeedForward value, the
call receives a hidden sequence and returns a hidden sequence. The layer’s
weights and biases are context already stored inside the Rust object. Training
those weights belongs to a state update, not to this forward boundary.
It is not the whole Transformer block. It is the next shape-preserving sublayer that a later block can compose.
Design contract:
The second linear layer must return to the original model dimension. Otherwise the next residual or block boundary would receive the wrong object.
Worked Example: Values Change, Shape Stays HiddenSequence
The residual example produced this hidden sequence:
residual query_0 = [7.5, 3.5]
residual query_1 = [13.2033, 5.2033]
Layer normalization changes the row values while keeping the public object the same:
normalized query_0 = [0.9999988, -0.9999988]
normalized query_1 = [0.99999976, -0.99999976]
The sequence still has two positions and model dimension two:
Residual HiddenSequence
2 positions x model dimension 2
LayerNormalization
HiddenSequence -> HiddenSequence
Normalized HiddenSequence
2 positions x model dimension 2
The feed-forward sublayer then applies the same two-layer map to each position. In this tiny example, the ReLU step clips the negative feature before the second linear layer returns to the public model dimension:
feed-forward query_0 = [0.9999988, 0.0]
feed-forward query_1 = [0.99999976, 0.0]
Again, the object has not changed:
HiddenSequence
2 positions x model dimension 2
PositionWiseFeedForward
HiddenSequence -> HiddenSequence
HiddenSequence
2 positions x model dimension 2
The values changed twice. The sequence length and model dimension did not.
That distinction matters because normalization and feed-forward computation are not new sequence objects in this roadmap. They are shape-preserving maps over hidden rows:
Residual HiddenSequence -> LayerNormalization -> HiddenSequence
HiddenSequence -> PositionWiseFeedForward -> HiddenSequence
The invalid mental shortcut is:
normalization creates a special normalized object
feed-forward creates a special feed-forward object
The useful engineering view is stricter:
both stages return HiddenSequence so the next block boundary can compose
Step 9: Positional Encoding
The current problem:
Self-attention sees a set of hidden rows. A sequence model also needs to know where each row sits in the sequence.
Rust Syntax
The current code models position as another shape-preserving morphism:
PositionalEncoding : HiddenSequence -> HiddenSequence
The encoding table validates non-empty finite rows and a fixed model dimension. Applying it rejects hidden sequences that are too long for the table or have the wrong model width.
ML Concept
Position information is added to hidden vectors before attention so identical tokens in different positions can become distinguishable to later transformations.
Category Theory Concept
For one fixed positional-encoding table, the public shape is still an endomorphism:
HiddenSequence -> HiddenSequence
Read that as a forward call with the encoding table already selected. If a
future chapter learns, swaps, or rebuilds the position table, that changing
context must be named separately instead of being hidden inside the
HiddenSequence -> HiddenSequence arrow.
Design contract:
Adding position should change values, not the hidden sequence object. If the position table has the wrong width or not enough rows, composition should fail before attention starts.
Step 10: Single-Head And Multi-Head Blocks
The current problem:
A block should compose attention, residual addition, normalization, and feed-forward computation while keeping the public shape simple.
Rust Syntax
The current single-head sketch has shape:
HiddenSequence -> HiddenSequence
It uses:
SingleHeadTransformerBlock
MultiHeadTransformerBlock
The single-head sketch proves the compact block boundary. The multi-head sketch
extends that boundary by collecting several SelfAttentionHead values,
concatenating their outputs, and applying the output projection.
ML Concept
Transformer blocks combine:
attention
residual connection
normalization
feed-forward network
The block output has the same shape as the input.
This is where the current training chapter becomes useful again. A block with
shape HiddenSequence -> HiddenSequence can be stacked for the same reason a
training step with shape Parameters -> Parameters can be repeated: output and
input live in the same object.
Category Theory Concept
For one fixed single-head or multi-head block value, this is another endomorphism:
HiddenSequence -> HiddenSequence
Stacking layers is repeated endomorphism application.
If the block’s heads, projections, normalization parameters, or feed-forward weights are changing, the boundary is no longer only this forward call. The changing object is the larger training state:
TransformerTrainingState -> TransformerTrainingState
Design contract:
Internal complexity does not leak into every caller. The single-head and multi-head sketches contain several sublayers, but callers see one typed boundary.
For the multi-head sketch, the output-projection input dimension must equal:
head_count * value_head_dimension
That check is the difference between a typed block and a loose pile of matrix multiplications.
Step 11: Masked Blocks
The current problem:
Some sequence positions should not attend to other positions. The block boundary needs a mask, not only the lower-level score operation.
Rust Syntax
The current code models the masked block as:
MaskedMultiHeadTransformerBlock : HiddenSequence x AttentionMask -> HiddenSequence
The mask is part of the input product. The block applies it after query-key scoring and before row-wise softmax for each head.
ML Concept
A mask controls which source positions each query position may use. The same shape-preserving block can now represent selective attention.
Category Theory Concept
This is a product-to-object morphism:
HiddenSequence x AttentionMask -> HiddenSequence
Design contract:
The mask must have the same query and key dimensions as the score table inside the block. If the shape does not match, the block fails before softmax.
Worked Example: Fixed Mask Versus Open Mask
The masked block is where stacking language can become imprecise. The unmasked multi-head block has the simple public shape:
MultiHeadTransformerBlock : HiddenSequence -> HiddenSequence
That boundary can compose directly with another boundary of the same shape. The masked block is different:
MaskedMultiHeadTransformerBlock : HiddenSequence x AttentionMask -> HiddenSequence
While the mask is still an open input, the boundary is not a unary endomorphism. It needs the hidden sequence and the mask. There are two precise ways to use it repeatedly:
| Use case | Boundary to name | Why this is precise |
|---|---|---|
| the caller supplies a mask for each block call | HiddenSequence x AttentionMask -> HiddenSequence | the mask remains visible as required context |
| one example fixes a specific mask before applying the block | HiddenSequence -> HiddenSequence under fixed mask context | the unary map is induced by a named fixed mask |
The second row is useful in a lesson or a single training example, but only if the fixed context is named. The mask did not disappear. It became part of the chosen environment for that run.
The invalid shortcut is:
MaskedMultiHeadTransformerBlock returns HiddenSequence, so it is automatically
an endomorphism.
The output type is not enough. Count the whole input object. The open masked boundary has product input. A fixed-mask view can be treated as a unary shape-preserving map only after the mask has been selected and kept stable for that call path.
Step 12: Structured State For Training And Evaluation
The current problem:
Once the model has attention parameters, evaluation and future training need one structured object instead of loose matrices passed through the code.
Rust Syntax
The earlier training chapter used:
Parameters
The roadmap code now adds the structured attention-side version:
TransformerReadout : HiddenSequence -> SequenceLogits
TinyTransformerParameters : HiddenSequence x AttentionMask -> SequenceLogits
TransformerTrainingState
TransformerReadoutTrainStep : TransformerTrainingState -> TransformerTrainingState
TransformerFeedForwardTrainStep : TransformerTrainingState -> TransformerTrainingState
TinyTransformerParameters owns:
positional encoding
masked multi-head block
sequence readout
TransformerTrainingState owns that parameter object plus LearningRate and
StepCount. Its record_updated_parameters method records a new parameter
object and increments the step count.
The roadmap code also adds three supervised updates:
TransformerReadoutTrainStep : TransformerTrainingState -> TransformerTrainingState
TransformerFeedForwardTrainStep : TransformerTrainingState -> TransformerTrainingState
TransformerBlockTrainStep : TransformerTrainingState -> TransformerTrainingState
The readout step updates the sequence readout with a softmax cross-entropy gradient. The feed-forward step updates the position-wise feed-forward sublayer against hidden-sequence targets. The block step composes those ideas: it starts from sequence targets, computes readout gradients, carries the hidden gradient through the final layer-normalization and residual boundary, then through the attention-normalization and residual boundary, and updates the feed-forward sublayer, attention output projection, query/key/value projections, and both layer-normalization scale/shift vectors from the same supervised example. These are real gradient steps, but deliberately tiny ones.
ML Concept
A Transformer training loop still has the same outer structure:
predict
compute loss
backpropagate
update parameters
The internal model is richer, so the parameter object must be richer. A useful training state keeps three questions separate:
what parameters define the model?
what optimizer settings control the update?
which update step are we on?
The current code answers those questions structurally, then adds a small full-batch gradient through the current trainable block components. The readout update answers the first smaller question:
if the hidden sequence is fixed, can the vocabulary readout learn?
Yes. The update computes probabilities at each sequence position, subtracts one from the target class probability term, accumulates weight and bias gradients for the readout, applies the learning rate, and increments the step count.
The local feed-forward update asks a different small question:
if the attention output is treated as fixed, can the feed-forward sublayer
learn a hidden-sequence target?
It computes a squared-error gradient through the two feed-forward linear layers and the ReLU between them. That teaches a real block-internal update without pretending to use token targets.
The composed block update asks the next question:
if the model predicts target tokens, can the readout loss also update the
feed-forward sublayer through the final residual and normalization path?
Yes. The update follows the actual forward cache used by the block, computes the softmax cross-entropy gradient at the readout, applies the standard layer normalization backward pass for the final normalization boundary, splits the residual path, passes through the feed-forward sublayer and attention normalization boundary, then updates the readout, feed-forward parameters, attention output projection, and both layer-normalization scale/shift vectors together. It also backpropagates through value mixing, attention softmax, and scaled query-key scores to update the query, key, and value projections.
Worked Example: Three Updates, One State Shape
The attention example prints one structured state transition first:
training state step: 0 -> 1
The smaller training-state example isolates the same contract:
initial state: step=0, learning_rate=0.100, model_dimension=2, vocab_size=3
readout update: step 0 -> 1
feed-forward update: step 1 -> 2
composed block update: step 2 -> 3
That line is small, but it carries the whole contract:
TransformerTrainingState
owns TinyTransformerParameters
owns LearningRate
owns StepCount
An update is not allowed to return loose matrices. It must return another
TransformerTrainingState, because the next update needs the same state shape.
The readout-only step asks the smallest supervised question:
fixed hidden sequence
-> vocabulary logits
-> sequence loss
-> updated readout parameters
loss: 0.499085 -> 0.456495
The state shape is unchanged:
TransformerReadoutTrainStep
TransformerTrainingState -> TransformerTrainingState
The local feed-forward step asks a different question:
fixed hidden-sequence input
-> feed-forward output
-> squared-error hidden target
-> updated feed-forward parameters
loss: 0.250000 -> 0.160633
The same outer shape holds:
TransformerFeedForwardTrainStep
TransformerTrainingState -> TransformerTrainingState
The composed block step asks the broader question:
hidden sequence and mask
-> attention block
-> readout logits
-> sequence loss
-> updated block and readout parameters
loss: 0.456495 -> 0.409737
Again, the outside of the system is stable:
TransformerBlockTrainStep
TransformerTrainingState -> TransformerTrainingState
The internal gradient path grows from readout-only, to local feed-forward, to a composed block update. The public training shape does not grow:
state_0 -> state_1 -> state_2 -> state_3
That is the engineering version of the earlier endomorphism idea. A training step may touch different fields, but it should return the same kind of state so the loop can keep running.
The invalid shortcut would be:
readout update returns readout weights
feed-forward update returns feed-forward weights
block update returns a bag of changed matrices
That makes the next training step guess how to rebuild the model. The roadmap uses one structured state object instead, so every update must preserve the state boundary.
Gradient Evidence Ledger
The block training step has finite-difference tests. They are important, but they are not magic certificates. Read them as local evidence checks.
The test shape is:
one selected parameter
-> perturb it by +epsilon and -epsilon
-> measure two nearby losses
-> compute a central finite difference
one training step
-> compare before and after parameter values
-> infer the gradient used by the update
Those two paths should agree for the selected parameter:
central finite difference of loss ~= inferred update gradient
CS231n uses gradient checking this way: compare a numerical gradient with an
analytic gradient, preferably with a centered finite-difference formula and
careful error interpretation. PyTorch’s gradcheck documentation gives the
framework version of the same idea: finite differences are compared with
analytical gradients, and the result depends on tolerance, precision,
differentiability, and memory-layout assumptions.
The Rust roadmap keeps the claim smaller:
| Test family | Parameter path checked | What it can catch | What it cannot prove |
|---|---|---|---|
transformer_block_train_step_matches_finite_difference_for_readout_weight | sequence readout weight | wrong sign, missing target-class gradient, wrong averaging scale | correctness of attention gradients |
transformer_block_train_step_matches_finite_difference_for_feed_forward_weight | feed-forward weight | dropped ReLU or hidden-layer path | correctness of every feed-forward configuration |
transformer_block_train_step_matches_finite_difference_for_layer_norm_parameter | normalization scale or shift | wrong layer-normalization backward path | correctness of all normalization behavior |
transformer_block_train_step_matches_finite_difference_for_attention_projection | query, key, value, or output projection weight | dropped attention projection path | correctness of every attention variant |
| bias finite-difference tests | readout, feed-forward, output projection, or attention projection bias | missing bias gradient | correctness of all trainable fields |
The category-theory reading is also modest. These tests compare two local morphisms around one selected coordinate:
loss measurement around current state
parameter update inside TransformerTrainingState -> TransformerTrainingState
They support the implementation of the current state endomorphism. They do not prove that every possible dataset, learning rate, optimizer, mask, sequence length, or future Transformer block is correct.
Use this decision rule when reading a gradient-check result:
match -> local evidence for this parameter path
mismatch -> inspect sign, scaling, dropped path, nonsmooth point, or tolerance
Do not respond to a mismatch by only loosening the tolerance. First ask which typed boundary or gradient path failed.
Category Theory Concept
The forward path is now a typed morphism:
HiddenSequence x AttentionMask -> SequenceLogits
The training updates all have the same endomorphism shape:
TransformerTrainingState -> TransformerTrainingState
The block step is more global than the readout-only and local feed-forward steps because the loss starts at vocabulary logits and reaches an internal sublayer, the attention output projection, query/key/value projections, and both layer-normalization parameter sets. The update still preserves the same outer endomorphism shape even as the internal gradient path becomes richer.
Design contract:
The parameter object separates substructures:
position information
attention block
language-model readout
That separation is pedagogical and architectural. A reader can point at one field and say which mathematical role it plays. A future optimizer can update the same object without erasing the roles.
Core Mental Model
The current course teaches the typed skeleton:
TokenId -> Vector -> Logits -> Distribution
Distribution x TokenId -> Loss
Parameters -> Parameters
A Transformer extension grows the middle:
TokenSequence
-> HiddenSequence
-> HiddenSequence with position
-> QuerySequence x KeySequence x ValueSequence
-> AttentionOutput
-> HiddenSequence
-> SingleHeadTransformerBlock
-> MultiHeadTransformerBlock
-> SequenceLogits
-> sequence-level probabilities
The practical rule stays the same:
Make every intermediate object explicit, then compose only arrows whose types actually match.
Where This Leaves Us
The roadmap keeps the book honest. The current implementation is a tiny next-token system, not a production Transformer. Its value is that it gives the future system a typed foundation: tokens become vectors, vectors become logits, logits become probabilities, probabilities become loss, and training updates parameters through a repeatable endomorphism.
A future Transformer should extend that foundation with stronger optimizer checks, more realistic datasets, and clearer diagrams. Each new concept should enter the codebase the same way the current concepts did: as a named type, a validated boundary, a typed morphism, a compiled example, and a law or regression test where the concept has a law worth checking.
Roadmap Reference Path
Use References as a staged path. This roadmap comes first because it says what the current code has and does not have. The original Transformer paper then gives the architectural target. Dive into Deep Learning gives the practical sequence from attention scoring to full Transformer blocks. Implementation and visual tutorials help translate paper notation into code structure and diagrams.
After that, return to this repository and add one typed boundary at a time. A future contribution should not start by copying a full architecture into one large module. It should start by making one Transformer concept explicit enough to construct, compose, test, and explain.
The central question for every future contribution is:
What invalid Transformer state should this type make harder to express?
Terminal Output Checkpoint Map
The companion example prints many lines because it is acting as a shape lab. Before reading the typed transformation list, group the terminal output into checkpoints.
| Printed checkpoint | What changed | What stayed true | Boundary to protect |
|---|---|---|---|
attention shape: 2 query positions x 3 key positions | query-key scoring produced one row per query and one column per key | score rows are still tied to query positions | QuerySequence x KeySequence -> AttentionScores |
query 0 attends with [0.5, 0.0, 0.5] | masked scores became row-wise weights | the masked key position has zero contribution | AttentionScores x AttentionMask -> AttentionScores -> AttentionWeights |
query 0 output vector [2.0, 20.0] | weights mixed value rows into one output row | one output row still belongs to one query position | AttentionWeights x ValueSequence -> AttentionOutput |
multi-head shape: 2 heads x 2 features -> model dimension 4 | separate head outputs were concatenated | sequence length stayed two positions | AttentionHeadOutputs -> MultiHeadOutput |
projected attention shape: 2 positions x model dimension 2 | concatenated width returned to model width | each row can now rejoin the residual stream | MultiHeadOutput -> ProjectedAttentionOutput |
residual shape: 2 positions x model dimension 2 | projected attention was added to the input hidden rows | public object is still HiddenSequence | HiddenSequence x ProjectedAttentionOutput -> HiddenSequence |
normalized shape and feed-forward shape | values changed inside each row | sequence length and model dimension stayed stable | HiddenSequence -> HiddenSequence |
structured transformer logits shape | hidden rows became vocabulary scores per position | logits are not probabilities yet | HiddenSequence x AttentionMask -> SequenceLogits |
training state step: 0 -> 1 | parameters and step count advanced | the training object stayed whole | TransformerTrainingState -> TransformerTrainingState |
readout loss after one update | readout parameters learned from token targets | the update still returns full training state | readout endomorphism |
feed-forward loss after one local update | the feed-forward sublayer learned a hidden target | the update still returns full training state | local feed-forward endomorphism |
block loss after one composed update | readout, feed-forward, attention projections, and normalization parameters moved together | the outer state shape stayed stable | composed block endomorphism |
Use this map to avoid a common Transformer-reading mistake: treating every printed vector as “attention.” The output actually moves through three different ideas:
where to read
-> what information to read
-> how to return to the hidden-state stream
Then the training lines ask a separate question:
which parameters moved, and did the update preserve the state object?
Example Output Transfer Checklist
After running the companion example, read the printed transformation list as a boundary report. For each line, ask four questions:
What Rust object is being produced?
What ML role does it play?
What shape must remain true?
What shortcut would break the next composition?
| Example output line | Boundary to own | Shortcut to reject |
|---|---|---|
HiddenSequence -> QuerySequence | Hidden rows become question-like vectors for scoring. | Reusing raw hidden rows as queries without a named projection. |
HiddenSequence -> KeySequence | The same hidden rows become comparison vectors. | Treating keys and queries as the same role because they share a source. |
HiddenSequence -> ValueSequence | The same hidden rows become information vectors to be mixed. | Mixing keys or queries as if they were values. |
QuerySequence x KeySequence -> AttentionScores | Scores are unnormalized similarity numbers. | Reading scores as probabilities or final attention output. |
AttentionScores x AttentionMask -> AttentionScores | The mask removes illegal positions before normalization. | Applying softmax first, then hiding positions after probability mass has already moved. |
AttentionScores -> AttentionWeights | Row-wise softmax turns scores into weights. | Combining values before a normalized weight object exists. |
AttentionWeights x ValueSequence -> AttentionOutput | Values are mixed only after weights exist. | Asking scores to carry both similarity and information. |
AttentionHeadOutputs -> MultiHeadOutput | Head outputs concatenate, so width is head_count * head_dimension. | Pretending the concatenated width is already the model dimension. |
MultiHeadOutput -> ProjectedAttentionOutput | The output projection returns concatenated heads to model width. | Adding unprojected multi-head output directly to the residual stream. |
HiddenSequence x ProjectedAttentionOutput -> HiddenSequence | Residual addition preserves sequence length and model dimension. | Adding two objects whose row widths do not match. |
LayerNormalization : HiddenSequence -> HiddenSequence | Values change while the public hidden-sequence shape stays stable. | Treating normalization as a projection to a new domain object. |
PositionWiseFeedForward : HiddenSequence -> HiddenSequence | Each row may expand internally, but the sublayer returns model width. | Letting the hidden expansion leak past the sublayer boundary. |
TransformerTrainingState -> TransformerTrainingState | Training updates parameters while preserving learning rate and step count. | Returning only changed weights and forcing the next step to reconstruct state. |
This checklist is the transfer bridge from paper notation and framework documentation to this repository’s Rust style. The original Transformer uses query, key, value, masking, softmax, value mixing, multi-head concatenation, output projection, residual paths, normalization, and feed-forward sublayers. Diagrammatic attention research also treats attention as something that can be decomposed into recurring components before variants are compared. Framework APIs compress much of that into one function call. This chapter uncompresses the path so the reader can point at each intermediate Rust type and say what invalid connection it prevents.
Category Shape Diagnostic
The printed Transformer path uses several category-theory shapes that look similar if you only read the arrows. Before naming a boundary, ask two questions:
How many inputs does this boundary require?
Does it return the same public object, or a different object?
Those two questions prevent a common mistake: calling every shape-preserving
line an endomorphism. A true endomorphism in this book has the form
A -> A. A product-input boundary such as A x B -> A may return the same
object as its left input, but it still needs extra information.
There is a second kind of extra information: learned parameters stored inside
a layer object. When this diagnostic names
LayerNormalization : HiddenSequence -> HiddenSequence,
PositionWiseFeedForward : HiddenSequence -> HiddenSequence, or
MultiHeadTransformerBlock : HiddenSequence -> HiddenSequence, read it as a
forward call for one fixed layer or block instance. The scale, shift, weights,
and biases are already inside that Rust value. If those parameters are being
changed, the boundary has moved to training state:
TransformerTrainingState -> TransformerTrainingState
That distinction keeps the chapter honest. It allows a fixed module call to be shape-preserving without pretending parameter learning has disappeared.
They also prevent a second mistake: importing an advanced categorical name too early. Research on self-attention as a parametric endofunctor is useful for the linear portions of self-attention, especially query, key, value, positional, and layered structure. It does not make the whole pedagogical block a single endofunctor in this book. Softmax, masking, residual addition, normalization, feed-forward refinement, and training state each still need their own typed boundary.
Research on the anatomy of attention supports the opposite teaching move: decompose attention first, then compare variants. In this book, the decomposition is not a full diagrammatic formalism. It is a Rust teaching contract: every component must have a named type, a boundary shape, and a failure it prevents.
Categorical deep-learning research also separates architecture constraints from implementations. That distinction is useful here because the Rust code is an implementation witness for one small boundary at a time, not a proof that a future full Transformer satisfies every intended constraint. A good chapter claim should say which side it is on:
architecture constraint:
what should remain true?
implementation boundary:
which Rust type, constructor, example, or test currently enforces it?
Do not call the whole block an endofunctor when the explanation only checked one internal linear path. In this chapter, use the smaller safe name first: ordinary morphism, product-input morphism, shape-preserving endomorphism, state endomorphism, or illegal attempted composition.
The decision flow is:
flowchart TD
B["Boundary shape"] --> T{"Does it type-check?"}
T -->|"no"| I["Illegal attempted composition: name the missing conversion"]
T -->|"yes"| C{"How many inputs are visible?"}
C -->|"one input"| O{"Same whole source and target object?"}
O -->|"yes"| E["Endomorphism: A -> A"]
O -->|"no"| M["Ordinary morphism: A -> B"]
C -->|"product input"| F{"Was one context fixed first?"}
F -->|"yes"| U["Induced unary view: name the fixed context"]
F -->|"no"| P["Product-input morphism: keep A x B visible"]
The same naming rule as a compact rendered math view:
[ \begin{array}{rcl} A \to B &:& \text{ordinary morphism} \ A \to A &:& \text{endomorphism} \ A \times B \to C &:& \text{product-input morphism} \ A \times B \to A &:& \text{not automatically an endomorphism} \ A \xrightarrow{f_b} A &:& \text{fixed-context induced endomorphism, after } b \text{ is fixed} \end{array} ]
How to read this diagram:
- count the visible inputs before naming the category shape,
- compare the whole source object with the whole target object,
- fix context explicitly before using a unary view,
- reject same-output shortcuts that ignore product inputs.
Read the diagram from top to bottom before naming an attention boundary. It is only a local naming aid, but it prevents three common shortcuts:
| Shortcut | Safer move |
|---|---|
| output shape matches, so the boundary is an endomorphism | compare the whole source object with the target object |
| the product can be read as one source, so the boundary is an endomorphism | check whether the target is the same product object |
| context was fixed in prose, so the original boundary had one input | name the open boundary first, then name the fixed context |
Two-Minute Classification Drill
Before reading the longer table, classify these boundaries yourself. Cover the right column, count the inputs, then decide whether the output returns to the same whole object.
| Boundary | Question to ask first | Safe classification |
|---|---|---|
HiddenSequence -> QuerySequence | one input, different output object? | ordinary morphism |
AttentionScores x AttentionMask -> AttentionScores | two inputs, returns the left object? | product-input morphism returning the score object |
LayerNormalization : HiddenSequence -> HiddenSequence | one input, same output object? | shape-preserving endomorphism |
HiddenSequence x ProjectedAttentionOutput -> HiddenSequence | two inputs, returns the left object? | product-input morphism returning hidden state |
TransformerTrainingState -> TransformerTrainingState | one input, same whole training object? | state endomorphism |
The trap is the second and fourth rows. Returning the left-hand object is not
the same as being an endomorphism. A boundary that still needs a mask, value
sequence, projected sublayer output, dataset, or learning rate is not a pure
A -> A story until that context is explicitly fixed.
Source-Target Audit Card
Use this card when a row still feels ambiguous. Do not start from the output type. Name the whole source object, then name the target object.
| Boundary | Whole source object | Target object | Context status | Safe conclusion |
|---|---|---|---|---|
HiddenSequence -> QuerySequence | HiddenSequence | QuerySequence | no extra context in the boundary | ordinary morphism |
AttentionScores x AttentionMask -> AttentionScores | AttentionScores x AttentionMask | AttentionScores | mask is open context | product-input morphism, not an endomorphism on scores |
MaskedMultiHeadTransformerBlock[M] : HiddenSequence -> HiddenSequence | HiddenSequence | HiddenSequence | one mask M was fixed first | induced endomorphism for that named mask |
HiddenSequence x ProjectedAttentionOutput -> HiddenSequence | HiddenSequence x ProjectedAttentionOutput | HiddenSequence | residual input is open context | product-input morphism returning hidden state |
TransformerTrainingState -> TransformerTrainingState | TransformerTrainingState | TransformerTrainingState | update context is inside the state object | state endomorphism |
The second and fourth rows are unary only if you choose to regard the product as one source object, but they are still not endomorphisms. Their targets are not the same product object. The fixed-mask row is different because the mask has been selected before the remaining call.
Linear Scope Diagnostic
Use this when an external source gives a categorical reading of self-attention. First ask which part of the attention path the source actually classified.
| Attention part | Boundary in this roadmap | Safe reading here |
|---|---|---|
| query projection | HiddenSequence -> QuerySequence | linear role-producing morphism |
| key projection | HiddenSequence -> KeySequence | linear role-producing morphism |
| value projection | HiddenSequence -> ValueSequence | linear role-producing morphism |
| score construction | QuerySequence x KeySequence -> AttentionScores | product-input boundary |
| mask application | AttentionScores x AttentionMask -> AttentionScores | product-input boundary, not a pure score endomorphism |
| score normalization | AttentionScores -> AttentionWeights | nonlinear normalization boundary |
| value mixing | AttentionWeights x ValueSequence -> AttentionOutput | product-input boundary |
| residual addition | HiddenSequence x ProjectedAttentionOutput -> HiddenSequence | product-input boundary returning hidden state |
| layer normalization for a fixed layer instance | HiddenSequence -> HiddenSequence | shape-preserving but nonlinear endomorphism |
| parameter-changing training update | TransformerTrainingState -> TransformerTrainingState | state endomorphism over the whole training object |
Safe rule:
If a claim was checked for linear Q/K/V maps, do not carry it through softmax,
masking, residual addition, layer normalization, feed-forward refinement, or
training state without naming the next boundary.
This keeps the chapter usable for two readers at once. The category-theory reader sees where a stronger formal story might attach. The ML engineer sees which implemented boundary still needs its own shape, invariant, and test.
Worked Classification: Same Output, Different Shape
The most tempting mistake is to look only at the output type. Three boundaries
below all end with HiddenSequence, but they do not have the same category
shape.
| Boundary | Count the inputs | Classification | Why |
|---|---|---|---|
LayerNormalization : HiddenSequence -> HiddenSequence | one input | endomorphism | the whole input object and output object are both HiddenSequence |
HiddenSequence x ProjectedAttentionOutput -> HiddenSequence | two inputs | product-input morphism returning HiddenSequence | residual addition needs both the old stream and the projected sublayer output |
HiddenSequence x MultiHeadOutput -> HiddenSequence | two inputs, wrong second object | illegal attempted boundary | residual addition needs projected model-width output, not raw concatenated heads |
The first boundary can safely be named an endomorphism in this book. The second
cannot, even though it returns HiddenSequence, because the full input object
is not HiddenSequence; it is a product of two objects. The third should not
receive a category-theory name yet. It is missing the output projection that
makes the residual path well typed.
This gives a short decision tree:
Does the boundary type-check?
no -> name the missing conversion first
yes -> count the inputs
one input -> compare input object and output object
two inputs -> keep the product-input boundary visible
Use this naming rule before reading the table:
1. Count the inputs.
2. If there is one input, compare the input object and output object.
3. If there is a product input, keep the product in the name.
4. If a required projection or conversion is missing, call it illegal before
giving it a category-theory name.
That gives five safe cases:
| Shape | Safe name | Example |
|---|---|---|
A -> B | ordinary morphism | AttentionScores -> AttentionWeights |
A -> A | endomorphism | LayerNormalization : HiddenSequence -> HiddenSequence |
A x B -> C | product-input morphism | AttentionWeights x ValueSequence -> AttentionOutput |
A x B -> A | product-input morphism returning A | HiddenSequence x ProjectedAttentionOutput -> HiddenSequence |
A x B -> A with the wrong B | illegal attempted boundary | HiddenSequence x MultiHeadOutput -> HiddenSequence |
The fourth row is the trap. Returning the left object is not enough to make a boundary an endomorphism. The whole input must be one object, and the output must be that same object.
There is a second safe reading that is useful but different. You may choose to treat the product itself as one source object:
(A x B) -> A
That makes the arrow unary from the product object, but it still is not an
endomorphism. The source object is A x B; the target object is A. An
endomorphism on the product would have shape:
(A x B) -> (A x B)
This is why the roadmap keeps the phrase “product-input morphism returning
A” instead of shortening it to “endomorphism on A.”
Terminal Output Audit: Shape Line Is Not Boundary Shape
The runnable example prints several lines with the same public dimensions:
cargo run --example 06_attention_scores
Those lines are useful evidence, but they are not category names by themselves. A printed shape tells you something about the target object. The typed transformation line tells you the whole source object and the target object.
| Printed output line | What the line proves | What it does not prove | Boundary to name |
|---|---|---|---|
projected attention shape: 2 positions x model dimension 2 | raw head output has been projected back to model width | the residual connection has already happened | MultiHeadOutput -> ProjectedAttentionOutput |
residual shape: 2 positions x model dimension 2 | the result has returned to hidden-sequence shape | residual addition was unary | HiddenSequence x ProjectedAttentionOutput -> HiddenSequence |
masked multi-head block shape: 2 positions x model dimension 2 | the block output can feed the next hidden-sequence layer | the open masked block is a pure endomorphism | MaskedMultiHeadTransformerBlock : HiddenSequence x AttentionMask -> HiddenSequence |
training state step: 0 -> 1 | the update returns a state that can be updated again | training is a loose Loss -> Parameters shortcut | TransformerTrainingState -> TransformerTrainingState |
Use this three-step audit whenever the terminal output seems to settle the category name too quickly:
printed shape line -> evidence about the target object
typed transformation line -> evidence about the source and target objects
category name -> only after both source and target are known
That is why residual shape and masked multi-head block shape can both show
model-width hidden rows while still having different safe category readings.
Same printed dimensions are not the same boundary.
Stackability With Context
Stacking means the output of one boundary can feed the next boundary without inventing missing inputs. A direct endomorphism can stack by itself. A product-input boundary can stack only if the extra context is carried along or fixed explicitly.
For learned sublayers, “direct” means the layer instance is fixed for the
forward call. A different LayerNormalization or PositionWiseFeedForward
value is a different morphism. Changing those parameters is training-state
work, so the safe outer name is TransformerTrainingState -> TransformerTrainingState.
| Boundary | Can it stack directly as HiddenSequence -> HiddenSequence? | Safe reading |
|---|---|---|
LayerNormalization : HiddenSequence -> HiddenSequence | yes | direct shape-preserving endomorphism |
MultiHeadTransformerBlock : HiddenSequence -> HiddenSequence | yes | direct block endomorphism |
MaskedMultiHeadTransformerBlock : HiddenSequence x AttentionMask -> HiddenSequence | no, not while the mask is open | product-input morphism that still needs mask context |
fixed-mask view of MaskedMultiHeadTransformerBlock | yes, for that named mask context | induced endomorphism after context is fixed |
TransformerTrainingState -> TransformerTrainingState | yes | state endomorphism over the whole training object |
This is the same discipline as the rest of the chapter. Do not erase context to make a category name fit. If a mask, dataset, learning rate, or parameter object is part of the boundary, either keep it in the type shape or say exactly where it was fixed.
Context Fixing Drill
The open masked block and a fixed-mask view are related, but they are not the same boundary:
open boundary:
MaskedMultiHeadTransformerBlock : HiddenSequence x AttentionMask -> HiddenSequence
fixed-context boundary:
choose one mask M
MaskedMultiHeadTransformerBlock[M] : HiddenSequence -> HiddenSequence
The fixed-context boundary is a new named view after selecting M.
It is not a claim that the original block had only one input. The source of
the context must stay visible in the prose, the exercise, or the type that
carries it.
| Case | What is fixed? | Safe category shape | Can it stack as HiddenSequence -> HiddenSequence? | Overclaim to avoid |
|---|---|---|---|---|
| open masked block | nothing | product-input morphism | no | “it returns HiddenSequence, so it is an endomorphism” |
| fixed-mask view | one named AttentionMask | induced endomorphism for that mask | yes, while the same mask context remains fixed | “the mask disappeared” |
| changing mask per call | the mask is supplied again each call | repeated product-input calls, or a larger state carrying the mask | only if the caller threads or fixes the context | “this is the same as a fixed-mask view” |
| residual addition | no input is fixed; both hidden stream and projected output are supplied | product-input morphism returning hidden state | no | “a binary operation is unary because the result is hidden state” |
The residual row is a negative contrast. It is not a context-fixing example. The hidden stream is still an input, and the projected sublayer output is still an input. Nothing has been selected in advance. The boundary therefore remains:
HiddenSequence x ProjectedAttentionOutput -> HiddenSequence
It returns HiddenSequence, but it does not become a unary
HiddenSequence -> HiddenSequence boundary unless one input has actually been
fixed. If the product object itself is named as the source, the arrow is unary
from that product object:
(HiddenSequence x ProjectedAttentionOutput) -> HiddenSequence
That is still not an endomorphism, because the target is not the same product object. An endomorphism on the named product would have to return the whole product again:
(HiddenSequence x ProjectedAttentionOutput)
-> (HiddenSequence x ProjectedAttentionOutput)
This is only a local teaching use of fixing context. It is not a proof that the whole attention block lives in a closed category, and it is not permission to hide arbitrary inputs. The practical rule stays simple:
name the open boundary first
name exactly what was fixed
then name the induced unary view
Rust already has a familiar mechanism for this idea: a closure can capture a value from the surrounding environment. The official Rust Book uses closures to teach how a callable value can remember environment. In this roadmap, a fixed-mask view can be read the same way:
let fixed_mask = mask.clone();
let fixed_mask_view = move |hidden: HiddenSequence| {
masked_block.apply(Product::new(hidden, fixed_mask.clone()))
};
That closure-shaped explanation is only an analogy for this local boundary. It does not change the original open type:
MaskedMultiHeadTransformerBlock : HiddenSequence x AttentionMask -> HiddenSequence
It says how one chosen AttentionMask can be captured before the remaining
call receives HiddenSequence. If a reader cannot point to the captured
fixed_mask, the text has hidden context instead of fixing it.
| Boundary | Category shape to name | Why this is the right name | Common misread |
|---|---|---|---|
QuerySequence x KeySequence -> AttentionScores | product-input morphism | scoring needs query rows and key rows | treating scores as a unary query transform |
AttentionScores x AttentionMask -> AttentionScores | product-input morphism returning the score object | the mask is extra evidence used before softmax | calling it a pure endomorphism on scores |
AttentionScores -> AttentionWeights | ordinary morphism | raw scores become normalized rows | treating weights as the same object as scores |
AttentionWeights x ValueSequence -> AttentionOutput | product-input morphism | weights decide which value rows to read | treating keys and values as interchangeable |
MultiHeadOutput -> ProjectedAttentionOutput | ordinary morphism | concatenated head width returns to model width | adding multi-head output directly to residual state |
HiddenSequence x ProjectedAttentionOutput -> HiddenSequence | product-input morphism returning hidden state | residual addition needs both the old stream and the sublayer output | calling the binary residual operation a unary endomorphism |
LayerNormalization : HiddenSequence -> HiddenSequence | shape-preserving endomorphism | values change while the public hidden object stays the same | treating normalization as a new sequence domain |
PositionWiseFeedForward : HiddenSequence -> HiddenSequence | shape-preserving endomorphism | internal width may expand, but the public object returns unchanged | leaking the internal expansion into the next block |
TransformerTrainingState -> TransformerTrainingState | state endomorphism | one update returns a complete object that can be updated again | returning only changed weights or only loss |
HiddenSequence x MultiHeadOutput -> HiddenSequence | not a legal composed boundary | residual addition needs projected model-width output | skipping the output projection |
This diagnostic is the category-theory version of shape checking. The ML question is:
what information must this stage receive?
The Rust question is:
which type should own that information before the next call?
The category-theory question is:
is this a unary morphism, a product-input morphism, an endomorphism, or not
composable yet?
If a boundary needs two objects, write both. If it returns to the same public object, say whether that return is unary or product-input. Precision here is what keeps the roadmap from turning attention into a single vague arrow.
Reader Evidence Handoff
If this diagnostic becomes unclear, the most useful report is not “attention is confusing.” The useful report names the exact rule that failed.
Use this shape:
Command: cargo run --example 06_attention_scores
Page: Transformer Roadmap -> Category Shape Diagnostic
Evidence signal: one boundary row or printed output line
Last clear idea: the last boundary name that still made sense
First unclear rule: input count, fixed context, legal composition, source role,
target role, or linear-scope limit
Smallest useful fix: one sentence, table row, diagram, or exercise check
Good evidence signals are small:
AttentionScores x AttentionMask -> AttentionScores
MaskedMultiHeadTransformerBlock[M] : HiddenSequence -> HiddenSequence
HiddenSequence x MultiHeadOutput -> HiddenSequence
query 0 attends with [0.5, 0.0, 0.5]
A report like that gives the next rewrite a concrete target: which boundary, which rule, and which reader expectation failed.
Open the chapter clarity feedback form with those fields filled from your own run or reading.
Retrieval Practice
Run the attention example before answering:
cargo run --example 06_attention_scores
Recall
Recover the named objects and boundaries before explaining them.
- Which three role objects are produced from
HiddenSequencebefore attention scores are computed? - Which printed line is the first point where attention scores become row-wise normalized weights?
- Which boundaries in the example preserve the public shape
HiddenSequence -> HiddenSequence? - Which three training steps share the outer shape
TransformerTrainingState -> TransformerTrainingState? - Which printed line tells you that multi-head width must be projected before it can rejoin the residual stream?
Explain
Use the type boundary to explain the reason for the design.
- Why must the mask act before row-wise softmax?
- Why does multi-head attention need an output projection before residual addition?
- Why is returning a full
TransformerTrainingStatesafer than returning only changed readout or feed-forward weights?
The next questions check the scope of the evidence. A local gradient check and a shape-preserving sublayer are useful only when their boundaries are named.
- Why is a finite-difference check useful for one selected parameter without proving that every future training loop is correct?
- Why does
PositionWiseFeedForward : HiddenSequence -> HiddenSequencepermit an internal hidden expansion but not an expanded public output?
Apply
Change the numbers and check whether the same typed rule still holds.
- A block has three heads and each head produces four features per position. What input width must the output projection accept?
- A feed-forward sublayer expands a model-dimension-two row to six hidden features, then returns five features. Which public boundary has been broken?
- A training step updates the feed-forward weights but drops the learning rate. Why can the next training step no longer compose safely?
- A two-token sequence has raw scores
[1.0, 9.0]for the first row, but the second position is masked out. Which object should record the forbidden position before softmax? - A learner sees the line
AttentionWeights x ValueSequence -> AttentionOutputand wants to replaceValueSequencewithKeySequence. Which ML role has been lost? - A learner sees
HiddenSequence x ProjectedAttentionOutput -> HiddenSequenceand calls it an endomorphism because the output isHiddenSequence. Which step of the naming rule corrects that mistake?
Debug
For each invalid shortcut, name the missing boundary:
HiddenSequence x MultiHeadOutput -> HiddenSequence
AttentionScores -> AttentionOutput
readout update -> changed readout weights
softmax scores -> masked weights
feed-forward hidden expansion -> next block input
finite-difference agreement -> full optimizer proof
Good answers should point back to a concrete type or transformation in this chapter, not only to a phrase such as “shape mismatch” or “training update.”