Keyboard shortcuts

Press or to navigate between chapters

Press S or / to search in the book

Press ? to show this help

Press Esc to hide this help

Transformer Roadmap

The problem this chapter solves is:

The repository name points toward Transformers, but the current code is a foundation course. This chapter explains exactly how the current objects and morphisms point toward a future attention-based model.

The current code is not a full Transformer.

It teaches the typed pieces you need first:

tokens
vectors
logits
probabilities
loss
training updates
composition

This distinction matters. A roadmap should not pretend the current crate is a production Transformer or a full sequence model. It should show how the current typed skeleton can grow without losing the discipline that made the small examples understandable.

Reader orientation: Read this chapter as an engineering migration plan, not as a promise that the current code already contains every Transformer component.

The source path for this roadmap is:

current Rust pipeline
  -> original Transformer architecture
  -> implementation-oriented attention tutorials
  -> future typed Rust milestones

The original Transformer paper introduced an architecture based on attention instead of recurrence or convolution for sequence transduction. Dive into Deep Learning gives a practical learning path through queries, keys, values, multi-head attention, self-attention, positional encoding, and the full Transformer architecture. Implementation tutorials such as The Annotated Transformer and visual explainers such as The Illustrated Transformer are useful bridges from paper notation to code and diagrams.

Framework documentation such as PyTorch’s attention and Transformer layer APIs is useful for one narrower purpose: checking the public shapes that production tools expose. This chapter does not copy those APIs. It uses them as a sanity check while keeping the teaching path smaller and typed.

The Hugging Face course also gives a useful distinction for this roadmap: architecture, checkpoint, and model are not the same idea. This repository is working on architecture pieces: named states, typed boundaries, and update rules. It is not loading a pretrained checkpoint, and it is not wrapping a large framework model output. When this chapter uses words such as HiddenSequence, AttentionWeights, or SequenceLogits, read them as tiny Rust-owned teaching objects that make the same roles inspectable.

There is also advanced category-theory work that studies attention more directly. One recent source introduces a category-theoretic diagrammatic formalism for decomposing attention mechanisms into anatomical components and comparing attention variants. Another treats the linear query, key, and value maps through a parametric categorical lens and studies how layered self-attention structure can be composed. These sources are useful precision support, but they are not a license to call every part of a Transformer the same categorical object. The parametric-endofunctor paper itself separates its linear focus from nonlinear pieces such as softmax and layer normalization. This roadmap follows the same caution: name the linear maps, product-input boundaries, shape-preserving endomorphisms, and state updates separately.

A broader categorical deep-learning source makes the same warning at the architecture level: a model can be described by constraints it should satisfy and by the implementation that realizes those constraints. This roadmap uses that distinction as a practical rule. Do not treat a compiled Rust boundary as proof that the whole architecture satisfies a mathematical constraint. Do not treat an architecture diagram as a substitute for a concrete type, constructor, example, and test.

This chapter keeps those sources in view, but it does not import their full complexity all at once. The rule is: add one typed concept only when the tiny Rust version can explain its boundary.

Chapter Outcomes

By the end of this chapter, you should be able to:

  • trace the attention example from query/key scoring through masking, softmax, value mixing, projection, residual addition, normalization, and feed-forward refinement,
  • classify Transformer boundaries by counting inputs before naming morphisms, product-input morphisms, endomorphisms, or illegal attempted compositions,
  • separate architecture constraints from implementation evidence in the tiny Rust roadmap.

What You Already Know

If you understand the current prediction path, you already know the skeleton a Transformer will extend. Tokens become vectors, vectors move through typed transformations, and probabilities feed a loss. The future work is to replace the one-token middle with sequence-aware structure.

Transformer Role Ownership Map

Before reading the implementation status table, separate the roles. A Transformer explanation becomes hard when query, key, value, score, weight, mask, and hidden-state roles all look like raw vectors or matrices. This roadmap assigns each role to a named Rust type or boundary.

Transformer roleRust ownerBoundary shapeConfusion prevented
hidden state sequenceHiddenSequencemodel-width rows over sequence positionstreating one token vector as a full sequence
query roleQuerySequence and HiddenToQueryHiddenSequence -> QuerySequencepassing values where queries are expected
key roleKeySequence and HiddenToKeyHiddenSequence -> KeySequencecomparing against value vectors instead of keys
value roleValueSequence and HiddenToValueHiddenSequence -> ValueSequencemixing scores directly instead of value vectors
raw attention scoresAttentionScoresQuerySequence x KeySequence -> AttentionScorestreating unnormalized scores as probabilities
maskAttentionMaskAttentionScores x AttentionMask -> AttentionScoresallowing illegal positions into softmax
normalized attention weightsAttentionWeightsAttentionScores -> AttentionWeightsforgetting that each query row is a distribution over source positions
value mixingAttentionOutputAttentionWeights x ValueSequence -> AttentionOutputmultiplying weights without saying what information is read
multiple headsAttentionHeadOutputs and MultiHeadOutputhead outputs -> concatenated model-width rowslosing head count and head dimension
output projectionProjectedAttentionOutputMultiHeadOutput -> ProjectedAttentionOutputleaving concatenated heads at the wrong width
residual boundaryResidualConnectionHiddenSequence x ProjectedAttentionOutput -> HiddenSequenceadding tensors that cannot return to the block input shape
layer normalizationLayerNormalizationHiddenSequence -> HiddenSequencechanging values while accidentally changing the public object
feed-forward sublayerPositionWiseFeedForwardHiddenSequence -> HiddenSequenceforgetting that the sublayer is position-wise and shape-preserving
block mask boundaryMaskedMultiHeadTransformerBlockHiddenSequence x AttentionMask -> HiddenSequencehiding the mask inside loose optional state
sequence readoutTransformerReadout and SequenceLogitsHiddenSequence -> SequenceLogitsconfusing hidden states with vocabulary scores
training stateTransformerTrainingStatestate plus learning rate plus step countpassing loose parameters without optimizer context

This table is the chapter’s first debugging tool. If a later attention formula feels vague, point to the row that owns the role. The typed roadmap should make the question concrete:

Which object owns this role?
Which boundary produces it?
Which invalid connection should fail?

Category Naming Contract

Before this chapter calls an attention boundary an endomorphism, count its inputs. The original Transformer architecture, the query-key-value teaching path, and framework attention APIs all expose the same warning: attention is not one vague arrow from a sequence to itself. Some stages need a query side and a source side. Some stages need a mask. Some stages need the previous hidden stream and a sublayer output.

Use this contract while reading the roadmap:

If the boundary has shapeName it asExampleDo not call it
A -> Bordinary morphismAttentionScores -> AttentionWeightsan endomorphism
A -> AendomorphismLayerNormalization : HiddenSequence -> HiddenSequencea product boundary
A x B -> Cproduct-input morphismAttentionWeights x ValueSequence -> AttentionOutputa unary transform
A x B -> Aproduct-input morphism returning AHiddenSequence x ProjectedAttentionOutput -> HiddenSequencean endomorphism unless the whole input object and output object are identical
missing projection or wrong roleillegal attempted compositionHiddenSequence x MultiHeadOutput -> HiddenSequencea clever shortcut

One more context rule matters for learned layers. When this roadmap writes LayerNormalization : HiddenSequence -> HiddenSequence or PositionWiseFeedForward : HiddenSequence -> HiddenSequence, it means:

for this fixed layer instance, with its current parameters already stored
inside the Rust object

If the parameters themselves are allowed to vary, name the larger boundary instead. For example, a parameter-learning story belongs to TransformerTrainingState -> TransformerTrainingState, not to a hidden sequence endomorphism that silently changes weights.

This rule keeps the category-theory vocabulary proportional to the code. The linear query, key, value, positional, and layered pieces can be compared with advanced categorical work on self-attention. Masking, softmax, residual addition, normalization, feed-forward refinement, and training updates still need their own typed boundaries in this teaching project.

Fixed-Value Endomorphism Ledger

Use this ledger whenever a roadmap boundary looks like HiddenSequence -> HiddenSequence. The shape is not enough by itself; the stored context must also be stable for the forward call.

BoundaryFixed value that makes the unary view validIf that value changes
PositionalEncoding : HiddenSequence -> HiddenSequenceone position table with fixed row count and model dimensionname the table update or rebuild path separately
LayerNormalization : HiddenSequence -> HiddenSequenceone layer-normalization value with fixed scale, shift, and epsilonmove to TransformerTrainingState -> TransformerTrainingState
PositionWiseFeedForward : HiddenSequence -> HiddenSequenceone feed-forward value with fixed weights, biases, and activation rulemove to TransformerTrainingState -> TransformerTrainingState
MultiHeadTransformerBlock : HiddenSequence -> HiddenSequenceone block value with fixed heads, projections, residual path, normalization, and feed-forward layersmove to TransformerTrainingState -> TransformerTrainingState
fixed-mask view of MaskedMultiHeadTransformerBlockone named AttentionMask selected before the hidden-sequence callreturn to HiddenSequence x AttentionMask -> HiddenSequence, or name a larger state carrying the changing mask

This table is backed by the same source roles as the precision rules below: parameter-management references explain why model components own parameters, optimizer references explain why changing parameters belongs to the training loop, and Rust closure references explain the local analogy for fixing a mask before the remaining call.

Add-Norm Order Ledger

Residual addition and layer normalization are not only two names that happen in the same neighborhood. Their order is part of the block boundary.

The original Transformer uses residual addition around each sublayer followed by layer normalization. Dive into Deep Learning teaches the same AddNorm shape as residual addition followed by layer normalization. PyTorch exposes the order as a configurable boundary: TransformerEncoderLayer has norm_first, where layer normalization can happen before attention and feed-forward operations instead of after them. Research on layer-normalization placement also treats the difference between Post-LN and Pre-LN as an optimization question, not a cosmetic rewrite.

This repository currently teaches the post-add normalization path:

attention sublayer output
-> ResidualConnection
-> attention_norm

feed-forward sublayer output
-> ResidualConnection
-> feed_forward_norm

Use this ledger when reading or extending the block:

Order questionSource signalCurrent Rust readingSafe category statement
original post-norm shapeoriginal Transformer and D2L AddNorm place normalization after residual additionResidualConnection runs before attention_norm and feed_forward_normfixed block is still HiddenSequence -> HiddenSequence
configurable framework shapePyTorch norm_first can move normalization before attention and feed-forward operationsno pre-norm block is implemented here yeta future pre-norm block needs a named constructor or type
optimization meaninglayer-normalization placement affects gradient behavior in Transformer trainingcurrent tests validate the local post-add path onlysame source and target object does not imply same morphism
teaching boundaryorder is visible in MultiHeadTransformerBlock::apply and MaskedMultiHeadTransformerBlock::apply_with_cacheresidual output is normalized before feed-forward runsdo not erase order when explaining composition

The important category-theory lesson is modest:

post-norm block : HiddenSequence -> HiddenSequence
pre-norm block  : HiddenSequence -> HiddenSequence

Those two arrows can have the same source and target while being different morphisms. Shape compatibility permits composition. It does not say the two implementations are interchangeable.

Source-Backed Precision Rules

Use this table as a citation-to-claim guard while reading the rest of the roadmap. Each source supports a local teaching rule. None of them should be used as a shortcut around the typed Rust boundary.

Source signalLocal rule in this roadmapRust evidence to inspect
Attention Is All You Need introduces the Transformer around attention instead of recurrence or convolutiontreat attention as the architecture target, not as proof that the current crate is a full Transformerexamples/06_attention_scores.rs is a shape lab, not a production model
Dive into Deep Learning: Scaled Dot Product Attention writes attention with n queries and m key-value pairskeep query-side length and source-side length visible before naming the morphismQuerySequence x KeySequence -> AttentionScores and AttentionWeights x ValueSequence -> AttentionOutput
PyTorch MultiheadAttention exposes separate query, key, and value inputs with target length L and source length Sdo not collapse self-attention and cross-attention into one vague HiddenSequence -> HiddenSequence arrowTargetHiddenSequence -> QuerySequence and SourceHiddenSequence -> KeySequence, ValueSequence are the future cross-attention shape
PyTorch scaled dot product attention says the attention mask must broadcast to the attention-weight shape, a boolean True means the element participates in attention, and a float mask is added to attention scoresa mask modifies the score table before probability normalization; it is not a token sequence and not attention weightsAttentionScores x AttentionMask -> AttentionScores runs before AttentionScores -> AttentionWeights
PyTorch Transformer and PyTorch MultiheadAttention expose mask arguments where boolean True can mean “not allowed” or “ignore this key”mask shape and mask polarity are separate ideas; translate polarity before comparing APIs with this Rust roadmapAttentionMask::new(vec![vec![true, false, true], ...]) uses true for “this source position is allowed”
TensorFlow Keras MultiHeadAttention uses query shape (B, T, dim), value/key shape (B, S, dim), mask shape (B, T, S), and a boolean attention mask where 1 means attention is allowedtreat target/query length, source/key-value length, and allow-mask polarity as framework-neutral shape evidenceAttentionMask answers which source positions each target position may read
PyTorch Transformer building blocks separates dense tensors, nested tensors, masks, scaled dot-product attention, and cross-attention concernsproduction masking and variable-length behavior are framework boundary choices; the tiny Rust mask is deliberately stricterAttentionMask::new rejects a row with no legal keys
PyTorch TransformerEncoderLayer exposes norm_first and the original encoder-layer reference shaperesidual-normalization order is a named architecture choice, not a detail to hide behind HiddenSequence -> HiddenSequenceMultiHeadTransformerBlock::apply uses post-add normalization today; a future pre-norm variant needs a named boundary
On Layer Normalization in the Transformer Architecture distinguishes Post-LN and Pre-LN Transformer variants and studies their training behaviorsame source and target object can still mean different morphisms when the internal order changeslocal tests validate the current post-add path, not every normalization-order variant
Dive into Deep Learning: Parameter Management treats parameters as named model components that can be accessed and updateda forward sublayer may be an endomorphism only for a fixed layer instance; parameter-changing claims belong to the training-state boundaryLayerNormalization stores scale and shift parameters; TransformerTrainingState owns mutable training context
CS231n Neural Networks Part 3 and PyTorch gradcheck compare numerical finite differences with analytical gradients under tolerance and precision caveatsa finite-difference match is local evidence for one selected parameter path, not proof of every gradient, dataset, optimizer, or future training looptransformer_block_train_step_matches_finite_difference_for_readout_weight, feed-forward, layer-normalization, output-projection, and attention-projection tests
Rust Book: Closures explains closures as callable values that can capture values from their surrounding environmentuse closure capture as the Rust analogy for fixing a mask context before applying a unary view`move
On the Anatomy of Attention studies attention by decomposing variants into componentsdecompose attention first, then compare variantsthe roadmap names scores, masks, weights, values, heads, projection, residuals, normalization, and feed-forward separately
Self-Attention as a Parametric Endofunctor focuses on linear self-attention structure and explicitly separates nonlinear piecesuse “endofunctor” language only after naming the linear scope; do not carry it through softmax, masking, residuals, normalization, or training state without a new argumentHiddenSequence -> QuerySequence is a linear role-producing morphism; AttentionScores x AttentionMask -> AttentionScores is still product-input context

Attention Mental Model Repair Table

Use this table before the first attention example. Each row repairs a tempting mental model with one source-backed rule and one local Rust checkpoint.

Tempting mental modelSafer modelRust checkpoint
query turns into key, then key turns into valuequery, key, and value are roles; in self-attention they are parallel projections from the same hidden source, and in cross-attention the query side and key-value side may come from different sequence objectsHiddenSequence -> QuerySequence, HiddenSequence -> KeySequence, and HiddenSequence -> ValueSequence are siblings, not a pipeline
raw scores are already attention probabilitiesscores become probabilities only after mask handling and row-wise softmax; value mixing happens after weights existAttentionScores x AttentionMask -> AttentionScores -> AttentionWeights comes before AttentionWeights x ValueSequence -> AttentionOutput
same output shape means endomorphismcount the whole source object first; A x B -> A returns A, but it is still a product-input boundary while B is openHiddenSequence x ProjectedAttentionOutput -> HiddenSequence is not the same shape as HiddenSequence -> HiddenSequence
fixing a mask means the mask disappeareda fixed-context view is a new named view after one mask has been chosen; the open boundary remains product-inputMaskedMultiHeadTransformerBlock[M] : HiddenSequence -> HiddenSequence is valid only after naming the fixed AttentionMask M

The repair pattern is:

bad shortcut -> source-backed role or shape rule -> local Rust boundary

Do this before using a category-theory label. The label should describe the boundary the reader can inspect, not the shortcut the reader is trying to remember.

If a future chapter cites a stronger categorical result, it should add the same three pieces:

source claim
local typed boundary
validation command or test

That keeps the roadmap useful for both readers: the ML reader can see which shape is being implemented, and the category-theory reader can see which formal claim is being used and where it stops.

Worked Example Priority

The roadmap now has many typed attention boundaries. A reader does not need all of them expanded at the same depth on a first pass. Use this priority table to decide which sections deserve worked examples before more implementation is added.

PriorityBoundaryWhy this comes firstEvidence to ask from a reader
1AttentionScores x AttentionMask -> AttentionScores -> AttentionWeightsreaders often confuse raw scores, masked scores, and probabilitiesCan the reader explain which positions were removed before softmax?
2HiddenSequence -> QuerySequence, KeySequence, ValueSequencequery, key, and value are numerically similar but semantically different rolesCan the reader say which role asks, which role is compared, and which role is mixed?
3AttentionWeights x ValueSequence -> AttentionOutputattention becomes useful only when weights read valuesCan the reader trace one output row as a weighted sum of value rows?
4AttentionHeadOutputs -> MultiHeadOutput -> ProjectedAttentionOutputmulti-head attention adds shape arithmetic that can hide mistakesCan the reader compute head_count * head_dimension and name the projection input width?
5HiddenSequence x ProjectedAttentionOutput -> HiddenSequenceresidual addition explains why many sublayers return to the same objectCan the reader explain why mismatched sequence length or model dimension must fail?
6HiddenSequence -> HiddenSequence for normalization and feed-forwardthese are shape-preserving sublayers, not new sequence objectsCan the reader name what changes and what stays invariant?
7TransformerTrainingState -> TransformerTrainingStatetraining is important, but it should come after forward shape ownership is clearCan the reader separate readout-only, local feed-forward, and composed block updates?

This table is not a ranking of importance. It is a ranking of teaching risk. The first three rows protect the core attention story:

roles -> scores -> masked weights -> mixed values

If a reader cannot trace that path, the later block and training sections will feel like a list of names. If the reader can trace it, residuals, normalization, feed-forward layers, and training state have a stable place to attach.

Worked Example: Mask Before Softmax

The original Transformer formula and implementation-oriented attention references put softmax after query-key scoring. Practical implementations add the attention mask to the score table before softmax. The reason is simple: only legal positions should compete for probability mass.

The runnable attention example starts with two query positions and three key positions:

let queries = QuerySequence::new(vec![vec![1.0, 0.0], vec![0.0, 1.0]])?;
let keys = KeySequence::new(vec![vec![1.0, 0.0], vec![0.0, 1.0], vec![1.0, 1.0]])?;
let values = ValueSequence::new(vec![vec![1.0, 10.0], vec![2.0, 20.0], vec![3.0, 30.0]])?;
let mask = AttentionMask::new(vec![vec![true, false, true], vec![true, true, true]])?;

For the first query row, scaled dot product produces:

raw scores:
[0.7071, 0.0000, 0.7071]

mask:
[true, false, true]

masked scores:
[0.7071, very negative, 0.7071]

row-wise softmax:
[0.5, 0.0, 0.5]

The middle key position has a real raw score, but the mask says this query position is not allowed to read it. The mask must therefore act before softmax. After softmax, the illegal position would already have received probability mass.

The value-mixing step then reads only the allowed value rows:

0.5 * [1.0, 10.0]
+ 0.0 * [2.0, 20.0]
+ 0.5 * [3.0, 30.0]
= [2.0, 20.0]

That is why the typed path is:

AttentionScores x AttentionMask -> AttentionScores
AttentionScores -> AttentionWeights
AttentionWeights x ValueSequence -> AttentionOutput

The mask boundary is not a cosmetic option. It protects the meaning of the probability row. AttentionWeights should answer:

among the positions this query may read, how much should each one contribute?

In this repository, masked-out scores become a very negative finite value instead of a non-finite value so that the pedagogical constructors can keep the “all scores are finite” invariant. The teaching meaning is the same as the standard attention implementation pattern: make disallowed positions effectively impossible before row-wise softmax.

Mask Polarity Ledger

Mask shape answers:

which query row and source column is this mask cell about?

Mask polarity answers:

does true mean allowed, or does true mean blocked?

Those are different questions. This repository chooses the smaller teaching polarity:

true  -> this query may read this source position
false -> this query may not read this source position

That choice matches the boolean mask meaning used by PyTorch’s scaled_dot_product_attention, where True means the element participates in attention. It also matches the Keras MultiHeadAttention attention-mask rule where 1 marks a query-key pair that may attend. It does not match every PyTorch attention API. In MultiheadAttention padding masks, and in the boolean masks described by torch.nn.Transformer, True can mean the position is blocked or ignored.

So translate a framework mask in two steps:

QuestionRust roadmap answerFramework caution
What is the shape?one cell per query-source score positionL x S, (B, T, S), and padding masks point at different axes
What is the polarity?true means allowedsome APIs use true to mean blocked or padding
When is it applied?before softmax, while values are still scoresafter-softmax masking would change the meaning of the probability row

The safe translation rule is:

first match the mask cells to score cells,
then translate boolean polarity,
then apply the mask before softmax

Do not carry a raw boolean mask from a framework into this Rust roadmap without stating its polarity. Two masks can have the same shape and opposite meaning.

Production Masking Caveat

The tiny AttentionMask in src/attention.rs is stricter than a production framework boundary. For example, this constructor call is rejected:

AttentionMask::new(vec![vec![false, false]])

The reason is pedagogical. In this book, every attention-weight row should mean:

among at least one legal source position, how much should each one contribute?

If a row allows no source positions, there is no probability support for that row. The constructor therefore returns:

Err(CtError::EmptyInput("attention mask row allows no keys"))

That boundary is intentionally less general than production Transformer libraries. PyTorch’s Transformer building-blocks tutorial discusses nested tensors, variable sequence lengths, padding masks, and the production problem of fully masked rows. It notes that softmax over an empty set is undefined and that newer scaled_dot_product_attention behavior returns zero output for fully masked rows.

The contrast is useful:

ConcernProduction framework boundaryTiny teaching boundary
variable sequence lengthsragged batches, padding, nested tensors, and mask ergonomicseach example uses one explicit rectangular mask
fully masked query rowframework must decide a stable output conventionconstructor rejects the row before softmax
performancefused kernels, compilation, and memory-aware representationssmall values the reader can inspect by hand

This is not a disagreement with framework behavior. It is a scope decision. It preserves the invariant that AttentionWeights is a row-wise distribution over at least one source position. A future production-oriented chapter can relax that boundary only if it also names the new output convention for rows with no legal source positions.

What Exists Now

The current model has this prediction path:

TokenId -> Vector -> Logits -> Distribution

The implementation status is:

ConceptCurrent statusReason
Token idsimplementedTokenId and TokenSequence already exist
VectorsimplementedVector is the current hidden representation
Logits and probabilitiesimplementedLinearToLogits and Softmax are executable
LossimplementedCrossEntropy evaluates prediction against target
Parameter updateimplementedTrainStep updates Parameters
Query-key score boundaryimplemented as a tiny roadmap sketchQuerySequence x KeySequence -> AttentionScores is executable
Attention mask boundaryimplemented as a tiny roadmap sketchAttentionScores x AttentionMask -> AttentionScores is executable
Attention score-to-weight boundaryimplemented as a tiny roadmap sketchAttentionScores -> AttentionWeights is executable
Value-mixing boundaryimplemented as a tiny roadmap sketchAttentionWeights x ValueSequence -> AttentionOutput is executable
Multi-head concatenation boundaryimplemented as a tiny roadmap sketchAttentionHeadOutputs -> MultiHeadOutput is executable
Attention output projection boundaryimplemented as a tiny roadmap sketchMultiHeadOutput -> ProjectedAttentionOutput is executable
Sequence hidden statesimplemented as a tiny roadmap sketchHiddenSequence is executable for residual addition
Residual addition boundaryimplemented as a tiny roadmap sketchHiddenSequence x ProjectedAttentionOutput -> HiddenSequence is executable
Layer normalization boundaryimplemented as a tiny roadmap sketchHiddenSequence -> HiddenSequence is executable through LayerNormalization
Position-wise feed-forward boundaryimplemented as a tiny roadmap sketchHiddenSequence -> HiddenSequence is executable through PositionWiseFeedForward
Hidden-to-query/key/value projectionsimplemented as a tiny roadmap sketchHiddenSequence -> QuerySequence, HiddenSequence -> KeySequence, and HiddenSequence -> ValueSequence are executable
Single-head block boundaryimplemented as a tiny roadmap sketchSingleHeadTransformerBlock : HiddenSequence -> HiddenSequence composes the current boundaries
Multi-head block boundaryimplemented as a tiny roadmap sketchMultiHeadTransformerBlock : HiddenSequence -> HiddenSequence composes several SelfAttentionHead values
Positional encodingimplemented as a tiny roadmap sketchPositionalEncoding : HiddenSequence -> HiddenSequence adds position rows while preserving shape
Masked block variantsimplemented as a tiny roadmap sketchMaskedMultiHeadTransformerBlock : HiddenSequence x AttentionMask -> HiddenSequence accepts a block-level mask
Sequence logits and readoutimplemented as a tiny roadmap sketchTransformerReadout : HiddenSequence -> SequenceLogits produces vocabulary scores at each sequence position
Structured Transformer parameter objectimplemented as a tiny roadmap sketchTinyTransformerParameters : HiddenSequence x AttentionMask -> SequenceLogits owns position, masked block, and readout pieces
Structured Transformer training stateimplemented as a tiny roadmap sketchTransformerTrainingState owns parameters, learning rate, and step count
Readout-only training stepimplemented as a tiny roadmap sketchTransformerReadoutTrainStep : TransformerTrainingState -> TransformerTrainingState updates only the sequence readout
Local feed-forward training stepimplemented as a tiny roadmap sketchTransformerFeedForwardTrainStep : TransformerTrainingState -> TransformerTrainingState updates only the position-wise feed-forward sublayer against hidden targets
Composed block training stepimplemented as a tiny roadmap sketchTransformerBlockTrainStep : TransformerTrainingState -> TransformerTrainingState updates readout, feed-forward, attention-output-projection, query/key/value, and layer-normalization parameters from sequence targets through residual, normalization, and attention paths

This table is a guardrail. When extending the project, do not present planned items as implemented content. Add the type, example, test, chapter prose, and reference link together.

Rust Syntax

The path is implemented with:

Embedding
LinearToLogits
Softmax
Compose

The main domain objects are:

TokenId
Vector
Logits
Distribution
Parameters

The training update is:

TrainStep : Parameters -> Parameters

ML Concept

This is a tiny next-token model.

It predicts from one token at a time.

The main training example is still that small. The roadmap module now sketches attention blocks and structured Transformer state, but it does not yet train a production Transformer.

Still, it already teaches the core path:

discrete token
  -> dense representation
  -> vocabulary scores
  -> next-token probabilities

Category Theory Concept

The current system teaches composition:

TokenId -> Vector -> Logits -> Distribution

and endomorphism:

Parameters -> Parameters

Those two shapes remain central in Transformers.

Step 1: Sequences As First-Class Objects

The future problem:

Attention does not operate on one token alone. It operates on a sequence of hidden states.

The current code already has TokenSequence, but that is a sequence of token ids. Attention needs a sequence of hidden vectors, usually with position and mask information attached. That is a different object with different invariants.

Worked Example: Validating Sequence Length

The first-principles Rust move is the same one used throughout the book: do not let a meaningful value travel as a raw primitive once it crosses a conceptual boundary. The roadmap module now starts with a small validating type:

#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct SequenceLength(usize);

impl SequenceLength {
    pub fn new(value: usize) -> CtResult<Self> {
        if value == 0 {
            return Err(CtError::EmptyInput("sequence length"));
        }

        Ok(Self(value))
    }

    pub fn value(&self) -> usize {
        self.0
    }
}

Self-Check

Before reading the roadmap steps, explain why a future SequenceLength should not be passed around as a bare usize.

Rust Syntax

A future extension should introduce types such as:

pub struct Position(usize);
pub struct SequenceLength(usize);
pub struct HiddenSequence(Vec<Vector>);
pub struct AttentionMask(/* validated mask representation */);

The important rule is the same as this course:

do not pass raw vectors across architectural boundaries

ML Concept

Attention needs a representation like:

[hidden_0, hidden_1, hidden_2, ...]

plus position and mask information.

Category Theory Concept

The object changes from:

Vector

to:

Sequence(Vector)

The next morphisms operate on structured sequences.

Design contract:

TokenSequence -> HiddenSequence

should not be represented as:

Vec<usize> -> Vec<Vec<f32>>

The second shape hides every domain distinction the course has worked to make visible.

Step 2: Query, Key, And Value Projections

The current problem:

Attention compares tokens by projecting hidden states into query, key, and value spaces.

The important design move is not only three matrices. It is three roles. A query vector, key vector, and value vector may share the same numeric representation, but they should not share the same Rust type once they cross a module boundary.

Rust Syntax

The current projection morphisms have shapes:

HiddenSequence -> QuerySequence
HiddenSequence -> KeySequence
HiddenSequence -> ValueSequence

Each output type should be distinct.

The current roadmap code models both the role objects and the hidden-state projection morphisms:

HiddenToQuery
HiddenToKey
HiddenToValue
QuerySequence
KeySequence
ValueSequence

Queries, keys, and values are all vectors underneath, but they have different roles.

Worked Example: Same Hidden Row, Three Roles

The query-key-value split is not about three mysterious kinds of vector. It is about three uses of a hidden state.

Start with two hidden rows:

hidden_0 = [1.0, 2.0]
hidden_1 = [3.0, 4.0]

A tiny set of projections can send the same hidden rows into three role-specific objects:

let hidden = HiddenSequence::new(vec![vec![1.0, 2.0], vec![3.0, 4.0]])?;

let to_query = HiddenToQuery::new(
    vec![vec![1.0, 0.0], vec![0.0, 1.0]],
    vec![0.0, 0.0],
)?;
let to_key = HiddenToKey::new(
    vec![vec![0.0, 1.0], vec![1.0, 0.0]],
    vec![0.0, 0.0],
)?;
let to_value = HiddenToValue::new(
    vec![vec![10.0, 0.0], vec![0.0, 10.0]],
    vec![0.0, 0.0],
)?;

For hidden_0, those projections produce:

query_0 = [1.0, 2.0]
key_0   = [2.0, 1.0]
value_0 = [10.0, 20.0]

The numbers are deliberately simple. The important lesson is the role separation:

RoleQuestion it answersUsed for
querywhat is this position looking for?compared with keys
keywhat can this source position be matched by?compared with queries
valuewhat information can this source position contribute?mixed after weights exist

If all three values were passed around as Vec<Vec<f32>>, the compiler could not help a reader notice a role mistake. ValueSequence could accidentally be fed into query-key scoring. KeySequence could accidentally be mixed as if it were content. The typed split makes that confusion harder to express.

This also explains why the attention path has two phases:

QuerySequence x KeySequence -> AttentionScores
AttentionWeights x ValueSequence -> AttentionOutput

Queries and keys decide where to look. Values provide what gets read.

Self-Attention And Cross-Attention Boundary

The current roadmap example is self-attention: query, key, and value roles all come from the same HiddenSequence before they are projected into separate role objects.

That is only one attention case.

Official framework documentation exposes a more general boundary. PyTorch’s multi-head attention API accepts query, key, and value as separate inputs. Its shape language distinguishes target sequence length L for queries from source sequence length S for keys and values. Dive into Deep Learning makes the same teaching distinction when it writes attention over n queries and m key-value pairs.

TensorFlow/Keras exposes the same split with different letters: query has target length T, value and key have source length S, and the attention mask has shape (B, T, S). That cross-framework agreement is useful because it keeps the rule from sounding like a PyTorch naming quirk. A target/query row asks a question. A source/key-value column is something that can be read.

This matters for the book because it prevents a subtle category mistake. The attention scoring boundary is not automatically:

HiddenSequence -> HiddenSequence

The more honest shape is:

Target positions x Source positions -> attention weights

or, in the current Rust vocabulary:

QuerySequence x KeySequence -> AttentionScores
AttentionWeights x ValueSequence -> AttentionOutput

Self-attention is the special case where the target positions and source positions come from the same hidden sequence:

HiddenSequence -> QuerySequence
HiddenSequence -> KeySequence
HiddenSequence -> ValueSequence

These are parallel projections, not a pipeline where queries turn into keys and keys turn into values. The shared source is what makes the case “self-attention”; the role split is still real after projection.

Cross-attention is the case where the query side and the key-value side come from different sequence objects:

TargetHiddenSequence -> QuerySequence
SourceHiddenSequence -> KeySequence
SourceHiddenSequence -> ValueSequence

The tiny repository does not implement a full cross-attention module yet. But the naming rule should already be clear:

same source for Q, K, V  -> self-attention case
separate query and key-value sources -> cross-attention case

Use this Q/K/V source diagnostic before reading a framework call:

QuestionSelf-attention answerCross-attention answer
Which sequence owns the query side?the same hidden sequencethe target hidden sequence
Which sequence owns the key side?the same hidden sequencethe source hidden sequence
Which sequence owns the value side?the same hidden sequencethe source hidden sequence
Which length counts score rows?target/query lengthtarget/query length
Which length counts score columns?source/key-value length, equal to target length in the simple self-attention casesource/key-value length, possibly different from target length

This table prevents a common framework-reading mistake. Passing the same hidden sequence into Q, K, and V means the source object is shared. It does not mean the projected query, key, and value roles have become the same role.

When you run the attention example, the first lines now anchor that diagnostic before any probabilities appear:

Q/K/V source diagnostic:
query rows own score rows; key/value rows own score columns
self-attention shares the hidden source before projection; projected roles stay distinct
mask polarity here: true = allowed, false = blocked

Use those four lines before interpreting attention shape: 2 query positions x 3 key positions. The terminal output gives the learner one inspectable signal for the source-backed rule above: score rows come from the query side, score columns come from the key-value side, and the local mask polarity must be translated before comparing the Rust example with a framework API.

PyTorch and TensorFlow/Keras use different names but expose the same shape split:

Framework cueQuery sideKey-value sideMask cue
PyTorchtarget length Lsource length Sattention weights and masks use L x S
TensorFlow/Kerastarget length Tsource length Smask shape is (B, T, S)
Rust roadmapQuerySequenceKeySequence and ValueSequenceAttentionMask says which source positions each query may read

Use the same ledger when reading the Rust types:

Ledger itemMeaning in framework docsMeaning in this roadmapCategory-shape consequence
target lengthPyTorch L, Keras Tnumber of QuerySequence rowsscore rows belong to the query-side object
source lengthPyTorch/Keras Snumber of KeySequence and ValueSequence rowsscore columns belong to the key-value source object
attention maskPyTorch L x S, Keras (B, T, S)one permission table from query rows to source rowsthe mask is context over a product boundary
attention outputtarget-side output rowsone AttentionOutput row for each query rowvalue mixing returns information to the query side

The shape ledger gives a quick sanity check:

score table rows == query positions
score table columns == key-value positions
mask cells == query-position/source-position permissions
output rows == query positions after reading values

If those four statements are not true, the explanation has probably collapsed source ownership, role ownership, or mask context too early.

Mask Role Ledger: Permissions, Not Tokens

Framework APIs make the mask shape look like another tensor argument, but the teaching question is more specific:

Which query rows may read which source columns before softmax?

That is why the roadmap names the mask separately from the token sequence, score table, and attention weights. The mask is permission context over the score table.

Mask misreadingCorrect local boundaryWhat to inspect
the mask is a shorter token sequenceAttentionScores x AttentionMask -> AttentionScoresthe score table keeps query rows and source columns
the mask directly produces probabilitiesAttentionScores -> AttentionWeights still happens after maskingquery 0 attends with [0.5, 0.0, 0.5]
the mask is hidden global stateMaskedMultiHeadTransformerBlock : HiddenSequence x AttentionMask -> HiddenSequencethe block boundary keeps the mask visible
a fixed mask means no mask existsMaskedMultiHeadTransformerBlock[M] : HiddenSequence -> HiddenSequencethe chosen mask M is fixed context for that view

The local rule is:

mask cells select legal score cells;
softmax turns remaining score rows into weights;
weights read value rows.

Do not say “the mask removes tokens” unless you also say which score cells were removed from probability competition. The source sequence still owns the value rows. The mask only says which of those rows each query is allowed to read.

The category-theory reading follows the input count. Self-attention can be wrapped inside a shape-preserving block after projection, masking, value mixing, output projection, residual addition, and normalization return to the hidden stream. The core scoring and mixing steps are still product-input morphisms. Cross-attention makes that product input impossible to ignore, because the target side and source side may have different sequence lengths.

When a framework call reports an attention mask of shape L x S, read it as a typed reminder:

for each target position, which source positions may be read?

That is why this roadmap names QuerySequence, KeySequence, ValueSequence, AttentionScores, AttentionMask, AttentionWeights, and AttentionOutput separately. The names keep target-side questions, source-side comparison, and source-side information from collapsing into one raw tensor.

ML Concept

Queries ask:

what am I looking for?

Keys answer:

what do I contain?

Values provide:

what information should be mixed?

Category Theory Concept

These are parallel morphisms out of the same object:

HiddenSequence -> QuerySequence
HiddenSequence -> KeySequence
HiddenSequence -> ValueSequence

The current attention example combines query and key roles to produce scores, then uses value roles to produce output vectors.

Design contract:

HiddenSequence -> QuerySequence
HiddenSequence -> KeySequence
HiddenSequence -> ValueSequence

should be three explicit morphisms. A single untyped vector list would make it too easy to pass values into the wrong part of the attention computation.

Step 3: Scaled Dot-Product Attention

The future problem:

Convert query-key similarity into a probability distribution over positions, then use it to mix values.

Rust Syntax

A typed shape could be:

QuerySequence x KeySequence -> AttentionScores
AttentionScores x AttentionMask -> AttentionScores
AttentionScores -> AttentionWeights
AttentionWeights x ValueSequence -> AttentionOutput
AttentionHeadOutputs -> MultiHeadOutput
MultiHeadOutput -> ProjectedAttentionOutput
HiddenSequence x ProjectedAttentionOutput -> HiddenSequence
HiddenSequence -> HiddenSequence

Read the current roadmap code through this shape trace:

flowchart LR
    H["HiddenSequence"] --> Q["QuerySequence"]
    H --> K["KeySequence"]
    H --> V["ValueSequence"]
    Q --> S["AttentionScores"]
    K --> S
    S --> M["Masked Scores"]
    Mask["AttentionMask"] --> M
    M --> W["AttentionWeights"]
    W --> O["AttentionOutput"]
    V --> O
    O --> MH["MultiHeadOutput"]
    MH --> P["ProjectedAttentionOutput"]
    H --> R["Residual HiddenSequence"]
    P --> R
    R --> N["Normalized HiddenSequence"]
    N --> FF["FeedForward HiddenSequence"]

The same attention core as a compact rendered math view:

[ \begin{array}{rcl} \mathrm{QuerySequence} \times \mathrm{KeySequence} & \to & \mathrm{AttentionScores} \ \mathrm{AttentionScores} \times \mathrm{AttentionMask} & \to & \mathrm{MaskedScores} \ \mathrm{MaskedScores} & \to & \mathrm{AttentionWeights} \ \mathrm{AttentionWeights} \times \mathrm{ValueSequence} & \to & \mathrm{AttentionOutput} \ \mathrm{AttentionOutput} & \to & \mathrm{ProjectedAttentionOutput} \ \mathrm{HiddenSequence} \times \mathrm{ProjectedAttentionOutput} & \to & \mathrm{HiddenSequence} \end{array} ]

How to read this diagram:

  • every product input means two roles must stay visible,
  • masking happens before row-wise softmax produces weights,
  • value mixing is separate from score calculation,
  • the residual step is the first row here that explicitly returns to HiddenSequence.

What to notice:

Rust reading:
each box is a named type or a named typed boundary in src/attention.rs

ML reading:
scores choose positions, weights mix values, projection and residual return to
the hidden-state width

Category-theory reading:
the middle of attention is a composition with product inputs, and the enclosing
block keeps returning to HiddenSequence

AttentionWeights should be validated like Distribution, but over sequence positions instead of vocabulary tokens.

The current roadmap code implements the query-key score boundary, the mask boundary, the score-to-weight boundary, the value-mixing boundary, the multi-head concatenation boundary, the output projection boundary, and the residual addition and normalization boundaries:

Source snapshot: src/attention.rs
//! Tiny typed attention boundary for the Transformer roadmap.
//!
//! This module does not implement a full Transformer. It makes the first small
//! attention-specific shapes explicit:
//!
//! - projected queries and keys become query-by-key scores,
//! - masks turn illegal score positions into negligible softmax inputs,
//! - query-by-key scores become row-wise attention probabilities,
//! - attention probabilities mix value vectors into output vectors,
//! - multiple head outputs concatenate into a multi-head output,
//! - the concatenated heads project back into a hidden sequence width,
//! - residual addition preserves the hidden sequence boundary,
//! - layer normalization preserves the hidden sequence boundary,
//! - a position-wise feed-forward map preserves the hidden sequence boundary,
//! - positional encoding adds position information while preserving shape,
//! - a single-head block sketch composes those boundaries end to end,
//! - a masked multi-head block accepts attention masks at the block boundary,
//! - a structured parameter object owns position, block, and readout pieces,
//! - a training-state object owns parameters, learning rate, and step count,
//! - a composed block train step updates readout, feed-forward,
//!   normalization, and attention projection parameters from sequence targets.

use crate::category::{Morphism, StepCount};
use crate::domain::{
    Distribution, LearningRate, Logits, Loss, ModelDimension, Product, TokenSequence, Vector,
    VocabSize,
};
use crate::error::{CtError, CtResult};
use crate::ml::Softmax;

const MASKED_SCORE: f32 = -1_000_000.0;

/// Number of positions in a sequence.
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct SequenceLength(usize);

impl SequenceLength {
    pub fn new(value: usize) -> CtResult<Self> {
        if value == 0 {
            return Err(CtError::EmptyInput("sequence length"));
        }

        Ok(Self(value))
    }

    pub fn value(&self) -> usize {
        self.0
    }
}

/// Number of parallel attention heads.
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct HeadCount(usize);

impl HeadCount {
    pub fn new(value: usize) -> CtResult<Self> {
        if value == 0 {
            return Err(CtError::EmptyInput("head count"));
        }

        Ok(Self(value))
    }

    pub fn value(&self) -> usize {
        self.0
    }
}

/// Width of one attention head.
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct HeadDimension(usize);

impl HeadDimension {
    pub fn new(value: usize) -> CtResult<Self> {
        if value == 0 {
            return Err(CtError::EmptyInput("head dimension"));
        }

        Ok(Self(value))
    }

    pub fn value(&self) -> usize {
        self.0
    }
}

/// Positive stabilizer used in layer normalization.
#[derive(Debug, Clone, Copy, PartialEq)]
pub struct NormalizationEpsilon(f32);

impl NormalizationEpsilon {
    pub fn new(value: f32) -> CtResult<Self> {
        if !value.is_finite() || value <= 0.0 {
            return Err(CtError::ShapeMismatch {
                op: "normalization epsilon",
                expected: "positive finite epsilon".to_string(),
                got: format!("epsilon {value}"),
            });
        }

        Ok(Self(value))
    }

    pub fn value(&self) -> f32 {
        self.0
    }
}

/// Projected query vectors for one attention head.
#[derive(Debug, Clone, PartialEq)]
pub struct QuerySequence {
    sequence_len: SequenceLength,
    head_dimension: HeadDimension,
    rows: Vec<Vector>,
}

impl QuerySequence {
    pub fn new(rows: Vec<Vec<f32>>) -> CtResult<Self> {
        let matrix = AttentionVectorRows::new("query sequence", rows)?;

        Ok(Self {
            sequence_len: matrix.sequence_len,
            head_dimension: matrix.head_dimension,
            rows: matrix.rows,
        })
    }

    pub fn sequence_len(&self) -> SequenceLength {
        self.sequence_len
    }

    pub fn head_dimension(&self) -> HeadDimension {
        self.head_dimension
    }

    pub fn rows(&self) -> &[Vector] {
        &self.rows
    }
}

/// Projected key vectors for one attention head.
#[derive(Debug, Clone, PartialEq)]
pub struct KeySequence {
    sequence_len: SequenceLength,
    head_dimension: HeadDimension,
    rows: Vec<Vector>,
}

impl KeySequence {
    pub fn new(rows: Vec<Vec<f32>>) -> CtResult<Self> {
        let matrix = AttentionVectorRows::new("key sequence", rows)?;

        Ok(Self {
            sequence_len: matrix.sequence_len,
            head_dimension: matrix.head_dimension,
            rows: matrix.rows,
        })
    }

    pub fn sequence_len(&self) -> SequenceLength {
        self.sequence_len
    }

    pub fn head_dimension(&self) -> HeadDimension {
        self.head_dimension
    }

    pub fn rows(&self) -> &[Vector] {
        &self.rows
    }
}

/// Projected value vectors for one attention head.
#[derive(Debug, Clone, PartialEq)]
pub struct ValueSequence {
    sequence_len: SequenceLength,
    head_dimension: HeadDimension,
    rows: Vec<Vector>,
}

impl ValueSequence {
    pub fn new(rows: Vec<Vec<f32>>) -> CtResult<Self> {
        let matrix = AttentionVectorRows::new("value sequence", rows)?;

        Ok(Self {
            sequence_len: matrix.sequence_len,
            head_dimension: matrix.head_dimension,
            rows: matrix.rows,
        })
    }

    pub fn sequence_len(&self) -> SequenceLength {
        self.sequence_len
    }

    pub fn head_dimension(&self) -> HeadDimension {
        self.head_dimension
    }

    pub fn rows(&self) -> &[Vector] {
        &self.rows
    }
}

/// Hidden vectors over sequence positions.
#[derive(Debug, Clone, PartialEq)]
pub struct HiddenSequence {
    sequence_len: SequenceLength,
    model_dimension: ModelDimension,
    rows: Vec<Vector>,
}

impl HiddenSequence {
    pub fn new(rows: Vec<Vec<f32>>) -> CtResult<Self> {
        let matrix = AttentionVectorRows::new("hidden sequence", rows)?;

        Ok(Self {
            sequence_len: matrix.sequence_len,
            model_dimension: ModelDimension::new(matrix.head_dimension.value())?,
            rows: matrix.rows,
        })
    }

    pub fn sequence_len(&self) -> SequenceLength {
        self.sequence_len
    }

    pub fn model_dimension(&self) -> ModelDimension {
        self.model_dimension
    }

    pub fn rows(&self) -> &[Vector] {
        &self.rows
    }
}

/// A finite table of position vectors added to hidden states.
#[derive(Debug, Clone, PartialEq)]
pub struct PositionalEncoding {
    max_sequence_len: SequenceLength,
    model_dimension: ModelDimension,
    rows: Vec<Vector>,
}

impl PositionalEncoding {
    pub fn new(rows: Vec<Vec<f32>>) -> CtResult<Self> {
        let matrix = AttentionVectorRows::new("positional encoding", rows)?;

        Ok(Self {
            max_sequence_len: matrix.sequence_len,
            model_dimension: ModelDimension::new(matrix.head_dimension.value())?,
            rows: matrix.rows,
        })
    }

    pub fn max_sequence_len(&self) -> SequenceLength {
        self.max_sequence_len
    }

    pub fn model_dimension(&self) -> ModelDimension {
        self.model_dimension
    }
}

#[derive(Debug, Clone, PartialEq)]
struct AttentionVectorRows {
    sequence_len: SequenceLength,
    head_dimension: HeadDimension,
    rows: Vec<Vector>,
}

impl AttentionVectorRows {
    fn new(kind: &'static str, rows: Vec<Vec<f32>>) -> CtResult<Self> {
        if rows.is_empty() {
            return Err(CtError::EmptyInput(kind));
        }

        let head_dimension = rows[0].len();

        if head_dimension == 0 {
            return Err(CtError::EmptyInput("attention vector row"));
        }

        for row in &rows {
            if row.len() != head_dimension {
                return Err(CtError::ShapeMismatch {
                    op: kind,
                    expected: format!("all rows have {head_dimension} columns"),
                    got: format!("row with {} columns", row.len()),
                });
            }

            if row.iter().any(|value| !value.is_finite()) {
                return Err(CtError::ShapeMismatch {
                    op: kind,
                    expected: "all vector values are finite".to_string(),
                    got: "non-finite vector value".to_string(),
                });
            }
        }

        Ok(Self {
            sequence_len: SequenceLength::new(rows.len())?,
            head_dimension: HeadDimension::new(head_dimension)?,
            rows: rows.into_iter().map(Vector::new).collect(),
        })
    }
}

/// Query-by-key scores before row-wise softmax.
#[derive(Debug, Clone, PartialEq)]
pub struct AttentionScores {
    query_len: SequenceLength,
    key_len: SequenceLength,
    rows: Vec<Logits>,
}

impl AttentionScores {
    pub fn new(rows: Vec<Vec<f32>>) -> CtResult<Self> {
        if rows.is_empty() {
            return Err(CtError::EmptyInput("attention scores"));
        }

        let key_len = rows[0].len();

        if key_len == 0 {
            return Err(CtError::EmptyInput("attention score row"));
        }

        for row in &rows {
            if row.len() != key_len {
                return Err(CtError::ShapeMismatch {
                    op: "attention scores",
                    expected: format!("all rows have {key_len} columns"),
                    got: format!("row with {} columns", row.len()),
                });
            }

            if row.iter().any(|value| !value.is_finite()) {
                return Err(CtError::ShapeMismatch {
                    op: "attention scores",
                    expected: "all score values are finite".to_string(),
                    got: "non-finite score value".to_string(),
                });
            }
        }

        Ok(Self {
            query_len: SequenceLength::new(rows.len())?,
            key_len: SequenceLength::new(key_len)?,
            rows: rows.into_iter().map(Logits::new).collect(),
        })
    }

    pub fn query_len(&self) -> SequenceLength {
        self.query_len
    }

    pub fn key_len(&self) -> SequenceLength {
        self.key_len
    }

    pub fn rows(&self) -> &[Logits] {
        &self.rows
    }
}

/// Allowed query-by-key positions before attention softmax.
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct AttentionMask {
    query_len: SequenceLength,
    key_len: SequenceLength,
    rows: Vec<Vec<bool>>,
}

impl AttentionMask {
    pub fn new(rows: Vec<Vec<bool>>) -> CtResult<Self> {
        if rows.is_empty() {
            return Err(CtError::EmptyInput("attention mask"));
        }

        let key_len = rows[0].len();

        if key_len == 0 {
            return Err(CtError::EmptyInput("attention mask row"));
        }

        for row in &rows {
            if row.len() != key_len {
                return Err(CtError::ShapeMismatch {
                    op: "attention mask",
                    expected: format!("all rows have {key_len} columns"),
                    got: format!("row with {} columns", row.len()),
                });
            }

            if !row.iter().any(|allowed| *allowed) {
                return Err(CtError::EmptyInput("attention mask row allows no keys"));
            }
        }

        Ok(Self {
            query_len: SequenceLength::new(rows.len())?,
            key_len: SequenceLength::new(key_len)?,
            rows,
        })
    }

    pub fn query_len(&self) -> SequenceLength {
        self.query_len
    }

    pub fn key_len(&self) -> SequenceLength {
        self.key_len
    }

    pub fn rows(&self) -> &[Vec<bool>] {
        &self.rows
    }
}

/// Computes scaled query-key dot-product scores.
#[derive(Debug, Clone)]
pub struct ScaledDotProductScores;

impl Morphism<Product<QuerySequence, KeySequence>, AttentionScores> for ScaledDotProductScores {
    fn name(&self) -> &'static str {
        "scaled_dot_product_scores"
    }

    fn apply(&self, input: Product<QuerySequence, KeySequence>) -> CtResult<AttentionScores> {
        let (queries, keys) = input.into_parts();
        let query_dimension = queries.head_dimension();
        let key_dimension = keys.head_dimension();

        if query_dimension != key_dimension {
            return Err(CtError::ShapeMismatch {
                op: "scaled dot-product attention scores",
                expected: format!("query head dimension {}", query_dimension.value()),
                got: format!("key head dimension {}", key_dimension.value()),
            });
        }

        let scale = (query_dimension.value() as f32).sqrt();
        let rows = queries
            .rows()
            .iter()
            .map(|query| {
                keys.rows()
                    .iter()
                    .map(|key| dot_product(query.as_slice(), key.as_slice()) / scale)
                    .collect::<Vec<_>>()
            })
            .collect::<Vec<_>>();

        AttentionScores::new(rows)
    }
}

fn dot_product(left: &[f32], right: &[f32]) -> f32 {
    left.iter()
        .zip(right.iter())
        .map(|(left, right)| left * right)
        .sum()
}

/// Applies a boolean attention mask to score rows before softmax.
#[derive(Debug, Clone)]
pub struct MaskedAttentionScores;

impl Morphism<Product<AttentionScores, AttentionMask>, AttentionScores> for MaskedAttentionScores {
    fn name(&self) -> &'static str {
        "masked_attention_scores"
    }

    fn apply(&self, input: Product<AttentionScores, AttentionMask>) -> CtResult<AttentionScores> {
        let (scores, mask) = input.into_parts();

        if scores.query_len() != mask.query_len() || scores.key_len() != mask.key_len() {
            return Err(CtError::ShapeMismatch {
                op: "masked attention scores",
                expected: format!(
                    "{} query rows x {} key columns",
                    scores.query_len().value(),
                    scores.key_len().value()
                ),
                got: format!(
                    "{} query rows x {} key columns",
                    mask.query_len().value(),
                    mask.key_len().value()
                ),
            });
        }

        let rows = scores
            .rows()
            .iter()
            .zip(mask.rows())
            .map(|(score_row, mask_row)| {
                score_row
                    .as_slice()
                    .iter()
                    .zip(mask_row)
                    .map(|(score, allowed)| if *allowed { *score } else { MASKED_SCORE })
                    .collect::<Vec<_>>()
            })
            .collect::<Vec<_>>();

        AttentionScores::new(rows)
    }
}

/// Row-wise attention probabilities over key positions.
#[derive(Debug, Clone, PartialEq)]
pub struct AttentionWeights {
    query_len: SequenceLength,
    key_len: SequenceLength,
    rows: Vec<Distribution>,
}

impl AttentionWeights {
    pub fn new(rows: Vec<Distribution>) -> CtResult<Self> {
        if rows.is_empty() {
            return Err(CtError::EmptyInput("attention weights"));
        }

        let key_len = rows[0].as_slice().len();

        if key_len == 0 {
            return Err(CtError::EmptyInput("attention weight row"));
        }

        for row in &rows {
            if row.as_slice().len() != key_len {
                return Err(CtError::ShapeMismatch {
                    op: "attention weights",
                    expected: format!("all rows have {key_len} columns"),
                    got: format!("row with {} columns", row.as_slice().len()),
                });
            }
        }

        Ok(Self {
            query_len: SequenceLength::new(rows.len())?,
            key_len: SequenceLength::new(key_len)?,
            rows,
        })
    }

    pub fn query_len(&self) -> SequenceLength {
        self.query_len
    }

    pub fn key_len(&self) -> SequenceLength {
        self.key_len
    }

    pub fn rows(&self) -> &[Distribution] {
        &self.rows
    }
}

/// Weighted value vectors, one output row per query position.
#[derive(Debug, Clone, PartialEq)]
pub struct AttentionOutput {
    sequence_len: SequenceLength,
    head_dimension: HeadDimension,
    rows: Vec<Vector>,
}

impl AttentionOutput {
    pub fn new(rows: Vec<Vec<f32>>) -> CtResult<Self> {
        let matrix = AttentionVectorRows::new("attention output", rows)?;

        Ok(Self {
            sequence_len: matrix.sequence_len,
            head_dimension: matrix.head_dimension,
            rows: matrix.rows,
        })
    }

    pub fn sequence_len(&self) -> SequenceLength {
        self.sequence_len
    }

    pub fn head_dimension(&self) -> HeadDimension {
        self.head_dimension
    }

    pub fn rows(&self) -> &[Vector] {
        &self.rows
    }
}

/// Validated collection of single-head attention outputs.
#[derive(Debug, Clone, PartialEq)]
pub struct AttentionHeadOutputs {
    head_count: HeadCount,
    sequence_len: SequenceLength,
    head_dimension: HeadDimension,
    heads: Vec<AttentionOutput>,
}

impl AttentionHeadOutputs {
    pub fn new(heads: Vec<AttentionOutput>) -> CtResult<Self> {
        if heads.is_empty() {
            return Err(CtError::EmptyInput("attention head outputs"));
        }

        let sequence_len = heads[0].sequence_len();
        let head_dimension = heads[0].head_dimension();

        for head in &heads {
            if head.sequence_len() != sequence_len {
                return Err(CtError::ShapeMismatch {
                    op: "attention head outputs",
                    expected: format!("all heads have {} sequence rows", sequence_len.value()),
                    got: format!("head with {} sequence rows", head.sequence_len().value()),
                });
            }

            if head.head_dimension() != head_dimension {
                return Err(CtError::ShapeMismatch {
                    op: "attention head outputs",
                    expected: format!("all heads have dimension {}", head_dimension.value()),
                    got: format!("head dimension {}", head.head_dimension().value()),
                });
            }
        }

        Ok(Self {
            head_count: HeadCount::new(heads.len())?,
            sequence_len,
            head_dimension,
            heads,
        })
    }

    pub fn head_count(&self) -> HeadCount {
        self.head_count
    }

    pub fn sequence_len(&self) -> SequenceLength {
        self.sequence_len
    }

    pub fn head_dimension(&self) -> HeadDimension {
        self.head_dimension
    }

    pub fn heads(&self) -> &[AttentionOutput] {
        &self.heads
    }
}

/// Concatenated output of several attention heads.
#[derive(Debug, Clone, PartialEq)]
pub struct MultiHeadOutput {
    sequence_len: SequenceLength,
    head_count: HeadCount,
    head_dimension: HeadDimension,
    model_dimension: ModelDimension,
    rows: Vec<Vector>,
}

impl MultiHeadOutput {
    fn new(
        rows: Vec<Vec<f32>>,
        head_count: HeadCount,
        head_dimension: HeadDimension,
    ) -> CtResult<Self> {
        let matrix = AttentionVectorRows::new("multi-head output", rows)?;
        let expected_dimension = head_count.value() * head_dimension.value();

        if matrix.head_dimension.value() != expected_dimension {
            return Err(CtError::ShapeMismatch {
                op: "multi-head output",
                expected: format!("row dimension {expected_dimension}"),
                got: format!("row dimension {}", matrix.head_dimension.value()),
            });
        }

        Ok(Self {
            sequence_len: matrix.sequence_len,
            head_count,
            head_dimension,
            model_dimension: ModelDimension::new(expected_dimension)?,
            rows: matrix.rows,
        })
    }

    pub fn sequence_len(&self) -> SequenceLength {
        self.sequence_len
    }

    pub fn head_count(&self) -> HeadCount {
        self.head_count
    }

    pub fn head_dimension(&self) -> HeadDimension {
        self.head_dimension
    }

    pub fn model_dimension(&self) -> ModelDimension {
        self.model_dimension
    }

    pub fn rows(&self) -> &[Vector] {
        &self.rows
    }
}

/// Output sequence after the multi-head output projection.
#[derive(Debug, Clone, PartialEq)]
pub struct ProjectedAttentionOutput {
    sequence_len: SequenceLength,
    model_dimension: ModelDimension,
    rows: Vec<Vector>,
}

impl ProjectedAttentionOutput {
    pub fn new(rows: Vec<Vec<f32>>) -> CtResult<Self> {
        let matrix = AttentionVectorRows::new("projected attention output", rows)?;

        Ok(Self {
            sequence_len: matrix.sequence_len,
            model_dimension: ModelDimension::new(matrix.head_dimension.value())?,
            rows: matrix.rows,
        })
    }

    pub fn sequence_len(&self) -> SequenceLength {
        self.sequence_len
    }

    pub fn model_dimension(&self) -> ModelDimension {
        self.model_dimension
    }

    pub fn rows(&self) -> &[Vector] {
        &self.rows
    }
}

/// Vocabulary logits for every position in a hidden sequence.
#[derive(Debug, Clone, PartialEq)]
pub struct SequenceLogits {
    sequence_len: SequenceLength,
    vocab_size: VocabSize,
    rows: Vec<Logits>,
}

impl SequenceLogits {
    pub fn new(rows: Vec<Vec<f32>>) -> CtResult<Self> {
        if rows.is_empty() {
            return Err(CtError::EmptyInput("sequence logits"));
        }

        let vocab_size = rows[0].len();

        if vocab_size == 0 {
            return Err(CtError::EmptyInput("sequence logits row"));
        }

        for row in &rows {
            if row.len() != vocab_size {
                return Err(CtError::ShapeMismatch {
                    op: "sequence logits",
                    expected: format!("all rows have {vocab_size} columns"),
                    got: format!("row with {} columns", row.len()),
                });
            }

            if row.iter().any(|value| !value.is_finite()) {
                return Err(CtError::ShapeMismatch {
                    op: "sequence logits",
                    expected: "all logit values are finite".to_string(),
                    got: "non-finite logit value".to_string(),
                });
            }
        }

        Ok(Self {
            sequence_len: SequenceLength::new(rows.len())?,
            vocab_size: VocabSize::new(vocab_size)?,
            rows: rows.into_iter().map(Logits::new).collect(),
        })
    }

    pub fn sequence_len(&self) -> SequenceLength {
        self.sequence_len
    }

    pub fn vocab_size(&self) -> VocabSize {
        self.vocab_size
    }

    pub fn rows(&self) -> &[Logits] {
        &self.rows
    }
}

/// Learned language-model readout applied to each hidden position.
#[derive(Debug, Clone, PartialEq)]
pub struct TransformerReadout {
    input_dimension: ModelDimension,
    vocab_size: VocabSize,
    weight: Vec<Vec<f32>>,
    bias: Vec<f32>,
}

impl TransformerReadout {
    pub fn new(weight: Vec<Vec<f32>>, bias: Vec<f32>) -> CtResult<Self> {
        let (input_dimension, output_dimension) =
            validate_linear_parts("transformer readout", &weight, &bias)?;

        Ok(Self {
            input_dimension,
            vocab_size: VocabSize::new(output_dimension.value())?,
            weight,
            bias,
        })
    }

    pub fn input_dimension(&self) -> ModelDimension {
        self.input_dimension
    }

    pub fn vocab_size(&self) -> VocabSize {
        self.vocab_size
    }

    pub fn weight(&self) -> &[Vec<f32>] {
        &self.weight
    }

    pub fn bias(&self) -> &[f32] {
        &self.bias
    }
}

/// Learned output projection after head concatenation.
#[derive(Debug, Clone, PartialEq)]
pub struct AttentionOutputProjection {
    input_dimension: ModelDimension,
    output_dimension: ModelDimension,
    weight: Vec<Vec<f32>>,
    bias: Vec<f32>,
}

impl AttentionOutputProjection {
    pub fn new(weight: Vec<Vec<f32>>, bias: Vec<f32>) -> CtResult<Self> {
        if weight.is_empty() {
            return Err(CtError::EmptyInput("attention output projection weight"));
        }

        if bias.is_empty() {
            return Err(CtError::EmptyInput("attention output projection bias"));
        }

        let output_dimension = bias.len();

        if bias.iter().any(|value| !value.is_finite()) {
            return Err(CtError::ShapeMismatch {
                op: "attention output projection",
                expected: "finite bias values".to_string(),
                got: "non-finite bias value".to_string(),
            });
        }

        for row in &weight {
            if row.len() != output_dimension {
                return Err(CtError::ShapeMismatch {
                    op: "attention output projection",
                    expected: format!("weight rows have {output_dimension} columns"),
                    got: format!("weight row with {} columns", row.len()),
                });
            }

            if row.iter().any(|value| !value.is_finite()) {
                return Err(CtError::ShapeMismatch {
                    op: "attention output projection",
                    expected: "finite weight values".to_string(),
                    got: "non-finite weight value".to_string(),
                });
            }
        }

        Ok(Self {
            input_dimension: ModelDimension::new(weight.len())?,
            output_dimension: ModelDimension::new(output_dimension)?,
            weight,
            bias,
        })
    }

    pub fn input_dimension(&self) -> ModelDimension {
        self.input_dimension
    }

    pub fn output_dimension(&self) -> ModelDimension {
        self.output_dimension
    }

    pub fn weight(&self) -> &[Vec<f32>] {
        &self.weight
    }

    pub fn bias(&self) -> &[f32] {
        &self.bias
    }
}

/// Scale, shift, and epsilon parameters for layer normalization.
#[derive(Debug, Clone, PartialEq)]
pub struct LayerNormParameters {
    model_dimension: ModelDimension,
    scale: Vec<f32>,
    shift: Vec<f32>,
    epsilon: NormalizationEpsilon,
}

impl LayerNormParameters {
    pub fn new(scale: Vec<f32>, shift: Vec<f32>, epsilon: NormalizationEpsilon) -> CtResult<Self> {
        if scale.is_empty() {
            return Err(CtError::EmptyInput("layer norm scale"));
        }

        if shift.is_empty() {
            return Err(CtError::EmptyInput("layer norm shift"));
        }

        if scale.len() != shift.len() {
            return Err(CtError::ShapeMismatch {
                op: "layer norm parameters",
                expected: format!("scale and shift length {}", scale.len()),
                got: format!("shift length {}", shift.len()),
            });
        }

        if scale.iter().any(|value| !value.is_finite()) {
            return Err(CtError::ShapeMismatch {
                op: "layer norm parameters",
                expected: "finite scale values".to_string(),
                got: "non-finite scale value".to_string(),
            });
        }

        if shift.iter().any(|value| !value.is_finite()) {
            return Err(CtError::ShapeMismatch {
                op: "layer norm parameters",
                expected: "finite shift values".to_string(),
                got: "non-finite shift value".to_string(),
            });
        }

        Ok(Self {
            model_dimension: ModelDimension::new(scale.len())?,
            scale,
            shift,
            epsilon,
        })
    }

    pub fn identity(model_dimension: ModelDimension) -> Self {
        Self {
            model_dimension,
            scale: vec![1.0; model_dimension.value()],
            shift: vec![0.0; model_dimension.value()],
            epsilon: NormalizationEpsilon(1e-5),
        }
    }

    pub fn model_dimension(&self) -> ModelDimension {
        self.model_dimension
    }

    pub fn scale(&self) -> &[f32] {
        &self.scale
    }

    pub fn shift(&self) -> &[f32] {
        &self.shift
    }

    pub fn epsilon(&self) -> NormalizationEpsilon {
        self.epsilon
    }
}

/// Layer normalization over each hidden vector independently.
#[derive(Debug, Clone, PartialEq)]
pub struct LayerNormalization {
    parameters: LayerNormParameters,
}

impl LayerNormalization {
    pub fn new(parameters: LayerNormParameters) -> Self {
        Self { parameters }
    }

    pub fn model_dimension(&self) -> ModelDimension {
        self.parameters.model_dimension()
    }

    pub fn parameters(&self) -> &LayerNormParameters {
        &self.parameters
    }
}

#[derive(Debug, Clone, PartialEq)]
struct FeedForwardRowCache {
    input: Vec<f32>,
    pre_activation: Vec<f32>,
    activation: Vec<f32>,
    output: Vec<f32>,
}

#[derive(Debug, Clone, PartialEq)]
struct AttentionHeadTrainingCache {
    queries: QuerySequence,
    keys: KeySequence,
    values: ValueSequence,
    weights: AttentionWeights,
    output: AttentionOutput,
}

#[derive(Debug, Clone, PartialEq)]
struct MaskedBlockTrainingCache {
    output: HiddenSequence,
    with_feed_forward: HiddenSequence,
    with_attention: HiddenSequence,
    multi_head_output: MultiHeadOutput,
    attention_heads: Vec<AttentionHeadTrainingCache>,
    feed_forward_rows: Vec<FeedForwardRowCache>,
}

/// Position-wise two-layer feed-forward sublayer.
#[derive(Debug, Clone, PartialEq)]
pub struct PositionWiseFeedForward {
    input_dimension: ModelDimension,
    hidden_dimension: ModelDimension,
    output_dimension: ModelDimension,
    first_weight: Vec<Vec<f32>>,
    first_bias: Vec<f32>,
    second_weight: Vec<Vec<f32>>,
    second_bias: Vec<f32>,
}

impl PositionWiseFeedForward {
    pub fn new(
        first_weight: Vec<Vec<f32>>,
        first_bias: Vec<f32>,
        second_weight: Vec<Vec<f32>>,
        second_bias: Vec<f32>,
    ) -> CtResult<Self> {
        let (input_dimension, hidden_dimension) = validate_linear_parts(
            "position-wise feed-forward first layer",
            &first_weight,
            &first_bias,
        )?;
        let (second_input_dimension, output_dimension) = validate_linear_parts(
            "position-wise feed-forward second layer",
            &second_weight,
            &second_bias,
        )?;

        if second_input_dimension != hidden_dimension {
            return Err(CtError::ShapeMismatch {
                op: "position-wise feed-forward",
                expected: format!("second input dimension {}", hidden_dimension.value()),
                got: format!("second input dimension {}", second_input_dimension.value()),
            });
        }

        if output_dimension != input_dimension {
            return Err(CtError::ShapeMismatch {
                op: "position-wise feed-forward",
                expected: format!("output dimension {}", input_dimension.value()),
                got: format!("output dimension {}", output_dimension.value()),
            });
        }

        Ok(Self {
            input_dimension,
            hidden_dimension,
            output_dimension,
            first_weight,
            first_bias,
            second_weight,
            second_bias,
        })
    }

    pub fn input_dimension(&self) -> ModelDimension {
        self.input_dimension
    }

    pub fn hidden_dimension(&self) -> ModelDimension {
        self.hidden_dimension
    }

    pub fn output_dimension(&self) -> ModelDimension {
        self.output_dimension
    }

    pub fn first_weight(&self) -> &[Vec<f32>] {
        &self.first_weight
    }

    pub fn first_bias(&self) -> &[f32] {
        &self.first_bias
    }

    pub fn second_weight(&self) -> &[Vec<f32>] {
        &self.second_weight
    }

    pub fn second_bias(&self) -> &[f32] {
        &self.second_bias
    }
}

#[derive(Debug, Clone, PartialEq)]
struct HiddenProjection {
    op: &'static str,
    input_dimension: ModelDimension,
    head_dimension: HeadDimension,
    weight: Vec<Vec<f32>>,
    bias: Vec<f32>,
}

impl HiddenProjection {
    fn new(op: &'static str, weight: Vec<Vec<f32>>, bias: Vec<f32>) -> CtResult<Self> {
        let (input_dimension, output_dimension) = validate_linear_parts(op, &weight, &bias)?;

        Ok(Self {
            op,
            input_dimension,
            head_dimension: HeadDimension::new(output_dimension.value())?,
            weight,
            bias,
        })
    }

    fn project(&self, input: &HiddenSequence) -> CtResult<Vec<Vec<f32>>> {
        if input.model_dimension() != self.input_dimension {
            return Err(CtError::ShapeMismatch {
                op: self.op,
                expected: format!("input dimension {}", self.input_dimension.value()),
                got: format!("input dimension {}", input.model_dimension().value()),
            });
        }

        Ok(input
            .rows()
            .iter()
            .map(|row| project_row(row.as_slice(), &self.weight, &self.bias))
            .collect::<Vec<_>>())
    }

    fn input_dimension(&self) -> ModelDimension {
        self.input_dimension
    }

    fn head_dimension(&self) -> HeadDimension {
        self.head_dimension
    }

    fn weight(&self) -> &[Vec<f32>] {
        &self.weight
    }

    fn bias(&self) -> &[f32] {
        &self.bias
    }
}

/// Learned projection from hidden states to query vectors.
#[derive(Debug, Clone, PartialEq)]
pub struct HiddenToQuery {
    projection: HiddenProjection,
}

impl HiddenToQuery {
    pub fn new(weight: Vec<Vec<f32>>, bias: Vec<f32>) -> CtResult<Self> {
        Ok(Self {
            projection: HiddenProjection::new("hidden-to-query projection", weight, bias)?,
        })
    }

    pub fn input_dimension(&self) -> ModelDimension {
        self.projection.input_dimension()
    }

    pub fn head_dimension(&self) -> HeadDimension {
        self.projection.head_dimension()
    }

    pub fn weight(&self) -> &[Vec<f32>] {
        self.projection.weight()
    }

    pub fn bias(&self) -> &[f32] {
        self.projection.bias()
    }
}

/// Learned projection from hidden states to key vectors.
#[derive(Debug, Clone, PartialEq)]
pub struct HiddenToKey {
    projection: HiddenProjection,
}

impl HiddenToKey {
    pub fn new(weight: Vec<Vec<f32>>, bias: Vec<f32>) -> CtResult<Self> {
        Ok(Self {
            projection: HiddenProjection::new("hidden-to-key projection", weight, bias)?,
        })
    }

    pub fn input_dimension(&self) -> ModelDimension {
        self.projection.input_dimension()
    }

    pub fn head_dimension(&self) -> HeadDimension {
        self.projection.head_dimension()
    }

    pub fn weight(&self) -> &[Vec<f32>] {
        self.projection.weight()
    }

    pub fn bias(&self) -> &[f32] {
        self.projection.bias()
    }
}

/// Learned projection from hidden states to value vectors.
#[derive(Debug, Clone, PartialEq)]
pub struct HiddenToValue {
    projection: HiddenProjection,
}

impl HiddenToValue {
    pub fn new(weight: Vec<Vec<f32>>, bias: Vec<f32>) -> CtResult<Self> {
        Ok(Self {
            projection: HiddenProjection::new("hidden-to-value projection", weight, bias)?,
        })
    }

    pub fn input_dimension(&self) -> ModelDimension {
        self.projection.input_dimension()
    }

    pub fn head_dimension(&self) -> HeadDimension {
        self.projection.head_dimension()
    }

    pub fn weight(&self) -> &[Vec<f32>] {
        self.projection.weight()
    }

    pub fn bias(&self) -> &[f32] {
        self.projection.bias()
    }
}

/// A tiny single-head Transformer block sketch.
#[derive(Debug, Clone, PartialEq)]
pub struct SingleHeadTransformerBlock {
    model_dimension: ModelDimension,
    query_projection: HiddenToQuery,
    key_projection: HiddenToKey,
    value_projection: HiddenToValue,
    output_projection: AttentionOutputProjection,
    attention_norm: LayerNormalization,
    feed_forward: PositionWiseFeedForward,
    feed_forward_norm: LayerNormalization,
}

impl SingleHeadTransformerBlock {
    pub fn new(
        query_projection: HiddenToQuery,
        key_projection: HiddenToKey,
        value_projection: HiddenToValue,
        output_projection: AttentionOutputProjection,
        attention_norm: LayerNormalization,
        feed_forward: PositionWiseFeedForward,
        feed_forward_norm: LayerNormalization,
    ) -> CtResult<Self> {
        let model_dimension = query_projection.input_dimension();

        validate_projection_input(
            "single-head block key projection",
            model_dimension,
            key_projection.input_dimension(),
        )?;
        validate_projection_input(
            "single-head block value projection",
            model_dimension,
            value_projection.input_dimension(),
        )?;

        if query_projection.head_dimension() != key_projection.head_dimension() {
            return Err(CtError::ShapeMismatch {
                op: "single-head block",
                expected: format!(
                    "query/key head dimension {}",
                    query_projection.head_dimension().value()
                ),
                got: format!(
                    "key head dimension {}",
                    key_projection.head_dimension().value()
                ),
            });
        }

        if output_projection.input_dimension().value() != value_projection.head_dimension().value()
        {
            return Err(CtError::ShapeMismatch {
                op: "single-head block",
                expected: format!(
                    "output projection input dimension {}",
                    value_projection.head_dimension().value()
                ),
                got: format!(
                    "output projection input dimension {}",
                    output_projection.input_dimension().value()
                ),
            });
        }

        validate_projection_input(
            "single-head block output projection",
            model_dimension,
            output_projection.output_dimension(),
        )?;
        validate_projection_input(
            "single-head block attention normalization",
            model_dimension,
            attention_norm.model_dimension(),
        )?;
        validate_projection_input(
            "single-head block feed-forward",
            model_dimension,
            feed_forward.input_dimension(),
        )?;
        validate_projection_input(
            "single-head block feed-forward normalization",
            model_dimension,
            feed_forward_norm.model_dimension(),
        )?;

        Ok(Self {
            model_dimension,
            query_projection,
            key_projection,
            value_projection,
            output_projection,
            attention_norm,
            feed_forward,
            feed_forward_norm,
        })
    }

    pub fn model_dimension(&self) -> ModelDimension {
        self.model_dimension
    }
}

/// Learned query, key, and value projections for one self-attention head.
#[derive(Debug, Clone, PartialEq)]
pub struct SelfAttentionHead {
    query_projection: HiddenToQuery,
    key_projection: HiddenToKey,
    value_projection: HiddenToValue,
}

impl SelfAttentionHead {
    pub fn new(
        query_projection: HiddenToQuery,
        key_projection: HiddenToKey,
        value_projection: HiddenToValue,
    ) -> CtResult<Self> {
        let input_dimension = query_projection.input_dimension();

        validate_projection_input(
            "self-attention head key projection",
            input_dimension,
            key_projection.input_dimension(),
        )?;
        validate_projection_input(
            "self-attention head value projection",
            input_dimension,
            value_projection.input_dimension(),
        )?;

        if query_projection.head_dimension() != key_projection.head_dimension() {
            return Err(CtError::ShapeMismatch {
                op: "self-attention head",
                expected: format!(
                    "query/key head dimension {}",
                    query_projection.head_dimension().value()
                ),
                got: format!(
                    "key head dimension {}",
                    key_projection.head_dimension().value()
                ),
            });
        }

        Ok(Self {
            query_projection,
            key_projection,
            value_projection,
        })
    }

    pub fn input_dimension(&self) -> ModelDimension {
        self.query_projection.input_dimension()
    }

    pub fn query_key_dimension(&self) -> HeadDimension {
        self.query_projection.head_dimension()
    }

    pub fn value_dimension(&self) -> HeadDimension {
        self.value_projection.head_dimension()
    }

    pub fn query_projection(&self) -> &HiddenToQuery {
        &self.query_projection
    }

    pub fn key_projection(&self) -> &HiddenToKey {
        &self.key_projection
    }

    pub fn value_projection(&self) -> &HiddenToValue {
        &self.value_projection
    }
}

/// A tiny multi-head Transformer block sketch.
#[derive(Debug, Clone, PartialEq)]
pub struct MultiHeadTransformerBlock {
    model_dimension: ModelDimension,
    head_count: HeadCount,
    value_dimension: HeadDimension,
    heads: Vec<SelfAttentionHead>,
    output_projection: AttentionOutputProjection,
    attention_norm: LayerNormalization,
    feed_forward: PositionWiseFeedForward,
    feed_forward_norm: LayerNormalization,
}

impl MultiHeadTransformerBlock {
    pub fn new(
        heads: Vec<SelfAttentionHead>,
        output_projection: AttentionOutputProjection,
        attention_norm: LayerNormalization,
        feed_forward: PositionWiseFeedForward,
        feed_forward_norm: LayerNormalization,
    ) -> CtResult<Self> {
        if heads.is_empty() {
            return Err(CtError::EmptyInput("multi-head block heads"));
        }

        let model_dimension = heads[0].input_dimension();
        let value_dimension = heads[0].value_dimension();

        for head in &heads {
            validate_projection_input(
                "multi-head block head projection",
                model_dimension,
                head.input_dimension(),
            )?;

            if head.value_dimension() != value_dimension {
                return Err(CtError::ShapeMismatch {
                    op: "multi-head block",
                    expected: format!("value head dimension {}", value_dimension.value()),
                    got: format!("value head dimension {}", head.value_dimension().value()),
                });
            }
        }

        let head_count = HeadCount::new(heads.len())?;
        let concatenated_dimension =
            ModelDimension::new(head_count.value() * value_dimension.value())?;

        validate_projection_input(
            "multi-head block output projection input",
            concatenated_dimension,
            output_projection.input_dimension(),
        )?;
        validate_projection_input(
            "multi-head block output projection",
            model_dimension,
            output_projection.output_dimension(),
        )?;
        validate_projection_input(
            "multi-head block attention normalization",
            model_dimension,
            attention_norm.model_dimension(),
        )?;
        validate_projection_input(
            "multi-head block feed-forward",
            model_dimension,
            feed_forward.input_dimension(),
        )?;
        validate_projection_input(
            "multi-head block feed-forward normalization",
            model_dimension,
            feed_forward_norm.model_dimension(),
        )?;

        Ok(Self {
            model_dimension,
            head_count,
            value_dimension,
            heads,
            output_projection,
            attention_norm,
            feed_forward,
            feed_forward_norm,
        })
    }

    pub fn model_dimension(&self) -> ModelDimension {
        self.model_dimension
    }

    pub fn head_count(&self) -> HeadCount {
        self.head_count
    }

    pub fn value_dimension(&self) -> HeadDimension {
        self.value_dimension
    }

    pub fn heads(&self) -> &[SelfAttentionHead] {
        &self.heads
    }

    fn with_heads(self, heads: Vec<SelfAttentionHead>) -> CtResult<Self> {
        Self::new(
            heads,
            self.output_projection,
            self.attention_norm,
            self.feed_forward,
            self.feed_forward_norm,
        )
    }

    fn with_feed_forward(self, feed_forward: PositionWiseFeedForward) -> CtResult<Self> {
        Self::new(
            self.heads,
            self.output_projection,
            self.attention_norm,
            feed_forward,
            self.feed_forward_norm,
        )
    }

    fn with_output_projection(
        self,
        output_projection: AttentionOutputProjection,
    ) -> CtResult<Self> {
        Self::new(
            self.heads,
            output_projection,
            self.attention_norm,
            self.feed_forward,
            self.feed_forward_norm,
        )
    }

    fn with_layer_norms(
        self,
        attention_norm: LayerNormalization,
        feed_forward_norm: LayerNormalization,
    ) -> CtResult<Self> {
        Self::new(
            self.heads,
            self.output_projection,
            attention_norm,
            self.feed_forward,
            feed_forward_norm,
        )
    }
}

/// A tiny masked multi-head Transformer block sketch.
#[derive(Debug, Clone, PartialEq)]
pub struct MaskedMultiHeadTransformerBlock {
    block: MultiHeadTransformerBlock,
}

impl MaskedMultiHeadTransformerBlock {
    pub fn new(
        heads: Vec<SelfAttentionHead>,
        output_projection: AttentionOutputProjection,
        attention_norm: LayerNormalization,
        feed_forward: PositionWiseFeedForward,
        feed_forward_norm: LayerNormalization,
    ) -> CtResult<Self> {
        Ok(Self {
            block: MultiHeadTransformerBlock::new(
                heads,
                output_projection,
                attention_norm,
                feed_forward,
                feed_forward_norm,
            )?,
        })
    }

    pub fn model_dimension(&self) -> ModelDimension {
        self.block.model_dimension()
    }

    pub fn head_count(&self) -> HeadCount {
        self.block.head_count()
    }

    pub fn value_dimension(&self) -> HeadDimension {
        self.block.value_dimension()
    }

    pub fn heads(&self) -> &[SelfAttentionHead] {
        self.block.heads()
    }

    pub fn feed_forward(&self) -> &PositionWiseFeedForward {
        &self.block.feed_forward
    }

    fn with_feed_forward(self, feed_forward: PositionWiseFeedForward) -> CtResult<Self> {
        Ok(Self {
            block: self.block.with_feed_forward(feed_forward)?,
        })
    }

    fn with_heads(self, heads: Vec<SelfAttentionHead>) -> CtResult<Self> {
        Ok(Self {
            block: self.block.with_heads(heads)?,
        })
    }

    fn with_output_projection(
        self,
        output_projection: AttentionOutputProjection,
    ) -> CtResult<Self> {
        Ok(Self {
            block: self.block.with_output_projection(output_projection)?,
        })
    }

    fn with_layer_norms(
        self,
        attention_norm: LayerNormalization,
        feed_forward_norm: LayerNormalization,
    ) -> CtResult<Self> {
        Ok(Self {
            block: self
                .block
                .with_layer_norms(attention_norm, feed_forward_norm)?,
        })
    }

    fn apply_with_training_cache(
        &self,
        hidden: HiddenSequence,
        mask: AttentionMask,
    ) -> CtResult<MaskedBlockTrainingCache> {
        if hidden.model_dimension() != self.block.model_dimension {
            return Err(CtError::ShapeMismatch {
                op: "masked multi-head block",
                expected: format!("model dimension {}", self.block.model_dimension.value()),
                got: format!("model dimension {}", hidden.model_dimension().value()),
            });
        }

        let head_caches = self
            .block
            .heads
            .iter()
            .map(|head| apply_self_attention_head_with_mask_cache(&hidden, head, Some(&mask)))
            .collect::<CtResult<Vec<_>>>()?;
        let attention_outputs = head_caches
            .iter()
            .map(|cache| cache.output.clone())
            .collect::<Vec<_>>();
        let head_outputs = AttentionHeadOutputs::new(attention_outputs)?;
        let multi_head_output = ConcatenateHeads.apply(head_outputs)?;
        let projected_attention = self
            .block
            .output_projection
            .apply(multi_head_output.clone())?;
        let with_attention = ResidualConnection.apply(Product::new(hidden, projected_attention))?;
        let normalized_attention = self.block.attention_norm.apply(with_attention.clone())?;
        let (feed_forward_output, feed_forward_rows) =
            feed_forward_with_cache(&self.block.feed_forward, &normalized_attention)?;
        let with_feed_forward =
            ResidualConnection.apply(Product::new(normalized_attention, feed_forward_output))?;
        let output = self
            .block
            .feed_forward_norm
            .apply(with_feed_forward.clone())?;

        Ok(MaskedBlockTrainingCache {
            output,
            with_feed_forward,
            with_attention,
            multi_head_output,
            attention_heads: head_caches,
            feed_forward_rows,
        })
    }
}

/// Tiny structured Transformer parameter object for the roadmap.
#[derive(Debug, Clone, PartialEq)]
pub struct TinyTransformerParameters {
    positional_encoding: PositionalEncoding,
    block: MaskedMultiHeadTransformerBlock,
    readout: TransformerReadout,
}

impl TinyTransformerParameters {
    pub fn new(
        positional_encoding: PositionalEncoding,
        block: MaskedMultiHeadTransformerBlock,
        readout: TransformerReadout,
    ) -> CtResult<Self> {
        let model_dimension = positional_encoding.model_dimension();

        validate_projection_input(
            "tiny transformer parameters block",
            model_dimension,
            block.model_dimension(),
        )?;
        validate_projection_input(
            "tiny transformer parameters readout",
            model_dimension,
            readout.input_dimension(),
        )?;

        Ok(Self {
            positional_encoding,
            block,
            readout,
        })
    }

    pub fn model_dimension(&self) -> ModelDimension {
        self.positional_encoding.model_dimension()
    }

    pub fn max_sequence_len(&self) -> SequenceLength {
        self.positional_encoding.max_sequence_len()
    }

    pub fn vocab_size(&self) -> VocabSize {
        self.readout.vocab_size()
    }

    pub fn encode(&self, hidden: HiddenSequence, mask: AttentionMask) -> CtResult<HiddenSequence> {
        let positioned = self.positional_encoding.apply(hidden)?;

        self.block.apply(Product::new(positioned, mask))
    }

    pub fn readout(&self) -> &TransformerReadout {
        &self.readout
    }

    pub fn feed_forward(&self) -> &PositionWiseFeedForward {
        self.block.feed_forward()
    }

    pub fn output_projection(&self) -> &AttentionOutputProjection {
        &self.block.block.output_projection
    }

    pub fn attention_heads(&self) -> &[SelfAttentionHead] {
        self.block.heads()
    }

    pub fn attention_norm(&self) -> &LayerNormalization {
        &self.block.block.attention_norm
    }

    pub fn feed_forward_norm(&self) -> &LayerNormalization {
        &self.block.block.feed_forward_norm
    }

    fn with_readout(self, readout: TransformerReadout) -> CtResult<Self> {
        Self::new(self.positional_encoding, self.block, readout)
    }

    fn with_feed_forward(self, feed_forward: PositionWiseFeedForward) -> CtResult<Self> {
        Self::new(
            self.positional_encoding,
            self.block.with_feed_forward(feed_forward)?,
            self.readout,
        )
    }

    fn with_attention_heads(self, heads: Vec<SelfAttentionHead>) -> CtResult<Self> {
        Self::new(
            self.positional_encoding,
            self.block.with_heads(heads)?,
            self.readout,
        )
    }

    fn with_output_projection(
        self,
        output_projection: AttentionOutputProjection,
    ) -> CtResult<Self> {
        Self::new(
            self.positional_encoding,
            self.block.with_output_projection(output_projection)?,
            self.readout,
        )
    }

    fn with_layer_norms(
        self,
        attention_norm: LayerNormalization,
        feed_forward_norm: LayerNormalization,
    ) -> CtResult<Self> {
        Self::new(
            self.positional_encoding,
            self.block
                .with_layer_norms(attention_norm, feed_forward_norm)?,
            self.readout,
        )
    }
}

/// Structured state owned by a future Transformer training loop.
#[derive(Debug, Clone, PartialEq)]
pub struct TransformerTrainingState {
    parameters: TinyTransformerParameters,
    learning_rate: LearningRate,
    step_count: StepCount,
}

impl TransformerTrainingState {
    pub fn new(parameters: TinyTransformerParameters, learning_rate: LearningRate) -> Self {
        Self::from_parts(parameters, learning_rate, StepCount::new(0))
    }

    pub fn from_parts(
        parameters: TinyTransformerParameters,
        learning_rate: LearningRate,
        step_count: StepCount,
    ) -> Self {
        Self {
            parameters,
            learning_rate,
            step_count,
        }
    }

    pub fn parameters(&self) -> &TinyTransformerParameters {
        &self.parameters
    }

    pub fn learning_rate(&self) -> LearningRate {
        self.learning_rate
    }

    pub fn step_count(&self) -> StepCount {
        self.step_count
    }

    pub fn record_updated_parameters(self, parameters: TinyTransformerParameters) -> Self {
        Self {
            parameters,
            learning_rate: self.learning_rate,
            step_count: StepCount::new(self.step_count.value() + 1),
        }
    }
}

/// One supervised sequence example for a readout-only Transformer update.
#[derive(Debug, Clone, PartialEq)]
pub struct TransformerReadoutTrainingExample {
    hidden: HiddenSequence,
    mask: AttentionMask,
    targets: TokenSequence,
}

impl TransformerReadoutTrainingExample {
    pub fn new(
        hidden: HiddenSequence,
        mask: AttentionMask,
        targets: TokenSequence,
    ) -> CtResult<Self> {
        let sequence_len = hidden.sequence_len();

        if targets.as_slice().len() != sequence_len.value() {
            return Err(CtError::ShapeMismatch {
                op: "transformer readout training targets",
                expected: format!("{} target tokens", sequence_len.value()),
                got: format!("{} target tokens", targets.as_slice().len()),
            });
        }

        if mask.query_len() != sequence_len || mask.key_len() != sequence_len {
            return Err(CtError::ShapeMismatch {
                op: "transformer readout training mask",
                expected: format!(
                    "{} query rows x {} key columns",
                    sequence_len.value(),
                    sequence_len.value()
                ),
                got: format!(
                    "{} query rows x {} key columns",
                    mask.query_len().value(),
                    mask.key_len().value()
                ),
            });
        }

        Ok(Self {
            hidden,
            mask,
            targets,
        })
    }

    pub fn hidden(&self) -> &HiddenSequence {
        &self.hidden
    }

    pub fn mask(&self) -> &AttentionMask {
        &self.mask
    }

    pub fn targets(&self) -> &TokenSequence {
        &self.targets
    }
}

/// Non-empty set of supervised readout examples.
#[derive(Debug, Clone, PartialEq)]
pub struct TransformerReadoutTrainingSet(Vec<TransformerReadoutTrainingExample>);

impl TransformerReadoutTrainingSet {
    pub fn new(
        examples: impl IntoIterator<Item = TransformerReadoutTrainingExample>,
    ) -> CtResult<Self> {
        let examples = examples.into_iter().collect::<Vec<_>>();

        if examples.is_empty() {
            return Err(CtError::EmptyInput("transformer readout training set"));
        }

        Ok(Self(examples))
    }

    pub fn examples(&self) -> &[TransformerReadoutTrainingExample] {
        &self.0
    }
}

/// One full-batch update of the sequence readout parameters.
#[derive(Debug, Clone, PartialEq)]
pub struct TransformerReadoutTrainStep {
    dataset: TransformerReadoutTrainingSet,
}

impl TransformerReadoutTrainStep {
    pub fn new(dataset: TransformerReadoutTrainingSet) -> Self {
        Self { dataset }
    }
}

impl Morphism<TransformerTrainingState, TransformerTrainingState> for TransformerReadoutTrainStep {
    fn name(&self) -> &'static str {
        "transformer_readout_train_step"
    }

    fn apply(&self, state: TransformerTrainingState) -> CtResult<TransformerTrainingState> {
        let input_dimension = state.parameters.readout().input_dimension().value();
        let vocab_size = state.parameters.vocab_size().value();
        let mut grad_weight = vec![vec![0.0; vocab_size]; input_dimension];
        let mut grad_bias = vec![0.0; vocab_size];
        let mut position_count = 0usize;

        for example in self.dataset.examples() {
            let encoded = state
                .parameters
                .encode(example.hidden().clone(), example.mask().clone())?;
            let logits = state.parameters.readout().apply(encoded.clone())?;

            for ((hidden_row, logit_row), target) in encoded
                .rows()
                .iter()
                .zip(logits.rows())
                .zip(example.targets().as_slice())
            {
                let target_index = target.index();

                if target_index >= vocab_size {
                    return Err(CtError::OutOfRange {
                        kind: "sequence target",
                        index: target_index,
                        limit: vocab_size,
                    });
                }

                let probabilities = Softmax.apply(logit_row.clone())?;
                let mut dlogits = probabilities.as_slice().to_vec();
                dlogits[target_index] -= 1.0;

                for (vocab_id, dlogit) in dlogits.iter().copied().enumerate() {
                    grad_bias[vocab_id] += dlogit;

                    for (feature, hidden_value) in hidden_row.as_slice().iter().copied().enumerate()
                    {
                        grad_weight[feature][vocab_id] += hidden_value * dlogit;
                    }
                }

                position_count += 1;
            }
        }

        let scale = state.learning_rate().value() / position_count as f32;
        let mut updated_weight = state.parameters.readout().weight().to_vec();
        let mut updated_bias = state.parameters.readout().bias().to_vec();

        for (row, grad_row) in updated_weight.iter_mut().zip(&grad_weight) {
            for (weight, grad) in row.iter_mut().zip(grad_row) {
                *weight -= scale * grad;
            }
        }

        for (bias, grad) in updated_bias.iter_mut().zip(&grad_bias) {
            *bias -= scale * grad;
        }

        let updated_readout = TransformerReadout::new(updated_weight, updated_bias)?;
        let updated_parameters = state.parameters.clone().with_readout(updated_readout)?;

        Ok(state.record_updated_parameters(updated_parameters))
    }
}

/// One supervised sequence example for local feed-forward sublayer training.
#[derive(Debug, Clone, PartialEq)]
pub struct TransformerFeedForwardTrainingExample {
    input: HiddenSequence,
    target: HiddenSequence,
}

impl TransformerFeedForwardTrainingExample {
    pub fn new(input: HiddenSequence, target: HiddenSequence) -> CtResult<Self> {
        if input.sequence_len() != target.sequence_len() {
            return Err(CtError::ShapeMismatch {
                op: "transformer feed-forward training sequence",
                expected: format!("{} target rows", input.sequence_len().value()),
                got: format!("{} target rows", target.sequence_len().value()),
            });
        }

        if input.model_dimension() != target.model_dimension() {
            return Err(CtError::ShapeMismatch {
                op: "transformer feed-forward training dimension",
                expected: format!("target dimension {}", input.model_dimension().value()),
                got: format!("target dimension {}", target.model_dimension().value()),
            });
        }

        Ok(Self { input, target })
    }

    pub fn input(&self) -> &HiddenSequence {
        &self.input
    }

    pub fn target(&self) -> &HiddenSequence {
        &self.target
    }
}

/// Non-empty set of local feed-forward training examples.
#[derive(Debug, Clone, PartialEq)]
pub struct TransformerFeedForwardTrainingSet(Vec<TransformerFeedForwardTrainingExample>);

impl TransformerFeedForwardTrainingSet {
    pub fn new(
        examples: impl IntoIterator<Item = TransformerFeedForwardTrainingExample>,
    ) -> CtResult<Self> {
        let examples = examples.into_iter().collect::<Vec<_>>();

        if examples.is_empty() {
            return Err(CtError::EmptyInput("transformer feed-forward training set"));
        }

        Ok(Self(examples))
    }

    pub fn examples(&self) -> &[TransformerFeedForwardTrainingExample] {
        &self.0
    }
}

/// One full-batch update of the position-wise feed-forward sublayer.
#[derive(Debug, Clone, PartialEq)]
pub struct TransformerFeedForwardTrainStep {
    dataset: TransformerFeedForwardTrainingSet,
}

impl TransformerFeedForwardTrainStep {
    pub fn new(dataset: TransformerFeedForwardTrainingSet) -> Self {
        Self { dataset }
    }
}

impl Morphism<TransformerTrainingState, TransformerTrainingState>
    for TransformerFeedForwardTrainStep
{
    fn name(&self) -> &'static str {
        "transformer_feed_forward_train_step"
    }

    fn apply(&self, state: TransformerTrainingState) -> CtResult<TransformerTrainingState> {
        let feed_forward = state.parameters.feed_forward();
        let mut gradients = FeedForwardGradients::new(feed_forward);
        let mut row_count = 0usize;

        for example in self.dataset.examples() {
            if example.input().model_dimension() != feed_forward.input_dimension() {
                return Err(CtError::ShapeMismatch {
                    op: "transformer feed-forward train step",
                    expected: format!("input dimension {}", feed_forward.input_dimension().value()),
                    got: format!(
                        "input dimension {}",
                        example.input().model_dimension().value()
                    ),
                });
            }

            let (_output, cache_rows) = feed_forward_with_cache(feed_forward, example.input())?;

            for (cache, target_row) in cache_rows.iter().zip(example.target().rows()) {
                let d_output = cache
                    .output
                    .iter()
                    .zip(target_row.as_slice())
                    .map(|(output_value, target_value)| output_value - target_value)
                    .collect::<Vec<_>>();
                gradients.accumulate(feed_forward, cache, &d_output);

                row_count += 1;
            }
        }

        let updated_feed_forward =
            gradients.apply_to(feed_forward, state.learning_rate(), row_count)?;
        let updated_parameters = state
            .parameters
            .clone()
            .with_feed_forward(updated_feed_forward)?;

        Ok(state.record_updated_parameters(updated_parameters))
    }
}

/// One supervised sequence example for a composed block-level update.
#[derive(Debug, Clone, PartialEq)]
pub struct TransformerBlockTrainingExample {
    hidden: HiddenSequence,
    mask: AttentionMask,
    targets: TokenSequence,
}

impl TransformerBlockTrainingExample {
    pub fn new(
        hidden: HiddenSequence,
        mask: AttentionMask,
        targets: TokenSequence,
    ) -> CtResult<Self> {
        let sequence_len = hidden.sequence_len();

        if targets.as_slice().len() != sequence_len.value() {
            return Err(CtError::ShapeMismatch {
                op: "transformer block training targets",
                expected: format!("{} target tokens", sequence_len.value()),
                got: format!("{} target tokens", targets.as_slice().len()),
            });
        }

        if mask.query_len() != sequence_len || mask.key_len() != sequence_len {
            return Err(CtError::ShapeMismatch {
                op: "transformer block training mask",
                expected: format!(
                    "{} query rows x {} key columns",
                    sequence_len.value(),
                    sequence_len.value()
                ),
                got: format!(
                    "{} query rows x {} key columns",
                    mask.query_len().value(),
                    mask.key_len().value()
                ),
            });
        }

        Ok(Self {
            hidden,
            mask,
            targets,
        })
    }

    pub fn hidden(&self) -> &HiddenSequence {
        &self.hidden
    }

    pub fn mask(&self) -> &AttentionMask {
        &self.mask
    }

    pub fn targets(&self) -> &TokenSequence {
        &self.targets
    }
}

/// Non-empty set of supervised block-level training examples.
#[derive(Debug, Clone, PartialEq)]
pub struct TransformerBlockTrainingSet(Vec<TransformerBlockTrainingExample>);

impl TransformerBlockTrainingSet {
    pub fn new(
        examples: impl IntoIterator<Item = TransformerBlockTrainingExample>,
    ) -> CtResult<Self> {
        let examples = examples.into_iter().collect::<Vec<_>>();

        if examples.is_empty() {
            return Err(CtError::EmptyInput("transformer block training set"));
        }

        Ok(Self(examples))
    }

    pub fn examples(&self) -> &[TransformerBlockTrainingExample] {
        &self.0
    }
}

/// One full-batch update through the readout, block sublayers, and attention heads.
#[derive(Debug, Clone, PartialEq)]
pub struct TransformerBlockTrainStep {
    dataset: TransformerBlockTrainingSet,
}

impl TransformerBlockTrainStep {
    pub fn new(dataset: TransformerBlockTrainingSet) -> Self {
        Self { dataset }
    }
}

impl Morphism<TransformerTrainingState, TransformerTrainingState> for TransformerBlockTrainStep {
    fn name(&self) -> &'static str {
        "transformer_block_train_step"
    }

    fn apply(&self, state: TransformerTrainingState) -> CtResult<TransformerTrainingState> {
        let readout = state.parameters.readout();
        let feed_forward = state.parameters.feed_forward();
        let output_projection = state.parameters.output_projection();
        let attention_norm = state.parameters.attention_norm();
        let feed_forward_norm = state.parameters.feed_forward_norm();
        let mut readout_gradients = ReadoutGradients::new(readout);
        let mut feed_forward_gradients = FeedForwardGradients::new(feed_forward);
        let mut output_projection_gradients =
            AttentionOutputProjectionGradients::new(output_projection);
        let mut attention_norm_gradients = LayerNormGradients::new(attention_norm.parameters());
        let mut feed_forward_norm_gradients =
            LayerNormGradients::new(feed_forward_norm.parameters());
        let mut attention_head_gradients = state
            .parameters
            .attention_heads()
            .iter()
            .map(AttentionHeadGradients::new)
            .collect::<Vec<_>>();
        let mut position_count = 0usize;

        for example in self.dataset.examples() {
            let positioned = state
                .parameters
                .positional_encoding
                .apply(example.hidden().clone())?;
            let block_cache = state
                .parameters
                .block
                .apply_with_training_cache(positioned.clone(), example.mask().clone())?;
            let logits = readout.apply(block_cache.output.clone())?;
            let vocab_size = logits.vocab_size().value();
            let mut d_multi_head_rows = vec![
                vec![0.0; output_projection.input_dimension().value()];
                block_cache.multi_head_output.sequence_len().value()
            ];

            for (
                position,
                (
                    (((encoded_row, logit_row), with_feed_forward_row), with_attention_row),
                    ((feed_forward_cache, multi_head_row), target),
                ),
            ) in block_cache
                .output
                .rows()
                .iter()
                .zip(logits.rows())
                .zip(block_cache.with_feed_forward.rows())
                .zip(block_cache.with_attention.rows())
                .zip(
                    block_cache
                        .feed_forward_rows
                        .iter()
                        .zip(block_cache.multi_head_output.rows())
                        .zip(example.targets().as_slice()),
                )
                .enumerate()
            {
                let dlogits =
                    softmax_cross_entropy_logits_gradient(logit_row, target.index(), vocab_size)?;
                let d_encoded =
                    readout_gradients.accumulate(readout, encoded_row.as_slice(), &dlogits);
                let d_with_feed_forward = feed_forward_norm_gradients.accumulate(
                    &d_encoded,
                    with_feed_forward_row.as_slice(),
                    feed_forward_norm.parameters(),
                );

                let d_feed_forward_input = feed_forward_gradients.accumulate(
                    feed_forward,
                    feed_forward_cache,
                    &d_with_feed_forward,
                );
                let d_normalized_attention = add_rows(&d_with_feed_forward, &d_feed_forward_input);
                let d_with_attention = attention_norm_gradients.accumulate(
                    &d_normalized_attention,
                    with_attention_row.as_slice(),
                    attention_norm.parameters(),
                );
                d_multi_head_rows[position] = output_projection_gradients.accumulate(
                    output_projection,
                    multi_head_row.as_slice(),
                    &d_with_attention,
                );
                position_count += 1;
            }

            let value_dimension = block_cache.multi_head_output.head_dimension().value();
            for (head_index, ((head_gradient, head), head_cache)) in attention_head_gradients
                .iter_mut()
                .zip(state.parameters.attention_heads())
                .zip(&block_cache.attention_heads)
                .enumerate()
            {
                let start = head_index * value_dimension;
                let end = start + value_dimension;
                let d_head_output_rows = d_multi_head_rows
                    .iter()
                    .map(|row| row[start..end].to_vec())
                    .collect::<Vec<_>>();

                head_gradient.accumulate(
                    head,
                    &positioned,
                    head_cache,
                    example.mask(),
                    &d_head_output_rows,
                )?;
            }
        }

        let updated_readout =
            readout_gradients.apply_to(readout, state.learning_rate(), position_count)?;
        let updated_feed_forward =
            feed_forward_gradients.apply_to(feed_forward, state.learning_rate(), position_count)?;
        let updated_output_projection = output_projection_gradients.apply_to(
            output_projection,
            state.learning_rate(),
            position_count,
        )?;
        let updated_attention_norm = LayerNormalization::new(attention_norm_gradients.apply_to(
            attention_norm.parameters(),
            state.learning_rate(),
            position_count,
        )?);
        let updated_feed_forward_norm =
            LayerNormalization::new(feed_forward_norm_gradients.apply_to(
                feed_forward_norm.parameters(),
                state.learning_rate(),
                position_count,
            )?);
        let updated_heads = state
            .parameters
            .attention_heads()
            .iter()
            .zip(attention_head_gradients)
            .map(|(head, gradients)| {
                gradients.apply_to(head, state.learning_rate(), position_count)
            })
            .collect::<CtResult<Vec<_>>>()?;
        let updated_parameters = state
            .parameters
            .clone()
            .with_readout(updated_readout)?
            .with_feed_forward(updated_feed_forward)?
            .with_output_projection(updated_output_projection)?
            .with_layer_norms(updated_attention_norm, updated_feed_forward_norm)?
            .with_attention_heads(updated_heads)?;

        Ok(state.record_updated_parameters(updated_parameters))
    }
}

/// Average sequence cross-entropy for the structured Transformer readout.
pub fn transformer_readout_average_loss(
    state: &TransformerTrainingState,
    dataset: &TransformerReadoutTrainingSet,
) -> CtResult<Loss> {
    let mut total = 0.0;
    let mut position_count = 0usize;

    for example in dataset.examples() {
        let logits = state.apply(Product::new(
            example.hidden().clone(),
            example.mask().clone(),
        ))?;
        let vocab_size = logits.vocab_size().value();

        for (logit_row, target) in logits.rows().iter().zip(example.targets().as_slice()) {
            let target_index = target.index();

            if target_index >= vocab_size {
                return Err(CtError::OutOfRange {
                    kind: "sequence target",
                    index: target_index,
                    limit: vocab_size,
                });
            }

            let probabilities = Softmax.apply(logit_row.clone())?;
            let probability = probabilities.as_slice()[target_index].max(1e-9);
            total += -probability.ln();
            position_count += 1;
        }
    }

    Loss::new(total / position_count as f32)
}

/// Average squared error for local feed-forward sublayer training.
pub fn transformer_feed_forward_average_loss(
    state: &TransformerTrainingState,
    dataset: &TransformerFeedForwardTrainingSet,
) -> CtResult<Loss> {
    let feed_forward = state.parameters.feed_forward();
    let mut total = 0.0;
    let mut value_count = 0usize;

    for example in dataset.examples() {
        let output = feed_forward.apply(example.input().clone())?;

        for (output_row, target_row) in output.rows().iter().zip(example.target().rows()) {
            for (output_value, target_value) in
                output_row.as_slice().iter().zip(target_row.as_slice())
            {
                let error = output_value - target_value;
                total += 0.5 * error * error;
                value_count += 1;
            }
        }
    }

    Loss::new(total / value_count as f32)
}

/// Average sequence cross-entropy for the composed block-level training set.
pub fn transformer_block_average_loss(
    state: &TransformerTrainingState,
    dataset: &TransformerBlockTrainingSet,
) -> CtResult<Loss> {
    let mut total = 0.0;
    let mut position_count = 0usize;

    for example in dataset.examples() {
        let logits = state.apply(Product::new(
            example.hidden().clone(),
            example.mask().clone(),
        ))?;
        let vocab_size = logits.vocab_size().value();

        for (logit_row, target) in logits.rows().iter().zip(example.targets().as_slice()) {
            let target_index = target.index();

            if target_index >= vocab_size {
                return Err(CtError::OutOfRange {
                    kind: "sequence target",
                    index: target_index,
                    limit: vocab_size,
                });
            }

            let probabilities = Softmax.apply(logit_row.clone())?;
            let probability = probabilities.as_slice()[target_index].max(1e-9);
            total += -probability.ln();
            position_count += 1;
        }
    }

    Loss::new(total / position_count as f32)
}

/// Applies softmax independently to each query row.
#[derive(Debug, Clone)]
pub struct AttentionSoftmax;

impl Morphism<AttentionScores, AttentionWeights> for AttentionSoftmax {
    fn name(&self) -> &'static str {
        "attention_softmax"
    }

    fn apply(&self, scores: AttentionScores) -> CtResult<AttentionWeights> {
        let rows = scores
            .rows
            .into_iter()
            .map(|row| Softmax.apply(row))
            .collect::<CtResult<Vec<_>>>()?;

        AttentionWeights::new(rows)
    }
}

/// Mixes value vectors with row-wise attention weights.
#[derive(Debug, Clone)]
pub struct WeightedValueMixing;

impl Morphism<Product<AttentionWeights, ValueSequence>, AttentionOutput> for WeightedValueMixing {
    fn name(&self) -> &'static str {
        "weighted_value_mixing"
    }

    fn apply(&self, input: Product<AttentionWeights, ValueSequence>) -> CtResult<AttentionOutput> {
        let (weights, values) = input.into_parts();

        if weights.key_len() != values.sequence_len() {
            return Err(CtError::ShapeMismatch {
                op: "weighted value mixing",
                expected: format!("{} value rows", weights.key_len().value()),
                got: format!("{} value rows", values.sequence_len().value()),
            });
        }

        let value_width = values.head_dimension().value();
        let rows = weights
            .rows()
            .iter()
            .map(|weight_row| weighted_sum(weight_row.as_slice(), values.rows(), value_width))
            .collect::<Vec<_>>();

        AttentionOutput::new(rows)
    }
}

fn weighted_sum(weights: &[f32], values: &[Vector], value_width: usize) -> Vec<f32> {
    let mut output = vec![0.0; value_width];

    for (weight, value) in weights.iter().zip(values.iter()) {
        for (output_value, value_component) in output.iter_mut().zip(value.as_slice()) {
            *output_value += weight * value_component;
        }
    }

    output
}

/// Concatenates single-head outputs along the feature dimension.
#[derive(Debug, Clone)]
pub struct ConcatenateHeads;

impl Morphism<AttentionHeadOutputs, MultiHeadOutput> for ConcatenateHeads {
    fn name(&self) -> &'static str {
        "concatenate_heads"
    }

    fn apply(&self, heads: AttentionHeadOutputs) -> CtResult<MultiHeadOutput> {
        let rows = (0..heads.sequence_len().value())
            .map(|position| {
                heads
                    .heads()
                    .iter()
                    .flat_map(|head| head.rows()[position].as_slice().iter().copied())
                    .collect::<Vec<_>>()
            })
            .collect::<Vec<_>>();

        MultiHeadOutput::new(rows, heads.head_count(), heads.head_dimension())
    }
}

impl Morphism<MultiHeadOutput, ProjectedAttentionOutput> for AttentionOutputProjection {
    fn name(&self) -> &'static str {
        "attention_output_projection"
    }

    fn apply(&self, input: MultiHeadOutput) -> CtResult<ProjectedAttentionOutput> {
        if input.model_dimension() != self.input_dimension {
            return Err(CtError::ShapeMismatch {
                op: "attention output projection",
                expected: format!("input dimension {}", self.input_dimension.value()),
                got: format!("input dimension {}", input.model_dimension().value()),
            });
        }

        let rows = input
            .rows()
            .iter()
            .map(|row| project_row(row.as_slice(), &self.weight, &self.bias))
            .collect::<Vec<_>>();

        ProjectedAttentionOutput::new(rows)
    }
}

/// Adds a same-shaped sublayer output back to the hidden sequence.
#[derive(Debug, Clone)]
pub struct ResidualConnection;

impl Morphism<Product<HiddenSequence, ProjectedAttentionOutput>, HiddenSequence>
    for ResidualConnection
{
    fn name(&self) -> &'static str {
        "residual_connection"
    }

    fn apply(
        &self,
        input: Product<HiddenSequence, ProjectedAttentionOutput>,
    ) -> CtResult<HiddenSequence> {
        let (hidden, sublayer_output) = input.into_parts();

        if hidden.sequence_len() != sublayer_output.sequence_len() {
            return Err(CtError::ShapeMismatch {
                op: "residual connection",
                expected: format!("{} sequence rows", hidden.sequence_len().value()),
                got: format!("{} sequence rows", sublayer_output.sequence_len().value()),
            });
        }

        if hidden.model_dimension() != sublayer_output.model_dimension() {
            return Err(CtError::ShapeMismatch {
                op: "residual connection",
                expected: format!("model dimension {}", hidden.model_dimension().value()),
                got: format!(
                    "model dimension {}",
                    sublayer_output.model_dimension().value()
                ),
            });
        }

        let rows = hidden
            .rows()
            .iter()
            .zip(sublayer_output.rows())
            .map(|(left, right)| add_rows(left.as_slice(), right.as_slice()))
            .collect::<Vec<_>>();

        HiddenSequence::new(rows)
    }
}

impl Morphism<HiddenSequence, HiddenSequence> for LayerNormalization {
    fn name(&self) -> &'static str {
        "layer_normalization"
    }

    fn apply(&self, input: HiddenSequence) -> CtResult<HiddenSequence> {
        if input.model_dimension() != self.parameters.model_dimension() {
            return Err(CtError::ShapeMismatch {
                op: "layer normalization",
                expected: format!(
                    "model dimension {}",
                    self.parameters.model_dimension().value()
                ),
                got: format!("model dimension {}", input.model_dimension().value()),
            });
        }

        let rows = input
            .rows()
            .iter()
            .map(|row| normalize_row(row.as_slice(), &self.parameters))
            .collect::<Vec<_>>();

        HiddenSequence::new(rows)
    }
}

impl Morphism<HiddenSequence, HiddenSequence> for PositionWiseFeedForward {
    fn name(&self) -> &'static str {
        "position_wise_feed_forward"
    }

    fn apply(&self, input: HiddenSequence) -> CtResult<HiddenSequence> {
        let (output, _cache) = feed_forward_with_cache(self, &input)?;

        Ok(output)
    }
}

impl Morphism<HiddenSequence, HiddenSequence> for PositionalEncoding {
    fn name(&self) -> &'static str {
        "positional_encoding"
    }

    fn apply(&self, input: HiddenSequence) -> CtResult<HiddenSequence> {
        if input.sequence_len().value() > self.max_sequence_len.value() {
            return Err(CtError::ShapeMismatch {
                op: "positional encoding",
                expected: format!("at most {} sequence rows", self.max_sequence_len.value()),
                got: format!("{} sequence rows", input.sequence_len().value()),
            });
        }

        if input.model_dimension() != self.model_dimension {
            return Err(CtError::ShapeMismatch {
                op: "positional encoding",
                expected: format!("model dimension {}", self.model_dimension.value()),
                got: format!("model dimension {}", input.model_dimension().value()),
            });
        }

        let rows = input
            .rows()
            .iter()
            .zip(&self.rows)
            .map(|(hidden_row, position_row)| {
                add_rows(hidden_row.as_slice(), position_row.as_slice())
            })
            .collect::<Vec<_>>();

        HiddenSequence::new(rows)
    }
}

impl Morphism<HiddenSequence, SequenceLogits> for TransformerReadout {
    fn name(&self) -> &'static str {
        "transformer_readout"
    }

    fn apply(&self, input: HiddenSequence) -> CtResult<SequenceLogits> {
        if input.model_dimension() != self.input_dimension {
            return Err(CtError::ShapeMismatch {
                op: "transformer readout",
                expected: format!("input dimension {}", self.input_dimension.value()),
                got: format!("input dimension {}", input.model_dimension().value()),
            });
        }

        let rows = input
            .rows()
            .iter()
            .map(|row| project_row(row.as_slice(), &self.weight, &self.bias))
            .collect::<Vec<_>>();

        SequenceLogits::new(rows)
    }
}

impl Morphism<HiddenSequence, QuerySequence> for HiddenToQuery {
    fn name(&self) -> &'static str {
        "hidden_to_query"
    }

    fn apply(&self, input: HiddenSequence) -> CtResult<QuerySequence> {
        QuerySequence::new(self.projection.project(&input)?)
    }
}

impl Morphism<HiddenSequence, KeySequence> for HiddenToKey {
    fn name(&self) -> &'static str {
        "hidden_to_key"
    }

    fn apply(&self, input: HiddenSequence) -> CtResult<KeySequence> {
        KeySequence::new(self.projection.project(&input)?)
    }
}

impl Morphism<HiddenSequence, ValueSequence> for HiddenToValue {
    fn name(&self) -> &'static str {
        "hidden_to_value"
    }

    fn apply(&self, input: HiddenSequence) -> CtResult<ValueSequence> {
        ValueSequence::new(self.projection.project(&input)?)
    }
}

fn apply_self_attention_head(
    input: &HiddenSequence,
    head: &SelfAttentionHead,
) -> CtResult<AttentionOutput> {
    apply_self_attention_head_with_mask(input, head, None)
}

fn apply_self_attention_head_with_mask(
    input: &HiddenSequence,
    head: &SelfAttentionHead,
    mask: Option<&AttentionMask>,
) -> CtResult<AttentionOutput> {
    Ok(apply_self_attention_head_with_mask_cache(input, head, mask)?.output)
}

fn apply_self_attention_head_with_mask_cache(
    input: &HiddenSequence,
    head: &SelfAttentionHead,
    mask: Option<&AttentionMask>,
) -> CtResult<AttentionHeadTrainingCache> {
    let queries = head.query_projection.apply(input.clone())?;
    let keys = head.key_projection.apply(input.clone())?;
    let values = head.value_projection.apply(input.clone())?;
    let scores = ScaledDotProductScores.apply(Product::new(queries.clone(), keys.clone()))?;
    let scores = if let Some(mask) = mask {
        MaskedAttentionScores.apply(Product::new(scores, mask.clone()))?
    } else {
        scores
    };
    let weights = AttentionSoftmax.apply(scores)?;
    let output = WeightedValueMixing.apply(Product::new(weights.clone(), values.clone()))?;

    Ok(AttentionHeadTrainingCache {
        queries,
        keys,
        values,
        weights,
        output,
    })
}

impl Morphism<Product<HiddenSequence, HiddenSequence>, HiddenSequence> for ResidualConnection {
    fn name(&self) -> &'static str {
        "hidden_residual_connection"
    }

    fn apply(&self, input: Product<HiddenSequence, HiddenSequence>) -> CtResult<HiddenSequence> {
        let (left, right) = input.into_parts();

        if left.sequence_len() != right.sequence_len() {
            return Err(CtError::ShapeMismatch {
                op: "hidden residual connection",
                expected: format!("{} sequence rows", left.sequence_len().value()),
                got: format!("{} sequence rows", right.sequence_len().value()),
            });
        }

        if left.model_dimension() != right.model_dimension() {
            return Err(CtError::ShapeMismatch {
                op: "hidden residual connection",
                expected: format!("model dimension {}", left.model_dimension().value()),
                got: format!("model dimension {}", right.model_dimension().value()),
            });
        }

        let rows = left
            .rows()
            .iter()
            .zip(right.rows())
            .map(|(left, right)| add_rows(left.as_slice(), right.as_slice()))
            .collect::<Vec<_>>();

        HiddenSequence::new(rows)
    }
}

impl Morphism<HiddenSequence, HiddenSequence> for SingleHeadTransformerBlock {
    fn name(&self) -> &'static str {
        "single_head_transformer_block"
    }

    fn apply(&self, input: HiddenSequence) -> CtResult<HiddenSequence> {
        if input.model_dimension() != self.model_dimension {
            return Err(CtError::ShapeMismatch {
                op: "single-head block",
                expected: format!("model dimension {}", self.model_dimension.value()),
                got: format!("model dimension {}", input.model_dimension().value()),
            });
        }

        let head = SelfAttentionHead::new(
            self.query_projection.clone(),
            self.key_projection.clone(),
            self.value_projection.clone(),
        )?;
        let attention_output = apply_self_attention_head(&input, &head)?;
        let head_outputs = AttentionHeadOutputs::new(vec![attention_output])?;
        let multi_head_output = ConcatenateHeads.apply(head_outputs)?;
        let projected_attention = self.output_projection.apply(multi_head_output)?;
        let with_attention = ResidualConnection.apply(Product::new(input, projected_attention))?;
        let normalized_attention = self.attention_norm.apply(with_attention)?;
        let feed_forward_output = self.feed_forward.apply(normalized_attention.clone())?;
        let with_feed_forward =
            ResidualConnection.apply(Product::new(normalized_attention, feed_forward_output))?;

        self.feed_forward_norm.apply(with_feed_forward)
    }
}

impl Morphism<HiddenSequence, HiddenSequence> for MultiHeadTransformerBlock {
    fn name(&self) -> &'static str {
        "multi_head_transformer_block"
    }

    fn apply(&self, input: HiddenSequence) -> CtResult<HiddenSequence> {
        if input.model_dimension() != self.model_dimension {
            return Err(CtError::ShapeMismatch {
                op: "multi-head block",
                expected: format!("model dimension {}", self.model_dimension.value()),
                got: format!("model dimension {}", input.model_dimension().value()),
            });
        }

        let attention_outputs = self
            .heads
            .iter()
            .map(|head| apply_self_attention_head(&input, head))
            .collect::<CtResult<Vec<_>>>()?;
        let head_outputs = AttentionHeadOutputs::new(attention_outputs)?;
        let multi_head_output = ConcatenateHeads.apply(head_outputs)?;
        let projected_attention = self.output_projection.apply(multi_head_output)?;
        let with_attention = ResidualConnection.apply(Product::new(input, projected_attention))?;
        let normalized_attention = self.attention_norm.apply(with_attention)?;
        let feed_forward_output = self.feed_forward.apply(normalized_attention.clone())?;
        let with_feed_forward =
            ResidualConnection.apply(Product::new(normalized_attention, feed_forward_output))?;

        self.feed_forward_norm.apply(with_feed_forward)
    }
}

impl Morphism<Product<HiddenSequence, AttentionMask>, HiddenSequence>
    for MaskedMultiHeadTransformerBlock
{
    fn name(&self) -> &'static str {
        "masked_multi_head_transformer_block"
    }

    fn apply(&self, input: Product<HiddenSequence, AttentionMask>) -> CtResult<HiddenSequence> {
        let (hidden, mask) = input.into_parts();

        Ok(self.apply_with_training_cache(hidden, mask)?.output)
    }
}

impl Morphism<Product<HiddenSequence, AttentionMask>, SequenceLogits>
    for TinyTransformerParameters
{
    fn name(&self) -> &'static str {
        "tiny_transformer_parameters"
    }

    fn apply(&self, input: Product<HiddenSequence, AttentionMask>) -> CtResult<SequenceLogits> {
        let (hidden, mask) = input.into_parts();
        let positioned = self.positional_encoding.apply(hidden)?;
        let encoded = self.block.apply(Product::new(positioned, mask))?;

        self.readout.apply(encoded)
    }
}

impl Morphism<Product<HiddenSequence, AttentionMask>, SequenceLogits> for TransformerTrainingState {
    fn name(&self) -> &'static str {
        "transformer_training_state_forward"
    }

    fn apply(&self, input: Product<HiddenSequence, AttentionMask>) -> CtResult<SequenceLogits> {
        self.parameters.apply(input)
    }
}

#[derive(Debug, Clone, PartialEq)]
struct ReadoutGradients {
    weight: Vec<Vec<f32>>,
    bias: Vec<f32>,
}

impl ReadoutGradients {
    fn new(readout: &TransformerReadout) -> Self {
        Self {
            weight: vec![
                vec![0.0; readout.vocab_size().value()];
                readout.input_dimension().value()
            ],
            bias: vec![0.0; readout.vocab_size().value()],
        }
    }

    fn accumulate(
        &mut self,
        readout: &TransformerReadout,
        hidden_row: &[f32],
        dlogits: &[f32],
    ) -> Vec<f32> {
        let mut d_hidden = vec![0.0; readout.input_dimension().value()];

        for (vocab_id, dlogit) in dlogits.iter().copied().enumerate() {
            self.bias[vocab_id] += dlogit;

            for (feature, hidden_value) in hidden_row.iter().copied().enumerate() {
                self.weight[feature][vocab_id] += hidden_value * dlogit;
                d_hidden[feature] += readout.weight()[feature][vocab_id] * dlogit;
            }
        }

        d_hidden
    }

    fn apply_to(
        self,
        readout: &TransformerReadout,
        learning_rate: LearningRate,
        position_count: usize,
    ) -> CtResult<TransformerReadout> {
        if position_count == 0 {
            return Err(CtError::EmptyInput("readout gradient positions"));
        }

        let scale = learning_rate.value() / position_count as f32;
        let mut updated_weight = readout.weight().to_vec();
        let mut updated_bias = readout.bias().to_vec();

        for (row, grad_row) in updated_weight.iter_mut().zip(&self.weight) {
            for (weight, grad) in row.iter_mut().zip(grad_row) {
                *weight -= scale * grad;
            }
        }

        for (bias, grad) in updated_bias.iter_mut().zip(&self.bias) {
            *bias -= scale * grad;
        }

        TransformerReadout::new(updated_weight, updated_bias)
    }
}

#[derive(Debug, Clone, PartialEq)]
struct FeedForwardGradients {
    first_weight: Vec<Vec<f32>>,
    first_bias: Vec<f32>,
    second_weight: Vec<Vec<f32>>,
    second_bias: Vec<f32>,
}

impl FeedForwardGradients {
    fn new(feed_forward: &PositionWiseFeedForward) -> Self {
        Self {
            first_weight: vec![
                vec![0.0; feed_forward.hidden_dimension().value()];
                feed_forward.input_dimension().value()
            ],
            first_bias: vec![0.0; feed_forward.hidden_dimension().value()],
            second_weight: vec![
                vec![0.0; feed_forward.output_dimension().value()];
                feed_forward.hidden_dimension().value()
            ],
            second_bias: vec![0.0; feed_forward.output_dimension().value()],
        }
    }

    fn accumulate(
        &mut self,
        feed_forward: &PositionWiseFeedForward,
        cache: &FeedForwardRowCache,
        d_output: &[f32],
    ) -> Vec<f32> {
        let mut d_activation = vec![0.0; feed_forward.hidden_dimension().value()];

        for (output_id, d_output_value) in d_output.iter().copied().enumerate() {
            self.second_bias[output_id] += d_output_value;

            for (hidden_id, activation_value) in cache.activation.iter().copied().enumerate() {
                self.second_weight[hidden_id][output_id] += activation_value * d_output_value;
                d_activation[hidden_id] +=
                    feed_forward.second_weight()[hidden_id][output_id] * d_output_value;
            }
        }

        let mut d_input = vec![0.0; feed_forward.input_dimension().value()];

        for (hidden_id, pre_activation_value) in cache.pre_activation.iter().copied().enumerate() {
            let d_pre_activation = if pre_activation_value > 0.0 {
                d_activation[hidden_id]
            } else {
                0.0
            };
            self.first_bias[hidden_id] += d_pre_activation;

            for (input_id, input_value) in cache.input.iter().copied().enumerate() {
                self.first_weight[input_id][hidden_id] += input_value * d_pre_activation;
                d_input[input_id] +=
                    feed_forward.first_weight()[input_id][hidden_id] * d_pre_activation;
            }
        }

        d_input
    }

    fn apply_to(
        self,
        feed_forward: &PositionWiseFeedForward,
        learning_rate: LearningRate,
        row_count: usize,
    ) -> CtResult<PositionWiseFeedForward> {
        if row_count == 0 {
            return Err(CtError::EmptyInput("feed-forward gradient rows"));
        }

        let scale = learning_rate.value() / row_count as f32;
        let mut updated_first_weight = feed_forward.first_weight().to_vec();
        let mut updated_first_bias = feed_forward.first_bias().to_vec();
        let mut updated_second_weight = feed_forward.second_weight().to_vec();
        let mut updated_second_bias = feed_forward.second_bias().to_vec();

        for (row, grad_row) in updated_first_weight.iter_mut().zip(&self.first_weight) {
            for (weight, grad) in row.iter_mut().zip(grad_row) {
                *weight -= scale * grad;
            }
        }

        for (bias, grad) in updated_first_bias.iter_mut().zip(&self.first_bias) {
            *bias -= scale * grad;
        }

        for (row, grad_row) in updated_second_weight.iter_mut().zip(&self.second_weight) {
            for (weight, grad) in row.iter_mut().zip(grad_row) {
                *weight -= scale * grad;
            }
        }

        for (bias, grad) in updated_second_bias.iter_mut().zip(&self.second_bias) {
            *bias -= scale * grad;
        }

        PositionWiseFeedForward::new(
            updated_first_weight,
            updated_first_bias,
            updated_second_weight,
            updated_second_bias,
        )
    }
}

#[derive(Debug, Clone, PartialEq)]
struct AttentionOutputProjectionGradients {
    weight: Vec<Vec<f32>>,
    bias: Vec<f32>,
}

impl AttentionOutputProjectionGradients {
    fn new(output_projection: &AttentionOutputProjection) -> Self {
        Self {
            weight: vec![
                vec![0.0; output_projection.output_dimension().value()];
                output_projection.input_dimension().value()
            ],
            bias: vec![0.0; output_projection.output_dimension().value()],
        }
    }

    fn accumulate(
        &mut self,
        output_projection: &AttentionOutputProjection,
        multi_head_row: &[f32],
        d_projected_attention: &[f32],
    ) -> Vec<f32> {
        let mut d_multi_head = vec![0.0; output_projection.input_dimension().value()];

        for (output_id, d_value) in d_projected_attention.iter().copied().enumerate() {
            self.bias[output_id] += d_value;

            for (input_id, input_value) in multi_head_row.iter().copied().enumerate() {
                self.weight[input_id][output_id] += input_value * d_value;
                d_multi_head[input_id] += output_projection.weight()[input_id][output_id] * d_value;
            }
        }

        d_multi_head
    }

    fn apply_to(
        self,
        output_projection: &AttentionOutputProjection,
        learning_rate: LearningRate,
        row_count: usize,
    ) -> CtResult<AttentionOutputProjection> {
        if row_count == 0 {
            return Err(CtError::EmptyInput(
                "attention output projection gradient rows",
            ));
        }

        let scale = learning_rate.value() / row_count as f32;
        let mut updated_weight = output_projection.weight().to_vec();
        let mut updated_bias = output_projection.bias().to_vec();

        for (row, grad_row) in updated_weight.iter_mut().zip(&self.weight) {
            for (weight, grad) in row.iter_mut().zip(grad_row) {
                *weight -= scale * grad;
            }
        }

        for (bias, grad) in updated_bias.iter_mut().zip(&self.bias) {
            *bias -= scale * grad;
        }

        AttentionOutputProjection::new(updated_weight, updated_bias)
    }
}

#[derive(Debug, Clone, PartialEq)]
struct HiddenProjectionGradients {
    weight: Vec<Vec<f32>>,
    bias: Vec<f32>,
}

impl HiddenProjectionGradients {
    fn new(projection: &HiddenProjection) -> Self {
        Self {
            weight: vec![
                vec![0.0; projection.head_dimension().value()];
                projection.input_dimension().value()
            ],
            bias: vec![0.0; projection.head_dimension().value()],
        }
    }

    fn accumulate(
        &mut self,
        projection: &HiddenProjection,
        hidden_row: &[f32],
        d_output: &[f32],
    ) -> CtResult<()> {
        if hidden_row.len() != projection.input_dimension().value() {
            return Err(CtError::ShapeMismatch {
                op: projection.op,
                expected: format!("input dimension {}", projection.input_dimension().value()),
                got: format!("input dimension {}", hidden_row.len()),
            });
        }

        if d_output.len() != projection.head_dimension().value() {
            return Err(CtError::ShapeMismatch {
                op: projection.op,
                expected: format!("output dimension {}", projection.head_dimension().value()),
                got: format!("output dimension {}", d_output.len()),
            });
        }

        for (output_id, d_output_value) in d_output.iter().copied().enumerate() {
            self.bias[output_id] += d_output_value;

            for (input_id, input_value) in hidden_row.iter().copied().enumerate() {
                self.weight[input_id][output_id] += input_value * d_output_value;
            }
        }

        Ok(())
    }

    fn updated_parts(
        self,
        projection: &HiddenProjection,
        learning_rate: LearningRate,
        row_count: usize,
    ) -> CtResult<(Vec<Vec<f32>>, Vec<f32>)> {
        if row_count == 0 {
            return Err(CtError::EmptyInput("hidden projection gradient rows"));
        }

        let scale = learning_rate.value() / row_count as f32;
        let mut updated_weight = projection.weight().to_vec();
        let mut updated_bias = projection.bias().to_vec();

        for (row, grad_row) in updated_weight.iter_mut().zip(&self.weight) {
            for (weight, grad) in row.iter_mut().zip(grad_row) {
                *weight -= scale * grad;
            }
        }

        for (bias, grad) in updated_bias.iter_mut().zip(&self.bias) {
            *bias -= scale * grad;
        }

        Ok((updated_weight, updated_bias))
    }
}

#[derive(Debug, Clone, PartialEq)]
struct AttentionHeadGradients {
    query: HiddenProjectionGradients,
    key: HiddenProjectionGradients,
    value: HiddenProjectionGradients,
}

impl AttentionHeadGradients {
    fn new(head: &SelfAttentionHead) -> Self {
        Self {
            query: HiddenProjectionGradients::new(&head.query_projection.projection),
            key: HiddenProjectionGradients::new(&head.key_projection.projection),
            value: HiddenProjectionGradients::new(&head.value_projection.projection),
        }
    }

    fn accumulate(
        &mut self,
        head: &SelfAttentionHead,
        input: &HiddenSequence,
        cache: &AttentionHeadTrainingCache,
        mask: &AttentionMask,
        d_output_rows: &[Vec<f32>],
    ) -> CtResult<()> {
        let sequence_len = input.sequence_len().value();
        let value_dimension = head.value_dimension().value();
        let query_key_dimension = head.query_key_dimension().value();

        if d_output_rows.len() != sequence_len {
            return Err(CtError::ShapeMismatch {
                op: "attention head gradients",
                expected: format!("{sequence_len} output rows"),
                got: format!("{} output rows", d_output_rows.len()),
            });
        }

        if mask.query_len().value() != sequence_len || mask.key_len().value() != sequence_len {
            return Err(CtError::ShapeMismatch {
                op: "attention head gradient mask",
                expected: format!("{sequence_len} query rows x {sequence_len} key columns"),
                got: format!(
                    "{} query rows x {} key columns",
                    mask.query_len().value(),
                    mask.key_len().value()
                ),
            });
        }

        let mut d_weights = vec![vec![0.0; sequence_len]; sequence_len];
        let mut d_values = vec![vec![0.0; value_dimension]; sequence_len];

        for (query_id, d_output) in d_output_rows.iter().enumerate() {
            if d_output.len() != value_dimension {
                return Err(CtError::ShapeMismatch {
                    op: "attention head output gradient",
                    expected: format!("value dimension {value_dimension}"),
                    got: format!("value dimension {}", d_output.len()),
                });
            }

            for key_id in 0..sequence_len {
                let value_row = cache.values.rows()[key_id].as_slice();
                let weight = cache.weights.rows()[query_id].as_slice()[key_id];

                for value_id in 0..value_dimension {
                    d_weights[query_id][key_id] += d_output[value_id] * value_row[value_id];
                    d_values[key_id][value_id] += weight * d_output[value_id];
                }
            }
        }

        let mut d_scores = vec![vec![0.0; sequence_len]; sequence_len];

        for query_id in 0..sequence_len {
            let weight_row = cache.weights.rows()[query_id].as_slice();
            let row_dot = d_weights[query_id]
                .iter()
                .zip(weight_row)
                .map(|(grad, weight)| grad * weight)
                .sum::<f32>();

            for key_id in 0..sequence_len {
                if mask.rows()[query_id][key_id] {
                    d_scores[query_id][key_id] =
                        weight_row[key_id] * (d_weights[query_id][key_id] - row_dot);
                }
            }
        }

        let score_scale = (query_key_dimension as f32).sqrt();
        let mut d_queries = vec![vec![0.0; query_key_dimension]; sequence_len];
        let mut d_keys = vec![vec![0.0; query_key_dimension]; sequence_len];

        for query_id in 0..sequence_len {
            let query_row = cache.queries.rows()[query_id].as_slice();

            for key_id in 0..sequence_len {
                let score_gradient = d_scores[query_id][key_id] / score_scale;
                let key_row = cache.keys.rows()[key_id].as_slice();

                for feature in 0..query_key_dimension {
                    d_queries[query_id][feature] += score_gradient * key_row[feature];
                    d_keys[key_id][feature] += score_gradient * query_row[feature];
                }
            }
        }

        for position in 0..sequence_len {
            let hidden_row = input.rows()[position].as_slice();

            self.query.accumulate(
                &head.query_projection.projection,
                hidden_row,
                &d_queries[position],
            )?;
            self.key.accumulate(
                &head.key_projection.projection,
                hidden_row,
                &d_keys[position],
            )?;
            self.value.accumulate(
                &head.value_projection.projection,
                hidden_row,
                &d_values[position],
            )?;
        }

        Ok(())
    }

    fn apply_to(
        self,
        head: &SelfAttentionHead,
        learning_rate: LearningRate,
        row_count: usize,
    ) -> CtResult<SelfAttentionHead> {
        let (query_weight, query_bias) = self.query.updated_parts(
            &head.query_projection.projection,
            learning_rate,
            row_count,
        )?;
        let (key_weight, key_bias) =
            self.key
                .updated_parts(&head.key_projection.projection, learning_rate, row_count)?;
        let (value_weight, value_bias) = self.value.updated_parts(
            &head.value_projection.projection,
            learning_rate,
            row_count,
        )?;

        SelfAttentionHead::new(
            HiddenToQuery::new(query_weight, query_bias)?,
            HiddenToKey::new(key_weight, key_bias)?,
            HiddenToValue::new(value_weight, value_bias)?,
        )
    }
}

#[derive(Debug, Clone, PartialEq)]
struct LayerNormGradients {
    scale: Vec<f32>,
    shift: Vec<f32>,
}

impl LayerNormGradients {
    fn new(parameters: &LayerNormParameters) -> Self {
        Self {
            scale: vec![0.0; parameters.model_dimension().value()],
            shift: vec![0.0; parameters.model_dimension().value()],
        }
    }

    fn accumulate(
        &mut self,
        d_output: &[f32],
        input: &[f32],
        parameters: &LayerNormParameters,
    ) -> Vec<f32> {
        let stats = layer_norm_stats(input, parameters);

        for (feature, (grad, normalized_value)) in
            d_output.iter().zip(&stats.normalized).enumerate()
        {
            self.scale[feature] += grad * normalized_value;
            self.shift[feature] += grad;
        }

        layer_norm_backward_from_stats(d_output, &stats, parameters)
    }

    fn apply_to(
        self,
        parameters: &LayerNormParameters,
        learning_rate: LearningRate,
        row_count: usize,
    ) -> CtResult<LayerNormParameters> {
        if row_count == 0 {
            return Err(CtError::EmptyInput("layer norm gradient rows"));
        }

        let scale = learning_rate.value() / row_count as f32;
        let mut updated_scale = parameters.scale().to_vec();
        let mut updated_shift = parameters.shift().to_vec();

        for (value, grad) in updated_scale.iter_mut().zip(&self.scale) {
            *value -= scale * grad;
        }

        for (value, grad) in updated_shift.iter_mut().zip(&self.shift) {
            *value -= scale * grad;
        }

        LayerNormParameters::new(updated_scale, updated_shift, parameters.epsilon())
    }
}

fn feed_forward_with_cache(
    feed_forward: &PositionWiseFeedForward,
    input: &HiddenSequence,
) -> CtResult<(HiddenSequence, Vec<FeedForwardRowCache>)> {
    if input.model_dimension() != feed_forward.input_dimension() {
        return Err(CtError::ShapeMismatch {
            op: "position-wise feed-forward",
            expected: format!("input dimension {}", feed_forward.input_dimension().value()),
            got: format!("input dimension {}", input.model_dimension().value()),
        });
    }

    let mut output_rows = Vec::with_capacity(input.rows().len());
    let mut cache_rows = Vec::with_capacity(input.rows().len());

    for row in input.rows() {
        let input_row = row.as_slice().to_vec();
        let pre_activation = project_row(
            &input_row,
            feed_forward.first_weight(),
            feed_forward.first_bias(),
        );
        let activation = pre_activation
            .iter()
            .map(|value| value.max(0.0))
            .collect::<Vec<_>>();
        let output = project_row(
            &activation,
            feed_forward.second_weight(),
            feed_forward.second_bias(),
        );

        output_rows.push(output.clone());
        cache_rows.push(FeedForwardRowCache {
            input: input_row,
            pre_activation,
            activation,
            output,
        });
    }

    Ok((HiddenSequence::new(output_rows)?, cache_rows))
}

fn softmax_cross_entropy_logits_gradient(
    logits: &Logits,
    target_index: usize,
    vocab_size: usize,
) -> CtResult<Vec<f32>> {
    if target_index >= vocab_size {
        return Err(CtError::OutOfRange {
            kind: "sequence target",
            index: target_index,
            limit: vocab_size,
        });
    }

    let probabilities = Softmax.apply(logits.clone())?;
    let mut dlogits = probabilities.as_slice().to_vec();
    dlogits[target_index] -= 1.0;

    Ok(dlogits)
}

#[derive(Debug, Clone, PartialEq)]
struct LayerNormStats {
    dimension: f32,
    inverse_std: f32,
    normalized: Vec<f32>,
}

fn layer_norm_stats(input: &[f32], parameters: &LayerNormParameters) -> LayerNormStats {
    let dimension = input.len() as f32;
    let mean = input.iter().sum::<f32>() / dimension;
    let variance = input
        .iter()
        .map(|value| {
            let centered = value - mean;
            centered * centered
        })
        .sum::<f32>()
        / dimension;
    let inverse_std = 1.0 / (variance + parameters.epsilon().value()).sqrt();
    let normalized = input
        .iter()
        .map(|value| (value - mean) * inverse_std)
        .collect::<Vec<_>>();

    LayerNormStats {
        dimension,
        inverse_std,
        normalized,
    }
}

fn layer_norm_backward_from_stats(
    d_output: &[f32],
    stats: &LayerNormStats,
    parameters: &LayerNormParameters,
) -> Vec<f32> {
    let d_normalized = d_output
        .iter()
        .zip(parameters.scale())
        .map(|(grad, scale)| grad * scale)
        .collect::<Vec<_>>();
    let sum_d_normalized = d_normalized.iter().sum::<f32>();
    let sum_d_normalized_times_normalized = d_normalized
        .iter()
        .zip(&stats.normalized)
        .map(|(grad, normalized_value)| grad * normalized_value)
        .sum::<f32>();

    d_normalized
        .iter()
        .zip(&stats.normalized)
        .map(|(grad, normalized_value)| {
            (stats.dimension * grad
                - sum_d_normalized
                - normalized_value * sum_d_normalized_times_normalized)
                * stats.inverse_std
                / stats.dimension
        })
        .collect()
}

fn validate_projection_input(
    op: &'static str,
    expected: ModelDimension,
    got: ModelDimension,
) -> CtResult<()> {
    if expected != got {
        return Err(CtError::ShapeMismatch {
            op,
            expected: format!("model dimension {}", expected.value()),
            got: format!("model dimension {}", got.value()),
        });
    }

    Ok(())
}

fn add_rows(left: &[f32], right: &[f32]) -> Vec<f32> {
    left.iter()
        .zip(right.iter())
        .map(|(left, right)| left + right)
        .collect()
}

fn validate_linear_parts(
    op: &'static str,
    weight: &[Vec<f32>],
    bias: &[f32],
) -> CtResult<(ModelDimension, ModelDimension)> {
    if weight.is_empty() {
        return Err(CtError::EmptyInput("linear weight"));
    }

    if bias.is_empty() {
        return Err(CtError::EmptyInput("linear bias"));
    }

    let output_dimension = bias.len();

    if bias.iter().any(|value| !value.is_finite()) {
        return Err(CtError::ShapeMismatch {
            op,
            expected: "finite bias values".to_string(),
            got: "non-finite bias value".to_string(),
        });
    }

    for row in weight {
        if row.len() != output_dimension {
            return Err(CtError::ShapeMismatch {
                op,
                expected: format!("weight rows have {output_dimension} columns"),
                got: format!("weight row with {} columns", row.len()),
            });
        }

        if row.iter().any(|value| !value.is_finite()) {
            return Err(CtError::ShapeMismatch {
                op,
                expected: "finite weight values".to_string(),
                got: "non-finite weight value".to_string(),
            });
        }
    }

    Ok((
        ModelDimension::new(weight.len())?,
        ModelDimension::new(output_dimension)?,
    ))
}

fn normalize_row(input: &[f32], parameters: &LayerNormParameters) -> Vec<f32> {
    let mean = input.iter().sum::<f32>() / input.len() as f32;
    let variance = input
        .iter()
        .map(|value| {
            let centered = value - mean;
            centered * centered
        })
        .sum::<f32>()
        / input.len() as f32;
    let denominator = (variance + parameters.epsilon.value()).sqrt();

    input
        .iter()
        .zip(parameters.scale.iter().zip(&parameters.shift))
        .map(|(value, (scale, shift))| ((value - mean) / denominator) * scale + shift)
        .collect()
}

fn project_row(input: &[f32], weight: &[Vec<f32>], bias: &[f32]) -> Vec<f32> {
    let mut output = bias.to_vec();

    for (feature, input_value) in input.iter().enumerate() {
        for (output_value, weight_value) in output.iter_mut().zip(&weight[feature]) {
            *output_value += input_value * weight_value;
        }
    }

    output
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn scaled_dot_product_scores_build_query_by_key_rows() -> CtResult<()> {
        let queries = QuerySequence::new(vec![vec![1.0, 0.0], vec![0.0, 1.0]])?;
        let keys = KeySequence::new(vec![vec![1.0, 0.0], vec![0.0, 1.0], vec![1.0, 1.0]])?;

        let scores = ScaledDotProductScores.apply(Product::new(queries, keys))?;

        assert_eq!(scores.query_len().value(), 2);
        assert_eq!(scores.key_len().value(), 3);
        assert!(crate::domain::approx_eq(
            scores.rows()[0].as_slice()[0],
            std::f32::consts::FRAC_1_SQRT_2,
            1e-4
        ));
        assert!(crate::domain::approx_eq(
            scores.rows()[0].as_slice()[1],
            0.0,
            1e-4
        ));
        assert!(crate::domain::approx_eq(
            scores.rows()[1].as_slice()[2],
            std::f32::consts::FRAC_1_SQRT_2,
            1e-4
        ));

        Ok(())
    }

    #[test]
    fn weighted_value_mixing_builds_one_output_per_query() -> CtResult<()> {
        let queries = QuerySequence::new(vec![vec![1.0, 0.0], vec![0.0, 1.0]])?;
        let keys = KeySequence::new(vec![vec![1.0, 0.0], vec![0.0, 1.0], vec![1.0, 1.0]])?;
        let values = ValueSequence::new(vec![vec![1.0, 10.0], vec![2.0, 20.0], vec![3.0, 30.0]])?;

        let scores = ScaledDotProductScores.apply(Product::new(queries, keys))?;
        let weights = AttentionSoftmax.apply(scores)?;
        let output = WeightedValueMixing.apply(Product::new(weights, values))?;

        assert_eq!(output.sequence_len().value(), 2);
        assert_eq!(output.head_dimension().value(), 2);
        assert!(crate::domain::approx_eq(
            output.rows()[0].as_slice()[0],
            2.0,
            1e-4
        ));
        assert!(crate::domain::approx_eq(
            output.rows()[0].as_slice()[1],
            20.0,
            1e-4
        ));

        Ok(())
    }

    #[test]
    fn concatenate_heads_preserves_sequence_and_concatenates_features() -> CtResult<()> {
        let head_a = AttentionOutput::new(vec![vec![1.0, 10.0], vec![2.0, 20.0]])?;
        let head_b = AttentionOutput::new(vec![vec![3.0, 30.0], vec![4.0, 40.0]])?;

        let heads = AttentionHeadOutputs::new(vec![head_a, head_b])?;
        let output = ConcatenateHeads.apply(heads)?;

        assert_eq!(output.sequence_len().value(), 2);
        assert_eq!(output.head_count().value(), 2);
        assert_eq!(output.head_dimension().value(), 2);
        assert_eq!(output.model_dimension().value(), 4);
        assert_eq!(output.rows()[0].as_slice(), &[1.0, 10.0, 3.0, 30.0]);
        assert_eq!(output.rows()[1].as_slice(), &[2.0, 20.0, 4.0, 40.0]);

        Ok(())
    }

    #[test]
    fn attention_output_projection_maps_multi_head_rows() -> CtResult<()> {
        let head_a = AttentionOutput::new(vec![vec![1.0, 10.0], vec![2.0, 20.0]])?;
        let head_b = AttentionOutput::new(vec![vec![3.0, 30.0], vec![4.0, 40.0]])?;
        let multi_head =
            ConcatenateHeads.apply(AttentionHeadOutputs::new(vec![head_a, head_b])?)?;
        let projection = AttentionOutputProjection::new(
            vec![
                vec![1.0, 0.0],
                vec![0.0, 0.1],
                vec![0.5, 0.0],
                vec![0.0, 0.01],
            ],
            vec![0.0, 1.0],
        )?;

        let output = projection.apply(multi_head)?;

        assert_eq!(output.sequence_len().value(), 2);
        assert_eq!(output.model_dimension().value(), 2);
        assert!(crate::domain::approx_eq(
            output.rows()[0].as_slice()[0],
            2.5,
            1e-4
        ));
        assert!(crate::domain::approx_eq(
            output.rows()[0].as_slice()[1],
            2.3,
            1e-4
        ));
        assert!(crate::domain::approx_eq(
            output.rows()[1].as_slice()[0],
            4.0,
            1e-4
        ));
        assert!(crate::domain::approx_eq(
            output.rows()[1].as_slice()[1],
            3.4,
            1e-4
        ));

        Ok(())
    }

    #[test]
    fn residual_connection_adds_matching_sequences() -> CtResult<()> {
        let hidden = HiddenSequence::new(vec![vec![1.0, 2.0], vec![3.0, 4.0]])?;
        let sublayer_output = ProjectedAttentionOutput::new(vec![vec![0.5, 1.5], vec![2.5, 3.5]])?;

        let output = ResidualConnection.apply(Product::new(hidden, sublayer_output))?;

        assert_eq!(output.sequence_len().value(), 2);
        assert_eq!(output.model_dimension().value(), 2);
        assert_eq!(output.rows()[0].as_slice(), &[1.5, 3.5]);
        assert_eq!(output.rows()[1].as_slice(), &[5.5, 7.5]);

        Ok(())
    }

    #[test]
    fn layer_normalization_preserves_shape_and_normalizes_each_row() -> CtResult<()> {
        let hidden = HiddenSequence::new(vec![vec![1.0, 3.0], vec![2.0, 4.0]])?;
        let norm = LayerNormalization::new(LayerNormParameters::identity(ModelDimension::new(2)?));

        let output = norm.apply(hidden)?;

        assert_eq!(output.sequence_len().value(), 2);
        assert_eq!(output.model_dimension().value(), 2);
        assert!(crate::domain::approx_eq(
            output.rows()[0].as_slice()[0],
            -0.999995,
            1e-4
        ));
        assert!(crate::domain::approx_eq(
            output.rows()[0].as_slice()[1],
            0.999995,
            1e-4
        ));
        assert!(crate::domain::approx_eq(
            output.rows()[1].as_slice()[0],
            -0.999995,
            1e-4
        ));
        assert!(crate::domain::approx_eq(
            output.rows()[1].as_slice()[1],
            0.999995,
            1e-4
        ));

        Ok(())
    }

    #[test]
    fn layer_normalization_applies_scale_and_shift() -> CtResult<()> {
        let hidden = HiddenSequence::new(vec![vec![1.0, 3.0]])?;
        let params = LayerNormParameters::new(
            vec![2.0, 0.5],
            vec![1.0, -1.0],
            NormalizationEpsilon::new(1e-5)?,
        )?;
        let norm = LayerNormalization::new(params);

        let output = norm.apply(hidden)?;

        assert!(crate::domain::approx_eq(
            output.rows()[0].as_slice()[0],
            -0.99999,
            1e-4
        ));
        assert!(crate::domain::approx_eq(
            output.rows()[0].as_slice()[1],
            -0.5000025,
            1e-4
        ));

        Ok(())
    }

    #[test]
    fn position_wise_feed_forward_maps_each_row_and_preserves_shape() -> CtResult<()> {
        let hidden = HiddenSequence::new(vec![vec![1.0, 2.0], vec![3.0, 4.0]])?;
        let feed_forward = PositionWiseFeedForward::new(
            vec![vec![1.0, -1.0, 0.5], vec![0.0, 1.0, 0.5]],
            vec![0.0, 0.0, 0.0],
            vec![vec![1.0, 0.0], vec![0.0, 1.0], vec![0.5, 0.5]],
            vec![0.0, 0.0],
        )?;

        let output = feed_forward.apply(hidden)?;

        assert_eq!(output.sequence_len().value(), 2);
        assert_eq!(output.model_dimension().value(), 2);
        assert!(crate::domain::approx_eq(
            output.rows()[0].as_slice()[0],
            1.75,
            1e-4
        ));
        assert!(crate::domain::approx_eq(
            output.rows()[0].as_slice()[1],
            1.75,
            1e-4
        ));
        assert!(crate::domain::approx_eq(
            output.rows()[1].as_slice()[0],
            4.75,
            1e-4
        ));
        assert!(crate::domain::approx_eq(
            output.rows()[1].as_slice()[1],
            2.75,
            1e-4
        ));

        Ok(())
    }

    #[test]
    fn attention_mask_removes_disallowed_positions_before_softmax() -> CtResult<()> {
        let scores = AttentionScores::new(vec![vec![2.0, 1.0, 2.0]])?;
        let mask = AttentionMask::new(vec![vec![true, false, true]])?;
        let masked_scores = MaskedAttentionScores.apply(Product::new(scores, mask))?;
        let weights = AttentionSoftmax.apply(masked_scores)?;

        assert!(crate::domain::approx_eq(
            weights.rows()[0].as_slice()[0],
            0.5,
            1e-4
        ));
        assert!(crate::domain::approx_eq(
            weights.rows()[0].as_slice()[1],
            0.0,
            1e-4
        ));
        assert!(crate::domain::approx_eq(
            weights.rows()[0].as_slice()[2],
            0.5,
            1e-4
        ));

        Ok(())
    }

    #[test]
    fn attention_softmax_normalizes_each_query_row() -> CtResult<()> {
        let scores = AttentionScores::new(vec![vec![2.0, 1.0], vec![0.0, 3.0]])?;
        let weights = AttentionSoftmax.apply(scores)?;

        assert_eq!(weights.query_len().value(), 2);
        assert_eq!(weights.key_len().value(), 2);

        for row in weights.rows() {
            let sum: f32 = row.as_slice().iter().sum();
            assert!(crate::domain::approx_eq(sum, 1.0, 1e-4));
        }

        Ok(())
    }

    #[test]
    fn attention_scores_reject_non_finite_values() {
        assert!(matches!(
            AttentionScores::new(vec![vec![1.0, f32::NAN]]),
            Err(CtError::ShapeMismatch {
                op: "attention scores",
                ..
            })
        ));
    }

    #[test]
    fn attention_scores_reject_ragged_rows() {
        assert!(matches!(
            AttentionScores::new(vec![vec![1.0, 2.0], vec![3.0]]),
            Err(CtError::ShapeMismatch {
                op: "attention scores",
                ..
            })
        ));
    }

    #[test]
    fn attention_mask_rejects_rows_with_no_allowed_keys() {
        assert!(matches!(
            AttentionMask::new(vec![vec![false, false]]),
            Err(CtError::EmptyInput("attention mask row allows no keys"))
        ));
    }

    #[test]
    fn masked_attention_scores_reject_shape_mismatch() -> CtResult<()> {
        let scores = AttentionScores::new(vec![vec![1.0, 2.0]])?;
        let mask = AttentionMask::new(vec![vec![true, true], vec![true, true]])?;

        assert!(matches!(
            MaskedAttentionScores.apply(Product::new(scores, mask)),
            Err(CtError::ShapeMismatch {
                op: "masked attention scores",
                ..
            })
        ));

        Ok(())
    }

    #[test]
    fn query_sequence_rejects_ragged_rows() {
        assert!(matches!(
            QuerySequence::new(vec![vec![1.0, 2.0], vec![3.0]]),
            Err(CtError::ShapeMismatch {
                op: "query sequence",
                ..
            })
        ));
    }

    #[test]
    fn value_sequence_rejects_empty_rows() {
        assert!(matches!(
            ValueSequence::new(vec![Vec::new()]),
            Err(CtError::EmptyInput("attention vector row"))
        ));
    }

    #[test]
    fn key_sequence_rejects_non_finite_values() {
        assert!(matches!(
            KeySequence::new(vec![vec![1.0, f32::NAN]]),
            Err(CtError::ShapeMismatch {
                op: "key sequence",
                ..
            })
        ));
    }

    #[test]
    fn scaled_dot_product_rejects_mismatched_head_dimensions() -> CtResult<()> {
        let queries = QuerySequence::new(vec![vec![1.0, 0.0]])?;
        let keys = KeySequence::new(vec![vec![1.0, 0.0, 1.0]])?;

        assert!(matches!(
            ScaledDotProductScores.apply(Product::new(queries, keys)),
            Err(CtError::ShapeMismatch {
                op: "scaled dot-product attention scores",
                ..
            })
        ));

        Ok(())
    }

    #[test]
    fn weighted_value_mixing_rejects_value_length_mismatch() -> CtResult<()> {
        let weights = AttentionWeights::new(vec![Distribution::new(vec![0.5, 0.5])?])?;
        let values = ValueSequence::new(vec![vec![1.0, 10.0]])?;

        assert!(matches!(
            WeightedValueMixing.apply(Product::new(weights, values)),
            Err(CtError::ShapeMismatch {
                op: "weighted value mixing",
                ..
            })
        ));

        Ok(())
    }

    #[test]
    fn sequence_and_head_dimensions_reject_zero() {
        assert!(matches!(
            SequenceLength::new(0),
            Err(CtError::EmptyInput("sequence length"))
        ));
        assert!(matches!(
            HeadDimension::new(0),
            Err(CtError::EmptyInput("head dimension"))
        ));
        assert!(matches!(
            HeadCount::new(0),
            Err(CtError::EmptyInput("head count"))
        ));
    }

    #[test]
    fn attention_head_outputs_reject_sequence_mismatch() -> CtResult<()> {
        let head_a = AttentionOutput::new(vec![vec![1.0, 10.0]])?;
        let head_b = AttentionOutput::new(vec![vec![2.0, 20.0], vec![3.0, 30.0]])?;

        assert!(matches!(
            AttentionHeadOutputs::new(vec![head_a, head_b]),
            Err(CtError::ShapeMismatch {
                op: "attention head outputs",
                ..
            })
        ));

        Ok(())
    }

    #[test]
    fn attention_head_outputs_reject_head_dimension_mismatch() -> CtResult<()> {
        let head_a = AttentionOutput::new(vec![vec![1.0, 10.0]])?;
        let head_b = AttentionOutput::new(vec![vec![2.0, 20.0, 200.0]])?;

        assert!(matches!(
            AttentionHeadOutputs::new(vec![head_a, head_b]),
            Err(CtError::ShapeMismatch {
                op: "attention head outputs",
                ..
            })
        ));

        Ok(())
    }

    #[test]
    fn attention_output_projection_rejects_input_dimension_mismatch() -> CtResult<()> {
        let head_a = AttentionOutput::new(vec![vec![1.0, 10.0]])?;
        let head_b = AttentionOutput::new(vec![vec![2.0, 20.0]])?;
        let multi_head =
            ConcatenateHeads.apply(AttentionHeadOutputs::new(vec![head_a, head_b])?)?;
        let projection =
            AttentionOutputProjection::new(vec![vec![1.0], vec![1.0], vec![1.0]], vec![0.0])?;

        assert!(matches!(
            projection.apply(multi_head),
            Err(CtError::ShapeMismatch {
                op: "attention output projection",
                ..
            })
        ));

        Ok(())
    }

    #[test]
    fn attention_output_projection_rejects_bad_weight_shapes() {
        assert!(matches!(
            AttentionOutputProjection::new(vec![vec![1.0, 0.0], vec![1.0]], vec![0.0, 0.0]),
            Err(CtError::ShapeMismatch {
                op: "attention output projection",
                ..
            })
        ));
    }

    #[test]
    fn attention_output_projection_rejects_non_finite_values() {
        assert!(matches!(
            AttentionOutputProjection::new(vec![vec![1.0]], vec![f32::NAN]),
            Err(CtError::ShapeMismatch {
                op: "attention output projection",
                ..
            })
        ));
        assert!(matches!(
            AttentionOutputProjection::new(vec![vec![f32::INFINITY]], vec![0.0]),
            Err(CtError::ShapeMismatch {
                op: "attention output projection",
                ..
            })
        ));
    }

    #[test]
    fn residual_connection_rejects_sequence_mismatch() -> CtResult<()> {
        let hidden = HiddenSequence::new(vec![vec![1.0, 2.0]])?;
        let sublayer_output = ProjectedAttentionOutput::new(vec![vec![0.5, 1.5], vec![2.5, 3.5]])?;

        assert!(matches!(
            ResidualConnection.apply(Product::new(hidden, sublayer_output)),
            Err(CtError::ShapeMismatch {
                op: "residual connection",
                ..
            })
        ));

        Ok(())
    }

    #[test]
    fn residual_connection_rejects_model_dimension_mismatch() -> CtResult<()> {
        let hidden = HiddenSequence::new(vec![vec![1.0, 2.0]])?;
        let sublayer_output = ProjectedAttentionOutput::new(vec![vec![0.5, 1.5, 2.5]])?;

        assert!(matches!(
            ResidualConnection.apply(Product::new(hidden, sublayer_output)),
            Err(CtError::ShapeMismatch {
                op: "residual connection",
                ..
            })
        ));

        Ok(())
    }

    #[test]
    fn layer_normalization_rejects_model_dimension_mismatch() -> CtResult<()> {
        let hidden = HiddenSequence::new(vec![vec![1.0, 2.0, 3.0]])?;
        let norm = LayerNormalization::new(LayerNormParameters::identity(ModelDimension::new(2)?));

        assert!(matches!(
            norm.apply(hidden),
            Err(CtError::ShapeMismatch {
                op: "layer normalization",
                ..
            })
        ));

        Ok(())
    }

    #[test]
    fn layer_norm_parameters_reject_bad_shapes_and_values() {
        assert!(matches!(
            LayerNormParameters::new(vec![1.0, 1.0], vec![0.0], NormalizationEpsilon(1e-5)),
            Err(CtError::ShapeMismatch {
                op: "layer norm parameters",
                ..
            })
        ));
        assert!(matches!(
            LayerNormParameters::new(vec![f32::NAN], vec![0.0], NormalizationEpsilon(1e-5)),
            Err(CtError::ShapeMismatch {
                op: "layer norm parameters",
                ..
            })
        ));
        assert!(matches!(
            LayerNormParameters::new(vec![1.0], vec![f32::INFINITY], NormalizationEpsilon(1e-5)),
            Err(CtError::ShapeMismatch {
                op: "layer norm parameters",
                ..
            })
        ));
        assert!(matches!(
            NormalizationEpsilon::new(0.0),
            Err(CtError::ShapeMismatch {
                op: "normalization epsilon",
                ..
            })
        ));
    }

    #[test]
    fn position_wise_feed_forward_rejects_input_dimension_mismatch() -> CtResult<()> {
        let hidden = HiddenSequence::new(vec![vec![1.0, 2.0, 3.0]])?;
        let feed_forward = PositionWiseFeedForward::new(
            vec![vec![1.0, 0.0], vec![0.0, 1.0]],
            vec![0.0, 0.0],
            vec![vec![1.0, 0.0], vec![0.0, 1.0]],
            vec![0.0, 0.0],
        )?;

        assert!(matches!(
            feed_forward.apply(hidden),
            Err(CtError::ShapeMismatch {
                op: "position-wise feed-forward",
                ..
            })
        ));

        Ok(())
    }

    #[test]
    fn position_wise_feed_forward_rejects_incompatible_layer_shapes() {
        assert!(matches!(
            PositionWiseFeedForward::new(
                vec![vec![1.0, 0.0]],
                vec![0.0, 0.0],
                vec![vec![1.0], vec![1.0], vec![1.0]],
                vec![0.0],
            ),
            Err(CtError::ShapeMismatch {
                op: "position-wise feed-forward",
                ..
            })
        ));
        assert!(matches!(
            PositionWiseFeedForward::new(
                vec![vec![1.0, 0.0]],
                vec![0.0, 0.0],
                vec![vec![1.0, 0.0], vec![1.0, 0.0]],
                vec![0.0, 0.0],
            ),
            Err(CtError::ShapeMismatch {
                op: "position-wise feed-forward",
                ..
            })
        ));
    }

    #[test]
    fn position_wise_feed_forward_rejects_non_finite_values() {
        assert!(matches!(
            PositionWiseFeedForward::new(
                vec![vec![f32::NAN]],
                vec![0.0],
                vec![vec![1.0]],
                vec![0.0],
            ),
            Err(CtError::ShapeMismatch {
                op: "position-wise feed-forward first layer",
                ..
            })
        ));
        assert!(matches!(
            PositionWiseFeedForward::new(
                vec![vec![1.0]],
                vec![0.0],
                vec![vec![1.0]],
                vec![f32::INFINITY],
            ),
            Err(CtError::ShapeMismatch {
                op: "position-wise feed-forward second layer",
                ..
            })
        ));
    }

    #[test]
    fn positional_encoding_adds_position_rows_and_preserves_shape() -> CtResult<()> {
        let hidden = HiddenSequence::new(vec![vec![1.0, 2.0], vec![3.0, 4.0]])?;
        let positions = PositionalEncoding::new(vec![vec![0.1, 0.2], vec![0.3, 0.4]])?;

        let output = positions.apply(hidden)?;

        assert_eq!(output.sequence_len().value(), 2);
        assert_eq!(output.model_dimension().value(), 2);
        assert!(crate::domain::approx_eq(
            output.rows()[0].as_slice()[0],
            1.1,
            1e-4
        ));
        assert!(crate::domain::approx_eq(
            output.rows()[1].as_slice()[1],
            4.4,
            1e-4
        ));

        Ok(())
    }

    #[test]
    fn positional_encoding_rejects_model_dimension_mismatch() -> CtResult<()> {
        let hidden = HiddenSequence::new(vec![vec![1.0, 2.0]])?;
        let positions = PositionalEncoding::new(vec![vec![0.1, 0.2, 0.3]])?;

        assert!(matches!(
            positions.apply(hidden),
            Err(CtError::ShapeMismatch {
                op: "positional encoding",
                ..
            })
        ));

        Ok(())
    }

    #[test]
    fn positional_encoding_rejects_sequence_too_long() -> CtResult<()> {
        let hidden = HiddenSequence::new(vec![vec![1.0], vec![2.0]])?;
        let positions = PositionalEncoding::new(vec![vec![0.1]])?;

        assert!(matches!(
            positions.apply(hidden),
            Err(CtError::ShapeMismatch {
                op: "positional encoding",
                ..
            })
        ));

        Ok(())
    }

    #[test]
    fn hidden_to_query_projects_hidden_rows() -> CtResult<()> {
        let hidden = HiddenSequence::new(vec![vec![1.0, 2.0], vec![3.0, 4.0]])?;
        let projection = HiddenToQuery::new(vec![vec![1.0, 0.0], vec![0.0, 1.0]], vec![0.5, -0.5])?;

        let queries = projection.apply(hidden)?;

        assert_eq!(queries.sequence_len().value(), 2);
        assert_eq!(queries.head_dimension().value(), 2);
        assert!(crate::domain::approx_eq(
            queries.rows()[0].as_slice()[0],
            1.5,
            1e-4
        ));
        assert!(crate::domain::approx_eq(
            queries.rows()[1].as_slice()[1],
            3.5,
            1e-4
        ));

        Ok(())
    }

    #[test]
    fn hidden_projection_rejects_input_dimension_mismatch() -> CtResult<()> {
        let hidden = HiddenSequence::new(vec![vec![1.0, 2.0, 3.0]])?;
        let projection = HiddenToValue::new(vec![vec![1.0, 0.0], vec![0.0, 1.0]], vec![0.0, 0.0])?;

        assert!(matches!(
            projection.apply(hidden),
            Err(CtError::ShapeMismatch {
                op: "hidden-to-value projection",
                ..
            })
        ));

        Ok(())
    }

    #[test]
    fn residual_connection_adds_hidden_sequences() -> CtResult<()> {
        let left = HiddenSequence::new(vec![vec![1.0, 2.0]])?;
        let right = HiddenSequence::new(vec![vec![3.0, 4.0]])?;

        let output = ResidualConnection.apply(Product::new(left, right))?;

        assert_eq!(output.rows()[0].as_slice(), &[4.0, 6.0]);

        Ok(())
    }

    #[test]
    fn single_head_transformer_block_preserves_hidden_sequence_shape() -> CtResult<()> {
        let block = tiny_single_head_block()?;
        let hidden = HiddenSequence::new(vec![vec![1.0, 0.0], vec![0.0, 1.0]])?;

        let output = block.apply(hidden)?;

        assert_eq!(output.sequence_len().value(), 2);
        assert_eq!(output.model_dimension().value(), 2);
        assert!(
            output
                .rows()
                .iter()
                .flat_map(|row| row.as_slice())
                .all(|value| value.is_finite())
        );

        Ok(())
    }

    #[test]
    fn single_head_transformer_block_rejects_constructor_dimension_mismatch() -> CtResult<()> {
        let query = HiddenToQuery::new(vec![vec![1.0, 0.0], vec![0.0, 1.0]], vec![0.0, 0.0])?;
        let key = HiddenToKey::new(
            vec![vec![1.0, 0.0], vec![0.0, 1.0], vec![0.0, 0.0]],
            vec![0.0, 0.0],
        )?;
        let value = HiddenToValue::new(vec![vec![1.0, 0.0], vec![0.0, 1.0]], vec![0.0, 0.0])?;
        let output_projection =
            AttentionOutputProjection::new(vec![vec![1.0, 0.0], vec![0.0, 1.0]], vec![0.0, 0.0])?;
        let model_dimension = ModelDimension::new(2)?;
        let attention_norm =
            LayerNormalization::new(LayerNormParameters::identity(model_dimension));
        let feed_forward = PositionWiseFeedForward::new(
            vec![vec![1.0, 0.0], vec![0.0, 1.0]],
            vec![0.0, 0.0],
            vec![vec![1.0, 0.0], vec![0.0, 1.0]],
            vec![0.0, 0.0],
        )?;
        let feed_forward_norm =
            LayerNormalization::new(LayerNormParameters::identity(model_dimension));

        assert!(matches!(
            SingleHeadTransformerBlock::new(
                query,
                key,
                value,
                output_projection,
                attention_norm,
                feed_forward,
                feed_forward_norm,
            ),
            Err(CtError::ShapeMismatch {
                op: "single-head block key projection",
                ..
            })
        ));

        Ok(())
    }

    #[test]
    fn single_head_transformer_block_rejects_apply_dimension_mismatch() -> CtResult<()> {
        let block = tiny_single_head_block()?;
        let hidden = HiddenSequence::new(vec![vec![1.0, 0.0, 0.0]])?;

        assert!(matches!(
            block.apply(hidden),
            Err(CtError::ShapeMismatch {
                op: "single-head block",
                ..
            })
        ));

        Ok(())
    }

    #[test]
    fn self_attention_head_rejects_query_key_head_mismatch() -> CtResult<()> {
        let query = HiddenToQuery::new(vec![vec![1.0], vec![0.0]], vec![0.0])?;
        let key = HiddenToKey::new(vec![vec![1.0, 0.0], vec![0.0, 1.0]], vec![0.0, 0.0])?;
        let value = HiddenToValue::new(vec![vec![1.0], vec![0.0]], vec![0.0])?;

        assert!(matches!(
            SelfAttentionHead::new(query, key, value),
            Err(CtError::ShapeMismatch {
                op: "self-attention head",
                ..
            })
        ));

        Ok(())
    }

    #[test]
    fn multi_head_transformer_block_preserves_hidden_sequence_shape() -> CtResult<()> {
        let block = tiny_multi_head_block()?;
        let hidden = HiddenSequence::new(vec![vec![1.0, 0.0], vec![0.0, 1.0]])?;

        let output = block.apply(hidden)?;

        assert_eq!(block.head_count().value(), 2);
        assert_eq!(block.value_dimension().value(), 1);
        assert_eq!(output.sequence_len().value(), 2);
        assert_eq!(output.model_dimension().value(), 2);
        assert!(
            output
                .rows()
                .iter()
                .flat_map(|row| row.as_slice())
                .all(|value| value.is_finite())
        );

        Ok(())
    }

    #[test]
    fn multi_head_transformer_block_rejects_value_dimension_mismatch() -> CtResult<()> {
        let head_a = tiny_self_attention_head_first_feature()?;
        let head_b = SelfAttentionHead::new(
            HiddenToQuery::new(vec![vec![0.0], vec![1.0]], vec![0.0])?,
            HiddenToKey::new(vec![vec![0.0], vec![1.0]], vec![0.0])?,
            HiddenToValue::new(vec![vec![0.0, 1.0], vec![1.0, 0.0]], vec![0.0, 0.0])?,
        )?;
        let model_dimension = ModelDimension::new(2)?;

        assert!(matches!(
            MultiHeadTransformerBlock::new(
                vec![head_a, head_b],
                AttentionOutputProjection::new(
                    vec![vec![1.0, 0.0], vec![0.0, 1.0]],
                    vec![0.0, 0.0],
                )?,
                LayerNormalization::new(LayerNormParameters::identity(model_dimension)),
                identity_feed_forward()?,
                LayerNormalization::new(LayerNormParameters::identity(model_dimension)),
            ),
            Err(CtError::ShapeMismatch {
                op: "multi-head block",
                ..
            })
        ));

        Ok(())
    }

    #[test]
    fn multi_head_transformer_block_rejects_output_projection_input_mismatch() -> CtResult<()> {
        let model_dimension = ModelDimension::new(2)?;

        assert!(matches!(
            MultiHeadTransformerBlock::new(
                vec![
                    tiny_self_attention_head_first_feature()?,
                    tiny_self_attention_head_second_feature()?,
                ],
                AttentionOutputProjection::new(
                    vec![vec![1.0, 0.0], vec![0.0, 1.0], vec![0.0, 0.0]],
                    vec![0.0, 0.0],
                )?,
                LayerNormalization::new(LayerNormParameters::identity(model_dimension)),
                identity_feed_forward()?,
                LayerNormalization::new(LayerNormParameters::identity(model_dimension)),
            ),
            Err(CtError::ShapeMismatch {
                op: "multi-head block output projection input",
                ..
            })
        ));

        Ok(())
    }

    #[test]
    fn multi_head_transformer_block_rejects_apply_dimension_mismatch() -> CtResult<()> {
        let block = tiny_multi_head_block()?;
        let hidden = HiddenSequence::new(vec![vec![1.0, 0.0, 0.0]])?;

        assert!(matches!(
            block.apply(hidden),
            Err(CtError::ShapeMismatch {
                op: "multi-head block",
                ..
            })
        ));

        Ok(())
    }

    #[test]
    fn masked_multi_head_transformer_block_preserves_hidden_sequence_shape() -> CtResult<()> {
        let block = tiny_masked_multi_head_block()?;
        let hidden = HiddenSequence::new(vec![vec![1.0, 0.0], vec![0.0, 1.0]])?;
        let mask = AttentionMask::new(vec![vec![true, false], vec![true, true]])?;

        let output = block.apply(Product::new(hidden, mask))?;

        assert_eq!(block.head_count().value(), 2);
        assert_eq!(output.sequence_len().value(), 2);
        assert_eq!(output.model_dimension().value(), 2);
        assert!(
            output
                .rows()
                .iter()
                .flat_map(|row| row.as_slice())
                .all(|value| value.is_finite())
        );

        Ok(())
    }

    #[test]
    fn masked_multi_head_transformer_block_rejects_mask_shape_mismatch() -> CtResult<()> {
        let block = tiny_masked_multi_head_block()?;
        let hidden = HiddenSequence::new(vec![vec![1.0, 0.0], vec![0.0, 1.0]])?;
        let mask = AttentionMask::new(vec![vec![true, true, true], vec![true, true, true]])?;

        assert!(matches!(
            block.apply(Product::new(hidden, mask)),
            Err(CtError::ShapeMismatch {
                op: "masked attention scores",
                ..
            })
        ));

        Ok(())
    }

    #[test]
    fn transformer_readout_maps_each_hidden_position_to_logits() -> CtResult<()> {
        let readout = TransformerReadout::new(
            vec![vec![1.0, 0.0, 0.5], vec![0.0, 1.0, -0.5]],
            vec![0.0, 0.1, -0.1],
        )?;
        let hidden = HiddenSequence::new(vec![vec![2.0, 3.0], vec![4.0, 5.0]])?;

        let logits = readout.apply(hidden)?;

        assert_eq!(logits.sequence_len().value(), 2);
        assert_eq!(logits.vocab_size().value(), 3);
        assert_eq!(logits.rows()[0].as_slice(), &[2.0, 3.1, -0.6]);
        assert_eq!(logits.rows()[1].as_slice(), &[4.0, 5.1, -0.6]);
        Ok(())
    }

    #[test]
    fn tiny_transformer_parameters_forward_maps_hidden_and_mask_to_sequence_logits() -> CtResult<()>
    {
        let parameters = tiny_transformer_parameters()?;
        let hidden = HiddenSequence::new(vec![vec![1.0, 0.0], vec![0.0, 1.0]])?;
        let mask = AttentionMask::new(vec![vec![true, false], vec![true, true]])?;

        let logits = parameters.apply(Product::new(hidden, mask))?;

        assert_eq!(parameters.model_dimension().value(), 2);
        assert_eq!(parameters.max_sequence_len().value(), 2);
        assert_eq!(logits.sequence_len().value(), 2);
        assert_eq!(logits.vocab_size().value(), 3);
        assert!(
            logits
                .rows()
                .iter()
                .flat_map(|row| row.as_slice())
                .all(|value| value.is_finite())
        );

        Ok(())
    }

    #[test]
    fn tiny_transformer_parameters_rejects_readout_dimension_mismatch() -> CtResult<()> {
        let positional_encoding = PositionalEncoding::new(vec![vec![0.1, 0.0], vec![0.0, 0.1]])?;
        let block = tiny_masked_multi_head_block()?;
        let readout = TransformerReadout::new(vec![vec![1.0], vec![0.0], vec![0.5]], vec![0.0])?;

        assert!(matches!(
            TinyTransformerParameters::new(positional_encoding, block, readout),
            Err(CtError::ShapeMismatch {
                op: "tiny transformer parameters readout",
                ..
            })
        ));

        Ok(())
    }

    #[test]
    fn transformer_training_state_records_updated_parameters_and_step_count() -> CtResult<()> {
        let initial_parameters = tiny_transformer_parameters()?;
        let updated_parameters = tiny_transformer_parameters()?;
        let state = TransformerTrainingState::new(initial_parameters, LearningRate::new(0.25)?);

        let next_state = state.record_updated_parameters(updated_parameters.clone());

        assert_eq!(next_state.parameters(), &updated_parameters);
        assert_eq!(next_state.learning_rate().value(), 0.25);
        assert_eq!(next_state.step_count().value(), 1);
        Ok(())
    }

    #[test]
    fn transformer_training_state_forward_uses_structured_parameters() -> CtResult<()> {
        let state =
            TransformerTrainingState::new(tiny_transformer_parameters()?, LearningRate::new(0.1)?);
        let hidden = HiddenSequence::new(vec![vec![1.0, 0.0], vec![0.0, 1.0]])?;
        let mask = AttentionMask::new(vec![vec![true, false], vec![true, true]])?;

        let logits = state.apply(Product::new(hidden, mask))?;

        assert_eq!(logits.sequence_len().value(), 2);
        assert_eq!(logits.vocab_size().value(), 3);
        assert_eq!(state.step_count().value(), 0);
        Ok(())
    }

    #[test]
    fn transformer_readout_training_example_rejects_target_length_mismatch() -> CtResult<()> {
        let hidden = HiddenSequence::new(vec![vec![1.0, 0.0], vec![0.0, 1.0]])?;
        let mask = AttentionMask::new(vec![vec![true, false], vec![true, true]])?;
        let targets = TokenSequence::from_indices([0])?;

        assert!(matches!(
            TransformerReadoutTrainingExample::new(hidden, mask, targets),
            Err(CtError::ShapeMismatch {
                op: "transformer readout training targets",
                ..
            })
        ));

        Ok(())
    }

    #[test]
    fn transformer_readout_training_example_rejects_mask_shape_mismatch() -> CtResult<()> {
        let hidden = HiddenSequence::new(vec![vec![1.0, 0.0], vec![0.0, 1.0]])?;
        let mask = AttentionMask::new(vec![vec![true, true, true], vec![true, true, true]])?;
        let targets = TokenSequence::from_indices([0, 1])?;

        assert!(matches!(
            TransformerReadoutTrainingExample::new(hidden, mask, targets),
            Err(CtError::ShapeMismatch {
                op: "transformer readout training mask",
                ..
            })
        ));

        Ok(())
    }

    #[test]
    fn transformer_readout_train_step_reduces_sequence_loss() -> CtResult<()> {
        let dataset = tiny_transformer_training_set()?;
        let state =
            TransformerTrainingState::new(tiny_transformer_parameters()?, LearningRate::new(0.5)?);
        let before = transformer_readout_average_loss(&state, &dataset)?;
        let train_step = TransformerReadoutTrainStep::new(dataset.clone());

        let trained =
            crate::category::apply_endomorphism_n_times(&train_step, state, StepCount::new(40))?;
        let after = transformer_readout_average_loss(&trained, &dataset)?;

        assert!(after.value() < before.value());
        assert_eq!(trained.step_count().value(), 40);
        Ok(())
    }

    #[test]
    fn transformer_readout_train_step_rejects_target_outside_vocabulary() -> CtResult<()> {
        let example = TransformerReadoutTrainingExample::new(
            HiddenSequence::new(vec![vec![1.0, 0.0], vec![0.0, 1.0]])?,
            AttentionMask::new(vec![vec![true, false], vec![true, true]])?,
            TokenSequence::from_indices([0, 9])?,
        )?;
        let dataset = TransformerReadoutTrainingSet::new([example])?;
        let state =
            TransformerTrainingState::new(tiny_transformer_parameters()?, LearningRate::new(0.1)?);
        let train_step = TransformerReadoutTrainStep::new(dataset);

        assert!(matches!(
            train_step.apply(state),
            Err(CtError::OutOfRange {
                kind: "sequence target",
                index: 9,
                limit: 3,
            })
        ));

        Ok(())
    }

    #[test]
    fn transformer_feed_forward_training_example_rejects_target_shape_mismatch() -> CtResult<()> {
        let input = HiddenSequence::new(vec![vec![1.0, 0.0], vec![0.0, 1.0]])?;
        let target = HiddenSequence::new(vec![vec![1.0, 0.0, 0.0], vec![0.0, 1.0, 0.0]])?;

        assert!(matches!(
            TransformerFeedForwardTrainingExample::new(input, target),
            Err(CtError::ShapeMismatch {
                op: "transformer feed-forward training dimension",
                ..
            })
        ));

        Ok(())
    }

    #[test]
    fn transformer_feed_forward_training_set_rejects_empty_input() {
        assert!(matches!(
            TransformerFeedForwardTrainingSet::new([]),
            Err(CtError::EmptyInput("transformer feed-forward training set"))
        ));
    }

    #[test]
    fn transformer_feed_forward_train_step_reduces_local_hidden_loss() -> CtResult<()> {
        let dataset = tiny_feed_forward_training_set()?;
        let state =
            TransformerTrainingState::new(tiny_transformer_parameters()?, LearningRate::new(0.2)?);
        let before = transformer_feed_forward_average_loss(&state, &dataset)?;
        let train_step = TransformerFeedForwardTrainStep::new(dataset.clone());

        let trained =
            crate::category::apply_endomorphism_n_times(&train_step, state, StepCount::new(60))?;
        let after = transformer_feed_forward_average_loss(&trained, &dataset)?;

        assert!(after.value() < before.value());
        assert_eq!(trained.step_count().value(), 60);
        Ok(())
    }

    #[test]
    fn transformer_block_training_example_rejects_mask_shape_mismatch() -> CtResult<()> {
        let hidden = HiddenSequence::new(vec![vec![1.0, 0.0], vec![0.0, 1.0]])?;
        let mask = AttentionMask::new(vec![vec![true, true, true], vec![true, true, true]])?;
        let targets = TokenSequence::from_indices([0, 1])?;

        assert!(matches!(
            TransformerBlockTrainingExample::new(hidden, mask, targets),
            Err(CtError::ShapeMismatch {
                op: "transformer block training mask",
                ..
            })
        ));

        Ok(())
    }

    #[test]
    fn transformer_block_training_set_rejects_empty_input() {
        assert!(matches!(
            TransformerBlockTrainingSet::new([]),
            Err(CtError::EmptyInput("transformer block training set"))
        ));
    }

    #[test]
    fn transformer_block_train_step_rejects_target_outside_vocabulary() -> CtResult<()> {
        let example = TransformerBlockTrainingExample::new(
            HiddenSequence::new(vec![vec![1.0, 0.0], vec![0.0, 1.0]])?,
            AttentionMask::new(vec![vec![true, false], vec![true, true]])?,
            TokenSequence::from_indices([0, 9])?,
        )?;
        let dataset = TransformerBlockTrainingSet::new([example])?;
        let state =
            TransformerTrainingState::new(tiny_transformer_parameters()?, LearningRate::new(0.1)?);
        let train_step = TransformerBlockTrainStep::new(dataset);

        assert!(matches!(
            train_step.apply(state),
            Err(CtError::OutOfRange {
                kind: "sequence target",
                index: 9,
                limit: 3,
            })
        ));

        Ok(())
    }

    #[test]
    fn transformer_block_train_step_reduces_sequence_loss() -> CtResult<()> {
        let dataset = tiny_transformer_block_training_set()?;
        let state =
            TransformerTrainingState::new(tiny_transformer_parameters()?, LearningRate::new(0.2)?);
        let before = transformer_block_average_loss(&state, &dataset)?;
        let train_step = TransformerBlockTrainStep::new(dataset.clone());

        let trained =
            crate::category::apply_endomorphism_n_times(&train_step, state, StepCount::new(40))?;
        let after = transformer_block_average_loss(&trained, &dataset)?;

        assert!(after.value() < before.value());
        assert_eq!(trained.step_count().value(), 40);
        Ok(())
    }

    #[test]
    fn transformer_block_train_step_updates_attention_output_projection() -> CtResult<()> {
        let dataset = tiny_transformer_block_training_set()?;
        let state =
            TransformerTrainingState::new(tiny_transformer_parameters()?, LearningRate::new(0.2)?);
        let before = state.parameters().output_projection().weight().to_vec();
        let train_step = TransformerBlockTrainStep::new(dataset);

        let trained = train_step.apply(state)?;
        let after = trained.parameters().output_projection().weight().to_vec();

        assert_ne!(before, after);
        assert_eq!(trained.step_count().value(), 1);
        Ok(())
    }

    #[test]
    fn transformer_block_train_step_updates_layer_norm_parameters() -> CtResult<()> {
        let dataset = tiny_transformer_block_training_set()?;
        let state =
            TransformerTrainingState::new(tiny_transformer_parameters()?, LearningRate::new(0.2)?);
        let before_attention_scale = state
            .parameters()
            .attention_norm()
            .parameters()
            .scale()
            .to_vec();
        let before_feed_forward_shift = state
            .parameters()
            .feed_forward_norm()
            .parameters()
            .shift()
            .to_vec();
        let train_step = TransformerBlockTrainStep::new(dataset);

        let trained = train_step.apply(state)?;
        let after_attention_scale = trained
            .parameters()
            .attention_norm()
            .parameters()
            .scale()
            .to_vec();
        let after_feed_forward_shift = trained
            .parameters()
            .feed_forward_norm()
            .parameters()
            .shift()
            .to_vec();

        assert_ne!(before_attention_scale, after_attention_scale);
        assert_ne!(before_feed_forward_shift, after_feed_forward_shift);
        assert_eq!(trained.step_count().value(), 1);
        Ok(())
    }

    #[test]
    fn transformer_block_train_step_updates_query_key_value_projections() -> CtResult<()> {
        let dataset = tiny_transformer_block_training_set()?;
        let state =
            TransformerTrainingState::new(tiny_transformer_parameters()?, LearningRate::new(0.2)?);
        let before_query = state
            .parameters()
            .attention_heads()
            .iter()
            .map(|head| head.query_projection().weight().to_vec())
            .collect::<Vec<_>>();
        let before_key = state
            .parameters()
            .attention_heads()
            .iter()
            .map(|head| head.key_projection().weight().to_vec())
            .collect::<Vec<_>>();
        let before_value = state
            .parameters()
            .attention_heads()
            .iter()
            .map(|head| head.value_projection().weight().to_vec())
            .collect::<Vec<_>>();
        let train_step = TransformerBlockTrainStep::new(dataset);

        let trained = train_step.apply(state)?;
        let after_query = trained
            .parameters()
            .attention_heads()
            .iter()
            .map(|head| head.query_projection().weight().to_vec())
            .collect::<Vec<_>>();
        let after_key = trained
            .parameters()
            .attention_heads()
            .iter()
            .map(|head| head.key_projection().weight().to_vec())
            .collect::<Vec<_>>();
        let after_value = trained
            .parameters()
            .attention_heads()
            .iter()
            .map(|head| head.value_projection().weight().to_vec())
            .collect::<Vec<_>>();

        assert_ne!(before_query, after_query);
        assert_ne!(before_key, after_key);
        assert_ne!(before_value, after_value);
        assert_eq!(trained.step_count().value(), 1);
        Ok(())
    }

    #[test]
    fn transformer_block_train_step_matches_finite_difference_for_attention_projection()
    -> CtResult<()> {
        let dataset = tiny_transformer_block_training_set()?;
        let state =
            TransformerTrainingState::new(tiny_transformer_parameters()?, LearningRate::new(0.2)?);
        let train_step = TransformerBlockTrainStep::new(dataset.clone());
        let trained = train_step.apply(state.clone())?;
        let selection = largest_changed_attention_projection(&state, &trained)?;
        let before_value = attention_projection_weight(&state, selection)?;
        let after_value = attention_projection_weight(&trained, selection)?;
        let inferred_gradient = (before_value - after_value) / state.learning_rate().value();
        let epsilon = 1e-3;
        let loss_plus = transformer_block_average_loss(
            &state_with_attention_projection_weight(&state, selection, before_value + epsilon)?,
            &dataset,
        )?
        .value();
        let loss_minus = transformer_block_average_loss(
            &state_with_attention_projection_weight(&state, selection, before_value - epsilon)?,
            &dataset,
        )?
        .value();
        let finite_difference = (loss_plus - loss_minus) / (2.0 * epsilon);

        assert!(
            (inferred_gradient - finite_difference).abs() < 1e-2,
            "inferred gradient {inferred_gradient} should match finite difference {finite_difference}"
        );
        Ok(())
    }

    #[test]
    fn transformer_block_train_step_matches_finite_difference_for_readout_weight() -> CtResult<()> {
        let dataset = tiny_transformer_block_training_set()?;
        let state =
            TransformerTrainingState::new(tiny_transformer_parameters()?, LearningRate::new(0.2)?);
        let train_step = TransformerBlockTrainStep::new(dataset.clone());
        let trained = train_step.apply(state.clone())?;
        let selection = largest_changed_readout_weight(&state, &trained)?;
        let before_value = readout_weight(&state, selection);
        let after_value = readout_weight(&trained, selection);

        assert_block_gradient_matches_finite_difference(
            &state,
            &dataset,
            before_value,
            after_value,
            |value| state_with_readout_weight(&state, selection, value),
        )
    }

    #[test]
    fn transformer_block_train_step_matches_finite_difference_for_feed_forward_weight()
    -> CtResult<()> {
        let dataset = tiny_transformer_block_training_set()?;
        let state =
            TransformerTrainingState::new(tiny_transformer_parameters()?, LearningRate::new(0.2)?);
        let train_step = TransformerBlockTrainStep::new(dataset.clone());
        let trained = train_step.apply(state.clone())?;
        let selection = largest_changed_feed_forward_weight(&state, &trained)?;
        let before_value = feed_forward_weight(&state, selection);
        let after_value = feed_forward_weight(&trained, selection);

        assert_block_gradient_matches_finite_difference(
            &state,
            &dataset,
            before_value,
            after_value,
            |value| state_with_feed_forward_weight(&state, selection, value),
        )
    }

    #[test]
    fn transformer_block_train_step_matches_finite_difference_for_layer_norm_parameter()
    -> CtResult<()> {
        let dataset = tiny_transformer_block_training_set()?;
        let state =
            TransformerTrainingState::new(tiny_transformer_parameters()?, LearningRate::new(0.2)?);
        let train_step = TransformerBlockTrainStep::new(dataset.clone());
        let trained = train_step.apply(state.clone())?;
        let selection = largest_changed_layer_norm_parameter(&state, &trained)?;
        let before_value = layer_norm_parameter_value(&state, selection);
        let after_value = layer_norm_parameter_value(&trained, selection);

        assert_block_gradient_matches_finite_difference(
            &state,
            &dataset,
            before_value,
            after_value,
            |value| state_with_layer_norm_parameter(&state, selection, value),
        )
    }

    #[test]
    fn transformer_block_train_step_matches_finite_difference_for_readout_bias() -> CtResult<()> {
        let dataset = tiny_transformer_block_training_set()?;
        let state =
            TransformerTrainingState::new(tiny_transformer_parameters()?, LearningRate::new(0.2)?);
        let train_step = TransformerBlockTrainStep::new(dataset.clone());
        let trained = train_step.apply(state.clone())?;
        let selection = largest_changed_vector_index(
            state.parameters().readout().bias(),
            trained.parameters().readout().bias(),
            "changed readout bias",
        )?;
        let before_value = state.parameters().readout().bias()[selection];
        let after_value = trained.parameters().readout().bias()[selection];

        assert_block_gradient_matches_finite_difference(
            &state,
            &dataset,
            before_value,
            after_value,
            |value| state_with_readout_bias(&state, selection, value),
        )
    }

    #[test]
    fn transformer_block_train_step_matches_finite_difference_for_feed_forward_bias() -> CtResult<()>
    {
        let dataset = tiny_transformer_block_training_set()?;
        let state =
            TransformerTrainingState::new(tiny_transformer_parameters()?, LearningRate::new(0.2)?);
        let train_step = TransformerBlockTrainStep::new(dataset.clone());
        let trained = train_step.apply(state.clone())?;
        let selection = largest_changed_feed_forward_bias(&state, &trained)?;
        let before_value = feed_forward_bias(&state, selection);
        let after_value = feed_forward_bias(&trained, selection);

        assert_block_gradient_matches_finite_difference(
            &state,
            &dataset,
            before_value,
            after_value,
            |value| state_with_feed_forward_bias(&state, selection, value),
        )
    }

    #[test]
    fn transformer_block_train_step_matches_finite_difference_for_output_projection_bias()
    -> CtResult<()> {
        let dataset = tiny_transformer_block_training_set()?;
        let state =
            TransformerTrainingState::new(tiny_transformer_parameters()?, LearningRate::new(0.2)?);
        let train_step = TransformerBlockTrainStep::new(dataset.clone());
        let trained = train_step.apply(state.clone())?;
        let selection = largest_changed_vector_index(
            state.parameters().output_projection().bias(),
            trained.parameters().output_projection().bias(),
            "changed attention output projection bias",
        )?;
        let before_value = state.parameters().output_projection().bias()[selection];
        let after_value = trained.parameters().output_projection().bias()[selection];

        assert_block_gradient_matches_finite_difference(
            &state,
            &dataset,
            before_value,
            after_value,
            |value| state_with_output_projection_bias(&state, selection, value),
        )
    }

    #[test]
    fn transformer_block_train_step_matches_finite_difference_for_attention_projection_bias()
    -> CtResult<()> {
        let dataset = tiny_transformer_block_training_set()?;
        let state =
            TransformerTrainingState::new(tiny_transformer_parameters()?, LearningRate::new(0.2)?);
        let train_step = TransformerBlockTrainStep::new(dataset.clone());
        let trained = train_step.apply(state.clone())?;
        let selection = largest_changed_attention_projection_bias(&state, &trained)?;
        let before_value = attention_projection_bias(&state, selection)?;
        let after_value = attention_projection_bias(&trained, selection)?;

        assert_block_gradient_matches_finite_difference(
            &state,
            &dataset,
            before_value,
            after_value,
            |value| state_with_attention_projection_bias(&state, selection, value),
        )
    }

    fn assert_block_gradient_matches_finite_difference(
        state: &TransformerTrainingState,
        dataset: &TransformerBlockTrainingSet,
        before_value: f32,
        after_value: f32,
        mut state_with_value: impl FnMut(f32) -> CtResult<TransformerTrainingState>,
    ) -> CtResult<()> {
        let inferred_gradient = (before_value - after_value) / state.learning_rate().value();
        let epsilon = 1e-3;
        let loss_plus =
            transformer_block_average_loss(&state_with_value(before_value + epsilon)?, dataset)?
                .value();
        let loss_minus =
            transformer_block_average_loss(&state_with_value(before_value - epsilon)?, dataset)?
                .value();
        let finite_difference = (loss_plus - loss_minus) / (2.0 * epsilon);

        assert!(
            (inferred_gradient - finite_difference).abs() < 1e-2,
            "inferred gradient {inferred_gradient} should match finite difference {finite_difference}"
        );
        Ok(())
    }

    fn largest_changed_vector_index(
        before: &[f32],
        after: &[f32],
        label: &'static str,
    ) -> CtResult<usize> {
        let mut selected = None;
        let mut largest_delta = 0.0;

        for (index, (before_value, after_value)) in before.iter().zip(after).enumerate() {
            let delta = (before_value - after_value).abs();

            if delta > largest_delta {
                largest_delta = delta;
                selected = Some(index);
            }
        }

        selected.ok_or(CtError::EmptyInput(label))
    }

    #[derive(Debug, Clone, Copy, PartialEq, Eq)]
    struct MatrixSelection {
        input_index: usize,
        output_index: usize,
    }

    fn largest_changed_matrix_weight(
        before: &[Vec<f32>],
        after: &[Vec<f32>],
        label: &'static str,
    ) -> CtResult<MatrixSelection> {
        let mut selected = None;
        let mut largest_delta = 0.0;

        for (input_index, (before_row, after_row)) in before.iter().zip(after).enumerate() {
            for (output_index, (before_value, after_value)) in
                before_row.iter().zip(after_row).enumerate()
            {
                let delta = (before_value - after_value).abs();

                if delta > largest_delta {
                    largest_delta = delta;
                    selected = Some(MatrixSelection {
                        input_index,
                        output_index,
                    });
                }
            }
        }

        selected.ok_or(CtError::EmptyInput(label))
    }

    fn largest_changed_readout_weight(
        before: &TransformerTrainingState,
        after: &TransformerTrainingState,
    ) -> CtResult<MatrixSelection> {
        largest_changed_matrix_weight(
            before.parameters().readout().weight(),
            after.parameters().readout().weight(),
            "changed readout weight",
        )
    }

    fn readout_weight(state: &TransformerTrainingState, selection: MatrixSelection) -> f32 {
        state.parameters().readout().weight()[selection.input_index][selection.output_index]
    }

    fn state_with_readout_weight(
        state: &TransformerTrainingState,
        selection: MatrixSelection,
        value: f32,
    ) -> CtResult<TransformerTrainingState> {
        let readout = state.parameters().readout();
        let mut weight = readout.weight().to_vec();
        weight[selection.input_index][selection.output_index] = value;
        let readout = TransformerReadout::new(weight, readout.bias().to_vec())?;
        let parameters = state.parameters().clone().with_readout(readout)?;

        Ok(TransformerTrainingState::from_parts(
            parameters,
            state.learning_rate(),
            state.step_count(),
        ))
    }

    fn state_with_readout_bias(
        state: &TransformerTrainingState,
        selection: usize,
        value: f32,
    ) -> CtResult<TransformerTrainingState> {
        let readout = state.parameters().readout();
        let mut bias = readout.bias().to_vec();
        bias[selection] = value;
        let readout = TransformerReadout::new(readout.weight().to_vec(), bias)?;
        let parameters = state.parameters().clone().with_readout(readout)?;

        Ok(TransformerTrainingState::from_parts(
            parameters,
            state.learning_rate(),
            state.step_count(),
        ))
    }

    #[derive(Debug, Clone, Copy, PartialEq, Eq)]
    enum FeedForwardWeightKind {
        First,
        Second,
    }

    #[derive(Debug, Clone, Copy, PartialEq, Eq)]
    struct FeedForwardWeightSelection {
        kind: FeedForwardWeightKind,
        matrix: MatrixSelection,
    }

    #[derive(Debug, Clone, Copy, PartialEq, Eq)]
    struct FeedForwardBiasSelection {
        kind: FeedForwardWeightKind,
        index: usize,
    }

    fn largest_changed_feed_forward_weight(
        before: &TransformerTrainingState,
        after: &TransformerTrainingState,
    ) -> CtResult<FeedForwardWeightSelection> {
        let first = largest_changed_matrix_weight(
            before.parameters().feed_forward().first_weight(),
            after.parameters().feed_forward().first_weight(),
            "changed first feed-forward weight",
        );
        let second = largest_changed_matrix_weight(
            before.parameters().feed_forward().second_weight(),
            after.parameters().feed_forward().second_weight(),
            "changed second feed-forward weight",
        );

        match (first, second) {
            (Ok(first), Ok(second)) => {
                let first_delta = feed_forward_weight_delta(
                    before,
                    after,
                    FeedForwardWeightSelection {
                        kind: FeedForwardWeightKind::First,
                        matrix: first,
                    },
                );
                let second_delta = feed_forward_weight_delta(
                    before,
                    after,
                    FeedForwardWeightSelection {
                        kind: FeedForwardWeightKind::Second,
                        matrix: second,
                    },
                );

                if first_delta >= second_delta {
                    Ok(FeedForwardWeightSelection {
                        kind: FeedForwardWeightKind::First,
                        matrix: first,
                    })
                } else {
                    Ok(FeedForwardWeightSelection {
                        kind: FeedForwardWeightKind::Second,
                        matrix: second,
                    })
                }
            }
            (Ok(first), Err(_)) => Ok(FeedForwardWeightSelection {
                kind: FeedForwardWeightKind::First,
                matrix: first,
            }),
            (Err(_), Ok(second)) => Ok(FeedForwardWeightSelection {
                kind: FeedForwardWeightKind::Second,
                matrix: second,
            }),
            (Err(_), Err(_)) => Err(CtError::EmptyInput("changed feed-forward weight")),
        }
    }

    fn feed_forward_weight_delta(
        before: &TransformerTrainingState,
        after: &TransformerTrainingState,
        selection: FeedForwardWeightSelection,
    ) -> f32 {
        (feed_forward_weight(before, selection) - feed_forward_weight(after, selection)).abs()
    }

    fn feed_forward_weight(
        state: &TransformerTrainingState,
        selection: FeedForwardWeightSelection,
    ) -> f32 {
        let feed_forward = state.parameters().feed_forward();

        match selection.kind {
            FeedForwardWeightKind::First => feed_forward.first_weight()
                [selection.matrix.input_index][selection.matrix.output_index],
            FeedForwardWeightKind::Second => feed_forward.second_weight()
                [selection.matrix.input_index][selection.matrix.output_index],
        }
    }

    fn state_with_feed_forward_weight(
        state: &TransformerTrainingState,
        selection: FeedForwardWeightSelection,
        value: f32,
    ) -> CtResult<TransformerTrainingState> {
        let feed_forward = state.parameters().feed_forward();
        let mut first_weight = feed_forward.first_weight().to_vec();
        let mut second_weight = feed_forward.second_weight().to_vec();

        match selection.kind {
            FeedForwardWeightKind::First => {
                first_weight[selection.matrix.input_index][selection.matrix.output_index] = value;
            }
            FeedForwardWeightKind::Second => {
                second_weight[selection.matrix.input_index][selection.matrix.output_index] = value;
            }
        }

        let feed_forward = PositionWiseFeedForward::new(
            first_weight,
            feed_forward.first_bias().to_vec(),
            second_weight,
            feed_forward.second_bias().to_vec(),
        )?;
        let parameters = state.parameters().clone().with_feed_forward(feed_forward)?;

        Ok(TransformerTrainingState::from_parts(
            parameters,
            state.learning_rate(),
            state.step_count(),
        ))
    }

    fn largest_changed_feed_forward_bias(
        before: &TransformerTrainingState,
        after: &TransformerTrainingState,
    ) -> CtResult<FeedForwardBiasSelection> {
        let first = largest_changed_vector_index(
            before.parameters().feed_forward().first_bias(),
            after.parameters().feed_forward().first_bias(),
            "changed first feed-forward bias",
        );
        let second = largest_changed_vector_index(
            before.parameters().feed_forward().second_bias(),
            after.parameters().feed_forward().second_bias(),
            "changed second feed-forward bias",
        );

        match (first, second) {
            (Ok(first), Ok(second)) => {
                let first_delta = feed_forward_bias_delta(
                    before,
                    after,
                    FeedForwardBiasSelection {
                        kind: FeedForwardWeightKind::First,
                        index: first,
                    },
                );
                let second_delta = feed_forward_bias_delta(
                    before,
                    after,
                    FeedForwardBiasSelection {
                        kind: FeedForwardWeightKind::Second,
                        index: second,
                    },
                );

                if first_delta >= second_delta {
                    Ok(FeedForwardBiasSelection {
                        kind: FeedForwardWeightKind::First,
                        index: first,
                    })
                } else {
                    Ok(FeedForwardBiasSelection {
                        kind: FeedForwardWeightKind::Second,
                        index: second,
                    })
                }
            }
            (Ok(first), Err(_)) => Ok(FeedForwardBiasSelection {
                kind: FeedForwardWeightKind::First,
                index: first,
            }),
            (Err(_), Ok(second)) => Ok(FeedForwardBiasSelection {
                kind: FeedForwardWeightKind::Second,
                index: second,
            }),
            (Err(_), Err(_)) => Err(CtError::EmptyInput("changed feed-forward bias")),
        }
    }

    fn feed_forward_bias_delta(
        before: &TransformerTrainingState,
        after: &TransformerTrainingState,
        selection: FeedForwardBiasSelection,
    ) -> f32 {
        (feed_forward_bias(before, selection) - feed_forward_bias(after, selection)).abs()
    }

    fn feed_forward_bias(
        state: &TransformerTrainingState,
        selection: FeedForwardBiasSelection,
    ) -> f32 {
        let feed_forward = state.parameters().feed_forward();

        match selection.kind {
            FeedForwardWeightKind::First => feed_forward.first_bias()[selection.index],
            FeedForwardWeightKind::Second => feed_forward.second_bias()[selection.index],
        }
    }

    fn state_with_feed_forward_bias(
        state: &TransformerTrainingState,
        selection: FeedForwardBiasSelection,
        value: f32,
    ) -> CtResult<TransformerTrainingState> {
        let feed_forward = state.parameters().feed_forward();
        let mut first_bias = feed_forward.first_bias().to_vec();
        let mut second_bias = feed_forward.second_bias().to_vec();

        match selection.kind {
            FeedForwardWeightKind::First => {
                first_bias[selection.index] = value;
            }
            FeedForwardWeightKind::Second => {
                second_bias[selection.index] = value;
            }
        }

        let feed_forward = PositionWiseFeedForward::new(
            feed_forward.first_weight().to_vec(),
            first_bias,
            feed_forward.second_weight().to_vec(),
            second_bias,
        )?;
        let parameters = state.parameters().clone().with_feed_forward(feed_forward)?;

        Ok(TransformerTrainingState::from_parts(
            parameters,
            state.learning_rate(),
            state.step_count(),
        ))
    }

    fn state_with_output_projection_bias(
        state: &TransformerTrainingState,
        selection: usize,
        value: f32,
    ) -> CtResult<TransformerTrainingState> {
        let output_projection = state.parameters().output_projection();
        let mut bias = output_projection.bias().to_vec();
        bias[selection] = value;
        let output_projection =
            AttentionOutputProjection::new(output_projection.weight().to_vec(), bias)?;
        let parameters = state
            .parameters()
            .clone()
            .with_output_projection(output_projection)?;

        Ok(TransformerTrainingState::from_parts(
            parameters,
            state.learning_rate(),
            state.step_count(),
        ))
    }

    #[derive(Debug, Clone, Copy, PartialEq, Eq)]
    enum LayerNormParameterKind {
        AttentionScale,
        AttentionShift,
        FeedForwardScale,
        FeedForwardShift,
    }

    #[derive(Debug, Clone, Copy, PartialEq, Eq)]
    struct LayerNormParameterSelection {
        kind: LayerNormParameterKind,
        feature_index: usize,
    }

    fn largest_changed_layer_norm_parameter(
        before: &TransformerTrainingState,
        after: &TransformerTrainingState,
    ) -> CtResult<LayerNormParameterSelection> {
        let mut selected = None;
        let mut largest_delta = 0.0;

        for kind in [
            LayerNormParameterKind::AttentionScale,
            LayerNormParameterKind::AttentionShift,
            LayerNormParameterKind::FeedForwardScale,
            LayerNormParameterKind::FeedForwardShift,
        ] {
            let before_values = layer_norm_parameter_values(before, kind);
            let after_values = layer_norm_parameter_values(after, kind);

            for (feature_index, (before_value, after_value)) in
                before_values.iter().zip(after_values).enumerate()
            {
                let delta = (before_value - after_value).abs();

                if delta > largest_delta {
                    largest_delta = delta;
                    selected = Some(LayerNormParameterSelection {
                        kind,
                        feature_index,
                    });
                }
            }
        }

        selected.ok_or(CtError::EmptyInput("changed layer norm parameter"))
    }

    fn layer_norm_parameter_values(
        state: &TransformerTrainingState,
        kind: LayerNormParameterKind,
    ) -> &[f32] {
        match kind {
            LayerNormParameterKind::AttentionScale => {
                state.parameters().attention_norm().parameters().scale()
            }
            LayerNormParameterKind::AttentionShift => {
                state.parameters().attention_norm().parameters().shift()
            }
            LayerNormParameterKind::FeedForwardScale => {
                state.parameters().feed_forward_norm().parameters().scale()
            }
            LayerNormParameterKind::FeedForwardShift => {
                state.parameters().feed_forward_norm().parameters().shift()
            }
        }
    }

    fn layer_norm_parameter_value(
        state: &TransformerTrainingState,
        selection: LayerNormParameterSelection,
    ) -> f32 {
        layer_norm_parameter_values(state, selection.kind)[selection.feature_index]
    }

    fn state_with_layer_norm_parameter(
        state: &TransformerTrainingState,
        selection: LayerNormParameterSelection,
        value: f32,
    ) -> CtResult<TransformerTrainingState> {
        let attention_parameters = state.parameters().attention_norm().parameters();
        let feed_forward_parameters = state.parameters().feed_forward_norm().parameters();
        let mut attention_scale = attention_parameters.scale().to_vec();
        let mut attention_shift = attention_parameters.shift().to_vec();
        let mut feed_forward_scale = feed_forward_parameters.scale().to_vec();
        let mut feed_forward_shift = feed_forward_parameters.shift().to_vec();

        match selection.kind {
            LayerNormParameterKind::AttentionScale => {
                attention_scale[selection.feature_index] = value;
            }
            LayerNormParameterKind::AttentionShift => {
                attention_shift[selection.feature_index] = value;
            }
            LayerNormParameterKind::FeedForwardScale => {
                feed_forward_scale[selection.feature_index] = value;
            }
            LayerNormParameterKind::FeedForwardShift => {
                feed_forward_shift[selection.feature_index] = value;
            }
        }

        let attention_norm = LayerNormalization::new(LayerNormParameters::new(
            attention_scale,
            attention_shift,
            attention_parameters.epsilon(),
        )?);
        let feed_forward_norm = LayerNormalization::new(LayerNormParameters::new(
            feed_forward_scale,
            feed_forward_shift,
            feed_forward_parameters.epsilon(),
        )?);
        let parameters = state
            .parameters()
            .clone()
            .with_layer_norms(attention_norm, feed_forward_norm)?;

        Ok(TransformerTrainingState::from_parts(
            parameters,
            state.learning_rate(),
            state.step_count(),
        ))
    }

    #[derive(Debug, Clone, Copy, PartialEq, Eq)]
    enum AttentionProjectionKind {
        Query,
        Key,
        Value,
    }

    #[derive(Debug, Clone, Copy, PartialEq, Eq)]
    struct AttentionProjectionSelection {
        head_index: usize,
        kind: AttentionProjectionKind,
        input_index: usize,
        output_index: usize,
    }

    #[derive(Debug, Clone, Copy, PartialEq, Eq)]
    struct AttentionProjectionBiasSelection {
        head_index: usize,
        kind: AttentionProjectionKind,
        output_index: usize,
    }

    fn largest_changed_attention_projection(
        before: &TransformerTrainingState,
        after: &TransformerTrainingState,
    ) -> CtResult<AttentionProjectionSelection> {
        let head_count = before.parameters().attention_heads().len();
        let mut selected = None;
        let mut largest_delta = 0.0;

        for head_index in 0..head_count {
            for kind in [
                AttentionProjectionKind::Query,
                AttentionProjectionKind::Key,
                AttentionProjectionKind::Value,
            ] {
                let before_weight = attention_projection_weight_matrix(before, head_index, kind)?;
                let after_weight = attention_projection_weight_matrix(after, head_index, kind)?;

                for (input_index, (before_row, after_row)) in
                    before_weight.iter().zip(after_weight).enumerate()
                {
                    for (output_index, (before_value, after_value)) in
                        before_row.iter().zip(after_row).enumerate()
                    {
                        let delta = (before_value - after_value).abs();

                        if delta > largest_delta {
                            largest_delta = delta;
                            selected = Some(AttentionProjectionSelection {
                                head_index,
                                kind,
                                input_index,
                                output_index,
                            });
                        }
                    }
                }
            }
        }

        selected.ok_or(CtError::EmptyInput("changed attention projection"))
    }

    fn attention_projection_weight(
        state: &TransformerTrainingState,
        selection: AttentionProjectionSelection,
    ) -> CtResult<f32> {
        let weight =
            attention_projection_weight_matrix(state, selection.head_index, selection.kind)?;

        Ok(weight[selection.input_index][selection.output_index])
    }

    fn attention_projection_weight_matrix(
        state: &TransformerTrainingState,
        head_index: usize,
        kind: AttentionProjectionKind,
    ) -> CtResult<&[Vec<f32>]> {
        let head =
            state
                .parameters()
                .attention_heads()
                .get(head_index)
                .ok_or(CtError::OutOfRange {
                    kind: "attention head",
                    index: head_index,
                    limit: state.parameters().attention_heads().len(),
                })?;

        Ok(match kind {
            AttentionProjectionKind::Query => head.query_projection().weight(),
            AttentionProjectionKind::Key => head.key_projection().weight(),
            AttentionProjectionKind::Value => head.value_projection().weight(),
        })
    }

    fn largest_changed_attention_projection_bias(
        before: &TransformerTrainingState,
        after: &TransformerTrainingState,
    ) -> CtResult<AttentionProjectionBiasSelection> {
        let head_count = before.parameters().attention_heads().len();
        let mut selected = None;
        let mut largest_delta = 0.0;

        for head_index in 0..head_count {
            for kind in [
                AttentionProjectionKind::Query,
                AttentionProjectionKind::Key,
                AttentionProjectionKind::Value,
            ] {
                let before_bias = attention_projection_bias_values(before, head_index, kind)?;
                let after_bias = attention_projection_bias_values(after, head_index, kind)?;

                for (output_index, (before_value, after_value)) in
                    before_bias.iter().zip(after_bias).enumerate()
                {
                    let delta = (before_value - after_value).abs();

                    if delta > largest_delta {
                        largest_delta = delta;
                        selected = Some(AttentionProjectionBiasSelection {
                            head_index,
                            kind,
                            output_index,
                        });
                    }
                }
            }
        }

        selected.ok_or(CtError::EmptyInput("changed attention projection bias"))
    }

    fn attention_projection_bias(
        state: &TransformerTrainingState,
        selection: AttentionProjectionBiasSelection,
    ) -> CtResult<f32> {
        let bias = attention_projection_bias_values(state, selection.head_index, selection.kind)?;

        Ok(bias[selection.output_index])
    }

    fn attention_projection_bias_values(
        state: &TransformerTrainingState,
        head_index: usize,
        kind: AttentionProjectionKind,
    ) -> CtResult<&[f32]> {
        let head =
            state
                .parameters()
                .attention_heads()
                .get(head_index)
                .ok_or(CtError::OutOfRange {
                    kind: "attention head",
                    index: head_index,
                    limit: state.parameters().attention_heads().len(),
                })?;

        Ok(match kind {
            AttentionProjectionKind::Query => head.query_projection().bias(),
            AttentionProjectionKind::Key => head.key_projection().bias(),
            AttentionProjectionKind::Value => head.value_projection().bias(),
        })
    }

    fn state_with_attention_projection_weight(
        state: &TransformerTrainingState,
        selection: AttentionProjectionSelection,
        value: f32,
    ) -> CtResult<TransformerTrainingState> {
        let mut heads = state.parameters().attention_heads().to_vec();
        let head = heads.get(selection.head_index).ok_or(CtError::OutOfRange {
            kind: "attention head",
            index: selection.head_index,
            limit: heads.len(),
        })?;
        let mut query_weight = head.query_projection().weight().to_vec();
        let mut key_weight = head.key_projection().weight().to_vec();
        let mut value_weight = head.value_projection().weight().to_vec();

        match selection.kind {
            AttentionProjectionKind::Query => {
                query_weight[selection.input_index][selection.output_index] = value;
            }
            AttentionProjectionKind::Key => {
                key_weight[selection.input_index][selection.output_index] = value;
            }
            AttentionProjectionKind::Value => {
                value_weight[selection.input_index][selection.output_index] = value;
            }
        }

        heads[selection.head_index] = SelfAttentionHead::new(
            HiddenToQuery::new(query_weight, head.query_projection().bias().to_vec())?,
            HiddenToKey::new(key_weight, head.key_projection().bias().to_vec())?,
            HiddenToValue::new(value_weight, head.value_projection().bias().to_vec())?,
        )?;
        let parameters = state.parameters().clone().with_attention_heads(heads)?;

        Ok(TransformerTrainingState::from_parts(
            parameters,
            state.learning_rate(),
            state.step_count(),
        ))
    }

    fn state_with_attention_projection_bias(
        state: &TransformerTrainingState,
        selection: AttentionProjectionBiasSelection,
        value: f32,
    ) -> CtResult<TransformerTrainingState> {
        let mut heads = state.parameters().attention_heads().to_vec();
        let head = heads.get(selection.head_index).ok_or(CtError::OutOfRange {
            kind: "attention head",
            index: selection.head_index,
            limit: heads.len(),
        })?;
        let mut query_bias = head.query_projection().bias().to_vec();
        let mut key_bias = head.key_projection().bias().to_vec();
        let mut value_bias = head.value_projection().bias().to_vec();

        match selection.kind {
            AttentionProjectionKind::Query => {
                query_bias[selection.output_index] = value;
            }
            AttentionProjectionKind::Key => {
                key_bias[selection.output_index] = value;
            }
            AttentionProjectionKind::Value => {
                value_bias[selection.output_index] = value;
            }
        }

        heads[selection.head_index] = SelfAttentionHead::new(
            HiddenToQuery::new(head.query_projection().weight().to_vec(), query_bias)?,
            HiddenToKey::new(head.key_projection().weight().to_vec(), key_bias)?,
            HiddenToValue::new(head.value_projection().weight().to_vec(), value_bias)?,
        )?;
        let parameters = state.parameters().clone().with_attention_heads(heads)?;

        Ok(TransformerTrainingState::from_parts(
            parameters,
            state.learning_rate(),
            state.step_count(),
        ))
    }

    fn tiny_single_head_block() -> CtResult<SingleHeadTransformerBlock> {
        let query = HiddenToQuery::new(vec![vec![1.0, 0.0], vec![0.0, 1.0]], vec![0.0, 0.0])?;
        let key = HiddenToKey::new(vec![vec![1.0, 0.0], vec![0.0, 1.0]], vec![0.0, 0.0])?;
        let value = HiddenToValue::new(vec![vec![1.0, 0.0], vec![0.0, 1.0]], vec![0.0, 0.0])?;
        let output_projection =
            AttentionOutputProjection::new(vec![vec![1.0, 0.0], vec![0.0, 1.0]], vec![0.0, 0.0])?;
        let model_dimension = ModelDimension::new(2)?;
        let attention_norm =
            LayerNormalization::new(LayerNormParameters::identity(model_dimension));
        let feed_forward = PositionWiseFeedForward::new(
            vec![vec![1.0, 0.0], vec![0.0, 1.0]],
            vec![0.0, 0.0],
            vec![vec![1.0, 0.0], vec![0.0, 1.0]],
            vec![0.0, 0.0],
        )?;
        let feed_forward_norm =
            LayerNormalization::new(LayerNormParameters::identity(model_dimension));

        SingleHeadTransformerBlock::new(
            query,
            key,
            value,
            output_projection,
            attention_norm,
            feed_forward,
            feed_forward_norm,
        )
    }

    fn tiny_multi_head_block() -> CtResult<MultiHeadTransformerBlock> {
        let model_dimension = ModelDimension::new(2)?;

        MultiHeadTransformerBlock::new(
            vec![
                tiny_self_attention_head_first_feature()?,
                tiny_self_attention_head_second_feature()?,
            ],
            AttentionOutputProjection::new(vec![vec![1.0, 0.0], vec![0.0, 1.0]], vec![0.0, 0.0])?,
            LayerNormalization::new(LayerNormParameters::identity(model_dimension)),
            identity_feed_forward()?,
            LayerNormalization::new(LayerNormParameters::identity(model_dimension)),
        )
    }

    fn tiny_masked_multi_head_block() -> CtResult<MaskedMultiHeadTransformerBlock> {
        let model_dimension = ModelDimension::new(2)?;

        MaskedMultiHeadTransformerBlock::new(
            vec![
                tiny_self_attention_head_first_feature()?,
                tiny_self_attention_head_second_feature()?,
            ],
            AttentionOutputProjection::new(vec![vec![1.0, 0.0], vec![0.0, 1.0]], vec![0.0, 0.0])?,
            LayerNormalization::new(LayerNormParameters::identity(model_dimension)),
            identity_feed_forward()?,
            LayerNormalization::new(LayerNormParameters::identity(model_dimension)),
        )
    }

    fn tiny_transformer_parameters() -> CtResult<TinyTransformerParameters> {
        TinyTransformerParameters::new(
            PositionalEncoding::new(vec![vec![0.1, 0.0], vec![0.0, 0.1]])?,
            tiny_masked_multi_head_block()?,
            TransformerReadout::new(
                vec![vec![1.0, 0.0, 0.5], vec![0.0, 1.0, -0.5]],
                vec![0.0, 0.0, 0.0],
            )?,
        )
    }

    fn tiny_transformer_training_set() -> CtResult<TransformerReadoutTrainingSet> {
        let example = TransformerReadoutTrainingExample::new(
            HiddenSequence::new(vec![vec![1.0, 0.0], vec![0.0, 1.0]])?,
            AttentionMask::new(vec![vec![true, false], vec![true, true]])?,
            TokenSequence::from_indices([0, 1])?,
        )?;

        TransformerReadoutTrainingSet::new([example])
    }

    fn tiny_feed_forward_training_set() -> CtResult<TransformerFeedForwardTrainingSet> {
        let example = TransformerFeedForwardTrainingExample::new(
            HiddenSequence::new(vec![vec![1.0, 0.0], vec![0.0, 1.0]])?,
            HiddenSequence::new(vec![vec![2.0, 0.0], vec![0.0, 2.0]])?,
        )?;

        TransformerFeedForwardTrainingSet::new([example])
    }

    fn tiny_transformer_block_training_set() -> CtResult<TransformerBlockTrainingSet> {
        let example = TransformerBlockTrainingExample::new(
            HiddenSequence::new(vec![vec![1.0, 0.0], vec![0.0, 1.0]])?,
            AttentionMask::new(vec![vec![true, false], vec![true, true]])?,
            TokenSequence::from_indices([0, 1])?,
        )?;

        TransformerBlockTrainingSet::new([example])
    }

    fn tiny_self_attention_head_first_feature() -> CtResult<SelfAttentionHead> {
        SelfAttentionHead::new(
            HiddenToQuery::new(vec![vec![1.0], vec![0.0]], vec![0.0])?,
            HiddenToKey::new(vec![vec![1.0], vec![0.0]], vec![0.0])?,
            HiddenToValue::new(vec![vec![1.0], vec![0.0]], vec![0.0])?,
        )
    }

    fn tiny_self_attention_head_second_feature() -> CtResult<SelfAttentionHead> {
        SelfAttentionHead::new(
            HiddenToQuery::new(vec![vec![0.0], vec![1.0]], vec![0.0])?,
            HiddenToKey::new(vec![vec![0.0], vec![1.0]], vec![0.0])?,
            HiddenToValue::new(vec![vec![0.0], vec![1.0]], vec![0.0])?,
        )
    }

    fn identity_feed_forward() -> CtResult<PositionWiseFeedForward> {
        PositionWiseFeedForward::new(
            vec![vec![1.0, 0.0], vec![0.0, 1.0]],
            vec![0.0, 0.0],
            vec![vec![1.0, 0.0], vec![0.0, 1.0]],
            vec![0.0, 0.0],
        )
    }
}

The full runnable companion is:

Source snapshot: examples/06_attention_scores.rs
use category_theory_transformer_rs::{
    AttentionHeadOutputs, AttentionMask, AttentionOutput, AttentionOutputProjection,
    AttentionSoftmax, ConcatenateHeads, CtResult, HiddenSequence, HiddenToKey, HiddenToQuery,
    HiddenToValue, KeySequence, LayerNormParameters, LayerNormalization, LearningRate,
    MaskedAttentionScores, MaskedMultiHeadTransformerBlock, Morphism, MultiHeadTransformerBlock,
    PositionWiseFeedForward, PositionalEncoding, Product, QuerySequence, ResidualConnection,
    ScaledDotProductScores, SelfAttentionHead, SingleHeadTransformerBlock,
    TinyTransformerParameters, TokenSequence, TransformerBlockTrainStep,
    TransformerBlockTrainingExample, TransformerBlockTrainingSet, TransformerFeedForwardTrainStep,
    TransformerFeedForwardTrainingExample, TransformerFeedForwardTrainingSet, TransformerReadout,
    TransformerReadoutTrainStep, TransformerReadoutTrainingExample, TransformerReadoutTrainingSet,
    TransformerTrainingState, ValueSequence, WeightedValueMixing, transformer_block_average_loss,
    transformer_feed_forward_average_loss, transformer_readout_average_loss,
};

fn main() -> CtResult<()> {
    let queries = QuerySequence::new(vec![vec![1.0, 0.0], vec![0.0, 1.0]])?;
    let keys = KeySequence::new(vec![vec![1.0, 0.0], vec![0.0, 1.0], vec![1.0, 1.0]])?;
    let values = ValueSequence::new(vec![vec![1.0, 10.0], vec![2.0, 20.0], vec![3.0, 30.0]])?;
    let mask = AttentionMask::new(vec![vec![true, false, true], vec![true, true, true]])?;

    println!("Q/K/V source diagnostic:");
    println!("query rows own score rows; key/value rows own score columns");
    println!(
        "self-attention shares the hidden source before projection; projected roles stay distinct"
    );
    println!("mask polarity here: true = allowed, false = blocked\n");

    let scores = ScaledDotProductScores.apply(Product::new(queries, keys))?;
    let masked_scores = MaskedAttentionScores.apply(Product::new(scores, mask))?;
    let weights = AttentionSoftmax.apply(masked_scores)?;

    println!(
        "attention shape: {} query positions x {} key positions",
        weights.query_len().value(),
        weights.key_len().value()
    );

    for (query_position, row) in weights.rows().iter().enumerate() {
        println!("query {query_position} attends with {:?}", row.as_slice());
    }

    let output = WeightedValueMixing.apply(Product::new(weights, values))?;

    for (query_position, row) in output.rows().iter().enumerate() {
        println!("query {query_position} output vector {:?}", row.as_slice());
    }

    let second_head = AttentionOutput::new(vec![vec![10.0, 1.0], vec![20.0, 2.0]])?;
    let head_outputs = AttentionHeadOutputs::new(vec![output, second_head])?;
    let multi_head = ConcatenateHeads.apply(head_outputs)?;

    println!(
        "multi-head shape: {} heads x {} features -> model dimension {}",
        multi_head.head_count().value(),
        multi_head.head_dimension().value(),
        multi_head.model_dimension().value()
    );

    for (query_position, row) in multi_head.rows().iter().enumerate() {
        println!("query {query_position} multi-head row {:?}", row.as_slice());
    }

    let output_projection = AttentionOutputProjection::new(
        vec![
            vec![1.0, 0.0],
            vec![0.0, 0.1],
            vec![0.5, 0.0],
            vec![0.0, 1.0],
        ],
        vec![0.0, 0.0],
    )?;
    let projected = output_projection.apply(multi_head)?;

    println!(
        "projected attention shape: {} positions x model dimension {}",
        projected.sequence_len().value(),
        projected.model_dimension().value()
    );

    for (query_position, row) in projected.rows().iter().enumerate() {
        println!(
            "query {query_position} projected attention row {:?}",
            row.as_slice()
        );
    }

    let hidden_input = HiddenSequence::new(vec![vec![0.5, 0.5], vec![1.0, 1.0]])?;
    let residual = ResidualConnection.apply(Product::new(hidden_input, projected))?;

    println!(
        "residual shape: {} positions x model dimension {}",
        residual.sequence_len().value(),
        residual.model_dimension().value()
    );

    for (query_position, row) in residual.rows().iter().enumerate() {
        println!("query {query_position} residual row {:?}", row.as_slice());
    }

    let layer_norm =
        LayerNormalization::new(LayerNormParameters::identity(residual.model_dimension()));
    let normalized = layer_norm.apply(residual)?;

    println!(
        "normalized shape: {} positions x model dimension {}",
        normalized.sequence_len().value(),
        normalized.model_dimension().value()
    );

    for (query_position, row) in normalized.rows().iter().enumerate() {
        println!("query {query_position} normalized row {:?}", row.as_slice());
    }

    let feed_forward = PositionWiseFeedForward::new(
        vec![vec![1.0, -1.0, 0.5], vec![0.0, 1.0, 0.5]],
        vec![0.0, 0.0, 0.0],
        vec![vec![1.0, 0.0], vec![0.0, 1.0], vec![0.5, 0.5]],
        vec![0.0, 0.0],
    )?;
    let fed_forward = feed_forward.apply(normalized)?;

    println!(
        "feed-forward shape: {} positions x model dimension {}",
        fed_forward.sequence_len().value(),
        fed_forward.model_dimension().value()
    );

    for (query_position, row) in fed_forward.rows().iter().enumerate() {
        println!(
            "query {query_position} feed-forward row {:?}",
            row.as_slice()
        );
    }

    let positional_encoding = PositionalEncoding::new(vec![vec![0.1, 0.0], vec![0.0, 0.1]])?;
    let positioned_hidden =
        positional_encoding.apply(HiddenSequence::new(vec![vec![1.0, 0.0], vec![0.0, 1.0]])?)?;

    println!(
        "positioned hidden shape: {} positions x model dimension {}",
        positioned_hidden.sequence_len().value(),
        positioned_hidden.model_dimension().value()
    );

    let block = SingleHeadTransformerBlock::new(
        HiddenToQuery::new(vec![vec![1.0, 0.0], vec![0.0, 1.0]], vec![0.0, 0.0])?,
        HiddenToKey::new(vec![vec![1.0, 0.0], vec![0.0, 1.0]], vec![0.0, 0.0])?,
        HiddenToValue::new(vec![vec![1.0, 0.0], vec![0.0, 1.0]], vec![0.0, 0.0])?,
        AttentionOutputProjection::new(vec![vec![1.0, 0.0], vec![0.0, 1.0]], vec![0.0, 0.0])?,
        LayerNormalization::new(LayerNormParameters::identity(fed_forward.model_dimension())),
        PositionWiseFeedForward::new(
            vec![vec![1.0, 0.0], vec![0.0, 1.0]],
            vec![0.0, 0.0],
            vec![vec![1.0, 0.0], vec![0.0, 1.0]],
            vec![0.0, 0.0],
        )?,
        LayerNormalization::new(LayerNormParameters::identity(fed_forward.model_dimension())),
    )?;
    let block_output = block.apply(positioned_hidden.clone())?;

    println!(
        "single-head block shape: {} positions x model dimension {}",
        block_output.sequence_len().value(),
        block_output.model_dimension().value()
    );

    let multi_head_block = MultiHeadTransformerBlock::new(
        vec![
            SelfAttentionHead::new(
                HiddenToQuery::new(vec![vec![1.0], vec![0.0]], vec![0.0])?,
                HiddenToKey::new(vec![vec![1.0], vec![0.0]], vec![0.0])?,
                HiddenToValue::new(vec![vec![1.0], vec![0.0]], vec![0.0])?,
            )?,
            SelfAttentionHead::new(
                HiddenToQuery::new(vec![vec![0.0], vec![1.0]], vec![0.0])?,
                HiddenToKey::new(vec![vec![0.0], vec![1.0]], vec![0.0])?,
                HiddenToValue::new(vec![vec![0.0], vec![1.0]], vec![0.0])?,
            )?,
        ],
        AttentionOutputProjection::new(vec![vec![1.0, 0.0], vec![0.0, 1.0]], vec![0.0, 0.0])?,
        LayerNormalization::new(LayerNormParameters::identity(
            block_output.model_dimension(),
        )),
        PositionWiseFeedForward::new(
            vec![vec![1.0, 0.0], vec![0.0, 1.0]],
            vec![0.0, 0.0],
            vec![vec![1.0, 0.0], vec![0.0, 1.0]],
            vec![0.0, 0.0],
        )?,
        LayerNormalization::new(LayerNormParameters::identity(
            block_output.model_dimension(),
        )),
    )?;
    let multi_head_output = multi_head_block.apply(positioned_hidden)?;

    println!(
        "multi-head block shape: {} positions x {} heads x value dimension {} -> model dimension {}",
        multi_head_output.sequence_len().value(),
        multi_head_block.head_count().value(),
        multi_head_block.value_dimension().value(),
        multi_head_output.model_dimension().value()
    );

    let masked_multi_head_block = MaskedMultiHeadTransformerBlock::new(
        vec![
            SelfAttentionHead::new(
                HiddenToQuery::new(vec![vec![1.0], vec![0.0]], vec![0.0])?,
                HiddenToKey::new(vec![vec![1.0], vec![0.0]], vec![0.0])?,
                HiddenToValue::new(vec![vec![1.0], vec![0.0]], vec![0.0])?,
            )?,
            SelfAttentionHead::new(
                HiddenToQuery::new(vec![vec![0.0], vec![1.0]], vec![0.0])?,
                HiddenToKey::new(vec![vec![0.0], vec![1.0]], vec![0.0])?,
                HiddenToValue::new(vec![vec![0.0], vec![1.0]], vec![0.0])?,
            )?,
        ],
        AttentionOutputProjection::new(vec![vec![1.0, 0.0], vec![0.0, 1.0]], vec![0.0, 0.0])?,
        LayerNormalization::new(LayerNormParameters::identity(
            multi_head_output.model_dimension(),
        )),
        PositionWiseFeedForward::new(
            vec![vec![1.0, 0.0], vec![0.0, 1.0]],
            vec![0.0, 0.0],
            vec![vec![1.0, 0.0], vec![0.0, 1.0]],
            vec![0.0, 0.0],
        )?,
        LayerNormalization::new(LayerNormParameters::identity(
            multi_head_output.model_dimension(),
        )),
    )?;
    let masked_block_output = masked_multi_head_block.apply(Product::new(
        HiddenSequence::new(vec![vec![1.0, 0.0], vec![0.0, 1.0]])?,
        AttentionMask::new(vec![vec![true, false], vec![true, true]])?,
    ))?;

    println!(
        "masked multi-head block shape: {} positions x model dimension {}",
        masked_block_output.sequence_len().value(),
        masked_block_output.model_dimension().value()
    );

    let transformer_parameters = TinyTransformerParameters::new(
        PositionalEncoding::new(vec![vec![0.1, 0.0], vec![0.0, 0.1]])?,
        masked_multi_head_block,
        TransformerReadout::new(
            vec![vec![1.0, 0.0, 0.5], vec![0.0, 1.0, -0.5]],
            vec![0.0, 0.0, 0.0],
        )?,
    )?;
    let transformer_state =
        TransformerTrainingState::new(transformer_parameters, LearningRate::new(0.1)?);
    let training_set =
        TransformerReadoutTrainingSet::new([TransformerReadoutTrainingExample::new(
            HiddenSequence::new(vec![vec![1.0, 0.0], vec![0.0, 1.0]])?,
            AttentionMask::new(vec![vec![true, false], vec![true, true]])?,
            TokenSequence::from_indices([0, 1])?,
        )?])?;
    let sequence_logits = transformer_state.apply(Product::new(
        HiddenSequence::new(vec![vec![1.0, 0.0], vec![0.0, 1.0]])?,
        AttentionMask::new(vec![vec![true, false], vec![true, true]])?,
    ))?;
    let loss_before = transformer_readout_average_loss(&transformer_state, &training_set)?;
    let train_step = TransformerReadoutTrainStep::new(training_set.clone());
    let next_state = train_step.apply(transformer_state.clone())?;
    let loss_after = transformer_readout_average_loss(&next_state, &training_set)?;

    println!(
        "structured transformer logits shape: {} positions x vocabulary size {}",
        sequence_logits.sequence_len().value(),
        sequence_logits.vocab_size().value()
    );
    println!(
        "training state step: {} -> {}",
        transformer_state.step_count().value(),
        next_state.step_count().value()
    );
    println!(
        "readout loss after one update: {:.6} -> {:.6}",
        loss_before.value(),
        loss_after.value()
    );
    let feed_forward_training_set =
        TransformerFeedForwardTrainingSet::new([TransformerFeedForwardTrainingExample::new(
            HiddenSequence::new(vec![vec![1.0, 0.0], vec![0.0, 1.0]])?,
            HiddenSequence::new(vec![vec![2.0, 0.0], vec![0.0, 2.0]])?,
        )?])?;
    let feed_forward_loss_before =
        transformer_feed_forward_average_loss(&next_state, &feed_forward_training_set)?;
    let feed_forward_train_step =
        TransformerFeedForwardTrainStep::new(feed_forward_training_set.clone());
    let feed_forward_state = feed_forward_train_step.apply(next_state.clone())?;
    let feed_forward_loss_after =
        transformer_feed_forward_average_loss(&feed_forward_state, &feed_forward_training_set)?;

    println!(
        "feed-forward loss after one local update: {:.6} -> {:.6}",
        feed_forward_loss_before.value(),
        feed_forward_loss_after.value()
    );
    let block_training_set =
        TransformerBlockTrainingSet::new([TransformerBlockTrainingExample::new(
            HiddenSequence::new(vec![vec![1.0, 0.0], vec![0.0, 1.0]])?,
            AttentionMask::new(vec![vec![true, false], vec![true, true]])?,
            TokenSequence::from_indices([0, 1])?,
        )?])?;
    let block_loss_before =
        transformer_block_average_loss(&feed_forward_state, &block_training_set)?;
    let block_train_step = TransformerBlockTrainStep::new(block_training_set.clone());
    let block_trained_state = block_train_step.apply(feed_forward_state)?;
    let block_loss_after =
        transformer_block_average_loss(&block_trained_state, &block_training_set)?;

    println!(
        "block loss after one composed update: {:.6} -> {:.6}",
        block_loss_before.value(),
        block_loss_after.value()
    );

    println!();
    println!("Typed transformation:");
    println!("HiddenSequence -> QuerySequence");
    println!("HiddenSequence -> KeySequence");
    println!("HiddenSequence -> ValueSequence");
    println!("QuerySequence x KeySequence -> AttentionScores");
    println!("AttentionScores x AttentionMask -> AttentionScores");
    println!("AttentionScores -> AttentionWeights");
    println!("AttentionWeights x ValueSequence -> AttentionOutput");
    println!("AttentionHeadOutputs -> MultiHeadOutput");
    println!("MultiHeadOutput -> ProjectedAttentionOutput");
    println!("HiddenSequence x ProjectedAttentionOutput -> HiddenSequence");
    println!("LayerNormalization : HiddenSequence -> HiddenSequence");
    println!("PositionWiseFeedForward : HiddenSequence -> HiddenSequence");
    println!("PositionalEncoding : HiddenSequence -> HiddenSequence");
    println!("SingleHeadTransformerBlock : HiddenSequence -> HiddenSequence");
    println!("MultiHeadTransformerBlock : HiddenSequence -> HiddenSequence");
    println!("MaskedMultiHeadTransformerBlock : HiddenSequence x AttentionMask -> HiddenSequence");
    println!("TinyTransformerParameters : HiddenSequence x AttentionMask -> SequenceLogits");
    println!("TransformerTrainingState owns parameters, learning rate, and step count");
    println!("TransformerReadoutTrainStep : TransformerTrainingState -> TransformerTrainingState");
    println!(
        "TransformerFeedForwardTrainStep : TransformerTrainingState -> TransformerTrainingState"
    );
    println!("TransformerBlockTrainStep : TransformerTrainingState -> TransformerTrainingState");

    Ok(())
}

The smaller state-only companion is:

Source snapshot: examples/07_transformer_training_state.rs
use category_theory_transformer_rs::{
    AttentionMask, AttentionOutputProjection, CtResult, HiddenSequence, HiddenToKey, HiddenToQuery,
    HiddenToValue, LayerNormParameters, LayerNormalization, LearningRate,
    MaskedMultiHeadTransformerBlock, ModelDimension, Morphism, PositionWiseFeedForward,
    PositionalEncoding, Product, SelfAttentionHead, TinyTransformerParameters, TokenSequence,
    TransformerBlockTrainStep, TransformerBlockTrainingExample, TransformerBlockTrainingSet,
    TransformerFeedForwardTrainStep, TransformerFeedForwardTrainingExample,
    TransformerFeedForwardTrainingSet, TransformerReadout, TransformerReadoutTrainStep,
    TransformerReadoutTrainingExample, TransformerReadoutTrainingSet, TransformerTrainingState,
    transformer_block_average_loss, transformer_feed_forward_average_loss,
    transformer_readout_average_loss,
};

fn main() -> CtResult<()> {
    let hidden = HiddenSequence::new(vec![vec![1.0, 0.0], vec![0.0, 1.0]])?;
    let mask = AttentionMask::new(vec![vec![true, false], vec![true, true]])?;
    let targets = TokenSequence::from_indices([0, 1])?;
    let initial_state = tiny_training_state()?;

    let logits = initial_state.apply(Product::new(hidden.clone(), mask.clone()))?;

    println!(
        "initial state: step={}, learning_rate={:.3}, model_dimension={}, vocab_size={}",
        initial_state.step_count().value(),
        initial_state.learning_rate().value(),
        initial_state.parameters().model_dimension().value(),
        initial_state.parameters().vocab_size().value()
    );
    println!(
        "forward shape: {} positions x vocabulary size {}",
        logits.sequence_len().value(),
        logits.vocab_size().value()
    );

    let readout_set =
        TransformerReadoutTrainingSet::new([TransformerReadoutTrainingExample::new(
            hidden.clone(),
            mask.clone(),
            targets.clone(),
        )?])?;
    let readout_loss_before = transformer_readout_average_loss(&initial_state, &readout_set)?;
    let readout_state =
        TransformerReadoutTrainStep::new(readout_set.clone()).apply(initial_state)?;
    let readout_loss_after = transformer_readout_average_loss(&readout_state, &readout_set)?;

    print_update(
        "readout update",
        0,
        &readout_state,
        readout_loss_before.value(),
        readout_loss_after.value(),
    );

    let feed_forward_set =
        TransformerFeedForwardTrainingSet::new([TransformerFeedForwardTrainingExample::new(
            hidden.clone(),
            HiddenSequence::new(vec![vec![2.0, 0.0], vec![0.0, 2.0]])?,
        )?])?;
    let feed_forward_loss_before =
        transformer_feed_forward_average_loss(&readout_state, &feed_forward_set)?;
    let feed_forward_state =
        TransformerFeedForwardTrainStep::new(feed_forward_set.clone()).apply(readout_state)?;
    let feed_forward_loss_after =
        transformer_feed_forward_average_loss(&feed_forward_state, &feed_forward_set)?;

    print_update(
        "feed-forward update",
        1,
        &feed_forward_state,
        feed_forward_loss_before.value(),
        feed_forward_loss_after.value(),
    );

    let block_set = TransformerBlockTrainingSet::new([TransformerBlockTrainingExample::new(
        hidden, mask, targets,
    )?])?;
    let block_loss_before = transformer_block_average_loss(&feed_forward_state, &block_set)?;
    let block_state =
        TransformerBlockTrainStep::new(block_set.clone()).apply(feed_forward_state)?;
    let block_loss_after = transformer_block_average_loss(&block_state, &block_set)?;

    print_update(
        "composed block update",
        2,
        &block_state,
        block_loss_before.value(),
        block_loss_after.value(),
    );

    println!();
    println!("Typed transformation:");
    println!("TinyTransformerParameters : HiddenSequence x AttentionMask -> SequenceLogits");
    println!("TransformerTrainingState owns parameters, learning rate, and step count");
    println!("TransformerReadoutTrainStep : TransformerTrainingState -> TransformerTrainingState");
    println!(
        "TransformerFeedForwardTrainStep : TransformerTrainingState -> TransformerTrainingState"
    );
    println!("TransformerBlockTrainStep : TransformerTrainingState -> TransformerTrainingState");
    println!("Every update returns a full training state, not loose changed weights.");

    Ok(())
}

fn print_update(
    label: &str,
    previous_step: usize,
    state: &TransformerTrainingState,
    loss_before: f32,
    loss_after: f32,
) {
    println!(
        "{label}: step {} -> {}, loss {:.6} -> {:.6}",
        previous_step,
        state.step_count().value(),
        loss_before,
        loss_after
    );
}

fn tiny_training_state() -> CtResult<TransformerTrainingState> {
    let model_dimension = ModelDimension::new(2)?;
    let block = MaskedMultiHeadTransformerBlock::new(
        vec![
            SelfAttentionHead::new(
                HiddenToQuery::new(vec![vec![1.0], vec![0.0]], vec![0.0])?,
                HiddenToKey::new(vec![vec![1.0], vec![0.0]], vec![0.0])?,
                HiddenToValue::new(vec![vec![1.0], vec![0.0]], vec![0.0])?,
            )?,
            SelfAttentionHead::new(
                HiddenToQuery::new(vec![vec![0.0], vec![1.0]], vec![0.0])?,
                HiddenToKey::new(vec![vec![0.0], vec![1.0]], vec![0.0])?,
                HiddenToValue::new(vec![vec![0.0], vec![1.0]], vec![0.0])?,
            )?,
        ],
        AttentionOutputProjection::new(vec![vec![1.0, 0.0], vec![0.0, 1.0]], vec![0.0, 0.0])?,
        LayerNormalization::new(LayerNormParameters::identity(model_dimension)),
        PositionWiseFeedForward::new(
            vec![vec![1.0, 0.0], vec![0.0, 1.0]],
            vec![0.0, 0.0],
            vec![vec![1.0, 0.0], vec![0.0, 1.0]],
            vec![0.0, 0.0],
        )?,
        LayerNormalization::new(LayerNormParameters::identity(model_dimension)),
    )?;
    let parameters = TinyTransformerParameters::new(
        PositionalEncoding::new(vec![vec![0.1, 0.0], vec![0.0, 0.1]])?,
        block,
        TransformerReadout::new(
            vec![vec![1.0, 0.0, 0.5], vec![0.0, 1.0, -0.5]],
            vec![0.0, 0.0, 0.0],
        )?,
    )?;

    Ok(TransformerTrainingState::new(
        parameters,
        LearningRate::new(0.1)?,
    ))
}

Run it with:

cargo run --example 06_attention_scores

If you only want to inspect the training-state update shape, run the smaller companion example:

cargo run --example 07_transformer_training_state

You should see two query positions and three key positions. Query and key vectors first produce score rows. The mask removes one illegal score position. Then each row is normalized independently, and the weights mix value vectors:

QuerySequence x KeySequence -> AttentionScores
AttentionScores x AttentionMask -> AttentionScores
AttentionScores -> AttentionWeights
AttentionWeights x ValueSequence -> AttentionOutput
AttentionHeadOutputs -> MultiHeadOutput
MultiHeadOutput -> ProjectedAttentionOutput
HiddenSequence x ProjectedAttentionOutput -> HiddenSequence
HiddenSequence -> HiddenSequence

That is the bridge from the current softmax chapter. The current Distribution answers:

which next token is likely?

Attention weights answer:

which source positions should this position read from?

Both are probability-like objects. They differ in what their support means.

ML Concept

Attention computes:

scores = QK^T / sqrt(d)
weights = softmax(scores)
output = weights V

This is softmax again, but applied to token-to-token interaction scores.

Category Theory Concept

The attention block is a composition of typed maps with a product input:

(Q, K, V) -> scores -> weights -> mixed values

Design contract:

Attention should have a positive test showing that valid shapes compose and a negative test showing that mismatched head dimensions, mask shapes, or value lengths are rejected at construction or composition time.

Step 4: Multi-Head Concatenation

The current problem:

One attention head sees one interaction pattern. Multiple heads let the model carry several patterns in parallel. Their outputs must be recombined without losing shape information.

Rust Syntax

The recombination boundary is:

AttentionHeadOutputs -> MultiHeadOutput

HeadCount rejects zero. AttentionHeadOutputs rejects an empty collection, sequence-length mismatches, and head-dimension mismatches. MultiHeadOutput records:

sequence length
head count
head dimension
model dimension

This boundary is not the whole block by itself. It is the place where separate head outputs become one wider object before the output projection.

Worked Example: Concatenate Heads, Then Project

Multi-head attention adds one shape calculation that readers should be able to do without a framework:

model_dimension = head_count * head_dimension

In the runnable attention example, the first head has two output features per query position:

head_0 query_0 = [2.0, 20.0]
head_0 query_1 = [2.2033, 22.0334]

The example then adds a second head with the same sequence length and head dimension:

head_1 query_0 = [10.0, 1.0]
head_1 query_1 = [20.0, 2.0]

Concatenation does not average the heads. It places their feature rows side by side:

query_0 multi-head row = [2.0, 20.0, 10.0, 1.0]
query_1 multi-head row = [2.2033, 22.0334, 20.0, 2.0]

The recorded shape is:

2 heads x 2 features = model dimension 4

The next boundary is the learned output projection:

MultiHeadOutput -> ProjectedAttentionOutput

For the first query row, the tiny projection in the example uses:

input row:
[2.0, 20.0, 10.0, 1.0]

projection rows:
[1.0, 0.0]
[0.0, 0.1]
[0.5, 0.0]
[0.0, 1.0]

The projected row is:

first output feature  = 2.0 * 1.0 + 20.0 * 0.0 + 10.0 * 0.5 + 1.0 * 0.0 = 7.0
second output feature = 2.0 * 0.0 + 20.0 * 0.1 + 10.0 * 0.0 + 1.0 * 1.0 = 3.0

projected query_0 = [7.0, 3.0]

This is the reason the repository keeps two separate boundaries:

AttentionHeadOutputs -> MultiHeadOutput
MultiHeadOutput -> ProjectedAttentionOutput

The first boundary proves that head outputs can be concatenated. The second boundary proves that the concatenated width matches the projection input width. If either relationship fails, the model should fail at the boundary, before a later residual connection receives the wrong shape.

ML Concept

Each head performs attention separately.

The outputs are concatenated, then an output projection maps the combined row back into the hidden-state width. This repository now implements that projection as a separate typed boundary.

Category Theory Concept

This is parallel composition followed by recombination:

head_1 x head_2 x ... x head_n -> MultiHeadOutput

Design contract:

HeadCount, HeadDimension, and ModelDimension are not bare usize values. The arithmetic relationship between them is part of the architecture: if there are two heads of width two, the concatenated model dimension is four.

Step 5: Output Projection

The current problem:

Concatenated heads are wider than a single head. A later Transformer block expects a coherent hidden width again.

Rust Syntax

The current code models the projection as:

MultiHeadOutput -> ProjectedAttentionOutput

AttentionOutputProjection validates:

non-empty weight rows
non-empty bias
finite weight and bias values
weight rows matching output dimension
input dimension matching MultiHeadOutput model dimension

That shape follows the multi-head attention reference path: heads are concatenated, then another learned linear projection produces the output sequence.

ML Concept

The projection is a learned linear map after concatenation. It lets the model mix features across heads and return to the width expected by the surrounding block.

Category Theory Concept

This is another typed morphism:

MultiHeadOutput -> ProjectedAttentionOutput

Design contract:

The projection should fail before multiplication if the concatenated head width does not match the projection’s input dimension. That is a boundary invariant, not an indexing accident.

Step 6: Residual Addition

The current problem:

Transformer sublayers need to add their output back to the hidden sequence they received.

Rust Syntax

The current code models residual addition as:

HiddenSequence x ProjectedAttentionOutput -> HiddenSequence

ResidualConnection rejects sequence-length mismatches and model-dimension mismatches before adding rows. This follows the Transformer requirement that a sublayer output must have the same dimension as its input for residual addition to be feasible.

Worked Example: Residual Addition Needs The Same Shape

The previous worked example ended with a projected attention row:

projected query_0 = [7.0, 3.0]

Residual addition can only happen because that projected row has the same model dimension as the hidden row it will be added to:

hidden query_0    = [0.5, 0.5]
projected query_0 = [7.0, 3.0]

residual query_0  = [7.5, 3.5]

The second row follows the same rule:

hidden query_1    = [1.0, 1.0]
projected query_1 = [12.2033, 4.2033]

residual query_1  = [13.2033, 5.2033]

The shape is preserved:

HiddenSequence:
2 positions x model dimension 2

ProjectedAttentionOutput:
2 positions x model dimension 2

Residual HiddenSequence:
2 positions x model dimension 2

This is why the output projection matters. Multi-head concatenation produced a four-feature row. Residual addition needs a two-feature row because the hidden sequence has model dimension two in this tiny example. The projection is the bridge that makes the residual boundary legal.

The invalid shortcut would be:

HiddenSequence x MultiHeadOutput -> HiddenSequence

For the example above, that would try to add a two-feature hidden row to a four-feature concatenated row. The repository avoids that by making the legal path explicit:

MultiHeadOutput -> ProjectedAttentionOutput
HiddenSequence x ProjectedAttentionOutput -> HiddenSequence

ML Concept

Residual addition preserves the hidden sequence shape while allowing a sublayer to contribute a learned change:

hidden + sublayer_output

The repository currently implements the addition boundary, the layer normalization boundary, the position-wise feed-forward boundary, and compact single-head and multi-head block boundaries.

Category Theory Concept

The residual boundary consumes a product object and returns the same hidden sequence object:

HiddenSequence x ProjectedAttentionOutput -> HiddenSequence

For a fixed block value, the larger unmasked block shape is still an endomorphism:

HiddenSequence -> HiddenSequence

Design contract:

Residual addition should fail before addition if either the sequence length or model dimension differs. Without that check, a later Transformer block would be silently mixing incompatible objects.

Step 7: Layer Normalization

The current problem:

After residual addition, each sequence position should be normalized across its feature dimension while preserving the hidden sequence shape.

Rust Syntax

The current code models layer normalization as:

HiddenSequence -> HiddenSequence

LayerNormParameters validates non-empty scale and shift vectors, matching parameter lengths, finite parameter values, and positive finite epsilon. LayerNormalization rejects hidden sequences whose model dimension does not match its parameter dimension.

ML Concept

Layer normalization recenters and rescales each hidden vector across its feature dimension. It is batch-size independent and preserves the sequence shape:

same positions
same model dimension
new normalized values

The Layer Normalization paper is useful here because it frames the operation around statistics inside a single training case rather than statistics collected across a batch. In this roadmap’s tiny Rust object, that means each row can be normalized while the public HiddenSequence boundary remains the same.

Category Theory Concept

The normalization boundary is an endomorphism:

HiddenSequence -> HiddenSequence

Read that as a forward call for one fixed LayerNormalization value. Its scale, shift, and epsilon are already stored in the layer. If those values are being learned, the changing object is the larger training state, not the hidden sequence alone.

Design contract:

Normalization should not change the object type. If a later block expects a hidden sequence, the normalized result should still be a hidden sequence.

Step 8: Position-Wise Feed-Forward

The current problem:

After attention and normalization, each sequence position needs a learned non-linear transformation that preserves the hidden sequence shape.

Rust Syntax

The current code models this as:

HiddenSequence -> HiddenSequence

PositionWiseFeedForward validates two linear layers:

model dimension -> feed-forward hidden dimension -> model dimension

It also checks finite weights, finite biases, and compatible intermediate dimensions before any row is projected.

ML Concept

A position-wise feed-forward network applies the same two-layer non-linear map to each hidden vector independently:

hidden row -> expanded row -> activated row -> hidden row

It changes feature values, not the sequence length or public model dimension.

That is why the public boundary stays:

HiddenSequence -> HiddenSequence

Category Theory Concept

The feed-forward sublayer is another endomorphism:

HiddenSequence -> HiddenSequence

Read that the same way: for this fixed PositionWiseFeedForward value, the call receives a hidden sequence and returns a hidden sequence. The layer’s weights and biases are context already stored inside the Rust object. Training those weights belongs to a state update, not to this forward boundary.

It is not the whole Transformer block. It is the next shape-preserving sublayer that a later block can compose.

Design contract:

The second linear layer must return to the original model dimension. Otherwise the next residual or block boundary would receive the wrong object.

Worked Example: Values Change, Shape Stays HiddenSequence

The residual example produced this hidden sequence:

residual query_0 = [7.5, 3.5]
residual query_1 = [13.2033, 5.2033]

Layer normalization changes the row values while keeping the public object the same:

normalized query_0 = [0.9999988, -0.9999988]
normalized query_1 = [0.99999976, -0.99999976]

The sequence still has two positions and model dimension two:

Residual HiddenSequence
2 positions x model dimension 2

LayerNormalization
HiddenSequence -> HiddenSequence

Normalized HiddenSequence
2 positions x model dimension 2

The feed-forward sublayer then applies the same two-layer map to each position. In this tiny example, the ReLU step clips the negative feature before the second linear layer returns to the public model dimension:

feed-forward query_0 = [0.9999988, 0.0]
feed-forward query_1 = [0.99999976, 0.0]

Again, the object has not changed:

HiddenSequence
2 positions x model dimension 2

PositionWiseFeedForward
HiddenSequence -> HiddenSequence

HiddenSequence
2 positions x model dimension 2

The values changed twice. The sequence length and model dimension did not.

That distinction matters because normalization and feed-forward computation are not new sequence objects in this roadmap. They are shape-preserving maps over hidden rows:

Residual HiddenSequence -> LayerNormalization -> HiddenSequence
HiddenSequence -> PositionWiseFeedForward -> HiddenSequence

The invalid mental shortcut is:

normalization creates a special normalized object
feed-forward creates a special feed-forward object

The useful engineering view is stricter:

both stages return HiddenSequence so the next block boundary can compose

Step 9: Positional Encoding

The current problem:

Self-attention sees a set of hidden rows. A sequence model also needs to know where each row sits in the sequence.

Rust Syntax

The current code models position as another shape-preserving morphism:

PositionalEncoding : HiddenSequence -> HiddenSequence

The encoding table validates non-empty finite rows and a fixed model dimension. Applying it rejects hidden sequences that are too long for the table or have the wrong model width.

ML Concept

Position information is added to hidden vectors before attention so identical tokens in different positions can become distinguishable to later transformations.

Category Theory Concept

For one fixed positional-encoding table, the public shape is still an endomorphism:

HiddenSequence -> HiddenSequence

Read that as a forward call with the encoding table already selected. If a future chapter learns, swaps, or rebuilds the position table, that changing context must be named separately instead of being hidden inside the HiddenSequence -> HiddenSequence arrow.

Design contract:

Adding position should change values, not the hidden sequence object. If the position table has the wrong width or not enough rows, composition should fail before attention starts.

Step 10: Single-Head And Multi-Head Blocks

The current problem:

A block should compose attention, residual addition, normalization, and feed-forward computation while keeping the public shape simple.

Rust Syntax

The current single-head sketch has shape:

HiddenSequence -> HiddenSequence

It uses:

SingleHeadTransformerBlock
MultiHeadTransformerBlock

The single-head sketch proves the compact block boundary. The multi-head sketch extends that boundary by collecting several SelfAttentionHead values, concatenating their outputs, and applying the output projection.

ML Concept

Transformer blocks combine:

attention
residual connection
normalization
feed-forward network

The block output has the same shape as the input.

This is where the current training chapter becomes useful again. A block with shape HiddenSequence -> HiddenSequence can be stacked for the same reason a training step with shape Parameters -> Parameters can be repeated: output and input live in the same object.

Category Theory Concept

For one fixed single-head or multi-head block value, this is another endomorphism:

HiddenSequence -> HiddenSequence

Stacking layers is repeated endomorphism application.

If the block’s heads, projections, normalization parameters, or feed-forward weights are changing, the boundary is no longer only this forward call. The changing object is the larger training state:

TransformerTrainingState -> TransformerTrainingState

Design contract:

Internal complexity does not leak into every caller. The single-head and multi-head sketches contain several sublayers, but callers see one typed boundary.

For the multi-head sketch, the output-projection input dimension must equal:

head_count * value_head_dimension

That check is the difference between a typed block and a loose pile of matrix multiplications.

Step 11: Masked Blocks

The current problem:

Some sequence positions should not attend to other positions. The block boundary needs a mask, not only the lower-level score operation.

Rust Syntax

The current code models the masked block as:

MaskedMultiHeadTransformerBlock : HiddenSequence x AttentionMask -> HiddenSequence

The mask is part of the input product. The block applies it after query-key scoring and before row-wise softmax for each head.

ML Concept

A mask controls which source positions each query position may use. The same shape-preserving block can now represent selective attention.

Category Theory Concept

This is a product-to-object morphism:

HiddenSequence x AttentionMask -> HiddenSequence

Design contract:

The mask must have the same query and key dimensions as the score table inside the block. If the shape does not match, the block fails before softmax.

Worked Example: Fixed Mask Versus Open Mask

The masked block is where stacking language can become imprecise. The unmasked multi-head block has the simple public shape:

MultiHeadTransformerBlock : HiddenSequence -> HiddenSequence

That boundary can compose directly with another boundary of the same shape. The masked block is different:

MaskedMultiHeadTransformerBlock : HiddenSequence x AttentionMask -> HiddenSequence

While the mask is still an open input, the boundary is not a unary endomorphism. It needs the hidden sequence and the mask. There are two precise ways to use it repeatedly:

Use caseBoundary to nameWhy this is precise
the caller supplies a mask for each block callHiddenSequence x AttentionMask -> HiddenSequencethe mask remains visible as required context
one example fixes a specific mask before applying the blockHiddenSequence -> HiddenSequence under fixed mask contextthe unary map is induced by a named fixed mask

The second row is useful in a lesson or a single training example, but only if the fixed context is named. The mask did not disappear. It became part of the chosen environment for that run.

The invalid shortcut is:

MaskedMultiHeadTransformerBlock returns HiddenSequence, so it is automatically
an endomorphism.

The output type is not enough. Count the whole input object. The open masked boundary has product input. A fixed-mask view can be treated as a unary shape-preserving map only after the mask has been selected and kept stable for that call path.

Step 12: Structured State For Training And Evaluation

The current problem:

Once the model has attention parameters, evaluation and future training need one structured object instead of loose matrices passed through the code.

Rust Syntax

The earlier training chapter used:

Parameters

The roadmap code now adds the structured attention-side version:

TransformerReadout : HiddenSequence -> SequenceLogits
TinyTransformerParameters : HiddenSequence x AttentionMask -> SequenceLogits
TransformerTrainingState
TransformerReadoutTrainStep : TransformerTrainingState -> TransformerTrainingState
TransformerFeedForwardTrainStep : TransformerTrainingState -> TransformerTrainingState

TinyTransformerParameters owns:

positional encoding
masked multi-head block
sequence readout

TransformerTrainingState owns that parameter object plus LearningRate and StepCount. Its record_updated_parameters method records a new parameter object and increments the step count.

The roadmap code also adds three supervised updates:

TransformerReadoutTrainStep : TransformerTrainingState -> TransformerTrainingState
TransformerFeedForwardTrainStep : TransformerTrainingState -> TransformerTrainingState
TransformerBlockTrainStep : TransformerTrainingState -> TransformerTrainingState

The readout step updates the sequence readout with a softmax cross-entropy gradient. The feed-forward step updates the position-wise feed-forward sublayer against hidden-sequence targets. The block step composes those ideas: it starts from sequence targets, computes readout gradients, carries the hidden gradient through the final layer-normalization and residual boundary, then through the attention-normalization and residual boundary, and updates the feed-forward sublayer, attention output projection, query/key/value projections, and both layer-normalization scale/shift vectors from the same supervised example. These are real gradient steps, but deliberately tiny ones.

ML Concept

A Transformer training loop still has the same outer structure:

predict
compute loss
backpropagate
update parameters

The internal model is richer, so the parameter object must be richer. A useful training state keeps three questions separate:

what parameters define the model?
what optimizer settings control the update?
which update step are we on?

The current code answers those questions structurally, then adds a small full-batch gradient through the current trainable block components. The readout update answers the first smaller question:

if the hidden sequence is fixed, can the vocabulary readout learn?

Yes. The update computes probabilities at each sequence position, subtracts one from the target class probability term, accumulates weight and bias gradients for the readout, applies the learning rate, and increments the step count.

The local feed-forward update asks a different small question:

if the attention output is treated as fixed, can the feed-forward sublayer
learn a hidden-sequence target?

It computes a squared-error gradient through the two feed-forward linear layers and the ReLU between them. That teaches a real block-internal update without pretending to use token targets.

The composed block update asks the next question:

if the model predicts target tokens, can the readout loss also update the
feed-forward sublayer through the final residual and normalization path?

Yes. The update follows the actual forward cache used by the block, computes the softmax cross-entropy gradient at the readout, applies the standard layer normalization backward pass for the final normalization boundary, splits the residual path, passes through the feed-forward sublayer and attention normalization boundary, then updates the readout, feed-forward parameters, attention output projection, and both layer-normalization scale/shift vectors together. It also backpropagates through value mixing, attention softmax, and scaled query-key scores to update the query, key, and value projections.

Worked Example: Three Updates, One State Shape

The attention example prints one structured state transition first:

training state step: 0 -> 1

The smaller training-state example isolates the same contract:

initial state: step=0, learning_rate=0.100, model_dimension=2, vocab_size=3
readout update: step 0 -> 1
feed-forward update: step 1 -> 2
composed block update: step 2 -> 3

That line is small, but it carries the whole contract:

TransformerTrainingState
  owns TinyTransformerParameters
  owns LearningRate
  owns StepCount

An update is not allowed to return loose matrices. It must return another TransformerTrainingState, because the next update needs the same state shape.

The readout-only step asks the smallest supervised question:

fixed hidden sequence
  -> vocabulary logits
  -> sequence loss
  -> updated readout parameters

loss: 0.499085 -> 0.456495

The state shape is unchanged:

TransformerReadoutTrainStep
TransformerTrainingState -> TransformerTrainingState

The local feed-forward step asks a different question:

fixed hidden-sequence input
  -> feed-forward output
  -> squared-error hidden target
  -> updated feed-forward parameters

loss: 0.250000 -> 0.160633

The same outer shape holds:

TransformerFeedForwardTrainStep
TransformerTrainingState -> TransformerTrainingState

The composed block step asks the broader question:

hidden sequence and mask
  -> attention block
  -> readout logits
  -> sequence loss
  -> updated block and readout parameters

loss: 0.456495 -> 0.409737

Again, the outside of the system is stable:

TransformerBlockTrainStep
TransformerTrainingState -> TransformerTrainingState

The internal gradient path grows from readout-only, to local feed-forward, to a composed block update. The public training shape does not grow:

state_0 -> state_1 -> state_2 -> state_3

That is the engineering version of the earlier endomorphism idea. A training step may touch different fields, but it should return the same kind of state so the loop can keep running.

The invalid shortcut would be:

readout update returns readout weights
feed-forward update returns feed-forward weights
block update returns a bag of changed matrices

That makes the next training step guess how to rebuild the model. The roadmap uses one structured state object instead, so every update must preserve the state boundary.

Gradient Evidence Ledger

The block training step has finite-difference tests. They are important, but they are not magic certificates. Read them as local evidence checks.

The test shape is:

one selected parameter
  -> perturb it by +epsilon and -epsilon
  -> measure two nearby losses
  -> compute a central finite difference

one training step
  -> compare before and after parameter values
  -> infer the gradient used by the update

Those two paths should agree for the selected parameter:

central finite difference of loss ~= inferred update gradient

CS231n uses gradient checking this way: compare a numerical gradient with an analytic gradient, preferably with a centered finite-difference formula and careful error interpretation. PyTorch’s gradcheck documentation gives the framework version of the same idea: finite differences are compared with analytical gradients, and the result depends on tolerance, precision, differentiability, and memory-layout assumptions.

The Rust roadmap keeps the claim smaller:

Test familyParameter path checkedWhat it can catchWhat it cannot prove
transformer_block_train_step_matches_finite_difference_for_readout_weightsequence readout weightwrong sign, missing target-class gradient, wrong averaging scalecorrectness of attention gradients
transformer_block_train_step_matches_finite_difference_for_feed_forward_weightfeed-forward weightdropped ReLU or hidden-layer pathcorrectness of every feed-forward configuration
transformer_block_train_step_matches_finite_difference_for_layer_norm_parameternormalization scale or shiftwrong layer-normalization backward pathcorrectness of all normalization behavior
transformer_block_train_step_matches_finite_difference_for_attention_projectionquery, key, value, or output projection weightdropped attention projection pathcorrectness of every attention variant
bias finite-difference testsreadout, feed-forward, output projection, or attention projection biasmissing bias gradientcorrectness of all trainable fields

The category-theory reading is also modest. These tests compare two local morphisms around one selected coordinate:

loss measurement around current state
parameter update inside TransformerTrainingState -> TransformerTrainingState

They support the implementation of the current state endomorphism. They do not prove that every possible dataset, learning rate, optimizer, mask, sequence length, or future Transformer block is correct.

Use this decision rule when reading a gradient-check result:

match    -> local evidence for this parameter path
mismatch -> inspect sign, scaling, dropped path, nonsmooth point, or tolerance

Do not respond to a mismatch by only loosening the tolerance. First ask which typed boundary or gradient path failed.

Category Theory Concept

The forward path is now a typed morphism:

HiddenSequence x AttentionMask -> SequenceLogits

The training updates all have the same endomorphism shape:

TransformerTrainingState -> TransformerTrainingState

The block step is more global than the readout-only and local feed-forward steps because the loss starts at vocabulary logits and reaches an internal sublayer, the attention output projection, query/key/value projections, and both layer-normalization parameter sets. The update still preserves the same outer endomorphism shape even as the internal gradient path becomes richer.

Design contract:

The parameter object separates substructures:

position information
attention block
language-model readout

That separation is pedagogical and architectural. A reader can point at one field and say which mathematical role it plays. A future optimizer can update the same object without erasing the roles.

Core Mental Model

The current course teaches the typed skeleton:

TokenId -> Vector -> Logits -> Distribution
Distribution x TokenId -> Loss
Parameters -> Parameters

A Transformer extension grows the middle:

TokenSequence
  -> HiddenSequence
  -> HiddenSequence with position
  -> QuerySequence x KeySequence x ValueSequence
  -> AttentionOutput
  -> HiddenSequence
  -> SingleHeadTransformerBlock
  -> MultiHeadTransformerBlock
  -> SequenceLogits
  -> sequence-level probabilities

The practical rule stays the same:

Make every intermediate object explicit, then compose only arrows whose types actually match.

Where This Leaves Us

The roadmap keeps the book honest. The current implementation is a tiny next-token system, not a production Transformer. Its value is that it gives the future system a typed foundation: tokens become vectors, vectors become logits, logits become probabilities, probabilities become loss, and training updates parameters through a repeatable endomorphism.

A future Transformer should extend that foundation with stronger optimizer checks, more realistic datasets, and clearer diagrams. Each new concept should enter the codebase the same way the current concepts did: as a named type, a validated boundary, a typed morphism, a compiled example, and a law or regression test where the concept has a law worth checking.

Roadmap Reference Path

Use References as a staged path. This roadmap comes first because it says what the current code has and does not have. The original Transformer paper then gives the architectural target. Dive into Deep Learning gives the practical sequence from attention scoring to full Transformer blocks. Implementation and visual tutorials help translate paper notation into code structure and diagrams.

After that, return to this repository and add one typed boundary at a time. A future contribution should not start by copying a full architecture into one large module. It should start by making one Transformer concept explicit enough to construct, compose, test, and explain.

The central question for every future contribution is:

What invalid Transformer state should this type make harder to express?

Terminal Output Checkpoint Map

The companion example prints many lines because it is acting as a shape lab. Before reading the typed transformation list, group the terminal output into checkpoints.

Printed checkpointWhat changedWhat stayed trueBoundary to protect
attention shape: 2 query positions x 3 key positionsquery-key scoring produced one row per query and one column per keyscore rows are still tied to query positionsQuerySequence x KeySequence -> AttentionScores
query 0 attends with [0.5, 0.0, 0.5]masked scores became row-wise weightsthe masked key position has zero contributionAttentionScores x AttentionMask -> AttentionScores -> AttentionWeights
query 0 output vector [2.0, 20.0]weights mixed value rows into one output rowone output row still belongs to one query positionAttentionWeights x ValueSequence -> AttentionOutput
multi-head shape: 2 heads x 2 features -> model dimension 4separate head outputs were concatenatedsequence length stayed two positionsAttentionHeadOutputs -> MultiHeadOutput
projected attention shape: 2 positions x model dimension 2concatenated width returned to model widtheach row can now rejoin the residual streamMultiHeadOutput -> ProjectedAttentionOutput
residual shape: 2 positions x model dimension 2projected attention was added to the input hidden rowspublic object is still HiddenSequenceHiddenSequence x ProjectedAttentionOutput -> HiddenSequence
normalized shape and feed-forward shapevalues changed inside each rowsequence length and model dimension stayed stableHiddenSequence -> HiddenSequence
structured transformer logits shapehidden rows became vocabulary scores per positionlogits are not probabilities yetHiddenSequence x AttentionMask -> SequenceLogits
training state step: 0 -> 1parameters and step count advancedthe training object stayed wholeTransformerTrainingState -> TransformerTrainingState
readout loss after one updatereadout parameters learned from token targetsthe update still returns full training statereadout endomorphism
feed-forward loss after one local updatethe feed-forward sublayer learned a hidden targetthe update still returns full training statelocal feed-forward endomorphism
block loss after one composed updatereadout, feed-forward, attention projections, and normalization parameters moved togetherthe outer state shape stayed stablecomposed block endomorphism

Use this map to avoid a common Transformer-reading mistake: treating every printed vector as “attention.” The output actually moves through three different ideas:

where to read
  -> what information to read
  -> how to return to the hidden-state stream

Then the training lines ask a separate question:

which parameters moved, and did the update preserve the state object?

Example Output Transfer Checklist

After running the companion example, read the printed transformation list as a boundary report. For each line, ask four questions:

What Rust object is being produced?
What ML role does it play?
What shape must remain true?
What shortcut would break the next composition?
Example output lineBoundary to ownShortcut to reject
HiddenSequence -> QuerySequenceHidden rows become question-like vectors for scoring.Reusing raw hidden rows as queries without a named projection.
HiddenSequence -> KeySequenceThe same hidden rows become comparison vectors.Treating keys and queries as the same role because they share a source.
HiddenSequence -> ValueSequenceThe same hidden rows become information vectors to be mixed.Mixing keys or queries as if they were values.
QuerySequence x KeySequence -> AttentionScoresScores are unnormalized similarity numbers.Reading scores as probabilities or final attention output.
AttentionScores x AttentionMask -> AttentionScoresThe mask removes illegal positions before normalization.Applying softmax first, then hiding positions after probability mass has already moved.
AttentionScores -> AttentionWeightsRow-wise softmax turns scores into weights.Combining values before a normalized weight object exists.
AttentionWeights x ValueSequence -> AttentionOutputValues are mixed only after weights exist.Asking scores to carry both similarity and information.
AttentionHeadOutputs -> MultiHeadOutputHead outputs concatenate, so width is head_count * head_dimension.Pretending the concatenated width is already the model dimension.
MultiHeadOutput -> ProjectedAttentionOutputThe output projection returns concatenated heads to model width.Adding unprojected multi-head output directly to the residual stream.
HiddenSequence x ProjectedAttentionOutput -> HiddenSequenceResidual addition preserves sequence length and model dimension.Adding two objects whose row widths do not match.
LayerNormalization : HiddenSequence -> HiddenSequenceValues change while the public hidden-sequence shape stays stable.Treating normalization as a projection to a new domain object.
PositionWiseFeedForward : HiddenSequence -> HiddenSequenceEach row may expand internally, but the sublayer returns model width.Letting the hidden expansion leak past the sublayer boundary.
TransformerTrainingState -> TransformerTrainingStateTraining updates parameters while preserving learning rate and step count.Returning only changed weights and forcing the next step to reconstruct state.

This checklist is the transfer bridge from paper notation and framework documentation to this repository’s Rust style. The original Transformer uses query, key, value, masking, softmax, value mixing, multi-head concatenation, output projection, residual paths, normalization, and feed-forward sublayers. Diagrammatic attention research also treats attention as something that can be decomposed into recurring components before variants are compared. Framework APIs compress much of that into one function call. This chapter uncompresses the path so the reader can point at each intermediate Rust type and say what invalid connection it prevents.

Category Shape Diagnostic

The printed Transformer path uses several category-theory shapes that look similar if you only read the arrows. Before naming a boundary, ask two questions:

How many inputs does this boundary require?
Does it return the same public object, or a different object?

Those two questions prevent a common mistake: calling every shape-preserving line an endomorphism. A true endomorphism in this book has the form A -> A. A product-input boundary such as A x B -> A may return the same object as its left input, but it still needs extra information.

There is a second kind of extra information: learned parameters stored inside a layer object. When this diagnostic names LayerNormalization : HiddenSequence -> HiddenSequence, PositionWiseFeedForward : HiddenSequence -> HiddenSequence, or MultiHeadTransformerBlock : HiddenSequence -> HiddenSequence, read it as a forward call for one fixed layer or block instance. The scale, shift, weights, and biases are already inside that Rust value. If those parameters are being changed, the boundary has moved to training state:

TransformerTrainingState -> TransformerTrainingState

That distinction keeps the chapter honest. It allows a fixed module call to be shape-preserving without pretending parameter learning has disappeared.

They also prevent a second mistake: importing an advanced categorical name too early. Research on self-attention as a parametric endofunctor is useful for the linear portions of self-attention, especially query, key, value, positional, and layered structure. It does not make the whole pedagogical block a single endofunctor in this book. Softmax, masking, residual addition, normalization, feed-forward refinement, and training state each still need their own typed boundary.

Research on the anatomy of attention supports the opposite teaching move: decompose attention first, then compare variants. In this book, the decomposition is not a full diagrammatic formalism. It is a Rust teaching contract: every component must have a named type, a boundary shape, and a failure it prevents.

Categorical deep-learning research also separates architecture constraints from implementations. That distinction is useful here because the Rust code is an implementation witness for one small boundary at a time, not a proof that a future full Transformer satisfies every intended constraint. A good chapter claim should say which side it is on:

architecture constraint:
what should remain true?

implementation boundary:
which Rust type, constructor, example, or test currently enforces it?

Do not call the whole block an endofunctor when the explanation only checked one internal linear path. In this chapter, use the smaller safe name first: ordinary morphism, product-input morphism, shape-preserving endomorphism, state endomorphism, or illegal attempted composition.

The decision flow is:

flowchart TD
    B["Boundary shape"] --> T{"Does it type-check?"}
    T -->|"no"| I["Illegal attempted composition: name the missing conversion"]
    T -->|"yes"| C{"How many inputs are visible?"}
    C -->|"one input"| O{"Same whole source and target object?"}
    O -->|"yes"| E["Endomorphism: A -> A"]
    O -->|"no"| M["Ordinary morphism: A -> B"]
    C -->|"product input"| F{"Was one context fixed first?"}
    F -->|"yes"| U["Induced unary view: name the fixed context"]
    F -->|"no"| P["Product-input morphism: keep A x B visible"]

The same naming rule as a compact rendered math view:

[ \begin{array}{rcl} A \to B &:& \text{ordinary morphism} \ A \to A &:& \text{endomorphism} \ A \times B \to C &:& \text{product-input morphism} \ A \times B \to A &:& \text{not automatically an endomorphism} \ A \xrightarrow{f_b} A &:& \text{fixed-context induced endomorphism, after } b \text{ is fixed} \end{array} ]

How to read this diagram:

  • count the visible inputs before naming the category shape,
  • compare the whole source object with the whole target object,
  • fix context explicitly before using a unary view,
  • reject same-output shortcuts that ignore product inputs.

Read the diagram from top to bottom before naming an attention boundary. It is only a local naming aid, but it prevents three common shortcuts:

ShortcutSafer move
output shape matches, so the boundary is an endomorphismcompare the whole source object with the target object
the product can be read as one source, so the boundary is an endomorphismcheck whether the target is the same product object
context was fixed in prose, so the original boundary had one inputname the open boundary first, then name the fixed context

Two-Minute Classification Drill

Before reading the longer table, classify these boundaries yourself. Cover the right column, count the inputs, then decide whether the output returns to the same whole object.

BoundaryQuestion to ask firstSafe classification
HiddenSequence -> QuerySequenceone input, different output object?ordinary morphism
AttentionScores x AttentionMask -> AttentionScorestwo inputs, returns the left object?product-input morphism returning the score object
LayerNormalization : HiddenSequence -> HiddenSequenceone input, same output object?shape-preserving endomorphism
HiddenSequence x ProjectedAttentionOutput -> HiddenSequencetwo inputs, returns the left object?product-input morphism returning hidden state
TransformerTrainingState -> TransformerTrainingStateone input, same whole training object?state endomorphism

The trap is the second and fourth rows. Returning the left-hand object is not the same as being an endomorphism. A boundary that still needs a mask, value sequence, projected sublayer output, dataset, or learning rate is not a pure A -> A story until that context is explicitly fixed.

Source-Target Audit Card

Use this card when a row still feels ambiguous. Do not start from the output type. Name the whole source object, then name the target object.

BoundaryWhole source objectTarget objectContext statusSafe conclusion
HiddenSequence -> QuerySequenceHiddenSequenceQuerySequenceno extra context in the boundaryordinary morphism
AttentionScores x AttentionMask -> AttentionScoresAttentionScores x AttentionMaskAttentionScoresmask is open contextproduct-input morphism, not an endomorphism on scores
MaskedMultiHeadTransformerBlock[M] : HiddenSequence -> HiddenSequenceHiddenSequenceHiddenSequenceone mask M was fixed firstinduced endomorphism for that named mask
HiddenSequence x ProjectedAttentionOutput -> HiddenSequenceHiddenSequence x ProjectedAttentionOutputHiddenSequenceresidual input is open contextproduct-input morphism returning hidden state
TransformerTrainingState -> TransformerTrainingStateTransformerTrainingStateTransformerTrainingStateupdate context is inside the state objectstate endomorphism

The second and fourth rows are unary only if you choose to regard the product as one source object, but they are still not endomorphisms. Their targets are not the same product object. The fixed-mask row is different because the mask has been selected before the remaining call.

Linear Scope Diagnostic

Use this when an external source gives a categorical reading of self-attention. First ask which part of the attention path the source actually classified.

Attention partBoundary in this roadmapSafe reading here
query projectionHiddenSequence -> QuerySequencelinear role-producing morphism
key projectionHiddenSequence -> KeySequencelinear role-producing morphism
value projectionHiddenSequence -> ValueSequencelinear role-producing morphism
score constructionQuerySequence x KeySequence -> AttentionScoresproduct-input boundary
mask applicationAttentionScores x AttentionMask -> AttentionScoresproduct-input boundary, not a pure score endomorphism
score normalizationAttentionScores -> AttentionWeightsnonlinear normalization boundary
value mixingAttentionWeights x ValueSequence -> AttentionOutputproduct-input boundary
residual additionHiddenSequence x ProjectedAttentionOutput -> HiddenSequenceproduct-input boundary returning hidden state
layer normalization for a fixed layer instanceHiddenSequence -> HiddenSequenceshape-preserving but nonlinear endomorphism
parameter-changing training updateTransformerTrainingState -> TransformerTrainingStatestate endomorphism over the whole training object

Safe rule:

If a claim was checked for linear Q/K/V maps, do not carry it through softmax,
masking, residual addition, layer normalization, feed-forward refinement, or
training state without naming the next boundary.

This keeps the chapter usable for two readers at once. The category-theory reader sees where a stronger formal story might attach. The ML engineer sees which implemented boundary still needs its own shape, invariant, and test.

Worked Classification: Same Output, Different Shape

The most tempting mistake is to look only at the output type. Three boundaries below all end with HiddenSequence, but they do not have the same category shape.

BoundaryCount the inputsClassificationWhy
LayerNormalization : HiddenSequence -> HiddenSequenceone inputendomorphismthe whole input object and output object are both HiddenSequence
HiddenSequence x ProjectedAttentionOutput -> HiddenSequencetwo inputsproduct-input morphism returning HiddenSequenceresidual addition needs both the old stream and the projected sublayer output
HiddenSequence x MultiHeadOutput -> HiddenSequencetwo inputs, wrong second objectillegal attempted boundaryresidual addition needs projected model-width output, not raw concatenated heads

The first boundary can safely be named an endomorphism in this book. The second cannot, even though it returns HiddenSequence, because the full input object is not HiddenSequence; it is a product of two objects. The third should not receive a category-theory name yet. It is missing the output projection that makes the residual path well typed.

This gives a short decision tree:

Does the boundary type-check?
  no  -> name the missing conversion first
  yes -> count the inputs
          one input  -> compare input object and output object
          two inputs -> keep the product-input boundary visible

Use this naming rule before reading the table:

1. Count the inputs.
2. If there is one input, compare the input object and output object.
3. If there is a product input, keep the product in the name.
4. If a required projection or conversion is missing, call it illegal before
   giving it a category-theory name.

That gives five safe cases:

ShapeSafe nameExample
A -> Bordinary morphismAttentionScores -> AttentionWeights
A -> AendomorphismLayerNormalization : HiddenSequence -> HiddenSequence
A x B -> Cproduct-input morphismAttentionWeights x ValueSequence -> AttentionOutput
A x B -> Aproduct-input morphism returning AHiddenSequence x ProjectedAttentionOutput -> HiddenSequence
A x B -> A with the wrong Billegal attempted boundaryHiddenSequence x MultiHeadOutput -> HiddenSequence

The fourth row is the trap. Returning the left object is not enough to make a boundary an endomorphism. The whole input must be one object, and the output must be that same object.

There is a second safe reading that is useful but different. You may choose to treat the product itself as one source object:

(A x B) -> A

That makes the arrow unary from the product object, but it still is not an endomorphism. The source object is A x B; the target object is A. An endomorphism on the product would have shape:

(A x B) -> (A x B)

This is why the roadmap keeps the phrase “product-input morphism returning A” instead of shortening it to “endomorphism on A.”

Terminal Output Audit: Shape Line Is Not Boundary Shape

The runnable example prints several lines with the same public dimensions:

cargo run --example 06_attention_scores

Those lines are useful evidence, but they are not category names by themselves. A printed shape tells you something about the target object. The typed transformation line tells you the whole source object and the target object.

Printed output lineWhat the line provesWhat it does not proveBoundary to name
projected attention shape: 2 positions x model dimension 2raw head output has been projected back to model widththe residual connection has already happenedMultiHeadOutput -> ProjectedAttentionOutput
residual shape: 2 positions x model dimension 2the result has returned to hidden-sequence shaperesidual addition was unaryHiddenSequence x ProjectedAttentionOutput -> HiddenSequence
masked multi-head block shape: 2 positions x model dimension 2the block output can feed the next hidden-sequence layerthe open masked block is a pure endomorphismMaskedMultiHeadTransformerBlock : HiddenSequence x AttentionMask -> HiddenSequence
training state step: 0 -> 1the update returns a state that can be updated againtraining is a loose Loss -> Parameters shortcutTransformerTrainingState -> TransformerTrainingState

Use this three-step audit whenever the terminal output seems to settle the category name too quickly:

printed shape line -> evidence about the target object
typed transformation line -> evidence about the source and target objects
category name -> only after both source and target are known

That is why residual shape and masked multi-head block shape can both show model-width hidden rows while still having different safe category readings. Same printed dimensions are not the same boundary.

Stackability With Context

Stacking means the output of one boundary can feed the next boundary without inventing missing inputs. A direct endomorphism can stack by itself. A product-input boundary can stack only if the extra context is carried along or fixed explicitly.

For learned sublayers, “direct” means the layer instance is fixed for the forward call. A different LayerNormalization or PositionWiseFeedForward value is a different morphism. Changing those parameters is training-state work, so the safe outer name is TransformerTrainingState -> TransformerTrainingState.

BoundaryCan it stack directly as HiddenSequence -> HiddenSequence?Safe reading
LayerNormalization : HiddenSequence -> HiddenSequenceyesdirect shape-preserving endomorphism
MultiHeadTransformerBlock : HiddenSequence -> HiddenSequenceyesdirect block endomorphism
MaskedMultiHeadTransformerBlock : HiddenSequence x AttentionMask -> HiddenSequenceno, not while the mask is openproduct-input morphism that still needs mask context
fixed-mask view of MaskedMultiHeadTransformerBlockyes, for that named mask contextinduced endomorphism after context is fixed
TransformerTrainingState -> TransformerTrainingStateyesstate endomorphism over the whole training object

This is the same discipline as the rest of the chapter. Do not erase context to make a category name fit. If a mask, dataset, learning rate, or parameter object is part of the boundary, either keep it in the type shape or say exactly where it was fixed.

Context Fixing Drill

The open masked block and a fixed-mask view are related, but they are not the same boundary:

open boundary:
MaskedMultiHeadTransformerBlock : HiddenSequence x AttentionMask -> HiddenSequence

fixed-context boundary:
choose one mask M
MaskedMultiHeadTransformerBlock[M] : HiddenSequence -> HiddenSequence

The fixed-context boundary is a new named view after selecting M. It is not a claim that the original block had only one input. The source of the context must stay visible in the prose, the exercise, or the type that carries it.

CaseWhat is fixed?Safe category shapeCan it stack as HiddenSequence -> HiddenSequence?Overclaim to avoid
open masked blocknothingproduct-input morphismno“it returns HiddenSequence, so it is an endomorphism”
fixed-mask viewone named AttentionMaskinduced endomorphism for that maskyes, while the same mask context remains fixed“the mask disappeared”
changing mask per callthe mask is supplied again each callrepeated product-input calls, or a larger state carrying the maskonly if the caller threads or fixes the context“this is the same as a fixed-mask view”
residual additionno input is fixed; both hidden stream and projected output are suppliedproduct-input morphism returning hidden stateno“a binary operation is unary because the result is hidden state”

The residual row is a negative contrast. It is not a context-fixing example. The hidden stream is still an input, and the projected sublayer output is still an input. Nothing has been selected in advance. The boundary therefore remains:

HiddenSequence x ProjectedAttentionOutput -> HiddenSequence

It returns HiddenSequence, but it does not become a unary HiddenSequence -> HiddenSequence boundary unless one input has actually been fixed. If the product object itself is named as the source, the arrow is unary from that product object:

(HiddenSequence x ProjectedAttentionOutput) -> HiddenSequence

That is still not an endomorphism, because the target is not the same product object. An endomorphism on the named product would have to return the whole product again:

(HiddenSequence x ProjectedAttentionOutput)
    -> (HiddenSequence x ProjectedAttentionOutput)

This is only a local teaching use of fixing context. It is not a proof that the whole attention block lives in a closed category, and it is not permission to hide arbitrary inputs. The practical rule stays simple:

name the open boundary first
name exactly what was fixed
then name the induced unary view

Rust already has a familiar mechanism for this idea: a closure can capture a value from the surrounding environment. The official Rust Book uses closures to teach how a callable value can remember environment. In this roadmap, a fixed-mask view can be read the same way:

let fixed_mask = mask.clone();
let fixed_mask_view = move |hidden: HiddenSequence| {
    masked_block.apply(Product::new(hidden, fixed_mask.clone()))
};

That closure-shaped explanation is only an analogy for this local boundary. It does not change the original open type:

MaskedMultiHeadTransformerBlock : HiddenSequence x AttentionMask -> HiddenSequence

It says how one chosen AttentionMask can be captured before the remaining call receives HiddenSequence. If a reader cannot point to the captured fixed_mask, the text has hidden context instead of fixing it.

BoundaryCategory shape to nameWhy this is the right nameCommon misread
QuerySequence x KeySequence -> AttentionScoresproduct-input morphismscoring needs query rows and key rowstreating scores as a unary query transform
AttentionScores x AttentionMask -> AttentionScoresproduct-input morphism returning the score objectthe mask is extra evidence used before softmaxcalling it a pure endomorphism on scores
AttentionScores -> AttentionWeightsordinary morphismraw scores become normalized rowstreating weights as the same object as scores
AttentionWeights x ValueSequence -> AttentionOutputproduct-input morphismweights decide which value rows to readtreating keys and values as interchangeable
MultiHeadOutput -> ProjectedAttentionOutputordinary morphismconcatenated head width returns to model widthadding multi-head output directly to residual state
HiddenSequence x ProjectedAttentionOutput -> HiddenSequenceproduct-input morphism returning hidden stateresidual addition needs both the old stream and the sublayer outputcalling the binary residual operation a unary endomorphism
LayerNormalization : HiddenSequence -> HiddenSequenceshape-preserving endomorphismvalues change while the public hidden object stays the sametreating normalization as a new sequence domain
PositionWiseFeedForward : HiddenSequence -> HiddenSequenceshape-preserving endomorphisminternal width may expand, but the public object returns unchangedleaking the internal expansion into the next block
TransformerTrainingState -> TransformerTrainingStatestate endomorphismone update returns a complete object that can be updated againreturning only changed weights or only loss
HiddenSequence x MultiHeadOutput -> HiddenSequencenot a legal composed boundaryresidual addition needs projected model-width outputskipping the output projection

This diagnostic is the category-theory version of shape checking. The ML question is:

what information must this stage receive?

The Rust question is:

which type should own that information before the next call?

The category-theory question is:

is this a unary morphism, a product-input morphism, an endomorphism, or not
composable yet?

If a boundary needs two objects, write both. If it returns to the same public object, say whether that return is unary or product-input. Precision here is what keeps the roadmap from turning attention into a single vague arrow.

Reader Evidence Handoff

If this diagnostic becomes unclear, the most useful report is not “attention is confusing.” The useful report names the exact rule that failed.

Use this shape:

Command: cargo run --example 06_attention_scores
Page: Transformer Roadmap -> Category Shape Diagnostic
Evidence signal: one boundary row or printed output line
Last clear idea: the last boundary name that still made sense
First unclear rule: input count, fixed context, legal composition, source role,
target role, or linear-scope limit
Smallest useful fix: one sentence, table row, diagram, or exercise check

Good evidence signals are small:

AttentionScores x AttentionMask -> AttentionScores
MaskedMultiHeadTransformerBlock[M] : HiddenSequence -> HiddenSequence
HiddenSequence x MultiHeadOutput -> HiddenSequence
query 0 attends with [0.5, 0.0, 0.5]

A report like that gives the next rewrite a concrete target: which boundary, which rule, and which reader expectation failed.

Open the chapter clarity feedback form with those fields filled from your own run or reading.

Retrieval Practice

Run the attention example before answering:

cargo run --example 06_attention_scores

Recall

Recover the named objects and boundaries before explaining them.

  1. Which three role objects are produced from HiddenSequence before attention scores are computed?
  2. Which printed line is the first point where attention scores become row-wise normalized weights?
  3. Which boundaries in the example preserve the public shape HiddenSequence -> HiddenSequence?
  4. Which three training steps share the outer shape TransformerTrainingState -> TransformerTrainingState?
  5. Which printed line tells you that multi-head width must be projected before it can rejoin the residual stream?

Explain

Use the type boundary to explain the reason for the design.

  1. Why must the mask act before row-wise softmax?
  2. Why does multi-head attention need an output projection before residual addition?
  3. Why is returning a full TransformerTrainingState safer than returning only changed readout or feed-forward weights?

The next questions check the scope of the evidence. A local gradient check and a shape-preserving sublayer are useful only when their boundaries are named.

  1. Why is a finite-difference check useful for one selected parameter without proving that every future training loop is correct?
  2. Why does PositionWiseFeedForward : HiddenSequence -> HiddenSequence permit an internal hidden expansion but not an expanded public output?

Apply

Change the numbers and check whether the same typed rule still holds.

  1. A block has three heads and each head produces four features per position. What input width must the output projection accept?
  2. A feed-forward sublayer expands a model-dimension-two row to six hidden features, then returns five features. Which public boundary has been broken?
  3. A training step updates the feed-forward weights but drops the learning rate. Why can the next training step no longer compose safely?
  4. A two-token sequence has raw scores [1.0, 9.0] for the first row, but the second position is masked out. Which object should record the forbidden position before softmax?
  5. A learner sees the line AttentionWeights x ValueSequence -> AttentionOutput and wants to replace ValueSequence with KeySequence. Which ML role has been lost?
  6. A learner sees HiddenSequence x ProjectedAttentionOutput -> HiddenSequence and calls it an endomorphism because the output is HiddenSequence. Which step of the naming rule corrects that mistake?

Debug

For each invalid shortcut, name the missing boundary:

HiddenSequence x MultiHeadOutput -> HiddenSequence
AttentionScores -> AttentionOutput
readout update -> changed readout weights
softmax scores -> masked weights
feed-forward hidden expansion -> next block input
finite-difference agreement -> full optimizer proof

Good answers should point back to a concrete type or transformation in this chapter, not only to a phrase such as “shape mismatch” or “training update.”