Keyboard shortcuts

Press or to navigate between chapters

Press S or / to search in the book

Press ? to show this help

Press Esc to hide this help

Repository Source Snapshots

This appendix collects the learner-facing source files used throughout the course. The Rust snapshots are complete files, so they are marked ignore as standalone snippets. The real source files are still validated by the repository checks.

The problem this appendix solves is:

The explanatory chapters should never drift away from the real code.

Use this appendix as the exact-code lookup layer.

When you inspect any block here, use the same reading pattern as the chapters:

Rust syntax:
what does the code literally declare or execute?

ML or software concept:
why does this block exist in the pipeline or system model?

Category theory concept:
what object, morphism, product, composition, endomorphism, or law does it
represent?

The appendix is intentionally less narrative than the chapters. It keeps the full source available in one place so you can verify every explanation against the code itself.

How To Navigate The Snapshots

Use the snapshots in two directions.

When reading chapter-first, start with the explanation, then open the matching snapshot to verify the exact code. When reading code-first, start with the file that interests you, then return to the chapter that teaches its role.

If you want to inspectStart withThen read
public crate surfacesrc/lib.rsCourse Map
typed values and invariantssrc/domain.rsDomain Objects
arrows and compositionsrc/category.rsMorphism and Composition
prediction pipelinesrc/ml.rsThe Tiny ML Pipeline
parameter updatessrc/training.rsTraining as an Endomorphism
reusable structuresrc/structure.rsFunctors, Naturality, Monoids, and Chain Rule
local derivative flowsrc/calculus.rsFunctors, Naturality, Monoids, and Chain Rule
typed masked attention, value-mixing, head-concatenation, output-projection, residual, normalization, feed-forward, positional, hidden-projection, single-head block, multi-head block, masked-block, readout, parameter-object, training-state, readout-training, local feed-forward training, and composed block-training boundaries with query/key/value gradientssrc/attention.rsTransformer Roadmap
applied category-theory sketchessrc/sketches.rsSeven Sketches Through Rust
public challenge reference behaviorsrc/challenges/ and examples/challenge_adam.rsChallenges
runnable end-to-end walkthroughsrc/demo.rsCourse Map
command-line entrypointsrc/bin/category_ml.rsCourse Map

The useful question is:

Which explanation in the book would become false if this source file changed?

That question is why the source snapshots exist. They make drift visible.

Reading Order By Goal

For a five-minute run, inspect examples/01_token_sequence.rs, then compare it with the beginning of Course Map.

For the core book path, read src/domain.rs, src/category.rs, src/ml.rs, and src/training.rs in that order.

For the structure path, read src/structure.rs, src/calculus.rs, src/attention.rs, and src/sketches.rs after the tiny ML pipeline is clear.

For contribution work, read the chapter first, then the source file, then the tests in the same module. The tests often explain the contract more clearly than the implementation alone.

Rust Library Surface

src/lib.rs

//! A small, modular Rust tutorial for category-theory ideas in tiny ML.
//!
//! The crate is intentionally split into small learning modules:
//! - [`domain`] defines the nouns: tokens, vectors, probabilities, losses, and parameters.
//! - [`category`] defines the arrows: morphisms, identity, composition, and endomorphisms.
//! - [`ml`] implements concrete ML morphisms.
//! - [`training`] turns one optimizer step into an endomorphism on parameters.
//! - [`structure`] covers functors, natural transformations, and monoids.
//! - [`calculus`] shows the chain rule as a local backward pass.
//! - [`attention`] sketches typed attention boundaries and Transformer state for the roadmap.
//! - [`sketches`] gives Rust models for the seven applied-category-theory sketches.
//! - [`challenges`] backs the public Typed AI Rustlings and Paper-To-Rust tracks.
//! - [`demo`] connects the pieces into the terminal walkthrough.

pub mod attention;
pub mod calculus;
pub mod category;
pub mod challenges;
pub mod demo;
pub mod domain;
pub mod error;
pub mod ml;
pub mod sketches;
pub mod structure;
pub mod training;

pub use attention::{
    AttentionHeadOutputs, AttentionMask, AttentionOutput, AttentionOutputProjection,
    AttentionScores, AttentionSoftmax, AttentionWeights, ConcatenateHeads, HeadCount,
    HeadDimension, HiddenSequence, HiddenToKey, HiddenToQuery, HiddenToValue, KeySequence,
    LayerNormParameters, LayerNormalization, MaskedAttentionScores,
    MaskedMultiHeadTransformerBlock, MultiHeadOutput, MultiHeadTransformerBlock,
    NormalizationEpsilon, PositionWiseFeedForward, PositionalEncoding, ProjectedAttentionOutput,
    QuerySequence, ResidualConnection, ScaledDotProductScores, SelfAttentionHead, SequenceLength,
    SequenceLogits, SingleHeadTransformerBlock, TinyTransformerParameters,
    TransformerBlockTrainStep, TransformerBlockTrainingExample, TransformerBlockTrainingSet,
    TransformerFeedForwardTrainStep, TransformerFeedForwardTrainingExample,
    TransformerFeedForwardTrainingSet, TransformerReadout, TransformerReadoutTrainStep,
    TransformerReadoutTrainingExample, TransformerReadoutTrainingSet, TransformerTrainingState,
    ValueSequence, WeightedValueMixing, transformer_block_average_loss,
    transformer_feed_forward_average_loss, transformer_readout_average_loss,
};
pub use calculus::{LocalGradient, MulOp, Scalar};
pub use category::{
    Compose, Endomorphism, Identity, Morphism, StepCount, apply_endomorphism_n_times,
};
pub use challenges::papers::adam::{
    AdamConfig, AdamDecayRate, AdamEpsilon, AdamFirstMoment, AdamGradientVector, AdamModelState,
    AdamOptimizerState, AdamParameterVector, AdamSecondMoment, AdamStepCount, AdamTrainStep,
    AdamVectorDimension,
};
pub use challenges::typed_ai::{
    loss_from_logits, require_target_in_distribution, token_index, uniform_distribution,
};
pub use demo::run_demo;
pub use domain::{
    Distribution, LearningRate, Logits, Loss, ModelDimension, Parameters, Product, TokenId,
    TokenSequence, TrainingExample, TrainingSet, Vector, VocabSize,
};
pub use error::{CtError, CtResult};
pub use ml::{
    CrossEntropy, DatasetWindowing, DirectPredict, Embedding, LinearToLogits, Softmax,
    average_loss, composed_prediction_matches_direct_prediction,
};
pub use sketches::{
    CircuitComponent, CompanyInstance, DepartmentId, DesignRequirement, EmployeeId, EmployeeRecord,
    FeasibilityRelation, FeatureCount, ImplementationOffer, InformationLevel, LatencyMs,
    LayerBudget, LocalSafetyCheck, MatrixCols, MatrixRows, OpenCircuit, PortName, ResistanceOhms,
    ResourceAmount, ResourceBundle, SafetyCover, SignalCoefficient, SignalMatrix, Throughput,
    TimeInterval, TimeTick, TruthValue, abstract_to_layer_budget, concretize_layer_budget,
    feature_layer_galois_law_holds, information_order_obeys_preorder_laws,
    resource_tensor_is_monotone,
};
pub use structure::{
    Functor, Monoid, NaturalTransformation, OptionFunctor, PipelineTrace, TraceStep, VecFunctor,
    VecToFirstOption, monoid_laws_hold_for_pipeline_trace,
    naturality_square_holds_for_first_option,
};
pub use training::TrainStep;

src/error.rs

use std::fmt::{self, Display};

/// Shared error type for the tutorial crate.
#[derive(Debug, Clone, PartialEq)]
pub enum CtError {
    EmptyInput(&'static str),
    OutOfRange {
        kind: &'static str,
        index: usize,
        limit: usize,
    },
    ShapeMismatch {
        op: &'static str,
        expected: String,
        got: String,
    },
    InvalidProbability(&'static str),
    InvalidLoss(f32),
    InvalidLearningRate(f32),
    InvalidScalar {
        kind: &'static str,
        value: f32,
    },
    InvalidQuantity {
        kind: &'static str,
        value: i64,
    },
    InvalidInterval {
        start: usize,
        end: usize,
    },
}

impl Display for CtError {
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
        match self {
            CtError::EmptyInput(context) => write!(f, "empty input in {context}"),
            CtError::OutOfRange { kind, index, limit } => {
                write!(f, "{kind} index {index} out of range; limit is {limit}")
            }
            CtError::ShapeMismatch { op, expected, got } => {
                write!(f, "shape mismatch in {op}: expected {expected}, got {got}")
            }
            CtError::InvalidProbability(context) => {
                write!(f, "invalid probability distribution in {context}")
            }
            CtError::InvalidLoss(value) => write!(f, "invalid loss value {value}"),
            CtError::InvalidLearningRate(value) => write!(f, "invalid learning rate {value}"),
            CtError::InvalidScalar { kind, value } => {
                write!(f, "invalid {kind} scalar {value}")
            }
            CtError::InvalidQuantity { kind, value } => {
                write!(f, "invalid {kind} quantity {value}")
            }
            CtError::InvalidInterval { start, end } => {
                write!(f, "invalid interval: start {start} is after end {end}")
            }
        }
    }
}

impl std::error::Error for CtError {}

pub type CtResult<T> = Result<T, CtError>;

src/domain.rs

use crate::error::{CtError, CtResult};

/// A vocabulary index. It is intentionally not a raw `usize` in public APIs.
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub struct TokenId(usize);

impl TokenId {
    pub fn new(index: usize) -> Self {
        Self(index)
    }

    pub fn index(&self) -> usize {
        self.0
    }
}

impl From<usize> for TokenId {
    fn from(value: usize) -> Self {
        Self::new(value)
    }
}

/// A sequence of tokens before it has been converted into training pairs.
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct TokenSequence(Vec<TokenId>);

impl TokenSequence {
    pub fn new(tokens: impl IntoIterator<Item = TokenId>) -> CtResult<Self> {
        let tokens = tokens.into_iter().collect::<Vec<_>>();

        if tokens.is_empty() {
            return Err(CtError::EmptyInput("token sequence"));
        }

        Ok(Self(tokens))
    }

    pub fn from_indices(indices: impl IntoIterator<Item = usize>) -> CtResult<Self> {
        Self::new(indices.into_iter().map(TokenId::new))
    }

    pub fn as_slice(&self) -> &[TokenId] {
        &self.0
    }
}

/// A dense feature vector.
#[derive(Debug, Clone, PartialEq)]
pub struct Vector(Vec<f32>);

impl Vector {
    pub fn new(values: Vec<f32>) -> Self {
        Self(values)
    }

    pub fn as_slice(&self) -> &[f32] {
        &self.0
    }
}

/// Unnormalized model scores.
#[derive(Debug, Clone, PartialEq)]
pub struct Logits(Vec<f32>);

impl Logits {
    pub fn new(values: Vec<f32>) -> Self {
        Self(values)
    }

    pub fn as_slice(&self) -> &[f32] {
        &self.0
    }
}

/// A validated probability distribution.
#[derive(Debug, Clone, PartialEq)]
pub struct Distribution(Vec<f32>);

impl Distribution {
    pub fn new(probabilities: Vec<f32>) -> CtResult<Self> {
        if probabilities.is_empty() {
            return Err(CtError::EmptyInput("distribution"));
        }

        let sum: f32 = probabilities.iter().sum();
        let all_valid = probabilities
            .iter()
            .all(|probability| probability.is_finite() && *probability >= 0.0);

        if !all_valid || !approx_eq(sum, 1.0, 1e-4) {
            return Err(CtError::InvalidProbability("distribution constructor"));
        }

        Ok(Self(probabilities))
    }

    pub fn as_slice(&self) -> &[f32] {
        &self.0
    }
}

/// A non-negative scalar objective value.
#[derive(Debug, Clone, Copy, PartialEq)]
pub struct Loss(f32);

impl Loss {
    pub fn new(value: f32) -> CtResult<Self> {
        if !value.is_finite() || value < 0.0 {
            return Err(CtError::InvalidLoss(value));
        }

        Ok(Self(value))
    }

    pub fn value(&self) -> f32 {
        self.0
    }
}

/// Number of vocabulary entries.
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct VocabSize(usize);

impl VocabSize {
    pub fn new(value: usize) -> CtResult<Self> {
        if value == 0 {
            return Err(CtError::EmptyInput("vocabulary"));
        }

        Ok(Self(value))
    }

    pub fn value(&self) -> usize {
        self.0
    }
}

/// Width of each embedding vector.
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct ModelDimension(usize);

impl ModelDimension {
    pub fn new(value: usize) -> CtResult<Self> {
        if value == 0 {
            return Err(CtError::EmptyInput("model dimension"));
        }

        Ok(Self(value))
    }

    pub fn value(&self) -> usize {
        self.0
    }
}

/// Positive optimizer step size.
#[derive(Debug, Clone, Copy, PartialEq)]
pub struct LearningRate(f32);

impl LearningRate {
    pub fn new(value: f32) -> CtResult<Self> {
        if !value.is_finite() || value <= 0.0 {
            return Err(CtError::InvalidLearningRate(value));
        }

        Ok(Self(value))
    }

    pub fn value(&self) -> f32 {
        self.0
    }
}

/// Categorical product object: `A x B`.
#[derive(Debug, Clone, PartialEq)]
pub struct Product<A, B> {
    first: A,
    second: B,
}

impl<A, B> Product<A, B> {
    pub fn new(first: A, second: B) -> Self {
        Self { first, second }
    }

    pub fn first(&self) -> &A {
        &self.first
    }

    pub fn second(&self) -> &B {
        &self.second
    }

    pub fn into_parts(self) -> (A, B) {
        (self.first, self.second)
    }
}

pub type TrainingExample = Product<TokenId, TokenId>;

/// Non-empty next-token training pairs.
#[derive(Debug, Clone, PartialEq)]
pub struct TrainingSet(Vec<TrainingExample>);

impl TrainingSet {
    pub fn new(examples: impl IntoIterator<Item = TrainingExample>) -> CtResult<Self> {
        let examples = examples.into_iter().collect::<Vec<_>>();

        if examples.is_empty() {
            return Err(CtError::EmptyInput("training set"));
        }

        Ok(Self(examples))
    }

    pub fn examples(&self) -> &[TrainingExample] {
        &self.0
    }

    pub fn len(&self) -> usize {
        self.0.len()
    }

    pub fn is_empty(&self) -> bool {
        self.0.is_empty()
    }
}

/// Tiny model parameters for an embedding plus language-model head.
#[derive(Debug, Clone, PartialEq)]
pub struct Parameters {
    pub(crate) embedding: Vec<Vec<f32>>,
    pub(crate) lm_head: Vec<Vec<f32>>,
    pub(crate) bias: Vec<f32>,
}

impl Parameters {
    pub fn init(vocab_size: VocabSize, d_model: ModelDimension) -> Self {
        let vocab_size = vocab_size.value();
        let d_model = d_model.value();

        Self {
            embedding: init_matrix(vocab_size, d_model, 0.2, 1),
            lm_head: init_matrix(d_model, vocab_size, 0.2, 2),
            bias: vec![0.0; vocab_size],
        }
    }

    pub fn vocab_size(&self) -> usize {
        self.bias.len()
    }

    pub fn d_model(&self) -> usize {
        self.embedding.first().map_or(0, Vec::len)
    }

    pub fn embedding_table(&self) -> &[Vec<f32>] {
        &self.embedding
    }

    pub fn lm_head(&self) -> &[Vec<f32>] {
        &self.lm_head
    }

    pub fn bias(&self) -> &[f32] {
        &self.bias
    }
}

pub(crate) fn init_matrix(rows: usize, cols: usize, scale: f32, seed: usize) -> Vec<Vec<f32>> {
    let mut out = vec![vec![0.0; cols]; rows];

    for (row_index, row) in out.iter_mut().enumerate() {
        for (col_index, value) in row.iter_mut().enumerate() {
            let raw = ((row_index * cols + col_index) * 37 + seed * 101) % 1000;
            let unit = raw as f32 / 1000.0;
            *value = (unit - 0.5) * scale;
        }
    }

    out
}

pub(crate) fn approx_eq(a: f32, b: f32, eps: f32) -> bool {
    (a - b).abs() <= eps
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn distribution_rejects_non_normalized_values() {
        let result = Distribution::new(vec![0.4, 0.4]);

        assert!(matches!(result, Err(CtError::InvalidProbability(_))));
    }

    #[test]
    fn token_sequence_rejects_empty_input() {
        let result = TokenSequence::new(vec![]);

        assert!(matches!(result, Err(CtError::EmptyInput("token sequence"))));
    }
}

src/category.rs

use std::marker::PhantomData;

use crate::error::CtResult;

/// A typed category-theory arrow: `Input -> Output`.
pub trait Morphism<Input, Output> {
    fn name(&self) -> &'static str;
    fn apply(&self, input: Input) -> CtResult<Output>;
}

/// Identity morphism: `id_A : A -> A`.
#[derive(Debug, Clone, Copy)]
pub struct Identity<T> {
    _marker: PhantomData<T>,
}

impl<T> Identity<T> {
    pub fn new() -> Self {
        Self {
            _marker: PhantomData,
        }
    }
}

impl<T> Default for Identity<T> {
    fn default() -> Self {
        Self::new()
    }
}

impl<T> Morphism<T, T> for Identity<T> {
    fn name(&self) -> &'static str {
        "identity"
    }

    fn apply(&self, input: T) -> CtResult<T> {
        Ok(input)
    }
}

/// Composition of two morphisms: if `f : A -> B` and `g : B -> C`, this is
/// `g after f : A -> C`.
#[derive(Debug, Clone)]
pub struct Compose<F, G, Middle> {
    first: F,
    second: G,
    _middle: PhantomData<Middle>,
}

impl<F, G, Middle> Compose<F, G, Middle> {
    pub fn new(first: F, second: G) -> Self {
        Self {
            first,
            second,
            _middle: PhantomData,
        }
    }
}

impl<Input, Middle, Output, F, G> Morphism<Input, Output> for Compose<F, G, Middle>
where
    F: Morphism<Input, Middle>,
    G: Morphism<Middle, Output>,
{
    fn name(&self) -> &'static str {
        "composition"
    }

    fn apply(&self, input: Input) -> CtResult<Output> {
        let middle = self.first.apply(input)?;
        self.second.apply(middle)
    }
}

/// Endomorphism: a morphism from a type back to itself.
pub trait Endomorphism<T>: Morphism<T, T> {}

impl<T, M> Endomorphism<T> for M where M: Morphism<T, T> {}

/// How many times to repeat an endomorphism.
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct StepCount(usize);

impl StepCount {
    pub fn new(value: usize) -> Self {
        Self(value)
    }

    pub fn value(&self) -> usize {
        self.0
    }
}

/// Repeatedly apply an endomorphism: `A0 -> A1 -> ... -> An`.
pub fn apply_endomorphism_n_times<T, E>(endo: &E, mut value: T, count: StepCount) -> CtResult<T>
where
    E: Endomorphism<T>,
{
    for _ in 0..count.value() {
        value = endo.apply(value)?;
    }

    Ok(value)
}

#[cfg(test)]
mod tests {
    use super::*;
    use crate::error::CtError;

    #[derive(Debug, Clone, Copy)]
    struct AddOne;

    impl Morphism<i32, i32> for AddOne {
        fn name(&self) -> &'static str {
            "add_one"
        }

        fn apply(&self, input: i32) -> CtResult<i32> {
            Ok(input + 1)
        }
    }

    #[derive(Debug, Clone, Copy)]
    struct Double;

    impl Morphism<i32, i32> for Double {
        fn name(&self) -> &'static str {
            "double"
        }

        fn apply(&self, input: i32) -> CtResult<i32> {
            Ok(input * 2)
        }
    }

    #[derive(Debug, Clone, Copy)]
    struct Fail;

    impl Morphism<i32, i32> for Fail {
        fn name(&self) -> &'static str {
            "fail"
        }

        fn apply(&self, _input: i32) -> CtResult<i32> {
            Err(CtError::InvalidQuantity {
                kind: "test morphism",
                value: -1,
            })
        }
    }

    #[test]
    fn identity_returns_the_same_value() -> CtResult<()> {
        let value = String::from("same");

        assert_eq!(Identity::<String>::new().apply(value.clone())?, value);
        Ok(())
    }

    #[test]
    fn identity_composes_without_changing_behavior() -> CtResult<()> {
        let left_identity = Compose::<_, _, i32>::new(Identity::<i32>::new(), AddOne);
        let right_identity = Compose::<_, _, i32>::new(AddOne, Identity::<i32>::new());

        assert_eq!(left_identity.apply(41)?, AddOne.apply(41)?);
        assert_eq!(right_identity.apply(41)?, AddOne.apply(41)?);
        Ok(())
    }

    #[test]
    fn composition_applies_first_then_second() -> CtResult<()> {
        let add_then_double = Compose::<_, _, i32>::new(AddOne, Double);

        assert_eq!(add_then_double.apply(4)?, 10);
        Ok(())
    }

    #[test]
    fn composition_returns_the_first_error() {
        let composed = Compose::<_, _, i32>::new(Fail, AddOne);

        assert!(matches!(
            composed.apply(4),
            Err(CtError::InvalidQuantity {
                kind: "test morphism",
                value: -1,
            })
        ));
    }
}

src/ml.rs

use crate::category::{Compose, Morphism};
use crate::domain::{
    Distribution, Logits, Loss, Parameters, Product, TokenId, TokenSequence, TrainingSet, Vector,
    approx_eq,
};
use crate::error::{CtError, CtResult};

/// Turns adjacent tokens into next-token training examples.
#[derive(Debug, Clone)]
pub struct DatasetWindowing;

impl Morphism<TokenSequence, TrainingSet> for DatasetWindowing {
    fn name(&self) -> &'static str {
        "dataset_windowing"
    }

    fn apply(&self, tokens: TokenSequence) -> CtResult<TrainingSet> {
        if tokens.as_slice().len() < 2 {
            return Err(CtError::EmptyInput(
                "dataset windowing requires at least 2 tokens",
            ));
        }

        TrainingSet::new(
            tokens
                .as_slice()
                .windows(2)
                .map(|pair| Product::new(pair[0], pair[1])),
        )
    }
}

/// Morphism from token id to embedding vector.
#[derive(Debug, Clone)]
pub struct Embedding {
    table: Vec<Vec<f32>>,
}

impl Embedding {
    pub fn from_parameters(params: &Parameters) -> Self {
        Self {
            table: params.embedding_table().to_vec(),
        }
    }
}

impl Morphism<TokenId, Vector> for Embedding {
    fn name(&self) -> &'static str {
        "embedding"
    }

    fn apply(&self, token: TokenId) -> CtResult<Vector> {
        let Some(row) = self.table.get(token.index()) else {
            return Err(CtError::OutOfRange {
                kind: "token",
                index: token.index(),
                limit: self.table.len(),
            });
        };

        Ok(Vector::new(row.clone()))
    }
}

/// Linear projection from hidden vector to vocabulary logits.
#[derive(Debug, Clone)]
pub struct LinearToLogits {
    weight: Vec<Vec<f32>>,
    bias: Vec<f32>,
}

impl LinearToLogits {
    pub fn from_parameters(params: &Parameters) -> Self {
        Self {
            weight: params.lm_head().to_vec(),
            bias: params.bias().to_vec(),
        }
    }

    pub(crate) fn from_parts(weight: Vec<Vec<f32>>, bias: Vec<f32>) -> Self {
        Self { weight, bias }
    }
}

impl Morphism<Vector, Logits> for LinearToLogits {
    fn name(&self) -> &'static str {
        "linear_to_logits"
    }

    fn apply(&self, input: Vector) -> CtResult<Logits> {
        let d_model = input.as_slice().len();
        let vocab_size = self.bias.len();

        if self.weight.len() != d_model {
            return Err(CtError::ShapeMismatch {
                op: "linear layer",
                expected: format!("weight rows == input dim {d_model}"),
                got: format!("weight rows {}", self.weight.len()),
            });
        }

        let mut out = self.bias.clone();

        for (feature, input_value) in input.as_slice().iter().enumerate() {
            if self.weight[feature].len() != vocab_size {
                return Err(CtError::ShapeMismatch {
                    op: "linear layer",
                    expected: format!("weight cols == vocab size {vocab_size}"),
                    got: format!("weight cols {}", self.weight[feature].len()),
                });
            }

            for (vocab_id, output_value) in out.iter_mut().enumerate() {
                *output_value += input_value * self.weight[feature][vocab_id];
            }
        }

        Ok(Logits::new(out))
    }
}

/// Converts logits to a probability distribution.
#[derive(Debug, Clone)]
pub struct Softmax;

impl Morphism<Logits, Distribution> for Softmax {
    fn name(&self) -> &'static str {
        "softmax"
    }

    fn apply(&self, logits: Logits) -> CtResult<Distribution> {
        if logits.as_slice().is_empty() {
            return Err(CtError::EmptyInput("softmax"));
        }

        let max_value = logits
            .as_slice()
            .iter()
            .copied()
            .fold(f32::NEG_INFINITY, f32::max);
        let mut exps = Vec::with_capacity(logits.as_slice().len());
        let mut sum = 0.0;

        for value in logits.as_slice() {
            let exp = (*value - max_value).exp();
            exps.push(exp);
            sum += exp;
        }

        if sum <= 0.0 || !sum.is_finite() {
            return Err(CtError::InvalidProbability("softmax"));
        }

        Distribution::new(exps.into_iter().map(|value| value / sum).collect())
    }
}

/// Negative log likelihood for `(distribution, target_token)`.
#[derive(Debug, Clone)]
pub struct CrossEntropy;

impl Morphism<Product<Distribution, TokenId>, Loss> for CrossEntropy {
    fn name(&self) -> &'static str {
        "cross_entropy"
    }

    fn apply(&self, input: Product<Distribution, TokenId>) -> CtResult<Loss> {
        let (distribution, target) = input.into_parts();

        let Some(probability) = distribution.as_slice().get(target.index()).copied() else {
            return Err(CtError::OutOfRange {
                kind: "target",
                index: target.index(),
                limit: distribution.as_slice().len(),
            });
        };

        Loss::new(-probability.max(1e-9).ln())
    }
}

/// Direct path used for a commutative-diagram check.
#[derive(Debug, Clone)]
pub struct DirectPredict {
    params: Parameters,
}

impl DirectPredict {
    pub fn new(params: Parameters) -> Self {
        Self { params }
    }
}

impl Morphism<TokenId, Distribution> for DirectPredict {
    fn name(&self) -> &'static str {
        "direct_predict"
    }

    fn apply(&self, token: TokenId) -> CtResult<Distribution> {
        let embedding = Embedding::from_parameters(&self.params).apply(token)?;
        let logits = LinearToLogits::from_parameters(&self.params).apply(embedding)?;
        Softmax.apply(logits)
    }
}

/// Average cross-entropy over a training set.
pub fn average_loss(params: &Parameters, dataset: &TrainingSet) -> CtResult<Loss> {
    let embedding = Embedding::from_parameters(params);
    let linear = LinearToLogits::from_parameters(params);
    let predict = Compose::<_, _, Vector>::new(embedding, linear);
    let predict = Compose::<_, _, Logits>::new(predict, Softmax);
    let loss_fn = CrossEntropy;

    let mut total = 0.0;

    for example in dataset.examples() {
        let distribution = predict.apply(*example.first())?;
        let loss = loss_fn.apply(Product::new(distribution, *example.second()))?;
        total += loss.value();
    }

    Loss::new(total / dataset.len() as f32)
}

/// Verifies that the composed path and direct path produce the same result.
pub fn composed_prediction_matches_direct_prediction(params: &Parameters) -> CtResult<bool> {
    let token = TokenId::new(1);

    let composed = Compose::<_, _, Vector>::new(
        Embedding::from_parameters(params),
        LinearToLogits::from_parameters(params),
    );
    let composed = Compose::<_, _, Logits>::new(composed, Softmax);
    let direct = DirectPredict::new(params.clone());

    let left_path = composed.apply(token)?;
    let right_path = direct.apply(token)?;

    Ok(left_path
        .as_slice()
        .iter()
        .zip(right_path.as_slice().iter())
        .all(|(a, b)| approx_eq(*a, *b, 1e-6)))
}

#[cfg(test)]
mod tests {
    use super::*;
    use crate::domain::{ModelDimension, VocabSize};

    #[test]
    fn dataset_windowing_builds_adjacent_pairs() -> CtResult<()> {
        let tokens = TokenSequence::from_indices([1, 2, 3])?;
        let dataset = DatasetWindowing.apply(tokens)?;

        assert_eq!(dataset.len(), 2);
        assert_eq!(dataset.examples()[0].first().index(), 1);
        assert_eq!(dataset.examples()[0].second().index(), 2);
        Ok(())
    }

    #[test]
    fn composed_and_direct_prediction_match() -> CtResult<()> {
        let params = Parameters::init(VocabSize::new(5)?, ModelDimension::new(4)?);

        assert!(composed_prediction_matches_direct_prediction(&params)?);
        Ok(())
    }

    #[test]
    fn softmax_normalizes_logits_into_distribution() -> CtResult<()> {
        let distribution = Softmax.apply(Logits::new(vec![1.0, 2.0, 3.0]))?;
        let probabilities = distribution.as_slice();
        let sum: f32 = probabilities.iter().sum();

        assert!(approx_eq(sum, 1.0, 1e-6));
        assert!(probabilities[2] > probabilities[1]);
        assert!(probabilities[1] > probabilities[0]);
        Ok(())
    }

    #[test]
    fn cross_entropy_is_lower_for_more_confident_target_probability() -> CtResult<()> {
        let confident = Distribution::new(vec![0.9, 0.1])?;
        let surprised = Distribution::new(vec![0.1, 0.9])?;

        let confident_loss = CrossEntropy.apply(Product::new(confident, TokenId::new(0)))?;
        let surprised_loss = CrossEntropy.apply(Product::new(surprised, TokenId::new(0)))?;

        assert!(confident_loss.value() < surprised_loss.value());
        Ok(())
    }
}

src/training.rs

use crate::category::Morphism;
use crate::domain::{LearningRate, Parameters, TrainingSet, Vector};
use crate::error::{CtError, CtResult};
use crate::ml::{LinearToLogits, Softmax};

/// One full-batch optimizer update.
///
/// Categorically, this is an endomorphism:
///
/// `Parameters -> Parameters`
#[derive(Debug, Clone)]
pub struct TrainStep {
    dataset: TrainingSet,
    learning_rate: LearningRate,
}

impl TrainStep {
    pub fn new(dataset: TrainingSet, learning_rate: LearningRate) -> Self {
        Self {
            dataset,
            learning_rate,
        }
    }
}

impl Morphism<Parameters, Parameters> for TrainStep {
    fn name(&self) -> &'static str {
        "train_step_endomorphism"
    }

    fn apply(&self, params: Parameters) -> CtResult<Parameters> {
        let vocab_size = params.vocab_size();
        let d_model = params.d_model();

        if vocab_size == 0 || d_model == 0 {
            return Err(CtError::EmptyInput("parameters"));
        }

        let mut grad_embedding = vec![vec![0.0; d_model]; params.embedding.len()];
        let mut grad_lm_head = vec![vec![0.0; vocab_size]; d_model];
        let mut grad_bias = vec![0.0; vocab_size];

        for example in self.dataset.examples() {
            let input_id = example.first().index();
            let target_id = example.second().index();

            if input_id >= params.embedding.len() {
                return Err(CtError::OutOfRange {
                    kind: "input token",
                    index: input_id,
                    limit: params.embedding.len(),
                });
            }

            if target_id >= vocab_size {
                return Err(CtError::OutOfRange {
                    kind: "target token",
                    index: target_id,
                    limit: vocab_size,
                });
            }

            let x = &params.embedding[input_id];
            let logits = LinearToLogits::from_parts(params.lm_head.clone(), params.bias.clone())
                .apply(Vector::new(x.clone()))?;
            let probs = Softmax.apply(logits)?;

            let mut dlogits = probs.as_slice().to_vec();
            dlogits[target_id] -= 1.0;

            for (vocab_id, dlogit) in dlogits.iter().copied().enumerate() {
                grad_bias[vocab_id] += dlogit;

                for (feature, x_feature) in x.iter().copied().enumerate() {
                    grad_lm_head[feature][vocab_id] += x_feature * dlogit;
                }
            }

            for (feature, grad_feature) in grad_embedding[input_id].iter_mut().enumerate() {
                let dx = params.lm_head[feature]
                    .iter()
                    .zip(dlogits.iter())
                    .map(|(weight, dlogit)| weight * dlogit)
                    .sum::<f32>();

                *grad_feature += dx;
            }
        }

        let batch_scale = 1.0 / self.dataset.len() as f32;
        let learning_rate = self.learning_rate.value();
        let mut updated = params.clone();

        for (row, grad_row) in updated.embedding.iter_mut().zip(grad_embedding.iter()) {
            for (value, grad) in row.iter_mut().zip(grad_row.iter()) {
                *value -= learning_rate * grad * batch_scale;
            }
        }

        for (row, grad_row) in updated.lm_head.iter_mut().zip(grad_lm_head.iter()) {
            for (value, grad) in row.iter_mut().zip(grad_row.iter()) {
                *value -= learning_rate * grad * batch_scale;
            }
        }

        for (bias, grad) in updated.bias.iter_mut().zip(grad_bias.iter()) {
            *bias -= learning_rate * grad * batch_scale;
        }

        Ok(updated)
    }
}

#[cfg(test)]
mod tests {
    use super::*;
    use crate::category::{StepCount, apply_endomorphism_n_times};
    use crate::domain::{ModelDimension, Product, TokenId, TokenSequence, TrainingSet, VocabSize};
    use crate::ml::{DatasetWindowing, average_loss};

    #[test]
    fn repeated_training_step_reduces_loss() -> CtResult<()> {
        let tokens = TokenSequence::from_indices([1, 2, 3, 4, 1, 2, 3, 4])?;
        let dataset = DatasetWindowing.apply(tokens)?;
        let params = Parameters::init(VocabSize::new(5)?, ModelDimension::new(4)?);
        let before = average_loss(&params, &dataset)?;
        let train_step = TrainStep::new(dataset.clone(), LearningRate::new(1.0)?);
        let trained = apply_endomorphism_n_times(&train_step, params, StepCount::new(80))?;
        let after = average_loss(&trained, &dataset)?;

        assert!(after.value() < before.value());
        Ok(())
    }

    #[test]
    fn one_training_step_preserves_parameter_shape() -> CtResult<()> {
        let tokens = TokenSequence::from_indices([1, 2, 3, 4])?;
        let dataset = DatasetWindowing.apply(tokens)?;
        let params = Parameters::init(VocabSize::new(5)?, ModelDimension::new(4)?);
        let train_step = TrainStep::new(dataset, LearningRate::new(0.1)?);

        let trained = train_step.apply(params.clone())?;

        assert_eq!(trained.vocab_size(), params.vocab_size());
        assert_eq!(trained.d_model(), params.d_model());
        Ok(())
    }

    #[test]
    fn training_rejects_target_outside_vocabulary() -> CtResult<()> {
        let dataset = TrainingSet::new([Product::new(TokenId::new(0), TokenId::new(9))])?;
        let params = Parameters::init(VocabSize::new(2)?, ModelDimension::new(2)?);
        let train_step = TrainStep::new(dataset, LearningRate::new(0.1)?);

        assert!(matches!(
            train_step.apply(params),
            Err(CtError::OutOfRange {
                kind: "target token",
                index: 9,
                limit: 2,
            })
        ));

        Ok(())
    }
}

src/structure.rs

/// A minimal functor interface for this tutorial.
pub trait Functor<A, B> {
    type WrappedA;
    type WrappedB;

    fn fmap<F>(wrapped: Self::WrappedA, f: F) -> Self::WrappedB
    where
        F: Fn(A) -> B;
}

pub struct VecFunctor;

impl<A, B> Functor<A, B> for VecFunctor {
    type WrappedA = Vec<A>;
    type WrappedB = Vec<B>;

    fn fmap<F>(wrapped: Vec<A>, f: F) -> Vec<B>
    where
        F: Fn(A) -> B,
    {
        wrapped.into_iter().map(f).collect()
    }
}

pub struct OptionFunctor;

impl<A, B> Functor<A, B> for OptionFunctor {
    type WrappedA = Option<A>;
    type WrappedB = Option<B>;

    fn fmap<F>(wrapped: Option<A>, f: F) -> Option<B>
    where
        F: Fn(A) -> B,
    {
        wrapped.map(f)
    }
}

/// A structure-preserving conversion between wrappers.
pub trait NaturalTransformation<A> {
    type From;
    type To;

    fn transform(from: Self::From) -> Self::To;
}

/// Natural transformation from `Vec<A>` to `Option<A>` by taking the first item.
pub struct VecToFirstOption;

impl<A> NaturalTransformation<A> for VecToFirstOption {
    type From = Vec<A>;
    type To = Option<A>;

    fn transform(from: Vec<A>) -> Option<A> {
        from.into_iter().next()
    }
}

pub fn naturality_square_holds_for_first_option() -> bool {
    let xs = vec![1, 2, 3];
    let f = |x| x * 10;

    let path_top_then_right = VecToFirstOption::transform(VecFunctor::fmap(xs.clone(), f));
    let path_left_then_bottom = OptionFunctor::fmap(VecToFirstOption::transform(xs), f);

    path_top_then_right == path_left_then_bottom
}

pub trait Monoid: Sized {
    fn empty() -> Self;
    fn combine(&self, other: &Self) -> Self;
}

#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct TraceStep(&'static str);

impl TraceStep {
    pub fn new(name: &'static str) -> Self {
        Self(name)
    }

    pub fn name(&self) -> &'static str {
        self.0
    }
}

#[derive(Debug, Clone, PartialEq, Eq)]
pub struct PipelineTrace(Vec<TraceStep>);

impl PipelineTrace {
    pub fn from_steps(steps: impl IntoIterator<Item = TraceStep>) -> Self {
        Self(steps.into_iter().collect())
    }

    pub fn names(&self) -> Vec<&'static str> {
        self.0.iter().map(TraceStep::name).collect()
    }
}

impl Monoid for PipelineTrace {
    fn empty() -> Self {
        PipelineTrace(vec![])
    }

    fn combine(&self, other: &Self) -> Self {
        let mut combined = self.0.clone();
        combined.extend_from_slice(&other.0);
        PipelineTrace(combined)
    }
}

pub fn monoid_laws_hold_for_pipeline_trace() -> bool {
    let a = PipelineTrace::from_steps(vec![TraceStep::new("embedding")]);
    let b = PipelineTrace::from_steps(vec![TraceStep::new("linear")]);
    let c = PipelineTrace::from_steps(vec![TraceStep::new("softmax")]);
    let identity = PipelineTrace::empty();

    let left_identity = identity.combine(&a) == a;
    let right_identity = a.combine(&identity) == a;
    let associativity = a.combine(&b).combine(&c) == a.combine(&b.combine(&c));

    left_identity && right_identity && associativity
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn vec_functor_preserves_identity_for_values() {
        let values = vec![1, 2, 3];

        assert_eq!(VecFunctor::fmap(values.clone(), |value| value), values);
    }

    #[test]
    fn vec_functor_preserves_composition_for_values() {
        let values = vec![1, 2, 3];
        let add_one = |value| value + 1;
        let double = |value| value * 2;

        let map_then_map = VecFunctor::fmap(VecFunctor::fmap(values.clone(), add_one), double);
        let map_composed = VecFunctor::fmap(values, |value| double(add_one(value)));

        assert_eq!(map_then_map, map_composed);
    }

    #[test]
    fn option_functor_preserves_absence() {
        let missing = OptionFunctor::fmap(None::<i32>, |value| value * 10);

        assert_eq!(missing, None);
    }

    #[test]
    fn naturality_square_commutes() {
        assert!(naturality_square_holds_for_first_option());
    }

    #[test]
    fn pipeline_trace_obeys_monoid_laws() {
        assert!(monoid_laws_hold_for_pipeline_trace());
    }
}

src/calculus.rs

use crate::error::{CtError, CtResult};

#[derive(Debug, Clone, Copy, PartialEq)]
pub struct Scalar(f32);

impl Scalar {
    pub fn new(value: f32) -> CtResult<Self> {
        if !value.is_finite() {
            return Err(CtError::InvalidLoss(value));
        }

        Ok(Self(value))
    }

    pub fn value(&self) -> f32 {
        self.0
    }
}

#[derive(Debug, Clone, Copy, PartialEq)]
pub struct LocalGradient(f32);

impl LocalGradient {
    pub fn new(value: f32) -> CtResult<Self> {
        if !value.is_finite() {
            return Err(CtError::InvalidLoss(value));
        }

        Ok(Self(value))
    }

    pub fn value(&self) -> f32 {
        self.0
    }
}

#[derive(Debug, Clone, Copy)]
pub struct MulOp;

impl MulOp {
    pub fn forward(&self, x: Scalar, y: Scalar) -> CtResult<Scalar> {
        Scalar::new(x.value() * y.value())
    }

    /// Given upstream gradient dL/dz, return `(dL/dx, dL/dy)` for `z = x * y`.
    pub fn backward(
        &self,
        x: Scalar,
        y: Scalar,
        upstream: LocalGradient,
    ) -> CtResult<(LocalGradient, LocalGradient)> {
        let dz_dx = y.value();
        let dz_dy = x.value();

        Ok((
            LocalGradient::new(upstream.value() * dz_dx)?,
            LocalGradient::new(upstream.value() * dz_dy)?,
        ))
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn multiply_backward_returns_local_chain_rule_gradients() -> CtResult<()> {
        let mul = MulOp;
        let x = Scalar::new(2.0)?;
        let y = Scalar::new(3.0)?;
        let upstream = LocalGradient::new(1.0)?;
        let (dl_dx, dl_dy) = mul.backward(x, y, upstream)?;

        assert_eq!(dl_dx.value(), 3.0);
        assert_eq!(dl_dy.value(), 2.0);
        Ok(())
    }

    #[test]
    fn multiply_backward_scales_with_upstream_gradient() -> CtResult<()> {
        let mul = MulOp;
        let x = Scalar::new(2.0)?;
        let y = Scalar::new(3.0)?;
        let upstream = LocalGradient::new(4.0)?;
        let (dl_dx, dl_dy) = mul.backward(x, y, upstream)?;

        assert_eq!(dl_dx.value(), 12.0);
        assert_eq!(dl_dy.value(), 8.0);
        Ok(())
    }
}

src/attention.rs

//! Tiny typed attention boundary for the Transformer roadmap.
//!
//! This module does not implement a full Transformer. It makes the first small
//! attention-specific shapes explicit:
//!
//! - projected queries and keys become query-by-key scores,
//! - masks turn illegal score positions into negligible softmax inputs,
//! - query-by-key scores become row-wise attention probabilities,
//! - attention probabilities mix value vectors into output vectors,
//! - multiple head outputs concatenate into a multi-head output,
//! - the concatenated heads project back into a hidden sequence width,
//! - residual addition preserves the hidden sequence boundary,
//! - layer normalization preserves the hidden sequence boundary,
//! - a position-wise feed-forward map preserves the hidden sequence boundary,
//! - positional encoding adds position information while preserving shape,
//! - a single-head block sketch composes those boundaries end to end,
//! - a masked multi-head block accepts attention masks at the block boundary,
//! - a structured parameter object owns position, block, and readout pieces,
//! - a training-state object owns parameters, learning rate, and step count,
//! - a composed block train step updates readout, feed-forward,
//!   normalization, and attention projection parameters from sequence targets.

use crate::category::{Morphism, StepCount};
use crate::domain::{
    Distribution, LearningRate, Logits, Loss, ModelDimension, Product, TokenSequence, Vector,
    VocabSize,
};
use crate::error::{CtError, CtResult};
use crate::ml::Softmax;

const MASKED_SCORE: f32 = -1_000_000.0;

/// Number of positions in a sequence.
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct SequenceLength(usize);

impl SequenceLength {
    pub fn new(value: usize) -> CtResult<Self> {
        if value == 0 {
            return Err(CtError::EmptyInput("sequence length"));
        }

        Ok(Self(value))
    }

    pub fn value(&self) -> usize {
        self.0
    }
}

/// Number of parallel attention heads.
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct HeadCount(usize);

impl HeadCount {
    pub fn new(value: usize) -> CtResult<Self> {
        if value == 0 {
            return Err(CtError::EmptyInput("head count"));
        }

        Ok(Self(value))
    }

    pub fn value(&self) -> usize {
        self.0
    }
}

/// Width of one attention head.
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct HeadDimension(usize);

impl HeadDimension {
    pub fn new(value: usize) -> CtResult<Self> {
        if value == 0 {
            return Err(CtError::EmptyInput("head dimension"));
        }

        Ok(Self(value))
    }

    pub fn value(&self) -> usize {
        self.0
    }
}

/// Positive stabilizer used in layer normalization.
#[derive(Debug, Clone, Copy, PartialEq)]
pub struct NormalizationEpsilon(f32);

impl NormalizationEpsilon {
    pub fn new(value: f32) -> CtResult<Self> {
        if !value.is_finite() || value <= 0.0 {
            return Err(CtError::ShapeMismatch {
                op: "normalization epsilon",
                expected: "positive finite epsilon".to_string(),
                got: format!("epsilon {value}"),
            });
        }

        Ok(Self(value))
    }

    pub fn value(&self) -> f32 {
        self.0
    }
}

/// Projected query vectors for one attention head.
#[derive(Debug, Clone, PartialEq)]
pub struct QuerySequence {
    sequence_len: SequenceLength,
    head_dimension: HeadDimension,
    rows: Vec<Vector>,
}

impl QuerySequence {
    pub fn new(rows: Vec<Vec<f32>>) -> CtResult<Self> {
        let matrix = AttentionVectorRows::new("query sequence", rows)?;

        Ok(Self {
            sequence_len: matrix.sequence_len,
            head_dimension: matrix.head_dimension,
            rows: matrix.rows,
        })
    }

    pub fn sequence_len(&self) -> SequenceLength {
        self.sequence_len
    }

    pub fn head_dimension(&self) -> HeadDimension {
        self.head_dimension
    }

    pub fn rows(&self) -> &[Vector] {
        &self.rows
    }
}

/// Projected key vectors for one attention head.
#[derive(Debug, Clone, PartialEq)]
pub struct KeySequence {
    sequence_len: SequenceLength,
    head_dimension: HeadDimension,
    rows: Vec<Vector>,
}

impl KeySequence {
    pub fn new(rows: Vec<Vec<f32>>) -> CtResult<Self> {
        let matrix = AttentionVectorRows::new("key sequence", rows)?;

        Ok(Self {
            sequence_len: matrix.sequence_len,
            head_dimension: matrix.head_dimension,
            rows: matrix.rows,
        })
    }

    pub fn sequence_len(&self) -> SequenceLength {
        self.sequence_len
    }

    pub fn head_dimension(&self) -> HeadDimension {
        self.head_dimension
    }

    pub fn rows(&self) -> &[Vector] {
        &self.rows
    }
}

/// Projected value vectors for one attention head.
#[derive(Debug, Clone, PartialEq)]
pub struct ValueSequence {
    sequence_len: SequenceLength,
    head_dimension: HeadDimension,
    rows: Vec<Vector>,
}

impl ValueSequence {
    pub fn new(rows: Vec<Vec<f32>>) -> CtResult<Self> {
        let matrix = AttentionVectorRows::new("value sequence", rows)?;

        Ok(Self {
            sequence_len: matrix.sequence_len,
            head_dimension: matrix.head_dimension,
            rows: matrix.rows,
        })
    }

    pub fn sequence_len(&self) -> SequenceLength {
        self.sequence_len
    }

    pub fn head_dimension(&self) -> HeadDimension {
        self.head_dimension
    }

    pub fn rows(&self) -> &[Vector] {
        &self.rows
    }
}

/// Hidden vectors over sequence positions.
#[derive(Debug, Clone, PartialEq)]
pub struct HiddenSequence {
    sequence_len: SequenceLength,
    model_dimension: ModelDimension,
    rows: Vec<Vector>,
}

impl HiddenSequence {
    pub fn new(rows: Vec<Vec<f32>>) -> CtResult<Self> {
        let matrix = AttentionVectorRows::new("hidden sequence", rows)?;

        Ok(Self {
            sequence_len: matrix.sequence_len,
            model_dimension: ModelDimension::new(matrix.head_dimension.value())?,
            rows: matrix.rows,
        })
    }

    pub fn sequence_len(&self) -> SequenceLength {
        self.sequence_len
    }

    pub fn model_dimension(&self) -> ModelDimension {
        self.model_dimension
    }

    pub fn rows(&self) -> &[Vector] {
        &self.rows
    }
}

/// A finite table of position vectors added to hidden states.
#[derive(Debug, Clone, PartialEq)]
pub struct PositionalEncoding {
    max_sequence_len: SequenceLength,
    model_dimension: ModelDimension,
    rows: Vec<Vector>,
}

impl PositionalEncoding {
    pub fn new(rows: Vec<Vec<f32>>) -> CtResult<Self> {
        let matrix = AttentionVectorRows::new("positional encoding", rows)?;

        Ok(Self {
            max_sequence_len: matrix.sequence_len,
            model_dimension: ModelDimension::new(matrix.head_dimension.value())?,
            rows: matrix.rows,
        })
    }

    pub fn max_sequence_len(&self) -> SequenceLength {
        self.max_sequence_len
    }

    pub fn model_dimension(&self) -> ModelDimension {
        self.model_dimension
    }
}

#[derive(Debug, Clone, PartialEq)]
struct AttentionVectorRows {
    sequence_len: SequenceLength,
    head_dimension: HeadDimension,
    rows: Vec<Vector>,
}

impl AttentionVectorRows {
    fn new(kind: &'static str, rows: Vec<Vec<f32>>) -> CtResult<Self> {
        if rows.is_empty() {
            return Err(CtError::EmptyInput(kind));
        }

        let head_dimension = rows[0].len();

        if head_dimension == 0 {
            return Err(CtError::EmptyInput("attention vector row"));
        }

        for row in &rows {
            if row.len() != head_dimension {
                return Err(CtError::ShapeMismatch {
                    op: kind,
                    expected: format!("all rows have {head_dimension} columns"),
                    got: format!("row with {} columns", row.len()),
                });
            }

            if row.iter().any(|value| !value.is_finite()) {
                return Err(CtError::ShapeMismatch {
                    op: kind,
                    expected: "all vector values are finite".to_string(),
                    got: "non-finite vector value".to_string(),
                });
            }
        }

        Ok(Self {
            sequence_len: SequenceLength::new(rows.len())?,
            head_dimension: HeadDimension::new(head_dimension)?,
            rows: rows.into_iter().map(Vector::new).collect(),
        })
    }
}

/// Query-by-key scores before row-wise softmax.
#[derive(Debug, Clone, PartialEq)]
pub struct AttentionScores {
    query_len: SequenceLength,
    key_len: SequenceLength,
    rows: Vec<Logits>,
}

impl AttentionScores {
    pub fn new(rows: Vec<Vec<f32>>) -> CtResult<Self> {
        if rows.is_empty() {
            return Err(CtError::EmptyInput("attention scores"));
        }

        let key_len = rows[0].len();

        if key_len == 0 {
            return Err(CtError::EmptyInput("attention score row"));
        }

        for row in &rows {
            if row.len() != key_len {
                return Err(CtError::ShapeMismatch {
                    op: "attention scores",
                    expected: format!("all rows have {key_len} columns"),
                    got: format!("row with {} columns", row.len()),
                });
            }

            if row.iter().any(|value| !value.is_finite()) {
                return Err(CtError::ShapeMismatch {
                    op: "attention scores",
                    expected: "all score values are finite".to_string(),
                    got: "non-finite score value".to_string(),
                });
            }
        }

        Ok(Self {
            query_len: SequenceLength::new(rows.len())?,
            key_len: SequenceLength::new(key_len)?,
            rows: rows.into_iter().map(Logits::new).collect(),
        })
    }

    pub fn query_len(&self) -> SequenceLength {
        self.query_len
    }

    pub fn key_len(&self) -> SequenceLength {
        self.key_len
    }

    pub fn rows(&self) -> &[Logits] {
        &self.rows
    }
}

/// Allowed query-by-key positions before attention softmax.
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct AttentionMask {
    query_len: SequenceLength,
    key_len: SequenceLength,
    rows: Vec<Vec<bool>>,
}

impl AttentionMask {
    pub fn new(rows: Vec<Vec<bool>>) -> CtResult<Self> {
        if rows.is_empty() {
            return Err(CtError::EmptyInput("attention mask"));
        }

        let key_len = rows[0].len();

        if key_len == 0 {
            return Err(CtError::EmptyInput("attention mask row"));
        }

        for row in &rows {
            if row.len() != key_len {
                return Err(CtError::ShapeMismatch {
                    op: "attention mask",
                    expected: format!("all rows have {key_len} columns"),
                    got: format!("row with {} columns", row.len()),
                });
            }

            if !row.iter().any(|allowed| *allowed) {
                return Err(CtError::EmptyInput("attention mask row allows no keys"));
            }
        }

        Ok(Self {
            query_len: SequenceLength::new(rows.len())?,
            key_len: SequenceLength::new(key_len)?,
            rows,
        })
    }

    pub fn query_len(&self) -> SequenceLength {
        self.query_len
    }

    pub fn key_len(&self) -> SequenceLength {
        self.key_len
    }

    pub fn rows(&self) -> &[Vec<bool>] {
        &self.rows
    }
}

/// Computes scaled query-key dot-product scores.
#[derive(Debug, Clone)]
pub struct ScaledDotProductScores;

impl Morphism<Product<QuerySequence, KeySequence>, AttentionScores> for ScaledDotProductScores {
    fn name(&self) -> &'static str {
        "scaled_dot_product_scores"
    }

    fn apply(&self, input: Product<QuerySequence, KeySequence>) -> CtResult<AttentionScores> {
        let (queries, keys) = input.into_parts();
        let query_dimension = queries.head_dimension();
        let key_dimension = keys.head_dimension();

        if query_dimension != key_dimension {
            return Err(CtError::ShapeMismatch {
                op: "scaled dot-product attention scores",
                expected: format!("query head dimension {}", query_dimension.value()),
                got: format!("key head dimension {}", key_dimension.value()),
            });
        }

        let scale = (query_dimension.value() as f32).sqrt();
        let rows = queries
            .rows()
            .iter()
            .map(|query| {
                keys.rows()
                    .iter()
                    .map(|key| dot_product(query.as_slice(), key.as_slice()) / scale)
                    .collect::<Vec<_>>()
            })
            .collect::<Vec<_>>();

        AttentionScores::new(rows)
    }
}

fn dot_product(left: &[f32], right: &[f32]) -> f32 {
    left.iter()
        .zip(right.iter())
        .map(|(left, right)| left * right)
        .sum()
}

/// Applies a boolean attention mask to score rows before softmax.
#[derive(Debug, Clone)]
pub struct MaskedAttentionScores;

impl Morphism<Product<AttentionScores, AttentionMask>, AttentionScores> for MaskedAttentionScores {
    fn name(&self) -> &'static str {
        "masked_attention_scores"
    }

    fn apply(&self, input: Product<AttentionScores, AttentionMask>) -> CtResult<AttentionScores> {
        let (scores, mask) = input.into_parts();

        if scores.query_len() != mask.query_len() || scores.key_len() != mask.key_len() {
            return Err(CtError::ShapeMismatch {
                op: "masked attention scores",
                expected: format!(
                    "{} query rows x {} key columns",
                    scores.query_len().value(),
                    scores.key_len().value()
                ),
                got: format!(
                    "{} query rows x {} key columns",
                    mask.query_len().value(),
                    mask.key_len().value()
                ),
            });
        }

        let rows = scores
            .rows()
            .iter()
            .zip(mask.rows())
            .map(|(score_row, mask_row)| {
                score_row
                    .as_slice()
                    .iter()
                    .zip(mask_row)
                    .map(|(score, allowed)| if *allowed { *score } else { MASKED_SCORE })
                    .collect::<Vec<_>>()
            })
            .collect::<Vec<_>>();

        AttentionScores::new(rows)
    }
}

/// Row-wise attention probabilities over key positions.
#[derive(Debug, Clone, PartialEq)]
pub struct AttentionWeights {
    query_len: SequenceLength,
    key_len: SequenceLength,
    rows: Vec<Distribution>,
}

impl AttentionWeights {
    pub fn new(rows: Vec<Distribution>) -> CtResult<Self> {
        if rows.is_empty() {
            return Err(CtError::EmptyInput("attention weights"));
        }

        let key_len = rows[0].as_slice().len();

        if key_len == 0 {
            return Err(CtError::EmptyInput("attention weight row"));
        }

        for row in &rows {
            if row.as_slice().len() != key_len {
                return Err(CtError::ShapeMismatch {
                    op: "attention weights",
                    expected: format!("all rows have {key_len} columns"),
                    got: format!("row with {} columns", row.as_slice().len()),
                });
            }
        }

        Ok(Self {
            query_len: SequenceLength::new(rows.len())?,
            key_len: SequenceLength::new(key_len)?,
            rows,
        })
    }

    pub fn query_len(&self) -> SequenceLength {
        self.query_len
    }

    pub fn key_len(&self) -> SequenceLength {
        self.key_len
    }

    pub fn rows(&self) -> &[Distribution] {
        &self.rows
    }
}

/// Weighted value vectors, one output row per query position.
#[derive(Debug, Clone, PartialEq)]
pub struct AttentionOutput {
    sequence_len: SequenceLength,
    head_dimension: HeadDimension,
    rows: Vec<Vector>,
}

impl AttentionOutput {
    pub fn new(rows: Vec<Vec<f32>>) -> CtResult<Self> {
        let matrix = AttentionVectorRows::new("attention output", rows)?;

        Ok(Self {
            sequence_len: matrix.sequence_len,
            head_dimension: matrix.head_dimension,
            rows: matrix.rows,
        })
    }

    pub fn sequence_len(&self) -> SequenceLength {
        self.sequence_len
    }

    pub fn head_dimension(&self) -> HeadDimension {
        self.head_dimension
    }

    pub fn rows(&self) -> &[Vector] {
        &self.rows
    }
}

/// Validated collection of single-head attention outputs.
#[derive(Debug, Clone, PartialEq)]
pub struct AttentionHeadOutputs {
    head_count: HeadCount,
    sequence_len: SequenceLength,
    head_dimension: HeadDimension,
    heads: Vec<AttentionOutput>,
}

impl AttentionHeadOutputs {
    pub fn new(heads: Vec<AttentionOutput>) -> CtResult<Self> {
        if heads.is_empty() {
            return Err(CtError::EmptyInput("attention head outputs"));
        }

        let sequence_len = heads[0].sequence_len();
        let head_dimension = heads[0].head_dimension();

        for head in &heads {
            if head.sequence_len() != sequence_len {
                return Err(CtError::ShapeMismatch {
                    op: "attention head outputs",
                    expected: format!("all heads have {} sequence rows", sequence_len.value()),
                    got: format!("head with {} sequence rows", head.sequence_len().value()),
                });
            }

            if head.head_dimension() != head_dimension {
                return Err(CtError::ShapeMismatch {
                    op: "attention head outputs",
                    expected: format!("all heads have dimension {}", head_dimension.value()),
                    got: format!("head dimension {}", head.head_dimension().value()),
                });
            }
        }

        Ok(Self {
            head_count: HeadCount::new(heads.len())?,
            sequence_len,
            head_dimension,
            heads,
        })
    }

    pub fn head_count(&self) -> HeadCount {
        self.head_count
    }

    pub fn sequence_len(&self) -> SequenceLength {
        self.sequence_len
    }

    pub fn head_dimension(&self) -> HeadDimension {
        self.head_dimension
    }

    pub fn heads(&self) -> &[AttentionOutput] {
        &self.heads
    }
}

/// Concatenated output of several attention heads.
#[derive(Debug, Clone, PartialEq)]
pub struct MultiHeadOutput {
    sequence_len: SequenceLength,
    head_count: HeadCount,
    head_dimension: HeadDimension,
    model_dimension: ModelDimension,
    rows: Vec<Vector>,
}

impl MultiHeadOutput {
    fn new(
        rows: Vec<Vec<f32>>,
        head_count: HeadCount,
        head_dimension: HeadDimension,
    ) -> CtResult<Self> {
        let matrix = AttentionVectorRows::new("multi-head output", rows)?;
        let expected_dimension = head_count.value() * head_dimension.value();

        if matrix.head_dimension.value() != expected_dimension {
            return Err(CtError::ShapeMismatch {
                op: "multi-head output",
                expected: format!("row dimension {expected_dimension}"),
                got: format!("row dimension {}", matrix.head_dimension.value()),
            });
        }

        Ok(Self {
            sequence_len: matrix.sequence_len,
            head_count,
            head_dimension,
            model_dimension: ModelDimension::new(expected_dimension)?,
            rows: matrix.rows,
        })
    }

    pub fn sequence_len(&self) -> SequenceLength {
        self.sequence_len
    }

    pub fn head_count(&self) -> HeadCount {
        self.head_count
    }

    pub fn head_dimension(&self) -> HeadDimension {
        self.head_dimension
    }

    pub fn model_dimension(&self) -> ModelDimension {
        self.model_dimension
    }

    pub fn rows(&self) -> &[Vector] {
        &self.rows
    }
}

/// Output sequence after the multi-head output projection.
#[derive(Debug, Clone, PartialEq)]
pub struct ProjectedAttentionOutput {
    sequence_len: SequenceLength,
    model_dimension: ModelDimension,
    rows: Vec<Vector>,
}

impl ProjectedAttentionOutput {
    pub fn new(rows: Vec<Vec<f32>>) -> CtResult<Self> {
        let matrix = AttentionVectorRows::new("projected attention output", rows)?;

        Ok(Self {
            sequence_len: matrix.sequence_len,
            model_dimension: ModelDimension::new(matrix.head_dimension.value())?,
            rows: matrix.rows,
        })
    }

    pub fn sequence_len(&self) -> SequenceLength {
        self.sequence_len
    }

    pub fn model_dimension(&self) -> ModelDimension {
        self.model_dimension
    }

    pub fn rows(&self) -> &[Vector] {
        &self.rows
    }
}

/// Vocabulary logits for every position in a hidden sequence.
#[derive(Debug, Clone, PartialEq)]
pub struct SequenceLogits {
    sequence_len: SequenceLength,
    vocab_size: VocabSize,
    rows: Vec<Logits>,
}

impl SequenceLogits {
    pub fn new(rows: Vec<Vec<f32>>) -> CtResult<Self> {
        if rows.is_empty() {
            return Err(CtError::EmptyInput("sequence logits"));
        }

        let vocab_size = rows[0].len();

        if vocab_size == 0 {
            return Err(CtError::EmptyInput("sequence logits row"));
        }

        for row in &rows {
            if row.len() != vocab_size {
                return Err(CtError::ShapeMismatch {
                    op: "sequence logits",
                    expected: format!("all rows have {vocab_size} columns"),
                    got: format!("row with {} columns", row.len()),
                });
            }

            if row.iter().any(|value| !value.is_finite()) {
                return Err(CtError::ShapeMismatch {
                    op: "sequence logits",
                    expected: "all logit values are finite".to_string(),
                    got: "non-finite logit value".to_string(),
                });
            }
        }

        Ok(Self {
            sequence_len: SequenceLength::new(rows.len())?,
            vocab_size: VocabSize::new(vocab_size)?,
            rows: rows.into_iter().map(Logits::new).collect(),
        })
    }

    pub fn sequence_len(&self) -> SequenceLength {
        self.sequence_len
    }

    pub fn vocab_size(&self) -> VocabSize {
        self.vocab_size
    }

    pub fn rows(&self) -> &[Logits] {
        &self.rows
    }
}

/// Learned language-model readout applied to each hidden position.
#[derive(Debug, Clone, PartialEq)]
pub struct TransformerReadout {
    input_dimension: ModelDimension,
    vocab_size: VocabSize,
    weight: Vec<Vec<f32>>,
    bias: Vec<f32>,
}

impl TransformerReadout {
    pub fn new(weight: Vec<Vec<f32>>, bias: Vec<f32>) -> CtResult<Self> {
        let (input_dimension, output_dimension) =
            validate_linear_parts("transformer readout", &weight, &bias)?;

        Ok(Self {
            input_dimension,
            vocab_size: VocabSize::new(output_dimension.value())?,
            weight,
            bias,
        })
    }

    pub fn input_dimension(&self) -> ModelDimension {
        self.input_dimension
    }

    pub fn vocab_size(&self) -> VocabSize {
        self.vocab_size
    }

    pub fn weight(&self) -> &[Vec<f32>] {
        &self.weight
    }

    pub fn bias(&self) -> &[f32] {
        &self.bias
    }
}

/// Learned output projection after head concatenation.
#[derive(Debug, Clone, PartialEq)]
pub struct AttentionOutputProjection {
    input_dimension: ModelDimension,
    output_dimension: ModelDimension,
    weight: Vec<Vec<f32>>,
    bias: Vec<f32>,
}

impl AttentionOutputProjection {
    pub fn new(weight: Vec<Vec<f32>>, bias: Vec<f32>) -> CtResult<Self> {
        if weight.is_empty() {
            return Err(CtError::EmptyInput("attention output projection weight"));
        }

        if bias.is_empty() {
            return Err(CtError::EmptyInput("attention output projection bias"));
        }

        let output_dimension = bias.len();

        if bias.iter().any(|value| !value.is_finite()) {
            return Err(CtError::ShapeMismatch {
                op: "attention output projection",
                expected: "finite bias values".to_string(),
                got: "non-finite bias value".to_string(),
            });
        }

        for row in &weight {
            if row.len() != output_dimension {
                return Err(CtError::ShapeMismatch {
                    op: "attention output projection",
                    expected: format!("weight rows have {output_dimension} columns"),
                    got: format!("weight row with {} columns", row.len()),
                });
            }

            if row.iter().any(|value| !value.is_finite()) {
                return Err(CtError::ShapeMismatch {
                    op: "attention output projection",
                    expected: "finite weight values".to_string(),
                    got: "non-finite weight value".to_string(),
                });
            }
        }

        Ok(Self {
            input_dimension: ModelDimension::new(weight.len())?,
            output_dimension: ModelDimension::new(output_dimension)?,
            weight,
            bias,
        })
    }

    pub fn input_dimension(&self) -> ModelDimension {
        self.input_dimension
    }

    pub fn output_dimension(&self) -> ModelDimension {
        self.output_dimension
    }

    pub fn weight(&self) -> &[Vec<f32>] {
        &self.weight
    }

    pub fn bias(&self) -> &[f32] {
        &self.bias
    }
}

/// Scale, shift, and epsilon parameters for layer normalization.
#[derive(Debug, Clone, PartialEq)]
pub struct LayerNormParameters {
    model_dimension: ModelDimension,
    scale: Vec<f32>,
    shift: Vec<f32>,
    epsilon: NormalizationEpsilon,
}

impl LayerNormParameters {
    pub fn new(scale: Vec<f32>, shift: Vec<f32>, epsilon: NormalizationEpsilon) -> CtResult<Self> {
        if scale.is_empty() {
            return Err(CtError::EmptyInput("layer norm scale"));
        }

        if shift.is_empty() {
            return Err(CtError::EmptyInput("layer norm shift"));
        }

        if scale.len() != shift.len() {
            return Err(CtError::ShapeMismatch {
                op: "layer norm parameters",
                expected: format!("scale and shift length {}", scale.len()),
                got: format!("shift length {}", shift.len()),
            });
        }

        if scale.iter().any(|value| !value.is_finite()) {
            return Err(CtError::ShapeMismatch {
                op: "layer norm parameters",
                expected: "finite scale values".to_string(),
                got: "non-finite scale value".to_string(),
            });
        }

        if shift.iter().any(|value| !value.is_finite()) {
            return Err(CtError::ShapeMismatch {
                op: "layer norm parameters",
                expected: "finite shift values".to_string(),
                got: "non-finite shift value".to_string(),
            });
        }

        Ok(Self {
            model_dimension: ModelDimension::new(scale.len())?,
            scale,
            shift,
            epsilon,
        })
    }

    pub fn identity(model_dimension: ModelDimension) -> Self {
        Self {
            model_dimension,
            scale: vec![1.0; model_dimension.value()],
            shift: vec![0.0; model_dimension.value()],
            epsilon: NormalizationEpsilon(1e-5),
        }
    }

    pub fn model_dimension(&self) -> ModelDimension {
        self.model_dimension
    }

    pub fn scale(&self) -> &[f32] {
        &self.scale
    }

    pub fn shift(&self) -> &[f32] {
        &self.shift
    }

    pub fn epsilon(&self) -> NormalizationEpsilon {
        self.epsilon
    }
}

/// Layer normalization over each hidden vector independently.
#[derive(Debug, Clone, PartialEq)]
pub struct LayerNormalization {
    parameters: LayerNormParameters,
}

impl LayerNormalization {
    pub fn new(parameters: LayerNormParameters) -> Self {
        Self { parameters }
    }

    pub fn model_dimension(&self) -> ModelDimension {
        self.parameters.model_dimension()
    }

    pub fn parameters(&self) -> &LayerNormParameters {
        &self.parameters
    }
}

#[derive(Debug, Clone, PartialEq)]
struct FeedForwardRowCache {
    input: Vec<f32>,
    pre_activation: Vec<f32>,
    activation: Vec<f32>,
    output: Vec<f32>,
}

#[derive(Debug, Clone, PartialEq)]
struct AttentionHeadTrainingCache {
    queries: QuerySequence,
    keys: KeySequence,
    values: ValueSequence,
    weights: AttentionWeights,
    output: AttentionOutput,
}

#[derive(Debug, Clone, PartialEq)]
struct MaskedBlockTrainingCache {
    output: HiddenSequence,
    with_feed_forward: HiddenSequence,
    with_attention: HiddenSequence,
    multi_head_output: MultiHeadOutput,
    attention_heads: Vec<AttentionHeadTrainingCache>,
    feed_forward_rows: Vec<FeedForwardRowCache>,
}

/// Position-wise two-layer feed-forward sublayer.
#[derive(Debug, Clone, PartialEq)]
pub struct PositionWiseFeedForward {
    input_dimension: ModelDimension,
    hidden_dimension: ModelDimension,
    output_dimension: ModelDimension,
    first_weight: Vec<Vec<f32>>,
    first_bias: Vec<f32>,
    second_weight: Vec<Vec<f32>>,
    second_bias: Vec<f32>,
}

impl PositionWiseFeedForward {
    pub fn new(
        first_weight: Vec<Vec<f32>>,
        first_bias: Vec<f32>,
        second_weight: Vec<Vec<f32>>,
        second_bias: Vec<f32>,
    ) -> CtResult<Self> {
        let (input_dimension, hidden_dimension) = validate_linear_parts(
            "position-wise feed-forward first layer",
            &first_weight,
            &first_bias,
        )?;
        let (second_input_dimension, output_dimension) = validate_linear_parts(
            "position-wise feed-forward second layer",
            &second_weight,
            &second_bias,
        )?;

        if second_input_dimension != hidden_dimension {
            return Err(CtError::ShapeMismatch {
                op: "position-wise feed-forward",
                expected: format!("second input dimension {}", hidden_dimension.value()),
                got: format!("second input dimension {}", second_input_dimension.value()),
            });
        }

        if output_dimension != input_dimension {
            return Err(CtError::ShapeMismatch {
                op: "position-wise feed-forward",
                expected: format!("output dimension {}", input_dimension.value()),
                got: format!("output dimension {}", output_dimension.value()),
            });
        }

        Ok(Self {
            input_dimension,
            hidden_dimension,
            output_dimension,
            first_weight,
            first_bias,
            second_weight,
            second_bias,
        })
    }

    pub fn input_dimension(&self) -> ModelDimension {
        self.input_dimension
    }

    pub fn hidden_dimension(&self) -> ModelDimension {
        self.hidden_dimension
    }

    pub fn output_dimension(&self) -> ModelDimension {
        self.output_dimension
    }

    pub fn first_weight(&self) -> &[Vec<f32>] {
        &self.first_weight
    }

    pub fn first_bias(&self) -> &[f32] {
        &self.first_bias
    }

    pub fn second_weight(&self) -> &[Vec<f32>] {
        &self.second_weight
    }

    pub fn second_bias(&self) -> &[f32] {
        &self.second_bias
    }
}

#[derive(Debug, Clone, PartialEq)]
struct HiddenProjection {
    op: &'static str,
    input_dimension: ModelDimension,
    head_dimension: HeadDimension,
    weight: Vec<Vec<f32>>,
    bias: Vec<f32>,
}

impl HiddenProjection {
    fn new(op: &'static str, weight: Vec<Vec<f32>>, bias: Vec<f32>) -> CtResult<Self> {
        let (input_dimension, output_dimension) = validate_linear_parts(op, &weight, &bias)?;

        Ok(Self {
            op,
            input_dimension,
            head_dimension: HeadDimension::new(output_dimension.value())?,
            weight,
            bias,
        })
    }

    fn project(&self, input: &HiddenSequence) -> CtResult<Vec<Vec<f32>>> {
        if input.model_dimension() != self.input_dimension {
            return Err(CtError::ShapeMismatch {
                op: self.op,
                expected: format!("input dimension {}", self.input_dimension.value()),
                got: format!("input dimension {}", input.model_dimension().value()),
            });
        }

        Ok(input
            .rows()
            .iter()
            .map(|row| project_row(row.as_slice(), &self.weight, &self.bias))
            .collect::<Vec<_>>())
    }

    fn input_dimension(&self) -> ModelDimension {
        self.input_dimension
    }

    fn head_dimension(&self) -> HeadDimension {
        self.head_dimension
    }

    fn weight(&self) -> &[Vec<f32>] {
        &self.weight
    }

    fn bias(&self) -> &[f32] {
        &self.bias
    }
}

/// Learned projection from hidden states to query vectors.
#[derive(Debug, Clone, PartialEq)]
pub struct HiddenToQuery {
    projection: HiddenProjection,
}

impl HiddenToQuery {
    pub fn new(weight: Vec<Vec<f32>>, bias: Vec<f32>) -> CtResult<Self> {
        Ok(Self {
            projection: HiddenProjection::new("hidden-to-query projection", weight, bias)?,
        })
    }

    pub fn input_dimension(&self) -> ModelDimension {
        self.projection.input_dimension()
    }

    pub fn head_dimension(&self) -> HeadDimension {
        self.projection.head_dimension()
    }

    pub fn weight(&self) -> &[Vec<f32>] {
        self.projection.weight()
    }

    pub fn bias(&self) -> &[f32] {
        self.projection.bias()
    }
}

/// Learned projection from hidden states to key vectors.
#[derive(Debug, Clone, PartialEq)]
pub struct HiddenToKey {
    projection: HiddenProjection,
}

impl HiddenToKey {
    pub fn new(weight: Vec<Vec<f32>>, bias: Vec<f32>) -> CtResult<Self> {
        Ok(Self {
            projection: HiddenProjection::new("hidden-to-key projection", weight, bias)?,
        })
    }

    pub fn input_dimension(&self) -> ModelDimension {
        self.projection.input_dimension()
    }

    pub fn head_dimension(&self) -> HeadDimension {
        self.projection.head_dimension()
    }

    pub fn weight(&self) -> &[Vec<f32>] {
        self.projection.weight()
    }

    pub fn bias(&self) -> &[f32] {
        self.projection.bias()
    }
}

/// Learned projection from hidden states to value vectors.
#[derive(Debug, Clone, PartialEq)]
pub struct HiddenToValue {
    projection: HiddenProjection,
}

impl HiddenToValue {
    pub fn new(weight: Vec<Vec<f32>>, bias: Vec<f32>) -> CtResult<Self> {
        Ok(Self {
            projection: HiddenProjection::new("hidden-to-value projection", weight, bias)?,
        })
    }

    pub fn input_dimension(&self) -> ModelDimension {
        self.projection.input_dimension()
    }

    pub fn head_dimension(&self) -> HeadDimension {
        self.projection.head_dimension()
    }

    pub fn weight(&self) -> &[Vec<f32>] {
        self.projection.weight()
    }

    pub fn bias(&self) -> &[f32] {
        self.projection.bias()
    }
}

/// A tiny single-head Transformer block sketch.
#[derive(Debug, Clone, PartialEq)]
pub struct SingleHeadTransformerBlock {
    model_dimension: ModelDimension,
    query_projection: HiddenToQuery,
    key_projection: HiddenToKey,
    value_projection: HiddenToValue,
    output_projection: AttentionOutputProjection,
    attention_norm: LayerNormalization,
    feed_forward: PositionWiseFeedForward,
    feed_forward_norm: LayerNormalization,
}

impl SingleHeadTransformerBlock {
    pub fn new(
        query_projection: HiddenToQuery,
        key_projection: HiddenToKey,
        value_projection: HiddenToValue,
        output_projection: AttentionOutputProjection,
        attention_norm: LayerNormalization,
        feed_forward: PositionWiseFeedForward,
        feed_forward_norm: LayerNormalization,
    ) -> CtResult<Self> {
        let model_dimension = query_projection.input_dimension();

        validate_projection_input(
            "single-head block key projection",
            model_dimension,
            key_projection.input_dimension(),
        )?;
        validate_projection_input(
            "single-head block value projection",
            model_dimension,
            value_projection.input_dimension(),
        )?;

        if query_projection.head_dimension() != key_projection.head_dimension() {
            return Err(CtError::ShapeMismatch {
                op: "single-head block",
                expected: format!(
                    "query/key head dimension {}",
                    query_projection.head_dimension().value()
                ),
                got: format!(
                    "key head dimension {}",
                    key_projection.head_dimension().value()
                ),
            });
        }

        if output_projection.input_dimension().value() != value_projection.head_dimension().value()
        {
            return Err(CtError::ShapeMismatch {
                op: "single-head block",
                expected: format!(
                    "output projection input dimension {}",
                    value_projection.head_dimension().value()
                ),
                got: format!(
                    "output projection input dimension {}",
                    output_projection.input_dimension().value()
                ),
            });
        }

        validate_projection_input(
            "single-head block output projection",
            model_dimension,
            output_projection.output_dimension(),
        )?;
        validate_projection_input(
            "single-head block attention normalization",
            model_dimension,
            attention_norm.model_dimension(),
        )?;
        validate_projection_input(
            "single-head block feed-forward",
            model_dimension,
            feed_forward.input_dimension(),
        )?;
        validate_projection_input(
            "single-head block feed-forward normalization",
            model_dimension,
            feed_forward_norm.model_dimension(),
        )?;

        Ok(Self {
            model_dimension,
            query_projection,
            key_projection,
            value_projection,
            output_projection,
            attention_norm,
            feed_forward,
            feed_forward_norm,
        })
    }

    pub fn model_dimension(&self) -> ModelDimension {
        self.model_dimension
    }
}

/// Learned query, key, and value projections for one self-attention head.
#[derive(Debug, Clone, PartialEq)]
pub struct SelfAttentionHead {
    query_projection: HiddenToQuery,
    key_projection: HiddenToKey,
    value_projection: HiddenToValue,
}

impl SelfAttentionHead {
    pub fn new(
        query_projection: HiddenToQuery,
        key_projection: HiddenToKey,
        value_projection: HiddenToValue,
    ) -> CtResult<Self> {
        let input_dimension = query_projection.input_dimension();

        validate_projection_input(
            "self-attention head key projection",
            input_dimension,
            key_projection.input_dimension(),
        )?;
        validate_projection_input(
            "self-attention head value projection",
            input_dimension,
            value_projection.input_dimension(),
        )?;

        if query_projection.head_dimension() != key_projection.head_dimension() {
            return Err(CtError::ShapeMismatch {
                op: "self-attention head",
                expected: format!(
                    "query/key head dimension {}",
                    query_projection.head_dimension().value()
                ),
                got: format!(
                    "key head dimension {}",
                    key_projection.head_dimension().value()
                ),
            });
        }

        Ok(Self {
            query_projection,
            key_projection,
            value_projection,
        })
    }

    pub fn input_dimension(&self) -> ModelDimension {
        self.query_projection.input_dimension()
    }

    pub fn query_key_dimension(&self) -> HeadDimension {
        self.query_projection.head_dimension()
    }

    pub fn value_dimension(&self) -> HeadDimension {
        self.value_projection.head_dimension()
    }

    pub fn query_projection(&self) -> &HiddenToQuery {
        &self.query_projection
    }

    pub fn key_projection(&self) -> &HiddenToKey {
        &self.key_projection
    }

    pub fn value_projection(&self) -> &HiddenToValue {
        &self.value_projection
    }
}

/// A tiny multi-head Transformer block sketch.
#[derive(Debug, Clone, PartialEq)]
pub struct MultiHeadTransformerBlock {
    model_dimension: ModelDimension,
    head_count: HeadCount,
    value_dimension: HeadDimension,
    heads: Vec<SelfAttentionHead>,
    output_projection: AttentionOutputProjection,
    attention_norm: LayerNormalization,
    feed_forward: PositionWiseFeedForward,
    feed_forward_norm: LayerNormalization,
}

impl MultiHeadTransformerBlock {
    pub fn new(
        heads: Vec<SelfAttentionHead>,
        output_projection: AttentionOutputProjection,
        attention_norm: LayerNormalization,
        feed_forward: PositionWiseFeedForward,
        feed_forward_norm: LayerNormalization,
    ) -> CtResult<Self> {
        if heads.is_empty() {
            return Err(CtError::EmptyInput("multi-head block heads"));
        }

        let model_dimension = heads[0].input_dimension();
        let value_dimension = heads[0].value_dimension();

        for head in &heads {
            validate_projection_input(
                "multi-head block head projection",
                model_dimension,
                head.input_dimension(),
            )?;

            if head.value_dimension() != value_dimension {
                return Err(CtError::ShapeMismatch {
                    op: "multi-head block",
                    expected: format!("value head dimension {}", value_dimension.value()),
                    got: format!("value head dimension {}", head.value_dimension().value()),
                });
            }
        }

        let head_count = HeadCount::new(heads.len())?;
        let concatenated_dimension =
            ModelDimension::new(head_count.value() * value_dimension.value())?;

        validate_projection_input(
            "multi-head block output projection input",
            concatenated_dimension,
            output_projection.input_dimension(),
        )?;
        validate_projection_input(
            "multi-head block output projection",
            model_dimension,
            output_projection.output_dimension(),
        )?;
        validate_projection_input(
            "multi-head block attention normalization",
            model_dimension,
            attention_norm.model_dimension(),
        )?;
        validate_projection_input(
            "multi-head block feed-forward",
            model_dimension,
            feed_forward.input_dimension(),
        )?;
        validate_projection_input(
            "multi-head block feed-forward normalization",
            model_dimension,
            feed_forward_norm.model_dimension(),
        )?;

        Ok(Self {
            model_dimension,
            head_count,
            value_dimension,
            heads,
            output_projection,
            attention_norm,
            feed_forward,
            feed_forward_norm,
        })
    }

    pub fn model_dimension(&self) -> ModelDimension {
        self.model_dimension
    }

    pub fn head_count(&self) -> HeadCount {
        self.head_count
    }

    pub fn value_dimension(&self) -> HeadDimension {
        self.value_dimension
    }

    pub fn heads(&self) -> &[SelfAttentionHead] {
        &self.heads
    }

    fn with_heads(self, heads: Vec<SelfAttentionHead>) -> CtResult<Self> {
        Self::new(
            heads,
            self.output_projection,
            self.attention_norm,
            self.feed_forward,
            self.feed_forward_norm,
        )
    }

    fn with_feed_forward(self, feed_forward: PositionWiseFeedForward) -> CtResult<Self> {
        Self::new(
            self.heads,
            self.output_projection,
            self.attention_norm,
            feed_forward,
            self.feed_forward_norm,
        )
    }

    fn with_output_projection(
        self,
        output_projection: AttentionOutputProjection,
    ) -> CtResult<Self> {
        Self::new(
            self.heads,
            output_projection,
            self.attention_norm,
            self.feed_forward,
            self.feed_forward_norm,
        )
    }

    fn with_layer_norms(
        self,
        attention_norm: LayerNormalization,
        feed_forward_norm: LayerNormalization,
    ) -> CtResult<Self> {
        Self::new(
            self.heads,
            self.output_projection,
            attention_norm,
            self.feed_forward,
            feed_forward_norm,
        )
    }
}

/// A tiny masked multi-head Transformer block sketch.
#[derive(Debug, Clone, PartialEq)]
pub struct MaskedMultiHeadTransformerBlock {
    block: MultiHeadTransformerBlock,
}

impl MaskedMultiHeadTransformerBlock {
    pub fn new(
        heads: Vec<SelfAttentionHead>,
        output_projection: AttentionOutputProjection,
        attention_norm: LayerNormalization,
        feed_forward: PositionWiseFeedForward,
        feed_forward_norm: LayerNormalization,
    ) -> CtResult<Self> {
        Ok(Self {
            block: MultiHeadTransformerBlock::new(
                heads,
                output_projection,
                attention_norm,
                feed_forward,
                feed_forward_norm,
            )?,
        })
    }

    pub fn model_dimension(&self) -> ModelDimension {
        self.block.model_dimension()
    }

    pub fn head_count(&self) -> HeadCount {
        self.block.head_count()
    }

    pub fn value_dimension(&self) -> HeadDimension {
        self.block.value_dimension()
    }

    pub fn heads(&self) -> &[SelfAttentionHead] {
        self.block.heads()
    }

    pub fn feed_forward(&self) -> &PositionWiseFeedForward {
        &self.block.feed_forward
    }

    fn with_feed_forward(self, feed_forward: PositionWiseFeedForward) -> CtResult<Self> {
        Ok(Self {
            block: self.block.with_feed_forward(feed_forward)?,
        })
    }

    fn with_heads(self, heads: Vec<SelfAttentionHead>) -> CtResult<Self> {
        Ok(Self {
            block: self.block.with_heads(heads)?,
        })
    }

    fn with_output_projection(
        self,
        output_projection: AttentionOutputProjection,
    ) -> CtResult<Self> {
        Ok(Self {
            block: self.block.with_output_projection(output_projection)?,
        })
    }

    fn with_layer_norms(
        self,
        attention_norm: LayerNormalization,
        feed_forward_norm: LayerNormalization,
    ) -> CtResult<Self> {
        Ok(Self {
            block: self
                .block
                .with_layer_norms(attention_norm, feed_forward_norm)?,
        })
    }

    fn apply_with_training_cache(
        &self,
        hidden: HiddenSequence,
        mask: AttentionMask,
    ) -> CtResult<MaskedBlockTrainingCache> {
        if hidden.model_dimension() != self.block.model_dimension {
            return Err(CtError::ShapeMismatch {
                op: "masked multi-head block",
                expected: format!("model dimension {}", self.block.model_dimension.value()),
                got: format!("model dimension {}", hidden.model_dimension().value()),
            });
        }

        let head_caches = self
            .block
            .heads
            .iter()
            .map(|head| apply_self_attention_head_with_mask_cache(&hidden, head, Some(&mask)))
            .collect::<CtResult<Vec<_>>>()?;
        let attention_outputs = head_caches
            .iter()
            .map(|cache| cache.output.clone())
            .collect::<Vec<_>>();
        let head_outputs = AttentionHeadOutputs::new(attention_outputs)?;
        let multi_head_output = ConcatenateHeads.apply(head_outputs)?;
        let projected_attention = self
            .block
            .output_projection
            .apply(multi_head_output.clone())?;
        let with_attention = ResidualConnection.apply(Product::new(hidden, projected_attention))?;
        let normalized_attention = self.block.attention_norm.apply(with_attention.clone())?;
        let (feed_forward_output, feed_forward_rows) =
            feed_forward_with_cache(&self.block.feed_forward, &normalized_attention)?;
        let with_feed_forward =
            ResidualConnection.apply(Product::new(normalized_attention, feed_forward_output))?;
        let output = self
            .block
            .feed_forward_norm
            .apply(with_feed_forward.clone())?;

        Ok(MaskedBlockTrainingCache {
            output,
            with_feed_forward,
            with_attention,
            multi_head_output,
            attention_heads: head_caches,
            feed_forward_rows,
        })
    }
}

/// Tiny structured Transformer parameter object for the roadmap.
#[derive(Debug, Clone, PartialEq)]
pub struct TinyTransformerParameters {
    positional_encoding: PositionalEncoding,
    block: MaskedMultiHeadTransformerBlock,
    readout: TransformerReadout,
}

impl TinyTransformerParameters {
    pub fn new(
        positional_encoding: PositionalEncoding,
        block: MaskedMultiHeadTransformerBlock,
        readout: TransformerReadout,
    ) -> CtResult<Self> {
        let model_dimension = positional_encoding.model_dimension();

        validate_projection_input(
            "tiny transformer parameters block",
            model_dimension,
            block.model_dimension(),
        )?;
        validate_projection_input(
            "tiny transformer parameters readout",
            model_dimension,
            readout.input_dimension(),
        )?;

        Ok(Self {
            positional_encoding,
            block,
            readout,
        })
    }

    pub fn model_dimension(&self) -> ModelDimension {
        self.positional_encoding.model_dimension()
    }

    pub fn max_sequence_len(&self) -> SequenceLength {
        self.positional_encoding.max_sequence_len()
    }

    pub fn vocab_size(&self) -> VocabSize {
        self.readout.vocab_size()
    }

    pub fn encode(&self, hidden: HiddenSequence, mask: AttentionMask) -> CtResult<HiddenSequence> {
        let positioned = self.positional_encoding.apply(hidden)?;

        self.block.apply(Product::new(positioned, mask))
    }

    pub fn readout(&self) -> &TransformerReadout {
        &self.readout
    }

    pub fn feed_forward(&self) -> &PositionWiseFeedForward {
        self.block.feed_forward()
    }

    pub fn output_projection(&self) -> &AttentionOutputProjection {
        &self.block.block.output_projection
    }

    pub fn attention_heads(&self) -> &[SelfAttentionHead] {
        self.block.heads()
    }

    pub fn attention_norm(&self) -> &LayerNormalization {
        &self.block.block.attention_norm
    }

    pub fn feed_forward_norm(&self) -> &LayerNormalization {
        &self.block.block.feed_forward_norm
    }

    fn with_readout(self, readout: TransformerReadout) -> CtResult<Self> {
        Self::new(self.positional_encoding, self.block, readout)
    }

    fn with_feed_forward(self, feed_forward: PositionWiseFeedForward) -> CtResult<Self> {
        Self::new(
            self.positional_encoding,
            self.block.with_feed_forward(feed_forward)?,
            self.readout,
        )
    }

    fn with_attention_heads(self, heads: Vec<SelfAttentionHead>) -> CtResult<Self> {
        Self::new(
            self.positional_encoding,
            self.block.with_heads(heads)?,
            self.readout,
        )
    }

    fn with_output_projection(
        self,
        output_projection: AttentionOutputProjection,
    ) -> CtResult<Self> {
        Self::new(
            self.positional_encoding,
            self.block.with_output_projection(output_projection)?,
            self.readout,
        )
    }

    fn with_layer_norms(
        self,
        attention_norm: LayerNormalization,
        feed_forward_norm: LayerNormalization,
    ) -> CtResult<Self> {
        Self::new(
            self.positional_encoding,
            self.block
                .with_layer_norms(attention_norm, feed_forward_norm)?,
            self.readout,
        )
    }
}

/// Structured state owned by a future Transformer training loop.
#[derive(Debug, Clone, PartialEq)]
pub struct TransformerTrainingState {
    parameters: TinyTransformerParameters,
    learning_rate: LearningRate,
    step_count: StepCount,
}

impl TransformerTrainingState {
    pub fn new(parameters: TinyTransformerParameters, learning_rate: LearningRate) -> Self {
        Self::from_parts(parameters, learning_rate, StepCount::new(0))
    }

    pub fn from_parts(
        parameters: TinyTransformerParameters,
        learning_rate: LearningRate,
        step_count: StepCount,
    ) -> Self {
        Self {
            parameters,
            learning_rate,
            step_count,
        }
    }

    pub fn parameters(&self) -> &TinyTransformerParameters {
        &self.parameters
    }

    pub fn learning_rate(&self) -> LearningRate {
        self.learning_rate
    }

    pub fn step_count(&self) -> StepCount {
        self.step_count
    }

    pub fn record_updated_parameters(self, parameters: TinyTransformerParameters) -> Self {
        Self {
            parameters,
            learning_rate: self.learning_rate,
            step_count: StepCount::new(self.step_count.value() + 1),
        }
    }
}

/// One supervised sequence example for a readout-only Transformer update.
#[derive(Debug, Clone, PartialEq)]
pub struct TransformerReadoutTrainingExample {
    hidden: HiddenSequence,
    mask: AttentionMask,
    targets: TokenSequence,
}

impl TransformerReadoutTrainingExample {
    pub fn new(
        hidden: HiddenSequence,
        mask: AttentionMask,
        targets: TokenSequence,
    ) -> CtResult<Self> {
        let sequence_len = hidden.sequence_len();

        if targets.as_slice().len() != sequence_len.value() {
            return Err(CtError::ShapeMismatch {
                op: "transformer readout training targets",
                expected: format!("{} target tokens", sequence_len.value()),
                got: format!("{} target tokens", targets.as_slice().len()),
            });
        }

        if mask.query_len() != sequence_len || mask.key_len() != sequence_len {
            return Err(CtError::ShapeMismatch {
                op: "transformer readout training mask",
                expected: format!(
                    "{} query rows x {} key columns",
                    sequence_len.value(),
                    sequence_len.value()
                ),
                got: format!(
                    "{} query rows x {} key columns",
                    mask.query_len().value(),
                    mask.key_len().value()
                ),
            });
        }

        Ok(Self {
            hidden,
            mask,
            targets,
        })
    }

    pub fn hidden(&self) -> &HiddenSequence {
        &self.hidden
    }

    pub fn mask(&self) -> &AttentionMask {
        &self.mask
    }

    pub fn targets(&self) -> &TokenSequence {
        &self.targets
    }
}

/// Non-empty set of supervised readout examples.
#[derive(Debug, Clone, PartialEq)]
pub struct TransformerReadoutTrainingSet(Vec<TransformerReadoutTrainingExample>);

impl TransformerReadoutTrainingSet {
    pub fn new(
        examples: impl IntoIterator<Item = TransformerReadoutTrainingExample>,
    ) -> CtResult<Self> {
        let examples = examples.into_iter().collect::<Vec<_>>();

        if examples.is_empty() {
            return Err(CtError::EmptyInput("transformer readout training set"));
        }

        Ok(Self(examples))
    }

    pub fn examples(&self) -> &[TransformerReadoutTrainingExample] {
        &self.0
    }
}

/// One full-batch update of the sequence readout parameters.
#[derive(Debug, Clone, PartialEq)]
pub struct TransformerReadoutTrainStep {
    dataset: TransformerReadoutTrainingSet,
}

impl TransformerReadoutTrainStep {
    pub fn new(dataset: TransformerReadoutTrainingSet) -> Self {
        Self { dataset }
    }
}

impl Morphism<TransformerTrainingState, TransformerTrainingState> for TransformerReadoutTrainStep {
    fn name(&self) -> &'static str {
        "transformer_readout_train_step"
    }

    fn apply(&self, state: TransformerTrainingState) -> CtResult<TransformerTrainingState> {
        let input_dimension = state.parameters.readout().input_dimension().value();
        let vocab_size = state.parameters.vocab_size().value();
        let mut grad_weight = vec![vec![0.0; vocab_size]; input_dimension];
        let mut grad_bias = vec![0.0; vocab_size];
        let mut position_count = 0usize;

        for example in self.dataset.examples() {
            let encoded = state
                .parameters
                .encode(example.hidden().clone(), example.mask().clone())?;
            let logits = state.parameters.readout().apply(encoded.clone())?;

            for ((hidden_row, logit_row), target) in encoded
                .rows()
                .iter()
                .zip(logits.rows())
                .zip(example.targets().as_slice())
            {
                let target_index = target.index();

                if target_index >= vocab_size {
                    return Err(CtError::OutOfRange {
                        kind: "sequence target",
                        index: target_index,
                        limit: vocab_size,
                    });
                }

                let probabilities = Softmax.apply(logit_row.clone())?;
                let mut dlogits = probabilities.as_slice().to_vec();
                dlogits[target_index] -= 1.0;

                for (vocab_id, dlogit) in dlogits.iter().copied().enumerate() {
                    grad_bias[vocab_id] += dlogit;

                    for (feature, hidden_value) in hidden_row.as_slice().iter().copied().enumerate()
                    {
                        grad_weight[feature][vocab_id] += hidden_value * dlogit;
                    }
                }

                position_count += 1;
            }
        }

        let scale = state.learning_rate().value() / position_count as f32;
        let mut updated_weight = state.parameters.readout().weight().to_vec();
        let mut updated_bias = state.parameters.readout().bias().to_vec();

        for (row, grad_row) in updated_weight.iter_mut().zip(&grad_weight) {
            for (weight, grad) in row.iter_mut().zip(grad_row) {
                *weight -= scale * grad;
            }
        }

        for (bias, grad) in updated_bias.iter_mut().zip(&grad_bias) {
            *bias -= scale * grad;
        }

        let updated_readout = TransformerReadout::new(updated_weight, updated_bias)?;
        let updated_parameters = state.parameters.clone().with_readout(updated_readout)?;

        Ok(state.record_updated_parameters(updated_parameters))
    }
}

/// One supervised sequence example for local feed-forward sublayer training.
#[derive(Debug, Clone, PartialEq)]
pub struct TransformerFeedForwardTrainingExample {
    input: HiddenSequence,
    target: HiddenSequence,
}

impl TransformerFeedForwardTrainingExample {
    pub fn new(input: HiddenSequence, target: HiddenSequence) -> CtResult<Self> {
        if input.sequence_len() != target.sequence_len() {
            return Err(CtError::ShapeMismatch {
                op: "transformer feed-forward training sequence",
                expected: format!("{} target rows", input.sequence_len().value()),
                got: format!("{} target rows", target.sequence_len().value()),
            });
        }

        if input.model_dimension() != target.model_dimension() {
            return Err(CtError::ShapeMismatch {
                op: "transformer feed-forward training dimension",
                expected: format!("target dimension {}", input.model_dimension().value()),
                got: format!("target dimension {}", target.model_dimension().value()),
            });
        }

        Ok(Self { input, target })
    }

    pub fn input(&self) -> &HiddenSequence {
        &self.input
    }

    pub fn target(&self) -> &HiddenSequence {
        &self.target
    }
}

/// Non-empty set of local feed-forward training examples.
#[derive(Debug, Clone, PartialEq)]
pub struct TransformerFeedForwardTrainingSet(Vec<TransformerFeedForwardTrainingExample>);

impl TransformerFeedForwardTrainingSet {
    pub fn new(
        examples: impl IntoIterator<Item = TransformerFeedForwardTrainingExample>,
    ) -> CtResult<Self> {
        let examples = examples.into_iter().collect::<Vec<_>>();

        if examples.is_empty() {
            return Err(CtError::EmptyInput("transformer feed-forward training set"));
        }

        Ok(Self(examples))
    }

    pub fn examples(&self) -> &[TransformerFeedForwardTrainingExample] {
        &self.0
    }
}

/// One full-batch update of the position-wise feed-forward sublayer.
#[derive(Debug, Clone, PartialEq)]
pub struct TransformerFeedForwardTrainStep {
    dataset: TransformerFeedForwardTrainingSet,
}

impl TransformerFeedForwardTrainStep {
    pub fn new(dataset: TransformerFeedForwardTrainingSet) -> Self {
        Self { dataset }
    }
}

impl Morphism<TransformerTrainingState, TransformerTrainingState>
    for TransformerFeedForwardTrainStep
{
    fn name(&self) -> &'static str {
        "transformer_feed_forward_train_step"
    }

    fn apply(&self, state: TransformerTrainingState) -> CtResult<TransformerTrainingState> {
        let feed_forward = state.parameters.feed_forward();
        let mut gradients = FeedForwardGradients::new(feed_forward);
        let mut row_count = 0usize;

        for example in self.dataset.examples() {
            if example.input().model_dimension() != feed_forward.input_dimension() {
                return Err(CtError::ShapeMismatch {
                    op: "transformer feed-forward train step",
                    expected: format!("input dimension {}", feed_forward.input_dimension().value()),
                    got: format!(
                        "input dimension {}",
                        example.input().model_dimension().value()
                    ),
                });
            }

            let (_output, cache_rows) = feed_forward_with_cache(feed_forward, example.input())?;

            for (cache, target_row) in cache_rows.iter().zip(example.target().rows()) {
                let d_output = cache
                    .output
                    .iter()
                    .zip(target_row.as_slice())
                    .map(|(output_value, target_value)| output_value - target_value)
                    .collect::<Vec<_>>();
                gradients.accumulate(feed_forward, cache, &d_output);

                row_count += 1;
            }
        }

        let updated_feed_forward =
            gradients.apply_to(feed_forward, state.learning_rate(), row_count)?;
        let updated_parameters = state
            .parameters
            .clone()
            .with_feed_forward(updated_feed_forward)?;

        Ok(state.record_updated_parameters(updated_parameters))
    }
}

/// One supervised sequence example for a composed block-level update.
#[derive(Debug, Clone, PartialEq)]
pub struct TransformerBlockTrainingExample {
    hidden: HiddenSequence,
    mask: AttentionMask,
    targets: TokenSequence,
}

impl TransformerBlockTrainingExample {
    pub fn new(
        hidden: HiddenSequence,
        mask: AttentionMask,
        targets: TokenSequence,
    ) -> CtResult<Self> {
        let sequence_len = hidden.sequence_len();

        if targets.as_slice().len() != sequence_len.value() {
            return Err(CtError::ShapeMismatch {
                op: "transformer block training targets",
                expected: format!("{} target tokens", sequence_len.value()),
                got: format!("{} target tokens", targets.as_slice().len()),
            });
        }

        if mask.query_len() != sequence_len || mask.key_len() != sequence_len {
            return Err(CtError::ShapeMismatch {
                op: "transformer block training mask",
                expected: format!(
                    "{} query rows x {} key columns",
                    sequence_len.value(),
                    sequence_len.value()
                ),
                got: format!(
                    "{} query rows x {} key columns",
                    mask.query_len().value(),
                    mask.key_len().value()
                ),
            });
        }

        Ok(Self {
            hidden,
            mask,
            targets,
        })
    }

    pub fn hidden(&self) -> &HiddenSequence {
        &self.hidden
    }

    pub fn mask(&self) -> &AttentionMask {
        &self.mask
    }

    pub fn targets(&self) -> &TokenSequence {
        &self.targets
    }
}

/// Non-empty set of supervised block-level training examples.
#[derive(Debug, Clone, PartialEq)]
pub struct TransformerBlockTrainingSet(Vec<TransformerBlockTrainingExample>);

impl TransformerBlockTrainingSet {
    pub fn new(
        examples: impl IntoIterator<Item = TransformerBlockTrainingExample>,
    ) -> CtResult<Self> {
        let examples = examples.into_iter().collect::<Vec<_>>();

        if examples.is_empty() {
            return Err(CtError::EmptyInput("transformer block training set"));
        }

        Ok(Self(examples))
    }

    pub fn examples(&self) -> &[TransformerBlockTrainingExample] {
        &self.0
    }
}

/// One full-batch update through the readout, block sublayers, and attention heads.
#[derive(Debug, Clone, PartialEq)]
pub struct TransformerBlockTrainStep {
    dataset: TransformerBlockTrainingSet,
}

impl TransformerBlockTrainStep {
    pub fn new(dataset: TransformerBlockTrainingSet) -> Self {
        Self { dataset }
    }
}

impl Morphism<TransformerTrainingState, TransformerTrainingState> for TransformerBlockTrainStep {
    fn name(&self) -> &'static str {
        "transformer_block_train_step"
    }

    fn apply(&self, state: TransformerTrainingState) -> CtResult<TransformerTrainingState> {
        let readout = state.parameters.readout();
        let feed_forward = state.parameters.feed_forward();
        let output_projection = state.parameters.output_projection();
        let attention_norm = state.parameters.attention_norm();
        let feed_forward_norm = state.parameters.feed_forward_norm();
        let mut readout_gradients = ReadoutGradients::new(readout);
        let mut feed_forward_gradients = FeedForwardGradients::new(feed_forward);
        let mut output_projection_gradients =
            AttentionOutputProjectionGradients::new(output_projection);
        let mut attention_norm_gradients = LayerNormGradients::new(attention_norm.parameters());
        let mut feed_forward_norm_gradients =
            LayerNormGradients::new(feed_forward_norm.parameters());
        let mut attention_head_gradients = state
            .parameters
            .attention_heads()
            .iter()
            .map(AttentionHeadGradients::new)
            .collect::<Vec<_>>();
        let mut position_count = 0usize;

        for example in self.dataset.examples() {
            let positioned = state
                .parameters
                .positional_encoding
                .apply(example.hidden().clone())?;
            let block_cache = state
                .parameters
                .block
                .apply_with_training_cache(positioned.clone(), example.mask().clone())?;
            let logits = readout.apply(block_cache.output.clone())?;
            let vocab_size = logits.vocab_size().value();
            let mut d_multi_head_rows = vec![
                vec![0.0; output_projection.input_dimension().value()];
                block_cache.multi_head_output.sequence_len().value()
            ];

            for (
                position,
                (
                    (((encoded_row, logit_row), with_feed_forward_row), with_attention_row),
                    ((feed_forward_cache, multi_head_row), target),
                ),
            ) in block_cache
                .output
                .rows()
                .iter()
                .zip(logits.rows())
                .zip(block_cache.with_feed_forward.rows())
                .zip(block_cache.with_attention.rows())
                .zip(
                    block_cache
                        .feed_forward_rows
                        .iter()
                        .zip(block_cache.multi_head_output.rows())
                        .zip(example.targets().as_slice()),
                )
                .enumerate()
            {
                let dlogits =
                    softmax_cross_entropy_logits_gradient(logit_row, target.index(), vocab_size)?;
                let d_encoded =
                    readout_gradients.accumulate(readout, encoded_row.as_slice(), &dlogits);
                let d_with_feed_forward = feed_forward_norm_gradients.accumulate(
                    &d_encoded,
                    with_feed_forward_row.as_slice(),
                    feed_forward_norm.parameters(),
                );

                let d_feed_forward_input = feed_forward_gradients.accumulate(
                    feed_forward,
                    feed_forward_cache,
                    &d_with_feed_forward,
                );
                let d_normalized_attention = add_rows(&d_with_feed_forward, &d_feed_forward_input);
                let d_with_attention = attention_norm_gradients.accumulate(
                    &d_normalized_attention,
                    with_attention_row.as_slice(),
                    attention_norm.parameters(),
                );
                d_multi_head_rows[position] = output_projection_gradients.accumulate(
                    output_projection,
                    multi_head_row.as_slice(),
                    &d_with_attention,
                );
                position_count += 1;
            }

            let value_dimension = block_cache.multi_head_output.head_dimension().value();
            for (head_index, ((head_gradient, head), head_cache)) in attention_head_gradients
                .iter_mut()
                .zip(state.parameters.attention_heads())
                .zip(&block_cache.attention_heads)
                .enumerate()
            {
                let start = head_index * value_dimension;
                let end = start + value_dimension;
                let d_head_output_rows = d_multi_head_rows
                    .iter()
                    .map(|row| row[start..end].to_vec())
                    .collect::<Vec<_>>();

                head_gradient.accumulate(
                    head,
                    &positioned,
                    head_cache,
                    example.mask(),
                    &d_head_output_rows,
                )?;
            }
        }

        let updated_readout =
            readout_gradients.apply_to(readout, state.learning_rate(), position_count)?;
        let updated_feed_forward =
            feed_forward_gradients.apply_to(feed_forward, state.learning_rate(), position_count)?;
        let updated_output_projection = output_projection_gradients.apply_to(
            output_projection,
            state.learning_rate(),
            position_count,
        )?;
        let updated_attention_norm = LayerNormalization::new(attention_norm_gradients.apply_to(
            attention_norm.parameters(),
            state.learning_rate(),
            position_count,
        )?);
        let updated_feed_forward_norm =
            LayerNormalization::new(feed_forward_norm_gradients.apply_to(
                feed_forward_norm.parameters(),
                state.learning_rate(),
                position_count,
            )?);
        let updated_heads = state
            .parameters
            .attention_heads()
            .iter()
            .zip(attention_head_gradients)
            .map(|(head, gradients)| {
                gradients.apply_to(head, state.learning_rate(), position_count)
            })
            .collect::<CtResult<Vec<_>>>()?;
        let updated_parameters = state
            .parameters
            .clone()
            .with_readout(updated_readout)?
            .with_feed_forward(updated_feed_forward)?
            .with_output_projection(updated_output_projection)?
            .with_layer_norms(updated_attention_norm, updated_feed_forward_norm)?
            .with_attention_heads(updated_heads)?;

        Ok(state.record_updated_parameters(updated_parameters))
    }
}

/// Average sequence cross-entropy for the structured Transformer readout.
pub fn transformer_readout_average_loss(
    state: &TransformerTrainingState,
    dataset: &TransformerReadoutTrainingSet,
) -> CtResult<Loss> {
    let mut total = 0.0;
    let mut position_count = 0usize;

    for example in dataset.examples() {
        let logits = state.apply(Product::new(
            example.hidden().clone(),
            example.mask().clone(),
        ))?;
        let vocab_size = logits.vocab_size().value();

        for (logit_row, target) in logits.rows().iter().zip(example.targets().as_slice()) {
            let target_index = target.index();

            if target_index >= vocab_size {
                return Err(CtError::OutOfRange {
                    kind: "sequence target",
                    index: target_index,
                    limit: vocab_size,
                });
            }

            let probabilities = Softmax.apply(logit_row.clone())?;
            let probability = probabilities.as_slice()[target_index].max(1e-9);
            total += -probability.ln();
            position_count += 1;
        }
    }

    Loss::new(total / position_count as f32)
}

/// Average squared error for local feed-forward sublayer training.
pub fn transformer_feed_forward_average_loss(
    state: &TransformerTrainingState,
    dataset: &TransformerFeedForwardTrainingSet,
) -> CtResult<Loss> {
    let feed_forward = state.parameters.feed_forward();
    let mut total = 0.0;
    let mut value_count = 0usize;

    for example in dataset.examples() {
        let output = feed_forward.apply(example.input().clone())?;

        for (output_row, target_row) in output.rows().iter().zip(example.target().rows()) {
            for (output_value, target_value) in
                output_row.as_slice().iter().zip(target_row.as_slice())
            {
                let error = output_value - target_value;
                total += 0.5 * error * error;
                value_count += 1;
            }
        }
    }

    Loss::new(total / value_count as f32)
}

/// Average sequence cross-entropy for the composed block-level training set.
pub fn transformer_block_average_loss(
    state: &TransformerTrainingState,
    dataset: &TransformerBlockTrainingSet,
) -> CtResult<Loss> {
    let mut total = 0.0;
    let mut position_count = 0usize;

    for example in dataset.examples() {
        let logits = state.apply(Product::new(
            example.hidden().clone(),
            example.mask().clone(),
        ))?;
        let vocab_size = logits.vocab_size().value();

        for (logit_row, target) in logits.rows().iter().zip(example.targets().as_slice()) {
            let target_index = target.index();

            if target_index >= vocab_size {
                return Err(CtError::OutOfRange {
                    kind: "sequence target",
                    index: target_index,
                    limit: vocab_size,
                });
            }

            let probabilities = Softmax.apply(logit_row.clone())?;
            let probability = probabilities.as_slice()[target_index].max(1e-9);
            total += -probability.ln();
            position_count += 1;
        }
    }

    Loss::new(total / position_count as f32)
}

/// Applies softmax independently to each query row.
#[derive(Debug, Clone)]
pub struct AttentionSoftmax;

impl Morphism<AttentionScores, AttentionWeights> for AttentionSoftmax {
    fn name(&self) -> &'static str {
        "attention_softmax"
    }

    fn apply(&self, scores: AttentionScores) -> CtResult<AttentionWeights> {
        let rows = scores
            .rows
            .into_iter()
            .map(|row| Softmax.apply(row))
            .collect::<CtResult<Vec<_>>>()?;

        AttentionWeights::new(rows)
    }
}

/// Mixes value vectors with row-wise attention weights.
#[derive(Debug, Clone)]
pub struct WeightedValueMixing;

impl Morphism<Product<AttentionWeights, ValueSequence>, AttentionOutput> for WeightedValueMixing {
    fn name(&self) -> &'static str {
        "weighted_value_mixing"
    }

    fn apply(&self, input: Product<AttentionWeights, ValueSequence>) -> CtResult<AttentionOutput> {
        let (weights, values) = input.into_parts();

        if weights.key_len() != values.sequence_len() {
            return Err(CtError::ShapeMismatch {
                op: "weighted value mixing",
                expected: format!("{} value rows", weights.key_len().value()),
                got: format!("{} value rows", values.sequence_len().value()),
            });
        }

        let value_width = values.head_dimension().value();
        let rows = weights
            .rows()
            .iter()
            .map(|weight_row| weighted_sum(weight_row.as_slice(), values.rows(), value_width))
            .collect::<Vec<_>>();

        AttentionOutput::new(rows)
    }
}

fn weighted_sum(weights: &[f32], values: &[Vector], value_width: usize) -> Vec<f32> {
    let mut output = vec![0.0; value_width];

    for (weight, value) in weights.iter().zip(values.iter()) {
        for (output_value, value_component) in output.iter_mut().zip(value.as_slice()) {
            *output_value += weight * value_component;
        }
    }

    output
}

/// Concatenates single-head outputs along the feature dimension.
#[derive(Debug, Clone)]
pub struct ConcatenateHeads;

impl Morphism<AttentionHeadOutputs, MultiHeadOutput> for ConcatenateHeads {
    fn name(&self) -> &'static str {
        "concatenate_heads"
    }

    fn apply(&self, heads: AttentionHeadOutputs) -> CtResult<MultiHeadOutput> {
        let rows = (0..heads.sequence_len().value())
            .map(|position| {
                heads
                    .heads()
                    .iter()
                    .flat_map(|head| head.rows()[position].as_slice().iter().copied())
                    .collect::<Vec<_>>()
            })
            .collect::<Vec<_>>();

        MultiHeadOutput::new(rows, heads.head_count(), heads.head_dimension())
    }
}

impl Morphism<MultiHeadOutput, ProjectedAttentionOutput> for AttentionOutputProjection {
    fn name(&self) -> &'static str {
        "attention_output_projection"
    }

    fn apply(&self, input: MultiHeadOutput) -> CtResult<ProjectedAttentionOutput> {
        if input.model_dimension() != self.input_dimension {
            return Err(CtError::ShapeMismatch {
                op: "attention output projection",
                expected: format!("input dimension {}", self.input_dimension.value()),
                got: format!("input dimension {}", input.model_dimension().value()),
            });
        }

        let rows = input
            .rows()
            .iter()
            .map(|row| project_row(row.as_slice(), &self.weight, &self.bias))
            .collect::<Vec<_>>();

        ProjectedAttentionOutput::new(rows)
    }
}

/// Adds a same-shaped sublayer output back to the hidden sequence.
#[derive(Debug, Clone)]
pub struct ResidualConnection;

impl Morphism<Product<HiddenSequence, ProjectedAttentionOutput>, HiddenSequence>
    for ResidualConnection
{
    fn name(&self) -> &'static str {
        "residual_connection"
    }

    fn apply(
        &self,
        input: Product<HiddenSequence, ProjectedAttentionOutput>,
    ) -> CtResult<HiddenSequence> {
        let (hidden, sublayer_output) = input.into_parts();

        if hidden.sequence_len() != sublayer_output.sequence_len() {
            return Err(CtError::ShapeMismatch {
                op: "residual connection",
                expected: format!("{} sequence rows", hidden.sequence_len().value()),
                got: format!("{} sequence rows", sublayer_output.sequence_len().value()),
            });
        }

        if hidden.model_dimension() != sublayer_output.model_dimension() {
            return Err(CtError::ShapeMismatch {
                op: "residual connection",
                expected: format!("model dimension {}", hidden.model_dimension().value()),
                got: format!(
                    "model dimension {}",
                    sublayer_output.model_dimension().value()
                ),
            });
        }

        let rows = hidden
            .rows()
            .iter()
            .zip(sublayer_output.rows())
            .map(|(left, right)| add_rows(left.as_slice(), right.as_slice()))
            .collect::<Vec<_>>();

        HiddenSequence::new(rows)
    }
}

impl Morphism<HiddenSequence, HiddenSequence> for LayerNormalization {
    fn name(&self) -> &'static str {
        "layer_normalization"
    }

    fn apply(&self, input: HiddenSequence) -> CtResult<HiddenSequence> {
        if input.model_dimension() != self.parameters.model_dimension() {
            return Err(CtError::ShapeMismatch {
                op: "layer normalization",
                expected: format!(
                    "model dimension {}",
                    self.parameters.model_dimension().value()
                ),
                got: format!("model dimension {}", input.model_dimension().value()),
            });
        }

        let rows = input
            .rows()
            .iter()
            .map(|row| normalize_row(row.as_slice(), &self.parameters))
            .collect::<Vec<_>>();

        HiddenSequence::new(rows)
    }
}

impl Morphism<HiddenSequence, HiddenSequence> for PositionWiseFeedForward {
    fn name(&self) -> &'static str {
        "position_wise_feed_forward"
    }

    fn apply(&self, input: HiddenSequence) -> CtResult<HiddenSequence> {
        let (output, _cache) = feed_forward_with_cache(self, &input)?;

        Ok(output)
    }
}

impl Morphism<HiddenSequence, HiddenSequence> for PositionalEncoding {
    fn name(&self) -> &'static str {
        "positional_encoding"
    }

    fn apply(&self, input: HiddenSequence) -> CtResult<HiddenSequence> {
        if input.sequence_len().value() > self.max_sequence_len.value() {
            return Err(CtError::ShapeMismatch {
                op: "positional encoding",
                expected: format!("at most {} sequence rows", self.max_sequence_len.value()),
                got: format!("{} sequence rows", input.sequence_len().value()),
            });
        }

        if input.model_dimension() != self.model_dimension {
            return Err(CtError::ShapeMismatch {
                op: "positional encoding",
                expected: format!("model dimension {}", self.model_dimension.value()),
                got: format!("model dimension {}", input.model_dimension().value()),
            });
        }

        let rows = input
            .rows()
            .iter()
            .zip(&self.rows)
            .map(|(hidden_row, position_row)| {
                add_rows(hidden_row.as_slice(), position_row.as_slice())
            })
            .collect::<Vec<_>>();

        HiddenSequence::new(rows)
    }
}

impl Morphism<HiddenSequence, SequenceLogits> for TransformerReadout {
    fn name(&self) -> &'static str {
        "transformer_readout"
    }

    fn apply(&self, input: HiddenSequence) -> CtResult<SequenceLogits> {
        if input.model_dimension() != self.input_dimension {
            return Err(CtError::ShapeMismatch {
                op: "transformer readout",
                expected: format!("input dimension {}", self.input_dimension.value()),
                got: format!("input dimension {}", input.model_dimension().value()),
            });
        }

        let rows = input
            .rows()
            .iter()
            .map(|row| project_row(row.as_slice(), &self.weight, &self.bias))
            .collect::<Vec<_>>();

        SequenceLogits::new(rows)
    }
}

impl Morphism<HiddenSequence, QuerySequence> for HiddenToQuery {
    fn name(&self) -> &'static str {
        "hidden_to_query"
    }

    fn apply(&self, input: HiddenSequence) -> CtResult<QuerySequence> {
        QuerySequence::new(self.projection.project(&input)?)
    }
}

impl Morphism<HiddenSequence, KeySequence> for HiddenToKey {
    fn name(&self) -> &'static str {
        "hidden_to_key"
    }

    fn apply(&self, input: HiddenSequence) -> CtResult<KeySequence> {
        KeySequence::new(self.projection.project(&input)?)
    }
}

impl Morphism<HiddenSequence, ValueSequence> for HiddenToValue {
    fn name(&self) -> &'static str {
        "hidden_to_value"
    }

    fn apply(&self, input: HiddenSequence) -> CtResult<ValueSequence> {
        ValueSequence::new(self.projection.project(&input)?)
    }
}

fn apply_self_attention_head(
    input: &HiddenSequence,
    head: &SelfAttentionHead,
) -> CtResult<AttentionOutput> {
    apply_self_attention_head_with_mask(input, head, None)
}

fn apply_self_attention_head_with_mask(
    input: &HiddenSequence,
    head: &SelfAttentionHead,
    mask: Option<&AttentionMask>,
) -> CtResult<AttentionOutput> {
    Ok(apply_self_attention_head_with_mask_cache(input, head, mask)?.output)
}

fn apply_self_attention_head_with_mask_cache(
    input: &HiddenSequence,
    head: &SelfAttentionHead,
    mask: Option<&AttentionMask>,
) -> CtResult<AttentionHeadTrainingCache> {
    let queries = head.query_projection.apply(input.clone())?;
    let keys = head.key_projection.apply(input.clone())?;
    let values = head.value_projection.apply(input.clone())?;
    let scores = ScaledDotProductScores.apply(Product::new(queries.clone(), keys.clone()))?;
    let scores = if let Some(mask) = mask {
        MaskedAttentionScores.apply(Product::new(scores, mask.clone()))?
    } else {
        scores
    };
    let weights = AttentionSoftmax.apply(scores)?;
    let output = WeightedValueMixing.apply(Product::new(weights.clone(), values.clone()))?;

    Ok(AttentionHeadTrainingCache {
        queries,
        keys,
        values,
        weights,
        output,
    })
}

impl Morphism<Product<HiddenSequence, HiddenSequence>, HiddenSequence> for ResidualConnection {
    fn name(&self) -> &'static str {
        "hidden_residual_connection"
    }

    fn apply(&self, input: Product<HiddenSequence, HiddenSequence>) -> CtResult<HiddenSequence> {
        let (left, right) = input.into_parts();

        if left.sequence_len() != right.sequence_len() {
            return Err(CtError::ShapeMismatch {
                op: "hidden residual connection",
                expected: format!("{} sequence rows", left.sequence_len().value()),
                got: format!("{} sequence rows", right.sequence_len().value()),
            });
        }

        if left.model_dimension() != right.model_dimension() {
            return Err(CtError::ShapeMismatch {
                op: "hidden residual connection",
                expected: format!("model dimension {}", left.model_dimension().value()),
                got: format!("model dimension {}", right.model_dimension().value()),
            });
        }

        let rows = left
            .rows()
            .iter()
            .zip(right.rows())
            .map(|(left, right)| add_rows(left.as_slice(), right.as_slice()))
            .collect::<Vec<_>>();

        HiddenSequence::new(rows)
    }
}

impl Morphism<HiddenSequence, HiddenSequence> for SingleHeadTransformerBlock {
    fn name(&self) -> &'static str {
        "single_head_transformer_block"
    }

    fn apply(&self, input: HiddenSequence) -> CtResult<HiddenSequence> {
        if input.model_dimension() != self.model_dimension {
            return Err(CtError::ShapeMismatch {
                op: "single-head block",
                expected: format!("model dimension {}", self.model_dimension.value()),
                got: format!("model dimension {}", input.model_dimension().value()),
            });
        }

        let head = SelfAttentionHead::new(
            self.query_projection.clone(),
            self.key_projection.clone(),
            self.value_projection.clone(),
        )?;
        let attention_output = apply_self_attention_head(&input, &head)?;
        let head_outputs = AttentionHeadOutputs::new(vec![attention_output])?;
        let multi_head_output = ConcatenateHeads.apply(head_outputs)?;
        let projected_attention = self.output_projection.apply(multi_head_output)?;
        let with_attention = ResidualConnection.apply(Product::new(input, projected_attention))?;
        let normalized_attention = self.attention_norm.apply(with_attention)?;
        let feed_forward_output = self.feed_forward.apply(normalized_attention.clone())?;
        let with_feed_forward =
            ResidualConnection.apply(Product::new(normalized_attention, feed_forward_output))?;

        self.feed_forward_norm.apply(with_feed_forward)
    }
}

impl Morphism<HiddenSequence, HiddenSequence> for MultiHeadTransformerBlock {
    fn name(&self) -> &'static str {
        "multi_head_transformer_block"
    }

    fn apply(&self, input: HiddenSequence) -> CtResult<HiddenSequence> {
        if input.model_dimension() != self.model_dimension {
            return Err(CtError::ShapeMismatch {
                op: "multi-head block",
                expected: format!("model dimension {}", self.model_dimension.value()),
                got: format!("model dimension {}", input.model_dimension().value()),
            });
        }

        let attention_outputs = self
            .heads
            .iter()
            .map(|head| apply_self_attention_head(&input, head))
            .collect::<CtResult<Vec<_>>>()?;
        let head_outputs = AttentionHeadOutputs::new(attention_outputs)?;
        let multi_head_output = ConcatenateHeads.apply(head_outputs)?;
        let projected_attention = self.output_projection.apply(multi_head_output)?;
        let with_attention = ResidualConnection.apply(Product::new(input, projected_attention))?;
        let normalized_attention = self.attention_norm.apply(with_attention)?;
        let feed_forward_output = self.feed_forward.apply(normalized_attention.clone())?;
        let with_feed_forward =
            ResidualConnection.apply(Product::new(normalized_attention, feed_forward_output))?;

        self.feed_forward_norm.apply(with_feed_forward)
    }
}

impl Morphism<Product<HiddenSequence, AttentionMask>, HiddenSequence>
    for MaskedMultiHeadTransformerBlock
{
    fn name(&self) -> &'static str {
        "masked_multi_head_transformer_block"
    }

    fn apply(&self, input: Product<HiddenSequence, AttentionMask>) -> CtResult<HiddenSequence> {
        let (hidden, mask) = input.into_parts();

        Ok(self.apply_with_training_cache(hidden, mask)?.output)
    }
}

impl Morphism<Product<HiddenSequence, AttentionMask>, SequenceLogits>
    for TinyTransformerParameters
{
    fn name(&self) -> &'static str {
        "tiny_transformer_parameters"
    }

    fn apply(&self, input: Product<HiddenSequence, AttentionMask>) -> CtResult<SequenceLogits> {
        let (hidden, mask) = input.into_parts();
        let positioned = self.positional_encoding.apply(hidden)?;
        let encoded = self.block.apply(Product::new(positioned, mask))?;

        self.readout.apply(encoded)
    }
}

impl Morphism<Product<HiddenSequence, AttentionMask>, SequenceLogits> for TransformerTrainingState {
    fn name(&self) -> &'static str {
        "transformer_training_state_forward"
    }

    fn apply(&self, input: Product<HiddenSequence, AttentionMask>) -> CtResult<SequenceLogits> {
        self.parameters.apply(input)
    }
}

#[derive(Debug, Clone, PartialEq)]
struct ReadoutGradients {
    weight: Vec<Vec<f32>>,
    bias: Vec<f32>,
}

impl ReadoutGradients {
    fn new(readout: &TransformerReadout) -> Self {
        Self {
            weight: vec![
                vec![0.0; readout.vocab_size().value()];
                readout.input_dimension().value()
            ],
            bias: vec![0.0; readout.vocab_size().value()],
        }
    }

    fn accumulate(
        &mut self,
        readout: &TransformerReadout,
        hidden_row: &[f32],
        dlogits: &[f32],
    ) -> Vec<f32> {
        let mut d_hidden = vec![0.0; readout.input_dimension().value()];

        for (vocab_id, dlogit) in dlogits.iter().copied().enumerate() {
            self.bias[vocab_id] += dlogit;

            for (feature, hidden_value) in hidden_row.iter().copied().enumerate() {
                self.weight[feature][vocab_id] += hidden_value * dlogit;
                d_hidden[feature] += readout.weight()[feature][vocab_id] * dlogit;
            }
        }

        d_hidden
    }

    fn apply_to(
        self,
        readout: &TransformerReadout,
        learning_rate: LearningRate,
        position_count: usize,
    ) -> CtResult<TransformerReadout> {
        if position_count == 0 {
            return Err(CtError::EmptyInput("readout gradient positions"));
        }

        let scale = learning_rate.value() / position_count as f32;
        let mut updated_weight = readout.weight().to_vec();
        let mut updated_bias = readout.bias().to_vec();

        for (row, grad_row) in updated_weight.iter_mut().zip(&self.weight) {
            for (weight, grad) in row.iter_mut().zip(grad_row) {
                *weight -= scale * grad;
            }
        }

        for (bias, grad) in updated_bias.iter_mut().zip(&self.bias) {
            *bias -= scale * grad;
        }

        TransformerReadout::new(updated_weight, updated_bias)
    }
}

#[derive(Debug, Clone, PartialEq)]
struct FeedForwardGradients {
    first_weight: Vec<Vec<f32>>,
    first_bias: Vec<f32>,
    second_weight: Vec<Vec<f32>>,
    second_bias: Vec<f32>,
}

impl FeedForwardGradients {
    fn new(feed_forward: &PositionWiseFeedForward) -> Self {
        Self {
            first_weight: vec![
                vec![0.0; feed_forward.hidden_dimension().value()];
                feed_forward.input_dimension().value()
            ],
            first_bias: vec![0.0; feed_forward.hidden_dimension().value()],
            second_weight: vec![
                vec![0.0; feed_forward.output_dimension().value()];
                feed_forward.hidden_dimension().value()
            ],
            second_bias: vec![0.0; feed_forward.output_dimension().value()],
        }
    }

    fn accumulate(
        &mut self,
        feed_forward: &PositionWiseFeedForward,
        cache: &FeedForwardRowCache,
        d_output: &[f32],
    ) -> Vec<f32> {
        let mut d_activation = vec![0.0; feed_forward.hidden_dimension().value()];

        for (output_id, d_output_value) in d_output.iter().copied().enumerate() {
            self.second_bias[output_id] += d_output_value;

            for (hidden_id, activation_value) in cache.activation.iter().copied().enumerate() {
                self.second_weight[hidden_id][output_id] += activation_value * d_output_value;
                d_activation[hidden_id] +=
                    feed_forward.second_weight()[hidden_id][output_id] * d_output_value;
            }
        }

        let mut d_input = vec![0.0; feed_forward.input_dimension().value()];

        for (hidden_id, pre_activation_value) in cache.pre_activation.iter().copied().enumerate() {
            let d_pre_activation = if pre_activation_value > 0.0 {
                d_activation[hidden_id]
            } else {
                0.0
            };
            self.first_bias[hidden_id] += d_pre_activation;

            for (input_id, input_value) in cache.input.iter().copied().enumerate() {
                self.first_weight[input_id][hidden_id] += input_value * d_pre_activation;
                d_input[input_id] +=
                    feed_forward.first_weight()[input_id][hidden_id] * d_pre_activation;
            }
        }

        d_input
    }

    fn apply_to(
        self,
        feed_forward: &PositionWiseFeedForward,
        learning_rate: LearningRate,
        row_count: usize,
    ) -> CtResult<PositionWiseFeedForward> {
        if row_count == 0 {
            return Err(CtError::EmptyInput("feed-forward gradient rows"));
        }

        let scale = learning_rate.value() / row_count as f32;
        let mut updated_first_weight = feed_forward.first_weight().to_vec();
        let mut updated_first_bias = feed_forward.first_bias().to_vec();
        let mut updated_second_weight = feed_forward.second_weight().to_vec();
        let mut updated_second_bias = feed_forward.second_bias().to_vec();

        for (row, grad_row) in updated_first_weight.iter_mut().zip(&self.first_weight) {
            for (weight, grad) in row.iter_mut().zip(grad_row) {
                *weight -= scale * grad;
            }
        }

        for (bias, grad) in updated_first_bias.iter_mut().zip(&self.first_bias) {
            *bias -= scale * grad;
        }

        for (row, grad_row) in updated_second_weight.iter_mut().zip(&self.second_weight) {
            for (weight, grad) in row.iter_mut().zip(grad_row) {
                *weight -= scale * grad;
            }
        }

        for (bias, grad) in updated_second_bias.iter_mut().zip(&self.second_bias) {
            *bias -= scale * grad;
        }

        PositionWiseFeedForward::new(
            updated_first_weight,
            updated_first_bias,
            updated_second_weight,
            updated_second_bias,
        )
    }
}

#[derive(Debug, Clone, PartialEq)]
struct AttentionOutputProjectionGradients {
    weight: Vec<Vec<f32>>,
    bias: Vec<f32>,
}

impl AttentionOutputProjectionGradients {
    fn new(output_projection: &AttentionOutputProjection) -> Self {
        Self {
            weight: vec![
                vec![0.0; output_projection.output_dimension().value()];
                output_projection.input_dimension().value()
            ],
            bias: vec![0.0; output_projection.output_dimension().value()],
        }
    }

    fn accumulate(
        &mut self,
        output_projection: &AttentionOutputProjection,
        multi_head_row: &[f32],
        d_projected_attention: &[f32],
    ) -> Vec<f32> {
        let mut d_multi_head = vec![0.0; output_projection.input_dimension().value()];

        for (output_id, d_value) in d_projected_attention.iter().copied().enumerate() {
            self.bias[output_id] += d_value;

            for (input_id, input_value) in multi_head_row.iter().copied().enumerate() {
                self.weight[input_id][output_id] += input_value * d_value;
                d_multi_head[input_id] += output_projection.weight()[input_id][output_id] * d_value;
            }
        }

        d_multi_head
    }

    fn apply_to(
        self,
        output_projection: &AttentionOutputProjection,
        learning_rate: LearningRate,
        row_count: usize,
    ) -> CtResult<AttentionOutputProjection> {
        if row_count == 0 {
            return Err(CtError::EmptyInput(
                "attention output projection gradient rows",
            ));
        }

        let scale = learning_rate.value() / row_count as f32;
        let mut updated_weight = output_projection.weight().to_vec();
        let mut updated_bias = output_projection.bias().to_vec();

        for (row, grad_row) in updated_weight.iter_mut().zip(&self.weight) {
            for (weight, grad) in row.iter_mut().zip(grad_row) {
                *weight -= scale * grad;
            }
        }

        for (bias, grad) in updated_bias.iter_mut().zip(&self.bias) {
            *bias -= scale * grad;
        }

        AttentionOutputProjection::new(updated_weight, updated_bias)
    }
}

#[derive(Debug, Clone, PartialEq)]
struct HiddenProjectionGradients {
    weight: Vec<Vec<f32>>,
    bias: Vec<f32>,
}

impl HiddenProjectionGradients {
    fn new(projection: &HiddenProjection) -> Self {
        Self {
            weight: vec![
                vec![0.0; projection.head_dimension().value()];
                projection.input_dimension().value()
            ],
            bias: vec![0.0; projection.head_dimension().value()],
        }
    }

    fn accumulate(
        &mut self,
        projection: &HiddenProjection,
        hidden_row: &[f32],
        d_output: &[f32],
    ) -> CtResult<()> {
        if hidden_row.len() != projection.input_dimension().value() {
            return Err(CtError::ShapeMismatch {
                op: projection.op,
                expected: format!("input dimension {}", projection.input_dimension().value()),
                got: format!("input dimension {}", hidden_row.len()),
            });
        }

        if d_output.len() != projection.head_dimension().value() {
            return Err(CtError::ShapeMismatch {
                op: projection.op,
                expected: format!("output dimension {}", projection.head_dimension().value()),
                got: format!("output dimension {}", d_output.len()),
            });
        }

        for (output_id, d_output_value) in d_output.iter().copied().enumerate() {
            self.bias[output_id] += d_output_value;

            for (input_id, input_value) in hidden_row.iter().copied().enumerate() {
                self.weight[input_id][output_id] += input_value * d_output_value;
            }
        }

        Ok(())
    }

    fn updated_parts(
        self,
        projection: &HiddenProjection,
        learning_rate: LearningRate,
        row_count: usize,
    ) -> CtResult<(Vec<Vec<f32>>, Vec<f32>)> {
        if row_count == 0 {
            return Err(CtError::EmptyInput("hidden projection gradient rows"));
        }

        let scale = learning_rate.value() / row_count as f32;
        let mut updated_weight = projection.weight().to_vec();
        let mut updated_bias = projection.bias().to_vec();

        for (row, grad_row) in updated_weight.iter_mut().zip(&self.weight) {
            for (weight, grad) in row.iter_mut().zip(grad_row) {
                *weight -= scale * grad;
            }
        }

        for (bias, grad) in updated_bias.iter_mut().zip(&self.bias) {
            *bias -= scale * grad;
        }

        Ok((updated_weight, updated_bias))
    }
}

#[derive(Debug, Clone, PartialEq)]
struct AttentionHeadGradients {
    query: HiddenProjectionGradients,
    key: HiddenProjectionGradients,
    value: HiddenProjectionGradients,
}

impl AttentionHeadGradients {
    fn new(head: &SelfAttentionHead) -> Self {
        Self {
            query: HiddenProjectionGradients::new(&head.query_projection.projection),
            key: HiddenProjectionGradients::new(&head.key_projection.projection),
            value: HiddenProjectionGradients::new(&head.value_projection.projection),
        }
    }

    fn accumulate(
        &mut self,
        head: &SelfAttentionHead,
        input: &HiddenSequence,
        cache: &AttentionHeadTrainingCache,
        mask: &AttentionMask,
        d_output_rows: &[Vec<f32>],
    ) -> CtResult<()> {
        let sequence_len = input.sequence_len().value();
        let value_dimension = head.value_dimension().value();
        let query_key_dimension = head.query_key_dimension().value();

        if d_output_rows.len() != sequence_len {
            return Err(CtError::ShapeMismatch {
                op: "attention head gradients",
                expected: format!("{sequence_len} output rows"),
                got: format!("{} output rows", d_output_rows.len()),
            });
        }

        if mask.query_len().value() != sequence_len || mask.key_len().value() != sequence_len {
            return Err(CtError::ShapeMismatch {
                op: "attention head gradient mask",
                expected: format!("{sequence_len} query rows x {sequence_len} key columns"),
                got: format!(
                    "{} query rows x {} key columns",
                    mask.query_len().value(),
                    mask.key_len().value()
                ),
            });
        }

        let mut d_weights = vec![vec![0.0; sequence_len]; sequence_len];
        let mut d_values = vec![vec![0.0; value_dimension]; sequence_len];

        for (query_id, d_output) in d_output_rows.iter().enumerate() {
            if d_output.len() != value_dimension {
                return Err(CtError::ShapeMismatch {
                    op: "attention head output gradient",
                    expected: format!("value dimension {value_dimension}"),
                    got: format!("value dimension {}", d_output.len()),
                });
            }

            for key_id in 0..sequence_len {
                let value_row = cache.values.rows()[key_id].as_slice();
                let weight = cache.weights.rows()[query_id].as_slice()[key_id];

                for value_id in 0..value_dimension {
                    d_weights[query_id][key_id] += d_output[value_id] * value_row[value_id];
                    d_values[key_id][value_id] += weight * d_output[value_id];
                }
            }
        }

        let mut d_scores = vec![vec![0.0; sequence_len]; sequence_len];

        for query_id in 0..sequence_len {
            let weight_row = cache.weights.rows()[query_id].as_slice();
            let row_dot = d_weights[query_id]
                .iter()
                .zip(weight_row)
                .map(|(grad, weight)| grad * weight)
                .sum::<f32>();

            for key_id in 0..sequence_len {
                if mask.rows()[query_id][key_id] {
                    d_scores[query_id][key_id] =
                        weight_row[key_id] * (d_weights[query_id][key_id] - row_dot);
                }
            }
        }

        let score_scale = (query_key_dimension as f32).sqrt();
        let mut d_queries = vec![vec![0.0; query_key_dimension]; sequence_len];
        let mut d_keys = vec![vec![0.0; query_key_dimension]; sequence_len];

        for query_id in 0..sequence_len {
            let query_row = cache.queries.rows()[query_id].as_slice();

            for key_id in 0..sequence_len {
                let score_gradient = d_scores[query_id][key_id] / score_scale;
                let key_row = cache.keys.rows()[key_id].as_slice();

                for feature in 0..query_key_dimension {
                    d_queries[query_id][feature] += score_gradient * key_row[feature];
                    d_keys[key_id][feature] += score_gradient * query_row[feature];
                }
            }
        }

        for position in 0..sequence_len {
            let hidden_row = input.rows()[position].as_slice();

            self.query.accumulate(
                &head.query_projection.projection,
                hidden_row,
                &d_queries[position],
            )?;
            self.key.accumulate(
                &head.key_projection.projection,
                hidden_row,
                &d_keys[position],
            )?;
            self.value.accumulate(
                &head.value_projection.projection,
                hidden_row,
                &d_values[position],
            )?;
        }

        Ok(())
    }

    fn apply_to(
        self,
        head: &SelfAttentionHead,
        learning_rate: LearningRate,
        row_count: usize,
    ) -> CtResult<SelfAttentionHead> {
        let (query_weight, query_bias) = self.query.updated_parts(
            &head.query_projection.projection,
            learning_rate,
            row_count,
        )?;
        let (key_weight, key_bias) =
            self.key
                .updated_parts(&head.key_projection.projection, learning_rate, row_count)?;
        let (value_weight, value_bias) = self.value.updated_parts(
            &head.value_projection.projection,
            learning_rate,
            row_count,
        )?;

        SelfAttentionHead::new(
            HiddenToQuery::new(query_weight, query_bias)?,
            HiddenToKey::new(key_weight, key_bias)?,
            HiddenToValue::new(value_weight, value_bias)?,
        )
    }
}

#[derive(Debug, Clone, PartialEq)]
struct LayerNormGradients {
    scale: Vec<f32>,
    shift: Vec<f32>,
}

impl LayerNormGradients {
    fn new(parameters: &LayerNormParameters) -> Self {
        Self {
            scale: vec![0.0; parameters.model_dimension().value()],
            shift: vec![0.0; parameters.model_dimension().value()],
        }
    }

    fn accumulate(
        &mut self,
        d_output: &[f32],
        input: &[f32],
        parameters: &LayerNormParameters,
    ) -> Vec<f32> {
        let stats = layer_norm_stats(input, parameters);

        for (feature, (grad, normalized_value)) in
            d_output.iter().zip(&stats.normalized).enumerate()
        {
            self.scale[feature] += grad * normalized_value;
            self.shift[feature] += grad;
        }

        layer_norm_backward_from_stats(d_output, &stats, parameters)
    }

    fn apply_to(
        self,
        parameters: &LayerNormParameters,
        learning_rate: LearningRate,
        row_count: usize,
    ) -> CtResult<LayerNormParameters> {
        if row_count == 0 {
            return Err(CtError::EmptyInput("layer norm gradient rows"));
        }

        let scale = learning_rate.value() / row_count as f32;
        let mut updated_scale = parameters.scale().to_vec();
        let mut updated_shift = parameters.shift().to_vec();

        for (value, grad) in updated_scale.iter_mut().zip(&self.scale) {
            *value -= scale * grad;
        }

        for (value, grad) in updated_shift.iter_mut().zip(&self.shift) {
            *value -= scale * grad;
        }

        LayerNormParameters::new(updated_scale, updated_shift, parameters.epsilon())
    }
}

fn feed_forward_with_cache(
    feed_forward: &PositionWiseFeedForward,
    input: &HiddenSequence,
) -> CtResult<(HiddenSequence, Vec<FeedForwardRowCache>)> {
    if input.model_dimension() != feed_forward.input_dimension() {
        return Err(CtError::ShapeMismatch {
            op: "position-wise feed-forward",
            expected: format!("input dimension {}", feed_forward.input_dimension().value()),
            got: format!("input dimension {}", input.model_dimension().value()),
        });
    }

    let mut output_rows = Vec::with_capacity(input.rows().len());
    let mut cache_rows = Vec::with_capacity(input.rows().len());

    for row in input.rows() {
        let input_row = row.as_slice().to_vec();
        let pre_activation = project_row(
            &input_row,
            feed_forward.first_weight(),
            feed_forward.first_bias(),
        );
        let activation = pre_activation
            .iter()
            .map(|value| value.max(0.0))
            .collect::<Vec<_>>();
        let output = project_row(
            &activation,
            feed_forward.second_weight(),
            feed_forward.second_bias(),
        );

        output_rows.push(output.clone());
        cache_rows.push(FeedForwardRowCache {
            input: input_row,
            pre_activation,
            activation,
            output,
        });
    }

    Ok((HiddenSequence::new(output_rows)?, cache_rows))
}

fn softmax_cross_entropy_logits_gradient(
    logits: &Logits,
    target_index: usize,
    vocab_size: usize,
) -> CtResult<Vec<f32>> {
    if target_index >= vocab_size {
        return Err(CtError::OutOfRange {
            kind: "sequence target",
            index: target_index,
            limit: vocab_size,
        });
    }

    let probabilities = Softmax.apply(logits.clone())?;
    let mut dlogits = probabilities.as_slice().to_vec();
    dlogits[target_index] -= 1.0;

    Ok(dlogits)
}

#[derive(Debug, Clone, PartialEq)]
struct LayerNormStats {
    dimension: f32,
    inverse_std: f32,
    normalized: Vec<f32>,
}

fn layer_norm_stats(input: &[f32], parameters: &LayerNormParameters) -> LayerNormStats {
    let dimension = input.len() as f32;
    let mean = input.iter().sum::<f32>() / dimension;
    let variance = input
        .iter()
        .map(|value| {
            let centered = value - mean;
            centered * centered
        })
        .sum::<f32>()
        / dimension;
    let inverse_std = 1.0 / (variance + parameters.epsilon().value()).sqrt();
    let normalized = input
        .iter()
        .map(|value| (value - mean) * inverse_std)
        .collect::<Vec<_>>();

    LayerNormStats {
        dimension,
        inverse_std,
        normalized,
    }
}

fn layer_norm_backward_from_stats(
    d_output: &[f32],
    stats: &LayerNormStats,
    parameters: &LayerNormParameters,
) -> Vec<f32> {
    let d_normalized = d_output
        .iter()
        .zip(parameters.scale())
        .map(|(grad, scale)| grad * scale)
        .collect::<Vec<_>>();
    let sum_d_normalized = d_normalized.iter().sum::<f32>();
    let sum_d_normalized_times_normalized = d_normalized
        .iter()
        .zip(&stats.normalized)
        .map(|(grad, normalized_value)| grad * normalized_value)
        .sum::<f32>();

    d_normalized
        .iter()
        .zip(&stats.normalized)
        .map(|(grad, normalized_value)| {
            (stats.dimension * grad
                - sum_d_normalized
                - normalized_value * sum_d_normalized_times_normalized)
                * stats.inverse_std
                / stats.dimension
        })
        .collect()
}

fn validate_projection_input(
    op: &'static str,
    expected: ModelDimension,
    got: ModelDimension,
) -> CtResult<()> {
    if expected != got {
        return Err(CtError::ShapeMismatch {
            op,
            expected: format!("model dimension {}", expected.value()),
            got: format!("model dimension {}", got.value()),
        });
    }

    Ok(())
}

fn add_rows(left: &[f32], right: &[f32]) -> Vec<f32> {
    left.iter()
        .zip(right.iter())
        .map(|(left, right)| left + right)
        .collect()
}

fn validate_linear_parts(
    op: &'static str,
    weight: &[Vec<f32>],
    bias: &[f32],
) -> CtResult<(ModelDimension, ModelDimension)> {
    if weight.is_empty() {
        return Err(CtError::EmptyInput("linear weight"));
    }

    if bias.is_empty() {
        return Err(CtError::EmptyInput("linear bias"));
    }

    let output_dimension = bias.len();

    if bias.iter().any(|value| !value.is_finite()) {
        return Err(CtError::ShapeMismatch {
            op,
            expected: "finite bias values".to_string(),
            got: "non-finite bias value".to_string(),
        });
    }

    for row in weight {
        if row.len() != output_dimension {
            return Err(CtError::ShapeMismatch {
                op,
                expected: format!("weight rows have {output_dimension} columns"),
                got: format!("weight row with {} columns", row.len()),
            });
        }

        if row.iter().any(|value| !value.is_finite()) {
            return Err(CtError::ShapeMismatch {
                op,
                expected: "finite weight values".to_string(),
                got: "non-finite weight value".to_string(),
            });
        }
    }

    Ok((
        ModelDimension::new(weight.len())?,
        ModelDimension::new(output_dimension)?,
    ))
}

fn normalize_row(input: &[f32], parameters: &LayerNormParameters) -> Vec<f32> {
    let mean = input.iter().sum::<f32>() / input.len() as f32;
    let variance = input
        .iter()
        .map(|value| {
            let centered = value - mean;
            centered * centered
        })
        .sum::<f32>()
        / input.len() as f32;
    let denominator = (variance + parameters.epsilon.value()).sqrt();

    input
        .iter()
        .zip(parameters.scale.iter().zip(&parameters.shift))
        .map(|(value, (scale, shift))| ((value - mean) / denominator) * scale + shift)
        .collect()
}

fn project_row(input: &[f32], weight: &[Vec<f32>], bias: &[f32]) -> Vec<f32> {
    let mut output = bias.to_vec();

    for (feature, input_value) in input.iter().enumerate() {
        for (output_value, weight_value) in output.iter_mut().zip(&weight[feature]) {
            *output_value += input_value * weight_value;
        }
    }

    output
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn scaled_dot_product_scores_build_query_by_key_rows() -> CtResult<()> {
        let queries = QuerySequence::new(vec![vec![1.0, 0.0], vec![0.0, 1.0]])?;
        let keys = KeySequence::new(vec![vec![1.0, 0.0], vec![0.0, 1.0], vec![1.0, 1.0]])?;

        let scores = ScaledDotProductScores.apply(Product::new(queries, keys))?;

        assert_eq!(scores.query_len().value(), 2);
        assert_eq!(scores.key_len().value(), 3);
        assert!(crate::domain::approx_eq(
            scores.rows()[0].as_slice()[0],
            std::f32::consts::FRAC_1_SQRT_2,
            1e-4
        ));
        assert!(crate::domain::approx_eq(
            scores.rows()[0].as_slice()[1],
            0.0,
            1e-4
        ));
        assert!(crate::domain::approx_eq(
            scores.rows()[1].as_slice()[2],
            std::f32::consts::FRAC_1_SQRT_2,
            1e-4
        ));

        Ok(())
    }

    #[test]
    fn weighted_value_mixing_builds_one_output_per_query() -> CtResult<()> {
        let queries = QuerySequence::new(vec![vec![1.0, 0.0], vec![0.0, 1.0]])?;
        let keys = KeySequence::new(vec![vec![1.0, 0.0], vec![0.0, 1.0], vec![1.0, 1.0]])?;
        let values = ValueSequence::new(vec![vec![1.0, 10.0], vec![2.0, 20.0], vec![3.0, 30.0]])?;

        let scores = ScaledDotProductScores.apply(Product::new(queries, keys))?;
        let weights = AttentionSoftmax.apply(scores)?;
        let output = WeightedValueMixing.apply(Product::new(weights, values))?;

        assert_eq!(output.sequence_len().value(), 2);
        assert_eq!(output.head_dimension().value(), 2);
        assert!(crate::domain::approx_eq(
            output.rows()[0].as_slice()[0],
            2.0,
            1e-4
        ));
        assert!(crate::domain::approx_eq(
            output.rows()[0].as_slice()[1],
            20.0,
            1e-4
        ));

        Ok(())
    }

    #[test]
    fn concatenate_heads_preserves_sequence_and_concatenates_features() -> CtResult<()> {
        let head_a = AttentionOutput::new(vec![vec![1.0, 10.0], vec![2.0, 20.0]])?;
        let head_b = AttentionOutput::new(vec![vec![3.0, 30.0], vec![4.0, 40.0]])?;

        let heads = AttentionHeadOutputs::new(vec![head_a, head_b])?;
        let output = ConcatenateHeads.apply(heads)?;

        assert_eq!(output.sequence_len().value(), 2);
        assert_eq!(output.head_count().value(), 2);
        assert_eq!(output.head_dimension().value(), 2);
        assert_eq!(output.model_dimension().value(), 4);
        assert_eq!(output.rows()[0].as_slice(), &[1.0, 10.0, 3.0, 30.0]);
        assert_eq!(output.rows()[1].as_slice(), &[2.0, 20.0, 4.0, 40.0]);

        Ok(())
    }

    #[test]
    fn attention_output_projection_maps_multi_head_rows() -> CtResult<()> {
        let head_a = AttentionOutput::new(vec![vec![1.0, 10.0], vec![2.0, 20.0]])?;
        let head_b = AttentionOutput::new(vec![vec![3.0, 30.0], vec![4.0, 40.0]])?;
        let multi_head =
            ConcatenateHeads.apply(AttentionHeadOutputs::new(vec![head_a, head_b])?)?;
        let projection = AttentionOutputProjection::new(
            vec![
                vec![1.0, 0.0],
                vec![0.0, 0.1],
                vec![0.5, 0.0],
                vec![0.0, 0.01],
            ],
            vec![0.0, 1.0],
        )?;

        let output = projection.apply(multi_head)?;

        assert_eq!(output.sequence_len().value(), 2);
        assert_eq!(output.model_dimension().value(), 2);
        assert!(crate::domain::approx_eq(
            output.rows()[0].as_slice()[0],
            2.5,
            1e-4
        ));
        assert!(crate::domain::approx_eq(
            output.rows()[0].as_slice()[1],
            2.3,
            1e-4
        ));
        assert!(crate::domain::approx_eq(
            output.rows()[1].as_slice()[0],
            4.0,
            1e-4
        ));
        assert!(crate::domain::approx_eq(
            output.rows()[1].as_slice()[1],
            3.4,
            1e-4
        ));

        Ok(())
    }

    #[test]
    fn residual_connection_adds_matching_sequences() -> CtResult<()> {
        let hidden = HiddenSequence::new(vec![vec![1.0, 2.0], vec![3.0, 4.0]])?;
        let sublayer_output = ProjectedAttentionOutput::new(vec![vec![0.5, 1.5], vec![2.5, 3.5]])?;

        let output = ResidualConnection.apply(Product::new(hidden, sublayer_output))?;

        assert_eq!(output.sequence_len().value(), 2);
        assert_eq!(output.model_dimension().value(), 2);
        assert_eq!(output.rows()[0].as_slice(), &[1.5, 3.5]);
        assert_eq!(output.rows()[1].as_slice(), &[5.5, 7.5]);

        Ok(())
    }

    #[test]
    fn layer_normalization_preserves_shape_and_normalizes_each_row() -> CtResult<()> {
        let hidden = HiddenSequence::new(vec![vec![1.0, 3.0], vec![2.0, 4.0]])?;
        let norm = LayerNormalization::new(LayerNormParameters::identity(ModelDimension::new(2)?));

        let output = norm.apply(hidden)?;

        assert_eq!(output.sequence_len().value(), 2);
        assert_eq!(output.model_dimension().value(), 2);
        assert!(crate::domain::approx_eq(
            output.rows()[0].as_slice()[0],
            -0.999995,
            1e-4
        ));
        assert!(crate::domain::approx_eq(
            output.rows()[0].as_slice()[1],
            0.999995,
            1e-4
        ));
        assert!(crate::domain::approx_eq(
            output.rows()[1].as_slice()[0],
            -0.999995,
            1e-4
        ));
        assert!(crate::domain::approx_eq(
            output.rows()[1].as_slice()[1],
            0.999995,
            1e-4
        ));

        Ok(())
    }

    #[test]
    fn layer_normalization_applies_scale_and_shift() -> CtResult<()> {
        let hidden = HiddenSequence::new(vec![vec![1.0, 3.0]])?;
        let params = LayerNormParameters::new(
            vec![2.0, 0.5],
            vec![1.0, -1.0],
            NormalizationEpsilon::new(1e-5)?,
        )?;
        let norm = LayerNormalization::new(params);

        let output = norm.apply(hidden)?;

        assert!(crate::domain::approx_eq(
            output.rows()[0].as_slice()[0],
            -0.99999,
            1e-4
        ));
        assert!(crate::domain::approx_eq(
            output.rows()[0].as_slice()[1],
            -0.5000025,
            1e-4
        ));

        Ok(())
    }

    #[test]
    fn position_wise_feed_forward_maps_each_row_and_preserves_shape() -> CtResult<()> {
        let hidden = HiddenSequence::new(vec![vec![1.0, 2.0], vec![3.0, 4.0]])?;
        let feed_forward = PositionWiseFeedForward::new(
            vec![vec![1.0, -1.0, 0.5], vec![0.0, 1.0, 0.5]],
            vec![0.0, 0.0, 0.0],
            vec![vec![1.0, 0.0], vec![0.0, 1.0], vec![0.5, 0.5]],
            vec![0.0, 0.0],
        )?;

        let output = feed_forward.apply(hidden)?;

        assert_eq!(output.sequence_len().value(), 2);
        assert_eq!(output.model_dimension().value(), 2);
        assert!(crate::domain::approx_eq(
            output.rows()[0].as_slice()[0],
            1.75,
            1e-4
        ));
        assert!(crate::domain::approx_eq(
            output.rows()[0].as_slice()[1],
            1.75,
            1e-4
        ));
        assert!(crate::domain::approx_eq(
            output.rows()[1].as_slice()[0],
            4.75,
            1e-4
        ));
        assert!(crate::domain::approx_eq(
            output.rows()[1].as_slice()[1],
            2.75,
            1e-4
        ));

        Ok(())
    }

    #[test]
    fn attention_mask_removes_disallowed_positions_before_softmax() -> CtResult<()> {
        let scores = AttentionScores::new(vec![vec![2.0, 1.0, 2.0]])?;
        let mask = AttentionMask::new(vec![vec![true, false, true]])?;
        let masked_scores = MaskedAttentionScores.apply(Product::new(scores, mask))?;
        let weights = AttentionSoftmax.apply(masked_scores)?;

        assert!(crate::domain::approx_eq(
            weights.rows()[0].as_slice()[0],
            0.5,
            1e-4
        ));
        assert!(crate::domain::approx_eq(
            weights.rows()[0].as_slice()[1],
            0.0,
            1e-4
        ));
        assert!(crate::domain::approx_eq(
            weights.rows()[0].as_slice()[2],
            0.5,
            1e-4
        ));

        Ok(())
    }

    #[test]
    fn attention_softmax_normalizes_each_query_row() -> CtResult<()> {
        let scores = AttentionScores::new(vec![vec![2.0, 1.0], vec![0.0, 3.0]])?;
        let weights = AttentionSoftmax.apply(scores)?;

        assert_eq!(weights.query_len().value(), 2);
        assert_eq!(weights.key_len().value(), 2);

        for row in weights.rows() {
            let sum: f32 = row.as_slice().iter().sum();
            assert!(crate::domain::approx_eq(sum, 1.0, 1e-4));
        }

        Ok(())
    }

    #[test]
    fn attention_scores_reject_non_finite_values() {
        assert!(matches!(
            AttentionScores::new(vec![vec![1.0, f32::NAN]]),
            Err(CtError::ShapeMismatch {
                op: "attention scores",
                ..
            })
        ));
    }

    #[test]
    fn attention_scores_reject_ragged_rows() {
        assert!(matches!(
            AttentionScores::new(vec![vec![1.0, 2.0], vec![3.0]]),
            Err(CtError::ShapeMismatch {
                op: "attention scores",
                ..
            })
        ));
    }

    #[test]
    fn attention_mask_rejects_rows_with_no_allowed_keys() {
        assert!(matches!(
            AttentionMask::new(vec![vec![false, false]]),
            Err(CtError::EmptyInput("attention mask row allows no keys"))
        ));
    }

    #[test]
    fn masked_attention_scores_reject_shape_mismatch() -> CtResult<()> {
        let scores = AttentionScores::new(vec![vec![1.0, 2.0]])?;
        let mask = AttentionMask::new(vec![vec![true, true], vec![true, true]])?;

        assert!(matches!(
            MaskedAttentionScores.apply(Product::new(scores, mask)),
            Err(CtError::ShapeMismatch {
                op: "masked attention scores",
                ..
            })
        ));

        Ok(())
    }

    #[test]
    fn query_sequence_rejects_ragged_rows() {
        assert!(matches!(
            QuerySequence::new(vec![vec![1.0, 2.0], vec![3.0]]),
            Err(CtError::ShapeMismatch {
                op: "query sequence",
                ..
            })
        ));
    }

    #[test]
    fn value_sequence_rejects_empty_rows() {
        assert!(matches!(
            ValueSequence::new(vec![Vec::new()]),
            Err(CtError::EmptyInput("attention vector row"))
        ));
    }

    #[test]
    fn key_sequence_rejects_non_finite_values() {
        assert!(matches!(
            KeySequence::new(vec![vec![1.0, f32::NAN]]),
            Err(CtError::ShapeMismatch {
                op: "key sequence",
                ..
            })
        ));
    }

    #[test]
    fn scaled_dot_product_rejects_mismatched_head_dimensions() -> CtResult<()> {
        let queries = QuerySequence::new(vec![vec![1.0, 0.0]])?;
        let keys = KeySequence::new(vec![vec![1.0, 0.0, 1.0]])?;

        assert!(matches!(
            ScaledDotProductScores.apply(Product::new(queries, keys)),
            Err(CtError::ShapeMismatch {
                op: "scaled dot-product attention scores",
                ..
            })
        ));

        Ok(())
    }

    #[test]
    fn weighted_value_mixing_rejects_value_length_mismatch() -> CtResult<()> {
        let weights = AttentionWeights::new(vec![Distribution::new(vec![0.5, 0.5])?])?;
        let values = ValueSequence::new(vec![vec![1.0, 10.0]])?;

        assert!(matches!(
            WeightedValueMixing.apply(Product::new(weights, values)),
            Err(CtError::ShapeMismatch {
                op: "weighted value mixing",
                ..
            })
        ));

        Ok(())
    }

    #[test]
    fn sequence_and_head_dimensions_reject_zero() {
        assert!(matches!(
            SequenceLength::new(0),
            Err(CtError::EmptyInput("sequence length"))
        ));
        assert!(matches!(
            HeadDimension::new(0),
            Err(CtError::EmptyInput("head dimension"))
        ));
        assert!(matches!(
            HeadCount::new(0),
            Err(CtError::EmptyInput("head count"))
        ));
    }

    #[test]
    fn attention_head_outputs_reject_sequence_mismatch() -> CtResult<()> {
        let head_a = AttentionOutput::new(vec![vec![1.0, 10.0]])?;
        let head_b = AttentionOutput::new(vec![vec![2.0, 20.0], vec![3.0, 30.0]])?;

        assert!(matches!(
            AttentionHeadOutputs::new(vec![head_a, head_b]),
            Err(CtError::ShapeMismatch {
                op: "attention head outputs",
                ..
            })
        ));

        Ok(())
    }

    #[test]
    fn attention_head_outputs_reject_head_dimension_mismatch() -> CtResult<()> {
        let head_a = AttentionOutput::new(vec![vec![1.0, 10.0]])?;
        let head_b = AttentionOutput::new(vec![vec![2.0, 20.0, 200.0]])?;

        assert!(matches!(
            AttentionHeadOutputs::new(vec![head_a, head_b]),
            Err(CtError::ShapeMismatch {
                op: "attention head outputs",
                ..
            })
        ));

        Ok(())
    }

    #[test]
    fn attention_output_projection_rejects_input_dimension_mismatch() -> CtResult<()> {
        let head_a = AttentionOutput::new(vec![vec![1.0, 10.0]])?;
        let head_b = AttentionOutput::new(vec![vec![2.0, 20.0]])?;
        let multi_head =
            ConcatenateHeads.apply(AttentionHeadOutputs::new(vec![head_a, head_b])?)?;
        let projection =
            AttentionOutputProjection::new(vec![vec![1.0], vec![1.0], vec![1.0]], vec![0.0])?;

        assert!(matches!(
            projection.apply(multi_head),
            Err(CtError::ShapeMismatch {
                op: "attention output projection",
                ..
            })
        ));

        Ok(())
    }

    #[test]
    fn attention_output_projection_rejects_bad_weight_shapes() {
        assert!(matches!(
            AttentionOutputProjection::new(vec![vec![1.0, 0.0], vec![1.0]], vec![0.0, 0.0]),
            Err(CtError::ShapeMismatch {
                op: "attention output projection",
                ..
            })
        ));
    }

    #[test]
    fn attention_output_projection_rejects_non_finite_values() {
        assert!(matches!(
            AttentionOutputProjection::new(vec![vec![1.0]], vec![f32::NAN]),
            Err(CtError::ShapeMismatch {
                op: "attention output projection",
                ..
            })
        ));
        assert!(matches!(
            AttentionOutputProjection::new(vec![vec![f32::INFINITY]], vec![0.0]),
            Err(CtError::ShapeMismatch {
                op: "attention output projection",
                ..
            })
        ));
    }

    #[test]
    fn residual_connection_rejects_sequence_mismatch() -> CtResult<()> {
        let hidden = HiddenSequence::new(vec![vec![1.0, 2.0]])?;
        let sublayer_output = ProjectedAttentionOutput::new(vec![vec![0.5, 1.5], vec![2.5, 3.5]])?;

        assert!(matches!(
            ResidualConnection.apply(Product::new(hidden, sublayer_output)),
            Err(CtError::ShapeMismatch {
                op: "residual connection",
                ..
            })
        ));

        Ok(())
    }

    #[test]
    fn residual_connection_rejects_model_dimension_mismatch() -> CtResult<()> {
        let hidden = HiddenSequence::new(vec![vec![1.0, 2.0]])?;
        let sublayer_output = ProjectedAttentionOutput::new(vec![vec![0.5, 1.5, 2.5]])?;

        assert!(matches!(
            ResidualConnection.apply(Product::new(hidden, sublayer_output)),
            Err(CtError::ShapeMismatch {
                op: "residual connection",
                ..
            })
        ));

        Ok(())
    }

    #[test]
    fn layer_normalization_rejects_model_dimension_mismatch() -> CtResult<()> {
        let hidden = HiddenSequence::new(vec![vec![1.0, 2.0, 3.0]])?;
        let norm = LayerNormalization::new(LayerNormParameters::identity(ModelDimension::new(2)?));

        assert!(matches!(
            norm.apply(hidden),
            Err(CtError::ShapeMismatch {
                op: "layer normalization",
                ..
            })
        ));

        Ok(())
    }

    #[test]
    fn layer_norm_parameters_reject_bad_shapes_and_values() {
        assert!(matches!(
            LayerNormParameters::new(vec![1.0, 1.0], vec![0.0], NormalizationEpsilon(1e-5)),
            Err(CtError::ShapeMismatch {
                op: "layer norm parameters",
                ..
            })
        ));
        assert!(matches!(
            LayerNormParameters::new(vec![f32::NAN], vec![0.0], NormalizationEpsilon(1e-5)),
            Err(CtError::ShapeMismatch {
                op: "layer norm parameters",
                ..
            })
        ));
        assert!(matches!(
            LayerNormParameters::new(vec![1.0], vec![f32::INFINITY], NormalizationEpsilon(1e-5)),
            Err(CtError::ShapeMismatch {
                op: "layer norm parameters",
                ..
            })
        ));
        assert!(matches!(
            NormalizationEpsilon::new(0.0),
            Err(CtError::ShapeMismatch {
                op: "normalization epsilon",
                ..
            })
        ));
    }

    #[test]
    fn position_wise_feed_forward_rejects_input_dimension_mismatch() -> CtResult<()> {
        let hidden = HiddenSequence::new(vec![vec![1.0, 2.0, 3.0]])?;
        let feed_forward = PositionWiseFeedForward::new(
            vec![vec![1.0, 0.0], vec![0.0, 1.0]],
            vec![0.0, 0.0],
            vec![vec![1.0, 0.0], vec![0.0, 1.0]],
            vec![0.0, 0.0],
        )?;

        assert!(matches!(
            feed_forward.apply(hidden),
            Err(CtError::ShapeMismatch {
                op: "position-wise feed-forward",
                ..
            })
        ));

        Ok(())
    }

    #[test]
    fn position_wise_feed_forward_rejects_incompatible_layer_shapes() {
        assert!(matches!(
            PositionWiseFeedForward::new(
                vec![vec![1.0, 0.0]],
                vec![0.0, 0.0],
                vec![vec![1.0], vec![1.0], vec![1.0]],
                vec![0.0],
            ),
            Err(CtError::ShapeMismatch {
                op: "position-wise feed-forward",
                ..
            })
        ));
        assert!(matches!(
            PositionWiseFeedForward::new(
                vec![vec![1.0, 0.0]],
                vec![0.0, 0.0],
                vec![vec![1.0, 0.0], vec![1.0, 0.0]],
                vec![0.0, 0.0],
            ),
            Err(CtError::ShapeMismatch {
                op: "position-wise feed-forward",
                ..
            })
        ));
    }

    #[test]
    fn position_wise_feed_forward_rejects_non_finite_values() {
        assert!(matches!(
            PositionWiseFeedForward::new(
                vec![vec![f32::NAN]],
                vec![0.0],
                vec![vec![1.0]],
                vec![0.0],
            ),
            Err(CtError::ShapeMismatch {
                op: "position-wise feed-forward first layer",
                ..
            })
        ));
        assert!(matches!(
            PositionWiseFeedForward::new(
                vec![vec![1.0]],
                vec![0.0],
                vec![vec![1.0]],
                vec![f32::INFINITY],
            ),
            Err(CtError::ShapeMismatch {
                op: "position-wise feed-forward second layer",
                ..
            })
        ));
    }

    #[test]
    fn positional_encoding_adds_position_rows_and_preserves_shape() -> CtResult<()> {
        let hidden = HiddenSequence::new(vec![vec![1.0, 2.0], vec![3.0, 4.0]])?;
        let positions = PositionalEncoding::new(vec![vec![0.1, 0.2], vec![0.3, 0.4]])?;

        let output = positions.apply(hidden)?;

        assert_eq!(output.sequence_len().value(), 2);
        assert_eq!(output.model_dimension().value(), 2);
        assert!(crate::domain::approx_eq(
            output.rows()[0].as_slice()[0],
            1.1,
            1e-4
        ));
        assert!(crate::domain::approx_eq(
            output.rows()[1].as_slice()[1],
            4.4,
            1e-4
        ));

        Ok(())
    }

    #[test]
    fn positional_encoding_rejects_model_dimension_mismatch() -> CtResult<()> {
        let hidden = HiddenSequence::new(vec![vec![1.0, 2.0]])?;
        let positions = PositionalEncoding::new(vec![vec![0.1, 0.2, 0.3]])?;

        assert!(matches!(
            positions.apply(hidden),
            Err(CtError::ShapeMismatch {
                op: "positional encoding",
                ..
            })
        ));

        Ok(())
    }

    #[test]
    fn positional_encoding_rejects_sequence_too_long() -> CtResult<()> {
        let hidden = HiddenSequence::new(vec![vec![1.0], vec![2.0]])?;
        let positions = PositionalEncoding::new(vec![vec![0.1]])?;

        assert!(matches!(
            positions.apply(hidden),
            Err(CtError::ShapeMismatch {
                op: "positional encoding",
                ..
            })
        ));

        Ok(())
    }

    #[test]
    fn hidden_to_query_projects_hidden_rows() -> CtResult<()> {
        let hidden = HiddenSequence::new(vec![vec![1.0, 2.0], vec![3.0, 4.0]])?;
        let projection = HiddenToQuery::new(vec![vec![1.0, 0.0], vec![0.0, 1.0]], vec![0.5, -0.5])?;

        let queries = projection.apply(hidden)?;

        assert_eq!(queries.sequence_len().value(), 2);
        assert_eq!(queries.head_dimension().value(), 2);
        assert!(crate::domain::approx_eq(
            queries.rows()[0].as_slice()[0],
            1.5,
            1e-4
        ));
        assert!(crate::domain::approx_eq(
            queries.rows()[1].as_slice()[1],
            3.5,
            1e-4
        ));

        Ok(())
    }

    #[test]
    fn hidden_projection_rejects_input_dimension_mismatch() -> CtResult<()> {
        let hidden = HiddenSequence::new(vec![vec![1.0, 2.0, 3.0]])?;
        let projection = HiddenToValue::new(vec![vec![1.0, 0.0], vec![0.0, 1.0]], vec![0.0, 0.0])?;

        assert!(matches!(
            projection.apply(hidden),
            Err(CtError::ShapeMismatch {
                op: "hidden-to-value projection",
                ..
            })
        ));

        Ok(())
    }

    #[test]
    fn residual_connection_adds_hidden_sequences() -> CtResult<()> {
        let left = HiddenSequence::new(vec![vec![1.0, 2.0]])?;
        let right = HiddenSequence::new(vec![vec![3.0, 4.0]])?;

        let output = ResidualConnection.apply(Product::new(left, right))?;

        assert_eq!(output.rows()[0].as_slice(), &[4.0, 6.0]);

        Ok(())
    }

    #[test]
    fn single_head_transformer_block_preserves_hidden_sequence_shape() -> CtResult<()> {
        let block = tiny_single_head_block()?;
        let hidden = HiddenSequence::new(vec![vec![1.0, 0.0], vec![0.0, 1.0]])?;

        let output = block.apply(hidden)?;

        assert_eq!(output.sequence_len().value(), 2);
        assert_eq!(output.model_dimension().value(), 2);
        assert!(
            output
                .rows()
                .iter()
                .flat_map(|row| row.as_slice())
                .all(|value| value.is_finite())
        );

        Ok(())
    }

    #[test]
    fn single_head_transformer_block_rejects_constructor_dimension_mismatch() -> CtResult<()> {
        let query = HiddenToQuery::new(vec![vec![1.0, 0.0], vec![0.0, 1.0]], vec![0.0, 0.0])?;
        let key = HiddenToKey::new(
            vec![vec![1.0, 0.0], vec![0.0, 1.0], vec![0.0, 0.0]],
            vec![0.0, 0.0],
        )?;
        let value = HiddenToValue::new(vec![vec![1.0, 0.0], vec![0.0, 1.0]], vec![0.0, 0.0])?;
        let output_projection =
            AttentionOutputProjection::new(vec![vec![1.0, 0.0], vec![0.0, 1.0]], vec![0.0, 0.0])?;
        let model_dimension = ModelDimension::new(2)?;
        let attention_norm =
            LayerNormalization::new(LayerNormParameters::identity(model_dimension));
        let feed_forward = PositionWiseFeedForward::new(
            vec![vec![1.0, 0.0], vec![0.0, 1.0]],
            vec![0.0, 0.0],
            vec![vec![1.0, 0.0], vec![0.0, 1.0]],
            vec![0.0, 0.0],
        )?;
        let feed_forward_norm =
            LayerNormalization::new(LayerNormParameters::identity(model_dimension));

        assert!(matches!(
            SingleHeadTransformerBlock::new(
                query,
                key,
                value,
                output_projection,
                attention_norm,
                feed_forward,
                feed_forward_norm,
            ),
            Err(CtError::ShapeMismatch {
                op: "single-head block key projection",
                ..
            })
        ));

        Ok(())
    }

    #[test]
    fn single_head_transformer_block_rejects_apply_dimension_mismatch() -> CtResult<()> {
        let block = tiny_single_head_block()?;
        let hidden = HiddenSequence::new(vec![vec![1.0, 0.0, 0.0]])?;

        assert!(matches!(
            block.apply(hidden),
            Err(CtError::ShapeMismatch {
                op: "single-head block",
                ..
            })
        ));

        Ok(())
    }

    #[test]
    fn self_attention_head_rejects_query_key_head_mismatch() -> CtResult<()> {
        let query = HiddenToQuery::new(vec![vec![1.0], vec![0.0]], vec![0.0])?;
        let key = HiddenToKey::new(vec![vec![1.0, 0.0], vec![0.0, 1.0]], vec![0.0, 0.0])?;
        let value = HiddenToValue::new(vec![vec![1.0], vec![0.0]], vec![0.0])?;

        assert!(matches!(
            SelfAttentionHead::new(query, key, value),
            Err(CtError::ShapeMismatch {
                op: "self-attention head",
                ..
            })
        ));

        Ok(())
    }

    #[test]
    fn multi_head_transformer_block_preserves_hidden_sequence_shape() -> CtResult<()> {
        let block = tiny_multi_head_block()?;
        let hidden = HiddenSequence::new(vec![vec![1.0, 0.0], vec![0.0, 1.0]])?;

        let output = block.apply(hidden)?;

        assert_eq!(block.head_count().value(), 2);
        assert_eq!(block.value_dimension().value(), 1);
        assert_eq!(output.sequence_len().value(), 2);
        assert_eq!(output.model_dimension().value(), 2);
        assert!(
            output
                .rows()
                .iter()
                .flat_map(|row| row.as_slice())
                .all(|value| value.is_finite())
        );

        Ok(())
    }

    #[test]
    fn multi_head_transformer_block_rejects_value_dimension_mismatch() -> CtResult<()> {
        let head_a = tiny_self_attention_head_first_feature()?;
        let head_b = SelfAttentionHead::new(
            HiddenToQuery::new(vec![vec![0.0], vec![1.0]], vec![0.0])?,
            HiddenToKey::new(vec![vec![0.0], vec![1.0]], vec![0.0])?,
            HiddenToValue::new(vec![vec![0.0, 1.0], vec![1.0, 0.0]], vec![0.0, 0.0])?,
        )?;
        let model_dimension = ModelDimension::new(2)?;

        assert!(matches!(
            MultiHeadTransformerBlock::new(
                vec![head_a, head_b],
                AttentionOutputProjection::new(
                    vec![vec![1.0, 0.0], vec![0.0, 1.0]],
                    vec![0.0, 0.0],
                )?,
                LayerNormalization::new(LayerNormParameters::identity(model_dimension)),
                identity_feed_forward()?,
                LayerNormalization::new(LayerNormParameters::identity(model_dimension)),
            ),
            Err(CtError::ShapeMismatch {
                op: "multi-head block",
                ..
            })
        ));

        Ok(())
    }

    #[test]
    fn multi_head_transformer_block_rejects_output_projection_input_mismatch() -> CtResult<()> {
        let model_dimension = ModelDimension::new(2)?;

        assert!(matches!(
            MultiHeadTransformerBlock::new(
                vec![
                    tiny_self_attention_head_first_feature()?,
                    tiny_self_attention_head_second_feature()?,
                ],
                AttentionOutputProjection::new(
                    vec![vec![1.0, 0.0], vec![0.0, 1.0], vec![0.0, 0.0]],
                    vec![0.0, 0.0],
                )?,
                LayerNormalization::new(LayerNormParameters::identity(model_dimension)),
                identity_feed_forward()?,
                LayerNormalization::new(LayerNormParameters::identity(model_dimension)),
            ),
            Err(CtError::ShapeMismatch {
                op: "multi-head block output projection input",
                ..
            })
        ));

        Ok(())
    }

    #[test]
    fn multi_head_transformer_block_rejects_apply_dimension_mismatch() -> CtResult<()> {
        let block = tiny_multi_head_block()?;
        let hidden = HiddenSequence::new(vec![vec![1.0, 0.0, 0.0]])?;

        assert!(matches!(
            block.apply(hidden),
            Err(CtError::ShapeMismatch {
                op: "multi-head block",
                ..
            })
        ));

        Ok(())
    }

    #[test]
    fn masked_multi_head_transformer_block_preserves_hidden_sequence_shape() -> CtResult<()> {
        let block = tiny_masked_multi_head_block()?;
        let hidden = HiddenSequence::new(vec![vec![1.0, 0.0], vec![0.0, 1.0]])?;
        let mask = AttentionMask::new(vec![vec![true, false], vec![true, true]])?;

        let output = block.apply(Product::new(hidden, mask))?;

        assert_eq!(block.head_count().value(), 2);
        assert_eq!(output.sequence_len().value(), 2);
        assert_eq!(output.model_dimension().value(), 2);
        assert!(
            output
                .rows()
                .iter()
                .flat_map(|row| row.as_slice())
                .all(|value| value.is_finite())
        );

        Ok(())
    }

    #[test]
    fn masked_multi_head_transformer_block_rejects_mask_shape_mismatch() -> CtResult<()> {
        let block = tiny_masked_multi_head_block()?;
        let hidden = HiddenSequence::new(vec![vec![1.0, 0.0], vec![0.0, 1.0]])?;
        let mask = AttentionMask::new(vec![vec![true, true, true], vec![true, true, true]])?;

        assert!(matches!(
            block.apply(Product::new(hidden, mask)),
            Err(CtError::ShapeMismatch {
                op: "masked attention scores",
                ..
            })
        ));

        Ok(())
    }

    #[test]
    fn transformer_readout_maps_each_hidden_position_to_logits() -> CtResult<()> {
        let readout = TransformerReadout::new(
            vec![vec![1.0, 0.0, 0.5], vec![0.0, 1.0, -0.5]],
            vec![0.0, 0.1, -0.1],
        )?;
        let hidden = HiddenSequence::new(vec![vec![2.0, 3.0], vec![4.0, 5.0]])?;

        let logits = readout.apply(hidden)?;

        assert_eq!(logits.sequence_len().value(), 2);
        assert_eq!(logits.vocab_size().value(), 3);
        assert_eq!(logits.rows()[0].as_slice(), &[2.0, 3.1, -0.6]);
        assert_eq!(logits.rows()[1].as_slice(), &[4.0, 5.1, -0.6]);
        Ok(())
    }

    #[test]
    fn tiny_transformer_parameters_forward_maps_hidden_and_mask_to_sequence_logits() -> CtResult<()>
    {
        let parameters = tiny_transformer_parameters()?;
        let hidden = HiddenSequence::new(vec![vec![1.0, 0.0], vec![0.0, 1.0]])?;
        let mask = AttentionMask::new(vec![vec![true, false], vec![true, true]])?;

        let logits = parameters.apply(Product::new(hidden, mask))?;

        assert_eq!(parameters.model_dimension().value(), 2);
        assert_eq!(parameters.max_sequence_len().value(), 2);
        assert_eq!(logits.sequence_len().value(), 2);
        assert_eq!(logits.vocab_size().value(), 3);
        assert!(
            logits
                .rows()
                .iter()
                .flat_map(|row| row.as_slice())
                .all(|value| value.is_finite())
        );

        Ok(())
    }

    #[test]
    fn tiny_transformer_parameters_rejects_readout_dimension_mismatch() -> CtResult<()> {
        let positional_encoding = PositionalEncoding::new(vec![vec![0.1, 0.0], vec![0.0, 0.1]])?;
        let block = tiny_masked_multi_head_block()?;
        let readout = TransformerReadout::new(vec![vec![1.0], vec![0.0], vec![0.5]], vec![0.0])?;

        assert!(matches!(
            TinyTransformerParameters::new(positional_encoding, block, readout),
            Err(CtError::ShapeMismatch {
                op: "tiny transformer parameters readout",
                ..
            })
        ));

        Ok(())
    }

    #[test]
    fn transformer_training_state_records_updated_parameters_and_step_count() -> CtResult<()> {
        let initial_parameters = tiny_transformer_parameters()?;
        let updated_parameters = tiny_transformer_parameters()?;
        let state = TransformerTrainingState::new(initial_parameters, LearningRate::new(0.25)?);

        let next_state = state.record_updated_parameters(updated_parameters.clone());

        assert_eq!(next_state.parameters(), &updated_parameters);
        assert_eq!(next_state.learning_rate().value(), 0.25);
        assert_eq!(next_state.step_count().value(), 1);
        Ok(())
    }

    #[test]
    fn transformer_training_state_forward_uses_structured_parameters() -> CtResult<()> {
        let state =
            TransformerTrainingState::new(tiny_transformer_parameters()?, LearningRate::new(0.1)?);
        let hidden = HiddenSequence::new(vec![vec![1.0, 0.0], vec![0.0, 1.0]])?;
        let mask = AttentionMask::new(vec![vec![true, false], vec![true, true]])?;

        let logits = state.apply(Product::new(hidden, mask))?;

        assert_eq!(logits.sequence_len().value(), 2);
        assert_eq!(logits.vocab_size().value(), 3);
        assert_eq!(state.step_count().value(), 0);
        Ok(())
    }

    #[test]
    fn transformer_readout_training_example_rejects_target_length_mismatch() -> CtResult<()> {
        let hidden = HiddenSequence::new(vec![vec![1.0, 0.0], vec![0.0, 1.0]])?;
        let mask = AttentionMask::new(vec![vec![true, false], vec![true, true]])?;
        let targets = TokenSequence::from_indices([0])?;

        assert!(matches!(
            TransformerReadoutTrainingExample::new(hidden, mask, targets),
            Err(CtError::ShapeMismatch {
                op: "transformer readout training targets",
                ..
            })
        ));

        Ok(())
    }

    #[test]
    fn transformer_readout_training_example_rejects_mask_shape_mismatch() -> CtResult<()> {
        let hidden = HiddenSequence::new(vec![vec![1.0, 0.0], vec![0.0, 1.0]])?;
        let mask = AttentionMask::new(vec![vec![true, true, true], vec![true, true, true]])?;
        let targets = TokenSequence::from_indices([0, 1])?;

        assert!(matches!(
            TransformerReadoutTrainingExample::new(hidden, mask, targets),
            Err(CtError::ShapeMismatch {
                op: "transformer readout training mask",
                ..
            })
        ));

        Ok(())
    }

    #[test]
    fn transformer_readout_train_step_reduces_sequence_loss() -> CtResult<()> {
        let dataset = tiny_transformer_training_set()?;
        let state =
            TransformerTrainingState::new(tiny_transformer_parameters()?, LearningRate::new(0.5)?);
        let before = transformer_readout_average_loss(&state, &dataset)?;
        let train_step = TransformerReadoutTrainStep::new(dataset.clone());

        let trained =
            crate::category::apply_endomorphism_n_times(&train_step, state, StepCount::new(40))?;
        let after = transformer_readout_average_loss(&trained, &dataset)?;

        assert!(after.value() < before.value());
        assert_eq!(trained.step_count().value(), 40);
        Ok(())
    }

    #[test]
    fn transformer_readout_train_step_rejects_target_outside_vocabulary() -> CtResult<()> {
        let example = TransformerReadoutTrainingExample::new(
            HiddenSequence::new(vec![vec![1.0, 0.0], vec![0.0, 1.0]])?,
            AttentionMask::new(vec![vec![true, false], vec![true, true]])?,
            TokenSequence::from_indices([0, 9])?,
        )?;
        let dataset = TransformerReadoutTrainingSet::new([example])?;
        let state =
            TransformerTrainingState::new(tiny_transformer_parameters()?, LearningRate::new(0.1)?);
        let train_step = TransformerReadoutTrainStep::new(dataset);

        assert!(matches!(
            train_step.apply(state),
            Err(CtError::OutOfRange {
                kind: "sequence target",
                index: 9,
                limit: 3,
            })
        ));

        Ok(())
    }

    #[test]
    fn transformer_feed_forward_training_example_rejects_target_shape_mismatch() -> CtResult<()> {
        let input = HiddenSequence::new(vec![vec![1.0, 0.0], vec![0.0, 1.0]])?;
        let target = HiddenSequence::new(vec![vec![1.0, 0.0, 0.0], vec![0.0, 1.0, 0.0]])?;

        assert!(matches!(
            TransformerFeedForwardTrainingExample::new(input, target),
            Err(CtError::ShapeMismatch {
                op: "transformer feed-forward training dimension",
                ..
            })
        ));

        Ok(())
    }

    #[test]
    fn transformer_feed_forward_training_set_rejects_empty_input() {
        assert!(matches!(
            TransformerFeedForwardTrainingSet::new([]),
            Err(CtError::EmptyInput("transformer feed-forward training set"))
        ));
    }

    #[test]
    fn transformer_feed_forward_train_step_reduces_local_hidden_loss() -> CtResult<()> {
        let dataset = tiny_feed_forward_training_set()?;
        let state =
            TransformerTrainingState::new(tiny_transformer_parameters()?, LearningRate::new(0.2)?);
        let before = transformer_feed_forward_average_loss(&state, &dataset)?;
        let train_step = TransformerFeedForwardTrainStep::new(dataset.clone());

        let trained =
            crate::category::apply_endomorphism_n_times(&train_step, state, StepCount::new(60))?;
        let after = transformer_feed_forward_average_loss(&trained, &dataset)?;

        assert!(after.value() < before.value());
        assert_eq!(trained.step_count().value(), 60);
        Ok(())
    }

    #[test]
    fn transformer_block_training_example_rejects_mask_shape_mismatch() -> CtResult<()> {
        let hidden = HiddenSequence::new(vec![vec![1.0, 0.0], vec![0.0, 1.0]])?;
        let mask = AttentionMask::new(vec![vec![true, true, true], vec![true, true, true]])?;
        let targets = TokenSequence::from_indices([0, 1])?;

        assert!(matches!(
            TransformerBlockTrainingExample::new(hidden, mask, targets),
            Err(CtError::ShapeMismatch {
                op: "transformer block training mask",
                ..
            })
        ));

        Ok(())
    }

    #[test]
    fn transformer_block_training_set_rejects_empty_input() {
        assert!(matches!(
            TransformerBlockTrainingSet::new([]),
            Err(CtError::EmptyInput("transformer block training set"))
        ));
    }

    #[test]
    fn transformer_block_train_step_rejects_target_outside_vocabulary() -> CtResult<()> {
        let example = TransformerBlockTrainingExample::new(
            HiddenSequence::new(vec![vec![1.0, 0.0], vec![0.0, 1.0]])?,
            AttentionMask::new(vec![vec![true, false], vec![true, true]])?,
            TokenSequence::from_indices([0, 9])?,
        )?;
        let dataset = TransformerBlockTrainingSet::new([example])?;
        let state =
            TransformerTrainingState::new(tiny_transformer_parameters()?, LearningRate::new(0.1)?);
        let train_step = TransformerBlockTrainStep::new(dataset);

        assert!(matches!(
            train_step.apply(state),
            Err(CtError::OutOfRange {
                kind: "sequence target",
                index: 9,
                limit: 3,
            })
        ));

        Ok(())
    }

    #[test]
    fn transformer_block_train_step_reduces_sequence_loss() -> CtResult<()> {
        let dataset = tiny_transformer_block_training_set()?;
        let state =
            TransformerTrainingState::new(tiny_transformer_parameters()?, LearningRate::new(0.2)?);
        let before = transformer_block_average_loss(&state, &dataset)?;
        let train_step = TransformerBlockTrainStep::new(dataset.clone());

        let trained =
            crate::category::apply_endomorphism_n_times(&train_step, state, StepCount::new(40))?;
        let after = transformer_block_average_loss(&trained, &dataset)?;

        assert!(after.value() < before.value());
        assert_eq!(trained.step_count().value(), 40);
        Ok(())
    }

    #[test]
    fn transformer_block_train_step_updates_attention_output_projection() -> CtResult<()> {
        let dataset = tiny_transformer_block_training_set()?;
        let state =
            TransformerTrainingState::new(tiny_transformer_parameters()?, LearningRate::new(0.2)?);
        let before = state.parameters().output_projection().weight().to_vec();
        let train_step = TransformerBlockTrainStep::new(dataset);

        let trained = train_step.apply(state)?;
        let after = trained.parameters().output_projection().weight().to_vec();

        assert_ne!(before, after);
        assert_eq!(trained.step_count().value(), 1);
        Ok(())
    }

    #[test]
    fn transformer_block_train_step_updates_layer_norm_parameters() -> CtResult<()> {
        let dataset = tiny_transformer_block_training_set()?;
        let state =
            TransformerTrainingState::new(tiny_transformer_parameters()?, LearningRate::new(0.2)?);
        let before_attention_scale = state
            .parameters()
            .attention_norm()
            .parameters()
            .scale()
            .to_vec();
        let before_feed_forward_shift = state
            .parameters()
            .feed_forward_norm()
            .parameters()
            .shift()
            .to_vec();
        let train_step = TransformerBlockTrainStep::new(dataset);

        let trained = train_step.apply(state)?;
        let after_attention_scale = trained
            .parameters()
            .attention_norm()
            .parameters()
            .scale()
            .to_vec();
        let after_feed_forward_shift = trained
            .parameters()
            .feed_forward_norm()
            .parameters()
            .shift()
            .to_vec();

        assert_ne!(before_attention_scale, after_attention_scale);
        assert_ne!(before_feed_forward_shift, after_feed_forward_shift);
        assert_eq!(trained.step_count().value(), 1);
        Ok(())
    }

    #[test]
    fn transformer_block_train_step_updates_query_key_value_projections() -> CtResult<()> {
        let dataset = tiny_transformer_block_training_set()?;
        let state =
            TransformerTrainingState::new(tiny_transformer_parameters()?, LearningRate::new(0.2)?);
        let before_query = state
            .parameters()
            .attention_heads()
            .iter()
            .map(|head| head.query_projection().weight().to_vec())
            .collect::<Vec<_>>();
        let before_key = state
            .parameters()
            .attention_heads()
            .iter()
            .map(|head| head.key_projection().weight().to_vec())
            .collect::<Vec<_>>();
        let before_value = state
            .parameters()
            .attention_heads()
            .iter()
            .map(|head| head.value_projection().weight().to_vec())
            .collect::<Vec<_>>();
        let train_step = TransformerBlockTrainStep::new(dataset);

        let trained = train_step.apply(state)?;
        let after_query = trained
            .parameters()
            .attention_heads()
            .iter()
            .map(|head| head.query_projection().weight().to_vec())
            .collect::<Vec<_>>();
        let after_key = trained
            .parameters()
            .attention_heads()
            .iter()
            .map(|head| head.key_projection().weight().to_vec())
            .collect::<Vec<_>>();
        let after_value = trained
            .parameters()
            .attention_heads()
            .iter()
            .map(|head| head.value_projection().weight().to_vec())
            .collect::<Vec<_>>();

        assert_ne!(before_query, after_query);
        assert_ne!(before_key, after_key);
        assert_ne!(before_value, after_value);
        assert_eq!(trained.step_count().value(), 1);
        Ok(())
    }

    #[test]
    fn transformer_block_train_step_matches_finite_difference_for_attention_projection()
    -> CtResult<()> {
        let dataset = tiny_transformer_block_training_set()?;
        let state =
            TransformerTrainingState::new(tiny_transformer_parameters()?, LearningRate::new(0.2)?);
        let train_step = TransformerBlockTrainStep::new(dataset.clone());
        let trained = train_step.apply(state.clone())?;
        let selection = largest_changed_attention_projection(&state, &trained)?;
        let before_value = attention_projection_weight(&state, selection)?;
        let after_value = attention_projection_weight(&trained, selection)?;
        let inferred_gradient = (before_value - after_value) / state.learning_rate().value();
        let epsilon = 1e-3;
        let loss_plus = transformer_block_average_loss(
            &state_with_attention_projection_weight(&state, selection, before_value + epsilon)?,
            &dataset,
        )?
        .value();
        let loss_minus = transformer_block_average_loss(
            &state_with_attention_projection_weight(&state, selection, before_value - epsilon)?,
            &dataset,
        )?
        .value();
        let finite_difference = (loss_plus - loss_minus) / (2.0 * epsilon);

        assert!(
            (inferred_gradient - finite_difference).abs() < 1e-2,
            "inferred gradient {inferred_gradient} should match finite difference {finite_difference}"
        );
        Ok(())
    }

    #[test]
    fn transformer_block_train_step_matches_finite_difference_for_readout_weight() -> CtResult<()> {
        let dataset = tiny_transformer_block_training_set()?;
        let state =
            TransformerTrainingState::new(tiny_transformer_parameters()?, LearningRate::new(0.2)?);
        let train_step = TransformerBlockTrainStep::new(dataset.clone());
        let trained = train_step.apply(state.clone())?;
        let selection = largest_changed_readout_weight(&state, &trained)?;
        let before_value = readout_weight(&state, selection);
        let after_value = readout_weight(&trained, selection);

        assert_block_gradient_matches_finite_difference(
            &state,
            &dataset,
            before_value,
            after_value,
            |value| state_with_readout_weight(&state, selection, value),
        )
    }

    #[test]
    fn transformer_block_train_step_matches_finite_difference_for_feed_forward_weight()
    -> CtResult<()> {
        let dataset = tiny_transformer_block_training_set()?;
        let state =
            TransformerTrainingState::new(tiny_transformer_parameters()?, LearningRate::new(0.2)?);
        let train_step = TransformerBlockTrainStep::new(dataset.clone());
        let trained = train_step.apply(state.clone())?;
        let selection = largest_changed_feed_forward_weight(&state, &trained)?;
        let before_value = feed_forward_weight(&state, selection);
        let after_value = feed_forward_weight(&trained, selection);

        assert_block_gradient_matches_finite_difference(
            &state,
            &dataset,
            before_value,
            after_value,
            |value| state_with_feed_forward_weight(&state, selection, value),
        )
    }

    #[test]
    fn transformer_block_train_step_matches_finite_difference_for_layer_norm_parameter()
    -> CtResult<()> {
        let dataset = tiny_transformer_block_training_set()?;
        let state =
            TransformerTrainingState::new(tiny_transformer_parameters()?, LearningRate::new(0.2)?);
        let train_step = TransformerBlockTrainStep::new(dataset.clone());
        let trained = train_step.apply(state.clone())?;
        let selection = largest_changed_layer_norm_parameter(&state, &trained)?;
        let before_value = layer_norm_parameter_value(&state, selection);
        let after_value = layer_norm_parameter_value(&trained, selection);

        assert_block_gradient_matches_finite_difference(
            &state,
            &dataset,
            before_value,
            after_value,
            |value| state_with_layer_norm_parameter(&state, selection, value),
        )
    }

    #[test]
    fn transformer_block_train_step_matches_finite_difference_for_readout_bias() -> CtResult<()> {
        let dataset = tiny_transformer_block_training_set()?;
        let state =
            TransformerTrainingState::new(tiny_transformer_parameters()?, LearningRate::new(0.2)?);
        let train_step = TransformerBlockTrainStep::new(dataset.clone());
        let trained = train_step.apply(state.clone())?;
        let selection = largest_changed_vector_index(
            state.parameters().readout().bias(),
            trained.parameters().readout().bias(),
            "changed readout bias",
        )?;
        let before_value = state.parameters().readout().bias()[selection];
        let after_value = trained.parameters().readout().bias()[selection];

        assert_block_gradient_matches_finite_difference(
            &state,
            &dataset,
            before_value,
            after_value,
            |value| state_with_readout_bias(&state, selection, value),
        )
    }

    #[test]
    fn transformer_block_train_step_matches_finite_difference_for_feed_forward_bias() -> CtResult<()>
    {
        let dataset = tiny_transformer_block_training_set()?;
        let state =
            TransformerTrainingState::new(tiny_transformer_parameters()?, LearningRate::new(0.2)?);
        let train_step = TransformerBlockTrainStep::new(dataset.clone());
        let trained = train_step.apply(state.clone())?;
        let selection = largest_changed_feed_forward_bias(&state, &trained)?;
        let before_value = feed_forward_bias(&state, selection);
        let after_value = feed_forward_bias(&trained, selection);

        assert_block_gradient_matches_finite_difference(
            &state,
            &dataset,
            before_value,
            after_value,
            |value| state_with_feed_forward_bias(&state, selection, value),
        )
    }

    #[test]
    fn transformer_block_train_step_matches_finite_difference_for_output_projection_bias()
    -> CtResult<()> {
        let dataset = tiny_transformer_block_training_set()?;
        let state =
            TransformerTrainingState::new(tiny_transformer_parameters()?, LearningRate::new(0.2)?);
        let train_step = TransformerBlockTrainStep::new(dataset.clone());
        let trained = train_step.apply(state.clone())?;
        let selection = largest_changed_vector_index(
            state.parameters().output_projection().bias(),
            trained.parameters().output_projection().bias(),
            "changed attention output projection bias",
        )?;
        let before_value = state.parameters().output_projection().bias()[selection];
        let after_value = trained.parameters().output_projection().bias()[selection];

        assert_block_gradient_matches_finite_difference(
            &state,
            &dataset,
            before_value,
            after_value,
            |value| state_with_output_projection_bias(&state, selection, value),
        )
    }

    #[test]
    fn transformer_block_train_step_matches_finite_difference_for_attention_projection_bias()
    -> CtResult<()> {
        let dataset = tiny_transformer_block_training_set()?;
        let state =
            TransformerTrainingState::new(tiny_transformer_parameters()?, LearningRate::new(0.2)?);
        let train_step = TransformerBlockTrainStep::new(dataset.clone());
        let trained = train_step.apply(state.clone())?;
        let selection = largest_changed_attention_projection_bias(&state, &trained)?;
        let before_value = attention_projection_bias(&state, selection)?;
        let after_value = attention_projection_bias(&trained, selection)?;

        assert_block_gradient_matches_finite_difference(
            &state,
            &dataset,
            before_value,
            after_value,
            |value| state_with_attention_projection_bias(&state, selection, value),
        )
    }

    fn assert_block_gradient_matches_finite_difference(
        state: &TransformerTrainingState,
        dataset: &TransformerBlockTrainingSet,
        before_value: f32,
        after_value: f32,
        mut state_with_value: impl FnMut(f32) -> CtResult<TransformerTrainingState>,
    ) -> CtResult<()> {
        let inferred_gradient = (before_value - after_value) / state.learning_rate().value();
        let epsilon = 1e-3;
        let loss_plus =
            transformer_block_average_loss(&state_with_value(before_value + epsilon)?, dataset)?
                .value();
        let loss_minus =
            transformer_block_average_loss(&state_with_value(before_value - epsilon)?, dataset)?
                .value();
        let finite_difference = (loss_plus - loss_minus) / (2.0 * epsilon);

        assert!(
            (inferred_gradient - finite_difference).abs() < 1e-2,
            "inferred gradient {inferred_gradient} should match finite difference {finite_difference}"
        );
        Ok(())
    }

    fn largest_changed_vector_index(
        before: &[f32],
        after: &[f32],
        label: &'static str,
    ) -> CtResult<usize> {
        let mut selected = None;
        let mut largest_delta = 0.0;

        for (index, (before_value, after_value)) in before.iter().zip(after).enumerate() {
            let delta = (before_value - after_value).abs();

            if delta > largest_delta {
                largest_delta = delta;
                selected = Some(index);
            }
        }

        selected.ok_or(CtError::EmptyInput(label))
    }

    #[derive(Debug, Clone, Copy, PartialEq, Eq)]
    struct MatrixSelection {
        input_index: usize,
        output_index: usize,
    }

    fn largest_changed_matrix_weight(
        before: &[Vec<f32>],
        after: &[Vec<f32>],
        label: &'static str,
    ) -> CtResult<MatrixSelection> {
        let mut selected = None;
        let mut largest_delta = 0.0;

        for (input_index, (before_row, after_row)) in before.iter().zip(after).enumerate() {
            for (output_index, (before_value, after_value)) in
                before_row.iter().zip(after_row).enumerate()
            {
                let delta = (before_value - after_value).abs();

                if delta > largest_delta {
                    largest_delta = delta;
                    selected = Some(MatrixSelection {
                        input_index,
                        output_index,
                    });
                }
            }
        }

        selected.ok_or(CtError::EmptyInput(label))
    }

    fn largest_changed_readout_weight(
        before: &TransformerTrainingState,
        after: &TransformerTrainingState,
    ) -> CtResult<MatrixSelection> {
        largest_changed_matrix_weight(
            before.parameters().readout().weight(),
            after.parameters().readout().weight(),
            "changed readout weight",
        )
    }

    fn readout_weight(state: &TransformerTrainingState, selection: MatrixSelection) -> f32 {
        state.parameters().readout().weight()[selection.input_index][selection.output_index]
    }

    fn state_with_readout_weight(
        state: &TransformerTrainingState,
        selection: MatrixSelection,
        value: f32,
    ) -> CtResult<TransformerTrainingState> {
        let readout = state.parameters().readout();
        let mut weight = readout.weight().to_vec();
        weight[selection.input_index][selection.output_index] = value;
        let readout = TransformerReadout::new(weight, readout.bias().to_vec())?;
        let parameters = state.parameters().clone().with_readout(readout)?;

        Ok(TransformerTrainingState::from_parts(
            parameters,
            state.learning_rate(),
            state.step_count(),
        ))
    }

    fn state_with_readout_bias(
        state: &TransformerTrainingState,
        selection: usize,
        value: f32,
    ) -> CtResult<TransformerTrainingState> {
        let readout = state.parameters().readout();
        let mut bias = readout.bias().to_vec();
        bias[selection] = value;
        let readout = TransformerReadout::new(readout.weight().to_vec(), bias)?;
        let parameters = state.parameters().clone().with_readout(readout)?;

        Ok(TransformerTrainingState::from_parts(
            parameters,
            state.learning_rate(),
            state.step_count(),
        ))
    }

    #[derive(Debug, Clone, Copy, PartialEq, Eq)]
    enum FeedForwardWeightKind {
        First,
        Second,
    }

    #[derive(Debug, Clone, Copy, PartialEq, Eq)]
    struct FeedForwardWeightSelection {
        kind: FeedForwardWeightKind,
        matrix: MatrixSelection,
    }

    #[derive(Debug, Clone, Copy, PartialEq, Eq)]
    struct FeedForwardBiasSelection {
        kind: FeedForwardWeightKind,
        index: usize,
    }

    fn largest_changed_feed_forward_weight(
        before: &TransformerTrainingState,
        after: &TransformerTrainingState,
    ) -> CtResult<FeedForwardWeightSelection> {
        let first = largest_changed_matrix_weight(
            before.parameters().feed_forward().first_weight(),
            after.parameters().feed_forward().first_weight(),
            "changed first feed-forward weight",
        );
        let second = largest_changed_matrix_weight(
            before.parameters().feed_forward().second_weight(),
            after.parameters().feed_forward().second_weight(),
            "changed second feed-forward weight",
        );

        match (first, second) {
            (Ok(first), Ok(second)) => {
                let first_delta = feed_forward_weight_delta(
                    before,
                    after,
                    FeedForwardWeightSelection {
                        kind: FeedForwardWeightKind::First,
                        matrix: first,
                    },
                );
                let second_delta = feed_forward_weight_delta(
                    before,
                    after,
                    FeedForwardWeightSelection {
                        kind: FeedForwardWeightKind::Second,
                        matrix: second,
                    },
                );

                if first_delta >= second_delta {
                    Ok(FeedForwardWeightSelection {
                        kind: FeedForwardWeightKind::First,
                        matrix: first,
                    })
                } else {
                    Ok(FeedForwardWeightSelection {
                        kind: FeedForwardWeightKind::Second,
                        matrix: second,
                    })
                }
            }
            (Ok(first), Err(_)) => Ok(FeedForwardWeightSelection {
                kind: FeedForwardWeightKind::First,
                matrix: first,
            }),
            (Err(_), Ok(second)) => Ok(FeedForwardWeightSelection {
                kind: FeedForwardWeightKind::Second,
                matrix: second,
            }),
            (Err(_), Err(_)) => Err(CtError::EmptyInput("changed feed-forward weight")),
        }
    }

    fn feed_forward_weight_delta(
        before: &TransformerTrainingState,
        after: &TransformerTrainingState,
        selection: FeedForwardWeightSelection,
    ) -> f32 {
        (feed_forward_weight(before, selection) - feed_forward_weight(after, selection)).abs()
    }

    fn feed_forward_weight(
        state: &TransformerTrainingState,
        selection: FeedForwardWeightSelection,
    ) -> f32 {
        let feed_forward = state.parameters().feed_forward();

        match selection.kind {
            FeedForwardWeightKind::First => feed_forward.first_weight()
                [selection.matrix.input_index][selection.matrix.output_index],
            FeedForwardWeightKind::Second => feed_forward.second_weight()
                [selection.matrix.input_index][selection.matrix.output_index],
        }
    }

    fn state_with_feed_forward_weight(
        state: &TransformerTrainingState,
        selection: FeedForwardWeightSelection,
        value: f32,
    ) -> CtResult<TransformerTrainingState> {
        let feed_forward = state.parameters().feed_forward();
        let mut first_weight = feed_forward.first_weight().to_vec();
        let mut second_weight = feed_forward.second_weight().to_vec();

        match selection.kind {
            FeedForwardWeightKind::First => {
                first_weight[selection.matrix.input_index][selection.matrix.output_index] = value;
            }
            FeedForwardWeightKind::Second => {
                second_weight[selection.matrix.input_index][selection.matrix.output_index] = value;
            }
        }

        let feed_forward = PositionWiseFeedForward::new(
            first_weight,
            feed_forward.first_bias().to_vec(),
            second_weight,
            feed_forward.second_bias().to_vec(),
        )?;
        let parameters = state.parameters().clone().with_feed_forward(feed_forward)?;

        Ok(TransformerTrainingState::from_parts(
            parameters,
            state.learning_rate(),
            state.step_count(),
        ))
    }

    fn largest_changed_feed_forward_bias(
        before: &TransformerTrainingState,
        after: &TransformerTrainingState,
    ) -> CtResult<FeedForwardBiasSelection> {
        let first = largest_changed_vector_index(
            before.parameters().feed_forward().first_bias(),
            after.parameters().feed_forward().first_bias(),
            "changed first feed-forward bias",
        );
        let second = largest_changed_vector_index(
            before.parameters().feed_forward().second_bias(),
            after.parameters().feed_forward().second_bias(),
            "changed second feed-forward bias",
        );

        match (first, second) {
            (Ok(first), Ok(second)) => {
                let first_delta = feed_forward_bias_delta(
                    before,
                    after,
                    FeedForwardBiasSelection {
                        kind: FeedForwardWeightKind::First,
                        index: first,
                    },
                );
                let second_delta = feed_forward_bias_delta(
                    before,
                    after,
                    FeedForwardBiasSelection {
                        kind: FeedForwardWeightKind::Second,
                        index: second,
                    },
                );

                if first_delta >= second_delta {
                    Ok(FeedForwardBiasSelection {
                        kind: FeedForwardWeightKind::First,
                        index: first,
                    })
                } else {
                    Ok(FeedForwardBiasSelection {
                        kind: FeedForwardWeightKind::Second,
                        index: second,
                    })
                }
            }
            (Ok(first), Err(_)) => Ok(FeedForwardBiasSelection {
                kind: FeedForwardWeightKind::First,
                index: first,
            }),
            (Err(_), Ok(second)) => Ok(FeedForwardBiasSelection {
                kind: FeedForwardWeightKind::Second,
                index: second,
            }),
            (Err(_), Err(_)) => Err(CtError::EmptyInput("changed feed-forward bias")),
        }
    }

    fn feed_forward_bias_delta(
        before: &TransformerTrainingState,
        after: &TransformerTrainingState,
        selection: FeedForwardBiasSelection,
    ) -> f32 {
        (feed_forward_bias(before, selection) - feed_forward_bias(after, selection)).abs()
    }

    fn feed_forward_bias(
        state: &TransformerTrainingState,
        selection: FeedForwardBiasSelection,
    ) -> f32 {
        let feed_forward = state.parameters().feed_forward();

        match selection.kind {
            FeedForwardWeightKind::First => feed_forward.first_bias()[selection.index],
            FeedForwardWeightKind::Second => feed_forward.second_bias()[selection.index],
        }
    }

    fn state_with_feed_forward_bias(
        state: &TransformerTrainingState,
        selection: FeedForwardBiasSelection,
        value: f32,
    ) -> CtResult<TransformerTrainingState> {
        let feed_forward = state.parameters().feed_forward();
        let mut first_bias = feed_forward.first_bias().to_vec();
        let mut second_bias = feed_forward.second_bias().to_vec();

        match selection.kind {
            FeedForwardWeightKind::First => {
                first_bias[selection.index] = value;
            }
            FeedForwardWeightKind::Second => {
                second_bias[selection.index] = value;
            }
        }

        let feed_forward = PositionWiseFeedForward::new(
            feed_forward.first_weight().to_vec(),
            first_bias,
            feed_forward.second_weight().to_vec(),
            second_bias,
        )?;
        let parameters = state.parameters().clone().with_feed_forward(feed_forward)?;

        Ok(TransformerTrainingState::from_parts(
            parameters,
            state.learning_rate(),
            state.step_count(),
        ))
    }

    fn state_with_output_projection_bias(
        state: &TransformerTrainingState,
        selection: usize,
        value: f32,
    ) -> CtResult<TransformerTrainingState> {
        let output_projection = state.parameters().output_projection();
        let mut bias = output_projection.bias().to_vec();
        bias[selection] = value;
        let output_projection =
            AttentionOutputProjection::new(output_projection.weight().to_vec(), bias)?;
        let parameters = state
            .parameters()
            .clone()
            .with_output_projection(output_projection)?;

        Ok(TransformerTrainingState::from_parts(
            parameters,
            state.learning_rate(),
            state.step_count(),
        ))
    }

    #[derive(Debug, Clone, Copy, PartialEq, Eq)]
    enum LayerNormParameterKind {
        AttentionScale,
        AttentionShift,
        FeedForwardScale,
        FeedForwardShift,
    }

    #[derive(Debug, Clone, Copy, PartialEq, Eq)]
    struct LayerNormParameterSelection {
        kind: LayerNormParameterKind,
        feature_index: usize,
    }

    fn largest_changed_layer_norm_parameter(
        before: &TransformerTrainingState,
        after: &TransformerTrainingState,
    ) -> CtResult<LayerNormParameterSelection> {
        let mut selected = None;
        let mut largest_delta = 0.0;

        for kind in [
            LayerNormParameterKind::AttentionScale,
            LayerNormParameterKind::AttentionShift,
            LayerNormParameterKind::FeedForwardScale,
            LayerNormParameterKind::FeedForwardShift,
        ] {
            let before_values = layer_norm_parameter_values(before, kind);
            let after_values = layer_norm_parameter_values(after, kind);

            for (feature_index, (before_value, after_value)) in
                before_values.iter().zip(after_values).enumerate()
            {
                let delta = (before_value - after_value).abs();

                if delta > largest_delta {
                    largest_delta = delta;
                    selected = Some(LayerNormParameterSelection {
                        kind,
                        feature_index,
                    });
                }
            }
        }

        selected.ok_or(CtError::EmptyInput("changed layer norm parameter"))
    }

    fn layer_norm_parameter_values(
        state: &TransformerTrainingState,
        kind: LayerNormParameterKind,
    ) -> &[f32] {
        match kind {
            LayerNormParameterKind::AttentionScale => {
                state.parameters().attention_norm().parameters().scale()
            }
            LayerNormParameterKind::AttentionShift => {
                state.parameters().attention_norm().parameters().shift()
            }
            LayerNormParameterKind::FeedForwardScale => {
                state.parameters().feed_forward_norm().parameters().scale()
            }
            LayerNormParameterKind::FeedForwardShift => {
                state.parameters().feed_forward_norm().parameters().shift()
            }
        }
    }

    fn layer_norm_parameter_value(
        state: &TransformerTrainingState,
        selection: LayerNormParameterSelection,
    ) -> f32 {
        layer_norm_parameter_values(state, selection.kind)[selection.feature_index]
    }

    fn state_with_layer_norm_parameter(
        state: &TransformerTrainingState,
        selection: LayerNormParameterSelection,
        value: f32,
    ) -> CtResult<TransformerTrainingState> {
        let attention_parameters = state.parameters().attention_norm().parameters();
        let feed_forward_parameters = state.parameters().feed_forward_norm().parameters();
        let mut attention_scale = attention_parameters.scale().to_vec();
        let mut attention_shift = attention_parameters.shift().to_vec();
        let mut feed_forward_scale = feed_forward_parameters.scale().to_vec();
        let mut feed_forward_shift = feed_forward_parameters.shift().to_vec();

        match selection.kind {
            LayerNormParameterKind::AttentionScale => {
                attention_scale[selection.feature_index] = value;
            }
            LayerNormParameterKind::AttentionShift => {
                attention_shift[selection.feature_index] = value;
            }
            LayerNormParameterKind::FeedForwardScale => {
                feed_forward_scale[selection.feature_index] = value;
            }
            LayerNormParameterKind::FeedForwardShift => {
                feed_forward_shift[selection.feature_index] = value;
            }
        }

        let attention_norm = LayerNormalization::new(LayerNormParameters::new(
            attention_scale,
            attention_shift,
            attention_parameters.epsilon(),
        )?);
        let feed_forward_norm = LayerNormalization::new(LayerNormParameters::new(
            feed_forward_scale,
            feed_forward_shift,
            feed_forward_parameters.epsilon(),
        )?);
        let parameters = state
            .parameters()
            .clone()
            .with_layer_norms(attention_norm, feed_forward_norm)?;

        Ok(TransformerTrainingState::from_parts(
            parameters,
            state.learning_rate(),
            state.step_count(),
        ))
    }

    #[derive(Debug, Clone, Copy, PartialEq, Eq)]
    enum AttentionProjectionKind {
        Query,
        Key,
        Value,
    }

    #[derive(Debug, Clone, Copy, PartialEq, Eq)]
    struct AttentionProjectionSelection {
        head_index: usize,
        kind: AttentionProjectionKind,
        input_index: usize,
        output_index: usize,
    }

    #[derive(Debug, Clone, Copy, PartialEq, Eq)]
    struct AttentionProjectionBiasSelection {
        head_index: usize,
        kind: AttentionProjectionKind,
        output_index: usize,
    }

    fn largest_changed_attention_projection(
        before: &TransformerTrainingState,
        after: &TransformerTrainingState,
    ) -> CtResult<AttentionProjectionSelection> {
        let head_count = before.parameters().attention_heads().len();
        let mut selected = None;
        let mut largest_delta = 0.0;

        for head_index in 0..head_count {
            for kind in [
                AttentionProjectionKind::Query,
                AttentionProjectionKind::Key,
                AttentionProjectionKind::Value,
            ] {
                let before_weight = attention_projection_weight_matrix(before, head_index, kind)?;
                let after_weight = attention_projection_weight_matrix(after, head_index, kind)?;

                for (input_index, (before_row, after_row)) in
                    before_weight.iter().zip(after_weight).enumerate()
                {
                    for (output_index, (before_value, after_value)) in
                        before_row.iter().zip(after_row).enumerate()
                    {
                        let delta = (before_value - after_value).abs();

                        if delta > largest_delta {
                            largest_delta = delta;
                            selected = Some(AttentionProjectionSelection {
                                head_index,
                                kind,
                                input_index,
                                output_index,
                            });
                        }
                    }
                }
            }
        }

        selected.ok_or(CtError::EmptyInput("changed attention projection"))
    }

    fn attention_projection_weight(
        state: &TransformerTrainingState,
        selection: AttentionProjectionSelection,
    ) -> CtResult<f32> {
        let weight =
            attention_projection_weight_matrix(state, selection.head_index, selection.kind)?;

        Ok(weight[selection.input_index][selection.output_index])
    }

    fn attention_projection_weight_matrix(
        state: &TransformerTrainingState,
        head_index: usize,
        kind: AttentionProjectionKind,
    ) -> CtResult<&[Vec<f32>]> {
        let head =
            state
                .parameters()
                .attention_heads()
                .get(head_index)
                .ok_or(CtError::OutOfRange {
                    kind: "attention head",
                    index: head_index,
                    limit: state.parameters().attention_heads().len(),
                })?;

        Ok(match kind {
            AttentionProjectionKind::Query => head.query_projection().weight(),
            AttentionProjectionKind::Key => head.key_projection().weight(),
            AttentionProjectionKind::Value => head.value_projection().weight(),
        })
    }

    fn largest_changed_attention_projection_bias(
        before: &TransformerTrainingState,
        after: &TransformerTrainingState,
    ) -> CtResult<AttentionProjectionBiasSelection> {
        let head_count = before.parameters().attention_heads().len();
        let mut selected = None;
        let mut largest_delta = 0.0;

        for head_index in 0..head_count {
            for kind in [
                AttentionProjectionKind::Query,
                AttentionProjectionKind::Key,
                AttentionProjectionKind::Value,
            ] {
                let before_bias = attention_projection_bias_values(before, head_index, kind)?;
                let after_bias = attention_projection_bias_values(after, head_index, kind)?;

                for (output_index, (before_value, after_value)) in
                    before_bias.iter().zip(after_bias).enumerate()
                {
                    let delta = (before_value - after_value).abs();

                    if delta > largest_delta {
                        largest_delta = delta;
                        selected = Some(AttentionProjectionBiasSelection {
                            head_index,
                            kind,
                            output_index,
                        });
                    }
                }
            }
        }

        selected.ok_or(CtError::EmptyInput("changed attention projection bias"))
    }

    fn attention_projection_bias(
        state: &TransformerTrainingState,
        selection: AttentionProjectionBiasSelection,
    ) -> CtResult<f32> {
        let bias = attention_projection_bias_values(state, selection.head_index, selection.kind)?;

        Ok(bias[selection.output_index])
    }

    fn attention_projection_bias_values(
        state: &TransformerTrainingState,
        head_index: usize,
        kind: AttentionProjectionKind,
    ) -> CtResult<&[f32]> {
        let head =
            state
                .parameters()
                .attention_heads()
                .get(head_index)
                .ok_or(CtError::OutOfRange {
                    kind: "attention head",
                    index: head_index,
                    limit: state.parameters().attention_heads().len(),
                })?;

        Ok(match kind {
            AttentionProjectionKind::Query => head.query_projection().bias(),
            AttentionProjectionKind::Key => head.key_projection().bias(),
            AttentionProjectionKind::Value => head.value_projection().bias(),
        })
    }

    fn state_with_attention_projection_weight(
        state: &TransformerTrainingState,
        selection: AttentionProjectionSelection,
        value: f32,
    ) -> CtResult<TransformerTrainingState> {
        let mut heads = state.parameters().attention_heads().to_vec();
        let head = heads.get(selection.head_index).ok_or(CtError::OutOfRange {
            kind: "attention head",
            index: selection.head_index,
            limit: heads.len(),
        })?;
        let mut query_weight = head.query_projection().weight().to_vec();
        let mut key_weight = head.key_projection().weight().to_vec();
        let mut value_weight = head.value_projection().weight().to_vec();

        match selection.kind {
            AttentionProjectionKind::Query => {
                query_weight[selection.input_index][selection.output_index] = value;
            }
            AttentionProjectionKind::Key => {
                key_weight[selection.input_index][selection.output_index] = value;
            }
            AttentionProjectionKind::Value => {
                value_weight[selection.input_index][selection.output_index] = value;
            }
        }

        heads[selection.head_index] = SelfAttentionHead::new(
            HiddenToQuery::new(query_weight, head.query_projection().bias().to_vec())?,
            HiddenToKey::new(key_weight, head.key_projection().bias().to_vec())?,
            HiddenToValue::new(value_weight, head.value_projection().bias().to_vec())?,
        )?;
        let parameters = state.parameters().clone().with_attention_heads(heads)?;

        Ok(TransformerTrainingState::from_parts(
            parameters,
            state.learning_rate(),
            state.step_count(),
        ))
    }

    fn state_with_attention_projection_bias(
        state: &TransformerTrainingState,
        selection: AttentionProjectionBiasSelection,
        value: f32,
    ) -> CtResult<TransformerTrainingState> {
        let mut heads = state.parameters().attention_heads().to_vec();
        let head = heads.get(selection.head_index).ok_or(CtError::OutOfRange {
            kind: "attention head",
            index: selection.head_index,
            limit: heads.len(),
        })?;
        let mut query_bias = head.query_projection().bias().to_vec();
        let mut key_bias = head.key_projection().bias().to_vec();
        let mut value_bias = head.value_projection().bias().to_vec();

        match selection.kind {
            AttentionProjectionKind::Query => {
                query_bias[selection.output_index] = value;
            }
            AttentionProjectionKind::Key => {
                key_bias[selection.output_index] = value;
            }
            AttentionProjectionKind::Value => {
                value_bias[selection.output_index] = value;
            }
        }

        heads[selection.head_index] = SelfAttentionHead::new(
            HiddenToQuery::new(head.query_projection().weight().to_vec(), query_bias)?,
            HiddenToKey::new(head.key_projection().weight().to_vec(), key_bias)?,
            HiddenToValue::new(head.value_projection().weight().to_vec(), value_bias)?,
        )?;
        let parameters = state.parameters().clone().with_attention_heads(heads)?;

        Ok(TransformerTrainingState::from_parts(
            parameters,
            state.learning_rate(),
            state.step_count(),
        ))
    }

    fn tiny_single_head_block() -> CtResult<SingleHeadTransformerBlock> {
        let query = HiddenToQuery::new(vec![vec![1.0, 0.0], vec![0.0, 1.0]], vec![0.0, 0.0])?;
        let key = HiddenToKey::new(vec![vec![1.0, 0.0], vec![0.0, 1.0]], vec![0.0, 0.0])?;
        let value = HiddenToValue::new(vec![vec![1.0, 0.0], vec![0.0, 1.0]], vec![0.0, 0.0])?;
        let output_projection =
            AttentionOutputProjection::new(vec![vec![1.0, 0.0], vec![0.0, 1.0]], vec![0.0, 0.0])?;
        let model_dimension = ModelDimension::new(2)?;
        let attention_norm =
            LayerNormalization::new(LayerNormParameters::identity(model_dimension));
        let feed_forward = PositionWiseFeedForward::new(
            vec![vec![1.0, 0.0], vec![0.0, 1.0]],
            vec![0.0, 0.0],
            vec![vec![1.0, 0.0], vec![0.0, 1.0]],
            vec![0.0, 0.0],
        )?;
        let feed_forward_norm =
            LayerNormalization::new(LayerNormParameters::identity(model_dimension));

        SingleHeadTransformerBlock::new(
            query,
            key,
            value,
            output_projection,
            attention_norm,
            feed_forward,
            feed_forward_norm,
        )
    }

    fn tiny_multi_head_block() -> CtResult<MultiHeadTransformerBlock> {
        let model_dimension = ModelDimension::new(2)?;

        MultiHeadTransformerBlock::new(
            vec![
                tiny_self_attention_head_first_feature()?,
                tiny_self_attention_head_second_feature()?,
            ],
            AttentionOutputProjection::new(vec![vec![1.0, 0.0], vec![0.0, 1.0]], vec![0.0, 0.0])?,
            LayerNormalization::new(LayerNormParameters::identity(model_dimension)),
            identity_feed_forward()?,
            LayerNormalization::new(LayerNormParameters::identity(model_dimension)),
        )
    }

    fn tiny_masked_multi_head_block() -> CtResult<MaskedMultiHeadTransformerBlock> {
        let model_dimension = ModelDimension::new(2)?;

        MaskedMultiHeadTransformerBlock::new(
            vec![
                tiny_self_attention_head_first_feature()?,
                tiny_self_attention_head_second_feature()?,
            ],
            AttentionOutputProjection::new(vec![vec![1.0, 0.0], vec![0.0, 1.0]], vec![0.0, 0.0])?,
            LayerNormalization::new(LayerNormParameters::identity(model_dimension)),
            identity_feed_forward()?,
            LayerNormalization::new(LayerNormParameters::identity(model_dimension)),
        )
    }

    fn tiny_transformer_parameters() -> CtResult<TinyTransformerParameters> {
        TinyTransformerParameters::new(
            PositionalEncoding::new(vec![vec![0.1, 0.0], vec![0.0, 0.1]])?,
            tiny_masked_multi_head_block()?,
            TransformerReadout::new(
                vec![vec![1.0, 0.0, 0.5], vec![0.0, 1.0, -0.5]],
                vec![0.0, 0.0, 0.0],
            )?,
        )
    }

    fn tiny_transformer_training_set() -> CtResult<TransformerReadoutTrainingSet> {
        let example = TransformerReadoutTrainingExample::new(
            HiddenSequence::new(vec![vec![1.0, 0.0], vec![0.0, 1.0]])?,
            AttentionMask::new(vec![vec![true, false], vec![true, true]])?,
            TokenSequence::from_indices([0, 1])?,
        )?;

        TransformerReadoutTrainingSet::new([example])
    }

    fn tiny_feed_forward_training_set() -> CtResult<TransformerFeedForwardTrainingSet> {
        let example = TransformerFeedForwardTrainingExample::new(
            HiddenSequence::new(vec![vec![1.0, 0.0], vec![0.0, 1.0]])?,
            HiddenSequence::new(vec![vec![2.0, 0.0], vec![0.0, 2.0]])?,
        )?;

        TransformerFeedForwardTrainingSet::new([example])
    }

    fn tiny_transformer_block_training_set() -> CtResult<TransformerBlockTrainingSet> {
        let example = TransformerBlockTrainingExample::new(
            HiddenSequence::new(vec![vec![1.0, 0.0], vec![0.0, 1.0]])?,
            AttentionMask::new(vec![vec![true, false], vec![true, true]])?,
            TokenSequence::from_indices([0, 1])?,
        )?;

        TransformerBlockTrainingSet::new([example])
    }

    fn tiny_self_attention_head_first_feature() -> CtResult<SelfAttentionHead> {
        SelfAttentionHead::new(
            HiddenToQuery::new(vec![vec![1.0], vec![0.0]], vec![0.0])?,
            HiddenToKey::new(vec![vec![1.0], vec![0.0]], vec![0.0])?,
            HiddenToValue::new(vec![vec![1.0], vec![0.0]], vec![0.0])?,
        )
    }

    fn tiny_self_attention_head_second_feature() -> CtResult<SelfAttentionHead> {
        SelfAttentionHead::new(
            HiddenToQuery::new(vec![vec![0.0], vec![1.0]], vec![0.0])?,
            HiddenToKey::new(vec![vec![0.0], vec![1.0]], vec![0.0])?,
            HiddenToValue::new(vec![vec![0.0], vec![1.0]], vec![0.0])?,
        )
    }

    fn identity_feed_forward() -> CtResult<PositionWiseFeedForward> {
        PositionWiseFeedForward::new(
            vec![vec![1.0, 0.0], vec![0.0, 1.0]],
            vec![0.0, 0.0],
            vec![vec![1.0, 0.0], vec![0.0, 1.0]],
            vec![0.0, 0.0],
        )
    }
}

src/sketches.rs

//! Rust models for the seven applied-category-theory sketches.
//!
//! This module is a study companion for *Seven Sketches in Compositionality*.
//! It does not try to encode all of category theory. Instead, each section
//! gives one small Rust model for the main structure of a sketch: orders,
//! resources, databases, co-design, signal flow, circuits, and logic of
//! behavior.

use crate::error::{CtError, CtResult};

/// A finite order used for the first sketch: observations can be refined into
/// features, scores, and decisions.
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
pub enum InformationLevel {
    Observation,
    Feature,
    Score,
    Decision,
}

impl InformationLevel {
    /// Returns true when information at this level can flow into `target`.
    pub fn can_flow_to(self, target: Self) -> bool {
        self <= target
    }

    /// Least upper bound in this small total order.
    pub fn join(self, other: Self) -> Self {
        self.max(other)
    }
}

const INFORMATION_LEVELS: [InformationLevel; 4] = [
    InformationLevel::Observation,
    InformationLevel::Feature,
    InformationLevel::Score,
    InformationLevel::Decision,
];

/// Checks reflexivity and transitivity for the finite information order.
pub fn information_order_obeys_preorder_laws() -> bool {
    for level in INFORMATION_LEVELS {
        if !level.can_flow_to(level) {
            return false;
        }
    }

    for first in INFORMATION_LEVELS {
        for second in INFORMATION_LEVELS {
            for third in INFORMATION_LEVELS {
                let premise = first.can_flow_to(second) && second.can_flow_to(third);
                if premise && !first.can_flow_to(third) {
                    return false;
                }
            }
        }
    }

    true
}

/// Number of concrete features in a tiny model.
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
pub struct FeatureCount(usize);

impl FeatureCount {
    pub fn new(value: usize) -> CtResult<Self> {
        if value == 0 {
            return Err(CtError::EmptyInput("feature count"));
        }

        Ok(Self(value))
    }

    pub fn value(&self) -> usize {
        self.0
    }
}

/// Number of abstract layers used to summarize feature capacity.
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
pub struct LayerBudget(usize);

impl LayerBudget {
    pub fn new(value: usize) -> CtResult<Self> {
        if value == 0 {
            return Err(CtError::EmptyInput("layer budget"));
        }

        Ok(Self(value))
    }

    pub fn value(&self) -> usize {
        self.0
    }
}

const FEATURES_PER_LAYER: usize = 4;

/// Abstracts concrete features to the minimum layer budget that can hold them.
pub fn abstract_to_layer_budget(features: FeatureCount) -> CtResult<LayerBudget> {
    LayerBudget::new(features.value().div_ceil(FEATURES_PER_LAYER))
}

/// Concretizes an abstract layer budget back to feature capacity.
pub fn concretize_layer_budget(layers: LayerBudget) -> FeatureCount {
    FeatureCount(layers.value() * FEATURES_PER_LAYER)
}

/// Checks the Galois-connection law for the feature/layer abstraction.
pub fn feature_layer_galois_law_holds(
    features: FeatureCount,
    layers: LayerBudget,
) -> CtResult<bool> {
    let abstracted_fits = abstract_to_layer_budget(features)?.value() <= layers.value();
    let concrete_fits = features.value() <= concretize_layer_budget(layers).value();

    Ok(abstracted_fits == concrete_fits)
}

/// A non-negative amount in a resource bundle.
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
pub struct ResourceAmount(usize);

impl ResourceAmount {
    pub fn new(value: usize) -> Self {
        Self(value)
    }

    pub fn value(&self) -> usize {
        self.0
    }
}

/// Compute and memory resources composed by component-wise addition.
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct ResourceBundle {
    compute: ResourceAmount,
    memory: ResourceAmount,
}

impl ResourceBundle {
    pub fn new(compute: ResourceAmount, memory: ResourceAmount) -> Self {
        Self { compute, memory }
    }

    pub fn compute(&self) -> ResourceAmount {
        self.compute
    }

    pub fn memory(&self) -> ResourceAmount {
        self.memory
    }

    /// Monoidal product for independent resources.
    pub fn tensor(&self, other: &Self) -> Self {
        Self {
            compute: ResourceAmount::new(self.compute.value() + other.compute.value()),
            memory: ResourceAmount::new(self.memory.value() + other.memory.value()),
        }
    }

    /// Resource preorder: supply can cover demand when every component is large
    /// enough.
    pub fn can_supply(&self, demand: &Self) -> bool {
        self.compute >= demand.compute && self.memory >= demand.memory
    }
}

/// Demonstrates monotonicity of the monoidal resource operation.
pub fn resource_tensor_is_monotone() -> bool {
    let small_supply = ResourceBundle::new(ResourceAmount::new(2), ResourceAmount::new(8));
    let large_supply = ResourceBundle::new(ResourceAmount::new(4), ResourceAmount::new(16));
    let fixed = ResourceBundle::new(ResourceAmount::new(1), ResourceAmount::new(2));

    large_supply.can_supply(&small_supply)
        && large_supply
            .tensor(&fixed)
            .can_supply(&small_supply.tensor(&fixed))
}

/// Identifier for a department row.
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub struct DepartmentId(usize);

impl DepartmentId {
    pub fn new(value: usize) -> Self {
        Self(value)
    }

    pub fn value(&self) -> usize {
        self.0
    }
}

/// Identifier for an employee row.
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub struct EmployeeId(usize);

impl EmployeeId {
    pub fn new(value: usize) -> Self {
        Self(value)
    }

    pub fn value(&self) -> usize {
        self.0
    }
}

/// A row in the employee table. The department field is the schema arrow
/// `Employee -> Department`.
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct EmployeeRecord {
    id: EmployeeId,
    department: DepartmentId,
}

impl EmployeeRecord {
    pub fn new(id: EmployeeId, department: DepartmentId) -> Self {
        Self { id, department }
    }

    pub fn id(&self) -> EmployeeId {
        self.id
    }

    pub fn department(&self) -> DepartmentId {
        self.department
    }
}

/// A tiny database instance: schema objects become sets of ids, and schema
/// arrows become functions between those sets.
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct CompanyInstance {
    departments: Vec<DepartmentId>,
    employees: Vec<EmployeeRecord>,
}

impl CompanyInstance {
    pub fn new(
        departments: impl IntoIterator<Item = DepartmentId>,
        employees: impl IntoIterator<Item = EmployeeRecord>,
    ) -> CtResult<Self> {
        let departments = departments.into_iter().collect::<Vec<_>>();
        let employees = employees.into_iter().collect::<Vec<_>>();

        for employee in &employees {
            if !departments.contains(&employee.department()) {
                return Err(CtError::ShapeMismatch {
                    op: "database instance",
                    expected: String::from("employee departments exist in Department"),
                    got: format!("missing department {}", employee.department().value()),
                });
            }
        }

        Ok(Self {
            departments,
            employees,
        })
    }

    pub fn departments(&self) -> &[DepartmentId] {
        &self.departments
    }

    pub fn employees(&self) -> &[EmployeeRecord] {
        &self.employees
    }

    pub fn department_of(&self, employee_id: EmployeeId) -> Option<DepartmentId> {
        self.employees
            .iter()
            .find(|employee| employee.id() == employee_id)
            .map(EmployeeRecord::department)
    }

    pub fn foreign_keys_resolve(&self) -> bool {
        self.employees
            .iter()
            .all(|employee| self.departments.contains(&employee.department()))
    }
}

/// Requests per second needed by a design.
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
pub struct Throughput(usize);

impl Throughput {
    pub fn new(value: usize) -> CtResult<Self> {
        if value == 0 {
            return Err(CtError::EmptyInput("throughput"));
        }

        Ok(Self(value))
    }

    pub fn value(&self) -> usize {
        self.0
    }
}

/// Millisecond latency boundary.
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
pub struct LatencyMs(usize);

impl LatencyMs {
    pub fn new(value: usize) -> CtResult<Self> {
        if value == 0 {
            return Err(CtError::EmptyInput("latency"));
        }

        Ok(Self(value))
    }

    pub fn value(&self) -> usize {
        self.0
    }
}

/// Functional need in a co-design problem.
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct DesignRequirement {
    minimum_throughput: Throughput,
    maximum_latency: LatencyMs,
}

impl DesignRequirement {
    pub fn new(minimum_throughput: Throughput, maximum_latency: LatencyMs) -> Self {
        Self {
            minimum_throughput,
            maximum_latency,
        }
    }
}

/// Implementation offered by a candidate component.
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct ImplementationOffer {
    throughput: Throughput,
    latency: LatencyMs,
}

impl ImplementationOffer {
    pub fn new(throughput: Throughput, latency: LatencyMs) -> Self {
        Self {
            throughput,
            latency,
        }
    }
}

/// A Bool-valued feasibility relation between requirements and offers.
#[derive(Debug, Clone, Copy)]
pub struct FeasibilityRelation;

impl FeasibilityRelation {
    pub fn relates(requirement: DesignRequirement, offer: ImplementationOffer) -> bool {
        offer.throughput >= requirement.minimum_throughput
            && offer.latency <= requirement.maximum_latency
    }
}

/// A scalar coefficient in a signal-flow matrix.
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct SignalCoefficient(i32);

impl SignalCoefficient {
    pub fn new(value: i32) -> Self {
        Self(value)
    }

    pub fn zero() -> Self {
        Self(0)
    }

    pub fn value(&self) -> i32 {
        self.0
    }

    fn add(self, other: Self) -> Self {
        Self(self.value() + other.value())
    }

    fn multiply(self, other: Self) -> Self {
        Self(self.value() * other.value())
    }
}

/// Number of rows in a signal-flow matrix.
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct MatrixRows(usize);

impl MatrixRows {
    pub fn new(value: usize) -> CtResult<Self> {
        if value == 0 {
            return Err(CtError::EmptyInput("matrix rows"));
        }

        Ok(Self(value))
    }

    pub fn value(&self) -> usize {
        self.0
    }
}

/// Number of columns in a signal-flow matrix.
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct MatrixCols(usize);

impl MatrixCols {
    pub fn new(value: usize) -> CtResult<Self> {
        if value == 0 {
            return Err(CtError::EmptyInput("matrix columns"));
        }

        Ok(Self(value))
    }

    pub fn value(&self) -> usize {
        self.0
    }
}

/// Matrix semantics for a signal-flow graph.
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct SignalMatrix {
    rows: MatrixRows,
    cols: MatrixCols,
    coefficients: Vec<Vec<SignalCoefficient>>,
}

impl SignalMatrix {
    pub fn new(
        rows: MatrixRows,
        cols: MatrixCols,
        coefficients: Vec<Vec<SignalCoefficient>>,
    ) -> CtResult<Self> {
        if coefficients.len() != rows.value() {
            return Err(CtError::ShapeMismatch {
                op: "signal matrix",
                expected: format!("{} rows", rows.value()),
                got: format!("{} rows", coefficients.len()),
            });
        }

        for row in &coefficients {
            if row.len() != cols.value() {
                return Err(CtError::ShapeMismatch {
                    op: "signal matrix",
                    expected: format!("{} columns", cols.value()),
                    got: format!("{} columns", row.len()),
                });
            }
        }

        Ok(Self {
            rows,
            cols,
            coefficients,
        })
    }

    pub fn rows(&self) -> MatrixRows {
        self.rows
    }

    pub fn cols(&self) -> MatrixCols {
        self.cols
    }

    pub fn coefficients(&self) -> &[Vec<SignalCoefficient>] {
        &self.coefficients
    }

    /// Matrix composition. If `previous` is `A -> B` and `self` is `B -> C`,
    /// the result is `A -> C`.
    pub fn compose_after(&self, previous: &Self) -> CtResult<Self> {
        if previous.rows.value() != self.cols.value() {
            return Err(CtError::ShapeMismatch {
                op: "signal matrix composition",
                expected: format!("middle dimension {}", self.cols.value()),
                got: format!("middle dimension {}", previous.rows.value()),
            });
        }

        let mut coefficients =
            vec![vec![SignalCoefficient::zero(); previous.cols.value()]; self.rows.value()];

        for (output_row, output_coefficients) in coefficients.iter_mut().enumerate() {
            for (input_col, coefficient) in output_coefficients.iter_mut().enumerate() {
                let mut total = SignalCoefficient::zero();

                for middle in 0..self.cols.value() {
                    total = total.add(
                        self.coefficients[output_row][middle]
                            .multiply(previous.coefficients[middle][input_col]),
                    );
                }

                *coefficient = total;
            }
        }

        Self::new(self.rows, previous.cols, coefficients)
    }
}

/// A named boundary port in an open circuit.
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub struct PortName(&'static str);

impl PortName {
    pub fn new(value: &'static str) -> CtResult<Self> {
        if value.trim().is_empty() {
            return Err(CtError::EmptyInput("port name"));
        }

        Ok(Self(value))
    }

    pub fn as_str(&self) -> &'static str {
        self.0
    }
}

/// Positive resistance value.
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct ResistanceOhms(usize);

impl ResistanceOhms {
    pub fn new(value: usize) -> CtResult<Self> {
        if value == 0 {
            return Err(CtError::InvalidQuantity {
                kind: "resistance",
                value: value as i64,
            });
        }

        Ok(Self(value))
    }

    pub fn value(&self) -> usize {
        self.0
    }
}

/// A simple resistor component between two ports.
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct CircuitComponent {
    from: PortName,
    to: PortName,
    resistance: ResistanceOhms,
}

impl CircuitComponent {
    pub fn resistor(from: PortName, to: PortName, resistance: ResistanceOhms) -> Self {
        Self {
            from,
            to,
            resistance,
        }
    }

    pub fn from(&self) -> PortName {
        self.from
    }

    pub fn to(&self) -> PortName {
        self.to
    }

    pub fn resistance(&self) -> ResistanceOhms {
        self.resistance
    }
}

/// An open circuit with input ports, output ports, and internal components.
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct OpenCircuit {
    inputs: Vec<PortName>,
    outputs: Vec<PortName>,
    components: Vec<CircuitComponent>,
}

impl OpenCircuit {
    pub fn new(
        inputs: impl IntoIterator<Item = PortName>,
        outputs: impl IntoIterator<Item = PortName>,
        components: impl IntoIterator<Item = CircuitComponent>,
    ) -> CtResult<Self> {
        let inputs = inputs.into_iter().collect::<Vec<_>>();
        let outputs = outputs.into_iter().collect::<Vec<_>>();
        let components = components.into_iter().collect::<Vec<_>>();

        if inputs.is_empty() {
            return Err(CtError::EmptyInput("circuit inputs"));
        }

        if outputs.is_empty() {
            return Err(CtError::EmptyInput("circuit outputs"));
        }

        Ok(Self {
            inputs,
            outputs,
            components,
        })
    }

    pub fn input_count(&self) -> usize {
        self.inputs.len()
    }

    pub fn output_count(&self) -> usize {
        self.outputs.len()
    }

    pub fn component_count(&self) -> usize {
        self.components.len()
    }

    /// Serial composition wires this circuit's outputs into the next circuit's
    /// inputs.
    pub fn then(&self, next: &Self) -> CtResult<Self> {
        if self.output_count() != next.input_count() {
            return Err(CtError::ShapeMismatch {
                op: "open circuit serial composition",
                expected: format!("{} next inputs", self.output_count()),
                got: format!("{} next inputs", next.input_count()),
            });
        }

        let mut components = self.components.clone();
        components.extend_from_slice(&next.components);

        Self::new(self.inputs.clone(), next.outputs.clone(), components)
    }

    /// Parallel composition keeps the two open interfaces side by side.
    pub fn parallel(&self, other: &Self) -> CtResult<Self> {
        let mut inputs = self.inputs.clone();
        inputs.extend_from_slice(&other.inputs);

        let mut outputs = self.outputs.clone();
        outputs.extend_from_slice(&other.outputs);

        let mut components = self.components.clone();
        components.extend_from_slice(&other.components);

        Self::new(inputs, outputs, components)
    }
}

/// Truth values used for behavior classification.
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum TruthValue {
    False,
    True,
}

impl TruthValue {
    pub fn and(self, other: Self) -> Self {
        match (self, other) {
            (TruthValue::True, TruthValue::True) => TruthValue::True,
            _ => TruthValue::False,
        }
    }

    pub fn implies(self, other: Self) -> Self {
        match (self, other) {
            (TruthValue::True, TruthValue::False) => TruthValue::False,
            _ => TruthValue::True,
        }
    }
}

/// Discrete time point in a behavior trace.
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
pub struct TimeTick(usize);

impl TimeTick {
    pub fn new(value: usize) -> Self {
        Self(value)
    }

    pub fn value(&self) -> usize {
        self.0
    }
}

/// Closed interval of time ticks.
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct TimeInterval {
    start: TimeTick,
    end: TimeTick,
}

impl TimeInterval {
    pub fn new(start: TimeTick, end: TimeTick) -> CtResult<Self> {
        if start > end {
            return Err(CtError::InvalidInterval {
                start: start.value(),
                end: end.value(),
            });
        }

        Ok(Self { start, end })
    }

    pub fn start(&self) -> TimeTick {
        self.start
    }

    pub fn end(&self) -> TimeTick {
        self.end
    }
}

/// Local safety result on one interval.
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct LocalSafetyCheck {
    interval: TimeInterval,
    truth: TruthValue,
}

impl LocalSafetyCheck {
    pub fn new(interval: TimeInterval, truth: TruthValue) -> Self {
        Self { interval, truth }
    }

    pub fn interval(&self) -> TimeInterval {
        self.interval
    }

    pub fn truth(&self) -> TruthValue {
        self.truth
    }
}

/// A cover of local behavior checks. The global result is true only when all
/// local checks are true.
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct SafetyCover(Vec<LocalSafetyCheck>);

impl SafetyCover {
    pub fn new(checks: impl IntoIterator<Item = LocalSafetyCheck>) -> CtResult<Self> {
        let checks = checks.into_iter().collect::<Vec<_>>();

        if checks.is_empty() {
            return Err(CtError::EmptyInput("safety cover"));
        }

        Ok(Self(checks))
    }

    pub fn global_truth(&self) -> TruthValue {
        self.0
            .iter()
            .fold(TruthValue::True, |truth, check| truth.and(check.truth()))
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn information_order_has_preorder_laws() {
        assert!(information_order_obeys_preorder_laws());
        assert_eq!(
            InformationLevel::Feature.join(InformationLevel::Decision),
            InformationLevel::Decision
        );
    }

    #[test]
    fn feature_layer_abstraction_obeys_galois_law() -> CtResult<()> {
        let features = FeatureCount::new(9)?;
        let layers = LayerBudget::new(3)?;

        assert!(feature_layer_galois_law_holds(features, layers)?);
        assert_eq!(abstract_to_layer_budget(features)?.value(), 3);
        assert_eq!(concretize_layer_budget(layers).value(), 12);
        Ok(())
    }

    #[test]
    fn resource_tensor_is_order_preserving() {
        assert!(resource_tensor_is_monotone());
    }

    #[test]
    fn database_instance_resolves_schema_arrow() -> CtResult<()> {
        let research = DepartmentId::new(1);
        let platform = DepartmentId::new(2);
        let ada = EmployeeId::new(7);
        let instance =
            CompanyInstance::new([research, platform], [EmployeeRecord::new(ada, research)])?;

        assert!(instance.foreign_keys_resolve());
        assert_eq!(instance.department_of(ada), Some(research));
        Ok(())
    }

    #[test]
    fn database_instance_rejects_missing_department_reference() {
        let research = DepartmentId::new(1);
        let missing_department = DepartmentId::new(99);
        let ada = EmployeeId::new(7);

        assert!(matches!(
            CompanyInstance::new([research], [EmployeeRecord::new(ada, missing_department)]),
            Err(CtError::ShapeMismatch {
                op: "database instance",
                ..
            })
        ));
    }

    #[test]
    fn feasibility_relation_matches_requirement_to_offer() -> CtResult<()> {
        let requirement = DesignRequirement::new(Throughput::new(100)?, LatencyMs::new(80)?);
        let offer = ImplementationOffer::new(Throughput::new(120)?, LatencyMs::new(50)?);

        assert!(FeasibilityRelation::relates(requirement, offer));
        Ok(())
    }

    #[test]
    fn signal_matrices_compose_like_flow_graph_semantics() -> CtResult<()> {
        let duplicate = SignalMatrix::new(
            MatrixRows::new(2)?,
            MatrixCols::new(1)?,
            vec![
                vec![SignalCoefficient::new(1)],
                vec![SignalCoefficient::new(1)],
            ],
        )?;
        let add_weighted = SignalMatrix::new(
            MatrixRows::new(1)?,
            MatrixCols::new(2)?,
            vec![vec![SignalCoefficient::new(2), SignalCoefficient::new(3)]],
        )?;

        let composed = add_weighted.compose_after(&duplicate)?;

        assert_eq!(composed.coefficients(), &[vec![SignalCoefficient::new(5)]]);
        Ok(())
    }

    #[test]
    fn signal_matrix_composition_rejects_mismatched_middle_dimension() -> CtResult<()> {
        let previous = SignalMatrix::new(
            MatrixRows::new(2)?,
            MatrixCols::new(1)?,
            vec![
                vec![SignalCoefficient::new(1)],
                vec![SignalCoefficient::new(1)],
            ],
        )?;
        let next = SignalMatrix::new(
            MatrixRows::new(1)?,
            MatrixCols::new(3)?,
            vec![vec![
                SignalCoefficient::new(1),
                SignalCoefficient::new(1),
                SignalCoefficient::new(1),
            ]],
        )?;

        assert!(matches!(
            next.compose_after(&previous),
            Err(CtError::ShapeMismatch {
                op: "signal matrix composition",
                ..
            })
        ));
        Ok(())
    }

    #[test]
    fn open_circuits_compose_in_series_and_parallel() -> CtResult<()> {
        let input = PortName::new("input")?;
        let middle = PortName::new("middle")?;
        let output = PortName::new("output")?;
        let first = OpenCircuit::new(
            [input],
            [middle],
            [CircuitComponent::resistor(
                input,
                middle,
                ResistanceOhms::new(10)?,
            )],
        )?;
        let second = OpenCircuit::new(
            [middle],
            [output],
            [CircuitComponent::resistor(
                middle,
                output,
                ResistanceOhms::new(20)?,
            )],
        )?;

        let serial = first.then(&second)?;
        let parallel = first.parallel(&second)?;

        assert_eq!(serial.input_count(), 1);
        assert_eq!(serial.output_count(), 1);
        assert_eq!(serial.component_count(), 2);
        assert_eq!(parallel.input_count(), 2);
        assert_eq!(parallel.output_count(), 2);
        Ok(())
    }

    #[test]
    fn open_circuit_serial_composition_rejects_boundary_mismatch() -> CtResult<()> {
        let input = PortName::new("input")?;
        let output = PortName::new("output")?;
        let left = OpenCircuit::new([input], [output], [])?;
        let right = OpenCircuit::new([input, output], [output], [])?;

        assert!(matches!(
            left.then(&right),
            Err(CtError::ShapeMismatch {
                op: "open circuit serial composition",
                ..
            })
        ));
        Ok(())
    }

    #[test]
    fn local_behavior_truth_glues_to_global_truth() -> CtResult<()> {
        let first = LocalSafetyCheck::new(
            TimeInterval::new(TimeTick::new(0), TimeTick::new(5))?,
            TruthValue::True,
        );
        let second = LocalSafetyCheck::new(
            TimeInterval::new(TimeTick::new(5), TimeTick::new(10))?,
            TruthValue::True,
        );
        let cover = SafetyCover::new([first, second])?;

        assert_eq!(cover.global_truth(), TruthValue::True);
        assert_eq!(
            TruthValue::True.implies(TruthValue::False),
            TruthValue::False
        );
        Ok(())
    }
}

src/challenges/mod.rs

//! Public challenge implementations that back the learner-facing challenge files.
//!
//! The intentionally broken exercise templates live under `challenges/` and are
//! not part of the normal Cargo build. This module contains the validated
//! reference behavior that keeps the public challenge honest.

pub mod papers;
pub mod typed_ai;

src/challenges/typed_ai.rs

//! Reference helpers for the Typed AI Rustlings challenge.

use crate::category::Morphism;
use crate::domain::{Distribution, Logits, Loss, Product, TokenId, VocabSize};
use crate::error::{CtError, CtResult};
use crate::ml::{CrossEntropy, Softmax};

/// Exposes the integer index only after the value has crossed the `TokenId`
/// boundary.
pub fn token_index(token: TokenId) -> usize {
    token.index()
}

/// Builds a uniform distribution for a non-empty vocabulary.
pub fn uniform_distribution(vocab_size: VocabSize) -> CtResult<Distribution> {
    let probability = 1.0 / vocab_size.value() as f32;
    Distribution::new(vec![probability; vocab_size.value()])
}

/// Computes target loss from raw logits by crossing the softmax boundary first.
pub fn loss_from_logits(logits: Logits, target: TokenId) -> CtResult<Loss> {
    let distribution = Softmax.apply(logits)?;
    CrossEntropy.apply(Product::new(distribution, target))
}

/// Validates that a target token can address a probability distribution.
pub fn require_target_in_distribution(
    distribution: &Distribution,
    target: TokenId,
) -> CtResult<()> {
    if target.index() >= distribution.as_slice().len() {
        return Err(CtError::OutOfRange {
            kind: "target token",
            index: target.index(),
            limit: distribution.as_slice().len(),
        });
    }

    Ok(())
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn token_index_requires_a_token_id_boundary() {
        assert_eq!(token_index(TokenId::new(7)), 7);
    }

    #[test]
    fn uniform_distribution_normalizes_vocab_size() -> CtResult<()> {
        let distribution = uniform_distribution(VocabSize::new(4)?)?;

        assert_eq!(distribution.as_slice(), &[0.25, 0.25, 0.25, 0.25]);
        Ok(())
    }

    #[test]
    fn loss_from_logits_crosses_softmax_before_cross_entropy() -> CtResult<()> {
        let loss = loss_from_logits(Logits::new(vec![0.0, 2.0]), TokenId::new(1))?;

        assert!(loss.value() < 0.2);
        Ok(())
    }

    #[test]
    fn target_validation_rejects_out_of_range_token() -> CtResult<()> {
        let distribution = uniform_distribution(VocabSize::new(2)?)?;

        assert!(matches!(
            require_target_in_distribution(&distribution, TokenId::new(9)),
            Err(CtError::OutOfRange {
                kind: "target token",
                index: 9,
                limit: 2,
            })
        ));

        Ok(())
    }
}

src/challenges/papers/mod.rs

//! Paper-To-Rust challenge seeds.

pub mod adam;

src/challenges/papers/adam.rs

//! A small Paper-To-Rust translation of Adam optimizer state.
//!
//! The challenge compiles one idea from the Adam paper: an optimizer step is not
//! just a parameter vector update. It carries first-moment, second-moment, and
//! step-count state forward.

use crate::category::Morphism;
use crate::domain::LearningRate;
use crate::error::{CtError, CtResult};

/// Number of coordinates in an Adam parameter or gradient vector.
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct AdamVectorDimension(usize);

impl AdamVectorDimension {
    pub fn new(value: usize) -> CtResult<Self> {
        if value == 0 {
            return Err(CtError::EmptyInput("Adam vector dimension"));
        }

        Ok(Self(value))
    }

    pub fn value(&self) -> usize {
        self.0
    }
}

/// Model parameters owned by the Adam challenge.
#[derive(Debug, Clone, PartialEq)]
pub struct AdamParameterVector(Vec<f32>);

impl AdamParameterVector {
    pub fn new(values: Vec<f32>) -> CtResult<Self> {
        validate_finite_non_empty("Adam parameter vector", &values)?;
        Ok(Self(values))
    }

    pub fn dimension(&self) -> AdamVectorDimension {
        AdamVectorDimension(self.0.len())
    }

    pub fn as_slice(&self) -> &[f32] {
        &self.0
    }
}

/// Gradient vector for one Adam update.
#[derive(Debug, Clone, PartialEq)]
pub struct AdamGradientVector(Vec<f32>);

impl AdamGradientVector {
    pub fn new(values: Vec<f32>) -> CtResult<Self> {
        validate_finite_non_empty("Adam gradient vector", &values)?;
        Ok(Self(values))
    }

    pub fn dimension(&self) -> AdamVectorDimension {
        AdamVectorDimension(self.0.len())
    }

    pub fn as_slice(&self) -> &[f32] {
        &self.0
    }
}

/// Exponential moving average of gradients.
#[derive(Debug, Clone, PartialEq)]
pub struct AdamFirstMoment(Vec<f32>);

impl AdamFirstMoment {
    pub fn zeros(dimension: AdamVectorDimension) -> Self {
        Self(vec![0.0; dimension.value()])
    }

    pub fn as_slice(&self) -> &[f32] {
        &self.0
    }
}

/// Exponential moving average of squared gradients.
#[derive(Debug, Clone, PartialEq)]
pub struct AdamSecondMoment(Vec<f32>);

impl AdamSecondMoment {
    pub fn zeros(dimension: AdamVectorDimension) -> Self {
        Self(vec![0.0; dimension.value()])
    }

    pub fn as_slice(&self) -> &[f32] {
        &self.0
    }
}

/// A decay coefficient in the half-open range `[0, 1)`.
#[derive(Debug, Clone, Copy, PartialEq)]
pub struct AdamDecayRate(f32);

impl AdamDecayRate {
    pub fn new(value: f32) -> CtResult<Self> {
        if !value.is_finite() || !(0.0..1.0).contains(&value) {
            return Err(CtError::InvalidScalar {
                kind: "Adam decay rate",
                value,
            });
        }

        Ok(Self(value))
    }

    pub fn value(&self) -> f32 {
        self.0
    }
}

/// Small positive stabilizer in the Adam denominator.
#[derive(Debug, Clone, Copy, PartialEq)]
pub struct AdamEpsilon(f32);

impl AdamEpsilon {
    pub fn new(value: f32) -> CtResult<Self> {
        if !value.is_finite() || value <= 0.0 {
            return Err(CtError::InvalidScalar {
                kind: "Adam epsilon",
                value,
            });
        }

        Ok(Self(value))
    }

    pub fn value(&self) -> f32 {
        self.0
    }
}

/// Count of Adam updates already applied.
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct AdamStepCount(usize);

impl AdamStepCount {
    pub fn zero() -> Self {
        Self(0)
    }

    pub fn value(&self) -> usize {
        self.0
    }

    fn next(self) -> Self {
        Self(self.0 + 1)
    }
}

/// Adam hyperparameters that affect one optimizer transition.
#[derive(Debug, Clone, Copy, PartialEq)]
pub struct AdamConfig {
    learning_rate: LearningRate,
    beta1: AdamDecayRate,
    beta2: AdamDecayRate,
    epsilon: AdamEpsilon,
}

impl AdamConfig {
    pub fn new(
        learning_rate: LearningRate,
        beta1: AdamDecayRate,
        beta2: AdamDecayRate,
        epsilon: AdamEpsilon,
    ) -> Self {
        Self {
            learning_rate,
            beta1,
            beta2,
            epsilon,
        }
    }

    pub fn learning_rate(&self) -> LearningRate {
        self.learning_rate
    }

    pub fn beta1(&self) -> AdamDecayRate {
        self.beta1
    }

    pub fn beta2(&self) -> AdamDecayRate {
        self.beta2
    }

    pub fn epsilon(&self) -> AdamEpsilon {
        self.epsilon
    }
}

/// Adam optimizer memory that must move with the parameters.
#[derive(Debug, Clone, PartialEq)]
pub struct AdamOptimizerState {
    first_moment: AdamFirstMoment,
    second_moment: AdamSecondMoment,
    step_count: AdamStepCount,
}

impl AdamOptimizerState {
    pub fn zeros(dimension: AdamVectorDimension) -> Self {
        Self {
            first_moment: AdamFirstMoment::zeros(dimension),
            second_moment: AdamSecondMoment::zeros(dimension),
            step_count: AdamStepCount::zero(),
        }
    }

    pub fn first_moment(&self) -> &AdamFirstMoment {
        &self.first_moment
    }

    pub fn second_moment(&self) -> &AdamSecondMoment {
        &self.second_moment
    }

    pub fn step_count(&self) -> AdamStepCount {
        self.step_count
    }
}

/// Complete Adam state at an optimizer boundary.
#[derive(Debug, Clone, PartialEq)]
pub struct AdamModelState {
    parameters: AdamParameterVector,
    optimizer: AdamOptimizerState,
}

impl AdamModelState {
    pub fn new(parameters: AdamParameterVector, optimizer: AdamOptimizerState) -> CtResult<Self> {
        validate_matching_dimension(
            "Adam model state",
            parameters.dimension(),
            optimizer.first_moment.as_slice().len(),
        )?;
        validate_matching_dimension(
            "Adam model state",
            parameters.dimension(),
            optimizer.second_moment.as_slice().len(),
        )?;

        Ok(Self {
            parameters,
            optimizer,
        })
    }

    pub fn from_parameters(parameters: AdamParameterVector) -> Self {
        let optimizer = AdamOptimizerState::zeros(parameters.dimension());
        Self {
            parameters,
            optimizer,
        }
    }

    pub fn parameters(&self) -> &AdamParameterVector {
        &self.parameters
    }

    pub fn optimizer(&self) -> &AdamOptimizerState {
        &self.optimizer
    }
}

/// One Adam optimizer step as an endomorphism on `AdamModelState`.
#[derive(Debug, Clone, PartialEq)]
pub struct AdamTrainStep {
    gradient: AdamGradientVector,
    config: AdamConfig,
}

impl AdamTrainStep {
    pub fn new(gradient: AdamGradientVector, config: AdamConfig) -> Self {
        Self { gradient, config }
    }
}

impl Morphism<AdamModelState, AdamModelState> for AdamTrainStep {
    fn name(&self) -> &'static str {
        "adam_train_step"
    }

    fn apply(&self, state: AdamModelState) -> CtResult<AdamModelState> {
        let dimension = state.parameters.dimension();
        validate_matching_dimension("Adam gradient", dimension, self.gradient.as_slice().len())?;
        validate_matching_dimension(
            "Adam first moment",
            dimension,
            state.optimizer.first_moment.as_slice().len(),
        )?;
        validate_matching_dimension(
            "Adam second moment",
            dimension,
            state.optimizer.second_moment.as_slice().len(),
        )?;

        let next_step = state.optimizer.step_count.next();
        let beta1 = self.config.beta1.value();
        let beta2 = self.config.beta2.value();
        let learning_rate = self.config.learning_rate.value();
        let epsilon = self.config.epsilon.value();
        let bias_correction_1 = 1.0 - beta1.powf(next_step.value() as f32);
        let bias_correction_2 = 1.0 - beta2.powf(next_step.value() as f32);

        let mut next_parameters = Vec::with_capacity(dimension.value());
        let mut next_first_moment = Vec::with_capacity(dimension.value());
        let mut next_second_moment = Vec::with_capacity(dimension.value());

        for (((parameter, gradient), first_moment), second_moment) in state
            .parameters
            .as_slice()
            .iter()
            .copied()
            .zip(self.gradient.as_slice().iter().copied())
            .zip(state.optimizer.first_moment.as_slice().iter().copied())
            .zip(state.optimizer.second_moment.as_slice().iter().copied())
        {
            let updated_first_moment = beta1 * first_moment + (1.0 - beta1) * gradient;
            let updated_second_moment = beta2 * second_moment + (1.0 - beta2) * gradient * gradient;
            let corrected_first_moment = updated_first_moment / bias_correction_1;
            let corrected_second_moment = updated_second_moment / bias_correction_2;
            let updated_parameter = parameter
                - learning_rate * corrected_first_moment
                    / (corrected_second_moment.sqrt() + epsilon);

            next_parameters.push(updated_parameter);
            next_first_moment.push(updated_first_moment);
            next_second_moment.push(updated_second_moment);
        }

        AdamModelState::new(
            AdamParameterVector::new(next_parameters)?,
            AdamOptimizerState {
                first_moment: AdamFirstMoment(next_first_moment),
                second_moment: AdamSecondMoment(next_second_moment),
                step_count: next_step,
            },
        )
    }
}

fn validate_finite_non_empty(kind: &'static str, values: &[f32]) -> CtResult<()> {
    if values.is_empty() {
        return Err(CtError::EmptyInput(kind));
    }

    if let Some(value) = values.iter().copied().find(|value| !value.is_finite()) {
        return Err(CtError::InvalidScalar { kind, value });
    }

    Ok(())
}

fn validate_matching_dimension(
    op: &'static str,
    expected: AdamVectorDimension,
    got: usize,
) -> CtResult<()> {
    if expected.value() != got {
        return Err(CtError::ShapeMismatch {
            op,
            expected: format!("dimension {}", expected.value()),
            got: format!("dimension {got}"),
        });
    }

    Ok(())
}

#[cfg(test)]
mod tests {
    use super::*;

    fn standard_config() -> CtResult<AdamConfig> {
        Ok(AdamConfig::new(
            LearningRate::new(0.1)?,
            AdamDecayRate::new(0.9)?,
            AdamDecayRate::new(0.999)?,
            AdamEpsilon::new(1e-8)?,
        ))
    }

    #[test]
    fn adam_first_step_matches_bias_corrected_update() -> CtResult<()> {
        let state = AdamModelState::from_parameters(AdamParameterVector::new(vec![1.0, -1.0])?);
        let step = AdamTrainStep::new(
            AdamGradientVector::new(vec![0.5, -0.25])?,
            standard_config()?,
        );
        let updated = step.apply(state)?;

        assert_eq!(updated.optimizer().step_count().value(), 1);
        assert!((updated.parameters().as_slice()[0] - 0.9).abs() < 1e-5);
        assert!((updated.parameters().as_slice()[1] + 0.9).abs() < 1e-5);
        assert!((updated.optimizer().first_moment().as_slice()[0] - 0.05).abs() < 1e-6);
        assert!((updated.optimizer().second_moment().as_slice()[0] - 0.00025).abs() < 1e-7);

        Ok(())
    }

    #[test]
    fn adam_rejects_bad_decay_rate() {
        assert!(matches!(
            AdamDecayRate::new(1.0),
            Err(CtError::InvalidScalar {
                kind: "Adam decay rate",
                value: 1.0,
            })
        ));
    }

    #[test]
    fn adam_rejects_gradient_dimension_mismatch() -> CtResult<()> {
        let state = AdamModelState::from_parameters(AdamParameterVector::new(vec![1.0, -1.0])?);
        let step = AdamTrainStep::new(AdamGradientVector::new(vec![0.5])?, standard_config()?);

        assert!(matches!(
            step.apply(state),
            Err(CtError::ShapeMismatch {
                op: "Adam gradient",
                ..
            })
        ));

        Ok(())
    }

    #[test]
    fn adam_step_preserves_complete_optimizer_state() -> CtResult<()> {
        let state = AdamModelState::from_parameters(AdamParameterVector::new(vec![1.0])?);
        let step = AdamTrainStep::new(AdamGradientVector::new(vec![0.5])?, standard_config()?);
        let once = step.apply(state)?;
        let twice = step.apply(once)?;

        assert_eq!(twice.optimizer().step_count().value(), 2);
        assert_eq!(twice.parameters().dimension().value(), 1);
        assert_eq!(twice.optimizer().first_moment().as_slice().len(), 1);
        assert_eq!(twice.optimizer().second_moment().as_slice().len(), 1);

        Ok(())
    }
}

src/demo.rs

use crate::calculus::{LocalGradient, MulOp, Scalar};
use crate::category::{Compose, StepCount, apply_endomorphism_n_times};
use crate::domain::{
    LearningRate, Logits, ModelDimension, Parameters, Product, TokenId, TokenSequence, Vector,
    VocabSize,
};
use crate::error::CtResult;
use crate::ml::{
    CrossEntropy, DatasetWindowing, Embedding, LinearToLogits, Softmax, average_loss,
    composed_prediction_matches_direct_prediction,
};
use crate::structure::{
    Functor, Monoid, OptionFunctor, PipelineTrace, TraceStep, VecFunctor,
    monoid_laws_hold_for_pipeline_trace, naturality_square_holds_for_first_option,
};
use crate::training::TrainStep;
use crate::{Identity, Morphism};

/// Run the full terminal walkthrough used by `cargo run --bin category_ml`.
pub fn run_demo() -> CtResult<()> {
    println!("Category theory concepts implemented in Rust 2024");
    println!("=================================================\n");

    let vocab = ["<pad>", "I", "love", "Rust", "."];
    let raw_text = TokenSequence::from_indices([1, 2, 3, 4, 1, 2, 3, 4])?;

    println!("1. Object examples");
    println!("   TokenId(1) means {:?}\n", vocab[1]);

    println!("2. Dataset morphism: TokenSequence -> TrainingSet");
    let dataset = DatasetWindowing.apply(raw_text)?;
    for example in dataset.examples() {
        println!(
            "   {:?} -> {:?}",
            vocab[example.first().index()],
            vocab[example.second().index()]
        );
    }
    println!();

    println!("3. Identity morphism: id_Vector : Vector -> Vector");
    let v = Vector::new(vec![1.0, 2.0, 3.0]);
    let same_v = Identity::<Vector>::new().apply(v.clone())?;
    println!("   input  = {:?}", v);
    println!("   output = {:?}\n", same_v);

    println!("4. Composition: Softmax after Linear after Embedding");
    let params = Parameters::init(VocabSize::new(vocab.len())?, ModelDimension::new(4)?);
    let embedding = Embedding::from_parameters(&params);
    let linear = LinearToLogits::from_parameters(&params);
    let token_to_logits = Compose::<_, _, Vector>::new(embedding, linear);
    let token_to_distribution = Compose::<_, _, Logits>::new(token_to_logits, Softmax);
    let distribution = token_to_distribution.apply(TokenId::new(1))?;
    println!("   P(next token | 'I') = {:?}\n", distribution.as_slice());

    println!("5. Product object: Prediction x Target -> Loss");
    let loss = CrossEntropy.apply(Product::new(distribution, TokenId::new(2)))?;
    println!("   loss for target 'love' = {:.6}\n", loss.value());

    println!("6. Endomorphism: TrainStep : Parameters -> Parameters");
    let before = average_loss(&params, &dataset)?;
    let train_step = TrainStep::new(dataset.clone(), LearningRate::new(1.0)?);
    let trained_params =
        apply_endomorphism_n_times(&train_step, params.clone(), StepCount::new(80))?;
    let after = average_loss(&trained_params, &dataset)?;
    println!("   average loss before training = {:.6}", before.value());
    println!("   average loss after  training = {:.6}\n", after.value());

    println!("7. Functor: fmap over Vec and Option");
    let xs = vec![1, 2, 3];
    let ys = VecFunctor::fmap(xs, |x| x * x);
    let maybe = OptionFunctor::fmap(Some(7), |x| x + 1);
    println!("   VecFunctor fmap square: {:?}", ys);
    println!("   OptionFunctor fmap +1: {:?}\n", maybe);

    println!("8. Natural transformation: Vec<A> -> Option<A>");
    println!(
        "   naturality square holds: {}\n",
        naturality_square_holds_for_first_option()
    );

    println!("9. Monoid: pipeline traces compose associatively with identity");
    let trace = PipelineTrace::from_steps(vec![TraceStep::new("embedding")])
        .combine(&PipelineTrace::from_steps(vec![TraceStep::new("linear")]))
        .combine(&PipelineTrace::from_steps(vec![TraceStep::new("softmax")]));
    println!("   trace = {:?}", trace.names());
    println!(
        "   monoid laws hold for this trace type: {}\n",
        monoid_laws_hold_for_pipeline_trace()
    );

    println!("10. Commutative diagram check");
    println!(
        "   composed prediction == direct prediction: {}\n",
        composed_prediction_matches_direct_prediction(&params)?
    );

    println!("11. Chain rule / local derivative morphism");
    let mul = MulOp;
    let x = Scalar::new(2.0)?;
    let y = Scalar::new(3.0)?;
    let z = mul.forward(x, y)?;
    let upstream = LocalGradient::new(1.0)?;
    let (dl_dx, dl_dy) = mul.backward(x, y, upstream)?;
    println!("   z = x * y = {}", z.value());
    println!(
        "   if dL/dz = {}, then dL/dx = {}, dL/dy = {}\n",
        upstream.value(),
        dl_dx.value(),
        dl_dy.value()
    );

    println!("Compressed categorical training view:");
    println!("   Dataset x Parameters -> Prediction -> Loss -> Gradients -> Updated Parameters");
    println!("   TrainStep is repeated as Parameters0 -> Parameters1 -> ... -> ParametersN");

    Ok(())
}

src/bin/category_ml.rs

fn main() -> category_theory_transformer_rs::CtResult<()> {
    category_theory_transformer_rs::run_demo()
}

Runnable Examples

examples/01_token_sequence.rs

use category_theory_transformer_rs::{CtResult, Product, TokenId, TokenSequence, TrainingExample};

fn main() -> CtResult<()> {
    let raw_input = "rust makes ai structure visible";
    let token_ids = tokenize_visible_structure(raw_input);
    let sequence = TokenSequence::new(token_ids)?;
    let training_pairs = training_pairs(sequence.as_slice())?;

    println!("Raw input:");
    println!("\"{raw_input}\"");
    println!();
    println!("TokenSequence:");
    println!("{}", format_token_sequence(sequence.as_slice()));
    println!();
    println!("TrainingPairs:");
    for pair in &training_pairs {
        println!("{}", format_training_pair(pair));
    }
    println!();
    println!("Typed transformation:");
    println!("Text -> TokenSequence -> TrainingPairs");
    println!();
    println!("No framework magic.");
    println!("Just explicit structure.");

    Ok(())
}

fn tokenize_visible_structure(input: &str) -> Vec<TokenId> {
    input
        .split_whitespace()
        .map(|word| match word {
            "rust" => TokenId::new(12),
            "makes" => TokenId::new(44),
            "ai" => TokenId::new(7),
            "structure" => TokenId::new(19),
            "visible" => TokenId::new(91),
            _ => TokenId::new(0),
        })
        .collect()
}

fn training_pairs(tokens: &[TokenId]) -> CtResult<Vec<TrainingExample>> {
    TokenSequence::new(tokens.iter().copied())?;

    Ok(tokens
        .windows(2)
        .map(|pair| Product::new(pair[0], pair[1]))
        .collect())
}

fn format_token_sequence(tokens: &[TokenId]) -> String {
    let formatted = tokens
        .iter()
        .map(format_token_id)
        .collect::<Vec<_>>()
        .join(", ");

    format!("[{formatted}]")
}

fn format_training_pair(pair: &TrainingExample) -> String {
    format!(
        "({} -> {})",
        format_token_id(pair.first()),
        format_token_id(pair.second())
    )
}

fn format_token_id(token: &TokenId) -> String {
    format!("TokenId({})", token.index())
}

examples/01_domain_objects.rs

use category_theory_transformer_rs::{
    CtResult, DatasetWindowing, Morphism, TokenId, TokenSequence,
};

fn main() -> CtResult<()> {
    let tokens = TokenSequence::from_indices([1, 2, 3, 4])?;
    let dataset = DatasetWindowing.apply(tokens.clone())?;

    println!("TokenSequence:");
    println!("{}", format_token_sequence(tokens.as_slice()));
    println!();
    println!("TrainingSet:");
    for example in dataset.examples() {
        println!(
            "({} -> {})",
            format_token_id(example.first()),
            format_token_id(example.second())
        );
    }
    println!();
    println!("Typed boundaries:");
    println!("usize -> TokenId");
    println!("Vec<TokenId> -> TokenSequence");
    println!("TokenSequence -> TrainingSet");
    println!("TrainingExample = Product<TokenId, TokenId>");

    Ok(())
}

fn format_token_sequence(tokens: &[TokenId]) -> String {
    let formatted = tokens
        .iter()
        .map(format_token_id)
        .collect::<Vec<_>>()
        .join(", ");

    format!("[{formatted}]")
}

fn format_token_id(token: &TokenId) -> String {
    format!("TokenId({})", token.index())
}

examples/02_morphism_composition.rs

use category_theory_transformer_rs::{
    Compose, CtResult, Distribution, Embedding, LinearToLogits, Logits, ModelDimension, Morphism,
    Parameters, Softmax, TokenId, Vector, VocabSize,
};

fn main() -> CtResult<()> {
    let params = Parameters::init(VocabSize::new(5)?, ModelDimension::new(4)?);
    let token = TokenId::new(1);
    let embedding = Embedding::from_parameters(&params);
    let linear = LinearToLogits::from_parameters(&params);

    let token_to_logits = Compose::<_, _, Vector>::new(embedding.clone(), linear.clone());
    let token_to_distribution = Compose::<_, _, Logits>::new(token_to_logits, Softmax);

    let vector = embedding.apply(token)?;
    let logits = linear.apply(vector.clone())?;
    let distribution = Softmax.apply(logits.clone())?;
    let composed_distribution = token_to_distribution.apply(token)?;

    println!("Input object:");
    println!("TokenId({})", token.index());
    println!();
    println!("Stage outputs:");
    println!("Embedding : TokenId -> Vector");
    println!("{}", format_vector(&vector));
    println!("LinearToLogits : Vector -> Logits");
    println!("{}", format_logits(&logits));
    println!("Softmax : Logits -> Distribution");
    println!("{}", format_distribution(&distribution));
    println!();
    println!("Composed morphism:");
    println!("TokenId -> Distribution");
    println!(
        "next-token probabilities: {}",
        format_values(composed_distribution.as_slice())
    );
    println!();
    println!("Middle objects kept visible:");
    println!("Vector");
    println!("Logits");
    println!();
    println!("Composition rule:");
    println!("first target must equal second source");
    println!("Embedding then LinearToLogits is legal because Vector == Vector");
    println!("Embedding then Softmax is illegal because Vector != Logits");

    Ok(())
}

fn format_vector(vector: &Vector) -> String {
    format!(
        "Vector(dim={}, values={})",
        vector.as_slice().len(),
        format_values(vector.as_slice())
    )
}

fn format_logits(logits: &Logits) -> String {
    format!(
        "Logits(vocab={}, values={})",
        logits.as_slice().len(),
        format_values(logits.as_slice())
    )
}

fn format_distribution(distribution: &Distribution) -> String {
    format!(
        "Distribution(vocab={}, sum={:.6}, values={})",
        distribution.as_slice().len(),
        distribution.as_slice().iter().sum::<f32>(),
        format_values(distribution.as_slice())
    )
}

fn format_values(values: &[f32]) -> String {
    let formatted = values
        .iter()
        .map(|value| format!("{value:.6}"))
        .collect::<Vec<_>>()
        .join(", ");

    format!("[{formatted}]")
}

examples/03_training_endomorphism.rs

use category_theory_transformer_rs::{
    CtResult, DatasetWindowing, LearningRate, ModelDimension, Morphism, Parameters, StepCount,
    TokenSequence, TrainStep, VocabSize, apply_endomorphism_n_times, average_loss,
};

fn main() -> CtResult<()> {
    let tokens = TokenSequence::from_indices([1, 2, 3, 4, 1, 2, 3, 4])?;
    let dataset = DatasetWindowing.apply(tokens)?;
    let params = Parameters::init(VocabSize::new(5)?, ModelDimension::new(4)?);

    let before = average_loss(&params, &dataset)?;
    let train_step = TrainStep::new(dataset.clone(), LearningRate::new(1.0)?);
    let trained = apply_endomorphism_n_times(&train_step, params, StepCount::new(80))?;
    let after = average_loss(&trained, &dataset)?;

    println!("loss before: {:.6}", before.value());
    println!("loss after:  {:.6}", after.value());
    println!();
    println!("Typed transformation:");
    println!("TrainStep : Parameters -> Parameters");
    println!("Repeated endomorphism:");
    println!("Parameters0 -> Parameters1 -> ... -> Parameters80");
    println!("Measurement:");
    println!("Parameters x TrainingSet -> Loss");

    Ok(())
}

examples/04_structure_and_calculus.rs

use category_theory_transformer_rs::{
    CtResult, Functor, LocalGradient, Monoid, MulOp, OptionFunctor, PipelineTrace, Scalar,
    TraceStep, VecFunctor, monoid_laws_hold_for_pipeline_trace,
    naturality_square_holds_for_first_option,
};

fn main() -> CtResult<()> {
    let squared = VecFunctor::fmap(vec![1, 2, 3], |x| x * x);
    let shifted = OptionFunctor::fmap(Some(7), |x| x + 1);

    println!("Vec fmap square: {squared:?}");
    println!("Option fmap +1: {shifted:?}");
    println!(
        "naturality square holds: {}",
        naturality_square_holds_for_first_option()
    );

    let trace = PipelineTrace::from_steps(vec![TraceStep::new("embedding")])
        .combine(&PipelineTrace::from_steps(vec![TraceStep::new("linear")]))
        .combine(&PipelineTrace::from_steps(vec![TraceStep::new("softmax")]));

    println!("trace: {:?}", trace.names());
    println!(
        "monoid laws hold: {}",
        monoid_laws_hold_for_pipeline_trace()
    );

    let mul = MulOp;
    let x = Scalar::new(2.0)?;
    let y = Scalar::new(3.0)?;
    let upstream = LocalGradient::new(1.0)?;
    let (dl_dx, dl_dy) = mul.backward(x, y, upstream)?;

    println!("dL/dx: {}", dl_dx.value());
    println!("dL/dy: {}", dl_dy.value());
    println!();
    println!("Typed transformation:");
    println!("VecFunctor::fmap : Vec<A> x (A -> B) -> Vec<B>");
    println!("OptionFunctor::fmap : Option<A> x (A -> B) -> Option<B>");
    println!("Naturality square:");
    println!("Vec<A> -> Vec<B> -> Option<B>");
    println!("Vec<A> -> Option<A> -> Option<B>");
    println!("Monoid:");
    println!("PipelineTrace x PipelineTrace -> PipelineTrace");
    println!("Chain rule:");
    println!("Scalar x Scalar -> Scalar");
    println!("dL/dz -> (dL/dx, dL/dy)");

    Ok(())
}

examples/05_seven_sketches.rs

use category_theory_transformer_rs::{
    CircuitComponent, CompanyInstance, CtResult, DepartmentId, DesignRequirement, EmployeeId,
    EmployeeRecord, FeasibilityRelation, FeatureCount, ImplementationOffer, InformationLevel,
    LatencyMs, LayerBudget, LocalSafetyCheck, MatrixCols, MatrixRows, OpenCircuit, PortName,
    ResistanceOhms, ResourceAmount, ResourceBundle, SafetyCover, SignalCoefficient, SignalMatrix,
    Throughput, TimeInterval, TimeTick, TruthValue, abstract_to_layer_budget,
    concretize_layer_budget, feature_layer_galois_law_holds, information_order_obeys_preorder_laws,
    resource_tensor_is_monotone,
};

fn main() -> CtResult<()> {
    println!(
        "orders obey preorder laws: {}",
        information_order_obeys_preorder_laws()
    );
    println!(
        "join(feature, decision): {:?}",
        InformationLevel::Feature.join(InformationLevel::Decision)
    );

    let features = FeatureCount::new(9)?;
    let layers = LayerBudget::new(3)?;
    println!(
        "feature/layer Galois law: {}",
        feature_layer_galois_law_holds(features, layers)?
    );
    println!(
        "abstract 9 features to {} layers; concretize 3 layers to {} features",
        abstract_to_layer_budget(features)?.value(),
        concretize_layer_budget(layers).value()
    );

    let encoder = ResourceBundle::new(ResourceAmount::new(2), ResourceAmount::new(8));
    let decoder = ResourceBundle::new(ResourceAmount::new(3), ResourceAmount::new(10));
    println!("combined resource bundle: {:?}", encoder.tensor(&decoder));
    println!(
        "resource tensor monotone: {}",
        resource_tensor_is_monotone()
    );

    let research = DepartmentId::new(1);
    let platform = DepartmentId::new(2);
    let ada = EmployeeId::new(7);
    let instance =
        CompanyInstance::new([research, platform], [EmployeeRecord::new(ada, research)])?;
    println!(
        "employee {:?} belongs to department {:?}",
        ada,
        instance.department_of(ada)
    );

    let requirement = DesignRequirement::new(Throughput::new(100)?, LatencyMs::new(80)?);
    let offer = ImplementationOffer::new(Throughput::new(120)?, LatencyMs::new(50)?);
    println!(
        "co-design offer feasible: {}",
        FeasibilityRelation::relates(requirement, offer)
    );

    let duplicate = SignalMatrix::new(
        MatrixRows::new(2)?,
        MatrixCols::new(1)?,
        vec![
            vec![SignalCoefficient::new(1)],
            vec![SignalCoefficient::new(1)],
        ],
    )?;
    let add_weighted = SignalMatrix::new(
        MatrixRows::new(1)?,
        MatrixCols::new(2)?,
        vec![vec![SignalCoefficient::new(2), SignalCoefficient::new(3)]],
    )?;
    let composed = add_weighted.compose_after(&duplicate)?;
    println!(
        "signal-flow matrix semantics: {:?}",
        composed.coefficients()
    );

    let input = PortName::new("input")?;
    let middle = PortName::new("middle")?;
    let output = PortName::new("output")?;
    let first_circuit = OpenCircuit::new(
        [input],
        [middle],
        [CircuitComponent::resistor(
            input,
            middle,
            ResistanceOhms::new(10)?,
        )],
    )?;
    let second_circuit = OpenCircuit::new(
        [middle],
        [output],
        [CircuitComponent::resistor(
            middle,
            output,
            ResistanceOhms::new(20)?,
        )],
    )?;
    println!(
        "serial circuit component count: {}",
        first_circuit.then(&second_circuit)?.component_count()
    );

    let safety = SafetyCover::new([
        LocalSafetyCheck::new(
            TimeInterval::new(TimeTick::new(0), TimeTick::new(5))?,
            TruthValue::True,
        ),
        LocalSafetyCheck::new(
            TimeInterval::new(TimeTick::new(5), TimeTick::new(10))?,
            TruthValue::True,
        ),
    ])?;
    println!("global behavior truth: {:?}", safety.global_truth());
    println!();
    println!("Typed transformation:");
    println!("InformationLevel <= InformationLevel checks preorder");
    println!("FeatureCount <-> LayerBudget checks Galois law");
    println!("ResourceBundle x ResourceBundle -> ResourceBundle");
    println!("EmployeeRecord -> DepartmentId must resolve in CompanyInstance");
    println!("DesignRequirement x ImplementationOffer -> bool");
    println!("SignalMatrix x SignalMatrix -> SignalMatrix when dimensions match");
    println!("OpenCircuit x OpenCircuit -> OpenCircuit when ports match");
    println!("SafetyCover -> TruthValue");

    Ok(())
}

examples/06_attention_scores.rs

use category_theory_transformer_rs::{
    AttentionHeadOutputs, AttentionMask, AttentionOutput, AttentionOutputProjection,
    AttentionSoftmax, ConcatenateHeads, CtResult, HiddenSequence, HiddenToKey, HiddenToQuery,
    HiddenToValue, KeySequence, LayerNormParameters, LayerNormalization, LearningRate,
    MaskedAttentionScores, MaskedMultiHeadTransformerBlock, Morphism, MultiHeadTransformerBlock,
    PositionWiseFeedForward, PositionalEncoding, Product, QuerySequence, ResidualConnection,
    ScaledDotProductScores, SelfAttentionHead, SingleHeadTransformerBlock,
    TinyTransformerParameters, TokenSequence, TransformerBlockTrainStep,
    TransformerBlockTrainingExample, TransformerBlockTrainingSet, TransformerFeedForwardTrainStep,
    TransformerFeedForwardTrainingExample, TransformerFeedForwardTrainingSet, TransformerReadout,
    TransformerReadoutTrainStep, TransformerReadoutTrainingExample, TransformerReadoutTrainingSet,
    TransformerTrainingState, ValueSequence, WeightedValueMixing, transformer_block_average_loss,
    transformer_feed_forward_average_loss, transformer_readout_average_loss,
};

fn main() -> CtResult<()> {
    let queries = QuerySequence::new(vec![vec![1.0, 0.0], vec![0.0, 1.0]])?;
    let keys = KeySequence::new(vec![vec![1.0, 0.0], vec![0.0, 1.0], vec![1.0, 1.0]])?;
    let values = ValueSequence::new(vec![vec![1.0, 10.0], vec![2.0, 20.0], vec![3.0, 30.0]])?;
    let mask = AttentionMask::new(vec![vec![true, false, true], vec![true, true, true]])?;

    println!("Q/K/V source diagnostic:");
    println!("query rows own score rows; key/value rows own score columns");
    println!(
        "self-attention shares the hidden source before projection; projected roles stay distinct"
    );
    println!("mask polarity here: true = allowed, false = blocked\n");

    let scores = ScaledDotProductScores.apply(Product::new(queries, keys))?;
    let masked_scores = MaskedAttentionScores.apply(Product::new(scores, mask))?;
    let weights = AttentionSoftmax.apply(masked_scores)?;

    println!(
        "attention shape: {} query positions x {} key positions",
        weights.query_len().value(),
        weights.key_len().value()
    );

    for (query_position, row) in weights.rows().iter().enumerate() {
        println!("query {query_position} attends with {:?}", row.as_slice());
    }

    let output = WeightedValueMixing.apply(Product::new(weights, values))?;

    for (query_position, row) in output.rows().iter().enumerate() {
        println!("query {query_position} output vector {:?}", row.as_slice());
    }

    let second_head = AttentionOutput::new(vec![vec![10.0, 1.0], vec![20.0, 2.0]])?;
    let head_outputs = AttentionHeadOutputs::new(vec![output, second_head])?;
    let multi_head = ConcatenateHeads.apply(head_outputs)?;

    println!(
        "multi-head shape: {} heads x {} features -> model dimension {}",
        multi_head.head_count().value(),
        multi_head.head_dimension().value(),
        multi_head.model_dimension().value()
    );

    for (query_position, row) in multi_head.rows().iter().enumerate() {
        println!("query {query_position} multi-head row {:?}", row.as_slice());
    }

    let output_projection = AttentionOutputProjection::new(
        vec![
            vec![1.0, 0.0],
            vec![0.0, 0.1],
            vec![0.5, 0.0],
            vec![0.0, 1.0],
        ],
        vec![0.0, 0.0],
    )?;
    let projected = output_projection.apply(multi_head)?;

    println!(
        "projected attention shape: {} positions x model dimension {}",
        projected.sequence_len().value(),
        projected.model_dimension().value()
    );

    for (query_position, row) in projected.rows().iter().enumerate() {
        println!(
            "query {query_position} projected attention row {:?}",
            row.as_slice()
        );
    }

    let hidden_input = HiddenSequence::new(vec![vec![0.5, 0.5], vec![1.0, 1.0]])?;
    let residual = ResidualConnection.apply(Product::new(hidden_input, projected))?;

    println!(
        "residual shape: {} positions x model dimension {}",
        residual.sequence_len().value(),
        residual.model_dimension().value()
    );

    for (query_position, row) in residual.rows().iter().enumerate() {
        println!("query {query_position} residual row {:?}", row.as_slice());
    }

    let layer_norm =
        LayerNormalization::new(LayerNormParameters::identity(residual.model_dimension()));
    let normalized = layer_norm.apply(residual)?;

    println!(
        "normalized shape: {} positions x model dimension {}",
        normalized.sequence_len().value(),
        normalized.model_dimension().value()
    );

    for (query_position, row) in normalized.rows().iter().enumerate() {
        println!("query {query_position} normalized row {:?}", row.as_slice());
    }

    let feed_forward = PositionWiseFeedForward::new(
        vec![vec![1.0, -1.0, 0.5], vec![0.0, 1.0, 0.5]],
        vec![0.0, 0.0, 0.0],
        vec![vec![1.0, 0.0], vec![0.0, 1.0], vec![0.5, 0.5]],
        vec![0.0, 0.0],
    )?;
    let fed_forward = feed_forward.apply(normalized)?;

    println!(
        "feed-forward shape: {} positions x model dimension {}",
        fed_forward.sequence_len().value(),
        fed_forward.model_dimension().value()
    );

    for (query_position, row) in fed_forward.rows().iter().enumerate() {
        println!(
            "query {query_position} feed-forward row {:?}",
            row.as_slice()
        );
    }

    let positional_encoding = PositionalEncoding::new(vec![vec![0.1, 0.0], vec![0.0, 0.1]])?;
    let positioned_hidden =
        positional_encoding.apply(HiddenSequence::new(vec![vec![1.0, 0.0], vec![0.0, 1.0]])?)?;

    println!(
        "positioned hidden shape: {} positions x model dimension {}",
        positioned_hidden.sequence_len().value(),
        positioned_hidden.model_dimension().value()
    );

    let block = SingleHeadTransformerBlock::new(
        HiddenToQuery::new(vec![vec![1.0, 0.0], vec![0.0, 1.0]], vec![0.0, 0.0])?,
        HiddenToKey::new(vec![vec![1.0, 0.0], vec![0.0, 1.0]], vec![0.0, 0.0])?,
        HiddenToValue::new(vec![vec![1.0, 0.0], vec![0.0, 1.0]], vec![0.0, 0.0])?,
        AttentionOutputProjection::new(vec![vec![1.0, 0.0], vec![0.0, 1.0]], vec![0.0, 0.0])?,
        LayerNormalization::new(LayerNormParameters::identity(fed_forward.model_dimension())),
        PositionWiseFeedForward::new(
            vec![vec![1.0, 0.0], vec![0.0, 1.0]],
            vec![0.0, 0.0],
            vec![vec![1.0, 0.0], vec![0.0, 1.0]],
            vec![0.0, 0.0],
        )?,
        LayerNormalization::new(LayerNormParameters::identity(fed_forward.model_dimension())),
    )?;
    let block_output = block.apply(positioned_hidden.clone())?;

    println!(
        "single-head block shape: {} positions x model dimension {}",
        block_output.sequence_len().value(),
        block_output.model_dimension().value()
    );

    let multi_head_block = MultiHeadTransformerBlock::new(
        vec![
            SelfAttentionHead::new(
                HiddenToQuery::new(vec![vec![1.0], vec![0.0]], vec![0.0])?,
                HiddenToKey::new(vec![vec![1.0], vec![0.0]], vec![0.0])?,
                HiddenToValue::new(vec![vec![1.0], vec![0.0]], vec![0.0])?,
            )?,
            SelfAttentionHead::new(
                HiddenToQuery::new(vec![vec![0.0], vec![1.0]], vec![0.0])?,
                HiddenToKey::new(vec![vec![0.0], vec![1.0]], vec![0.0])?,
                HiddenToValue::new(vec![vec![0.0], vec![1.0]], vec![0.0])?,
            )?,
        ],
        AttentionOutputProjection::new(vec![vec![1.0, 0.0], vec![0.0, 1.0]], vec![0.0, 0.0])?,
        LayerNormalization::new(LayerNormParameters::identity(
            block_output.model_dimension(),
        )),
        PositionWiseFeedForward::new(
            vec![vec![1.0, 0.0], vec![0.0, 1.0]],
            vec![0.0, 0.0],
            vec![vec![1.0, 0.0], vec![0.0, 1.0]],
            vec![0.0, 0.0],
        )?,
        LayerNormalization::new(LayerNormParameters::identity(
            block_output.model_dimension(),
        )),
    )?;
    let multi_head_output = multi_head_block.apply(positioned_hidden)?;

    println!(
        "multi-head block shape: {} positions x {} heads x value dimension {} -> model dimension {}",
        multi_head_output.sequence_len().value(),
        multi_head_block.head_count().value(),
        multi_head_block.value_dimension().value(),
        multi_head_output.model_dimension().value()
    );

    let masked_multi_head_block = MaskedMultiHeadTransformerBlock::new(
        vec![
            SelfAttentionHead::new(
                HiddenToQuery::new(vec![vec![1.0], vec![0.0]], vec![0.0])?,
                HiddenToKey::new(vec![vec![1.0], vec![0.0]], vec![0.0])?,
                HiddenToValue::new(vec![vec![1.0], vec![0.0]], vec![0.0])?,
            )?,
            SelfAttentionHead::new(
                HiddenToQuery::new(vec![vec![0.0], vec![1.0]], vec![0.0])?,
                HiddenToKey::new(vec![vec![0.0], vec![1.0]], vec![0.0])?,
                HiddenToValue::new(vec![vec![0.0], vec![1.0]], vec![0.0])?,
            )?,
        ],
        AttentionOutputProjection::new(vec![vec![1.0, 0.0], vec![0.0, 1.0]], vec![0.0, 0.0])?,
        LayerNormalization::new(LayerNormParameters::identity(
            multi_head_output.model_dimension(),
        )),
        PositionWiseFeedForward::new(
            vec![vec![1.0, 0.0], vec![0.0, 1.0]],
            vec![0.0, 0.0],
            vec![vec![1.0, 0.0], vec![0.0, 1.0]],
            vec![0.0, 0.0],
        )?,
        LayerNormalization::new(LayerNormParameters::identity(
            multi_head_output.model_dimension(),
        )),
    )?;
    let masked_block_output = masked_multi_head_block.apply(Product::new(
        HiddenSequence::new(vec![vec![1.0, 0.0], vec![0.0, 1.0]])?,
        AttentionMask::new(vec![vec![true, false], vec![true, true]])?,
    ))?;

    println!(
        "masked multi-head block shape: {} positions x model dimension {}",
        masked_block_output.sequence_len().value(),
        masked_block_output.model_dimension().value()
    );

    let transformer_parameters = TinyTransformerParameters::new(
        PositionalEncoding::new(vec![vec![0.1, 0.0], vec![0.0, 0.1]])?,
        masked_multi_head_block,
        TransformerReadout::new(
            vec![vec![1.0, 0.0, 0.5], vec![0.0, 1.0, -0.5]],
            vec![0.0, 0.0, 0.0],
        )?,
    )?;
    let transformer_state =
        TransformerTrainingState::new(transformer_parameters, LearningRate::new(0.1)?);
    let training_set =
        TransformerReadoutTrainingSet::new([TransformerReadoutTrainingExample::new(
            HiddenSequence::new(vec![vec![1.0, 0.0], vec![0.0, 1.0]])?,
            AttentionMask::new(vec![vec![true, false], vec![true, true]])?,
            TokenSequence::from_indices([0, 1])?,
        )?])?;
    let sequence_logits = transformer_state.apply(Product::new(
        HiddenSequence::new(vec![vec![1.0, 0.0], vec![0.0, 1.0]])?,
        AttentionMask::new(vec![vec![true, false], vec![true, true]])?,
    ))?;
    let loss_before = transformer_readout_average_loss(&transformer_state, &training_set)?;
    let train_step = TransformerReadoutTrainStep::new(training_set.clone());
    let next_state = train_step.apply(transformer_state.clone())?;
    let loss_after = transformer_readout_average_loss(&next_state, &training_set)?;

    println!(
        "structured transformer logits shape: {} positions x vocabulary size {}",
        sequence_logits.sequence_len().value(),
        sequence_logits.vocab_size().value()
    );
    println!(
        "training state step: {} -> {}",
        transformer_state.step_count().value(),
        next_state.step_count().value()
    );
    println!(
        "readout loss after one update: {:.6} -> {:.6}",
        loss_before.value(),
        loss_after.value()
    );
    let feed_forward_training_set =
        TransformerFeedForwardTrainingSet::new([TransformerFeedForwardTrainingExample::new(
            HiddenSequence::new(vec![vec![1.0, 0.0], vec![0.0, 1.0]])?,
            HiddenSequence::new(vec![vec![2.0, 0.0], vec![0.0, 2.0]])?,
        )?])?;
    let feed_forward_loss_before =
        transformer_feed_forward_average_loss(&next_state, &feed_forward_training_set)?;
    let feed_forward_train_step =
        TransformerFeedForwardTrainStep::new(feed_forward_training_set.clone());
    let feed_forward_state = feed_forward_train_step.apply(next_state.clone())?;
    let feed_forward_loss_after =
        transformer_feed_forward_average_loss(&feed_forward_state, &feed_forward_training_set)?;

    println!(
        "feed-forward loss after one local update: {:.6} -> {:.6}",
        feed_forward_loss_before.value(),
        feed_forward_loss_after.value()
    );
    let block_training_set =
        TransformerBlockTrainingSet::new([TransformerBlockTrainingExample::new(
            HiddenSequence::new(vec![vec![1.0, 0.0], vec![0.0, 1.0]])?,
            AttentionMask::new(vec![vec![true, false], vec![true, true]])?,
            TokenSequence::from_indices([0, 1])?,
        )?])?;
    let block_loss_before =
        transformer_block_average_loss(&feed_forward_state, &block_training_set)?;
    let block_train_step = TransformerBlockTrainStep::new(block_training_set.clone());
    let block_trained_state = block_train_step.apply(feed_forward_state)?;
    let block_loss_after =
        transformer_block_average_loss(&block_trained_state, &block_training_set)?;

    println!(
        "block loss after one composed update: {:.6} -> {:.6}",
        block_loss_before.value(),
        block_loss_after.value()
    );

    println!();
    println!("Typed transformation:");
    println!("HiddenSequence -> QuerySequence");
    println!("HiddenSequence -> KeySequence");
    println!("HiddenSequence -> ValueSequence");
    println!("QuerySequence x KeySequence -> AttentionScores");
    println!("AttentionScores x AttentionMask -> AttentionScores");
    println!("AttentionScores -> AttentionWeights");
    println!("AttentionWeights x ValueSequence -> AttentionOutput");
    println!("AttentionHeadOutputs -> MultiHeadOutput");
    println!("MultiHeadOutput -> ProjectedAttentionOutput");
    println!("HiddenSequence x ProjectedAttentionOutput -> HiddenSequence");
    println!("LayerNormalization : HiddenSequence -> HiddenSequence");
    println!("PositionWiseFeedForward : HiddenSequence -> HiddenSequence");
    println!("PositionalEncoding : HiddenSequence -> HiddenSequence");
    println!("SingleHeadTransformerBlock : HiddenSequence -> HiddenSequence");
    println!("MultiHeadTransformerBlock : HiddenSequence -> HiddenSequence");
    println!("MaskedMultiHeadTransformerBlock : HiddenSequence x AttentionMask -> HiddenSequence");
    println!("TinyTransformerParameters : HiddenSequence x AttentionMask -> SequenceLogits");
    println!("TransformerTrainingState owns parameters, learning rate, and step count");
    println!("TransformerReadoutTrainStep : TransformerTrainingState -> TransformerTrainingState");
    println!(
        "TransformerFeedForwardTrainStep : TransformerTrainingState -> TransformerTrainingState"
    );
    println!("TransformerBlockTrainStep : TransformerTrainingState -> TransformerTrainingState");

    Ok(())
}

examples/07_transformer_training_state.rs

use category_theory_transformer_rs::{
    AttentionMask, AttentionOutputProjection, CtResult, HiddenSequence, HiddenToKey, HiddenToQuery,
    HiddenToValue, LayerNormParameters, LayerNormalization, LearningRate,
    MaskedMultiHeadTransformerBlock, ModelDimension, Morphism, PositionWiseFeedForward,
    PositionalEncoding, Product, SelfAttentionHead, TinyTransformerParameters, TokenSequence,
    TransformerBlockTrainStep, TransformerBlockTrainingExample, TransformerBlockTrainingSet,
    TransformerFeedForwardTrainStep, TransformerFeedForwardTrainingExample,
    TransformerFeedForwardTrainingSet, TransformerReadout, TransformerReadoutTrainStep,
    TransformerReadoutTrainingExample, TransformerReadoutTrainingSet, TransformerTrainingState,
    transformer_block_average_loss, transformer_feed_forward_average_loss,
    transformer_readout_average_loss,
};

fn main() -> CtResult<()> {
    let hidden = HiddenSequence::new(vec![vec![1.0, 0.0], vec![0.0, 1.0]])?;
    let mask = AttentionMask::new(vec![vec![true, false], vec![true, true]])?;
    let targets = TokenSequence::from_indices([0, 1])?;
    let initial_state = tiny_training_state()?;

    let logits = initial_state.apply(Product::new(hidden.clone(), mask.clone()))?;

    println!(
        "initial state: step={}, learning_rate={:.3}, model_dimension={}, vocab_size={}",
        initial_state.step_count().value(),
        initial_state.learning_rate().value(),
        initial_state.parameters().model_dimension().value(),
        initial_state.parameters().vocab_size().value()
    );
    println!(
        "forward shape: {} positions x vocabulary size {}",
        logits.sequence_len().value(),
        logits.vocab_size().value()
    );

    let readout_set =
        TransformerReadoutTrainingSet::new([TransformerReadoutTrainingExample::new(
            hidden.clone(),
            mask.clone(),
            targets.clone(),
        )?])?;
    let readout_loss_before = transformer_readout_average_loss(&initial_state, &readout_set)?;
    let readout_state =
        TransformerReadoutTrainStep::new(readout_set.clone()).apply(initial_state)?;
    let readout_loss_after = transformer_readout_average_loss(&readout_state, &readout_set)?;

    print_update(
        "readout update",
        0,
        &readout_state,
        readout_loss_before.value(),
        readout_loss_after.value(),
    );

    let feed_forward_set =
        TransformerFeedForwardTrainingSet::new([TransformerFeedForwardTrainingExample::new(
            hidden.clone(),
            HiddenSequence::new(vec![vec![2.0, 0.0], vec![0.0, 2.0]])?,
        )?])?;
    let feed_forward_loss_before =
        transformer_feed_forward_average_loss(&readout_state, &feed_forward_set)?;
    let feed_forward_state =
        TransformerFeedForwardTrainStep::new(feed_forward_set.clone()).apply(readout_state)?;
    let feed_forward_loss_after =
        transformer_feed_forward_average_loss(&feed_forward_state, &feed_forward_set)?;

    print_update(
        "feed-forward update",
        1,
        &feed_forward_state,
        feed_forward_loss_before.value(),
        feed_forward_loss_after.value(),
    );

    let block_set = TransformerBlockTrainingSet::new([TransformerBlockTrainingExample::new(
        hidden, mask, targets,
    )?])?;
    let block_loss_before = transformer_block_average_loss(&feed_forward_state, &block_set)?;
    let block_state =
        TransformerBlockTrainStep::new(block_set.clone()).apply(feed_forward_state)?;
    let block_loss_after = transformer_block_average_loss(&block_state, &block_set)?;

    print_update(
        "composed block update",
        2,
        &block_state,
        block_loss_before.value(),
        block_loss_after.value(),
    );

    println!();
    println!("Typed transformation:");
    println!("TinyTransformerParameters : HiddenSequence x AttentionMask -> SequenceLogits");
    println!("TransformerTrainingState owns parameters, learning rate, and step count");
    println!("TransformerReadoutTrainStep : TransformerTrainingState -> TransformerTrainingState");
    println!(
        "TransformerFeedForwardTrainStep : TransformerTrainingState -> TransformerTrainingState"
    );
    println!("TransformerBlockTrainStep : TransformerTrainingState -> TransformerTrainingState");
    println!("Every update returns a full training state, not loose changed weights.");

    Ok(())
}

fn print_update(
    label: &str,
    previous_step: usize,
    state: &TransformerTrainingState,
    loss_before: f32,
    loss_after: f32,
) {
    println!(
        "{label}: step {} -> {}, loss {:.6} -> {:.6}",
        previous_step,
        state.step_count().value(),
        loss_before,
        loss_after
    );
}

fn tiny_training_state() -> CtResult<TransformerTrainingState> {
    let model_dimension = ModelDimension::new(2)?;
    let block = MaskedMultiHeadTransformerBlock::new(
        vec![
            SelfAttentionHead::new(
                HiddenToQuery::new(vec![vec![1.0], vec![0.0]], vec![0.0])?,
                HiddenToKey::new(vec![vec![1.0], vec![0.0]], vec![0.0])?,
                HiddenToValue::new(vec![vec![1.0], vec![0.0]], vec![0.0])?,
            )?,
            SelfAttentionHead::new(
                HiddenToQuery::new(vec![vec![0.0], vec![1.0]], vec![0.0])?,
                HiddenToKey::new(vec![vec![0.0], vec![1.0]], vec![0.0])?,
                HiddenToValue::new(vec![vec![0.0], vec![1.0]], vec![0.0])?,
            )?,
        ],
        AttentionOutputProjection::new(vec![vec![1.0, 0.0], vec![0.0, 1.0]], vec![0.0, 0.0])?,
        LayerNormalization::new(LayerNormParameters::identity(model_dimension)),
        PositionWiseFeedForward::new(
            vec![vec![1.0, 0.0], vec![0.0, 1.0]],
            vec![0.0, 0.0],
            vec![vec![1.0, 0.0], vec![0.0, 1.0]],
            vec![0.0, 0.0],
        )?,
        LayerNormalization::new(LayerNormParameters::identity(model_dimension)),
    )?;
    let parameters = TinyTransformerParameters::new(
        PositionalEncoding::new(vec![vec![0.1, 0.0], vec![0.0, 0.1]])?,
        block,
        TransformerReadout::new(
            vec![vec![1.0, 0.0, 0.5], vec![0.0, 1.0, -0.5]],
            vec![0.0, 0.0, 0.0],
        )?,
    )?;

    Ok(TransformerTrainingState::new(
        parameters,
        LearningRate::new(0.1)?,
    ))
}

examples/challenge_adam.rs

use category_theory_transformer_rs::{
    AdamConfig, AdamDecayRate, AdamEpsilon, AdamGradientVector, AdamModelState,
    AdamParameterVector, AdamTrainStep, CtResult, LearningRate, Morphism,
};

fn main() -> CtResult<()> {
    let state = AdamModelState::from_parameters(AdamParameterVector::new(vec![1.0, -1.0])?);
    let config = AdamConfig::new(
        LearningRate::new(0.1)?,
        AdamDecayRate::new(0.9)?,
        AdamDecayRate::new(0.999)?,
        AdamEpsilon::new(1e-8)?,
    );
    let step = AdamTrainStep::new(AdamGradientVector::new(vec![0.5, -0.25])?, config);
    let updated = step.apply(state)?;

    println!("Paper-To-Rust: Adam");
    println!("paper idea: optimizer state is part of the update boundary");
    println!("typed shape: AdamModelState -> AdamModelState");
    println!(
        "step count: {}, first moment: {:?}, second moment: {:?}",
        updated.optimizer().step_count().value(),
        updated.optimizer().first_moment().as_slice(),
        updated.optimizer().second_moment().as_slice()
    );
    println!("updated parameters: {:?}", updated.parameters().as_slice());
    println!("share line: Stop summarizing Adam. Compile optimizer state.");

    Ok(())
}

Project Configuration

Cargo.toml

[package]
name = "category-theory-transformer-rs"
version = "0.1.0"
edition = "2024"
description = "A Rust book and lab for understanding tiny ML systems through category-theory structure."
repository = "https://github.com/hghalebi/category_theory_transformer_rs"
documentation = "https://hghalebi.github.io/category_theory_transformer_rs/"
license-file = "LICENSE.md"

[dependencies]

Companion Lesson Notes

These are the shorter markdown notes kept under lessons/.

lessons/README.md

# Category Theory for Tiny ML: Compact Lesson Path

This folder is the compact reading path through the codebase.

For the complete self-contained course, use the chapters in `book/src/`. They
include source snapshots, runnable examples, exercises, a glossary, references,
and the full source appendix.

These compact lessons are kept for quick review. Each lesson is intentionally
short and points to a real Rust file that `cargo` checks.

## The Learning Loop

For each lesson:

1. Read the mental model.
2. Open the named Rust module.
3. Run the named example.
4. Answer the checkpoint before moving on.
5. If the lesson names a boundary, find the test that proves the boundary.

## Lessons

1. [Map of the Course](00-map.md)
2. [Domain Objects](01-domain-objects.md)
3. [Morphism and Composition](02-morphisms-composition.md)
4. [The Tiny ML Pipeline](03-ml-pipeline.md)
5. [Training as an Endomorphism](04-training-endomorphism.md)
6. [Functors, Naturality, Monoids, and Chain Rule](05-structure-and-calculus.md)
7. [Seven Sketches Through Rust](06-seven-sketches.md)

## Validation

Run the full check:

```bash
cargo test --all-targets --all-features
```

Run one lesson example:

```bash
cargo run --example 03_training_endomorphism
```

lessons/00-map.md

# 00 - Map of the Course

## Goal

Learn one idea:

> Category theory is a language for typed transformations.

In this repo, the transformations are tiny ML operations:

- token to vector
- vector to logits
- logits to probabilities
- prediction plus target to loss
- parameters to better parameters

## The Code Map

- `src/domain.rs`: nouns, also called objects
- `src/category.rs`: arrows, identity, composition, endomorphisms
- `src/ml.rs`: ML arrows
- `src/training.rs`: one training step
- `src/structure.rs`: functor, natural transformation, monoid
- `src/calculus.rs`: local derivative example
- `src/demo.rs`: one guided terminal walkthrough

## First Run

```bash
cargo run --bin category_ml
```

You should see a tiny language-model pipeline and the loss decreasing after
training.

## Checkpoint

Before moving on, say this in your own words:

> A morphism is a typed function. Composition lets small typed functions become
> one larger typed pipeline.

lessons/01-domain-objects.md

# 01 - Domain Objects

## Mental Model

Objects are the nouns of the system.

In normal Rust code, it is tempting to pass `usize`, `Vec<f32>`, and `f32`
everywhere. That is easy at first, then confusing later. This repo wraps those
values in small domain types.

## Read This File

Open `src/domain.rs`.

Focus only on these types first:

- `TokenId`
- `TokenSequence`
- `Vector`
- `Logits`
- `Distribution`
- `Loss`
- `TrainingSet`
- `Parameters`

## Run the Example

```bash
cargo run --example 01_domain_objects
```

Expected shape:

```text
training pairs:
1 -> 2
2 -> 3
3 -> 4
```

## Why This Matters

`TokenId` and `Loss` are not interchangeable, even if both could be represented
by numbers. The Rust type system keeps those ideas separate.

## Checkpoint

What bug becomes harder when `TokenId` is not just a raw `usize`?

lessons/02-morphisms-composition.md

# 02 - Morphism and Composition

## Mental Model

A morphism is a typed arrow:

```text
Input -> Output
```

In Rust, this repo models that with `Morphism<Input, Output>`.

## Read This File

Open `src/category.rs`.

Read in this order:

1. `Morphism`
2. `Identity`
3. `Compose`
4. `Endomorphism`

## Run the Example

```bash
cargo run --example 02_morphism_composition
```

## What to Notice

The pipeline is built from small arrows:

```text
TokenId -> Vector -> Logits -> Distribution
```

The example prints the middle objects:

```text
Vector = hidden features
Logits = vocabulary scores
Distribution = normalized probabilities
```

`Compose` is the glue. If the middle types do not match, Rust rejects the
program before it runs.

## Checkpoint

Why is `TokenId -> Vector -> Logits` easier to debug than one giant
`predict(...)` function?

lessons/03-ml-pipeline.md

# 03 - The Tiny ML Pipeline

## Mental Model

The ML pipeline is a sequence of typed transformations:

```text
TokenSequence -> TrainingSet
TokenId -> Vector
Vector -> Logits
Logits -> Distribution
Distribution x TokenId -> Loss
```

## Read This File

Open `src/ml.rs`.

Read only these structs first:

- `DatasetWindowing`
- `Embedding`
- `LinearToLogits`
- `Softmax`
- `CrossEntropy`

## Run the Full Demo

```bash
cargo run --bin category_ml
```

Look at sections 2 through 5 in the output.

## The Key Detail

`Softmax` validates that its output is a real probability distribution.
`CrossEntropy` validates that the target token is inside the distribution.

Errors happen at the boundary where the bad data is first understood.

## Checkpoint

Where should an out-of-range target token be caught: inside `CrossEntropy`, or
later after loss calculation?

lessons/04-training-endomorphism.md

# 04 - Training as an Endomorphism

## Mental Model

An endomorphism maps a thing back to the same kind of thing:

```text
A -> A
```

Training does exactly that:

```text
Parameters -> Parameters
```

The model changes, but it is still a model.

## Read This File

Open `src/training.rs`.

Read in this order:

1. `TrainStep`
2. `impl Morphism<Parameters, Parameters> for TrainStep`
3. The unit test at the bottom

## Run the Example

```bash
cargo run --example 03_training_endomorphism
```

Expected pattern:

```text
loss before: ...
loss after:  ...
```

The second number should be smaller.

## Why This Is Category-Theoretic

`TrainStep` can be repeated because its input type and output type are both
`Parameters`.

That makes this loop type-correct:

```text
Parameters0 -> Parameters1 -> Parameters2 -> ... -> ParametersN
```

## Checkpoint

Why would training be harder to compose if it returned raw vectors instead of
`Parameters`?

lessons/05-structure-and-calculus.md

# 05 - Functors, Naturality, Monoids, and Chain Rule

## Mental Model

This lesson gives names to patterns you already use:

- `Functor`: map inside a wrapper without changing the wrapper shape
- `NaturalTransformation`: convert one wrapper shape to another in a consistent way
- `Monoid`: an empty value plus an associative combine operation
- Chain rule: local gradients compose into larger gradients

## Read These Files

Open:

- `src/structure.rs`
- `src/calculus.rs`

## Run the Example

```bash
cargo run --example 04_structure_and_calculus
```

## What to Notice

The examples are tiny on purpose.

`VecFunctor` and `OptionFunctor` are not trying to replace real Rust APIs. They
show the shape of the idea:

```text
keep the container, transform the inside
```

`PipelineTrace` is a monoid because:

- there is an empty trace
- traces can be combined
- grouping does not change the final trace

`MulOp` shows the smallest useful backward pass:

```text
z = x * y
dL/dx = dL/dz * y
dL/dy = dL/dz * x
```

## Checkpoint

Why is "local rule plus composition" the core idea behind backpropagation?

lessons/06-seven-sketches.md

# 06 - Seven Sketches Through Rust

## Mental Model

Applied category theory is useful when it helps you model real structure:
orders, resources, database references, feasibility relations, signal flow,
open systems, and local-to-global checks.

This lesson uses one method:

```text
engineering problem
  -> named Rust values
  -> validated construction
  -> relation or composition
  -> law or boundary test
```

## Read This File

Open `src/sketches.rs`.

Read the tests first. They show the contract each small model is supposed to
protect.

Focus on these examples:

- `InformationLevel::can_flow_to`
- `ResourceBundle::tensor`
- `CompanyInstance::department_for`
- `SignalMatrix::compose_after`
- `OpenCircuit::then`
- `SafetyCover::global_truth`

## Run the Example

```bash
cargo run --example 05_seven_sketches
```

Expected shape:

```text
orders obey preorder laws: true
feature/layer Galois law: true
resource tensor monotone: true
co-design offer feasible: true
global behavior truth: True
```

The exact debug formatting is less important than the contracts. Each output
line is a small executable handle for a larger applied-category idea.

## What to Notice

The negative tests are as important as the positive examples.

They show that the code rejects invalid structure:

- a missing database reference,
- mismatched signal-matrix dimensions,
- an open-circuit boundary mismatch.

That is the same discipline as the tiny ML pipeline. A type or constructor is
valuable when it prevents a wrong connection from becoming ordinary runtime
state.

## Checkpoint

Pick one sketch and answer:

```text
What invalid composition or relationship does this model prevent?
```

A strong answer names the Rust type, the software problem, and the
category-theory shape.