Cover
Category Theory for Tiny ML in Rust
A practical bridge between compositional mathematics, Rust types, and tiny machine-learning systems
Working Draft · Public Feedback Edition
Coauthored by
Hamze Ghalebi
Farzad Jafarranmani
About This Book
Category Theory for Tiny ML in Rust is a working draft that develops a small, explicit machine-learning system through the lens of category theory and Rust.
The book is designed for readers who want to understand machine learning not only as numerical computation, but as a structured pipeline of objects, transformations, composition, and constraints.
Rather than treating category theory as decorative abstraction, this book uses it as an engineering tool:
- domain objects become Rust types,
- morphisms become typed transformations,
- composition becomes executable program structure,
- training becomes repeated transformation of model state,
- and tiny ML systems become a way to make mathematical structure concrete.
This is not a finalized edition. Chapters, examples, terminology, diagrams, code, and references may evolve as the work continues.
The current public edition is still worth publishing: readers can use the existing chapters, run the Rust examples, and send evidence-shaped feedback while later completion passes continue.
The public source repository is available at github.com/hghalebi/category_theory_transformer_rs.
Public Workshop
The first public workshop for this book and Rust lab is hosted through AI Reading Club. It introduces the tiny ML pipeline as typed Rust structure and uses the working draft as the shared study material.
Coauthors
Hamze Ghalebi
Hamze Ghalebi is a Paris-based AI architect, CTO, and software builder associated with Remo Lab. His work focuses on production GenAI, regulated AI systems, auditable AI products, Rust systems, and the transition from AI prototypes to reliable production architectures.
His background includes advanced study at Institut Polytechnique de Paris across statistics, optimization, machine learning, artificial intelligence, distributed systems, cloud computing, and data science.
Hamze brings the engineering and product perspective of the book: how to turn mathematical and machine-learning ideas into understandable, typed, maintainable systems. His current work is especially concerned with AI systems that can be evaluated, monitored, audited, and kept under human accountability in real operational environments.
In this book, his role is to connect tiny ML, Rust implementation, and production-minded software architecture — because apparently making category theory executable was not ambitious enough already.
Farzad Jafarranmani
Farzad Jafarranmani is a researcher and engineer in the Paris area, associated with Huawei and the Lagrange Mathematics and Computing Research Center. His work sits at the intersection of mathematics, computer science, logic, semantics, proof theory, and category theory.
He holds a PhD in Mathematics and Computer Science from Université Paris Cité, where his doctoral work focused on fixpoints of types in linear logic from a Curry–Howard–Lambek perspective. He also studied Mathematics and Computer Science at ENS Paris-Saclay, with work including induction in fibred multicategories and denotational semantics of linear logic with least and greatest fixpoints.
His previous research experience includes postdoctoral work at LIP6, Laboratoire d’Informatique de Sorbonne Université / CNRS, as well as a visiting research position at the University of Cambridge.
Farzad brings the mathematical and theoretical foundation of the book: category theory, denotational semantics, proof theory, type-theoretic structure, and the discipline required to keep abstractions precise instead of merely fashionable.
Public Feedback
Public feedback is welcome while the book is still growing.
Useful feedback includes unclear explanations, broken examples, missing references, awkward terminology, incorrect or overloaded mathematical language, Rust examples that could be clearer or more idiomatic, and places where the connection between Rust, machine learning, and category theory should be made more explicit.
Feedback is easiest to act on when it is opened in the GitHub repository with a specific chapter, command, or source file.
Use the public review path if you want a short route for reviewing the book. If you want the shareable public reviewer call, use Reviewers Needed. If several readers are reviewing together, use the public review sprint to split reports across Rust, ML, category-theory, educator, and beginner perspectives. Use the chapter clarity form when you can name the first unclear sentence, output line, table row, code block, or exercise.
If you are reading the online book and cannot clone the repository right now,
you can still review one public page. Read Welcome, Course Map, Domain Objects,
or Morphism and Composition, then report the first visible sentence, heading,
diagram, table row, code block, or exercise prompt that becomes unclear. Put
public book path in the command or page field and include one evidence signal
from the page. A no-clone report is useful only when it names a public page and
one visible signal; broad praise or broad confusion is not enough.
This edition is intentionally public before it is final.
Citation, Reuse, And Support
Short version:
- The public book will always remain open access at hghalebi.github.io/category_theory_transformer_rs.
- The source repository is available at github.com/hghalebi/category_theory_transformer_rs.
- Cite both the public book URL and the source repository URL where reuse is allowed.
- Personal and individual study are allowed with clear citation.
- One reader may study, cite, link, clone, and run the project for personal learning.
- Commercial or organizational group reuse involving more than one person requires written permission before reproducing, adapting, distributing, or teaching material from the book or repository beyond short quotation, linking, review, and individual-study allowances. This includes company workshops, internal team workshops, classes, cohorts, courses, and training programs.
- Company workshops, internal team workshops, paid workshops, commercial training programs, course packs, adapted slide decks, handouts, labs, and workshop packets require written permission when they reuse substantial material from this project.
- When Kindle or hard copy editions are available, buying the Kindle version or a hard copy supports continued public work. Paid editions are support editions, not access gates.
The public book will always remain open access at hghalebi.github.io/category_theory_transformer_rs.
The source repository is available at github.com/hghalebi/category_theory_transformer_rs.
These are custom citation-and-permission terms. Open access means the public book remains free to read online; it does not mean unrestricted commercial redistribution, commercial training use, company workshop use, or organizational group reuse.
The source code is published so readers can inspect, run, test, and contribute to the examples. Substantial reproduced code, prose, exercises, diagrams, or adapted teaching material used in a commercial or organizational group setting follows the same written-permission rule.
Original material from this book may be linked, quoted in short form, reviewed, or discussed for personal, academic, or noncommercial educational use when the source is clearly referenced and both coauthors are credited.
Suggested reference format. Use both the public book URL and the source repository URL:
Ghalebi, H., & Jafarranmani, F.
Category Theory for Tiny ML in Rust.
Open-access working draft.
Book: https://hghalebi.github.io/category_theory_transformer_rs/
Source: https://github.com/hghalebi/category_theory_transformer_rs
Use this citation in permitted reuse contexts such as papers, posts, slides, course notes, workshop pages, repositories, and public references.
Keep both URLs in public references. The book page is the open-access reading surface; the source repository is the executable Rust source for the examples.
Reproducing, adapting, distributing, or teaching material from this book or repository in a commercial or organizational setting that involves more than one person requires written permission from the project owners, except for short quotation, linking, review, and individual-study allowances. This includes company workshops, company-sponsored workshops, internal team workshops, company reading groups based on copied or adapted material, paid training material, course packs, adapted slide decks, handouts, labs, and workshop packets. If the material is reused by or for a company, team, class, workshop, cohort, course, or training program with more than one person, request written permission first. Citation is required where reuse is allowed, but citation alone is not permission and does not replace written permission for commercial or organizational group reuse. See the repository license and reuse terms.
Company workshops, internal team workshops, paid workshops, and workshop packets count as commercial or organizational group reuse when they reproduce, adapt, distribute, or teach substantial material from the book or repository.
Permission requests should start through the source repository, for example by opening an issue or contacting the maintainers from the repository page.
Opening an issue or sending a request does not itself grant permission. Only an explicit written approval from a project owner or maintainer grants permission for the requested commercial or organizational group use.
When Kindle or hard copy editions are available, buying the Kindle version or a hard copy is a way to support continued work on the project. Paid editions are support editions, not access gates. Paid editions will not remove free public access to the online book.
External works cited or referenced by this book remain under their own licenses, terms, and attribution requirements.
Category Theory for Tiny ML in Rust
This is a public working draft. The current edition is published so readers can learn from it now, run the examples, and send precise feedback. It is not the completed textbook yet; later passes will keep revising the chapters from source review, exercises, and direct reader reports.
First Win
From the repository root, run:
cargo run --example 01_token_sequence
That command turns one sentence into typed training material:
Text
-> TokenSequence
-> TrainingPairs
The point of starting there is practical. Before the book asks you to care about category theory, it lets you run a small program and inspect the shape it prints.
First Output Transfer Checklist
Use the first run as a reading test. Do not treat the output as a demo banner. Treat each printed block as evidence for a boundary.
| Printed output | Rust reading | ML reading | Category-theory reading |
|---|---|---|---|
Raw input | an ordinary &str enters the program | text before tokenization | source object before the first transformation |
TokenSequence | a validated domain object built from TokenId values | tokenized data the tiny system can inspect | object with named structure |
TrainingPairs | adjacent Product<TokenId, TokenId> values | input-target examples for next-token learning | product-shaped training examples |
Typed transformation | the command names the path the code took | text became examples for learning | a short chain of morphisms |
The first win is not that the tokenizer is impressive. It is deliberately tiny. The win is that you can point at the output and say:
this raw value became this domain object,
then this domain object became training examples,
and the transformation path is visible.
That is the reading habit the rest of the book repeats at larger scales.
Source-Backed Reading Contract
This welcome chapter uses sources to keep the first session practical. Each source supports one local rule for how the reader should move from the first command to the rest of the book.
| Source | What the source supports | Local rule in this chapter | Repository evidence |
|---|---|---|---|
| How People Learn II | New learning works better when it connects to prior knowledge, learner context, and transfer. | Start from what many readers already know: a function has an input, an output, and a visible transformation. | fn token_to_position(token_id: usize) -> usize, ## What You Already Know, ## Self-check |
| Rust By Example | Small runnable programs make syntax inspectable before the explanation gets abstract. | Make the first proof a command the reader can run before reading theory. | cargo run --example 01_token_sequence, examples/01_token_sequence.rs |
| Seven Sketches | Applied category theory becomes learnable through concrete compositional examples. | Name objects, morphisms, products, composition, and laws only after the reader can point to a tiny typed pipeline. | Text -> TokenSequence -> TrainingPairs, Distribution x TokenId -> Loss, Parameters -> Parameters |
The transfer pattern is:
run one small example -> name the visible boundary -> reuse the reading habit
For this chapter, the first command is evidence for a small claim:
Text becomes TokenSequence.
TokenSequence becomes TrainingPairs.
The path is visible in terminal output.
It is not evidence that the whole book is easy for every reader yet. That is why the public review path asks for exact evidence signals when a sentence, output line, table row, code block, or exercise breaks the learning path.
Then run the guided walkthrough:
cargo run --bin category_ml
That command walks through the larger pipeline: domain objects, morphisms, composition, prediction, loss, repeated training, functors, monoids, and the small chain-rule example.
The repository is public at github.com/hghalebi/category_theory_transformer_rs. Use it for source files, runnable examples, issues, and contribution work.
Help Improve This Book
If you want to help as a reader, use the public review path. If you want the shareable public call for the five reviewer perspectives, use Reviewers Needed. If you are reviewing with a group, use the public review sprint to collect one report from each reader perspective.
The most useful report is small:
Command or page tried:
Evidence signal:
Last clear idea:
First unclear sentence, output line, table row, code block, or exercise:
What would have helped:
If you are reading the online book and cannot clone the repository right now,
use the same shape. Put public book path in Command or page tried, name the
page you read, and quote or describe one visible evidence signal: a sentence,
heading, diagram, table row, code block, or exercise prompt.
Open the chapter clarity form when the learning path breaks. One exact blocked step is more useful than a broad review.
What This Book Is About
Most machine-learning education starts with frameworks.
Frameworks are useful. They let us train real models quickly. But they also hide the small structure underneath: the types of values moving through the system, the transformations between those values, the loss that measures error, and the update step that changes the model.
This book takes the opposite path.
It builds a tiny learning system slowly enough that every important shape can be read in Rust:
Text
-> TokenSequence
-> TrainingSet
-> Prediction
-> Loss
-> Updated Parameters
The goal is to make the hidden structure easier to see.
Executable structure, not AI magic.
The Central Thesis
This book is built around one claim:
A useful ML system is a chain of typed transformations.
Rust gives those transformations compile-checked boundaries. Category theory gives names to recurring shapes such as objects, morphisms, products, composition, endomorphisms, functors, and monoids. Tiny ML keeps the system small enough to inspect completely.
The book uses all three, but in this order:
intuition
-> small Rust example
-> ML meaning
-> category-theory name
-> runnable exercise
The category-theory words should arrive after the reader has seen the shape in code.
What You Already Know
If you have written a Rust function, you already know the first shape. A function has an input type, an output type, and a body that explains how to move from one to the other.
Worked Example: Naming One Raw Value
Start with the deliberately unsafe version:
#![allow(unused)]
fn main() {
fn token_to_position(token_id: usize) -> usize {
token_id + 100
}
assert_eq!(token_to_position(3), 103);
}
This is a transformation from one type to another:
usize -> usize
The problem is that both sides are too vague. A raw usize might mean a token
index, a vocabulary size, a vector dimension, a training step, or a row number.
Those are different concepts, even if the machine representation is the same.
The book’s first move is to give those concepts names.
pub struct TokenId(usize);
pub struct VocabSize(usize);
pub struct ModelDimension(usize);
Now the reader can ask better questions:
Can this token be embedded?
Does this vector have the expected dimension?
Is this probability distribution valid?
Can this loss be accumulated?
Can this training update be repeated?
That is where the Rust type system starts to become part of the explanation.
Self-check
Before continuing, explain what changed when token_id: usize became
TokenId. Did the machine representation change, or did the program gain a
clearer boundary?
The Three Readings
Every important idea in the book is read three ways.
Rust Reading
The Rust reading asks:
What type is this?
What function or trait connects it to another type?
What invariant does the constructor protect?
What error can happen at the boundary?
For example:
pub struct TokenSequence(Vec<TokenId>);
This is not only “a struct containing a vector.” In the real source, it is a
controlled domain object. Other code can use a TokenSequence, but it cannot
freely reach inside and mutate the raw representation.
ML Reading
The ML reading asks:
What stage of the learning pipeline is this?
Is it data, prediction, loss, or an update?
What would a larger framework usually hide here?
A token sequence is not the model yet. It is data after tokenization and before training pairs. A distribution is not just a vector of floats. It is a vector of non-negative probabilities that should sum to one.
Category-Theory Reading
The category-theory reading asks:
What object is this?
What morphism starts here or ends here?
Can two transformations compose?
Is this update an endomorphism?
Which law is the code trying to make visible?
The point is not to make the code sound more abstract. The point is to name the same shape that the Rust and ML readings already revealed.
The Main Picture
The tiny model is organized around this chain:
TokenSequence -> TrainingSet
TokenId -> Vector
Vector -> Logits
Logits -> Distribution
Distribution x TokenId -> Loss
Parameters -> Parameters
Read it left to right.
The first line prepares examples.
The middle lines make a prediction and measure error.
The last line updates the model.
The Rust reading is:
types + constructors + traits + errors + tests
The ML reading is:
data + scores + probabilities + loss + training
The category-theory reading is:
objects + morphisms + products + composition + laws
Learning Contract
Use the same loop in every chapter. Start with the practical problem, read the smallest example, and then inspect the relevant source snapshot. Translate the Rust type or function into plain English before connecting it to the ML pipeline. Only after the code is concrete should the chapter name the category-theory shape. Then run the example and answer the retrieval questions without looking back.
The chapters are deliberately repetitive in structure. That repetition is part of the learning design. The pattern should become familiar:
raw representation
-> validated domain object
-> typed transformation
-> composed pipeline
-> checked law
What This Book Is Not
This is not a production ML framework.
This is not a performance-first Rust implementation.
This is not category theory as decoration.
This is not a promise that every advanced mathematical idea has been fully formalized in the code.
The examples are intentionally small. They are designed to make structure visible before speed, scale, or completeness enter the conversation.
Reading Path
Read the chapters in order on the first pass.
The Course Map gives the whole pipeline shape.
Domain Objects names the typed nouns.
Morphism and Composition names the typed arrows between them.
The Tiny ML Pipeline turns those arrows into prediction and loss.
Training as an Endomorphism shows why one optimizer step has the repeatable shape:
Parameters -> Parameters
After the core pipeline, Functors, Naturality, Monoids, and Chain Rule introduces reusable structure, and Seven Sketches Through Rust widens the method to applied category theory.
Use the Exercises for practice, the Glossary for terms, the References for chapter-specific sources, and the Transformer Roadmap for the path toward attention.
Live Study
The first public workshop for the project is available through Luma registration.
The workshop is a guided study path through the same tiny pipeline. It is useful if you want to see the code, diagrams, and vocabulary connected live.
The public session plan is available in the repository: First online workshop curriculum.
What To Remember
The central discipline is:
Do not let raw values travel farther than they should.
A raw usize becomes TokenId.
A raw Vec<TokenId> becomes TokenSequence.
A raw Vec<f32> becomes Distribution only after probability validation.
A raw optimizer update becomes TrainStep, a typed endomorphism:
Parameters -> Parameters
The result is a small codebase where every concept has a name, every boundary has a type, and every composition has to make sense before Rust lets it run.
Where This Leaves Us
This welcome chapter sets the reading contract. You will see the same idea through Rust syntax, tiny ML behavior, and category-theory shape. The next chapter, Course Map, gives the full map before the book starts reading individual source files.
Practice After This Chapter
Do one small check before moving on: run cargo run --example 01_token_sequence
and explain one output line using the three-lens shape from this chapter. If
you want a written prompt, use the first-output transfer checklist above and
Beginner Exercise 3 in Exercises.
Retrieval Practice
Recall
What is the central pipeline shape this book keeps returning to?
Explain
Why does the book connect every concept to Rust syntax, ML meaning, and category-theory shape?
Apply
Pick one raw value from the pipeline, such as a token index or probability vector. Give it a domain-type name and explain what confusion the name prevents.
Course Map
The problem this chapter solves is:
Before reading individual source files, you need one map that connects the tiny ML pipeline, the Rust modules, and the category-theory vocabulary.
The repository is intentionally small, but it still has layers. One layer names the values. Another layer names transformations between values. A third layer uses those transformations to make predictions, measure loss, and update model parameters.
This chapter gives you the whole map before the book zooms in.
Chapter Outcomes
By the end of this chapter, you should be able to:
- place each printed line from
cargo run --bin category_mlinto domain value, typed transformation, or training update, - explain how
src/domain.rs,src/category.rs,src/ml.rs, andsrc/training.rsdivide responsibility, - translate the book’s first pipeline into objects, morphisms, product input, loss, and endomorphism language.
Choose Your Path
Use the book-first path if you want the concepts introduced in order:
Welcome
-> Course Map
-> Domain Objects
-> Morphism and Composition
-> Tiny ML Pipeline
-> Training as an Endomorphism
Use the code-first path if you learn faster by running something first:
cargo run --example 01_token_sequence
cargo run --bin category_ml
Then come back to this map and place each printed line in one of three locations:
domain value
typed transformation
training update
Both paths are valid. The book-first path reduces surprise. The code-first path reduces abstraction anxiety. The important thing is not to open every file at once. Start with one path, run one command, and attach each new word to one visible Rust shape.
What You Already Know
If you read a program from top to bottom, you already know how to follow a flow. If you read a Rust function signature, you already know that a step has an input type and an output type. If you have seen any ML pipeline, you already know that raw data eventually becomes predictions, loss, and updates.
The map in this chapter puts those familiar habits together:
value
-> transformation
-> composed transformations
-> measured error
-> repeated update
The category-theory vocabulary is not a separate layer pasted on top. It names shapes that are already present in the Rust and ML readings.
Worked Example: From One Function To A Pipeline
Start with one ordinary function:
#![allow(unused)]
fn main() {
fn token_to_vector_id(token_id: usize) -> usize {
token_id + 100
}
assert_eq!(token_to_vector_id(7), 107);
}
This has the shape:
usize -> usize
That is a transformation, but it is not yet a good teaching boundary. Both sides use the same raw type, so the signature does not tell us whether the number is a token, a vector row, a dimension, or something else.
The book replaces that vague movement with named stages:
TokenId -> Vector
Then it composes more stages:
TokenId -> Vector -> Logits -> Distribution
That is the basic move for the whole book. Start with a familiar function, give the meaningful values names, then ask which typed transformations can compose safely.
Self-check
Before continuing, explain why TokenId -> Vector carries more information
than usize -> Vec<f32>. A strong answer should mention both reader clarity
and compiler-checked boundaries.
The Whole Pipeline
The first mental model is:
Text -> Tokens -> TrainingPairs -> ModelState -> Prediction -> Loss -> Updated ModelState
Read it as one question:
What object do we have now, and what typed transformation moves us to the next object?
The same diagram with the first concrete Rust names is:
Text
|
| tokenize
v
TokenSequence
|
| adjacent pairs
v
TrainingSet
|
| train with current Parameters
v
Parameters
|
| predict
v
Distribution
|
| compare with target token
v
Loss
|
| optimizer step
v
Parameters
The public names and Rust names are close, but not identical:
| Reader-facing name | Rust name in this project | Why the distinction matters |
|---|---|---|
Tokens | TokenSequence | the code preserves order, not only a bag of token IDs |
TrainingPairs | TrainingSet of Product<TokenId, TokenId> | each example has an input token and the next-token target |
ModelState | Parameters | this tiny model’s trainable state is its embedding and projection parameters |
Updated ModelState | updated Parameters | training is a state update, not a new kind of object |
The central book pipeline is:
Text
-> TokenSequence
-> TrainingSet
-> Prediction
-> Loss
-> Updated Parameters
The concrete Rust shape is slightly more detailed:
TokenSequence -> TrainingSet
TokenId -> Vector
Vector -> Logits
Logits -> Distribution
Distribution x TokenId -> Loss
Parameters -> Parameters
The same map can be drawn as a learner-facing flow:
raw text
|
v
TokenSequence --DatasetWindowing--> TrainingSet
|
v
TokenId --Embedding--> Vector --LinearToLogits--> Logits --Softmax--> Distribution
|
v
Product<Distribution, TokenId>
|
v
CrossEntropy -> Loss
Parameters --TrainStep--> Updated Parameters
The same course map as a compact rendered math view:
[ \begin{array}{ccccccccc} \mathrm{Text} & \to & \mathrm{TokenSequence} & \xrightarrow{\mathrm{DatasetWindowing}} & \mathrm{TrainingSet} & \leadsto & \mathrm{Product}\langle\mathrm{Distribution},\mathrm{TokenId}\rangle & \xrightarrow{\mathrm{CrossEntropy}} & \mathrm{Loss} \ &&& &&& \uparrow \mathrm{Softmax \circ LinearToLogits \circ Embedding} && \ &&& &&& \mathrm{TokenId} && \ \mathrm{Parameters} & \xrightarrow{\mathrm{TrainStep}} & \mathrm{UpdatedParameters} &&&&&& \end{array} ]
Read the top path as prediction and evaluation. Read the bottom path as
training state. The two meet because TrainStep uses the training set, current
parameters, prediction path, and loss to produce updated parameters.
If the text diagram is easier to read first, use it first. If the rendered view is easier to track, redraw it and label the Rust object behind every mathematical name. Both views are teaching aids; the proof that the map is real is still the code and the commands.
Read that map in three ways.
The Rust reading is about named types, trait implementations, constructors, fallible boundaries, and tests. The ML reading is about data preparation, embeddings, scores, probabilities, error measurement, and parameter updates. The category-theory reading is about objects, morphisms, products, composition, endomorphisms, and laws.
These are not three different books. They are three readings of the same small program.
Module Map
The source tree follows the learning path. Each file owns one part of the conceptual load, so the reader does not have to learn every abstraction at the same time.
| File | What it teaches | Main shape |
|---|---|---|
src/domain.rs | Meaningful values | TokenId, Vector, Distribution, Parameters |
src/category.rs | Typed arrows | Morphism<Input, Output> |
src/ml.rs | Concrete ML transformations | TokenId -> Vector -> Logits -> Distribution |
src/training.rs | Repeated updates | Parameters -> Parameters |
src/structure.rs | Reusable structure | functor, natural transformation, monoid |
src/calculus.rs | Local derivative flow | chain rule for z = x * y |
src/sketches.rs | Applied category-theory sketches | typed models plus law checks |
src/demo.rs | Full guided walkthrough | executable course outline |
This table is not something to memorize. Use it as a navigation tool. When a later chapter names a concept, you should be able to place it in one file and one row of the pipeline.
The Library Surface
The library root collects the modules and re-exports the teaching types. That is why the examples can import clear names instead of reaching through deep paths.
The important design is:
domain nouns
category arrows
ML arrows
training update
structure patterns
calculus rule
applied sketches
demo
That order is also the reading path. The book first asks “what are the values?” Then it asks “what transformations are allowed?” Only after those two questions does it build prediction, loss, and training.
How The Files Fit Together
src/domain.rs defines the nouns. A TokenId is not a VocabSize, a
Distribution is not a raw vector, and a LearningRate is not any other
floating-point number. This file protects meaning at the boundary where raw
machine values enter the tutorial.
src/category.rs defines the arrows. The central trait is:
pub trait Morphism<Input, Output> {
fn name(&self) -> &'static str;
fn apply(&self, input: Input) -> CtResult<Output>;
}
That trait says: a morphism is something that transforms an Input into an
Output, and the transformation may fail with a typed course error.
src/ml.rs makes the arrows concrete. It implements dataset windowing,
embedding lookup, linear projection, softmax, and cross entropy. This is where
the abstract phrase “typed transformation” becomes a tiny learning pipeline.
src/training.rs defines the update step:
TrainStep : Parameters -> Parameters
Because the output type is the same as the input type, the update can be repeated:
Parameters0 -> Parameters1 -> Parameters2 -> ... -> ParametersN
That is why the training chapter teaches one optimizer step as an endomorphism. It is a transformation from a type back to itself.
src/structure.rs gives names to reusable patterns that appear after the
pipeline works: mapping inside a container, converting one wrapper shape to
another, and combining traces with an identity value. These ideas are useful
because real systems accumulate logs, batches, optional results, gradients, and
workflow traces.
src/calculus.rs keeps backpropagation deliberately small. It shows the local
rule for:
z = x * y
dL/dx = dL/dz * y
dL/dy = dL/dz * x
This is not a full automatic-differentiation engine. It is the smallest local chain-rule shape the later training story can point at.
src/sketches.rs connects the tutorial to applied category theory beyond the
tiny ML pipeline. It models orders, resources, databases, co-design, signal
flow, circuits, and behavior logic as typed Rust values with law-checking
tests.
Guided Walkthrough Snapshot
The terminal demo is the spine of the book. It gives a learner one command that uses every major idea in a concrete order.
Source snapshot: src/demo.rs
use crate::calculus::{LocalGradient, MulOp, Scalar};
use crate::category::{Compose, StepCount, apply_endomorphism_n_times};
use crate::domain::{
LearningRate, Logits, ModelDimension, Parameters, Product, TokenId, TokenSequence, Vector,
VocabSize,
};
use crate::error::CtResult;
use crate::ml::{
CrossEntropy, DatasetWindowing, Embedding, LinearToLogits, Softmax, average_loss,
composed_prediction_matches_direct_prediction,
};
use crate::structure::{
Functor, Monoid, OptionFunctor, PipelineTrace, TraceStep, VecFunctor,
monoid_laws_hold_for_pipeline_trace, naturality_square_holds_for_first_option,
};
use crate::training::TrainStep;
use crate::{Identity, Morphism};
/// Run the full terminal walkthrough used by `cargo run --bin category_ml`.
pub fn run_demo() -> CtResult<()> {
println!("Category theory concepts implemented in Rust 2024");
println!("=================================================\n");
let vocab = ["<pad>", "I", "love", "Rust", "."];
let raw_text = TokenSequence::from_indices([1, 2, 3, 4, 1, 2, 3, 4])?;
println!("1. Object examples");
println!(" TokenId(1) means {:?}\n", vocab[1]);
println!("2. Dataset morphism: TokenSequence -> TrainingSet");
let dataset = DatasetWindowing.apply(raw_text)?;
for example in dataset.examples() {
println!(
" {:?} -> {:?}",
vocab[example.first().index()],
vocab[example.second().index()]
);
}
println!();
println!("3. Identity morphism: id_Vector : Vector -> Vector");
let v = Vector::new(vec![1.0, 2.0, 3.0]);
let same_v = Identity::<Vector>::new().apply(v.clone())?;
println!(" input = {:?}", v);
println!(" output = {:?}\n", same_v);
println!("4. Composition: Softmax after Linear after Embedding");
let params = Parameters::init(VocabSize::new(vocab.len())?, ModelDimension::new(4)?);
let embedding = Embedding::from_parameters(¶ms);
let linear = LinearToLogits::from_parameters(¶ms);
let token_to_logits = Compose::<_, _, Vector>::new(embedding, linear);
let token_to_distribution = Compose::<_, _, Logits>::new(token_to_logits, Softmax);
let distribution = token_to_distribution.apply(TokenId::new(1))?;
println!(" P(next token | 'I') = {:?}\n", distribution.as_slice());
println!("5. Product object: Prediction x Target -> Loss");
let loss = CrossEntropy.apply(Product::new(distribution, TokenId::new(2)))?;
println!(" loss for target 'love' = {:.6}\n", loss.value());
println!("6. Endomorphism: TrainStep : Parameters -> Parameters");
let before = average_loss(¶ms, &dataset)?;
let train_step = TrainStep::new(dataset.clone(), LearningRate::new(1.0)?);
let trained_params =
apply_endomorphism_n_times(&train_step, params.clone(), StepCount::new(80))?;
let after = average_loss(&trained_params, &dataset)?;
println!(" average loss before training = {:.6}", before.value());
println!(" average loss after training = {:.6}\n", after.value());
println!("7. Functor: fmap over Vec and Option");
let xs = vec![1, 2, 3];
let ys = VecFunctor::fmap(xs, |x| x * x);
let maybe = OptionFunctor::fmap(Some(7), |x| x + 1);
println!(" VecFunctor fmap square: {:?}", ys);
println!(" OptionFunctor fmap +1: {:?}\n", maybe);
println!("8. Natural transformation: Vec<A> -> Option<A>");
println!(
" naturality square holds: {}\n",
naturality_square_holds_for_first_option()
);
println!("9. Monoid: pipeline traces compose associatively with identity");
let trace = PipelineTrace::from_steps(vec![TraceStep::new("embedding")])
.combine(&PipelineTrace::from_steps(vec![TraceStep::new("linear")]))
.combine(&PipelineTrace::from_steps(vec![TraceStep::new("softmax")]));
println!(" trace = {:?}", trace.names());
println!(
" monoid laws hold for this trace type: {}\n",
monoid_laws_hold_for_pipeline_trace()
);
println!("10. Commutative diagram check");
println!(
" composed prediction == direct prediction: {}\n",
composed_prediction_matches_direct_prediction(¶ms)?
);
println!("11. Chain rule / local derivative morphism");
let mul = MulOp;
let x = Scalar::new(2.0)?;
let y = Scalar::new(3.0)?;
let z = mul.forward(x, y)?;
let upstream = LocalGradient::new(1.0)?;
let (dl_dx, dl_dy) = mul.backward(x, y, upstream)?;
println!(" z = x * y = {}", z.value());
println!(
" if dL/dz = {}, then dL/dx = {}, dL/dy = {}\n",
upstream.value(),
dl_dx.value(),
dl_dy.value()
);
println!("Compressed categorical training view:");
println!(" Dataset x Parameters -> Prediction -> Loss -> Gradients -> Updated Parameters");
println!(" TrainStep is repeated as Parameters0 -> Parameters1 -> ... -> ParametersN");
Ok(())
}
How To Read The Demo
The demo output is a miniature course outline.
It starts with an object:
TokenId(1)
Then it applies a data-preparation morphism:
TokenSequence -> TrainingSet
Then it shows identity and composition:
Vector -> Vector
TokenId -> Vector -> Logits -> Distribution
Then it uses a product object to measure loss:
Distribution x TokenId -> Loss
Then it repeats an endomorphism:
Parameters -> Parameters
The later demo sections add functors, naturality, monoids, a commutative diagram check, and a local chain-rule example. By the time you finish the demo, you have seen each major term at least once in executable form.
Demo Output Wayfinding Checklist
After running cargo run --bin category_ml, use the numbered output as a map
instead of reading it as one long printout.
| Demo section | Source file to inspect next | Rust reading | ML reading | Category-theory reading |
|---|---|---|---|---|
1. Object examples | src/domain.rs | TokenId gives a raw index a domain name | tokens are data, not model state | object |
2. Dataset morphism | src/ml.rs | DatasetWindowing turns a sequence into pairs | text becomes supervised examples | morphism |
3. Identity morphism | src/category.rs | Identity<Vector> returns the same value | a neutral transformation should not change features | identity law |
4. Composition | src/ml.rs and src/category.rs | Compose connects matching output and input types | embedding, logits, and softmax form prediction | composition |
5. Product object | src/domain.rs and src/ml.rs | Product<Distribution, TokenId> pairs prediction with target | loss needs both prediction and correct next token | product object |
6. Endomorphism | src/training.rs | TrainStep returns Parameters | training updates model state | endomorphism |
7-9. Structure patterns | src/structure.rs | traits and tests name reusable operations | batches, options, and traces recur in ML systems | functor, naturality, monoid |
10. Commutative diagram check | src/ml.rs | two code paths are compared | direct and composed prediction should agree | commutative diagram |
11. Chain rule | src/calculus.rs | MulOp::backward returns local gradients | backprop starts from local derivative rules | chain rule |
This table gives you a safe next action. If the output line is clear, continue reading. If it is not clear, open the source file in the second column and look for the type or function named by the line. The goal is not to memorize the demo. The goal is to use it as a routing table from terminal output to chapter, source file, ML role, and category-theory shape.
Source-Backed Wayfinding Rules
This chapter uses sources to keep the opening map practical. Each source supports one local rule for how a first session should move from command output to source files and then to vocabulary.
| Source | What the source supports | Local rule in this chapter | Repository evidence |
|---|---|---|---|
| How People Learn II | Learning should connect new ideas to prior knowledge and learner context. | Start from a familiar function, then attach Rust, ML, and category-theory names to one visible pipeline. | ## Worked Example: From One Function To A Pipeline, ## The Three Readings |
| Rust Book: Packages, Crates, and Modules | Rust packages organize code into crates and modules with separate responsibilities. | Treat the source tree as the learning map: nouns, arrows, ML arrows, training, structure, calculus, sketches, and demo. | src/domain.rs, src/category.rs, src/ml.rs, src/training.rs, src/demo.rs |
| Rust By Example | Small runnable examples make syntax inspectable before a larger explanation. | Run one command, inspect its output, then route the output line to the matching source file. | cargo run --example 01_token_sequence, cargo run --bin category_ml |
| Seven Sketches | Applied category theory is taught through compositional examples and recurring shapes. | Introduce object, morphism, product, composition, endomorphism, and law as names for shapes already visible in the tiny pipeline. | TokenId -> Vector -> Logits -> Distribution, Distribution x TokenId -> Loss, Parameters -> Parameters |
| Category Theory for Programming | Programming examples can make category-theory vocabulary less detached from code. | Translate from Rust file and function evidence to category vocabulary only after the typed path is visible. | ## Demo Output Wayfinding Checklist, Morphism<Input, Output>, Compose<F, G, Middle> |
The transfer pattern is:
source rule -> route through files -> command/output evidence
For this chapter, that means using cargo run --bin category_ml as more than a
demo. Treat it as a table of contents whose output routes you to
src/domain.rs, src/category.rs, src/ml.rs, src/training.rs,
src/structure.rs, src/calculus.rs, and src/sketches.rs.
The table is not evidence that the book has solved every learner’s route through the material. It is evidence that the first-session map is grounded in named sources, real files, and executable output.
Binary Entrypoint
The binary entrypoint is deliberately tiny:
Source snapshot: src/bin/category_ml.rs
fn main() -> category_theory_transformer_rs::CtResult<()> {
category_theory_transformer_rs::run_demo()
}
The whole file delegates to the library walkthrough:
fn main() -> category_theory_transformer_rs::CtResult<()> {
category_theory_transformer_rs::run_demo()
}
The binary returns CtResult, so fallible work can propagate through Rust’s
ordinary Result path. The binary stays short because this book is teaching
the typed pipeline, not command-line interface design.
First Run
Start with the smallest visible pipeline:
cargo run --example 01_token_sequence
That command turns text into token IDs and next-token training pairs before any model weights appear.
Then run the full guided demo:
cargo run --bin category_ml
The exact floating-point values are less important than the shape. You should see a loss before training, a lower loss after repeated training, and the same typed pipeline used throughout the walkthrough.
Core Mental Model
Every chapter after this one zooms into one row of the map.
An object is a typed thing the program can talk about precisely.
A morphism is a typed transformation from one object to another.
Composition is a legal connection between transformations, where the output type of one step matches the input type of the next.
An endomorphism is a transformation from a type back to itself.
A law is a property the code checks so the reader can trust the shape, not only the example output.
Checkpoint
Explain this line in your own words:
TokenId -> Vector -> Logits -> Distribution
A strong answer should mention token lookup, the embedding vector, vocabulary scores, the probability distribution, and the fact that the whole path is a composition of typed morphisms.
Where This Leaves Us
This chapter gave the whole shape before the details. You now know the names of the source files, the major pipeline objects, and the difference between objects, morphisms, composition, endomorphisms, and laws.
The next chapter, Domain Objects, slows down and studies the objects themselves. Before a pipeline can compose arrows safely, it needs values whose meanings are clear enough for arrows to start and end at them.
Further Reading
Do not leave this chapter with only a list of links. Use the next sources to practice the map.
Start from this local evidence:
cargo run --example 01_token_sequence
cargo run --bin category_ml
src/domain.rs
src/category.rs
src/ml.rs
src/training.rs
Then read the sources in this order:
| Source | What to transfer back into this chapter | Local evidence to inspect |
|---|---|---|
| How People Learn II | New ideas should connect to prior knowledge, context, and visible learner activity. | ## What You Already Know, ## Demo Output Wayfinding Checklist |
| Rust Book: Packages, Crates, and Modules | A Rust package can expose a library crate, binary crates, and modules with separate responsibilities. | src/bin/category_ml.rs, src/lib.rs, src/domain.rs, src/category.rs |
| Rust By Example | Small runnable examples make syntax inspectable before a larger explanation. | examples/01_token_sequence.rs, examples/02_morphism_composition.rs |
| Seven Sketches | Applied category theory can be introduced through concrete examples before abstraction. | TokenId -> Vector -> Logits -> Distribution |
| Category Theory for Programming | Programming-shaped examples can keep category vocabulary attached to code. | Morphism<Input, Output>, Compose<F, G, Middle> |
After reading one external source, ask four questions:
- Which command output line did it make easier to place?
- Which source file should you inspect next?
- Which category word did it clarify?
- Which later chapter should you read after this map?
For this chapter, the commands are:
cargo run --example 01_token_sequence
cargo run --bin category_ml
For terminology recovery, use the Glossary entries for object, morphism, composition, and endomorphism. For source depth, use References and the Seven Sketches Through Rust companion chapter after you can already route the first demo output.
If an external source does not help you connect one terminal line to one source file and one category-theory word, it has not transferred back into the map yet.
Practice After This Chapter
Use Exercise 2 to change the demo input and Exercise 8 to connect one source file back to the course map. Use the demo-output wayfinding checklist above to decide which file to inspect. Those exercises check whether the map is active knowledge rather than only a diagram you read once.
Retrieval Practice
Recall
Name the three readings used throughout the book.
Explain
Why does the book start with a whole-pipeline map before reading individual source files?
Apply
Write a one-line diagram for a pipeline you already know, then label the input object, arrow, and output object.
Domain Objects
The problem this chapter solves is:
A machine-learning pipeline should not pass raw numbers around and hope everyone remembers what each number means.
Before this code talks about arrows, composition, loss, or training, it defines the objects those arrows will connect.
In this course, a domain object means:
raw representation
+ a meaningful name
+ optional validation
+ controlled access
For example:
usize
could mean:
- a token index
- a vocabulary size
- a model dimension
- a matrix row count
- a training step count
Those are different concepts.
So the code creates different types.
Reader orientation: In this chapter, focus on why a type exists before focusing on its syntax. A tuple struct, private field, constructor, or accessor is not decoration. It is a small boundary that tells the rest of the pipeline which states it may trust.
Chapter Outcomes
By the end of this chapter, you should be able to:
- explain why
TokenId,VocabSize,ModelDimension, andStepCountshould not all be rawusizevalues at the teaching boundary, - separate semantic wrappers from validated objects,
- name one invalid ML state that each constructor prevents before prediction, loss, or training sees it.
What You Already Know
If you have used a Rust struct, you already know that a value can carry a name instead of floating around as raw data. If you have used an ML pipeline, you already know that a token index, a vector, a probability distribution, and a loss value play different roles. This chapter turns that familiar separation into explicit domain types.
The important move is not “wrap everything because wrappers are nice.” The important move is to ask what the rest of the pipeline is allowed to trust. Some types only separate meanings. Other types also reject invalid values before they can enter prediction, loss, or training.
Worked Example: Naming One Number
The smallest version of the pattern looks like this:
#![allow(unused)]
fn main() {
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
struct TokenId(usize);
impl TokenId {
fn new(index: usize) -> Self {
Self(index)
}
fn index(self) -> usize {
self.0
}
}
assert_eq!(TokenId::new(3).index(), 3);
}
The real source file repeats that pattern with stronger validation where the value has an invariant, such as “a distribution must contain probabilities that sum to one.”
Self-Check
Before reading the full source snapshot, explain why TokenId(3) communicates
more than the raw number 3.
Two Kinds Of Domain Objects
Read the file with this distinction in mind.
| Kind | Example | What the type gives the pipeline |
|---|---|---|
| Semantic wrapper | TokenId, Vector, Logits | A name that prevents one raw representation from being confused with another |
| Validated object | TokenSequence, Distribution, Loss, VocabSize, LearningRate | A constructor that rejects states later code should not have to handle |
Both kinds matter. A TokenId is useful even though any usize can become a
token ID at this layer, because it prevents accidental mixing with dimensions
or row counts. A Distribution needs a stronger boundary, because not every
Vec<f32> is a valid probability distribution.
This is the Rust API idea behind the chapter: put meaning and validation near construction, then expose small accessors for the raw representation when lower level code really needs it.
The domain-boundary diagram is:
[ \begin{array}{ccccc} \mathrm{usize} & \xrightarrow{\mathrm{TokenId::new}} & \mathrm{TokenId} & \xrightarrow{\mathrm{Embedding}} & \mathrm{Vector} \ \mathrm{Vec}\langle\mathrm{TokenId}\rangle & \xrightarrow{\mathrm{TokenSequence::new}} & \mathrm{TokenSequence} & \xrightarrow{\mathrm{DatasetWindowing}} & \mathrm{TrainingSet} \ \mathrm{Vec}\langle f32\rangle & \xrightarrow{\mathrm{Distribution::new}} & \mathrm{Distribution} & \xrightarrow{\mathrm{Product(-, target)}} & \mathrm{Product}\langle\mathrm{Distribution},\mathrm{TokenId}\rangle \end{array} ]
How to read this diagram:
- the left column is raw representation,
- the first arrow is the constructor or naming boundary,
- the middle object is what downstream code is allowed to trust,
- the last arrow is the first later stage that benefits from the boundary,
- redrawing the diagram should tell you which rows are semantic wrappers and which rows validate an invariant.
The diagram is deliberately modest. It does not claim that TokenId::new
checks membership in a real tokenizer vocabulary. It does claim that once code
asks for a TokenId, a reader no longer has to wonder whether the value is a
model dimension, loop index, or training step count.
Mistakes These Types Prevent
Before reading the whole file, scan the reason each type exists. The point is not to wrap values for style. The point is to make common pipeline mistakes harder to express.
| Domain type | Raw representation it replaces | Concrete mistake it prevents |
|---|---|---|
TokenId | usize | passing a vocabulary index where a model dimension or row count was expected |
TokenSequence | Vec<TokenId> | training on an empty sequence or mutating a validated sequence after construction |
Vector | Vec<f32> | treating hidden features as if they were vocabulary scores |
Logits | Vec<f32> | treating raw scores as if they were probabilities |
Distribution | Vec<f32> | computing loss from negative, non-finite, empty, or non-normalized probabilities |
Loss | f32 | accumulating a negative or non-finite objective value |
VocabSize | usize | constructing parameters for a zero-token vocabulary |
ModelDimension | usize | constructing embedding rows with zero width |
LearningRate | f32 | applying an optimizer step with zero, negative, or non-finite step size |
TrainingSet | Vec<TrainingExample> | running training on no examples |
Parameters | loose matrices and bias vectors | scattering model state across unrelated arrays without one named owner |
Use this table as the chapter’s review checklist. When a later section shows syntax, ask which mistake the syntax blocks.
Source-Backed Precision Rules
This chapter uses Rust sources to keep the “domain object” claim precise. Each source supports one local teaching rule, and each rule is tied to a concrete constructor, accessor, example, or test. The chapter does not claim that every wrapper is fully validated. Some types only separate meanings; other types reject invalid states at construction.
| Source | What the source supports | Local rule in this chapter | Rust evidence |
|---|---|---|---|
| Rust Book: Structs | Structs and tuple structs give data a named type, even when the stored representation is small. | Use TokenId, Vector, and Logits to separate meanings that would otherwise share usize or Vec<f32>. | TokenId(usize), Vector(Vec<f32>), Logits(Vec<f32>) |
| Rust By Example: New Type Idiom | A wrapper type can make the compiler require the intended semantic role before a value enters a function. | Treat TokenId, VocabSize, and ModelDimension as compile-time role labels before adding heavier validation. | TokenId, VocabSize, ModelDimension |
| Rust Book: Result | Result<T, E> represents an operation that may either return a success value or an error value. | Use fallible constructors when raw input may violate an invariant. | TokenSequence::new, Distribution::new, Loss::new, LearningRate::new |
| Rust API Guidelines: Type Safety | Newtypes provide static distinctions and arguments should convey meaning through custom types. | Do not let usize, f32, or Vec<f32> cross teaching boundaries when they mean different ML concepts. | VocabSize, ModelDimension, LearningRate, Product<Distribution, TokenId> |
| Rust API Guidelines: Dependability | Functions should validate their arguments when invalid values would break later assumptions. | Validate once at construction, then let downstream morphisms trust the object. | distribution_rejects_non_normalized_values, token_sequence_rejects_empty_input |
| Rust API Guidelines: Future Proofing | Private fields and encapsulated newtypes protect invariants and implementation details. | Expose small accessors such as as_slice, value, and index instead of public mutable fields. | TokenSequence(Vec<TokenId>), Distribution(Vec<f32>), Parameters accessors |
The transfer pattern is:
source rule -> local domain type -> constructor, accessor, or test evidence
For this chapter, that means reading cargo run --example 01_domain_objects
and cargo test domain::tests as evidence for the small boundary claims:
TokenSequence is non-empty
Distribution is non-empty, finite, non-negative, and normalized
shape and training configuration values are not interchangeable
It is not evidence that every future ML value has already been modeled. It is evidence that the chapter’s first layer of objects has explicit names, construction boundaries, and validation where the later pipeline depends on an invariant.
Primitive-To-Domain Responsibility Ledger
Use this ledger whenever a raw value crosses into the tiny ML pipeline. The question is not only “what type wraps this value?” The question is “which boundary now owns the meaning, and what is downstream code allowed to trust?”
| Raw value | Domain object | Constructor or boundary | Invariant owned here | Downstream code may trust | Unsafe shortcut rejected | Source-backed limit | Validation command |
|---|---|---|---|---|---|---|---|
usize | TokenId | TokenId::new(index) | semantic role label only; vocabulary membership is checked later by lookup code | this value is being used as a token index, not a dimension or step count | passing bare usize through morphism boundaries | a newtype name does not prove the index exists in a specific vocabulary | cargo test domain::tests --lib |
Vec<TokenId> | TokenSequence | TokenSequence::new(tokens) | sequence is non-empty | dataset windowing can ask for adjacent pairs without handling an empty sequence as a valid training path | accepting any raw vector as sequence data | non-empty does not prove the sequence is long enough for every downstream task; each later boundary still owns its own check | cargo test domain::tests::token_sequence_rejects_empty_input --lib |
Vec<f32> | Distribution | Distribution::new(probabilities) | values are finite, non-negative, non-empty, and sum to one within the local tolerance | CrossEntropy can read a probability assigned to the target token | treating logits or arbitrary floats as probabilities | this proves a local normalized vector, not calibration, statistical quality, or framework equivalence | cargo test domain::tests::distribution_rejects_non_normalized_values --lib |
usize, usize | Parameters | Parameters::init(VocabSize, ModelDimension) | vocabulary size and model dimension have already rejected zero | model state has one owner for embedding rows, output head, and bias | constructing loose matrices with swapped or zero shape inputs | deterministic teaching initialization is not production initialization | cargo run --example 01_domain_objects |
The first row is intentionally different from the third row. TokenId::new
only gives a number a role. Distribution::new rejects invalid probability
mass. Both are domain boundaries, but they own different kinds of
responsibility.
This distinction protects the rest of the book from two common mistakes:
mistake 1: "Every wrapper validates everything."
mistake 2: "If a type stores a primitive, it is only decoration."
The right reading is narrower:
semantic wrapper:
prevents role confusion at typed boundaries
validated object:
prevents a specific invalid state before later code can trust the value
Source Snapshot
This is the domain layer used by the whole tutorial.
Source snapshot: src/domain.rs
use crate::error::{CtError, CtResult};
/// A vocabulary index. It is intentionally not a raw `usize` in public APIs.
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub struct TokenId(usize);
impl TokenId {
pub fn new(index: usize) -> Self {
Self(index)
}
pub fn index(&self) -> usize {
self.0
}
}
impl From<usize> for TokenId {
fn from(value: usize) -> Self {
Self::new(value)
}
}
/// A sequence of tokens before it has been converted into training pairs.
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct TokenSequence(Vec<TokenId>);
impl TokenSequence {
pub fn new(tokens: impl IntoIterator<Item = TokenId>) -> CtResult<Self> {
let tokens = tokens.into_iter().collect::<Vec<_>>();
if tokens.is_empty() {
return Err(CtError::EmptyInput("token sequence"));
}
Ok(Self(tokens))
}
pub fn from_indices(indices: impl IntoIterator<Item = usize>) -> CtResult<Self> {
Self::new(indices.into_iter().map(TokenId::new))
}
pub fn as_slice(&self) -> &[TokenId] {
&self.0
}
}
/// A dense feature vector.
#[derive(Debug, Clone, PartialEq)]
pub struct Vector(Vec<f32>);
impl Vector {
pub fn new(values: Vec<f32>) -> Self {
Self(values)
}
pub fn as_slice(&self) -> &[f32] {
&self.0
}
}
/// Unnormalized model scores.
#[derive(Debug, Clone, PartialEq)]
pub struct Logits(Vec<f32>);
impl Logits {
pub fn new(values: Vec<f32>) -> Self {
Self(values)
}
pub fn as_slice(&self) -> &[f32] {
&self.0
}
}
/// A validated probability distribution.
#[derive(Debug, Clone, PartialEq)]
pub struct Distribution(Vec<f32>);
impl Distribution {
pub fn new(probabilities: Vec<f32>) -> CtResult<Self> {
if probabilities.is_empty() {
return Err(CtError::EmptyInput("distribution"));
}
let sum: f32 = probabilities.iter().sum();
let all_valid = probabilities
.iter()
.all(|probability| probability.is_finite() && *probability >= 0.0);
if !all_valid || !approx_eq(sum, 1.0, 1e-4) {
return Err(CtError::InvalidProbability("distribution constructor"));
}
Ok(Self(probabilities))
}
pub fn as_slice(&self) -> &[f32] {
&self.0
}
}
/// A non-negative scalar objective value.
#[derive(Debug, Clone, Copy, PartialEq)]
pub struct Loss(f32);
impl Loss {
pub fn new(value: f32) -> CtResult<Self> {
if !value.is_finite() || value < 0.0 {
return Err(CtError::InvalidLoss(value));
}
Ok(Self(value))
}
pub fn value(&self) -> f32 {
self.0
}
}
/// Number of vocabulary entries.
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct VocabSize(usize);
impl VocabSize {
pub fn new(value: usize) -> CtResult<Self> {
if value == 0 {
return Err(CtError::EmptyInput("vocabulary"));
}
Ok(Self(value))
}
pub fn value(&self) -> usize {
self.0
}
}
/// Width of each embedding vector.
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct ModelDimension(usize);
impl ModelDimension {
pub fn new(value: usize) -> CtResult<Self> {
if value == 0 {
return Err(CtError::EmptyInput("model dimension"));
}
Ok(Self(value))
}
pub fn value(&self) -> usize {
self.0
}
}
/// Positive optimizer step size.
#[derive(Debug, Clone, Copy, PartialEq)]
pub struct LearningRate(f32);
impl LearningRate {
pub fn new(value: f32) -> CtResult<Self> {
if !value.is_finite() || value <= 0.0 {
return Err(CtError::InvalidLearningRate(value));
}
Ok(Self(value))
}
pub fn value(&self) -> f32 {
self.0
}
}
/// Categorical product object: `A x B`.
#[derive(Debug, Clone, PartialEq)]
pub struct Product<A, B> {
first: A,
second: B,
}
impl<A, B> Product<A, B> {
pub fn new(first: A, second: B) -> Self {
Self { first, second }
}
pub fn first(&self) -> &A {
&self.first
}
pub fn second(&self) -> &B {
&self.second
}
pub fn into_parts(self) -> (A, B) {
(self.first, self.second)
}
}
pub type TrainingExample = Product<TokenId, TokenId>;
/// Non-empty next-token training pairs.
#[derive(Debug, Clone, PartialEq)]
pub struct TrainingSet(Vec<TrainingExample>);
impl TrainingSet {
pub fn new(examples: impl IntoIterator<Item = TrainingExample>) -> CtResult<Self> {
let examples = examples.into_iter().collect::<Vec<_>>();
if examples.is_empty() {
return Err(CtError::EmptyInput("training set"));
}
Ok(Self(examples))
}
pub fn examples(&self) -> &[TrainingExample] {
&self.0
}
pub fn len(&self) -> usize {
self.0.len()
}
pub fn is_empty(&self) -> bool {
self.0.is_empty()
}
}
/// Tiny model parameters for an embedding plus language-model head.
#[derive(Debug, Clone, PartialEq)]
pub struct Parameters {
pub(crate) embedding: Vec<Vec<f32>>,
pub(crate) lm_head: Vec<Vec<f32>>,
pub(crate) bias: Vec<f32>,
}
impl Parameters {
pub fn init(vocab_size: VocabSize, d_model: ModelDimension) -> Self {
let vocab_size = vocab_size.value();
let d_model = d_model.value();
Self {
embedding: init_matrix(vocab_size, d_model, 0.2, 1),
lm_head: init_matrix(d_model, vocab_size, 0.2, 2),
bias: vec![0.0; vocab_size],
}
}
pub fn vocab_size(&self) -> usize {
self.bias.len()
}
pub fn d_model(&self) -> usize {
self.embedding.first().map_or(0, Vec::len)
}
pub fn embedding_table(&self) -> &[Vec<f32>] {
&self.embedding
}
pub fn lm_head(&self) -> &[Vec<f32>] {
&self.lm_head
}
pub fn bias(&self) -> &[f32] {
&self.bias
}
}
pub(crate) fn init_matrix(rows: usize, cols: usize, scale: f32, seed: usize) -> Vec<Vec<f32>> {
let mut out = vec![vec![0.0; cols]; rows];
for (row_index, row) in out.iter_mut().enumerate() {
for (col_index, value) in row.iter_mut().enumerate() {
let raw = ((row_index * cols + col_index) * 37 + seed * 101) % 1000;
let unit = raw as f32 / 1000.0;
*value = (unit - 0.5) * scale;
}
}
out
}
pub(crate) fn approx_eq(a: f32, b: f32, eps: f32) -> bool {
(a - b).abs() <= eps
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn distribution_rejects_non_normalized_values() {
let result = Distribution::new(vec![0.4, 0.4]);
assert!(matches!(result, Err(CtError::InvalidProbability(_))));
}
#[test]
fn token_sequence_rejects_empty_input() {
let result = TokenSequence::new(vec![]);
assert!(matches!(result, Err(CtError::EmptyInput("token sequence"))));
}
}
The Whole File
src/domain.rs defines the nouns in the tiny ML system:
TokenId
TokenSequence
Vector
Logits
Distribution
Loss
VocabSize
ModelDimension
LearningRate
Product
TrainingExample
TrainingSet
Parameters
The ML pipeline needs all of them:
TokenSequence -> TrainingSet
TokenId -> Vector
Vector -> Logits
Logits -> Distribution
Distribution x TokenId -> Loss
Parameters -> Parameters
The category-theory reading is:
These are the objects that morphisms start from and end at.
The Rust reading is:
These are wrapper types that prevent raw representation from leaking through the whole program.
Each major block below is meant to be read through three lenses:
Rust syntax:
what does the code literally declare?
ML concept:
why does the training pipeline need this value?
Category theory concept:
what object, product, list, distribution, or morphism endpoint does it model?
The chapter follows the same order as the model pipeline. First it names token data. Then it names hidden representations and probabilities. Then it names loss, configuration, paired inputs, training data, and model state.
TokenId
The problem this block solves is:
A token index should not be confused with any other
usize.
The block:
/// A vocabulary index. It is intentionally not a raw `usize` in public APIs.
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub struct TokenId(usize);
impl TokenId {
pub fn new(index: usize) -> Self {
Self(index)
}
pub fn index(&self) -> usize {
self.0
}
}
impl From<usize> for TokenId {
fn from(value: usize) -> Self {
Self::new(value)
}
}
Rust Syntax
TokenId is a tuple struct with one private field:
pub struct TokenId(usize);
The struct is public, but the field is private.
That means other modules can name TokenId, pass it around, and call its
methods, but they cannot directly reach inside and mutate the raw usize.
Why new Cannot Fail
pub fn new(index: usize) -> Self
Every usize is a valid token index at this layer.
The code does not know yet whether the token is inside a particular vocabulary. That check happens later when a morphism tries to look up an embedding row.
So TokenId::new is infallible.
Why index Exists
pub fn index(&self) -> usize {
self.0
}
This accessor gives read-only access to the raw index when low-level code needs it.
The type still prevents accidental mixing at the API boundary.
ML Concept
In ML terms, TokenId is a vocabulary position.
If the vocabulary is:
0 = <pad>
1 = I
2 = love
3 = Rust
4 = .
then:
TokenId::new(3)
means the token Rust.
Category Theory Concept
TokenId is one object in the category of this program’s typed values.
Arrows such as Embedding start from this object:
TokenId -> Vector
TokenSequence
The problem this block solves is:
A language model does not train directly on raw text. First, text becomes a sequence of token IDs. Then that sequence becomes input-target training pairs.
This block represents the middle stage:
raw text
-> tokens
-> token sequence
-> training examples
-> model training
The block:
/// A sequence of tokens before it has been converted into training pairs.
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct TokenSequence(Vec<TokenId>);
impl TokenSequence {
pub fn new(tokens: impl IntoIterator<Item = TokenId>) -> CtResult<Self> {
let tokens = tokens.into_iter().collect::<Vec<_>>();
if tokens.is_empty() {
return Err(CtError::EmptyInput("token sequence"));
}
Ok(Self(tokens))
}
pub fn from_indices(indices: impl IntoIterator<Item = usize>) -> CtResult<Self> {
Self::new(indices.into_iter().map(TokenId::new))
}
pub fn as_slice(&self) -> &[TokenId] {
&self.0
}
}
Rust Syntax: Documentation Comment
/// A sequence of tokens before it has been converted into training pairs.
This tells you the pipeline stage.
TokenSequence is not raw text.
It is also not yet training data.
It is the ordered token stream before adjacent pairs are created.
Example:
[TokenId(1), TokenId(2), TokenId(3)]
can later become:
TokenId(1) -> TokenId(2)
TokenId(2) -> TokenId(3)
Rust Syntax: Derived Traits
#[derive(Debug, Clone, PartialEq, Eq)]
Debug allows test and debugging output.
Clone allows an explicit copy of the sequence.
PartialEq allows equality checks.
Eq says equality is total and well-behaved.
Order matters. These are not equal:
[TokenId(1), TokenId(2)]
[TokenId(2), TokenId(1)]
Rust Syntax: Private Vector
pub struct TokenSequence(Vec<TokenId>);
This wraps:
Vec<TokenId>
but does not expose the vector directly.
That is important because the type’s invariant is:
TokenSequence is non-empty.
If the field were public, a caller could construct:
TokenSequence(vec![])
and bypass validation.
The private field forces construction through TokenSequence::new or
TokenSequence::from_indices.
Rust Syntax: Constructor
pub fn new(tokens: impl IntoIterator<Item = TokenId>) -> CtResult<Self>
This accepts any input that can produce TokenId values:
- a vector
- an array
- a mapped iterator
The return type is:
CtResult<TokenSequence>
So construction can succeed or fail.
Rust Syntax: Collection
let tokens = tokens.into_iter().collect::<Vec<_>>();
This turns the flexible input into the concrete representation stored inside the struct.
The _ means Rust infers the element type as TokenId.
Rust Syntax: Empty Check
if tokens.is_empty() {
return Err(CtError::EmptyInput("token sequence"));
}
This is the invariant boundary.
An empty token stream cannot carry useful sequence information.
The error happens immediately, before invalid data enters the rest of the pipeline.
Rust Syntax: Successful Construction
Ok(Self(tokens))
Inside the impl, Self means TokenSequence.
So this is equivalent to:
Ok(TokenSequence(tokens))
The vector has already been validated, so the object is safe for later code to trust.
Rust Syntax: Convenience Constructor
pub fn from_indices(indices: impl IntoIterator<Item = usize>) -> CtResult<Self> {
Self::new(indices.into_iter().map(TokenId::new))
}
This accepts raw indices and converts each one into TokenId.
The important design choice is delegation:
from_indices -> new
Validation is not duplicated.
All construction still passes through the same non-empty check.
Rust Syntax: Read-Only Access
pub fn as_slice(&self) -> &[TokenId] {
&self.0
}
This returns a borrowed slice.
Callers can inspect the sequence, but they cannot clear it, push to it, or replace the internal vector.
That preserves the invariant after construction.
ML Concept
TokenSequence is tokenized text before next-token examples are created.
A sequence of length n can produce n - 1 adjacent prediction pairs.
Category Theory Concept
TokenSequence behaves like:
List+ TokenId
where List+ means a non-empty finite list.
Its constructor is not:
List TokenId -> TokenSequence
because the empty list is invalid.
It is:
List TokenId -> Result TokenSequence CtError
Rust turns the partial construction into a total function by using Result.
Vector and Logits
The problem these blocks solve is:
A dense hidden vector and raw vocabulary scores are both
Vec<f32>, but they do not mean the same thing.
The blocks:
#[derive(Debug, Clone, PartialEq)]
pub struct Vector(Vec<f32>);
impl Vector {
pub fn new(values: Vec<f32>) -> Self {
Self(values)
}
pub fn as_slice(&self) -> &[f32] {
&self.0
}
}
#[derive(Debug, Clone, PartialEq)]
pub struct Logits(Vec<f32>);
impl Logits {
pub fn new(values: Vec<f32>) -> Self {
Self(values)
}
pub fn as_slice(&self) -> &[f32] {
&self.0
}
}
Rust Syntax
Vector means hidden features.
Logits means unnormalized scores.
Both wrap Vec<f32>.
The distinction matters because only this arrow should produce logits:
Vector -> Logits
and only this arrow should normalize logits:
Logits -> Distribution
If both were plain Vec<f32>, the compiler could not help keep those stages
separate.
These types derive PartialEq, but not Eq, because they contain f32.
Floating-point values do not have total equality because NaN is not equal to
itself.
ML Concept
A Vector is the dense representation used after embedding lookup.
Example:
TokenId(3) -> [0.12, -0.44, 0.88, 0.03]
Logits are raw vocabulary scores.
Example:
[3.0, 1.0, -2.0]
They can be negative, larger than one, and do not need to sum to one.
The pipeline is:
TokenId -> Vector -> Logits -> Distribution
Category Theory Concept
If the model dimension is d, a vector lives in a vector-space-like object:
R^d
If the vocabulary size is V, logits live in:
R^V
The output projection is an arrow:
R^d -> R^V
and softmax maps:
R^V -> probability distributions over TokenId
Distribution
The problem this block solves is:
Probabilities are not just floats. A probability distribution must be non-empty, finite, non-negative, and sum to one.
The core block:
#[derive(Debug, Clone, PartialEq)]
pub struct Distribution(Vec<f32>);
impl Distribution {
pub fn new(probabilities: Vec<f32>) -> CtResult<Self> {
if probabilities.is_empty() {
return Err(CtError::EmptyInput("distribution"));
}
let sum: f32 = probabilities.iter().sum();
let all_valid = probabilities
.iter()
.all(|probability| probability.is_finite() && *probability >= 0.0);
if !all_valid || !approx_eq(sum, 1.0, 1e-4) {
return Err(CtError::InvalidProbability("distribution constructor"));
}
Ok(Self(probabilities))
}
}
Rust Syntax: Why Construction Can Fail
This is invalid:
[]
This is invalid:
[0.4, 0.4]
because it sums to 0.8, not 1.0.
This is invalid:
[1.2, -0.2]
because probabilities cannot be negative.
So Distribution::new returns CtResult<Self>.
Rust Syntax: The Sum Check
let sum: f32 = probabilities.iter().sum();
This computes the total probability mass.
The code uses approximate equality:
approx_eq(sum, 1.0, 1e-4)
because floating-point arithmetic is not exact.
ML Concept
This is the output of softmax:
Logits -> Distribution
The rest of the model can treat a Distribution as real probabilities because
the constructor checked the rule.
Category Theory Concept
Distribution is an object with a stronger invariant than Vec<f32>.
The softmax morphism lands in this object only if it can produce valid probability mass.
Loss
The problem this block solves is:
A loss value must be a real, non-negative scalar.
The block:
#[derive(Debug, Clone, Copy, PartialEq)]
pub struct Loss(f32);
impl Loss {
pub fn new(value: f32) -> CtResult<Self> {
if !value.is_finite() || value < 0.0 {
return Err(CtError::InvalidLoss(value));
}
Ok(Self(value))
}
pub fn value(&self) -> f32 {
self.0
}
}
Rust Syntax
Loss::new rejects:
- infinity
- not-a-number values
- negative values
Cross entropy should not produce a negative loss. If it does, something has gone wrong before or during loss calculation.
Loss derives Copy because it wraps one small scalar.
Calling value() returns the raw f32 for printing, comparison, or averaging.
ML Concept
Loss is the training signal.
For next-token prediction:
loss = -log(probability assigned to the correct token)
Lower loss means the model assigned more probability to the correct answer.
Training tries to reduce this value.
Category Theory Concept
Loss is the codomain of an evaluation morphism:
Distribution x TokenId -> Loss
It maps prediction plus truth into a non-negative scalar objective.
Shape and Training Hyperparameter Types
The problem these blocks solve is:
Dimensions and learning rates need boundary checks before they are used to allocate matrices or update parameters.
The types are:
VocabSize
ModelDimension
LearningRate
Rust Syntax
VocabSize::new(0) fails because a vocabulary with zero entries is unusable.
ModelDimension::new(0) fails because an embedding vector with zero width
cannot carry features.
LearningRate::new(value) fails when the value is not finite or is not
positive.
These checks keep bad configuration from becoming strange matrix behavior later.
Worked Example: Configuration Values Are Not Interchangeable
The raw representation for all three values is small:
VocabSize -> usize
ModelDimension -> usize
LearningRate -> f32
That can make them look like ordinary numbers. They are not ordinary once they cross the model boundary.
Parameters::init in src/domain.rs makes the distinction concrete:
let parameters = Parameters::init(
VocabSize::new(5)?,
ModelDimension::new(2)?,
);
The first argument chooses how many vocabulary rows and output scores exist.
The second argument chooses how wide each hidden vector is. Swapping those
meanings would create a different model shape, even though both values are
stored as usize underneath.
The same rule applies to LearningRate. It is not a loss value, probability,
or model dimension. It controls how far one update moves the parameters:
new parameter = old parameter - learning_rate * gradient
If the learning rate were zero, negative, infinite, or not-a-number, the update would stop being the small controlled movement the training chapter needs. That is why construction fails early.
ML reading:
VocabSize -> how many token classes the model can score
ModelDimension -> how much hidden capacity each token receives
LearningRate -> how large each optimizer step is
Category-theory reading:
VocabSize helps choose the finite token object, ModelDimension helps choose
the intermediate representation object, and LearningRate selects one update
from a family of possible training endomorphisms. The values are configuration
for different parts of the typed system, not interchangeable numbers.
Checkpoint question:
If you see the raw value 5, what extra information tells you whether it is a
vocabulary size, model dimension, token id, or step count?
ML Concept
VocabSize controls:
embedding rows
logit length
distribution length
bias length
ModelDimension controls embedding width:
R^d
LearningRate controls optimizer step size:
parameter = parameter - learning_rate * gradient
Category Theory Concept
VocabSize describes the cardinality of the finite token object.
ModelDimension chooses the intermediate vector-space-like object.
LearningRate chooses one update morphism from a family of training
endomorphisms.
Product<A, B>
The problem this block solves is:
Some ML operations need two inputs that belong together.
The block:
#[derive(Debug, Clone, PartialEq)]
pub struct Product<A, B> {
first: A,
second: B,
}
This is a generic pair.
It is used in two important places:
pub type TrainingExample = Product<TokenId, TokenId>;
and:
Product<Distribution, TokenId> -> Loss
Rust Syntax: Why Not A Tuple Everywhere?
Rust tuples like (A, B) would work mechanically.
Product<A, B> makes the category-theory idea visible:
A x B
It also gives named methods:
first()
second()
into_parts()
Those methods make call sites easier to read during the course.
ML Concept
Product<TokenId, TokenId> is one supervised next-token example:
input token x target token
Product<Distribution, TokenId> is the input to cross entropy:
prediction x target
Category Theory Concept
Product<A, B> is the course’s named version of:
A x B
The accessors are projection-like operations:
first ~ pi_1
second ~ pi_2
TrainingSet
The problem this block solves is:
Training should not run on an empty collection of examples.
The block:
#[derive(Debug, Clone, PartialEq)]
pub struct TrainingSet(Vec<TrainingExample>);
impl TrainingSet {
pub fn new(examples: impl IntoIterator<Item = TrainingExample>) -> CtResult<Self> {
let examples = examples.into_iter().collect::<Vec<_>>();
if examples.is_empty() {
return Err(CtError::EmptyInput("training set"));
}
Ok(Self(examples))
}
}
This mirrors TokenSequence.
The internal vector is private.
Construction validates non-emptiness.
Callers get read-only access through:
pub fn examples(&self) -> &[TrainingExample]
Rust Syntax: Why is_empty Exists If Empty Is Impossible
TrainingSet includes:
pub fn is_empty(&self) -> bool {
self.0.is_empty()
}
For values constructed through TrainingSet::new, this should always be
false.
The method exists because collection-like types conventionally expose both
len and is_empty, and tests or generic code may use it.
The invariant is still protected by private storage and the constructor.
ML Concept
A TrainingSet is a non-empty list of next-token examples.
For:
[10, 25, 31, 7]
the training set is:
(10, 25)
(25, 31)
(31, 7)
Category Theory Concept
The shape is:
non-empty list of (TokenId x TokenId)
or:
List+ (TokenId x TokenId)
Parameters
The problem this block solves is:
Training needs one object that owns all trainable model state.
The block:
#[derive(Debug, Clone, PartialEq)]
pub struct Parameters {
pub(crate) embedding: Vec<Vec<f32>>,
pub(crate) lm_head: Vec<Vec<f32>>,
pub(crate) bias: Vec<f32>,
}
The model has three pieces:
embedding table
lm head matrix
bias vector
The fields are pub(crate), not fully public.
That means code inside this crate can update parameters during training, but external callers use accessors.
Rust Syntax: Initialization
pub fn init(vocab_size: VocabSize, d_model: ModelDimension) -> Self
This takes validated domain values, not raw usize.
That means matrix allocation starts from:
non-empty vocabulary
positive model dimension
The initialized shapes are:
embedding: vocab_size x d_model
lm_head: d_model x vocab_size
bias: vocab_size
ML Concept
Parameters is the trainable state.
Prediction reads it.
Training maps it back to a new Parameters value:
Parameters -> Parameters
Category Theory Concept
Parameters is the object of the training endomorphism.
The important point is not that the numbers change.
The important point is that the type remains the same.
Utility Functions
The file ends with:
pub(crate) fn init_matrix(...)
pub(crate) fn approx_eq(...)
init_matrix is local deterministic initialization for the teaching model.
approx_eq is a small floating-point helper used by probability checks and
composition tests.
Both are crate-internal implementation details, not learner-facing domain objects.
Runnable Example
The domain example shows token IDs becoming training pairs:
Source snapshot: examples/01_domain_objects.rs
use category_theory_transformer_rs::{
CtResult, DatasetWindowing, Morphism, TokenId, TokenSequence,
};
fn main() -> CtResult<()> {
let tokens = TokenSequence::from_indices([1, 2, 3, 4])?;
let dataset = DatasetWindowing.apply(tokens.clone())?;
println!("TokenSequence:");
println!("{}", format_token_sequence(tokens.as_slice()));
println!();
println!("TrainingSet:");
for example in dataset.examples() {
println!(
"({} -> {})",
format_token_id(example.first()),
format_token_id(example.second())
);
}
println!();
println!("Typed boundaries:");
println!("usize -> TokenId");
println!("Vec<TokenId> -> TokenSequence");
println!("TokenSequence -> TrainingSet");
println!("TrainingExample = Product<TokenId, TokenId>");
Ok(())
}
fn format_token_sequence(tokens: &[TokenId]) -> String {
let formatted = tokens
.iter()
.map(format_token_id)
.collect::<Vec<_>>()
.join(", ");
format!("[{formatted}]")
}
fn format_token_id(token: &TokenId) -> String {
format!("TokenId({})", token.index())
}
Run:
cargo run --example 01_domain_objects
Expected shape:
TokenSequence:
[TokenId(1), TokenId(2), TokenId(3), TokenId(4)]
TrainingSet:
(TokenId(1) -> TokenId(2))
(TokenId(2) -> TokenId(3))
(TokenId(3) -> TokenId(4))
Typed boundaries:
usize -> TokenId
Vec<TokenId> -> TokenSequence
TokenSequence -> TrainingSet
TrainingExample = Product<TokenId, TokenId>
Example Output Transfer Checklist
Use the example output to test whether the chapter’s boundary idea is working.
| Example output | Rust reading | ML reading | Category-theory reading | Shortcut to reject |
|---|---|---|---|---|
TokenSequence | a private Vec<TokenId> has passed the non-empty constructor | tokenized data is ready for adjacent-pair creation | non-empty list-like object | treating any Vec<TokenId> as valid sequence data |
TrainingSet | DatasetWindowing returned validated examples | adjacent input-target pairs are ready for training | list of product objects | training directly on a raw token list |
usize -> TokenId | a raw index receives a domain name | a number becomes a vocabulary position | raw representation enters a typed object | passing row counts, dimensions, and token IDs as the same usize |
Vec<TokenId> -> TokenSequence | a collection crosses a constructor boundary | tokenized text becomes an ordered sequence stage | partial construction into Result<TokenSequence, CtError> | allowing an empty sequence downstream |
TrainingExample = Product<TokenId, TokenId> | a pair has a named product shape | input token and target token travel together | TokenId x TokenId | using an unlabelled tuple and forgetting which side is the target |
This is why the example prints TokenId(...) instead of only 1 -> 2.
Display can use raw numbers at the edge, but the teaching output should remind
you that the program is moving through named objects.
Why This Matters
The main design rule is:
Use raw primitives only at the edge where they are created or displayed.
After that, use domain types.
This prevents mistakes like:
passing a model dimension where a token ID was expected
passing logits where probabilities were expected
training on an empty dataset
using a negative learning rate
Types do not prove that a model is good, that optimization will always converge, or that the tiny implementation is production-ready. They do something narrower and very useful: they make the wrong wiring harder to write. That is the first step toward a pipeline that can be explained, tested, and extended.
Core Mental Model
src/domain.rs turns raw storage into trustworthy objects.
In Rust terms:
private fields + smart constructors + accessors
In ML terms:
tokens, vectors, probabilities, loss, and model weights
In category-theory terms:
objects that morphisms can safely connect
Checkpoint
Pick one type from this file and explain:
- What raw representation it wraps.
- What invalid state it prevents.
- Which morphism consumes or produces it.
Example:
Distribution wraps Vec<f32>, rejects invalid probability mass, and is produced
by Softmax before CrossEntropy consumes it.
Where This Leaves Us
This chapter gave names to the values in the system. A token id is not a model dimension, logits are not probabilities, and a training set is not just any vector of pairs. Each type marks a stage where raw storage becomes a meaningful object.
The next chapter, Morphism and Composition, adds arrows between those objects. Once the arrows exist, the book can talk about identity, composition, and repeated transformations without falling back to loose wiring conventions.
Further Reading
Do not read these sources as generic Rust advice. Read them as a way to answer one question:
what is the pipeline allowed to trust after this value is constructed?
Start from the local Rust evidence:
TokenId::new(index) -> TokenId
TokenSequence::new(tokens) -> Result<TokenSequence, CtError>
Distribution::new(probabilities) -> Result<Distribution, CtError>
LearningRate::new(value) -> Result<LearningRate, CtError>
Parameters::init(vocab, dim) -> Parameters
Then read the sources in this order:
| Source | What to transfer back into this chapter | Local evidence to inspect |
|---|---|---|
| Rust Book: Structs | A named struct or tuple struct can make two identical raw representations mean different things. | TokenId(usize), VocabSize(usize), ModelDimension(usize) |
| Rust By Example: New Type Idiom | A small wrapper can make the compiler reject the wrong semantic role before runtime logic runs. | TokenId, VocabSize, ModelDimension |
| Rust Book: Result | A constructor can return either a trusted value or a typed error. | TokenSequence::new, Distribution::new, Loss::new, LearningRate::new |
| Rust API Guidelines: Type Safety | Newtypes provide static distinctions when raw arguments would hide meaning. | Product<Distribution, TokenId>, LearningRate, ModelDimension |
| Rust API Guidelines: Dependability | Invalid arguments should be rejected at the boundary that owns the invariant. | distribution_rejects_non_normalized_values, token_sequence_rejects_empty_input |
| Rust API Guidelines: Future Proofing | Private fields and small accessors keep later code from bypassing the boundary. | as_slice, value, index, Parameters accessors |
After reading one external source, ask four questions:
- Which domain type did it clarify?
- Does that type only separate meaning, or does it also validate an invariant?
- Which downstream morphism is allowed to trust the value?
- Which command would you run to see the evidence?
For this chapter, the commands are:
cargo run --example 01_domain_objects
cargo test domain::tests --lib
For terminology recovery, use the Glossary entries for object, product object, invariant, and smart constructor. For source depth, use References and follow the Rust struct, error-handling, and API design entries.
If a source does not help you explain why Distribution::new rejects invalid
probability mass before CrossEntropy sees it, it has not transferred back
into the chapter yet.
Practice After This Chapter
Use Exercise 1 to explain one domain type and Exercise 7 to explain one constructor boundary. Together they test the chapter’s main distinction: some types separate meaning, while others also reject invalid states.
Retrieval Practice
Recall
What is a domain object in this book?
Explain
Why does Distribution::new validate probability mass at construction time
instead of leaving that check to CrossEntropy?
Apply
Design a one-field newtype for a future SequenceLength. State one invariant
its constructor should protect.
Morphism and Composition
The problem this chapter solves is:
Once the system has typed objects, it needs typed transformations between them.
In the previous chapter, the code created objects such as:
TokenId
Vector
Logits
Distribution
Loss
Parameters
This chapter explains the arrows that connect them.
The central category-theory sentence is:
A morphism is a typed transformation from one object to another.
The central Rust sentence is:
A morphism is a trait implementation with an input type, output type, and typed error result.
Reader orientation: The previous chapter defined the objects of the tiny ML system. This chapter explains how values move between those objects. That movement is the bridge between ordinary Rust functions and the categorical idea of morphisms.
Chapter Outcomes
By the end of this chapter, you should be able to:
- read
Morphism<Input, Output>as a typed transformation contract, - explain why
Compose<F, G, Middle>requires the first target object to match the second source object, - diagnose why
Embeddingfollowed directly bySoftmaxis illegal without weakening either stage.
What You Already Know
If you know Rust functions, you already know that computation moves from an
input type to an output type. If you know ML pipelines, you already know that a
prediction path is built from stages. This chapter gives that familiar movement
a shared interface: Morphism<Input, Output>.
Category Terms As Rust Shapes
Before reading the generic source file, pin each category-theory word to a Rust shape and one tiny ML example.
| Category term | Rust shape in this repository | Tiny ML example |
|---|---|---|
| Object | A named type that can appear as an input or output | TokenId, Vector, Logits, Distribution |
| Morphism | impl Morphism<Input, Output> for SomeStage | impl Morphism<TokenId, Vector> for Embedding |
| Source object | The Input type parameter | TokenId in Morphism<TokenId, Vector> |
| Target object | The Output type parameter | Vector in Morphism<TokenId, Vector> |
| Identity morphism | Identity<T> implementing Morphism<T, T> | Identity::<Vector>::new() |
| Composition | Compose<F, G, Middle> | Embedding followed by LinearToLogits |
| Middle object | The type produced by F and consumed by G | Vector between embedding and projection |
| Endomorphism | Endomorphism<T> where input and output are the same type | TrainStep : Parameters -> Parameters |
| Repeated endomorphism | apply_endomorphism_n_times | repeated training updates |
Use the table as a translation layer. When a formal word appears later, ask which Rust trait, type parameter, or implementation makes it concrete. If no Rust shape is nearby, the explanation is probably moving too fast.
Source-Backed Precision Rules
This chapter uses external sources to keep the word morphism small enough to
teach. Each source supports a limited claim, and each claim is tied to one
local Rust boundary. The chapter does not claim that this crate implements a
general category-theory library; it models typed transformations, identity,
composition, and one repeated endomorphism helper for the tiny ML system.
| Source | What the source supports | Local rule in this chapter | Rust evidence |
|---|---|---|---|
| Rust Book: Generics | Generic type parameters let one definition describe many concrete types while preserving type relationships. | Read Input, Middle, and Output as type-level objects, not runtime values. | Morphism<Input, Output>, Compose<F, G, Middle> |
| Rust Book: Traits | Traits name shared behavior and make a contract that concrete types implement. | Treat a morphism as a trait contract: a named, fallible, typed transformation. | trait Morphism<Input, Output>, impl Morphism<TokenId, Vector> for Embedding |
| Stanford Encyclopedia of Philosophy: Category Theory | A category has morphisms between objects, identity morphisms, composition, and identity/associativity axioms. | Keep the local Rust claim narrow: the chapter models source type, target type, identity, and composition for this teaching crate. | Identity<T>, Compose<F, G, Middle>, identity_composes_without_changing_behavior, composition_applies_first_then_second |
| Seven Sketches | Category theory introduces objects, arrows, identity, and composition through concrete applied examples. | Use category words only when they point to a visible Rust object, arrow, identity, or composition boundary. | Identity<T>, Compose<F, G, Middle>, identity_composes_without_changing_behavior |
| Category Theory for Programming | Programming-shaped category-theory notes connect categorical vocabulary to datatypes, functions, and typed structure. | Explain the ordinary typed-function shape before using the word morphism. | fn add_one(input: i32) -> i32, Morphism<Input, Output> |
The transfer pattern is:
source idea -> local typed boundary -> compiler, output, or test evidence
For this chapter, that means reading cargo run --example 02_morphism_composition, cargo test category::tests, and the failed-shape
diagnostic as evidence for a small claim:
two arrows compose only when the first target object matches the second source
object
It is not evidence that every categorical law has been formalized. It is evidence that this tiny Rust interface makes the relevant middle object hard to ignore.
Source Snapshot
This file defines the typed arrow interface and the composition adapter.
Source snapshot: src/category.rs
use std::marker::PhantomData;
use crate::error::CtResult;
/// A typed category-theory arrow: `Input -> Output`.
pub trait Morphism<Input, Output> {
fn name(&self) -> &'static str;
fn apply(&self, input: Input) -> CtResult<Output>;
}
/// Identity morphism: `id_A : A -> A`.
#[derive(Debug, Clone, Copy)]
pub struct Identity<T> {
_marker: PhantomData<T>,
}
impl<T> Identity<T> {
pub fn new() -> Self {
Self {
_marker: PhantomData,
}
}
}
impl<T> Default for Identity<T> {
fn default() -> Self {
Self::new()
}
}
impl<T> Morphism<T, T> for Identity<T> {
fn name(&self) -> &'static str {
"identity"
}
fn apply(&self, input: T) -> CtResult<T> {
Ok(input)
}
}
/// Composition of two morphisms: if `f : A -> B` and `g : B -> C`, this is
/// `g after f : A -> C`.
#[derive(Debug, Clone)]
pub struct Compose<F, G, Middle> {
first: F,
second: G,
_middle: PhantomData<Middle>,
}
impl<F, G, Middle> Compose<F, G, Middle> {
pub fn new(first: F, second: G) -> Self {
Self {
first,
second,
_middle: PhantomData,
}
}
}
impl<Input, Middle, Output, F, G> Morphism<Input, Output> for Compose<F, G, Middle>
where
F: Morphism<Input, Middle>,
G: Morphism<Middle, Output>,
{
fn name(&self) -> &'static str {
"composition"
}
fn apply(&self, input: Input) -> CtResult<Output> {
let middle = self.first.apply(input)?;
self.second.apply(middle)
}
}
/// Endomorphism: a morphism from a type back to itself.
pub trait Endomorphism<T>: Morphism<T, T> {}
impl<T, M> Endomorphism<T> for M where M: Morphism<T, T> {}
/// How many times to repeat an endomorphism.
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct StepCount(usize);
impl StepCount {
pub fn new(value: usize) -> Self {
Self(value)
}
pub fn value(&self) -> usize {
self.0
}
}
/// Repeatedly apply an endomorphism: `A0 -> A1 -> ... -> An`.
pub fn apply_endomorphism_n_times<T, E>(endo: &E, mut value: T, count: StepCount) -> CtResult<T>
where
E: Endomorphism<T>,
{
for _ in 0..count.value() {
value = endo.apply(value)?;
}
Ok(value)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::error::CtError;
#[derive(Debug, Clone, Copy)]
struct AddOne;
impl Morphism<i32, i32> for AddOne {
fn name(&self) -> &'static str {
"add_one"
}
fn apply(&self, input: i32) -> CtResult<i32> {
Ok(input + 1)
}
}
#[derive(Debug, Clone, Copy)]
struct Double;
impl Morphism<i32, i32> for Double {
fn name(&self) -> &'static str {
"double"
}
fn apply(&self, input: i32) -> CtResult<i32> {
Ok(input * 2)
}
}
#[derive(Debug, Clone, Copy)]
struct Fail;
impl Morphism<i32, i32> for Fail {
fn name(&self) -> &'static str {
"fail"
}
fn apply(&self, _input: i32) -> CtResult<i32> {
Err(CtError::InvalidQuantity {
kind: "test morphism",
value: -1,
})
}
}
#[test]
fn identity_returns_the_same_value() -> CtResult<()> {
let value = String::from("same");
assert_eq!(Identity::<String>::new().apply(value.clone())?, value);
Ok(())
}
#[test]
fn identity_composes_without_changing_behavior() -> CtResult<()> {
let left_identity = Compose::<_, _, i32>::new(Identity::<i32>::new(), AddOne);
let right_identity = Compose::<_, _, i32>::new(AddOne, Identity::<i32>::new());
assert_eq!(left_identity.apply(41)?, AddOne.apply(41)?);
assert_eq!(right_identity.apply(41)?, AddOne.apply(41)?);
Ok(())
}
#[test]
fn composition_applies_first_then_second() -> CtResult<()> {
let add_then_double = Compose::<_, _, i32>::new(AddOne, Double);
assert_eq!(add_then_double.apply(4)?, 10);
Ok(())
}
#[test]
fn composition_returns_the_first_error() {
let composed = Compose::<_, _, i32>::new(Fail, AddOne);
assert!(matches!(
composed.apply(4),
Err(CtError::InvalidQuantity {
kind: "test morphism",
value: -1,
})
));
}
}
The Whole File
src/category.rs defines:
Morphism<Input, Output>
Identity<T>
Compose<F, G, Middle>
Endomorphism<T>
StepCount
apply_endomorphism_n_times
These are the abstract shapes used by the ML code.
Without this file, prediction could still be written as ordinary functions.
With this file, the course can name and test the structure:
identity
composition
endomorphism
repeated application
Read each block through the same three lenses:
Rust syntax:
what trait, struct, generic parameter, or bound is declared?
ML concept:
which model pipeline behavior does the shape support?
Category theory concept:
which arrow, identity, composition, or endomorphism idea is being modeled?
Worked Example: A Function As An Arrow
Before reading the generic trait, start with an ordinary Rust function:
#![allow(unused)]
fn main() {
fn add_one(input: i32) -> i32 {
input + 1
}
assert_eq!(add_one(41), 42);
}
That function already has an arrow shape:
i32 -> i32
The real Morphism<Input, Output> trait makes that shape explicit, gives the
arrow a name, and lets the arrow fail with a typed error when the input cannot
be transformed safely.
Self-Check
Before reading the trait, explain why i32 -> i32 and TokenId -> Vector have
the same arrow shape even though they mean very different things.
Worked Example: Where Composition Breaks
Now look at a tiny ML path:
TokenId -> Vector -> Logits -> Distribution
Each arrow has a job:
Embedding : TokenId -> Vector
LinearToLogits : Vector -> Logits
Softmax : Logits -> Distribution
A legal composition connects the target of one arrow to the source of the next arrow:
TokenId --Embedding--> Vector --LinearToLogits--> Logits --Softmax--> Distribution
The middle object is not decoration. It is the reason the pipeline is legal.
Embedding produces a Vector, and LinearToLogits consumes a Vector.
LinearToLogits produces Logits, and Softmax consumes Logits.
Now remove the middle step:
TokenId --Embedding--> Vector --Softmax--> Distribution
This looks tempting if you only think in English: “turn the token into
probabilities.” But the types say something more precise. Softmax does not
consume a Vector. It consumes Logits. The missing arrow is:
Vector -> Logits
That is why composition is not just “run functions in order.” Composition means:
the previous output type equals the next input type
In this chapter, the word “morphism” gives that rule a handle. A morphism has a source object and a target object. Two morphisms compose only when the first target object is the second source object.
Runnable Example: The Middle Type Is The Contract
Run the example for this chapter:
cargo run --example 02_morphism_composition
The example builds the legal path in two composed steps:
let token_to_logits = Compose::<_, _, Vector>::new(embedding.clone(), linear.clone());
let token_to_distribution = Compose::<_, _, Logits>::new(token_to_logits, Softmax);
Read the third type argument as the middle object being checked:
Embedding then LinearToLogits:
TokenId -> Vector -> Logits
middle object: Vector
token_to_logits then Softmax:
TokenId -> Logits -> Distribution
middle object: Logits
The example output ends with the composition rule:
first target must equal second source
Embedding then LinearToLogits is legal because Vector == Vector
Embedding then Softmax is illegal because Vector != Logits
This is the concrete Rust reason a morphism is more than a metaphor. The type signature tells you which object an arrow produces and which object the next arrow expects. If those do not match, the composed pipeline is not a valid pipeline.
Composition Debugging Checklist
When a composition fails, do not start by changing type signatures. Name the three objects first:
first source -> first target
second source -> second target
Then ask whether the middle objects match:
first target == second source ?
For the legal path:
| Stage | Source object | Target object |
|---|---|---|
Embedding | TokenId | Vector |
LinearToLogits | Vector | Logits |
Softmax | Logits | Distribution |
The legal middle objects are:
Vector
Logits
The same legal path as a rendered math view:
[ \mathrm{TokenId} \xrightarrow{\mathrm{Embedding}} \mathrm{Vector} \xrightarrow{\mathrm{LinearToLogits}} \mathrm{Logits} \xrightarrow{\mathrm{Softmax}} \mathrm{Distribution} ]
How to read this diagram:
- the objects are the Rust domain types,
- the arrows are morphism implementations,
- composition is legal only when the target object of one arrow is the source object of the next arrow,
- the diagram is a reading aid, not a claim that Rust proves every category law.
For the broken shortcut:
| Attempted composition | First target | Second source | Result |
|---|---|---|---|
Embedding then Softmax | Vector | Logits | illegal composition |
The fix is not to make Softmax accept Vector. That would erase the model
stage that turns hidden features into vocabulary scores. The fix is to restore
the missing morphism:
Vector -> Logits
The broken shortcut is useful to draw because it exposes the missing middle object:
[ \begin{array}{ccccc} \mathrm{TokenId} & \xrightarrow{\mathrm{Embedding}} & \mathrm{Vector} & \not!\xrightarrow{\mathrm{Softmax}} & \mathrm{Distribution} \ && \downarrow \mathrm{LinearToLogits} && \ && \mathrm{Logits} & \xrightarrow{\mathrm{Softmax}} & \mathrm{Distribution} \end{array} ]
Reconstruct this diagram by hand when a composition error appears. Label the first target, the second source, and the repair arrow before changing code.
This checklist is useful beyond this chapter. Most pipeline bugs can be read as one of three failures:
| Failure | Diagnostic question | Repair |
|---|---|---|
| missing stage | Which middle object should exist but does not? | restore the morphism that produces it |
| wrong stage order | Which target object arrives too early or too late? | reorder the arrows so targets meet sources |
| wrong object name | Which two values have the same raw representation but different roles? | introduce or restore the domain type |
The category-theory word “composition” is doing practical engineering work here. It tells you to debug the boundary, not the individual matrix multiplication, softmax formula, or display output first.
Source-Target-Middle Repair Ledger
When a composition breaks, write a small ledger before changing code. The ledger forces the abstract word “composition” back into source object, target object, middle object, and repair.
| Composition attempt | First arrow | Second arrow | Claimed middle | Actual mismatch | Repair | Unsafe shortcut rejected | Validation evidence |
|---|---|---|---|---|---|---|---|
| legal embedding then projection | Embedding : TokenId -> Vector | LinearToLogits : Vector -> Logits | Vector | none | keep the order | skipping vocabulary scoring | Embedding then LinearToLogits is legal because Vector == Vector |
| illegal embedding then softmax | Embedding : TokenId -> Vector | Softmax : Logits -> Distribution | Vector | Softmax needs Logits, not Vector | restore LinearToLogits : Vector -> Logits | making Softmax accept hidden features | Embedding then Softmax is illegal because Vector != Logits |
| legal projection then softmax | LinearToLogits : Vector -> Logits | Softmax : Logits -> Distribution | Logits | none | keep the order | treating logits as optional decoration | Compose::<_, _, Logits> |
Use this audit card when the compiler, a diagram, or a reader’s intuition says two stages should connect:
composition attempt:
first arrow:
second arrow:
claimed middle object:
actual first target:
actual second source:
repair:
unsafe shortcut rejected:
validation command or output:
Worked audit:
composition attempt: Embedding then Softmax
first arrow: Embedding : TokenId -> Vector
second arrow: Softmax : Logits -> Distribution
claimed middle object: Vector
actual first target: Vector
actual second source: Logits
repair: insert LinearToLogits : Vector -> Logits
unsafe shortcut rejected: changing Softmax to accept Vector
validation command or output:
cargo run --example 02_morphism_composition
Embedding then Softmax is illegal because Vector != Logits
The source-backed limit is important. The Rust compiler is not proving every theorem about categories. It is checking the local trait bounds that make this pipeline composition legal or illegal.
Compiler Error As Evidence
The example does not include a broken composition because examples in this repository are expected to run. But the failed shape is still worth naming.
If you try to compose Embedding directly with Softmax, the intended shape
would be:
Embedding : TokenId -> Vector
Softmax : Logits -> Distribution
For Compose<F, G, Middle> to implement Morphism<Input, Output>, Rust needs
these two facts:
F: Morphism<Input, Middle>
G: Morphism<Middle, Output>
With Embedding followed by Softmax, choosing Middle = Vector asks Rust
for:
Embedding : Morphism<TokenId, Vector>
Softmax : Morphism<Vector, Distribution>
The first fact is true. The second fact is false. Softmax is implemented for
Logits -> Distribution, not Vector -> Distribution.
That failed trait bound is not noise. It says the missing middle object is:
Logits
and the missing morphism is:
LinearToLogits : Vector -> Logits
So the repair is not to make Softmax accept Vector. The repair is to
restore the stage that turns hidden features into vocabulary scores.
From Function To Morphism
An ordinary Rust function already has the outline:
Input -> Output
The course’s Morphism<Input, Output> trait adds three things to that outline.
It gives the transformation a stable name, makes failure explicit with
CtResult<Output>, and lets different transformation structs share one
composition API.
That is why this chapter uses the word “morphism” carefully. In this codebase, read it first as:
a named, fallible, typed transformation
Only after that concrete reading should you attach the category-theory word.
Morphism<Input, Output>
The problem this block solves is:
The code needs one shared contract for typed transformations.
The block:
/// A typed category-theory arrow: `Input -> Output`.
pub trait Morphism<Input, Output> {
fn name(&self) -> &'static str;
fn apply(&self, input: Input) -> CtResult<Output>;
}
Rust Syntax: Documentation Comment
/// A typed category-theory arrow: `Input -> Output`.
This tells you how to read the trait.
For example:
Embedding : TokenId -> Vector
means:
impl Morphism<TokenId, Vector> for Embedding
Rust Syntax: Trait Definition
pub trait Morphism<Input, Output>
Input and Output are type parameters.
They are not values.
They describe the type-level shape of the arrow.
This allows the same trait to model:
TokenSequence -> TrainingSet
TokenId -> Vector
Vector -> Logits
Logits -> Distribution
Distribution x TokenId -> Loss
Parameters -> Parameters
Rust Syntax: name
fn name(&self) -> &'static str;
This gives a stable human-readable name.
It is useful for demonstrations, diagnostics, and teaching.
The return type &'static str means the string is known for the whole program
lifetime. Names such as "softmax" and "embedding" are static literals.
Rust Syntax: apply
fn apply(&self, input: Input) -> CtResult<Output>;
This is the actual transformation.
It consumes an Input and returns either:
Ok(Output)
or:
Err(CtError)
This is important because many arrows can fail. Embedding can receive an out-of-range token, softmax can receive empty logits, cross entropy can receive an invalid target, and training can receive malformed parameters. The shared return type keeps those failures explicit instead of hiding them behind a panic.
Read One Concrete Implementation
The abstract trait becomes concrete when a stage implements it. In src/ml.rs,
the embedding stage has this shape:
impl Morphism<TokenId, Vector> for Embedding {
fn name(&self) -> &'static str {
"embedding"
}
fn apply(&self, token: TokenId) -> CtResult<Vector> {
// lookup and validation happen here
}
}
Read the first line slowly:
Embedding is a morphism from TokenId to Vector.
That one line gives the reader all four pieces requested by the public starter issue:
| Piece | In the code | Meaning |
|---|---|---|
| typed input | TokenId | a vocabulary position, not a raw integer |
| typed output | Vector | hidden features for that token |
| transformation | Embedding | table lookup from token to feature row |
| explicit failure | CtResult<Vector> | out-of-range tokens return an error |
The implementation does not say that every usize can become features. It says
that a validated TokenId can be applied to this embedding table, and the
result is either a Vector or a typed error.
ML Concept
Every ML stage becomes an implementation of the same contract.
That makes the pipeline inspectable as arrows, not just function calls.
Category Theory Concept
This trait is the course’s concrete model of a morphism.
It is not trying to implement all category theory. It gives enough structure to talk about typed arrows and composition in ordinary Rust.
Identity<T>
The problem this block solves is:
Every object should have an arrow that returns the object unchanged.
The block:
/// Identity morphism: `id_A : A -> A`.
#[derive(Debug, Clone, Copy)]
pub struct Identity<T> {
_marker: PhantomData<T>,
}
Rust Syntax: Why The Struct Has No Real Data
Identity<T> does not need to store a T.
It only needs to remember the type T.
That is why it stores:
_marker: PhantomData<T>
PhantomData<T> tells Rust:
This struct is logically connected to
T, even though it does not own a realTvalue.
Rust Syntax: Constructor
pub fn new() -> Self {
Self {
_marker: PhantomData,
}
}
This creates the identity arrow for a type.
Example:
Identity::<Vector>::new()
means:
id_Vector : Vector -> Vector
Rust Syntax: Default
impl<T> Default for Identity<T> {
fn default() -> Self {
Self::new()
}
}
This follows Rust convention: if a type has an obvious empty constructor, it can
implement Default.
Rust Syntax: Morphism Implementation
impl<T> Morphism<T, T> for Identity<T> {
fn name(&self) -> &'static str {
"identity"
}
fn apply(&self, input: T) -> CtResult<T> {
Ok(input)
}
}
This is the key:
T -> T
The input and output type are the same.
The implementation simply returns the input.
ML Concept
Identity is a no-op transformation.
In a model pipeline, no-op stages are useful for tests and for understanding what it means for composition to have a neutral element.
Category Theory Concept
Identity matters because composition has laws:
id after f = f
f after id = f
This code does not prove those laws generally, but it gives the object you need to talk about them in Rust.
The tests in src/category.rs check the executable version of this idea:
composing identity on either side of a simple morphism leaves the behavior
unchanged.
Compose<F, G, Middle>
The problem this block solves is:
If one morphism produces the type another morphism consumes, the code should be able to build a larger morphism.
The block:
/// Composition of two morphisms: if `f : A -> B` and `g : B -> C`, this is
/// `g after f : A -> C`.
#[derive(Debug, Clone)]
pub struct Compose<F, G, Middle> {
first: F,
second: G,
_middle: PhantomData<Middle>,
}
Rust Syntax: The Shape
The category-theory shape is:
f : A -> B
g : B -> C
g after f : A -> C
The Rust type is:
Compose<F, G, Middle>
where:
Fis the first morphismGis the second morphismMiddleis the bridge type
The middle type is explicit because Rust needs to know what connects the two arrows.
This is the most important learner habit in the chapter: when composition feels abstract, look for the middle type.
Rust Syntax: Fields
first: F,
second: G,
_middle: PhantomData<Middle>,
first stores the first arrow.
second stores the second arrow.
_middle records the bridge type without storing a value of that type.
Rust Syntax: Constructor
pub fn new(first: F, second: G) -> Self
This builds the composed morphism.
It does not run the morphisms yet.
It only stores them.
Rust Syntax: Morphism Implementation
impl<Input, Middle, Output, F, G> Morphism<Input, Output>
for Compose<F, G, Middle>
where
F: Morphism<Input, Middle>,
G: Morphism<Middle, Output>,
{
fn apply(&self, input: Input) -> CtResult<Output> {
let middle = self.first.apply(input)?;
self.second.apply(middle)
}
}
This is the most important block in the chapter.
The where clause says:
F must be Input -> Middle
G must be Middle -> Output
Only then can Compose<F, G, Middle> be:
Input -> Output
Rust Syntax: The ? Operator
let middle = self.first.apply(input)?;
This applies the first arrow.
If it fails, the error returns immediately.
If it succeeds, the successful value is bound to middle.
Then the second arrow runs:
self.second.apply(middle)
So composition preserves failure.
It does not hide invalid states.
The category tests also check this behavior directly. A composed morphism that fails in its first step returns that error immediately instead of pretending the second step ran.
ML Concept
Prediction uses composition:
TokenId -> Vector -> Logits -> Distribution
The code builds that in two steps:
let token_to_logits = Compose::<_, _, Vector>::new(embedding, linear);
let token_to_distribution = Compose::<_, _, Logits>::new(token_to_logits, Softmax);
The bridge types are:
Vector
Logits
The legal diagram is:
TokenId
|
| Embedding
v
Vector
|
| LinearToLogits
v
Logits
|
| Softmax
v
Distribution
The important detail is not the vertical layout. The important detail is that every arrow’s output object is exactly the next arrow’s input object.
If you try to compose Embedding directly with Softmax, the middle type does
not match:
Embedding : TokenId -> Vector
Softmax : Logits -> Distribution
Vector is not Logits, so Rust rejects the composition.
This is the practical win. A diagram that skips LinearToLogits is not only
conceptually wrong; it has the wrong type boundary.
Category Theory Concept
Compose is function composition with types made explicit.
It is the course’s main example of:
small legal arrows -> larger legal arrow
The code is deliberately modest. It models enough composition to make the pipeline inspectable and testable; it is not claiming to encode every categorical law in Rust’s type system.
Endomorphism<T>
The problem this block solves is:
Some arrows start and end at the same type, and those arrows can be repeated.
The block:
/// Endomorphism: a morphism from a type back to itself.
pub trait Endomorphism<T>: Morphism<T, T> {}
impl<T, M> Endomorphism<T> for M where M: Morphism<T, T> {}
An endomorphism has shape:
T -> T
The trait has no methods of its own.
It is a marker trait:
if something implements Morphism<T, T>, it is an Endomorphism<T>
The blanket implementation says exactly that:
impl<T, M> Endomorphism<T> for M where M: Morphism<T, T> {}
ML Concept
Training has this shape:
Parameters -> Parameters
One training step consumes parameters and returns updated parameters.
The model changes, but the type stays the same.
Category Theory Concept
Endomorphisms are important because they can be iterated:
A -> A -> A -> A
That is the categorical shape of repeated training.
StepCount
The problem this block solves is:
Repetition count should have a semantic name instead of being a random
usizeat the call site.
The block:
/// How many times to repeat an endomorphism.
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct StepCount(usize);
This wraps a raw usize.
It means:
number of repeated endomorphism applications
StepCount::new(80) reads better than a bare 80 because it names the role of
the number.
Rust Syntax
StepCount is a newtype around usize.
It has a constructor and a value() accessor.
ML Concept
It controls how many optimizer steps are applied.
Category Theory Concept
It controls how many times an endomorphism is iterated.
apply_endomorphism_n_times
The problem this block solves is:
Given an endomorphism, repeatedly apply it in a type-safe loop.
The block:
pub fn apply_endomorphism_n_times<T, E>(
endo: &E,
mut value: T,
count: StepCount,
) -> CtResult<T>
where
E: Endomorphism<T>,
{
for _ in 0..count.value() {
value = endo.apply(value)?;
}
Ok(value)
}
Rust Syntax: Type Parameters
T is the object being updated.
E is the endomorphism type.
The bound:
E: Endomorphism<T>
means:
E must be a T -> T arrow
Rust Syntax: Mutable Value
mut value: T
The function owns the current value.
Each loop iteration replaces it with the next value:
value = endo.apply(value)?;
This is not mutation of shared global state.
It is ownership passing through a repeated transformation.
Rust Syntax: Failure Behavior
If any application fails, the whole repeated process fails immediately.
This is the correct behavior for training too: if a step discovers invalid parameters or an out-of-range token, the loop should not pretend everything is fine.
ML Concept
For training:
T = Parameters
E = TrainStep
The function becomes:
repeat TrainStep on Parameters
Category Theory Concept
This is iteration of an endomorphism:
value0
-> value1
-> value2
-> ...
-> valueN
Runnable Example
The composition example builds:
TokenId -> Vector -> Logits -> Distribution
Source snapshot: examples/02_morphism_composition.rs
use category_theory_transformer_rs::{
Compose, CtResult, Distribution, Embedding, LinearToLogits, Logits, ModelDimension, Morphism,
Parameters, Softmax, TokenId, Vector, VocabSize,
};
fn main() -> CtResult<()> {
let params = Parameters::init(VocabSize::new(5)?, ModelDimension::new(4)?);
let token = TokenId::new(1);
let embedding = Embedding::from_parameters(¶ms);
let linear = LinearToLogits::from_parameters(¶ms);
let token_to_logits = Compose::<_, _, Vector>::new(embedding.clone(), linear.clone());
let token_to_distribution = Compose::<_, _, Logits>::new(token_to_logits, Softmax);
let vector = embedding.apply(token)?;
let logits = linear.apply(vector.clone())?;
let distribution = Softmax.apply(logits.clone())?;
let composed_distribution = token_to_distribution.apply(token)?;
println!("Input object:");
println!("TokenId({})", token.index());
println!();
println!("Stage outputs:");
println!("Embedding : TokenId -> Vector");
println!("{}", format_vector(&vector));
println!("LinearToLogits : Vector -> Logits");
println!("{}", format_logits(&logits));
println!("Softmax : Logits -> Distribution");
println!("{}", format_distribution(&distribution));
println!();
println!("Composed morphism:");
println!("TokenId -> Distribution");
println!(
"next-token probabilities: {}",
format_values(composed_distribution.as_slice())
);
println!();
println!("Middle objects kept visible:");
println!("Vector");
println!("Logits");
println!();
println!("Composition rule:");
println!("first target must equal second source");
println!("Embedding then LinearToLogits is legal because Vector == Vector");
println!("Embedding then Softmax is illegal because Vector != Logits");
Ok(())
}
fn format_vector(vector: &Vector) -> String {
format!(
"Vector(dim={}, values={})",
vector.as_slice().len(),
format_values(vector.as_slice())
)
}
fn format_logits(logits: &Logits) -> String {
format!(
"Logits(vocab={}, values={})",
logits.as_slice().len(),
format_values(logits.as_slice())
)
}
fn format_distribution(distribution: &Distribution) -> String {
format!(
"Distribution(vocab={}, sum={:.6}, values={})",
distribution.as_slice().len(),
distribution.as_slice().iter().sum::<f32>(),
format_values(distribution.as_slice())
)
}
fn format_values(values: &[f32]) -> String {
let formatted = values
.iter()
.map(|value| format!("{value:.6}"))
.collect::<Vec<_>>()
.join(", ");
format!("[{formatted}]")
}
Run:
cargo run --example 02_morphism_composition
Expected shape:
Input object:
TokenId(1)
Stage outputs:
Embedding : TokenId -> Vector
Vector(dim=4, values=[...])
LinearToLogits : Vector -> Logits
Logits(vocab=5, values=[...])
Softmax : Logits -> Distribution
Distribution(vocab=5, sum=1.000000, values=[...])
Composed morphism:
TokenId -> Distribution
next-token probabilities: [...]
Middle objects kept visible:
Vector
Logits
Example Output Transfer Checklist
The example prints stage outputs and then prints the composed arrow. Read that output as a composition report, not only as a numeric demo.
| Example output or code evidence | Rust reading | ML reading | Category-theory reading | Shortcut to reject |
|---|---|---|---|---|
TokenId(1) | the input is a named object, not a bare index | choose one context token | source object | passing an unnamed row number through the pipeline |
Embedding : TokenId -> Vector | Embedding implements Morphism<TokenId, Vector> | look up the token’s hidden feature row | arrow from source object to middle object | passing a token directly to projection |
Vector(dim=4, values=[...]) | a Vector value exists before projection | hidden representation, not vocabulary scores | first middle object | treating features as logits |
LinearToLogits : Vector -> Logits | LinearToLogits implements Morphism<Vector, Logits> | project hidden features into vocabulary scores | arrow between middle objects | sending a vector directly to Softmax |
Logits(vocab=5, values=[...]) | unnormalized scores have their own type | one score per vocabulary item | second middle object | treating scores as probabilities |
Softmax : Logits -> Distribution | Softmax implements Morphism<Logits, Distribution> | normalize scores into probabilities | arrow into the target object | computing loss before a probability object exists |
Distribution(vocab=5, sum=1.000000, values=[...]) | constructor validation produced a distribution | next-token probabilities sum to one | target object | treating arbitrary floats as a probability distribution |
Compose::<_, _, Vector> | Vector is the first bridge type | embedding must happen before projection | legal composition through a middle object | hiding the bridge type and guessing that stages fit |
Compose::<_, _, Logits> | Logits is the second bridge type | projection must happen before softmax | legal composition through a middle object | forgetting that Softmax needs logits |
TokenId -> Distribution | the composed value is a larger morphism | the prediction path is now one callable stage | composite arrow | thinking composition erases intermediate obligations |
This is the chapter’s most important transfer move. The user-facing output is compact:
next-token probabilities: [...]
The typed explanation is larger:
TokenId -> Vector -> Logits -> Distribution
A strong reader can connect both views. The numeric output tells you what the pipeline produced. The typed path tells you why the pipeline was legal.
The stage outputs also explain the ML meaning of the middle objects:
Vector = hidden features
Logits = vocabulary scores
Distribution = normalized next-token probabilities
The category-theory discipline is to keep those middle objects visible. A
composite arrow can be named TokenId -> Distribution, but the legal route is
still built from the two bridge objects Vector and Logits.
Why This API Is Good Design
The code does not make composition a loose runtime convention.
It puts composition into the type system.
That means the compiler checks the bridge type:
F output == G input
This is the core practical value of the category-theory framing in this repo.
It turns:
remember to wire the stages correctly
into:
make invalid wiring fail to compile
Core Mental Model
In Rust terms:
Morphism<Input, Output> = fallible typed transformation
Compose<F, G, Middle> = legal connection of two transformations
Endomorphism<T> = repeatable T -> T transformation
In ML terms:
small prediction stages compose into a model path
training is a repeatable update step
In category-theory terms:
objects are connected by arrows, arrows compose when their endpoints match
Checkpoint
Why does this composition compile:
TokenId -> Vector -> Logits
but this one does not:
TokenId -> Vector -> Distribution
A strong answer should mention that Softmax expects Logits, not Vector.
Where This Leaves Us
This chapter turned ordinary transformations into named arrows. Identity<T>
leaves a value unchanged, Compose<F, G, Middle> connects compatible arrows,
and Endomorphism<T> names the special case where the input and output object
are the same.
The next chapter, The Tiny ML Pipeline, fills those arrow shapes with concrete ML behavior: token windowing, embedding lookup, linear projection, softmax, and cross entropy.
Further Reading
Do not use these sources to make the word “morphism” sound larger. Use them to debug one concrete question:
what is the source object, target object, and middle object?
Start from the local Rust evidence:
Morphism<Input, Output>
Compose<F, G, Middle>
F: Morphism<Input, Middle>
G: Morphism<Middle, Output>
Embedding : TokenId -> Vector
LinearToLogits : Vector -> Logits
Softmax : Logits -> Distribution
Then read the sources in this order:
| Source | What to transfer back into this chapter | Local evidence to inspect |
|---|---|---|
| Rust Book: Generics | Generic parameters preserve relationships between input, middle, and output types. | Compose<F, G, Middle> |
| Rust Book: Traits | A trait defines the method signatures each implementation must provide. | trait Morphism<Input, Output> |
| Stanford Encyclopedia of Philosophy: Category Theory | The formal category shape needs morphisms, identity, composition, associativity, and identity laws. | Identity<T>, Compose<F, G, Middle>, composition_applies_first_then_second |
| Seven Sketches | Objects, arrows, identity, and composition can be introduced through concrete applied examples. | Identity<T>, Compose<F, G, Middle> |
| Category Theory for Programming | Category-theory vocabulary can be connected to typed programming structure. | fn add_one(input: i32) -> i32, Morphism<Input, Output> |
After reading one external source, ask four questions:
- Which local boundary did it clarify?
- Which type relationship did it help protect?
- Which illegal composition does it help reject?
- Which command would you run to see the evidence?
For this chapter, the commands are:
cargo run --example 02_morphism_composition
cargo test category::tests --lib
cargo test ml::tests::composed_and_direct_prediction_match --lib
Use Glossary when a term becomes slippery. Use References when you want the full source list.
If a source does not help you explain why Embedding can compose with
LinearToLogits but not directly with Softmax, it has not transferred back
into the chapter yet.
Practice After This Chapter
Use Exercise 4 to intentionally break a composition and explain the missing middle type. This is the chapter’s most important transfer check: a type error should become evidence about the pipeline boundary.
Retrieval Practice
Recall
Recover the shape of the API before explaining the pipeline.
- What two methods must
Morphism<Input, Output>provide? - Which type in
Compose<F, G, Middle>records the bridge between two arrows? - What shape makes a morphism an endomorphism?
Explain
Use the middle object to explain why composition is legal or illegal.
- Why does
Compose<F, G, Middle>requireF: Morphism<Input, Middle>andG: Morphism<Middle, Output>? - Why is
TokenId -> Vector -> Distributionnot a legal version of the prediction path? - Why does composition return the first error instead of trying to run the second arrow?
Apply
Use the output from cargo run --example 02_morphism_composition as the
working path.
- Write the legal path from
TokenIdtoDistribution, naming both middle objects. - If you insert
Identity<Vector>betweenEmbeddingandLinearToLogits, why should the behavior stay the same? - If you try to repeat
Embeddingwithapply_endomorphism_n_times, which shape rule blocks the attempt?
Debug
For each invalid shortcut, name the missing or mismatched middle type:
Embedding followed directly by Softmax
Embedding followed by Identity<TokenId>
repeating Embedding as an endomorphism
A strong answer should identify the source and target object of each arrow, then state which object fails to line up. Do not answer only with “the compiler rejects it”; explain the typed boundary the compiler is protecting.
The Tiny ML Pipeline
The problem this chapter solves is:
The abstract
Morphismtrait needs concrete machine-learning arrows that turn token data into predictions and loss.
The whole prediction-and-loss path is:
TokenSequence -> TrainingSet
TokenId -> Vector
Vector -> Logits
Logits -> Distribution
Distribution x TokenId -> Loss
In ordinary ML language, the path turns a token stream into adjacent training pairs, looks up an embedding vector for the current token, uses a linear layer to score every possible next token, normalizes those scores with softmax, and then measures surprise with cross entropy.
In category-theory language:
Each stage is a morphism, and the legal stages compose.
Reader orientation: This is the first chapter where all three subjects meet at once. When the code feels dense, follow the pipeline order: data preparation first, prediction second, loss third.
First Mental Model
The public shorthand for the whole project is:
Text -> Tokens -> TrainingPairs -> ModelState -> Prediction -> Loss -> Updated ModelState
This chapter zooms into the middle of that path. It explains how token pairs, current parameters, predictions, and loss become separate typed boundaries.
flowchart LR
A["Text"] --> B["Tokens / TokenSequence"]
B --> C["TrainingPairs / TrainingSet"]
C --> F["Loss"]
M["ModelState / Parameters"] --> P["Prediction / Distribution"]
P --> F
M --> U["Updated ModelState / Updated Parameters"]
F --> U
Read the diagram as orientation, then use the Rust types for precision. The loss boundary needs both current model state and training data:
Parameters x TrainingSet -> Loss
The update boundary returns a complete next model state:
Parameters -> Updated Parameters
The same orientation as a compact rendered math view:
[ \begin{array}{ccccccc} \mathrm{Text} & \to & \mathrm{TokenSequence} & \to & \mathrm{TrainingSet} & \to & \mathrm{Loss} \ &&&& \uparrow && \downarrow \ &&&& \mathrm{Parameters} & \to & \mathrm{UpdatedParameters} \end{array} ]
How to read this diagram:
- the top row is the data path from text into measured loss,
Parametersenters the prediction and loss boundary as model state,- the returned object is a complete updated parameter object,
- the diagram is a map of responsibilities; the later sections name the exact Rust functions that own each arrow.
Chapter Outcomes
By the end of this chapter, you should be able to:
- trace
TokenId -> Vector -> Logits -> Distributionthrough the concrete Rust morphisms, - explain why cross entropy consumes both a prediction and the target token,
- distinguish the production shortcut
CrossEntropyLoss(logits, target)from this book’s explicitLogits -> Distribution -> Product<Distribution, TokenId> -> Lossteaching path.
What You Already Know
If you know ML, you already know the rough path: prepare data, make a prediction, and measure the error. If you know Rust, you already know that each step can have a concrete input and output type. This chapter combines those two habits by making each ML step implement the same morphism interface.
Prediction Trace Before Source
Before reading src/ml.rs, keep this trace in view. It separates raw scores,
probabilities, the target token, and the final loss.
| Stage | Rust type | Plain meaning | What to check |
|---|---|---|---|
| Input token | TokenId | the current token position | Is this a vocabulary index, not a dimension? |
| Embedding | Vector | dense hidden features for that token | Has the token become numeric features? |
| Scores | Logits | unnormalized next-token scores | Are these still raw scores, not probabilities? |
| Normalization | Distribution | probabilities over next tokens | Do probabilities sum to one? |
| Target pairing | Product<Distribution, TokenId> | prediction plus correct next token | Which index is the target token? |
| Loss | Loss | surprise assigned to the target token | Did cross entropy use the target probability? |
The target index is the key detail. Cross entropy does not punish every probability equally. It first selects the probability assigned to the correct next token, then computes:
loss = -ln(probability assigned to target)
So this chapter’s core mental model is:
Logits
-> Distribution
-> target probability
-> Loss
That order matters. If a reader treats logits as probabilities, or forgets that loss uses the target token index, the pipeline becomes hard to debug.
Source Snapshot
This file owns the concrete ML arrows.
Source snapshot: src/ml.rs
use crate::category::{Compose, Morphism};
use crate::domain::{
Distribution, Logits, Loss, Parameters, Product, TokenId, TokenSequence, TrainingSet, Vector,
approx_eq,
};
use crate::error::{CtError, CtResult};
/// Turns adjacent tokens into next-token training examples.
#[derive(Debug, Clone)]
pub struct DatasetWindowing;
impl Morphism<TokenSequence, TrainingSet> for DatasetWindowing {
fn name(&self) -> &'static str {
"dataset_windowing"
}
fn apply(&self, tokens: TokenSequence) -> CtResult<TrainingSet> {
if tokens.as_slice().len() < 2 {
return Err(CtError::EmptyInput(
"dataset windowing requires at least 2 tokens",
));
}
TrainingSet::new(
tokens
.as_slice()
.windows(2)
.map(|pair| Product::new(pair[0], pair[1])),
)
}
}
/// Morphism from token id to embedding vector.
#[derive(Debug, Clone)]
pub struct Embedding {
table: Vec<Vec<f32>>,
}
impl Embedding {
pub fn from_parameters(params: &Parameters) -> Self {
Self {
table: params.embedding_table().to_vec(),
}
}
}
impl Morphism<TokenId, Vector> for Embedding {
fn name(&self) -> &'static str {
"embedding"
}
fn apply(&self, token: TokenId) -> CtResult<Vector> {
let Some(row) = self.table.get(token.index()) else {
return Err(CtError::OutOfRange {
kind: "token",
index: token.index(),
limit: self.table.len(),
});
};
Ok(Vector::new(row.clone()))
}
}
/// Linear projection from hidden vector to vocabulary logits.
#[derive(Debug, Clone)]
pub struct LinearToLogits {
weight: Vec<Vec<f32>>,
bias: Vec<f32>,
}
impl LinearToLogits {
pub fn from_parameters(params: &Parameters) -> Self {
Self {
weight: params.lm_head().to_vec(),
bias: params.bias().to_vec(),
}
}
pub(crate) fn from_parts(weight: Vec<Vec<f32>>, bias: Vec<f32>) -> Self {
Self { weight, bias }
}
}
impl Morphism<Vector, Logits> for LinearToLogits {
fn name(&self) -> &'static str {
"linear_to_logits"
}
fn apply(&self, input: Vector) -> CtResult<Logits> {
let d_model = input.as_slice().len();
let vocab_size = self.bias.len();
if self.weight.len() != d_model {
return Err(CtError::ShapeMismatch {
op: "linear layer",
expected: format!("weight rows == input dim {d_model}"),
got: format!("weight rows {}", self.weight.len()),
});
}
let mut out = self.bias.clone();
for (feature, input_value) in input.as_slice().iter().enumerate() {
if self.weight[feature].len() != vocab_size {
return Err(CtError::ShapeMismatch {
op: "linear layer",
expected: format!("weight cols == vocab size {vocab_size}"),
got: format!("weight cols {}", self.weight[feature].len()),
});
}
for (vocab_id, output_value) in out.iter_mut().enumerate() {
*output_value += input_value * self.weight[feature][vocab_id];
}
}
Ok(Logits::new(out))
}
}
/// Converts logits to a probability distribution.
#[derive(Debug, Clone)]
pub struct Softmax;
impl Morphism<Logits, Distribution> for Softmax {
fn name(&self) -> &'static str {
"softmax"
}
fn apply(&self, logits: Logits) -> CtResult<Distribution> {
if logits.as_slice().is_empty() {
return Err(CtError::EmptyInput("softmax"));
}
let max_value = logits
.as_slice()
.iter()
.copied()
.fold(f32::NEG_INFINITY, f32::max);
let mut exps = Vec::with_capacity(logits.as_slice().len());
let mut sum = 0.0;
for value in logits.as_slice() {
let exp = (*value - max_value).exp();
exps.push(exp);
sum += exp;
}
if sum <= 0.0 || !sum.is_finite() {
return Err(CtError::InvalidProbability("softmax"));
}
Distribution::new(exps.into_iter().map(|value| value / sum).collect())
}
}
/// Negative log likelihood for `(distribution, target_token)`.
#[derive(Debug, Clone)]
pub struct CrossEntropy;
impl Morphism<Product<Distribution, TokenId>, Loss> for CrossEntropy {
fn name(&self) -> &'static str {
"cross_entropy"
}
fn apply(&self, input: Product<Distribution, TokenId>) -> CtResult<Loss> {
let (distribution, target) = input.into_parts();
let Some(probability) = distribution.as_slice().get(target.index()).copied() else {
return Err(CtError::OutOfRange {
kind: "target",
index: target.index(),
limit: distribution.as_slice().len(),
});
};
Loss::new(-probability.max(1e-9).ln())
}
}
/// Direct path used for a commutative-diagram check.
#[derive(Debug, Clone)]
pub struct DirectPredict {
params: Parameters,
}
impl DirectPredict {
pub fn new(params: Parameters) -> Self {
Self { params }
}
}
impl Morphism<TokenId, Distribution> for DirectPredict {
fn name(&self) -> &'static str {
"direct_predict"
}
fn apply(&self, token: TokenId) -> CtResult<Distribution> {
let embedding = Embedding::from_parameters(&self.params).apply(token)?;
let logits = LinearToLogits::from_parameters(&self.params).apply(embedding)?;
Softmax.apply(logits)
}
}
/// Average cross-entropy over a training set.
pub fn average_loss(params: &Parameters, dataset: &TrainingSet) -> CtResult<Loss> {
let embedding = Embedding::from_parameters(params);
let linear = LinearToLogits::from_parameters(params);
let predict = Compose::<_, _, Vector>::new(embedding, linear);
let predict = Compose::<_, _, Logits>::new(predict, Softmax);
let loss_fn = CrossEntropy;
let mut total = 0.0;
for example in dataset.examples() {
let distribution = predict.apply(*example.first())?;
let loss = loss_fn.apply(Product::new(distribution, *example.second()))?;
total += loss.value();
}
Loss::new(total / dataset.len() as f32)
}
/// Verifies that the composed path and direct path produce the same result.
pub fn composed_prediction_matches_direct_prediction(params: &Parameters) -> CtResult<bool> {
let token = TokenId::new(1);
let composed = Compose::<_, _, Vector>::new(
Embedding::from_parameters(params),
LinearToLogits::from_parameters(params),
);
let composed = Compose::<_, _, Logits>::new(composed, Softmax);
let direct = DirectPredict::new(params.clone());
let left_path = composed.apply(token)?;
let right_path = direct.apply(token)?;
Ok(left_path
.as_slice()
.iter()
.zip(right_path.as_slice().iter())
.all(|(a, b)| approx_eq(*a, *b, 1e-6)))
}
#[cfg(test)]
mod tests {
use super::*;
use crate::domain::{ModelDimension, VocabSize};
#[test]
fn dataset_windowing_builds_adjacent_pairs() -> CtResult<()> {
let tokens = TokenSequence::from_indices([1, 2, 3])?;
let dataset = DatasetWindowing.apply(tokens)?;
assert_eq!(dataset.len(), 2);
assert_eq!(dataset.examples()[0].first().index(), 1);
assert_eq!(dataset.examples()[0].second().index(), 2);
Ok(())
}
#[test]
fn composed_and_direct_prediction_match() -> CtResult<()> {
let params = Parameters::init(VocabSize::new(5)?, ModelDimension::new(4)?);
assert!(composed_prediction_matches_direct_prediction(¶ms)?);
Ok(())
}
#[test]
fn softmax_normalizes_logits_into_distribution() -> CtResult<()> {
let distribution = Softmax.apply(Logits::new(vec![1.0, 2.0, 3.0]))?;
let probabilities = distribution.as_slice();
let sum: f32 = probabilities.iter().sum();
assert!(approx_eq(sum, 1.0, 1e-6));
assert!(probabilities[2] > probabilities[1]);
assert!(probabilities[1] > probabilities[0]);
Ok(())
}
#[test]
fn cross_entropy_is_lower_for_more_confident_target_probability() -> CtResult<()> {
let confident = Distribution::new(vec![0.9, 0.1])?;
let surprised = Distribution::new(vec![0.1, 0.9])?;
let confident_loss = CrossEntropy.apply(Product::new(confident, TokenId::new(0)))?;
let surprised_loss = CrossEntropy.apply(Product::new(surprised, TokenId::new(0)))?;
assert!(confident_loss.value() < surprised_loss.value());
Ok(())
}
}
The Whole File
src/ml.rs defines:
DatasetWindowing
Embedding
LinearToLogits
Softmax
CrossEntropy
DirectPredict
average_loss
composed_prediction_matches_direct_prediction
The chapter reads them in pipeline order.
Read each block through the same three lenses:
Rust syntax:
what struct, trait implementation, loop, or error branch does the code use?
ML concept:
which prediction, loss, or data-preparation step does the block implement?
Category theory concept:
which object, product, morphism, composition, or commutative check appears?
Worked Example: Normalizing Scores
The smallest first-principles version of “normalize scores into probabilities” does not need a model yet:
#![allow(unused)]
fn main() {
let scores = [1.0_f32, 2.0, 3.0];
let total: f32 = scores.iter().sum();
let probabilities: Vec<f32> = scores.iter().map(|score| score / total).collect();
let probability_sum: f32 = probabilities.iter().sum();
assert!((probability_sum - 1.0).abs() < 1e-6);
}
The real Softmax implementation is more careful than this toy normalization:
it uses exponentials, subtracts the maximum score for numerical stability, and
validates the result through Distribution::new.
Self-Check
Why is it useful for the probability-validation boundary to live in
Distribution::new instead of in every caller that uses probabilities?
Scores, Probabilities, And Loss
This chapter becomes easier if you keep three numbers separate.
Logits are raw scores. They can be negative, larger than one, and they do not
need to sum to one. A logit says “how strongly the model scores this token
before normalization.”
Distribution values are probabilities. They must be finite, non-negative, and
sum to one. A distribution says “how much probability the model assigns to each
possible next token.”
That is still a local model probability, not a promise that the model’s
confidence is calibrated in the outside world. Calibration asks whether events
predicted with about 0.90 confidence really happen about ninety percent of
the time over a population of predictions. This tiny chapter only builds and
validates the normalized probability object.
Loss is a scalar penalty. Cross entropy makes the penalty small when the
model assigns high probability to the correct token and large when it assigns
low probability to the correct token.
The concrete path is:
raw scores
-> probabilities
-> surprise about the target
Here is one small numeric trace:
target token index: 0
probability assigned to target: 0.90
loss = -ln(0.90) = 0.105
target token index: 0
probability assigned to target: 0.10
loss = -ln(0.10) = 2.303
Nothing mysterious happened. The loss only looked at the probability assigned to the correct target token. A confident correct prediction receives a small penalty. A surprised prediction receives a larger penalty.
Worked Example: Do Not Use The Largest Probability
A common mistake is to compute loss from the largest probability in the distribution.
That is wrong.
Cross entropy uses the probability assigned to the actual target token, even when the model assigned a larger probability to some other token.
Consider this prediction:
probabilities over next tokens:
index 0: 0.60
index 1: 0.30
index 2: 0.10
target token index: 1
The largest probability is 0.60, but it belongs to token index 0.
The target probability is 0.30, because the correct next token is index 1.
So the loss is:
loss = -ln(0.30) = 1.204
The incorrect shortcut would be:
loss = -ln(0.60) = 0.511
That shortcut would make the prediction look better than it is. It rewards the model for being confident about the wrong token.
The Rust code prevents that confusion by pairing the distribution with the target:
Product<Distribution, TokenId>
Then CrossEntropy indexes into the distribution with target.index(). The
target decides which probability becomes the loss.
The Rust path is:
Logits -> Distribution
Distribution x TokenId -> Loss
Framework Shortcut, Teaching Boundary
PyTorch’s CrossEntropyLoss accepts unnormalized logits and a target class
index or target probabilities. That production API is efficient and ergonomic:
the framework can combine normalization, target selection, reduction, and
gradient behavior behind one call.
This book splits the same idea into smaller objects:
Logits -> Distribution -> Product<Distribution, TokenId> -> Loss
Read that as the book’s smaller
Logits -> Distribution -> Product<Distribution, TokenId> -> Loss path.
That split is not a claim that production frameworks are wrong. It is a teaching boundary. It makes two questions visible before the code becomes compact:
which boundary turns scores into probabilities?
which target index selects the probability used by loss?
| Production API habit | Tiny Rust teaching boundary |
|---|---|
CrossEntropyLoss(logits, target_index) | Logits -> Distribution, then Distribution x TokenId -> Loss |
| logits and target passed together | probability invariant and target selection are separate |
| optimized fused behavior may hide the intermediate probability object | reader can inspect the Distribution constructor and the target probability |
When moving back to frameworks, remember that the compact API still owns both roles: score normalization and target-conditioned loss.
Target-Probability Responsibility Ledger
This chapter’s most important debugging habit is to keep responsibility in the right place. Each boundary owns one job.
| Pipeline cue | Rust handle | ML responsibility | Category boundary | Unsafe shortcut rejected | Source-backed limit |
|---|---|---|---|---|---|
| raw vocabulary scores | LinearToLogits : Vector -> Logits | produce one unnormalized score per token | Vector -> Logits | treating logits as probabilities | this is a tiny linear projection, not a full classifier stack |
| normalized probabilities | Softmax : Logits -> Distribution and Distribution::new | exponentiate, normalize, and validate a probability vector | Logits -> Distribution | skipping the probability invariant | normalized probability is not calibrated confidence |
| target probability | target.index() and distribution.as_slice().get(...) | select the probability assigned to the correct next token | part of Distribution x TokenId -> Loss | using the largest probability | this checks supervised class-index loss, not every target encoding |
| scalar surprise | Loss::new(-probability.max(1e-9).ln()) | turn the target probability into a non-negative penalty | Distribution x TokenId -> Loss | hiding target selection inside a vague loss word | this is the expanded teaching path, not a fused production kernel |
Use this audit card whenever the loss boundary feels slippery:
pipeline cue:
Rust handle:
ML responsibility:
category boundary:
unsafe shortcut rejected:
source-backed limit:
validation command:
Worked audit:
pipeline cue: target probability
Rust handle: distribution.as_slice().get(target.index())
ML responsibility: select the probability assigned to the correct next token
category boundary: CrossEntropy : Distribution x TokenId -> Loss
unsafe shortcut rejected: using the largest probability
source-backed limit: this checks one local supervised classification boundary,
not calibration and not full framework equivalence
validation command:
cargo test cross_entropy_is_lower_for_more_confident_target_probability --lib
The phrase “probability assigned to the target” should now point to one line of Rust, one ML responsibility, and one category-shaped boundary.
Source-Backed Precision Rules
This chapter uses external sources to keep the tiny prediction-and-loss path honest. Each source supports a limited claim; these citations are not proof that this crate is a production classifier, a calibrated probability model, or a framework replacement.
| Source | What the source supports | Local rule in this chapter | Rust evidence |
|---|---|---|---|
| D2L Softmax Regression | A classifier needs one output per class; softmax turns raw outputs into non-negative probabilities that sum to one. | Logits are raw scores; Softmax is the only boundary that creates a Distribution. | LinearToLogits : Vector -> Logits, Softmax : Logits -> Distribution |
| D2L Softmax From Scratch | Implementing softmax explicitly makes normalization and probability sums visible, and cross entropy selects the probability assigned to the true label. | The local teaching path exposes Distribution before loss so readers can inspect normalization and target selection separately. | Distribution::new, CrossEntropy, target.index() |
| Accurate Computation of the Log-Sum-Exp and Softmax Functions | Softmax and log-sum-exp evaluation can overflow or underflow, and shifted formulas are used to improve floating-point behavior. | Subtract the maximum logit before exponentiation, but keep the local claim to numerical stability of this boundary, not full production numerical analysis. | let max_value = ..., let exp = (*value - max_value).exp() |
| On Calibration of Modern Neural Networks | Confidence calibration asks whether predicted probabilities match empirical correctness frequencies. | A Distribution is a normalized local model output; it is not a guarantee of calibrated confidence. | Distribution::new, softmax_normalizes_logits_into_distribution |
PyTorch CrossEntropyLoss | The production API accepts unnormalized logits and target class indices, and internally corresponds to log-softmax plus negative log likelihood. | The book deliberately expands that compact API into Logits -> Distribution -> Product<Distribution, TokenId> -> Loss. | Product<Distribution, TokenId>, CrossEntropy.apply |
| CS231n Linear Classification | Softmax treats class scores as unnormalized log probabilities and cross entropy penalizes the probability assigned to the correct class. | Do not compute loss from the largest probability; compute it from the target token’s probability. | cross_entropy_is_lower_for_more_confident_target_probability |
The transfer pattern is:
source claim -> local typed boundary -> validation command or test
For this chapter, that means reading cargo test ml::tests and the
src/ml.rs morphisms as evidence for the tiny
Logits -> Distribution -> Product<Distribution, TokenId> -> Loss boundary,
not as evidence for every production classification stack.
The tests in src/ml.rs protect those claims: softmax normalizes logits into a
distribution, and cross entropy is lower when the target token receives higher
probability.
Here is the chapter’s full data-preparation and prediction diagram:
TokenSequence
|
| DatasetWindowing
v
TrainingSet = [
Product<TokenId, TokenId>,
Product<TokenId, TokenId>,
...
]
For each TrainingExample:
input TokenId -------------------------------+
| |
| Embedding |
v |
Vector target TokenId
| |
| LinearToLogits |
v |
Logits |
| |
| Softmax |
v |
Distribution ---------------- Product -------+
|
| CrossEntropy
v
Loss
The left side is the prediction path. The right side carries the target token.
CrossEntropy is the first stage that needs both, so the chapter uses
Product<Distribution, TokenId> at that boundary.
The loss boundary as a rendered math view:
[ \begin{array}{ccccc} \mathrm{TokenId} & \xrightarrow{\mathrm{Embedding}} \mathrm{Vector} & \xrightarrow{\mathrm{LinearToLogits}} \mathrm{Logits} & \xrightarrow{\mathrm{Softmax}} \mathrm{Distribution} \ &&&& \downarrow \mathrm{Product(-, target)} \ &&&& \mathrm{Product}\langle \mathrm{Distribution}, \mathrm{TokenId}\rangle \xrightarrow{\mathrm{CrossEntropy}} \mathrm{Loss} \end{array} ]
How to read this diagram:
- the prediction path produces a
Distribution, - the target token does not become a prediction; it selects which probability becomes the loss,
CrossEntropyis the first arrow that needs the product input,- redrawing the diagram should make the target side visible, not hidden inside the word “loss”.
DatasetWindowing
The problem this block solves is:
A token sequence must become input-target pairs before supervised next-token training can happen.
The block:
/// Turns adjacent tokens into next-token training examples.
#[derive(Debug, Clone)]
pub struct DatasetWindowing;
impl Morphism<TokenSequence, TrainingSet> for DatasetWindowing {
fn name(&self) -> &'static str {
"dataset_windowing"
}
fn apply(&self, tokens: TokenSequence) -> CtResult<TrainingSet> {
if tokens.as_slice().len() < 2 {
return Err(CtError::EmptyInput(
"dataset windowing requires at least 2 tokens",
));
}
TrainingSet::new(
tokens
.as_slice()
.windows(2)
.map(|pair| Product::new(pair[0], pair[1])),
)
}
}
Rust Syntax: Unit Struct
pub struct DatasetWindowing;
This is a unit struct.
It stores no fields because the operation has no configuration.
The value itself represents the transformation.
Rust Syntax: Morphism Shape
impl Morphism<TokenSequence, TrainingSet> for DatasetWindowing
This says:
DatasetWindowing : TokenSequence -> TrainingSet
So it consumes the raw sequence stage and produces the training-example stage.
Rust Syntax: Why It Requires At Least Two Tokens
if tokens.as_slice().len() < 2 {
return Err(CtError::EmptyInput(
"dataset windowing requires at least 2 tokens",
));
}
TokenSequence only guarantees at least one token.
But next-token training requires at least one adjacent pair.
One token:
[7]
produces zero pairs.
Two tokens:
[7, 8]
produce one pair:
7 -> 8
So this morphism owns the stronger validation rule.
Rust Syntax: windows(2)
tokens.as_slice().windows(2)
This walks adjacent pairs:
[1, 2, 3, 4]
becomes:
[1, 2]
[2, 3]
[3, 4]
Each pair becomes:
Product::new(pair[0], pair[1])
That is a TrainingExample.
ML Concept
This is the data-preparation step for next-token prediction.
Category Theory Concept
This is a morphism between two structured objects:
non-empty token list -> non-empty product list
The output examples are product objects:
TokenId x TokenId
Embedding
The problem this block solves is:
A discrete token ID needs to become a dense vector before the model can use linear algebra.
The core block:
#[derive(Debug, Clone)]
pub struct Embedding {
table: Vec<Vec<f32>>,
}
impl Embedding {
pub fn from_parameters(params: &Parameters) -> Self {
Self {
table: params.embedding_table().to_vec(),
}
}
}
impl Morphism<TokenId, Vector> for Embedding {
fn name(&self) -> &'static str {
"embedding"
}
fn apply(&self, token: TokenId) -> CtResult<Vector> {
let Some(row) = self.table.get(token.index()) else {
return Err(CtError::OutOfRange {
kind: "token",
index: token.index(),
limit: self.table.len(),
});
};
Ok(Vector::new(row.clone()))
}
}
Rust Syntax: Stored Table
table: Vec<Vec<f32>>
The embedding table has shape:
vocab_size x model_dimension
Each row is the vector for one token.
Rust Syntax: Constructor From Parameters
pub fn from_parameters(params: &Parameters) -> Self
The embedding morphism is built from model parameters.
It copies the table out of Parameters:
params.embedding_table().to_vec()
This keeps the morphism simple and owned for the tiny tutorial.
Rust Syntax: Morphism Shape
impl Morphism<TokenId, Vector> for Embedding
This says:
Embedding : TokenId -> Vector
Rust Syntax: Bounds Check
let Some(row) = self.table.get(token.index()) else {
return Err(CtError::OutOfRange { ... });
};
The code does not assume every TokenId is valid for every embedding table.
It checks the row lookup at the boundary where the table is used.
Rust Syntax: Why Clone The Row
Ok(Vector::new(row.clone()))
The morphism returns an owned Vector.
The row inside the table is borrowed, so the code clones it into the output object.
This is a deliberate ownership boundary.
ML Concept
An embedding converts a symbolic token into numerical features.
Category Theory Concept
It is an arrow:
TokenId -> Vector
LinearToLogits
The problem this block solves is:
A hidden vector must be projected into one raw score per vocabulary item.
The shape is:
Vector -> Logits
The core implementation stores:
pub struct LinearToLogits {
weight: Vec<Vec<f32>>,
bias: Vec<f32>,
}
The dimensions are:
weight: d_model x vocab_size
bias: vocab_size
input: d_model
output: vocab_size
Rust Syntax: Shape Validation
Inside apply, the code checks:
if self.weight.len() != d_model {
return Err(CtError::ShapeMismatch { ... });
}
This catches a matrix whose row count does not match the input vector length.
Then each row checks:
if self.weight[feature].len() != vocab_size {
return Err(CtError::ShapeMismatch { ... });
}
This catches rows whose column count does not match the output vocabulary size.
Rust Syntax: Linear Computation
The output begins as the bias:
let mut out = self.bias.clone();
Then each input feature contributes to every vocabulary score:
for (feature, input_value) in input.as_slice().iter().enumerate() {
for (vocab_id, output_value) in out.iter_mut().enumerate() {
*output_value += input_value * self.weight[feature][vocab_id];
}
}
Mathematically:
logits = input x weight + bias
ML Concept
This is the language-model head.
It scores each possible next token.
Category Theory Concept
It is a morphism:
Vector -> Logits
It can compose after Embedding because Embedding returns Vector.
Softmax
The problem this block solves is:
Raw scores are not probabilities. They must be normalized into a valid distribution.
The block:
#[derive(Debug, Clone)]
pub struct Softmax;
impl Morphism<Logits, Distribution> for Softmax {
fn name(&self) -> &'static str {
"softmax"
}
fn apply(&self, logits: Logits) -> CtResult<Distribution> {
if logits.as_slice().is_empty() {
return Err(CtError::EmptyInput("softmax"));
}
let max_value = logits
.as_slice()
.iter()
.copied()
.fold(f32::NEG_INFINITY, f32::max);
let mut exps = Vec::with_capacity(logits.as_slice().len());
let mut sum = 0.0;
for value in logits.as_slice() {
let exp = (*value - max_value).exp();
exps.push(exp);
sum += exp;
}
if sum <= 0.0 || !sum.is_finite() {
return Err(CtError::InvalidProbability("softmax"));
}
Distribution::new(exps.into_iter().map(|value| value / sum).collect())
}
}
Rust Syntax: Unit Struct
Softmax stores no state.
It is the operation itself.
Rust Syntax: Morphism Shape
impl Morphism<Logits, Distribution> for Softmax
This says:
Softmax : Logits -> Distribution
Rust Syntax: Empty Check
Softmax over no scores is meaningless.
So the code rejects empty logits.
Rust Syntax: Numerical Stability
let max_value = ...
let exp = (*value - max_value).exp();
Subtracting the maximum value keeps exponentials smaller and more stable.
It does not change the final probabilities because softmax is invariant under adding or subtracting the same constant from every logit.
Rust Syntax: Normalization
Distribution::new(exps.into_iter().map(|value| value / sum).collect())
The raw exponentials are divided by their sum.
Then the Distribution constructor validates the probability invariant.
This is good boundary design: softmax computes, and Distribution::new
enforces the distribution contract.
ML Concept
Softmax turns raw model scores into probabilities. In softmax regression and classification models, this is the step that makes one score per class interpretable as a probability distribution.
High logits become high probabilities.
Low logits become low probabilities.
The output can be interpreted as:
P(next token | current token)
Category Theory Concept
Softmax is a morphism-like transformation:
Logits -> Distribution
It changes the object from an unconstrained score vector into a probability simplex-like object.
CrossEntropy
The problem this block solves is:
A model prediction must be compared to the actual target token to produce a scalar loss.
The block:
#[derive(Debug, Clone)]
pub struct CrossEntropy;
impl Morphism<Product<Distribution, TokenId>, Loss> for CrossEntropy {
fn name(&self) -> &'static str {
"cross_entropy"
}
fn apply(&self, input: Product<Distribution, TokenId>) -> CtResult<Loss> {
let (distribution, target) = input.into_parts();
let Some(probability) = distribution.as_slice().get(target.index()).copied() else {
return Err(CtError::OutOfRange {
kind: "target",
index: target.index(),
limit: distribution.as_slice().len(),
});
};
Loss::new(-probability.max(1e-9).ln())
}
}
Rust Syntax: Input Type
Product<Distribution, TokenId>
Cross entropy needs both:
- the predicted distribution
- the correct target token
That pair is a product object.
Rust Syntax: Splitting The Product
let (distribution, target) = input.into_parts();
This consumes the product and extracts both values.
Rust Syntax: Target Bounds Check
distribution.as_slice().get(target.index())
The target token must be inside the probability vector.
If the distribution has 5 entries, target index 7 is invalid.
This error belongs here because this is the first place the target is used as an index into the predicted distribution.
Rust Syntax: Negative Log Likelihood
Loss::new(-probability.max(1e-9).ln())
The loss is:
-ln(probability assigned to the correct token)
The max(1e-9) avoids taking the log of zero.
Then Loss::new validates the loss scalar.
ML Concept
Cross entropy measures how surprised the model was by the true target.
If the model assigns high probability to the target, the loss is small.
If the model assigns low probability to the target, the loss is large.
This is why the chapter says loss is a training signal. It turns a probability assigned to the correct token into a number the optimizer can try to reduce.
Category Theory Concept
Cross entropy consumes a product object:
Distribution x TokenId
and maps it into:
Loss
So its shape is:
Product<Distribution, TokenId> -> Loss
DirectPredict
The problem this block solves is:
The course needs a direct implementation to compare against the composed prediction path.
DirectPredict stores parameters and implements:
TokenId -> Distribution
Internally, it still performs:
Embedding
LinearToLogits
Softmax
but it writes the steps directly.
This allows the code to test:
composed path == direct path
That is the program’s tiny commutative diagram check.
Rust Syntax
DirectPredict is a struct that owns Parameters.
Its apply method calls the prediction steps directly instead of using
Compose.
ML Concept
This is the direct prediction implementation.
It exists so the composed path can be checked against a straightforward path.
Category Theory Concept
It provides the second path in a commutative diagram:
composed path
direct path
average_loss
The problem this function solves is:
Training needs one scalar loss over the whole training set.
The function builds the composed prediction path:
let embedding = Embedding::from_parameters(params);
let linear = LinearToLogits::from_parameters(params);
let predict = Compose::<_, _, Vector>::new(embedding, linear);
let predict = Compose::<_, _, Logits>::new(predict, Softmax);
The resulting shape is:
TokenId -> Distribution
Then each training example is evaluated:
let distribution = predict.apply(*example.first())?;
let loss = loss_fn.apply(Product::new(distribution, *example.second()))?;
Finally, the average is wrapped in Loss::new.
The function does not return a raw f32.
It returns a validated Loss.
Rust Syntax
The function takes borrowed parameters and a borrowed dataset:
pub fn average_loss(params: &Parameters, dataset: &TrainingSet) -> CtResult<Loss>
It does not consume either one.
The function loops through examples, accumulates scalar losses, and divides by the dataset length.
ML Concept
Average loss summarizes model performance over the full training set.
Category Theory Concept
It folds many example-level loss morphism results into one scalar object.
composed_prediction_matches_direct_prediction
The problem this function solves is:
The code should prove that the composed prediction pipeline and the direct implementation agree.
The composed path is:
TokenId -> Vector -> Logits -> Distribution
The direct path is:
TokenId -> Distribution
The function runs both on the same token and compares every probability with a small floating-point tolerance.
Category-theoretically, this is a commutative diagram test:
composed
TokenId ------------> Distribution
\ ^
\ direct |
---------------------
The exact drawing is less important than the idea:
Two paths through the system should produce the same meaning.
Rust Syntax
The function builds one composed path with Compose and one direct path with
DirectPredict.
It compares probabilities pairwise with approx_eq.
ML Concept
This verifies that refactoring the prediction path into smaller stages did not change the predicted probabilities.
Category Theory Concept
This is a commutative-diagram check in code.
Run The Demo
Run:
cargo run --bin category_ml
Look at sections 2 through 5 in the output.
You should see:
TokenSequence -> TrainingSet
prediction probabilities
loss for a target token
Demo Output Transfer Checklist
Sections 2 through 5 of the demo are the smallest complete ML story in the book. Read them as a boundary report.
| Demo output | Boundary to own | Shortcut to reject |
|---|---|---|
Dataset morphism: TokenSequence -> TrainingSet | DatasetWindowing turns a token stream into adjacent input-target pairs. | Treating a raw token sequence as if it were already supervised data. |
"I" -> "love" | Each pair is Product<TokenId, TokenId>. | Forgetting which token is the input and which token is the target. |
Composition: Softmax after Linear after Embedding | Prediction is TokenId -> Vector -> Logits -> Distribution. | Skipping Logits and pretending vectors are probabilities. |
| `P(next token | ‘I’) = […]` | The printed vector is a validated Distribution. |
Product object: Prediction x Target -> Loss | Loss needs both the prediction and the correct next token. | Calling loss on Distribution alone. |
loss for target 'love' = ... | Cross entropy uses the probability at the target token index. | Using the largest probability instead of the target probability. |
This checklist compresses the chapter into one reader habit:
visible output -> typed boundary -> invalid shortcut rejected
The ML idea is that a training example is not just an input. It is an input paired with the answer the model should have predicted. The category-theory idea is that the answer enters through a product boundary:
Distribution x TokenId -> Loss
The Rust idea is that the boundary is not only prose. It appears as a concrete type:
Product<Distribution, TokenId>
Why This Matters
This chapter is where the course stops being abstract.
The code implements a real, tiny version of the common language-model training path:
context token -> hidden vector -> next-token probabilities -> loss
The implementation is small, but the boundaries are real. Invalid token lookup
returns OutOfRange, invalid matrix shape returns ShapeMismatch, empty
logits return EmptyInput, invalid probabilities return InvalidProbability,
and invalid loss returns InvalidLoss.
Errors are caught where the invalid data first becomes meaningful.
Core Mental Model
In Rust terms:
each ML operation implements Morphism<Input, Output>
In ML terms:
prediction is embedding + linear projection + softmax
loss is negative log probability of the target
In category-theory terms:
prediction is composition of arrows
loss consumes a product object
the direct and composed paths should commute
Checkpoint
Where should an out-of-range target token be caught?
Correct answer:
Inside
CrossEntropy, because that is where the target is used to index the predicted distribution.
Where This Leaves Us
This chapter assembled the first complete tiny ML path. A token sequence becomes training examples, a token becomes a vector, a vector becomes logits, logits become probabilities, and a probability distribution plus a target token becomes loss.
The next chapter, Training as an Endomorphism, changes the question from “how do we evaluate one prediction?” to “how do repeated updates change the model state?” That is where training enters as an endomorphism.
Further Reading
The problem this section solves is transfer. If you only read the tiny Rust implementation, larger framework APIs may still look unrelated. If you only read a framework reference, the explicit typed boundaries in this chapter may feel unnecessarily small. Use the references to connect the two views without collapsing them.
Start from the local Rust evidence:
DatasetWindowing.apply : TokenSequence -> TrainingSet
Embedding.apply : TokenId -> Vector
LinearToLogits.apply : Vector -> Logits
Softmax.apply : Logits -> Distribution
CrossEntropy.apply : Distribution x TokenId -> Loss
average_loss : Parameters x TrainingSet -> Loss
Then read the sources in this order:
| Source | What to transfer back into this chapter | Local evidence to inspect |
|---|---|---|
| D2L Softmax Regression | Multiclass classification uses raw scores, softmax probabilities, and cross entropy as one connected prediction-and-loss story. | LinearToLogits.apply, Softmax.apply, CrossEntropy.apply |
| D2L Softmax From Scratch | Implementing the pieces from scratch reveals the roles hidden by concise framework calls. | src/ml.rs, average_loss, cargo test ml::tests --lib |
| Accurate Computation of the Log-Sum-Exp and Softmax Functions | Floating-point softmax implementations use shifted formulas to reduce overflow and harmful underflow. | let max_value = ..., (*value - max_value).exp() |
| On Calibration of Modern Neural Networks | A normalized probability vector is not automatically a calibrated confidence estimate over future predictions. | Distribution::new, softmax_normalizes_logits_into_distribution |
PyTorch CrossEntropyLoss | A production API can accept unnormalized logits and target class indices while internally combining log-softmax and negative log likelihood. | Logits -> Distribution, Product<Distribution, TokenId>, CrossEntropy.apply |
| CS231n Linear Classification | Scores, classifiers, and losses should be kept conceptually separate before optimization is discussed. | Vector -> Logits, Distribution x TokenId -> Loss |
The ML bridge is:
framework call
-> raw scores plus target index
-> probability assigned to the target
-> loss
The category-theory bridge is:
Logits -> Distribution
Distribution x TokenId -> Loss
The first arrow is an ordinary morphism. The second is a product-input morphism because loss needs both the model’s prediction and the correct target token.
After reading one source, answer four questions:
- Which local boundary did it clarify?
- Which value is raw score, probability, target, or loss?
- Which shortcut did the source use that the tiny Rust path expands?
- Which command or test shows the local evidence?
For this chapter, the commands are:
cargo run --bin category_ml
cargo test ml::tests --lib
Checkpoint:
When reading an external loss API, can you name which part corresponds to
Logits -> Distribution and which part corresponds to Distribution x TokenId
-> Loss?
For terminology recovery, use:
- Glossary: logits, softmax, probability distribution, cross entropy
- References: softmax regression and linear classifiers
If a source does not help you point to one local boundary and one output or test signal, it has not transferred back into this chapter yet.
Practice After This Chapter
Use Exercise 3 to trace adjacent training pairs and Exercise 9 to connect this tiny implementation to a larger ML reference. The pair checks both local code understanding and source-backed transfer.
Retrieval Practice
Recall
Recover the path before explaining the calculations.
- What morphism turns a
TokenSequenceinto aTrainingSet? - Which three arrows turn a
TokenIdinto aDistribution? - Which two objects are paired before
CrossEntropycan produceLoss?
Explain
Use the target token to explain why the loss boundary needs a product object.
- Why are
Logitsnot the same object asDistribution? - Why does
CrossEntropyuse the probability attarget.index()instead of the largest probability in the distribution? - Why does an out-of-range target error belong inside
CrossEntropy?
Apply
Use the demo output and the numeric examples in this chapter.
- Given
TokenId -> Vector -> Logits -> Distribution, write the Rust type that must appear betweenEmbeddingandSoftmax. - A distribution is
[0.70, 0.20, 0.10]and the target token index is1. Which probability should cross entropy use? - A token sequence is
[4, 9, 2]. Which adjacent training pairs shouldDatasetWindowingproduce?
Debug
For each invalid shortcut, name the missing boundary or wrong object:
Logits -> Loss
Distribution -> Loss
using the maximum probability instead of the target probability
A strong answer should mention the exact typed path:
Logits -> Distribution
Distribution x TokenId -> Loss
The point is not to memorize the formula. The point is to know which object owns the probability invariant and which object selects the target probability.
Training as an Endomorphism
The problem this chapter solves is:
A model is not only used for prediction. It must also be updated by training, and one update should produce the same kind of object it consumed.
The key shape is:
Parameters -> Parameters
This is an endomorphism.
In ordinary ML terms:
old parameters
-> compute predictions
-> compute loss gradients
-> subtract learning-rate-scaled gradients
-> new parameters
In category-theory terms:
A -> A
Because the input and output type are the same, the step can be repeated.
Reader orientation: Do not read this chapter as a full backpropagation engine. It is a small, explicit training step whose purpose is to make the shape
Parameters -> Parametersvisible and runnable.
Chapter Outcomes
By the end of this chapter, you should be able to:
- explain why one training step is modeled as
Parameters -> Parameters, - separate loss measurement from parameter update,
- compare the tiny
TrainStep(dataset, learning_rate)boundary with a production optimizer loop that callszero_grad,backward, andstep.
What You Already Know
If you have seen gradient descent, you already know the informal movement:
parameters are adjusted and then used again. If you know Rust, you already know
that a function can return the same type it receives. This chapter names that
shape precisely: a training step is an endomorphism on Parameters.
Update Trace Before Source
Before reading src/training.rs, keep this one-step trace in view. It separates
loss measurement, gradient accumulation, the parameter update, and repetition.
| Stage | Rust shape | Plain meaning | What to check |
|---|---|---|---|
| Current state | Parameters | embeddings, output weights, and bias before the step | What object is being updated? |
| Training data | TrainingSet | adjacent input-target examples | Is the update using examples, not one prediction alone? |
| Forward pass | TokenId -> Vector -> Logits -> Distribution | predict with the current parameters | Are predictions computed before gradients are accumulated? |
| Error signal | dlogits[target_id] -= 1.0 | probability minus target indicator | Which target index changes the gradient? |
| Gradient buffers | grad_embedding, grad_lm_head, grad_bias | accumulated directions for each parameter group | Which buffer matches which parameter group? |
| Average step | batch_scale and LearningRate | scale gradients before subtracting them | Is this one full-batch update? |
| New state | Parameters | updated model state with the same shape | Did the output remain reusable as model state? |
One optimizer update has this shape:
Parameters
-> predictions on TrainingSet
-> gradients
-> Parameters
Repeated optimization is not a different kind of arrow. It is the same arrow used again:
Parameters0 -> Parameters1 -> Parameters2 -> ... -> ParametersN
That is the chapter’s main separation. Parameters -> Loss measures the model.
Parameters -> Parameters updates the model. The first is diagnostic. The
second is repeatable training.
The local update rule in this chapter is the same first-order shape used in standard gradient descent:
parameter = parameter - learning_rate * average_gradient
In the Rust source, that appears as:
*value -= learning_rate * grad * batch_scale;
The chapter uses a full-batch step, so one call to TrainStep::apply reads all
examples in the TrainingSet, averages their gradients with batch_scale, and
returns a new Parameters value. The tests repeat that one endomorphism with
apply_endomorphism_n_times.
Training Debugging Checklist
When training output looks surprising, separate four questions before changing the update rule:
| Question | Safe answer in this chapter | Common mistake |
|---|---|---|
| What object is updated? | Parameters | treating Loss as the updated object |
| What object measures quality? | Loss from Parameters x TrainingSet | returning loss instead of new parameters |
| What repeats? | the same TrainStep : Parameters -> Parameters | inventing a new arrow for every step count |
| What controls update size? | LearningRate and the averaged gradient | assuming more steps always means better behavior |
The example prints both roles:
TrainStep : Parameters -> Parameters
Parameters x TrainingSet -> Loss
Those lines are deliberately different. Loss tells you how the current
parameters perform on the dataset. It is evidence, not the next model state.
TrainStep returns the next model state. That is what makes repetition legal:
Parameters0 -> Parameters1 -> ... -> Parameters80
Use this diagnostic when changing StepCount:
| If you change | You are testing | You are not proving |
|---|---|---|
StepCount::new(1) | one update preserves state shape | that one update is enough training |
StepCount::new(10) | repeated updates can improve the tiny dataset | that all datasets behave the same |
StepCount::new(200) | the same endomorphism can be iterated many times | that more steps can never overshoot or plateau |
The category-theory lesson is stable even when the numeric loss changes in different ways:
the update remains Parameters -> Parameters
the measurement remains Parameters x TrainingSet -> Loss
Source Snapshot
This file implements one full-batch optimizer update.
Source snapshot: src/training.rs
use crate::category::Morphism;
use crate::domain::{LearningRate, Parameters, TrainingSet, Vector};
use crate::error::{CtError, CtResult};
use crate::ml::{LinearToLogits, Softmax};
/// One full-batch optimizer update.
///
/// Categorically, this is an endomorphism:
///
/// `Parameters -> Parameters`
#[derive(Debug, Clone)]
pub struct TrainStep {
dataset: TrainingSet,
learning_rate: LearningRate,
}
impl TrainStep {
pub fn new(dataset: TrainingSet, learning_rate: LearningRate) -> Self {
Self {
dataset,
learning_rate,
}
}
}
impl Morphism<Parameters, Parameters> for TrainStep {
fn name(&self) -> &'static str {
"train_step_endomorphism"
}
fn apply(&self, params: Parameters) -> CtResult<Parameters> {
let vocab_size = params.vocab_size();
let d_model = params.d_model();
if vocab_size == 0 || d_model == 0 {
return Err(CtError::EmptyInput("parameters"));
}
let mut grad_embedding = vec![vec![0.0; d_model]; params.embedding.len()];
let mut grad_lm_head = vec![vec![0.0; vocab_size]; d_model];
let mut grad_bias = vec![0.0; vocab_size];
for example in self.dataset.examples() {
let input_id = example.first().index();
let target_id = example.second().index();
if input_id >= params.embedding.len() {
return Err(CtError::OutOfRange {
kind: "input token",
index: input_id,
limit: params.embedding.len(),
});
}
if target_id >= vocab_size {
return Err(CtError::OutOfRange {
kind: "target token",
index: target_id,
limit: vocab_size,
});
}
let x = ¶ms.embedding[input_id];
let logits = LinearToLogits::from_parts(params.lm_head.clone(), params.bias.clone())
.apply(Vector::new(x.clone()))?;
let probs = Softmax.apply(logits)?;
let mut dlogits = probs.as_slice().to_vec();
dlogits[target_id] -= 1.0;
for (vocab_id, dlogit) in dlogits.iter().copied().enumerate() {
grad_bias[vocab_id] += dlogit;
for (feature, x_feature) in x.iter().copied().enumerate() {
grad_lm_head[feature][vocab_id] += x_feature * dlogit;
}
}
for (feature, grad_feature) in grad_embedding[input_id].iter_mut().enumerate() {
let dx = params.lm_head[feature]
.iter()
.zip(dlogits.iter())
.map(|(weight, dlogit)| weight * dlogit)
.sum::<f32>();
*grad_feature += dx;
}
}
let batch_scale = 1.0 / self.dataset.len() as f32;
let learning_rate = self.learning_rate.value();
let mut updated = params.clone();
for (row, grad_row) in updated.embedding.iter_mut().zip(grad_embedding.iter()) {
for (value, grad) in row.iter_mut().zip(grad_row.iter()) {
*value -= learning_rate * grad * batch_scale;
}
}
for (row, grad_row) in updated.lm_head.iter_mut().zip(grad_lm_head.iter()) {
for (value, grad) in row.iter_mut().zip(grad_row.iter()) {
*value -= learning_rate * grad * batch_scale;
}
}
for (bias, grad) in updated.bias.iter_mut().zip(grad_bias.iter()) {
*bias -= learning_rate * grad * batch_scale;
}
Ok(updated)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::category::{StepCount, apply_endomorphism_n_times};
use crate::domain::{ModelDimension, Product, TokenId, TokenSequence, TrainingSet, VocabSize};
use crate::ml::{DatasetWindowing, average_loss};
#[test]
fn repeated_training_step_reduces_loss() -> CtResult<()> {
let tokens = TokenSequence::from_indices([1, 2, 3, 4, 1, 2, 3, 4])?;
let dataset = DatasetWindowing.apply(tokens)?;
let params = Parameters::init(VocabSize::new(5)?, ModelDimension::new(4)?);
let before = average_loss(¶ms, &dataset)?;
let train_step = TrainStep::new(dataset.clone(), LearningRate::new(1.0)?);
let trained = apply_endomorphism_n_times(&train_step, params, StepCount::new(80))?;
let after = average_loss(&trained, &dataset)?;
assert!(after.value() < before.value());
Ok(())
}
#[test]
fn one_training_step_preserves_parameter_shape() -> CtResult<()> {
let tokens = TokenSequence::from_indices([1, 2, 3, 4])?;
let dataset = DatasetWindowing.apply(tokens)?;
let params = Parameters::init(VocabSize::new(5)?, ModelDimension::new(4)?);
let train_step = TrainStep::new(dataset, LearningRate::new(0.1)?);
let trained = train_step.apply(params.clone())?;
assert_eq!(trained.vocab_size(), params.vocab_size());
assert_eq!(trained.d_model(), params.d_model());
Ok(())
}
#[test]
fn training_rejects_target_outside_vocabulary() -> CtResult<()> {
let dataset = TrainingSet::new([Product::new(TokenId::new(0), TokenId::new(9))])?;
let params = Parameters::init(VocabSize::new(2)?, ModelDimension::new(2)?);
let train_step = TrainStep::new(dataset, LearningRate::new(0.1)?);
assert!(matches!(
train_step.apply(params),
Err(CtError::OutOfRange {
kind: "target token",
index: 9,
limit: 2,
})
));
Ok(())
}
}
The Whole File
src/training.rs defines:
TrainStep
TrainStep::new
impl Morphism<Parameters, Parameters> for TrainStep
unit test proving repeated training reduces loss
The whole file is about one idea:
training is a repeatable typed transformation of model state
Source Reading Bridge: One Step Has Four Responsibilities
The short list above names the file’s pieces, but it does not yet tell you how
to read the main function. The central method is TrainStep::apply in
src/training.rs. Read it as four responsibilities in order:
validate the current Parameters
run the current model on each training example
accumulate gradients for embedding, output weights, and bias
subtract a learning-rate-scaled average gradient to create new Parameters
The ML intuition is gradient descent. A loss signal does not replace the model. It tells each parameter which direction would reduce the current error on the training set. The code makes that visible by separating the diagnostic value from the state update:
average_loss(¶ms, &dataset) -> Loss
TrainStep::apply(params) -> Parameters
That difference matters. If TrainStep::apply returned Loss, it could tell
you how bad the current model is, but it could not be composed with itself for
the next update.
The category-theory connection is the same boundary in a shorter form:
TrainStep(dataset, learning_rate) : Parameters -> Parameters
The dataset and learning rate configure which update arrow you have. The gradient buffers are internal machinery used while building the output object; they are not the object being returned by the morphism.
Checkpoint:
If `TrainStep::apply` returned `Loss` instead of `Parameters`, what ability
would `apply_endomorphism_n_times` lose?
Production Optimizer Boundary
Production frameworks usually split the training loop across model parameters,
stored gradients, an optimizer object, and an optimizer step. PyTorch’s
torch.optim documentation describes optimizers as objects that hold current
state and update parameters from computed gradients. The common loop shape is:
optimizer.zero_grad()
output = model(input)
loss = loss_fn(output, target)
loss.backward()
optimizer.step()
This book compresses the same teaching shape into one explicit Rust morphism:
TrainStep(dataset, learning_rate) : Parameters -> Parameters
The same training boundary as a rendered math view:
[ \begin{array}{ccccc} \mathrm{Parameters}_t & \xrightarrow{\mathrm{average_loss}(-,\mathrm{TrainingSet})} & \mathrm{Loss}t & \xrightarrow{\mathrm{local\ gradients}} & \nabla_t \ &&&& \downarrow \mathrm{apply\ learning\ rate} \ \mathrm{Parameters}{t+1} & \xleftarrow{\mathrm{TrainStep(dataset, learning_rate)}} & \mathrm{Parameters}_t && \end{array} ]
How to read this diagram:
- the upper path measures how wrong the current parameters are,
- the gradient path explains what should change,
- the bottom arrow is the typed update that returns the next full
Parametersobject, - only the bottom arrow has the endomorphism shape
Parameters -> Parameters.
The tiny Rust boundary is smaller than a production optimizer. It does not model momentum, parameter groups, optimizer state dictionaries, closures, schedulers, mixed precision, or distributed training. It keeps one full-batch gradient update inspectable.
| Production training responsibility | Tiny Rust teaching boundary |
|---|---|
| optimizer owns parameter groups and update state | TrainStep owns TrainingSet and LearningRate |
loss.backward() computes gradients | TrainStep::apply accumulates local gradients directly |
optimizer.step() updates parameters | TrainStep::apply returns a new Parameters value |
zero_grad() manages stored gradient buffers | gradient buffers are local variables inside one update |
| schedulers may change learning rates across epochs | one LearningRate configures one repeated endomorphism |
When you return to a framework, the useful transfer question is:
which object owns the update state, and which call turns current parameters
into next parameters?
Framework-To-Rust Responsibility Ledger
If you already know the framework loop, use this ledger before reading
TrainStep::apply. It prevents two common mistakes: treating the tiny Rust code
as a hidden framework clone, or treating framework calls as unrelated magic.
| Framework cue | Production responsibility | Tiny Rust handle | Category boundary | Safe non-claim |
|---|---|---|---|---|
optimizer.zero_grad() | clear accumulated gradient buffers before the next backward pass | grad_embedding, grad_lm_head, and grad_bias start as local zeroed buffers inside TrainStep::apply | preparation inside one update arrow | no persistent gradient field is stored on Parameters |
loss.backward() | compute gradients from the current loss through the recorded graph | dlogits[target_id] -= 1.0 and local gradient accumulation for the tiny softmax-linear path | measurement informs the update | not a general autograd tape |
optimizer.step() | update parameters using gradients and optimizer state | *value -= learning_rate * grad * batch_scale; and returned Parameters | Parameters -> Parameters | not Adam, momentum, scheduler, mixed precision, or distributed training |
optimizer state_dict | persist optimizer state and parameter-group metadata | no corresponding field in TrainStep; only TrainingSet and LearningRate configure the teaching update | larger state would need a larger object | the tiny step does not serialize optimizer state |
The useful habit is to translate a framework call into a responsibility, then ask where that responsibility appears in the local Rust code. If no local handle exists, say so explicitly.
Framework-to-Rust audit card:
framework cue:
responsibility:
local Rust handle:
returned object:
category boundary:
safe non-claim:
Example:
framework cue: optimizer.step()
responsibility: apply gradients to parameters
local Rust handle: *value -= learning_rate * grad * batch_scale;
returned object: Parameters
category boundary: TrainStep(dataset, learning_rate) : Parameters -> Parameters
safe non-claim: this is one full-batch teaching update, not a production optimizer
Source-Backed Precision Rules
This chapter uses external sources to keep the tiny update honest. Each source supports a limited claim; these citations are not proof that this crate is a production optimizer or a full automatic-differentiation engine.
| Source | What the source supports | Local rule in this chapter | Rust evidence |
|---|---|---|---|
| D2L Gradient Descent | First-order gradient descent updates a value by moving against the gradient, and the learning rate controls whether the step is useful or unstable. | The local update is parameter = parameter - learning_rate * average_gradient; do not claim every step count or learning rate must improve every dataset. | *value -= learning_rate * grad * batch_scale;, LearningRate, StepCount |
| D2L Backpropagation and Computational Graphs | Backpropagation computes gradients through intermediate variables using the chain rule in reverse order. | This chapter hand-computes the local softmax-linear gradients for one tiny model; it is not a general autograd tape. | dlogits[target_id] -= 1.0, grad_lm_head, grad_embedding |
| Automatic differentiation in machine learning: a survey | Automatic differentiation is broader than backpropagation and distinct from symbolic differentiation and finite differences. | Do not call this chapter’s hand-written gradient buffers an AD engine; they are one visible gradient path for one tiny model. | TrainStep::apply, grad_embedding, grad_lm_head, grad_bias |
PyTorch torch.optim | A production optimizer owns update state and updates parameters after gradients have been computed. | TrainStep compresses zero_grad, backward, and step into one inspectable full-batch teaching boundary. | TrainStep(dataset, learning_rate) : Parameters -> Parameters |
| Backprop as Functor | Parameter-update rules can be studied compositionally under stated assumptions. | The categorical claim here is narrower: one fixed training step is an endomorphism on Parameters; the chapter does not prove a monoidal-functor result. | impl Morphism<Parameters, Parameters> for TrainStep, apply_endomorphism_n_times |
The transfer pattern is:
source claim -> local typed boundary -> validation command or test
For this chapter, that means reading cargo run --example 03_training_endomorphism and the src/training.rs tests as evidence for the
tiny Parameters -> Parameters boundary, not as evidence for every production
training system.
Worked Example: Repeating One Update
The smallest first-principles version of a repeated update is a number being moved a little at a time:
#![allow(unused)]
fn main() {
fn step_toward_zero(value: f32, learning_rate: f32) -> f32 {
value - learning_rate * value
}
let once = step_toward_zero(10.0, 0.1);
let twice = step_toward_zero(once, 0.1);
assert!(twice < once);
}
The real training code applies the same repeatable-update idea to Parameters,
not to one scalar. The output stays the same kind of object as the input, so the
update can be run again.
Self-Check
Before reading the full training step, explain why Parameters -> Parameters
is repeatable but Parameters -> Loss is not.
One Step Before Many Steps
Training becomes easier to reason about if you separate two ideas.
One training step has the shape:
Parameters -> Parameters
It reads the dataset, computes predictions, accumulates gradients, subtracts a learning-rate-scaled average gradient, and returns updated model state.
Repeated training is just iteration of that same shape:
Parameters0 -> Parameters1 -> Parameters2 -> ... -> ParametersN
The chapter’s category-theory word for the one-step shape is endomorphism. The ML word for the update rule is gradient descent. The Rust evidence is the trait implementation:
impl Morphism<Parameters, Parameters> for TrainStep
The tests in src/training.rs protect the learner-visible claims: one training
step preserves the parameter shape, out-of-range targets fail with a typed
error, and repeated steps reduce loss on the tiny dataset.
TrainStep
The problem this block solves is:
A training update needs a dataset and a learning rate, and those values should travel together as one configured operation.
The block:
/// One full-batch optimizer update.
///
/// Categorically, this is an endomorphism:
///
/// `Parameters -> Parameters`
#[derive(Debug, Clone)]
pub struct TrainStep {
dataset: TrainingSet,
learning_rate: LearningRate,
}
Rust Syntax
This is a named-field struct.
It stores:
dataset: TrainingSet
learning_rate: LearningRate
Both fields are private.
That means callers cannot directly replace the dataset or learning rate after construction.
The derived traits mean:
Debug -> can be printed for debugging
Clone -> can be explicitly duplicated
TrainingSet is already non-empty.
LearningRate is already finite and positive.
So TrainStep stores validated inputs.
ML Concept
A training step needs:
- examples to learn from
- a step size for parameter updates
The dataset gives the input-target pairs.
The learning rate controls how far the update moves.
Category-Theory Concept
TrainStep is the value that will implement:
Parameters -> Parameters
That makes it an endomorphism on the object Parameters.
TrainStep::new
The problem this block solves is:
Construct a configured training step from already validated pieces.
The block:
impl TrainStep {
pub fn new(dataset: TrainingSet, learning_rate: LearningRate) -> Self {
Self {
dataset,
learning_rate,
}
}
}
Rust Syntax
impl TrainStep defines methods for TrainStep.
The constructor takes ownership of:
dataset
learning_rate
and stores them.
It returns Self, not CtResult<Self>, because the inputs are already
validated domain objects.
No extra validation is needed here.
ML Concept
This is like configuring an optimizer step:
use this dataset
use this learning rate
The actual update happens later in apply.
Category-Theory Concept
The constructor chooses one specific endomorphism from a family:
TrainStep(dataset, learning_rate) : Parameters -> Parameters
Different datasets or learning rates create different update morphisms.
Morphism Implementation
The problem this block solves is:
Make
TrainStepa real typed arrow from model parameters back to model parameters.
The header:
impl Morphism<Parameters, Parameters> for TrainStep {
Rust Syntax
This says:
TrainStep implements Morphism<Input = Parameters, Output = Parameters>
So the apply method must have this effective shape:
Parameters -> CtResult<Parameters>
The name method:
fn name(&self) -> &'static str {
"train_step_endomorphism"
}
returns a static label for the transformation.
ML Concept
The input Parameters are the current model weights.
The output Parameters are the updated weights after one full-batch step.
Category-Theory Concept
Because the input and output object are the same, TrainStep is an
endomorphism.
That is what lets this work:
Parameters0 -> Parameters1 -> Parameters2 -> ... -> ParametersN
apply: Shape Checks
The problem this block solves is:
Before computing gradients, verify that the parameter object has usable dimensions.
The block:
let vocab_size = params.vocab_size();
let d_model = params.d_model();
if vocab_size == 0 || d_model == 0 {
return Err(CtError::EmptyInput("parameters"));
}
Rust Syntax
The code asks the parameter object for two dimensions.
Then it rejects zero-sized parameters.
This uses an explicit error instead of panicking.
ML Concept
Training cannot run if:
- there are zero possible vocabulary outputs
- hidden vectors have zero width
Those shapes would make the gradient arrays meaningless.
Category-Theory Concept
The endomorphism is only defined on valid Parameters.
Invalid parameter state is rejected before the morphism performs the update.
Gradient Buffers
The problem this block solves is:
Accumulate gradients for every trainable parameter before applying the update.
The block:
let mut grad_embedding = vec![vec![0.0; d_model]; params.embedding.len()];
let mut grad_lm_head = vec![vec![0.0; vocab_size]; d_model];
let mut grad_bias = vec![0.0; vocab_size];
Rust Syntax
These are mutable matrices and vectors initialized to zero.
Their shapes mirror the trainable parameters:
grad_embedding: same row count as embedding, d_model columns
grad_lm_head: d_model x vocab_size
grad_bias: vocab_size
ML Concept
Gradients accumulate how each parameter should change to reduce loss.
The code uses full-batch training: it processes every example, accumulates all gradients, averages them, then updates once.
Category-Theory Concept
The gradient buffers are not the endomorphism itself.
They are internal machinery used to construct the output object in:
Parameters -> Parameters
Example Loop
The problem this block solves is:
For each training example, compute the local contribution to the parameter gradients.
The loop begins:
for example in self.dataset.examples() {
let input_id = example.first().index();
let target_id = example.second().index();
...
}
Rust Syntax
self.dataset.examples() returns a slice of TrainingExample.
Each example is a Product<TokenId, TokenId>.
So:
example.first()
is the input token.
example.second()
is the target token.
The code extracts raw indices because matrix indexing needs usize.
ML Concept
Each example says:
given input token, predict target token
The training loop calculates how wrong the current model is for that example.
Category-Theory Concept
The example is an element of:
TokenId x TokenId
The training morphism consumes many such product values while building the parameter update.
Token Bounds Checks
The problem this block solves is:
Training examples must refer to tokens that exist in the current parameter shapes.
The checks:
if input_id >= params.embedding.len() {
return Err(CtError::OutOfRange { ... });
}
if target_id >= vocab_size {
return Err(CtError::OutOfRange { ... });
}
Rust Syntax
These are ordinary bounds checks with typed errors.
They prevent invalid indexing into:
- the embedding table
- the vocabulary-sized output vector
ML Concept
An input token must have an embedding row.
A target token must be one of the possible prediction classes.
If either token is outside the model vocabulary, training cannot continue.
Category-Theory Concept
The example must belong to the finite token object that the parameters are currently modeling.
This check keeps the training morphism inside the intended domain.
Forward Pass Inside Training
The problem this block solves is:
To compute gradients, the training step first needs the current prediction.
The block:
let x = ¶ms.embedding[input_id];
let logits = LinearToLogits::from_parts(params.lm_head.clone(), params.bias.clone())
.apply(Vector::new(x.clone()))?;
let probs = Softmax.apply(logits)?;
Rust Syntax
x borrows the embedding row for the input token.
LinearToLogits::from_parts(...) builds a linear projection from the current
weights.
Vector::new(x.clone()) wraps the embedding row as a Vector.
Then:
Vector -> Logits -> Distribution
runs through the same morphism interface as prediction.
ML Concept
This computes the model’s current predicted distribution for one input token.
The gradient depends on the difference between that distribution and the true target.
Category-Theory Concept
Even inside training, prediction is still a composed path:
TokenId -> Vector -> Logits -> Distribution
Training uses that path as part of a larger endomorphism:
Parameters -> Parameters
Logit Gradient
The problem this block solves is:
For softmax plus cross entropy, the gradient with respect to logits is predicted probability minus one-hot target.
The block:
let mut dlogits = probs.as_slice().to_vec();
dlogits[target_id] -= 1.0;
Rust Syntax
The probabilities are copied into a mutable vector.
Then the target class is adjusted by subtracting 1.0.
If:
probs = [0.70, 0.20, 0.10]
target = 1
then:
dlogits = [0.70, -0.80, 0.10]
ML Concept
This is the standard simplified gradient for softmax cross entropy.
It says:
- decrease the scores that are too high
- increase the target score if it was too low
Category-Theory Concept
This is local derivative information for one part of the composed prediction path.
The next loops compose that local derivative back into parameter gradients.
Worked Example: Why Subtracting A Negative Gradient Increases The Target
The update rule can feel backwards the first time you see it. The code subtracts gradients:
parameter = parameter - learning_rate * gradient
So how can training increase the target score?
Use the same three-class example:
probs = [0.70, 0.20, 0.10]
target = 1
After the target correction:
dlogits = [0.70, -0.80, 0.10]
Now look only at the bias update with learning rate 0.1 and one example:
bias[0] = 0.0 - 0.1 * 0.70 = -0.07
bias[1] = 0.0 - 0.1 * -0.80 = 0.08
bias[2] = 0.0 - 0.1 * 0.10 = -0.01
The non-target classes had positive gradients, so subtracting them lowers their biases. The target class had a negative gradient, so subtracting it raises the target bias.
That is the local version of gradient descent: move parameters in the direction that lowers loss. In this tiny classifier, the direction says “make the target logit larger and make the overconfident non-target logits smaller.”
The Rust path is:
dlogits
-> grad_bias
-> bias -= learning_rate * grad * batch_scale
For output weights, the same sign passes through x_feature * dlogit. For the
embedding row, the sign passes backward through the output weights. The full
training step is bigger, but the sign logic starts here.
Output-Head And Bias Gradients
The problem this block solves is:
Convert the logit gradient into gradients for the output matrix and bias.
The core loop:
for (vocab_id, dlogit) in dlogits.iter().copied().enumerate() {
grad_bias[vocab_id] += dlogit;
for (feature, x_feature) in x.iter().copied().enumerate() {
grad_lm_head[feature][vocab_id] += x_feature * dlogit;
}
}
Rust Syntax
The outer loop visits every vocabulary output.
The inner loop visits every feature of the input vector.
The bias gradient is just the logit gradient.
The weight gradient is:
input feature * output gradient
ML Concept
For a linear layer:
logits = xW + b
the gradient of a weight is:
input activation * output gradient
This is the same pattern used in larger neural networks.
Category-Theory Concept
This is the local backward map for the affine projection stage.
It translates changes needed at the output object Logits into changes in the
parameter object.
Embedding Gradient
The problem this block solves is:
Move the output error backward through the language-model head to the input embedding row.
The block:
for (feature, grad_feature) in grad_embedding[input_id].iter_mut().enumerate() {
let dx = params.lm_head[feature]
.iter()
.zip(dlogits.iter())
.map(|(weight, dlogit)| weight * dlogit)
.sum::<f32>();
*grad_feature += dx;
}
Rust Syntax
The loop mutates the gradient row for the input token.
For each feature, it pairs:
weights from that feature to every vocab output
dlogits for every vocab output
Then it sums:
weight * dlogit
ML Concept
This is backpropagation through the linear head.
It tells the embedding row how it should change so the future logits improve.
Only the row for the current input token receives an embedding gradient.
Category-Theory Concept
This is another local backward map.
The training endomorphism is built by composing local derivative information from output back toward parameters.
Parameter Update
The problem this block solves is:
Turn accumulated gradients into new parameters.
The update can be read as a loop around the same object:
Parameters_t
|
| prediction on TrainingSet
v
Average Loss
|
| local gradients
v
Gradient Accumulators
|
| subtract learning_rate * average_gradient
v
Parameters_{t+1}
The diagram has one important boundary: the first and last objects are both
Parameters. Everything in the middle explains how one state becomes the next
state.
The code computes:
let batch_scale = 1.0 / self.dataset.len() as f32;
let learning_rate = self.learning_rate.value();
let mut updated = params.clone();
Then it subtracts scaled gradients from every parameter.
Rust Syntax
batch_scale averages the accumulated gradients.
learning_rate extracts the raw scalar.
updated = params.clone() creates the output parameter object.
The following loops mutate updated, not the original params.
Finally:
Ok(updated)
returns the new model state.
ML Concept
The update rule is:
parameter_new = parameter_old - learning_rate * average_gradient
This is gradient descent.
Category-Theory Concept
The final result has the same object type as the input:
Parameters -> Parameters
That completes the endomorphism.
Regression Test
The problem this block solves is:
Prove the learner-visible promise that repeated training reduces loss on the tiny dataset.
The test:
#[test]
fn repeated_training_step_reduces_loss() -> CtResult<()> {
let tokens = TokenSequence::from_indices([1, 2, 3, 4, 1, 2, 3, 4])?;
let dataset = DatasetWindowing.apply(tokens)?;
let params = Parameters::init(VocabSize::new(5)?, ModelDimension::new(4)?);
let before = average_loss(¶ms, &dataset)?;
let train_step = TrainStep::new(dataset.clone(), LearningRate::new(1.0)?);
let trained = apply_endomorphism_n_times(&train_step, params, StepCount::new(80))?;
let after = average_loss(&trained, &dataset)?;
assert!(after.value() < before.value());
Ok(())
}
Rust Syntax
The test returns CtResult<()>, so it can use ?.
It builds a token sequence, turns it into a training set, initializes parameters, and configures a training step.
Then it applies the endomorphism 80 times and checks the loss decreased.
ML Concept
This is not a benchmark.
It is a sanity check:
training should make the tiny model better on the tiny data
Category-Theory Concept
The test exercises repeated endomorphism application:
Parameters0 -> Parameters1 -> ... -> Parameters80
The companion tests check the one-step contract too. One update keeps the same vocabulary size and model dimension, and invalid targets are rejected before an unsafe index can enter gradient accumulation.
Run The Example
Source snapshot: examples/03_training_endomorphism.rs
use category_theory_transformer_rs::{
CtResult, DatasetWindowing, LearningRate, ModelDimension, Morphism, Parameters, StepCount,
TokenSequence, TrainStep, VocabSize, apply_endomorphism_n_times, average_loss,
};
fn main() -> CtResult<()> {
let tokens = TokenSequence::from_indices([1, 2, 3, 4, 1, 2, 3, 4])?;
let dataset = DatasetWindowing.apply(tokens)?;
let params = Parameters::init(VocabSize::new(5)?, ModelDimension::new(4)?);
let before = average_loss(¶ms, &dataset)?;
let train_step = TrainStep::new(dataset.clone(), LearningRate::new(1.0)?);
let trained = apply_endomorphism_n_times(&train_step, params, StepCount::new(80))?;
let after = average_loss(&trained, &dataset)?;
println!("loss before: {:.6}", before.value());
println!("loss after: {:.6}", after.value());
println!();
println!("Typed transformation:");
println!("TrainStep : Parameters -> Parameters");
println!("Repeated endomorphism:");
println!("Parameters0 -> Parameters1 -> ... -> Parameters80");
println!("Measurement:");
println!("Parameters x TrainingSet -> Loss");
Ok(())
}
Run:
cargo run --example 03_training_endomorphism
Expected pattern:
loss before: ...
loss after: ...
Typed transformation:
TrainStep : Parameters -> Parameters
Repeated endomorphism:
Parameters0 -> Parameters1 -> ... -> Parameters80
Measurement:
Parameters x TrainingSet -> Loss
The second number should be smaller.
Example Output Transfer Checklist
The example output is deliberately small. It gives you two measurements and then names the update shape that produced the second measurement.
Use the printed lines this way:
| Example output | Boundary to own | Shortcut to reject |
|---|---|---|
loss before: ... | measure the initial state with Parameters x TrainingSet -> Loss | treating the loss measurement as the training update |
loss after: ... | measure the state after repeated updates | assuming one lower loss proves a full optimizer is correct |
TrainStep : Parameters -> Parameters | one configured step consumes model state and returns model state | returning Loss, loose gradients, or one raw matrix from apply |
Parameters0 -> Parameters1 -> ... -> Parameters80 | the same endomorphism can be applied again | repeating Parameters -> Loss as if it were training |
Parameters x TrainingSet -> Loss | evaluation needs both model state and examples | judging the loop from one prediction alone |
This is the same separation used in standard gradient-descent explanations: compute a loss and its gradient, then update the parameters in the negative gradient direction. The measurement tells you whether the model improved. The endomorphism is the repeatable state transition that makes training possible.
If you only remember one distinction from this chapter, remember this:
Parameters -> Loss measures
Parameters -> Parameters trains
Core Mental Model
In Rust terms:
TrainStep implements Morphism<Parameters, Parameters>
In ML terms:
one full-batch gradient descent update
In category-theory terms:
an endomorphism that can be iterated
Checkpoint
Why is it useful that training returns Parameters instead of a raw matrix?
A strong answer:
Because the output can immediately be used as the input to the next
TrainStep, preserving theParameters -> Parametersendomorphism shape.
Where This Leaves Us
This chapter turned training into a repeatable typed transformation. The model
state enters as Parameters, the training step computes gradients from the tiny
dataset, and the updated model state leaves as Parameters again.
The next chapter, Functors, Naturality, Monoids, and Chain Rule, steps back from the training loop and names reusable structures that appear across the whole course: mapping inside wrappers, changing wrapper shapes consistently, combining traces, and composing local derivative rules.
Further Reading
The problem this section solves is transfer. A framework training loop compresses several responsibilities into familiar calls. This chapter expands those responsibilities so the reader can see which object is measured, which object is updated, and why the update can repeat.
Start from the local Rust evidence:
average_loss(¶ms, &dataset) -> Loss
TrainStep::apply(params) -> Parameters
apply_endomorphism_n_times -> Parameters
Then compare that with a framework loop:
optimizer.zero_grad()
loss = loss_fn(model(input), target)
loss.backward()
optimizer.step()
The framework loop is compact because the model, gradient buffers, optimizer state, parameter groups, and update rule live behind framework objects. The teaching path is expanded because the reader needs to separate four ideas:
| Framework responsibility | Tiny Rust question |
|---|---|
| clear old gradients | Which temporary gradient accumulators start empty inside TrainStep::apply? |
| compute current loss | Which call has shape Parameters x TrainingSet -> Loss? |
| compute gradients | Which local derivative changes Distribution into a logit gradient? |
| update parameters | Which call returns the next full Parameters object? |
Read the sources in this order:
- D2L Gradient Descent: use it for the update direction and learning-rate intuition.
- D2L Backpropagation and Computational Graphs: use it for the forward-then-reverse gradient story.
- Automatic differentiation in machine learning: a survey: use it to keep “automatic differentiation”, “backpropagation”, “symbolic differentiation”, and “finite differences” separate.
- PyTorch
torch.optim: use it to recognizezero_grad,backward, andstepas production boundaries. - PyTorch Autograd mechanics: use it to contrast graph-recording autograd with this chapter’s hand-written gradient path.
- Backprop as Functor: use it only as advanced context for compositional update rules.
The transfer bridge is:
production loop
-> measure current model
-> compute gradients
-> update optimizer/model state
-> repeat
The category-theory bridge is smaller and stricter:
Parameters x TrainingSet -> Loss
TrainStep(dataset, learning_rate) : Parameters -> Parameters
The first boundary measures. The second boundary updates. Only the second one
is the endomorphism that can be repeated by apply_endomorphism_n_times.
Draw the distinction like this:
[ \begin{array}{rcl} \mathrm{measure} &:& \mathrm{Parameters} \times \mathrm{TrainingSet} \to \mathrm{Loss} \ \mathrm{update} &:& \mathrm{Parameters} \to \mathrm{Parameters} \end{array} ]
If a diagram makes the measurement arrow return Parameters, or makes the
update arrow return only Loss, the training story has changed meaning.
Checkpoint:
When reading an external optimizer or autograd reference, can you name which
part corresponds to Parameters x TrainingSet -> Loss and which part
corresponds to TrainStep(dataset, learning_rate) : Parameters -> Parameters?
These pages connect the tiny update to the surrounding vocabulary and source material:
- Glossary: endomorphism, parameters, learning rate, gradient
- References: gradient descent, computational graphs, backpropagation, and compositional learning
Practice After This Chapter
Use Exercise 5
to change the number of repeated training steps. The goal is not to tune a real
model. The goal is to see why a Parameters -> Parameters update can be
applied again and again.
Retrieval Practice
Recall
Recover the update shape before explaining the gradient.
- What makes
TrainStepan endomorphism? - Which line changes the probability vector into the logit gradient for the target class?
- Which helper repeats the same
Parameters -> Parametersstep many times?
Explain
Separate measurement from update.
- Why is
Parameters -> Lossuseful for evaluation but not itself a training endomorphism? - Why does the training code validate input and target token bounds before accumulating gradients?
- Why does subtracting a negative target gradient increase the target bias or target weight?
Apply
Use the sign trace from this chapter.
-
Suppose:
probs = [0.65, 0.25, 0.10] target = 2 learning_rate = 0.1 batch_scale = 1.0 bias starts at [0.0, 0.0, 0.0]What is
dlogits, and what is the updated bias? -
If you changed
StepCount::new(80)toStepCount::new(1), what would you expect to happen to the loss, and why? -
If the dataset has four examples, why does the code multiply each accumulated gradient by
batch_scale = 0.25before updating parameters?
Debug
For each invalid shortcut, name the broken shape or missing state:
returning Loss from TrainStep.apply
updating only lm_head and discarding embedding and bias
repeating Parameters -> Loss as if it were Parameters -> Parameters
skipping token bounds checks before indexing gradient buffers
A strong answer should mention the outer loop shape:
Parameters_t -> Parameters_{t+1}
The loss and gradients explain how the update is computed. They are not the object that must be returned from the training step.
Functors, Naturality, Monoids, and Chain Rule
The problem this chapter solves is:
After seeing individual ML arrows, you need names for common structures that appear across many systems: mapping inside containers, converting containers consistently, combining traces, and composing local derivatives.
The previous chapters were mostly about one pipeline. This chapter zooms out from that pipeline and asks which shapes keep appearing even when the concrete data changes. Once you can see those repeated shapes, the category-theory vocabulary stops feeling like a separate subject. It becomes a set of names for ordinary engineering moves.
This chapter covers four patterns:
Functor
NaturalTransformation
Monoid
Chain rule
They are not separate from the ML pipeline.
They explain patterns you already saw:
- mapping over many examples
- converting one wrapper shape to another
- combining pipeline traces
- composing gradients through layers
Reader orientation: This chapter is more abstract than the previous ones. Read each section in this order: first the Rust mechanism, then the ML use, then the category-theory name. The names are not decoration; they are compression for patterns that appear repeatedly in real model code.
Chapter Outcomes
By the end of this chapter, you should be able to:
- trace a functor, natural transformation, monoid, and chain-rule example through concrete Rust code,
- explain which two paths must agree in the naturality and monoid examples,
- distinguish the tiny
MulOp::backwardlocal derivative boundary from a full automatic-differentiation engine.
What You Already Know
If you have mapped over a Vec, handled an Option, appended logs, or applied
the chain rule in calculus, you already know informal versions of this chapter.
The new work is to name those repeated shapes and connect them to the same
typed pipeline discipline used earlier.
Trace Both Paths Before The Names
Before reading src/structure.rs or src/calculus.rs, trace the paths first.
The formal names in this chapter should arrive after the reader can see what
has to agree.
The first two-path check is about converting a vector into an optional first item.
Path 1:
Vec<A> --map f--> Vec<B> --first--> Option<B>
Path 2:
Vec<A> --first--> Option<A> --map f--> Option<B>
Both paths should return the same Option<B>. The code names that agreement
with naturality_square_holds_for_first_option.
The second two-path check is about grouping a trace.
Path 1:
(embedding <> linear) <> softmax
Path 2:
embedding <> (linear <> softmax)
Both paths should produce the same PipelineTrace. The empty trace should also
leave any real trace unchanged.
The chain-rule check has a different shape: one path goes forward through the computation, and one path carries derivative information backward.
Forward:
x, y -> z = x * y -> L
Backward:
dL/dz -> dL/dx and dL/dy
For z = x * y, the local derivatives are:
dz/dx = y
dz/dy = x
So the backward path scales those local derivatives by the upstream gradient:
dL/dx = dL/dz * y
dL/dy = dL/dz * x
That is the useful mental model before any abstraction:
same result by two structural paths
or
same gradient signal carried through local paths
Now the names can be useful. A functor preserves wrapper shape while mapping inside it. A natural transformation makes the two wrapper-conversion paths agree. A monoid lets trace grouping stop mattering. The chain rule composes local derivative information.
Source Snapshots
src/structure.rs covers functors, natural transformations, and monoids.
Source snapshot: src/structure.rs
/// A minimal functor interface for this tutorial.
pub trait Functor<A, B> {
type WrappedA;
type WrappedB;
fn fmap<F>(wrapped: Self::WrappedA, f: F) -> Self::WrappedB
where
F: Fn(A) -> B;
}
pub struct VecFunctor;
impl<A, B> Functor<A, B> for VecFunctor {
type WrappedA = Vec<A>;
type WrappedB = Vec<B>;
fn fmap<F>(wrapped: Vec<A>, f: F) -> Vec<B>
where
F: Fn(A) -> B,
{
wrapped.into_iter().map(f).collect()
}
}
pub struct OptionFunctor;
impl<A, B> Functor<A, B> for OptionFunctor {
type WrappedA = Option<A>;
type WrappedB = Option<B>;
fn fmap<F>(wrapped: Option<A>, f: F) -> Option<B>
where
F: Fn(A) -> B,
{
wrapped.map(f)
}
}
/// A structure-preserving conversion between wrappers.
pub trait NaturalTransformation<A> {
type From;
type To;
fn transform(from: Self::From) -> Self::To;
}
/// Natural transformation from `Vec<A>` to `Option<A>` by taking the first item.
pub struct VecToFirstOption;
impl<A> NaturalTransformation<A> for VecToFirstOption {
type From = Vec<A>;
type To = Option<A>;
fn transform(from: Vec<A>) -> Option<A> {
from.into_iter().next()
}
}
pub fn naturality_square_holds_for_first_option() -> bool {
let xs = vec![1, 2, 3];
let f = |x| x * 10;
let path_top_then_right = VecToFirstOption::transform(VecFunctor::fmap(xs.clone(), f));
let path_left_then_bottom = OptionFunctor::fmap(VecToFirstOption::transform(xs), f);
path_top_then_right == path_left_then_bottom
}
pub trait Monoid: Sized {
fn empty() -> Self;
fn combine(&self, other: &Self) -> Self;
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct TraceStep(&'static str);
impl TraceStep {
pub fn new(name: &'static str) -> Self {
Self(name)
}
pub fn name(&self) -> &'static str {
self.0
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct PipelineTrace(Vec<TraceStep>);
impl PipelineTrace {
pub fn from_steps(steps: impl IntoIterator<Item = TraceStep>) -> Self {
Self(steps.into_iter().collect())
}
pub fn names(&self) -> Vec<&'static str> {
self.0.iter().map(TraceStep::name).collect()
}
}
impl Monoid for PipelineTrace {
fn empty() -> Self {
PipelineTrace(vec![])
}
fn combine(&self, other: &Self) -> Self {
let mut combined = self.0.clone();
combined.extend_from_slice(&other.0);
PipelineTrace(combined)
}
}
pub fn monoid_laws_hold_for_pipeline_trace() -> bool {
let a = PipelineTrace::from_steps(vec![TraceStep::new("embedding")]);
let b = PipelineTrace::from_steps(vec![TraceStep::new("linear")]);
let c = PipelineTrace::from_steps(vec![TraceStep::new("softmax")]);
let identity = PipelineTrace::empty();
let left_identity = identity.combine(&a) == a;
let right_identity = a.combine(&identity) == a;
let associativity = a.combine(&b).combine(&c) == a.combine(&b.combine(&c));
left_identity && right_identity && associativity
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn vec_functor_preserves_identity_for_values() {
let values = vec![1, 2, 3];
assert_eq!(VecFunctor::fmap(values.clone(), |value| value), values);
}
#[test]
fn vec_functor_preserves_composition_for_values() {
let values = vec![1, 2, 3];
let add_one = |value| value + 1;
let double = |value| value * 2;
let map_then_map = VecFunctor::fmap(VecFunctor::fmap(values.clone(), add_one), double);
let map_composed = VecFunctor::fmap(values, |value| double(add_one(value)));
assert_eq!(map_then_map, map_composed);
}
#[test]
fn option_functor_preserves_absence() {
let missing = OptionFunctor::fmap(None::<i32>, |value| value * 10);
assert_eq!(missing, None);
}
#[test]
fn naturality_square_commutes() {
assert!(naturality_square_holds_for_first_option());
}
#[test]
fn pipeline_trace_obeys_monoid_laws() {
assert!(monoid_laws_hold_for_pipeline_trace());
}
}
src/calculus.rs keeps the chain-rule example deliberately small.
Source snapshot: src/calculus.rs
use crate::error::{CtError, CtResult};
#[derive(Debug, Clone, Copy, PartialEq)]
pub struct Scalar(f32);
impl Scalar {
pub fn new(value: f32) -> CtResult<Self> {
if !value.is_finite() {
return Err(CtError::InvalidLoss(value));
}
Ok(Self(value))
}
pub fn value(&self) -> f32 {
self.0
}
}
#[derive(Debug, Clone, Copy, PartialEq)]
pub struct LocalGradient(f32);
impl LocalGradient {
pub fn new(value: f32) -> CtResult<Self> {
if !value.is_finite() {
return Err(CtError::InvalidLoss(value));
}
Ok(Self(value))
}
pub fn value(&self) -> f32 {
self.0
}
}
#[derive(Debug, Clone, Copy)]
pub struct MulOp;
impl MulOp {
pub fn forward(&self, x: Scalar, y: Scalar) -> CtResult<Scalar> {
Scalar::new(x.value() * y.value())
}
/// Given upstream gradient dL/dz, return `(dL/dx, dL/dy)` for `z = x * y`.
pub fn backward(
&self,
x: Scalar,
y: Scalar,
upstream: LocalGradient,
) -> CtResult<(LocalGradient, LocalGradient)> {
let dz_dx = y.value();
let dz_dy = x.value();
Ok((
LocalGradient::new(upstream.value() * dz_dx)?,
LocalGradient::new(upstream.value() * dz_dy)?,
))
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn multiply_backward_returns_local_chain_rule_gradients() -> CtResult<()> {
let mul = MulOp;
let x = Scalar::new(2.0)?;
let y = Scalar::new(3.0)?;
let upstream = LocalGradient::new(1.0)?;
let (dl_dx, dl_dy) = mul.backward(x, y, upstream)?;
assert_eq!(dl_dx.value(), 3.0);
assert_eq!(dl_dy.value(), 2.0);
Ok(())
}
#[test]
fn multiply_backward_scales_with_upstream_gradient() -> CtResult<()> {
let mul = MulOp;
let x = Scalar::new(2.0)?;
let y = Scalar::new(3.0)?;
let upstream = LocalGradient::new(4.0)?;
let (dl_dx, dl_dy) = mul.backward(x, y, upstream)?;
assert_eq!(dl_dx.value(), 12.0);
assert_eq!(dl_dy.value(), 8.0);
Ok(())
}
}
The Whole Structure File
src/structure.rs defines:
Functor<A, B>
VecFunctor
OptionFunctor
NaturalTransformation<A>
VecToFirstOption
Monoid
TraceStep
PipelineTrace
Each block gives a Rust handle to one abstract pattern.
Worked Example: Mapping Over A Vector
Before reading the traits, start with the plain Rust operation that motivates them. Mapping over a vector means taking each item out, applying a function, and collecting the new values into another vector:
#![allow(unused)]
fn main() {
let values = vec![1, 2, 3];
let doubled: Vec<i32> = values.into_iter().map(|x| x * 2).collect();
assert_eq!(doubled, vec![2, 4, 6]);
}
There is no category theory hidden in that snippet. It is just ordinary Rust.
The category-theory word functor appears when we notice the reusable shape:
the code changes the contents while preserving the surrounding container.
Self-Check
Before reading the Functor trait, explain what stayed the same and what
changed in the Vec mapping example.
Four Patterns, Four Questions
This chapter is easier if you do not try to memorize four category-theory words at once. Treat each word as an answer to one engineering question.
| Pattern | Engineering question | Course example |
|---|---|---|
| Functor | Can I transform values inside a wrapper without changing the wrapper shape? | VecFunctor::fmap, OptionFunctor::fmap |
| Natural transformation | Can I convert one wrapper shape to another consistently? | Vec<A> -> Option<A> |
| Monoid | Can I combine many values with an empty value that changes nothing? | PipelineTrace |
| Chain rule | Can I compose local derivative signals through a computation? | MulOp::backward |
The chapter’s tests mirror that progression. They check small functor-law examples, a naturality square, monoid laws for traces, and the local chain-rule gradient for multiplication. The tests are not a full mathematical proof of all possible cases. They are executable anchors for the patterns the book is teaching.
The law and boundary map is:
| Pattern | Law or boundary | What the code checks |
|---|---|---|
| Functor | identity and composition should be preserved | VecFunctor examples map identity and composed functions |
| Natural transformation | mapping before conversion should match conversion before mapping | naturality_square_commutes |
| Monoid | empty value and grouping should not change the combined trace | pipeline_trace_obeys_monoid_laws |
| Chain rule | upstream gradient should scale local derivatives | multiply_backward_scales_with_upstream_gradient |
Use this table as the chapter’s visual index. If a later section feels abstract, return to the row and ask which Rust test makes the law visible.
What Is A Law, A Test, Or An Analogy?
This chapter uses mathematical names, executable tests, and engineering analogies. Those are related, but they are not the same kind of evidence.
| Claim in this chapter | Evidence in this repository | How to read it |
|---|---|---|
| Functor identity and composition are laws | VecFunctor tests with concrete values | The tests are examples of the laws, not a proof for every possible type |
Vec<A> -> Option<A> is structure-preserving | naturality_square_commutes | The test checks one concrete naturality square for the first-item conversion |
PipelineTrace behaves like a monoid | pipeline_trace_obeys_monoid_laws | The code checks empty-trace and grouping behavior for this trace type |
MulOp::backward follows the chain rule | multiply_backward_scales_with_upstream_gradient | The test checks one local derivative rule for multiplication |
| Larger ML systems can use the same patterns | chapter prose and exercises | This is a transfer analogy until a larger typed implementation exists |
The rule for this book is conservative: a law word should point to a concrete Rust test, and an analogy should be named as an analogy. That keeps the chapter useful without pretending that a few tests prove all of category theory.
Source-Backed Precision Rules
This chapter uses external sources to keep the structure vocabulary precise. Each source supports a limited claim; these citations are not proof that this crate implements a full category-theory library or a production automatic-differentiation engine.
| Source | What the source supports | Local rule in this chapter | Rust evidence |
|---|---|---|---|
| Categories for the Working Mathematician | Functors, natural transformations, and monoids are formal category-theory structures with laws, not just programming metaphors. | Use formal vocabulary only where the local Rust boundary names a concrete structure and the text states what the tests do not prove. | Functor<A, B>, naturality_square_commutes, pipeline_trace_obeys_monoid_laws |
| Category Theory for Programming | Category-theory topics can be introduced through programming-shaped examples and functional-language structure. | Use Functor, naturality, and monoid as names for checked local patterns, not as a claim that the crate models all categorical laws. | Functor<A, B>, VecFunctor, OptionFunctor, pipeline_trace_obeys_monoid_laws |
| Seven Sketches | Applied category theory can be taught through concrete examples before formal generality. | The chapter introduces laws through inspectable Rust examples before broad abstraction. | naturality_square_holds_for_first_option, PipelineTrace |
| D2L Backpropagation and Computational Graphs | Forward propagation stores intermediate values, and backpropagation computes gradients through the graph using the chain rule. | The local calculus example keeps only one operation and one upstream-gradient boundary. | MulOp::forward, MulOp::backward, LocalGradient |
| Automatic differentiation in machine learning: a survey | Automatic differentiation evaluates derivatives of programs and is broader than one hand-written backpropagation example. | Keep MulOp::backward framed as one local derivative boundary, not as a general AD implementation. | MulOp::backward, LocalGradient |
| PyTorch Autograd Mechanics | Production autograd records a graph, saves needed tensors, and traverses the graph backward with the chain rule. | MulOp::backward is a microscope for one local derivative rule, not a replacement for dynamic autograd. | multiply_backward_returns_local_chain_rule_gradients, multiply_backward_scales_with_upstream_gradient |
| Backprop as Functor | Backpropagation and parameter-update rules can be studied compositionally under stated assumptions. | The chapter uses this as advanced context only; the local claim is a pair of explicit derivative tests, not a monoidal-functor proof. | MulOp::backward, cargo test calculus::tests |
The transfer pattern is:
source claim -> local typed boundary -> validation command or test
For this chapter, that means reading cargo run --example 04_structure_and_calculus, cargo test structure::tests, and cargo test calculus::tests as evidence for these small law-shaped examples, not as
evidence that every functor, natural transformation, monoid, or differentiable
program has been modeled.
Functor<A, B>
The problem this block solves is:
The code needs a name for “apply a function inside a wrapper while keeping the wrapper shape.”
First principle: a trait is a contract. It says, “any type that implements this
trait must provide these associated types and this method.” Here the contract is
small on purpose. It does not try to model every possible functor in
mathematics; it gives this tutorial one precise place to talk about fmap.
The block:
/// A minimal functor interface for this tutorial.
pub trait Functor<A, B> {
type WrappedA;
type WrappedB;
fn fmap<F>(wrapped: Self::WrappedA, f: F) -> Self::WrappedB
where
F: Fn(A) -> B;
}
Rust Syntax
Functor<A, B> is a trait. A trait is Rust’s way to name behavior that many
types can implement.
Here, the behavior is:
map a function through some wrapper shape
A and B are generic type parameters:
A = input item type
B = output item type
Generic means the trait is not tied to one concrete type like i32 or
String. The same trait can describe Vec<i32> -> Vec<String>,
Option<TokenId> -> Option<Embedding>, or any other pair of item types.
It has associated types:
type WrappedA;
type WrappedB;
Associated types are type names chosen by each implementation of the trait. They let the trait say:
every implementer must tell us what wrapped input and wrapped output mean
For VecFunctor, those associated types become Vec<A> and Vec<B>.
For OptionFunctor, they become Option<A> and Option<B>.
The method:
fn fmap<F>(wrapped: Self::WrappedA, f: F) -> Self::WrappedB
where
F: Fn(A) -> B;
means:
Given a wrapped
Aand a functionA -> B, produce a wrappedB.
The where clause is a readable place to put a bound. The bound:
F: Fn(A) -> B
means F must be callable like a function that consumes an A and returns a
B.
Here is the real crate API in the smallest useful form:
use category_theory_transformer_rs::{Functor, VecFunctor};
let lengths = VecFunctor::fmap(vec!["cat", "rust"], |word| word.len());
assert_eq!(lengths, vec![3, 4]);
What to notice: The call names the structure once:
VecFunctor::fmap. The closure only describes the item-level operation:&str -> usize. The vector shape is handled by the functor implementation.
ML Concept
In ML, you often apply the same transformation across a structure:
map preprocessing over a batch
map token conversion over a sequence
map loss computation over examples
The wrapper might be:
Vec
Option
Result
batch tensor
The idea is:
transform the contents, preserve the surrounding structure
Category-Theory Concept
A functor maps:
objects -> objects
morphisms -> morphisms
while preserving identity and composition.
This tutorial’s trait is deliberately small. It focuses on the practical
fmap operation.
The tests in src/structure.rs check the two law-shaped habits this chapter
uses: mapping identity leaves values unchanged, and mapping two functions in
sequence matches mapping their composition.
VecFunctor
The problem this block solves is:
Demonstrate
fmapfor lists.
The block:
pub struct VecFunctor;
impl<A, B> Functor<A, B> for VecFunctor {
type WrappedA = Vec<A>;
type WrappedB = Vec<B>;
fn fmap<F>(wrapped: Vec<A>, f: F) -> Vec<B>
where
F: Fn(A) -> B,
{
wrapped.into_iter().map(f).collect()
}
}
Rust Syntax
VecFunctor is a unit struct. It stores no state.
The impl block is a trait implementation:
impl<A, B> Functor<A, B> for VecFunctor
Read it as:
for every A and B, VecFunctor knows how to behave as Functor<A, B>
The implementation chooses the associated types:
WrappedA = Vec<A>
WrappedB = Vec<B>
The method consumes the vector:
wrapped.into_iter()
maps the function over every item:
.map(f)
and collects the result:
.collect()
The runnable companion example uses the same real crate API:
use category_theory_transformer_rs::{Functor, VecFunctor};
let token_ids = vec![1, 2, 3];
let shifted = VecFunctor::fmap(token_ids, |id| id + 100);
assert_eq!(shifted, vec![101, 102, 103]);
ML Concept
If you have a batch of token IDs:
[TokenId(1), TokenId(2), TokenId(3)]
and a function:
TokenId -> Vector
mapping over the batch gives:
[Vector, Vector, Vector]
That is the same shape as VecFunctor.
Category-Theory Concept
Vec is list-like structure.
Mapping preserves the list shape:
List A -> List B
The length and order remain structurally meaningful.
OptionFunctor
The problem this block solves is:
Demonstrate the same functor idea for optional values.
The block:
pub struct OptionFunctor;
impl<A, B> Functor<A, B> for OptionFunctor {
type WrappedA = Option<A>;
type WrappedB = Option<B>;
fn fmap<F>(wrapped: Option<A>, f: F) -> Option<B>
where
F: Fn(A) -> B,
{
wrapped.map(f)
}
}
Rust Syntax
The wrapper types are:
Option<A>
Option<B>
The implementation delegates to Rust’s built-in:
wrapped.map(f)
If the value is Some(a), it becomes Some(f(a)).
If the value is None, it stays None.
The important beginner point is that Option makes absence explicit in the
type. You cannot accidentally treat a missing value as a real value without
handling the None case.
use category_theory_transformer_rs::{Functor, OptionFunctor};
let present = OptionFunctor::fmap(Some(7), |value| value * 2);
let missing = OptionFunctor::fmap(None::<i32>, |value| value * 2);
assert_eq!(present, Some(14));
assert_eq!(missing, None);
ML Concept
Optional values appear when data may be absent:
maybe first token
maybe cached embedding
maybe resolved department
Mapping over Option lets you transform present values without inventing a
fake value for missing ones.
Category-Theory Concept
Option is a context representing possible absence.
fmap lifts:
A -> B
to:
Option<A> -> Option<B>
Conceptual Extension: Distribution<T>::map
The problem this block solves is:
A probabilistic value may contain many possible outcomes. Sometimes you want to transform every possible outcome while keeping its probability attached.
The current crate’s Distribution in src/domain.rs is a concrete validated
probability vector:
pub struct Distribution(Vec<f32>);
The block below is a conceptual generic version that explains the functor idea for probabilistic outcomes:
#![allow(unused)]
fn main() {
pub struct Probability(f32);
pub struct Distribution<T> {
outcomes: Vec<(T, Probability)>,
}
impl<T> Distribution<T> {
pub fn map<U>(
self,
f: impl Fn(T) -> U,
) -> Distribution<U> {
let outcomes = self
.outcomes
.into_iter()
.map(|(value, probability)| {
(f(value), probability)
})
.collect();
Distribution { outcomes }
}
}
}
The core idea is:
map changes the values inside the distribution,
but keeps the probabilities attached to them.
Rust Syntax
Start with the generic struct:
#![allow(unused)]
fn main() {
pub struct Probability(f32);
pub struct Distribution<T> {
outcomes: Vec<(T, Probability)>,
}
}
This means:
Distribution<T> = many possible T values, each paired with a probability
If T is TokenId, then the type is:
Distribution<TokenId>
If T is String, then the type is:
Distribution<String>
The method introduces a second generic type:
pub fn map<U>(...)
T is the old outcome type.
U is the new outcome type.
So the method has this shape:
Distribution<T> -> Distribution<U>
The first parameter is:
self
That means the method consumes the old distribution.
After calling:
let text_dist = token_dist.map(decode);
the old token_dist has been moved and cannot be used again.
That is why the implementation can call:
self.outcomes.into_iter()
into_iter() consumes the vector and yields owned pairs:
(T, Probability)
The function parameter is:
f: impl Fn(T) -> U
This means:
give this method a function or closure that takes T and returns U
For example, a decoder might have this shape:
TokenId -> String
Then:
Distribution<TokenId> -> Distribution<String>
The inner mapping line is:
.map(|(value, probability)| {
(f(value), probability)
})
For every pair:
(value, probability)
the code applies f to the value and leaves the probability unchanged.
So:
(TokenId(2), Probability(0.70))
can become:
("Rust", Probability(0.70))
Finally:
.collect()
collects the transformed pairs back into a vector, and:
Distribution { outcomes }
wraps them in the new distribution.
ML Concept
Imagine a model returns possible next tokens:
TokenId(2) -> 0.70
TokenId(4) -> 0.20
TokenId(3) -> 0.10
Those token IDs are useful to the model, but a learner or UI might need text:
TokenId(2) -> "Rust"
TokenId(4) -> "Python"
TokenId(3) -> "."
map changes the representation:
Distribution<TokenId> -> Distribution<String>
The values change:
TokenId(2) becomes "Rust"
TokenId(4) becomes "Python"
TokenId(3) becomes "."
The probabilities do not change:
0.70 stays 0.70
0.20 stays 0.20
0.10 stays 0.10
So map is for changing the meaning or representation of each possible
outcome, not for changing the probability mass.
Category Theory Concept
This is the functor pattern for probabilistic context.
Given a normal deterministic function:
f : T -> U
map lifts it into the distribution context:
Distribution<T> -> Distribution<U>
In functional-programming notation:
fmap : (T -> U) -> Distribution<T> -> Distribution<U>
The outer structure is preserved:
same number of possible outcomes
same probabilities
same probabilistic context
Only the inner values are transformed.
That is the same pattern as:
Option<T> -> Option<U>
Vec<T> -> Vec<U>
Distribution<T> -> Distribution<U>
Different context, same functor idea.
Concrete Example
Here is a complete conceptual example:
#![allow(unused)]
fn main() {
#[derive(Debug, Clone, Copy)]
pub struct TokenId(pub usize);
#[derive(Debug, Clone, Copy, PartialEq)]
pub struct Probability(pub f32);
#[derive(Debug, Clone)]
pub struct Distribution<T> {
outcomes: Vec<(T, Probability)>,
}
impl<T> Distribution<T> {
pub fn new(outcomes: Vec<(T, Probability)>) -> Self {
Self { outcomes }
}
pub fn map<U, F>(self, mut f: F) -> Distribution<U>
where
F: FnMut(T) -> U,
{
let outcomes = self
.outcomes
.into_iter()
.map(|(value, probability)| {
(f(value), probability)
})
.collect();
Distribution { outcomes }
}
}
let vocab = ["I", "love", "Rust", "."];
let token_dist = Distribution::new(vec![
(TokenId(2), Probability(0.70)),
(TokenId(3), Probability(0.30)),
]);
let text_dist = token_dist.map(|token| {
vocab[token.0].to_string()
});
assert_eq!(
text_dist.outcomes,
vec![
("Rust".to_string(), Probability(0.70)),
(".".to_string(), Probability(0.30)),
],
);
}
Conceptually, the result is:
"Rust" -> 0.70
"." -> 0.30
Why Fn(T) -> U
The signature:
f: impl Fn(T) -> U
accepts functions and closures.
It is more flexible than:
fn(T) -> U
because closures can capture values from the surrounding scope:
let vocab = ["I", "love", "Rust", "."];
let text_dist = token_dist.map(|token| {
vocab[token.0].to_string()
});
The closure uses vocab from outside the closure body.
Why A Library Might Use FnMut
The pedagogical signature:
f: impl Fn(T) -> U
is easy to read.
A more flexible library signature is often:
pub fn map<U, F>(self, mut f: F) -> Distribution<U>
where
F: FnMut(T) -> U,
FnMut allows the closure to mutate captured state.
For example:
let mut counter = 0;
let numbered = token_dist.map(|token| {
counter += 1;
(counter, token)
});
The key ownership rule is unchanged:
the method consumes the old distribution and moves each T into f
map Versus flat_map
Use map when each possible value becomes one transformed value:
T -> U
So:
Distribution<T> -> Distribution<U>
Example:
TokenId -> String
Use flat_map when each possible value produces another distribution:
T -> Distribution<U>
Without flattening, the result would be:
Distribution<Distribution<U>>
The simple distinction is:
map:
one possible value becomes one transformed value
flat_map:
one possible value becomes many possible future values
In language modeling:
map decodes possible tokens into text
flat_map chains uncertainty across another prediction step
Algebra Version
If:
D<T> = [(t1, p1), (t2, p2), ..., (tn, pn)]
and:
f : T -> U
then:
map(f, D<T>) = [(f(t1), p1), (f(t2), p2), ..., (f(tn), pn)]
The probabilities are untouched.
Only the values move through f.
Core Mental Model
In Rust terms:
consume self, move each T into f, preserve each Probability, collect into
Distribution<U>
In ML terms:
decode or transform every possible outcome without changing model confidence
In category-theory terms:
lift a deterministic morphism T -> U into the probabilistic context
Distribution<T> -> Distribution<U>
NaturalTransformation<A>
The problem this block solves is:
Sometimes you need to convert one wrapper shape into another without caring about the specific item type.
The beginner trap is to think a natural transformation is just any conversion. It is more disciplined than that. The conversion must be compatible with mapping. If you transform the wrapper first and then map the item, you should get the same result as mapping first and then transforming the wrapper.
The block:
/// A structure-preserving conversion between wrappers.
pub trait NaturalTransformation<A> {
type From;
type To;
fn transform(from: Self::From) -> Self::To;
}
Rust Syntax
The associated types are:
From
To
The method:
fn transform(from: Self::From) -> Self::To;
converts the wrapper.
The type parameter A represents the item type inside the wrapper.
This trait has the same shape as Functor: the abstract contract is public,
but each implementation chooses the concrete wrapper types through associated
types.
That is why the implementation can be generic over A without knowing whether
A is a token, a sentence, a loss value, or a trace step.
ML Concept
Data pipelines often convert shapes:
list of candidates -> optional selected candidate
batch -> first example
many diagnostics -> maybe first failure
The conversion should not depend on whether the item is a token, vector, or loss.
Category-Theory Concept
A natural transformation converts one functor into another in a way that is uniform over the inner type.
The important word is uniform.
It should not inspect special details of A.
VecToFirstOption
The problem this block solves is:
Convert a list into an optional first item in a type-uniform way.
The block:
/// Natural transformation from `Vec<A>` to `Option<A>` by taking the first item.
pub struct VecToFirstOption;
impl<A> NaturalTransformation<A> for VecToFirstOption {
type From = Vec<A>;
type To = Option<A>;
fn transform(from: Vec<A>) -> Option<A> {
from.into_iter().next()
}
}
Rust Syntax
The implementation works for every A.
That is the role of:
impl<A> NaturalTransformation<A> for VecToFirstOption
The implementation does not ask for any bound on A. There is no A: Clone,
no A: Debug, and no A: PartialEq. It does not need those capabilities
because it never looks inside the item. It only changes the outer shape.
It consumes a vector:
from.into_iter()
then takes the first item:
.next()
If the vector is empty, the result is None.
If it has at least one item, the result is Some(first_item).
use category_theory_transformer_rs::{NaturalTransformation, VecToFirstOption};
let first = VecToFirstOption::transform(vec!["embed", "linear", "softmax"]);
let empty = VecToFirstOption::transform(Vec::<&str>::new());
assert_eq!(first, Some("embed"));
assert_eq!(empty, None);
ML Concept
This is like selecting the first candidate from a batch while preserving the possibility that the batch was empty.
It does not care what the candidate type is.
Category-Theory Concept
This is the example transformation:
Vec<A> -> Option<A>
It is natural because it is uniform over A.
Naturality Check
The problem this block solves is:
Show that mapping first, then converting gives the same result as converting first, then mapping.
The function:
pub fn naturality_square_holds_for_first_option() -> bool {
let xs = vec![1, 2, 3];
let f = |x| x * 10;
let path_top_then_right = VecToFirstOption::transform(VecFunctor::fmap(xs.clone(), f));
let path_left_then_bottom = OptionFunctor::fmap(VecToFirstOption::transform(xs), f);
path_top_then_right == path_left_then_bottom
}
Rust Syntax
There are two paths.
Path one:
Vec<i32> -> Vec<i32> -> Option<i32>
Path two:
Vec<i32> -> Option<i32> -> Option<i32>
Both should produce the same value.
The crate exposes this as a small law check:
use category_theory_transformer_rs::naturality_square_holds_for_first_option;
assert!(naturality_square_holds_for_first_option());
ML Concept
This is a consistency check for pipeline shape conversions.
It says:
transform values, then select
matches:
select, then transform the selected value
for this operation.
Category-Theory Concept
This is the naturality square:
Vec<A> ----fmap f----> Vec<B>
| |
| transform | transform
v v
Option<A> --fmap f--> Option<B>
The square commutes when both paths agree.
The same square as a data-flow diagram:
flowchart LR
VA["Vec<i32>"] -->|VecFunctor::fmap x10| VB["Vec<i32>"]
VA -->|VecToFirstOption::transform| OA["Option<i32>"]
VB -->|VecToFirstOption::transform| OB["Option<i32>"]
OA -->|OptionFunctor::fmap x10| OB
The same square as a rendered math view:
[ \begin{array}{ccc} \mathrm{Vec}\langle A\rangle & \xrightarrow{\mathrm{VecFunctor::fmap}(f)} & \mathrm{Vec}\langle B\rangle \ \downarrow \mathrm{VecToFirstOption} && \downarrow \mathrm{VecToFirstOption} \ \mathrm{Option}\langle A\rangle & \xrightarrow{\mathrm{OptionFunctor::fmap}(f)} & \mathrm{Option}\langle B\rangle \end{array} ]
How to read this diagram:
- the top path maps inside the vector, then selects the first item,
- the left-bottom path selects the first item, then maps inside the option,
- the square commutes when both paths produce the same
Option<B>, - the Rust handle is
naturality_square_holds_for_first_option.
If you had to redraw this by hand, that is a useful learning signal. Redrawing forces you to decide which objects sit at the corners and which arrows are responsible for each conversion.
Read it as two executable paths:
top then right:
Vec<i32> -> Vec<i32> -> Option<i32>
left then bottom:
Vec<i32> -> Option<i32> -> Option<i32>
The test passes only when both paths produce the same optional value.
Monoid
The problem this block solves is:
Some values can be combined repeatedly, and there should be an empty value that changes nothing.
The first-principles version is string concatenation. The empty string changes nothing, and grouping does not change the final text:
#![allow(unused)]
fn main() {
let empty = String::new();
let a = String::from("embed");
let b = String::from(" -> linear");
let c = String::from(" -> softmax");
assert_eq!(format!("{empty}{a}"), "embed");
assert_eq!(format!("{}{}", format!("{a}{b}"), c), format!("{a}{}", format!("{b}{c}")));
}
PipelineTrace uses the same idea, but with named pipeline steps instead of
raw text.
The trait:
pub trait Monoid: Sized {
fn empty() -> Self;
fn combine(&self, other: &Self) -> Self;
}
Rust Syntax
Sized means values of this type have a known size at compile time. In this
trait, it keeps the return type Self straightforward:
fn empty() -> Self;
Self means “the type implementing this trait.”
The trait requires:
empty() -> Self
combine(&self, other: &Self) -> Self
So a monoid can produce an identity value and combine two values into one.
The combine method borrows both traces:
fn combine(&self, other: &Self) -> Self;
&self and &Self are references. They let the method read the two existing
values without taking ownership of them. The method returns a new combined
value.
ML Concept
Common monoid-like values in ML systems include logs, traces, metrics, batches, and accumulated gradients.
You often need to combine many small values into one larger value.
Category-Theory Concept
A monoid has:
identity element
associative binary operation
The laws are:
empty combine a = a
a combine empty = a
(a combine b) combine c = a combine (b combine c)
TraceStep and PipelineTrace
The problem this block solves is:
Pipeline execution steps should be combinable into a larger trace.
The key types:
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct TraceStep(&'static str);
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct PipelineTrace(Vec<TraceStep>);
Rust Syntax
TraceStep wraps one static string.
PipelineTrace wraps a vector of steps.
PipelineTrace::from_steps collects any iterable of steps.
names() returns the raw names for display:
self.0.iter().map(TraceStep::name).collect()
These wrappers matter because two values may both be strings but mean different
things. A TraceStep is not a model name, a token, or a user-facing sentence.
The type keeps that meaning attached to the value.
ML Concept
A trace can record:
embedding
linear
softmax
cross_entropy
This is useful for understanding which stages ran.
Category-Theory Concept
A pipeline trace is a sequence-like monoid.
The empty trace is the identity.
Combining traces is concatenation.
PipelineTrace as Monoid
The problem this block solves is:
Make traces obey the monoid interface.
The implementation:
impl Monoid for PipelineTrace {
fn empty() -> Self {
PipelineTrace(vec![])
}
fn combine(&self, other: &Self) -> Self {
let mut combined = self.0.clone();
combined.extend_from_slice(&other.0);
PipelineTrace(combined)
}
}
Rust Syntax
empty returns an empty vector.
combine clones the first trace, appends the second trace, and wraps the result
again as PipelineTrace.
The clone is intentional here: combine borrows both inputs, so it cannot move
steps out of either trace. Cloning the small TraceStep values lets the method
produce a fresh trace while leaving the inputs usable.
use category_theory_transformer_rs::{Monoid, PipelineTrace, TraceStep};
let encoder = PipelineTrace::from_steps([TraceStep::new("embedding")]);
let head = PipelineTrace::from_steps([TraceStep::new("softmax")]);
let trace = encoder.combine(&head);
assert_eq!(trace.names(), vec!["embedding", "softmax"]);
ML Concept
This is how many execution logs work:
trace_a + trace_b = longer trace
Category-Theory Concept
This is the list monoid specialized to trace steps.
Monoid Law Check
The problem this block solves is:
Verify the identity and associativity laws for the trace type.
The function:
pub fn monoid_laws_hold_for_pipeline_trace() -> bool {
let a = PipelineTrace::from_steps(vec![TraceStep::new("embedding")]);
let b = PipelineTrace::from_steps(vec![TraceStep::new("linear")]);
let c = PipelineTrace::from_steps(vec![TraceStep::new("softmax")]);
let identity = PipelineTrace::empty();
let left_identity = identity.combine(&a) == a;
let right_identity = a.combine(&identity) == a;
let associativity = a.combine(&b).combine(&c) == a.combine(&b.combine(&c));
left_identity && right_identity && associativity
}
Rust Syntax
The function constructs three traces and checks three booleans.
It returns true only if all monoid laws hold for those examples.
use category_theory_transformer_rs::monoid_laws_hold_for_pipeline_trace;
assert!(monoid_laws_hold_for_pipeline_trace());
ML Concept
Grouping trace combination should not change the final trace.
This matters when systems combine logs from nested pipelines.
Category-Theory Concept
This directly checks the monoid laws:
identity
associativity
The associativity check can be read as a grouping diagram:
flowchart LR
A["embedding"] --> AB["embedding + linear"]
B["linear"] --> AB
AB --> ABC1["(embedding + linear) + softmax"]
C["softmax"] --> ABC1
B --> BC["linear + softmax"]
C --> BC
A --> ABC2["embedding + (linear + softmax)"]
BC --> ABC2
The same law as a rendered math view:
[ \begin{array}{ccc} (\mathrm{embedding} \diamond \mathrm{linear}) \diamond \mathrm{softmax} & = & \mathrm{embedding} \diamond (\mathrm{linear} \diamond \mathrm{softmax}) \end{array} ]
Here \(\diamond\) means PipelineTrace::combine. The equality is not about
string formatting. It says the final trace meaning should not depend on where
the parentheses were placed.
How to read this diagram:
- the objects are trace values,
- the arrow-like operation is
combine, - the identity object is
PipelineTrace::empty(), - the Rust handle is
monoid_laws_hold_for_pipeline_trace.
The two final traces should contain the same step names in the same order. The law is not about performance or formatting. It says grouping nested trace combinations should not change what the trace means.
The Calculus File
The problem src/calculus.rs solves is:
Backpropagation is easier to understand if you first see one local derivative rule in isolation.
What to notice: Backpropagation does not require every operation to know the whole model. Each operation only needs to know how its own output changes when its own inputs change. Composition does the rest.
The file defines:
Scalar
LocalGradient
MulOp
Scalar
The block:
#[derive(Debug, Clone, Copy, PartialEq)]
pub struct Scalar(f32);
Rust Syntax
Scalar wraps one f32.
Scalar::new rejects non-finite values.
value() returns the raw float.
This repeats a pattern from the earlier domain-object chapter:
private field
validating constructor
accessor
The raw f32 is private, so callers cannot construct Scalar(f32::NAN)
directly. They must use Scalar::new, which returns a Result.
use category_theory_transformer_rs::Scalar;
let scalar = Scalar::new(2.5)?;
assert_eq!(scalar.value(), 2.5);
Ok::<(), category_theory_transformer_rs::CtError>(())
ML Concept
This is a single numeric value in a computation graph.
Examples:
activation
loss component
weight
Category-Theory Concept
It is the simple numeric object used in the local derivative example.
LocalGradient
The block:
#[derive(Debug, Clone, Copy, PartialEq)]
pub struct LocalGradient(f32);
Rust Syntax
This is another f32 wrapper.
It has the same finite-value validation as Scalar.
The semantic difference is important:
Scalar = forward value
LocalGradient = derivative signal
This is the same newtype move used throughout the crate. Both wrappers store
an f32, but the types prevent accidental mixing at function boundaries.
ML Concept
A gradient tells how a loss changes when an intermediate value changes.
For example:
dL/dz
Category-Theory Concept
This is information flowing backward through a composed computation.
MulOp
The problem this block solves is:
Show a forward operation and its local backward rule.
The important methods:
pub fn forward(&self, x: Scalar, y: Scalar) -> CtResult<Scalar> {
Scalar::new(x.value() * y.value())
}
pub fn backward(
&self,
x: Scalar,
y: Scalar,
upstream: LocalGradient,
) -> CtResult<(LocalGradient, LocalGradient)> {
let dz_dx = y.value();
let dz_dy = x.value();
Ok((
LocalGradient::new(upstream.value() * dz_dx)?,
LocalGradient::new(upstream.value() * dz_dy)?,
))
}
Rust Syntax
forward multiplies two scalars and validates the result.
backward takes:
x
y
upstream gradient dL/dz
and returns:
(dL/dx, dL/dy)
The return type is a Rust tuple.
The ? operator appears twice in backward:
LocalGradient::new(upstream.value() * dz_dx)?
It means:
if construction succeeded, keep the value
if construction failed, return the error from backward immediately
That keeps invalid numeric states at the boundary where they are created.
use category_theory_transformer_rs::{LocalGradient, MulOp, Scalar};
let mul = MulOp;
let x = Scalar::new(2.0)?;
let y = Scalar::new(3.0)?;
let z = mul.forward(x, y)?;
let (dl_dx, dl_dy) = mul.backward(x, y, LocalGradient::new(1.0)?)?;
assert_eq!(z.value(), 6.0);
assert_eq!(dl_dx.value(), 3.0);
assert_eq!(dl_dy.value(), 2.0);
Ok::<(), category_theory_transformer_rs::CtError>(())
ML Concept
For:
z = x * y
the local derivatives are:
dz/dx = y
dz/dy = x
By the chain rule:
dL/dx = dL/dz * dz/dx = dL/dz * y
dL/dy = dL/dz * dz/dy = dL/dz * x
The companion tests use two upstream gradients. With dL/dz = 1, the gradients
are 3 and 2. With dL/dz = 4, the same local rule scales them to 12 and
8.
Category-Theory Concept
The chain rule is composition of local derivative maps.
A big neural network is many small maps composed forward, then many local gradient rules composed backward.
The local multiplication example as a rendered math view:
[ \begin{array}{ccccc} x,y & \xrightarrow{\mathrm{MulOp::forward}} & z = x \cdot y & \xrightarrow{\mathrm{loss}} & L \ && \uparrow \mathrm{d}L/\mathrm{d}z && \ \mathrm{d}L/\mathrm{d}x = (\mathrm{d}L/\mathrm{d}z),y && \mathrm{d}L/\mathrm{d}y = (\mathrm{d}L/\mathrm{d}z),x \end{array} ]
How to read this diagram:
- the top row is the forward computation,
- the bottom row names the local backward results,
- the upstream gradient
dL/dzis the signal being carried backward, - the Rust handle is
MulOp::backward.
Production Autograd Boundary
Production frameworks do not ask the user to manually call one backward
method for every operation. PyTorch’s autograd documentation describes a
reverse automatic-differentiation system: the forward pass records the
operations that produced tensors, and the backward pass traces that recorded
graph with the chain rule.
The tiny Rust example keeps only one local rule:
MulOp::forward : Scalar x Scalar -> Scalar
MulOp::backward : Scalar x Scalar x LocalGradient -> (LocalGradient, LocalGradient)
Read those as MulOp::forward : Scalar x Scalar -> Scalar and
MulOp::backward : Scalar x Scalar x LocalGradient -> (LocalGradient, LocalGradient).
That is not a replacement for autograd. It is a microscope for one boundary. It answers three questions before the full graph machinery appears:
which forward value must be remembered?
which local derivative is used?
which upstream gradient is being carried backward?
| Production autograd responsibility | Tiny Rust teaching boundary |
|---|---|
| record a dynamic graph during the forward pass | inspect one explicit MulOp::forward call |
| save intermediate tensors needed for backward rules | pass x and y back into MulOp::backward |
| traverse the graph backward using the chain rule | compute dL/dx and dL/dy from dL/dz |
| control when operations are tracked | make the local derivative boundary explicit in the type signature |
When you return to a framework, the useful question is not “where did the chain rule go?” It is:
which graph edges and saved values let the framework compose the same local
rules automatically?
Run The Example
Source snapshot: examples/04_structure_and_calculus.rs
use category_theory_transformer_rs::{
CtResult, Functor, LocalGradient, Monoid, MulOp, OptionFunctor, PipelineTrace, Scalar,
TraceStep, VecFunctor, monoid_laws_hold_for_pipeline_trace,
naturality_square_holds_for_first_option,
};
fn main() -> CtResult<()> {
let squared = VecFunctor::fmap(vec![1, 2, 3], |x| x * x);
let shifted = OptionFunctor::fmap(Some(7), |x| x + 1);
println!("Vec fmap square: {squared:?}");
println!("Option fmap +1: {shifted:?}");
println!(
"naturality square holds: {}",
naturality_square_holds_for_first_option()
);
let trace = PipelineTrace::from_steps(vec![TraceStep::new("embedding")])
.combine(&PipelineTrace::from_steps(vec![TraceStep::new("linear")]))
.combine(&PipelineTrace::from_steps(vec![TraceStep::new("softmax")]));
println!("trace: {:?}", trace.names());
println!(
"monoid laws hold: {}",
monoid_laws_hold_for_pipeline_trace()
);
let mul = MulOp;
let x = Scalar::new(2.0)?;
let y = Scalar::new(3.0)?;
let upstream = LocalGradient::new(1.0)?;
let (dl_dx, dl_dy) = mul.backward(x, y, upstream)?;
println!("dL/dx: {}", dl_dx.value());
println!("dL/dy: {}", dl_dy.value());
println!();
println!("Typed transformation:");
println!("VecFunctor::fmap : Vec<A> x (A -> B) -> Vec<B>");
println!("OptionFunctor::fmap : Option<A> x (A -> B) -> Option<B>");
println!("Naturality square:");
println!("Vec<A> -> Vec<B> -> Option<B>");
println!("Vec<A> -> Option<A> -> Option<B>");
println!("Monoid:");
println!("PipelineTrace x PipelineTrace -> PipelineTrace");
println!("Chain rule:");
println!("Scalar x Scalar -> Scalar");
println!("dL/dz -> (dL/dx, dL/dy)");
Ok(())
}
Run:
cargo run --example 04_structure_and_calculus
You should see mapping over Vec, mapping over Option, a naturality check, a
combined trace, a monoid law check, and local gradients for multiplication.
The example also prints the typed boundaries behind those values:
Typed transformation:
VecFunctor::fmap : Vec<A> x (A -> B) -> Vec<B>
OptionFunctor::fmap : Option<A> x (A -> B) -> Option<B>
Naturality square:
Vec<A> -> Vec<B> -> Option<B>
Vec<A> -> Option<A> -> Option<B>
Monoid:
PipelineTrace x PipelineTrace -> PipelineTrace
Chain rule:
Scalar x Scalar -> Scalar
dL/dz -> (dL/dx, dL/dy)
Example Output Transfer Checklist
This example is a compact law lab. Each printed line should make one consistency condition visible.
Use the output this way:
| Example output | Boundary to own | Shortcut to reject |
|---|---|---|
Vec fmap square: [1, 4, 9] | a function is mapped over each item while vector order and length stay meaningful | treating fmap as an arbitrary rewrite of the wrapper |
Option fmap +1: Some(8) | a present value can be transformed without changing the optional context | inventing a value when the option is None |
naturality square holds: true | both paths from Vec<A> to Option<B> agree | calling any wrapper conversion natural without checking mapping compatibility |
trace: ["embedding", "linear", "softmax"] | traces combine into another trace | mixing raw strings with typed trace steps at the boundary |
monoid laws hold: true | empty trace and grouping do not change the trace meaning | claiming a combine operation is monoidal when empty adds a visible step |
dL/dx: 3 and dL/dy: 2 | one local derivative rule sends upstream gradient backward | treating backpropagation as one giant derivative with no local rules |
VecFunctor::fmap : Vec<A> x (A -> B) -> Vec<B> | the item function is lifted into the vector context | confusing item-level A -> B with wrapper-level Vec<A> -> Vec<B> |
Vec<A> -> Vec<B> -> Option<B> and Vec<A> -> Option<A> -> Option<B> | the naturality square has two executable paths | proving only one side of a square |
PipelineTrace x PipelineTrace -> PipelineTrace | the combine operation stays inside the same trace type | returning raw lists, strings, or unrelated diagnostics |
dL/dz -> (dL/dx, dL/dy) | upstream gradient is distributed through local partial derivatives | dropping the upstream gradient when computing local gradients |
The four pattern names are useful only if they protect these boundaries. A functor protects wrapper-preserving mapping. A natural transformation protects agreement between two paths. A monoid protects repeated combination. The chain rule protects local-to-global gradient composition.
Output-To-Law Audit
When the example prints a result, do not stop at “it worked.” Turn the output line into a law-shaped claim and then narrow the claim back to the local Rust evidence.
Use this audit card:
output line:
Rust handle:
law or boundary:
source support:
safe non-claim:
validation command:
Worked audit:
output line: naturality square holds: true
Rust handle: naturality_square_holds_for_first_option
law or boundary: mapping before first-or-none matches first-or-none before mapping
source support: formal naturality vocabulary; programming-shaped wrapper conversion
safe non-claim: this checks one concrete square, not every natural transformation
validation command: cargo test structure::tests::naturality_square_commutes --lib
Second worked audit:
output line: dL/dx: 3 and dL/dy: 2
Rust handle: MulOp::backward
law or boundary: upstream gradient is multiplied by the local derivatives of x * y
source support: chain rule and reverse traversal through a computation graph
safe non-claim: this is one local derivative rule, not a production autograd engine
validation command: cargo test calculus::tests::multiply_backward_returns_local_chain_rule_gradients --lib
The pattern is the same as the rest of the book:
visible output -> Rust handle -> law-shaped claim -> source-backed limit
That last step matters. A passing law-shaped test is good evidence for the teaching example. It is not permission to claim the repository proves all functor laws, all naturality squares, every monoid, or every differentiable program.
Core Mental Model
In Rust terms:
traits name reusable operation shapes
unit structs demonstrate stateless operations
tests check laws
In ML terms:
map over structures, combine traces, compose local gradients
In category-theory terms:
functors preserve structure
natural transformations commute with mapping
monoids combine associatively with identity
chain rule composes derivative information
Checkpoint
Why is “local rule plus composition” the core idea behind backpropagation?
A strong answer:
Each operation only needs its local derivative; the chain rule composes those local derivatives into the gradient for the whole computation.
Where This Leaves Us
This chapter named the repeated structures that sit underneath the small ML pipeline. A functor explains mapping inside a wrapper. A natural transformation explains changing wrappers consistently. A monoid explains safe accumulation. A local gradient explains why a large training computation can be assembled from small derivative rules.
The next chapter, Seven Sketches Through Rust, uses the same engineering habit on a wider set of ideas from applied category theory. Instead of adding a larger ML model, it shows how the same typed-Rust style can model orders, resources, database instances, design relations, signal flow, circuits, and local-to-global behavior.
Further Reading
Do not read these links as a bibliography to admire. Read them as a transfer path from the small laws in this chapter to larger systems.
Start from the local Rust evidence:
VecFunctor::fmap : Vec<A> x (A -> B) -> Vec<B>
naturality_square_holds_for_first_option() -> bool
PipelineTrace x PipelineTrace -> PipelineTrace
MulOp::backward : Scalar x Scalar x LocalGradient -> (LocalGradient, LocalGradient)
Then read the sources in this order:
| Source | What to transfer back into this chapter | Local evidence to inspect |
|---|---|---|
| Categories for the Working Mathematician | Functor, natural transformation, and monoid are formal structures with laws; the local tests are examples, not universal proofs. | Functor<A, B>, naturality_square_commutes, pipeline_trace_obeys_monoid_laws |
| Category Theory for Programming | Functor and monoid names are useful only when they point to operation shapes and laws. | Functor<A, B>, VecFunctor, OptionFunctor, monoid_laws_hold_for_pipeline_trace |
| Seven Sketches | Applied category theory can begin with concrete examples before general formalism. | naturality_square_holds_for_first_option, PipelineTrace |
| D2L Backpropagation and Computational Graphs | Backpropagation reverses the forward dependency path and uses the chain rule over stored intermediates. | MulOp::forward, MulOp::backward |
| The Matrix Calculus You Need For Deep Learning | Matrix-calculus notation is support for understanding local derivative rules, not a prerequisite for running the tiny example. | LocalGradient, Scalar |
| Automatic differentiation in machine learning: a survey | Automatic differentiation is a general program-derivative technique; the local example is one visible derivative boundary. | MulOp::backward, LocalGradient |
| PyTorch Autograd Mechanics | Production systems record a graph and traverse it backward; this chapter keeps one local backward rule visible. | multiply_backward_returns_local_chain_rule_gradients, multiply_backward_scales_with_upstream_gradient |
| Backprop as Functor | Backpropagation can be studied compositionally under stated assumptions. | Advanced context only; do not promote this chapter’s tests into a proof of the paper’s theorem. |
Use Glossary when a word becomes slippery. Use References when you want the full source list.
After reading one external source, ask four questions:
- Which exact Rust type or function did it clarify?
- Which law, local derivative, or path agreement did it support?
- Which claim did it not license this chapter to make?
- Which command would you run to inspect the local evidence?
For this chapter, the commands are:
cargo run --example 04_structure_and_calculus
cargo test structure::tests --lib
cargo test calculus::tests --lib
If you can answer those questions, the external sources have transferred back into the code.
Practice After This Chapter
Use Exercise 14 to trace the naturality square and monoid laws back to the exact tests. This is the chapter’s main transfer check: a term such as “natural transformation” or “monoid” should point to a runnable law check, not only to a definition.
Retrieval Practice
Recall
Recover the four reusable structures before using their names.
First, state what a functor does to values inside a wrapper.
Then name the two paths that must agree in the Vec<A> -> Option<A>
naturality square.
Next, name the two laws that make PipelineTrace a monoid-like trace type in
this chapter.
Finally, for MulOp::backward, name the local derivatives used for
z = x * y.
Explain
Separate the law, the test, and the analogy.
Explain why OptionFunctor::fmap(None::<i32>, |value| value * 10) is expected
to return None.
Explain why VecToFirstOption::transform does not need a bound such as
A: Clone or A: Debug.
Explain why a pipeline trace is a good example of a monoid.
Then explain why the functor, naturality, and monoid tests are executable anchors rather than full mathematical proofs.
Apply
Use the runnable example and the law checks.
-
For
xs = vec![1, 2, 3]andf = |x| x * 10, compute both naturality paths:Vec<i32> --fmap f--> Vec<i32> --first--> Option<i32> Vec<i32> --first--> Option<i32> --fmap f--> Option<i32> -
For trace steps
embedding,linear, andsoftmax, write both groupings checked by associativity:(embedding <> linear) <> softmax embedding <> (linear <> softmax)What list of names should both produce?
-
For
x = 2,y = 3, and upstream gradientdL/dz = 4, what shouldMulOp::backwardreturn fordL/dxanddL/dy? -
Write a small example of a value in your own codebase that has an empty value and a combine operation. State the identity law it should satisfy.
Debug
For each broken explanation, name the missing law or wrong boundary:
mapping a Vec and changing its length without saying why
transforming Vec<A> to Option<A> by inspecting special details of A
combining traces where empty <> trace adds a visible step
claiming backpropagation is one giant derivative instead of composed local rules
A strong answer should point back to the concrete Rust checks:
VecFunctor identity and composition tests
naturality_square_commutes
pipeline_trace_obeys_monoid_laws
multiply_backward_scales_with_upstream_gradient
The goal is not to recite category-theory vocabulary. The goal is to recognize which consistency condition the code is protecting.
Seven Sketches Through Rust
The problem this chapter solves is:
Applied category theory can feel too large to connect to code. This chapter turns the seven major themes of Seven Sketches in Compositionality into small Rust blocks with tests.
The previous chapter introduced reusable structures such as functors, natural transformations, monoids, and local derivative rules. This chapter keeps the same reading style but changes scale. Instead of following the tiny ML pipeline, it takes seven broader ideas and asks how each one can be made concrete enough to inspect in Rust.
This chapter does not reproduce the paper. It gives executable handles for the ideas.
The repeated pattern is:
mathematical structure
-> Rust type
-> constructor or method
-> law check
The Rust lesson is:
newtypes + private fields + validation + explicit composition
The category-theory lesson is:
objects + relationships + composition + laws
Reader orientation: Treat each sketch as a small modeling exercise. You are not expected to know the full mathematical theory before reading the Rust. Start from the type, then read the constructor, then read the law check.
Chapter Outcomes
By the end of this chapter, you should be able to:
- name the one law, relation, or boundary each Rust sketch preserves,
- explain which part of the source text each sketch deliberately does not implement,
- transfer one sketch to a small software or ML design problem without overclaiming the mathematics.
What You Already Know
If you have modeled business rules, resource limits, database relationships, or state machines in Rust, you already know the software version of this chapter. The new idea is that applied category theory gives names to those compositional structures and asks which laws make them trustworthy.
Source Snapshots
The main module:
Source snapshot: src/sketches.rs
//! Rust models for the seven applied-category-theory sketches.
//!
//! This module is a study companion for *Seven Sketches in Compositionality*.
//! It does not try to encode all of category theory. Instead, each section
//! gives one small Rust model for the main structure of a sketch: orders,
//! resources, databases, co-design, signal flow, circuits, and logic of
//! behavior.
use crate::error::{CtError, CtResult};
/// A finite order used for the first sketch: observations can be refined into
/// features, scores, and decisions.
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
pub enum InformationLevel {
Observation,
Feature,
Score,
Decision,
}
impl InformationLevel {
/// Returns true when information at this level can flow into `target`.
pub fn can_flow_to(self, target: Self) -> bool {
self <= target
}
/// Least upper bound in this small total order.
pub fn join(self, other: Self) -> Self {
self.max(other)
}
}
const INFORMATION_LEVELS: [InformationLevel; 4] = [
InformationLevel::Observation,
InformationLevel::Feature,
InformationLevel::Score,
InformationLevel::Decision,
];
/// Checks reflexivity and transitivity for the finite information order.
pub fn information_order_obeys_preorder_laws() -> bool {
for level in INFORMATION_LEVELS {
if !level.can_flow_to(level) {
return false;
}
}
for first in INFORMATION_LEVELS {
for second in INFORMATION_LEVELS {
for third in INFORMATION_LEVELS {
let premise = first.can_flow_to(second) && second.can_flow_to(third);
if premise && !first.can_flow_to(third) {
return false;
}
}
}
}
true
}
/// Number of concrete features in a tiny model.
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
pub struct FeatureCount(usize);
impl FeatureCount {
pub fn new(value: usize) -> CtResult<Self> {
if value == 0 {
return Err(CtError::EmptyInput("feature count"));
}
Ok(Self(value))
}
pub fn value(&self) -> usize {
self.0
}
}
/// Number of abstract layers used to summarize feature capacity.
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
pub struct LayerBudget(usize);
impl LayerBudget {
pub fn new(value: usize) -> CtResult<Self> {
if value == 0 {
return Err(CtError::EmptyInput("layer budget"));
}
Ok(Self(value))
}
pub fn value(&self) -> usize {
self.0
}
}
const FEATURES_PER_LAYER: usize = 4;
/// Abstracts concrete features to the minimum layer budget that can hold them.
pub fn abstract_to_layer_budget(features: FeatureCount) -> CtResult<LayerBudget> {
LayerBudget::new(features.value().div_ceil(FEATURES_PER_LAYER))
}
/// Concretizes an abstract layer budget back to feature capacity.
pub fn concretize_layer_budget(layers: LayerBudget) -> FeatureCount {
FeatureCount(layers.value() * FEATURES_PER_LAYER)
}
/// Checks the Galois-connection law for the feature/layer abstraction.
pub fn feature_layer_galois_law_holds(
features: FeatureCount,
layers: LayerBudget,
) -> CtResult<bool> {
let abstracted_fits = abstract_to_layer_budget(features)?.value() <= layers.value();
let concrete_fits = features.value() <= concretize_layer_budget(layers).value();
Ok(abstracted_fits == concrete_fits)
}
/// A non-negative amount in a resource bundle.
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
pub struct ResourceAmount(usize);
impl ResourceAmount {
pub fn new(value: usize) -> Self {
Self(value)
}
pub fn value(&self) -> usize {
self.0
}
}
/// Compute and memory resources composed by component-wise addition.
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct ResourceBundle {
compute: ResourceAmount,
memory: ResourceAmount,
}
impl ResourceBundle {
pub fn new(compute: ResourceAmount, memory: ResourceAmount) -> Self {
Self { compute, memory }
}
pub fn compute(&self) -> ResourceAmount {
self.compute
}
pub fn memory(&self) -> ResourceAmount {
self.memory
}
/// Monoidal product for independent resources.
pub fn tensor(&self, other: &Self) -> Self {
Self {
compute: ResourceAmount::new(self.compute.value() + other.compute.value()),
memory: ResourceAmount::new(self.memory.value() + other.memory.value()),
}
}
/// Resource preorder: supply can cover demand when every component is large
/// enough.
pub fn can_supply(&self, demand: &Self) -> bool {
self.compute >= demand.compute && self.memory >= demand.memory
}
}
/// Demonstrates monotonicity of the monoidal resource operation.
pub fn resource_tensor_is_monotone() -> bool {
let small_supply = ResourceBundle::new(ResourceAmount::new(2), ResourceAmount::new(8));
let large_supply = ResourceBundle::new(ResourceAmount::new(4), ResourceAmount::new(16));
let fixed = ResourceBundle::new(ResourceAmount::new(1), ResourceAmount::new(2));
large_supply.can_supply(&small_supply)
&& large_supply
.tensor(&fixed)
.can_supply(&small_supply.tensor(&fixed))
}
/// Identifier for a department row.
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub struct DepartmentId(usize);
impl DepartmentId {
pub fn new(value: usize) -> Self {
Self(value)
}
pub fn value(&self) -> usize {
self.0
}
}
/// Identifier for an employee row.
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub struct EmployeeId(usize);
impl EmployeeId {
pub fn new(value: usize) -> Self {
Self(value)
}
pub fn value(&self) -> usize {
self.0
}
}
/// A row in the employee table. The department field is the schema arrow
/// `Employee -> Department`.
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct EmployeeRecord {
id: EmployeeId,
department: DepartmentId,
}
impl EmployeeRecord {
pub fn new(id: EmployeeId, department: DepartmentId) -> Self {
Self { id, department }
}
pub fn id(&self) -> EmployeeId {
self.id
}
pub fn department(&self) -> DepartmentId {
self.department
}
}
/// A tiny database instance: schema objects become sets of ids, and schema
/// arrows become functions between those sets.
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct CompanyInstance {
departments: Vec<DepartmentId>,
employees: Vec<EmployeeRecord>,
}
impl CompanyInstance {
pub fn new(
departments: impl IntoIterator<Item = DepartmentId>,
employees: impl IntoIterator<Item = EmployeeRecord>,
) -> CtResult<Self> {
let departments = departments.into_iter().collect::<Vec<_>>();
let employees = employees.into_iter().collect::<Vec<_>>();
for employee in &employees {
if !departments.contains(&employee.department()) {
return Err(CtError::ShapeMismatch {
op: "database instance",
expected: String::from("employee departments exist in Department"),
got: format!("missing department {}", employee.department().value()),
});
}
}
Ok(Self {
departments,
employees,
})
}
pub fn departments(&self) -> &[DepartmentId] {
&self.departments
}
pub fn employees(&self) -> &[EmployeeRecord] {
&self.employees
}
pub fn department_of(&self, employee_id: EmployeeId) -> Option<DepartmentId> {
self.employees
.iter()
.find(|employee| employee.id() == employee_id)
.map(EmployeeRecord::department)
}
pub fn foreign_keys_resolve(&self) -> bool {
self.employees
.iter()
.all(|employee| self.departments.contains(&employee.department()))
}
}
/// Requests per second needed by a design.
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
pub struct Throughput(usize);
impl Throughput {
pub fn new(value: usize) -> CtResult<Self> {
if value == 0 {
return Err(CtError::EmptyInput("throughput"));
}
Ok(Self(value))
}
pub fn value(&self) -> usize {
self.0
}
}
/// Millisecond latency boundary.
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
pub struct LatencyMs(usize);
impl LatencyMs {
pub fn new(value: usize) -> CtResult<Self> {
if value == 0 {
return Err(CtError::EmptyInput("latency"));
}
Ok(Self(value))
}
pub fn value(&self) -> usize {
self.0
}
}
/// Functional need in a co-design problem.
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct DesignRequirement {
minimum_throughput: Throughput,
maximum_latency: LatencyMs,
}
impl DesignRequirement {
pub fn new(minimum_throughput: Throughput, maximum_latency: LatencyMs) -> Self {
Self {
minimum_throughput,
maximum_latency,
}
}
}
/// Implementation offered by a candidate component.
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct ImplementationOffer {
throughput: Throughput,
latency: LatencyMs,
}
impl ImplementationOffer {
pub fn new(throughput: Throughput, latency: LatencyMs) -> Self {
Self {
throughput,
latency,
}
}
}
/// A Bool-valued feasibility relation between requirements and offers.
#[derive(Debug, Clone, Copy)]
pub struct FeasibilityRelation;
impl FeasibilityRelation {
pub fn relates(requirement: DesignRequirement, offer: ImplementationOffer) -> bool {
offer.throughput >= requirement.minimum_throughput
&& offer.latency <= requirement.maximum_latency
}
}
/// A scalar coefficient in a signal-flow matrix.
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct SignalCoefficient(i32);
impl SignalCoefficient {
pub fn new(value: i32) -> Self {
Self(value)
}
pub fn zero() -> Self {
Self(0)
}
pub fn value(&self) -> i32 {
self.0
}
fn add(self, other: Self) -> Self {
Self(self.value() + other.value())
}
fn multiply(self, other: Self) -> Self {
Self(self.value() * other.value())
}
}
/// Number of rows in a signal-flow matrix.
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct MatrixRows(usize);
impl MatrixRows {
pub fn new(value: usize) -> CtResult<Self> {
if value == 0 {
return Err(CtError::EmptyInput("matrix rows"));
}
Ok(Self(value))
}
pub fn value(&self) -> usize {
self.0
}
}
/// Number of columns in a signal-flow matrix.
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct MatrixCols(usize);
impl MatrixCols {
pub fn new(value: usize) -> CtResult<Self> {
if value == 0 {
return Err(CtError::EmptyInput("matrix columns"));
}
Ok(Self(value))
}
pub fn value(&self) -> usize {
self.0
}
}
/// Matrix semantics for a signal-flow graph.
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct SignalMatrix {
rows: MatrixRows,
cols: MatrixCols,
coefficients: Vec<Vec<SignalCoefficient>>,
}
impl SignalMatrix {
pub fn new(
rows: MatrixRows,
cols: MatrixCols,
coefficients: Vec<Vec<SignalCoefficient>>,
) -> CtResult<Self> {
if coefficients.len() != rows.value() {
return Err(CtError::ShapeMismatch {
op: "signal matrix",
expected: format!("{} rows", rows.value()),
got: format!("{} rows", coefficients.len()),
});
}
for row in &coefficients {
if row.len() != cols.value() {
return Err(CtError::ShapeMismatch {
op: "signal matrix",
expected: format!("{} columns", cols.value()),
got: format!("{} columns", row.len()),
});
}
}
Ok(Self {
rows,
cols,
coefficients,
})
}
pub fn rows(&self) -> MatrixRows {
self.rows
}
pub fn cols(&self) -> MatrixCols {
self.cols
}
pub fn coefficients(&self) -> &[Vec<SignalCoefficient>] {
&self.coefficients
}
/// Matrix composition. If `previous` is `A -> B` and `self` is `B -> C`,
/// the result is `A -> C`.
pub fn compose_after(&self, previous: &Self) -> CtResult<Self> {
if previous.rows.value() != self.cols.value() {
return Err(CtError::ShapeMismatch {
op: "signal matrix composition",
expected: format!("middle dimension {}", self.cols.value()),
got: format!("middle dimension {}", previous.rows.value()),
});
}
let mut coefficients =
vec![vec![SignalCoefficient::zero(); previous.cols.value()]; self.rows.value()];
for (output_row, output_coefficients) in coefficients.iter_mut().enumerate() {
for (input_col, coefficient) in output_coefficients.iter_mut().enumerate() {
let mut total = SignalCoefficient::zero();
for middle in 0..self.cols.value() {
total = total.add(
self.coefficients[output_row][middle]
.multiply(previous.coefficients[middle][input_col]),
);
}
*coefficient = total;
}
}
Self::new(self.rows, previous.cols, coefficients)
}
}
/// A named boundary port in an open circuit.
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub struct PortName(&'static str);
impl PortName {
pub fn new(value: &'static str) -> CtResult<Self> {
if value.trim().is_empty() {
return Err(CtError::EmptyInput("port name"));
}
Ok(Self(value))
}
pub fn as_str(&self) -> &'static str {
self.0
}
}
/// Positive resistance value.
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct ResistanceOhms(usize);
impl ResistanceOhms {
pub fn new(value: usize) -> CtResult<Self> {
if value == 0 {
return Err(CtError::InvalidQuantity {
kind: "resistance",
value: value as i64,
});
}
Ok(Self(value))
}
pub fn value(&self) -> usize {
self.0
}
}
/// A simple resistor component between two ports.
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct CircuitComponent {
from: PortName,
to: PortName,
resistance: ResistanceOhms,
}
impl CircuitComponent {
pub fn resistor(from: PortName, to: PortName, resistance: ResistanceOhms) -> Self {
Self {
from,
to,
resistance,
}
}
pub fn from(&self) -> PortName {
self.from
}
pub fn to(&self) -> PortName {
self.to
}
pub fn resistance(&self) -> ResistanceOhms {
self.resistance
}
}
/// An open circuit with input ports, output ports, and internal components.
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct OpenCircuit {
inputs: Vec<PortName>,
outputs: Vec<PortName>,
components: Vec<CircuitComponent>,
}
impl OpenCircuit {
pub fn new(
inputs: impl IntoIterator<Item = PortName>,
outputs: impl IntoIterator<Item = PortName>,
components: impl IntoIterator<Item = CircuitComponent>,
) -> CtResult<Self> {
let inputs = inputs.into_iter().collect::<Vec<_>>();
let outputs = outputs.into_iter().collect::<Vec<_>>();
let components = components.into_iter().collect::<Vec<_>>();
if inputs.is_empty() {
return Err(CtError::EmptyInput("circuit inputs"));
}
if outputs.is_empty() {
return Err(CtError::EmptyInput("circuit outputs"));
}
Ok(Self {
inputs,
outputs,
components,
})
}
pub fn input_count(&self) -> usize {
self.inputs.len()
}
pub fn output_count(&self) -> usize {
self.outputs.len()
}
pub fn component_count(&self) -> usize {
self.components.len()
}
/// Serial composition wires this circuit's outputs into the next circuit's
/// inputs.
pub fn then(&self, next: &Self) -> CtResult<Self> {
if self.output_count() != next.input_count() {
return Err(CtError::ShapeMismatch {
op: "open circuit serial composition",
expected: format!("{} next inputs", self.output_count()),
got: format!("{} next inputs", next.input_count()),
});
}
let mut components = self.components.clone();
components.extend_from_slice(&next.components);
Self::new(self.inputs.clone(), next.outputs.clone(), components)
}
/// Parallel composition keeps the two open interfaces side by side.
pub fn parallel(&self, other: &Self) -> CtResult<Self> {
let mut inputs = self.inputs.clone();
inputs.extend_from_slice(&other.inputs);
let mut outputs = self.outputs.clone();
outputs.extend_from_slice(&other.outputs);
let mut components = self.components.clone();
components.extend_from_slice(&other.components);
Self::new(inputs, outputs, components)
}
}
/// Truth values used for behavior classification.
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum TruthValue {
False,
True,
}
impl TruthValue {
pub fn and(self, other: Self) -> Self {
match (self, other) {
(TruthValue::True, TruthValue::True) => TruthValue::True,
_ => TruthValue::False,
}
}
pub fn implies(self, other: Self) -> Self {
match (self, other) {
(TruthValue::True, TruthValue::False) => TruthValue::False,
_ => TruthValue::True,
}
}
}
/// Discrete time point in a behavior trace.
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
pub struct TimeTick(usize);
impl TimeTick {
pub fn new(value: usize) -> Self {
Self(value)
}
pub fn value(&self) -> usize {
self.0
}
}
/// Closed interval of time ticks.
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct TimeInterval {
start: TimeTick,
end: TimeTick,
}
impl TimeInterval {
pub fn new(start: TimeTick, end: TimeTick) -> CtResult<Self> {
if start > end {
return Err(CtError::InvalidInterval {
start: start.value(),
end: end.value(),
});
}
Ok(Self { start, end })
}
pub fn start(&self) -> TimeTick {
self.start
}
pub fn end(&self) -> TimeTick {
self.end
}
}
/// Local safety result on one interval.
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct LocalSafetyCheck {
interval: TimeInterval,
truth: TruthValue,
}
impl LocalSafetyCheck {
pub fn new(interval: TimeInterval, truth: TruthValue) -> Self {
Self { interval, truth }
}
pub fn interval(&self) -> TimeInterval {
self.interval
}
pub fn truth(&self) -> TruthValue {
self.truth
}
}
/// A cover of local behavior checks. The global result is true only when all
/// local checks are true.
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct SafetyCover(Vec<LocalSafetyCheck>);
impl SafetyCover {
pub fn new(checks: impl IntoIterator<Item = LocalSafetyCheck>) -> CtResult<Self> {
let checks = checks.into_iter().collect::<Vec<_>>();
if checks.is_empty() {
return Err(CtError::EmptyInput("safety cover"));
}
Ok(Self(checks))
}
pub fn global_truth(&self) -> TruthValue {
self.0
.iter()
.fold(TruthValue::True, |truth, check| truth.and(check.truth()))
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn information_order_has_preorder_laws() {
assert!(information_order_obeys_preorder_laws());
assert_eq!(
InformationLevel::Feature.join(InformationLevel::Decision),
InformationLevel::Decision
);
}
#[test]
fn feature_layer_abstraction_obeys_galois_law() -> CtResult<()> {
let features = FeatureCount::new(9)?;
let layers = LayerBudget::new(3)?;
assert!(feature_layer_galois_law_holds(features, layers)?);
assert_eq!(abstract_to_layer_budget(features)?.value(), 3);
assert_eq!(concretize_layer_budget(layers).value(), 12);
Ok(())
}
#[test]
fn resource_tensor_is_order_preserving() {
assert!(resource_tensor_is_monotone());
}
#[test]
fn database_instance_resolves_schema_arrow() -> CtResult<()> {
let research = DepartmentId::new(1);
let platform = DepartmentId::new(2);
let ada = EmployeeId::new(7);
let instance =
CompanyInstance::new([research, platform], [EmployeeRecord::new(ada, research)])?;
assert!(instance.foreign_keys_resolve());
assert_eq!(instance.department_of(ada), Some(research));
Ok(())
}
#[test]
fn database_instance_rejects_missing_department_reference() {
let research = DepartmentId::new(1);
let missing_department = DepartmentId::new(99);
let ada = EmployeeId::new(7);
assert!(matches!(
CompanyInstance::new([research], [EmployeeRecord::new(ada, missing_department)]),
Err(CtError::ShapeMismatch {
op: "database instance",
..
})
));
}
#[test]
fn feasibility_relation_matches_requirement_to_offer() -> CtResult<()> {
let requirement = DesignRequirement::new(Throughput::new(100)?, LatencyMs::new(80)?);
let offer = ImplementationOffer::new(Throughput::new(120)?, LatencyMs::new(50)?);
assert!(FeasibilityRelation::relates(requirement, offer));
Ok(())
}
#[test]
fn signal_matrices_compose_like_flow_graph_semantics() -> CtResult<()> {
let duplicate = SignalMatrix::new(
MatrixRows::new(2)?,
MatrixCols::new(1)?,
vec![
vec![SignalCoefficient::new(1)],
vec![SignalCoefficient::new(1)],
],
)?;
let add_weighted = SignalMatrix::new(
MatrixRows::new(1)?,
MatrixCols::new(2)?,
vec![vec![SignalCoefficient::new(2), SignalCoefficient::new(3)]],
)?;
let composed = add_weighted.compose_after(&duplicate)?;
assert_eq!(composed.coefficients(), &[vec![SignalCoefficient::new(5)]]);
Ok(())
}
#[test]
fn signal_matrix_composition_rejects_mismatched_middle_dimension() -> CtResult<()> {
let previous = SignalMatrix::new(
MatrixRows::new(2)?,
MatrixCols::new(1)?,
vec![
vec![SignalCoefficient::new(1)],
vec![SignalCoefficient::new(1)],
],
)?;
let next = SignalMatrix::new(
MatrixRows::new(1)?,
MatrixCols::new(3)?,
vec![vec![
SignalCoefficient::new(1),
SignalCoefficient::new(1),
SignalCoefficient::new(1),
]],
)?;
assert!(matches!(
next.compose_after(&previous),
Err(CtError::ShapeMismatch {
op: "signal matrix composition",
..
})
));
Ok(())
}
#[test]
fn open_circuits_compose_in_series_and_parallel() -> CtResult<()> {
let input = PortName::new("input")?;
let middle = PortName::new("middle")?;
let output = PortName::new("output")?;
let first = OpenCircuit::new(
[input],
[middle],
[CircuitComponent::resistor(
input,
middle,
ResistanceOhms::new(10)?,
)],
)?;
let second = OpenCircuit::new(
[middle],
[output],
[CircuitComponent::resistor(
middle,
output,
ResistanceOhms::new(20)?,
)],
)?;
let serial = first.then(&second)?;
let parallel = first.parallel(&second)?;
assert_eq!(serial.input_count(), 1);
assert_eq!(serial.output_count(), 1);
assert_eq!(serial.component_count(), 2);
assert_eq!(parallel.input_count(), 2);
assert_eq!(parallel.output_count(), 2);
Ok(())
}
#[test]
fn open_circuit_serial_composition_rejects_boundary_mismatch() -> CtResult<()> {
let input = PortName::new("input")?;
let output = PortName::new("output")?;
let left = OpenCircuit::new([input], [output], [])?;
let right = OpenCircuit::new([input, output], [output], [])?;
assert!(matches!(
left.then(&right),
Err(CtError::ShapeMismatch {
op: "open circuit serial composition",
..
})
));
Ok(())
}
#[test]
fn local_behavior_truth_glues_to_global_truth() -> CtResult<()> {
let first = LocalSafetyCheck::new(
TimeInterval::new(TimeTick::new(0), TimeTick::new(5))?,
TruthValue::True,
);
let second = LocalSafetyCheck::new(
TimeInterval::new(TimeTick::new(5), TimeTick::new(10))?,
TruthValue::True,
);
let cover = SafetyCover::new([first, second])?;
assert_eq!(cover.global_truth(), TruthValue::True);
assert_eq!(
TruthValue::True.implies(TruthValue::False),
TruthValue::False
);
Ok(())
}
}
The runnable companion:
Source snapshot: examples/05_seven_sketches.rs
use category_theory_transformer_rs::{
CircuitComponent, CompanyInstance, CtResult, DepartmentId, DesignRequirement, EmployeeId,
EmployeeRecord, FeasibilityRelation, FeatureCount, ImplementationOffer, InformationLevel,
LatencyMs, LayerBudget, LocalSafetyCheck, MatrixCols, MatrixRows, OpenCircuit, PortName,
ResistanceOhms, ResourceAmount, ResourceBundle, SafetyCover, SignalCoefficient, SignalMatrix,
Throughput, TimeInterval, TimeTick, TruthValue, abstract_to_layer_budget,
concretize_layer_budget, feature_layer_galois_law_holds, information_order_obeys_preorder_laws,
resource_tensor_is_monotone,
};
fn main() -> CtResult<()> {
println!(
"orders obey preorder laws: {}",
information_order_obeys_preorder_laws()
);
println!(
"join(feature, decision): {:?}",
InformationLevel::Feature.join(InformationLevel::Decision)
);
let features = FeatureCount::new(9)?;
let layers = LayerBudget::new(3)?;
println!(
"feature/layer Galois law: {}",
feature_layer_galois_law_holds(features, layers)?
);
println!(
"abstract 9 features to {} layers; concretize 3 layers to {} features",
abstract_to_layer_budget(features)?.value(),
concretize_layer_budget(layers).value()
);
let encoder = ResourceBundle::new(ResourceAmount::new(2), ResourceAmount::new(8));
let decoder = ResourceBundle::new(ResourceAmount::new(3), ResourceAmount::new(10));
println!("combined resource bundle: {:?}", encoder.tensor(&decoder));
println!(
"resource tensor monotone: {}",
resource_tensor_is_monotone()
);
let research = DepartmentId::new(1);
let platform = DepartmentId::new(2);
let ada = EmployeeId::new(7);
let instance =
CompanyInstance::new([research, platform], [EmployeeRecord::new(ada, research)])?;
println!(
"employee {:?} belongs to department {:?}",
ada,
instance.department_of(ada)
);
let requirement = DesignRequirement::new(Throughput::new(100)?, LatencyMs::new(80)?);
let offer = ImplementationOffer::new(Throughput::new(120)?, LatencyMs::new(50)?);
println!(
"co-design offer feasible: {}",
FeasibilityRelation::relates(requirement, offer)
);
let duplicate = SignalMatrix::new(
MatrixRows::new(2)?,
MatrixCols::new(1)?,
vec![
vec![SignalCoefficient::new(1)],
vec![SignalCoefficient::new(1)],
],
)?;
let add_weighted = SignalMatrix::new(
MatrixRows::new(1)?,
MatrixCols::new(2)?,
vec![vec![SignalCoefficient::new(2), SignalCoefficient::new(3)]],
)?;
let composed = add_weighted.compose_after(&duplicate)?;
println!(
"signal-flow matrix semantics: {:?}",
composed.coefficients()
);
let input = PortName::new("input")?;
let middle = PortName::new("middle")?;
let output = PortName::new("output")?;
let first_circuit = OpenCircuit::new(
[input],
[middle],
[CircuitComponent::resistor(
input,
middle,
ResistanceOhms::new(10)?,
)],
)?;
let second_circuit = OpenCircuit::new(
[middle],
[output],
[CircuitComponent::resistor(
middle,
output,
ResistanceOhms::new(20)?,
)],
)?;
println!(
"serial circuit component count: {}",
first_circuit.then(&second_circuit)?.component_count()
);
let safety = SafetyCover::new([
LocalSafetyCheck::new(
TimeInterval::new(TimeTick::new(0), TimeTick::new(5))?,
TruthValue::True,
),
LocalSafetyCheck::new(
TimeInterval::new(TimeTick::new(5), TimeTick::new(10))?,
TruthValue::True,
),
])?;
println!("global behavior truth: {:?}", safety.global_truth());
println!();
println!("Typed transformation:");
println!("InformationLevel <= InformationLevel checks preorder");
println!("FeatureCount <-> LayerBudget checks Galois law");
println!("ResourceBundle x ResourceBundle -> ResourceBundle");
println!("EmployeeRecord -> DepartmentId must resolve in CompanyInstance");
println!("DesignRequirement x ImplementationOffer -> bool");
println!("SignalMatrix x SignalMatrix -> SignalMatrix when dimensions match");
println!("OpenCircuit x OpenCircuit -> OpenCircuit when ports match");
println!("SafetyCover -> TruthValue");
Ok(())
}
Paper Map To Rust
Use this table as the navigation layer.
| Paper area | Main content | Rust companion |
|---|---|---|
| Generative effects | preorders, monotone maps, Galois connections | InformationLevel, FeatureCount, LayerBudget |
| Resources | monoidal preorders, resource composition, enrichment | ResourceBundle, ResourceAmount |
| Databases | schemas as categories, instances as functors | CompanyInstance, EmployeeRecord, DepartmentId |
| Co-design | feasibility relations and profunctor-like reasoning | DesignRequirement, ImplementationOffer, FeasibilityRelation |
| Signal flow | syntax, semantics, matrices, composition | SignalMatrix, SignalCoefficient |
| Circuits | open systems, ports, serial and parallel composition | OpenCircuit, CircuitComponent, PortName |
| Logic of behavior | truth values, intervals, local-to-global checks | TruthValue, TimeInterval, SafetyCover |
Each section below uses the same three-part lens:
Rust syntax
ML or software concept
Category theory concept
Source Scope Contract
This chapter is a study companion, not a replacement for the source text. Each Rust model preserves one inspectable idea and deliberately leaves the larger mathematical development outside the tiny example.
| Paper area | What the Rust sketch preserves | What it does not claim |
|---|---|---|
| Generative effects | order laws and one Galois-style capacity law | a complete treatment of generative effects |
| Resources | componentwise resource composition plus monotonicity | full enriched category theory |
| Databases | schema-like reference integrity through typed IDs | a general database semantics framework |
| Co-design | feasibility as a relation between requirements and offers | full profunctor theory |
| Signal flow | matrix composition and middle-dimension checks | a complete syntax-and-semantics account for signal-flow graphs |
| Circuits | open interfaces and serial boundary matching | a full circuit algebra |
| Logic of behavior | local interval truth combined into a global claim | sheaf theory or a full temporal logic |
Use the table as a precision guard. When the Rust code checks one law, say which law it checks. When the source paper develops a larger theory, do not pretend the small Rust model has implemented all of it.
PDF-To-Rust Reading Contract
The arXiv record describes Seven Sketches in Compositionality as a long invitation to applied category theory, built around concrete examples and seven major sketches. That matters for how to use this chapter. The goal is not to compress every page into a smaller page. The goal is to give every major sketch a Rust handle that a reader can run, inspect, and test.
Complete coverage in this companion chapter means:
every major sketch area has a Rust handle;
every Rust handle names one protected law, relation, or boundary;
every protected claim says what the source text still develops beyond the code;
every reader can run one command before arguing about the abstraction.
Use this ledger while reading the PDF beside the Rust:
| When the source text discusses | Ask in this chapter | Local evidence |
|---|---|---|
| an order, relation, or refinement | Which enum, newtype, or method names the ordered world? | InformationLevel::can_flow_to, FeatureCount, LayerBudget |
| a resource or compositional quantity | Which operation combines independent pieces? | ResourceBundle::tensor |
| a schema, instance, or reference | Which constructor rejects invalid references? | CompanyInstance::new |
| a feasibility relation | Which requirement-offer pair is accepted or rejected? | FeasibilityRelation::relates |
| a signal-flow or matrix composition | Which middle dimension must match? | SignalMatrix::compose_after |
| an open system or circuit interface | Which boundary ports must agree? | OpenCircuit::then |
| a local-to-global behavior claim | Which local checks combine into a global result? | SafetyCover::global_truth |
Page-To-Rust Decision Ladder
Use this ladder when you are reading the PDF page by page and do not yet know what to implement.
| Source paragraph shape | First Rust move | Evidence to seek | Safe non-claim |
|---|---|---|---|
| definition or named object | create a newtype, enum, or struct with private fields | constructor rejects an impossible value | this is one typed domain object, not the whole theory |
| relation, order, or feasibility statement | write a method that returns bool or CtResult<T> | one passing case and one rejected case | this is one relation, not a complete semantics |
| composition rule | write a method that consumes two typed inputs | mismatched middle object, dimension, or port returns Err(...) | this is one composition boundary, not a general algebra |
| theorem, law, or proof step | write the smallest law test | the test names the exact law being checked | one passing test is local evidence, not a proof of the source text |
| worked example or application story | build a tiny fixture and print one output line | cargo run --example 05_seven_sketches shows the protected boundary | the fixture is an analogy, not a production model |
| richer machinery beyond the local handle | write a larger-claim-not-implemented sentence | the non-claim names the missing theory explicitly | the chapter cites the source; it does not compress it |
The ladder keeps the reading order concrete:
source paragraph shape
-> first Rust move
-> local evidence
-> safe non-claim
Example decisions:
| If the source section is about… | Do not start with… | Start with… |
|---|---|---|
| schemas and instances | a trait called Category | CompanyInstance::new rejecting a dangling department |
| signal-flow semantics | a full graph language | SignalMatrix::compose_after checking the middle dimension |
| open-system composition | a general operad library | OpenCircuit::then checking output/input ports |
This is how the chapter can cover the whole source at the level promised by the book: not every theorem becomes a library feature, but every major reading shape has a disciplined path toward a Rust handle, a test, and a non-claim.
This is also the limit of the chapter. If a PDF section develops a richer construction than the Rust handle, record the richer construction as context, not as something the code has proved. The safe sentence is:
The source develops a larger theory here.
This Rust handle checks one executable boundary from that theory.
Source-Backed Precision Rules
This chapter uses external sources as scope guards. Each source supports a
limited teaching claim, and each claim is tied to one local Rust boundary or
test. The chapter does not claim that src/sketches.rs implements the full
source text, a general categorical semantics library, or a production ML
architecture theory.
The Bridge Back To Tiny ML section below is part of this source contract. Its check sentence is:
This sketch helps me reject this ML shortcut.
| Source | What the source supports | Local rule in this chapter | Rust evidence |
|---|---|---|---|
| Seven Sketches | Applied category theory can be introduced through concrete examples such as databases, circuits, dynamical systems, and other real-world structures. | Treat every sketch as one executable handle for one source idea, not as a replacement for the full mathematical development. | InformationLevel, ResourceBundle, CompanyInstance, OpenCircuit, SafetyCover |
| MIT Applied Category Theory OCW | The seven topic areas can be studied as a course sequence: orders, resources, databases, co-design, signal flow, circuits, and logic of behavior. | Keep the chapter order and Paper Map To Rust aligned with that applied-category sequence. | cargo run --example 05_seven_sketches |
| Category Theory for Programming | Category-theory vocabulary can be taught through programming-shaped structures. | Explain the programming boundary before naming the category-theory pattern. | information_order_obeys_preorder_laws, feature_layer_galois_law_holds, resource_tensor_is_monotone |
| Compositional Deep Learning | Categorical schemas, functorial structure, and composition invariants can appear in neural-network settings under stated assumptions. | Use the database and co-design sketches as ML transfer analogies only; do not claim the crate learns functors or implements the thesis. | CompanyInstance, FeasibilityRelation, database_instance_rejects_missing_department_reference |
| Categorical Deep Learning | Architecture discussions can separate constraints a model should satisfy from implementations that realize them. | Use co-design as a tiny Requirement x Offer -> Bool boundary, not as a theory of all neural architectures. | DesignRequirement, ImplementationOffer, FeasibilityRelation::relates |
| Rust Book: Enums | Enums encode values that must be one variant from a finite set. | Use enums for finite state-like domains before adding laws around them. | InformationLevel, TruthValue |
| Rust Book: Traits | Traits name shared behavior and make contracts explicit. | Treat laws and methods as contracts that tests must witness, not as prose-only claims. | SignalMatrix::compose_after, OpenCircuit::then, SafetyCover::global_truth |
The transfer pattern is:
source idea -> local Rust model -> law or boundary check
For this chapter, that means reading cargo run --example 05_seven_sketches and cargo test sketches::tests as evidence for the tiny
models above, not as evidence that the full seven-sketch source text,
categorical deep-learning literature, or every applied-category construction
has been implemented.
Choose A Sketch Without Losing The Tiny ML Thread
The source paper deliberately tours many application areas. This companion chapter keeps that breadth, but the book still has one main learning path: small typed systems that make ML structure inspectable.
Use this table to decide how deeply to read each sketch on a first pass.
| Sketch | Read deeply when you need | Tiny ML transfer | Safe first-pass treatment |
|---|---|---|---|
| Information order | staged representations or approval states | raw text, tokens, features, scores, and decisions form ordered levels of processed information | Core transfer |
| Feature/layer planning | two views of model capacity | feature counts and layer budgets are different descriptions of model size | Optional but useful |
| Resources | deployment or training constraints | compute and memory limits constrain model choices | Optional but practical |
| Database instance | structured training data | bad references in source data should fail before training | Core transfer |
| Co-design feasibility | requirements versus implementations | a model may satisfy some accuracy, latency, or memory requirements and fail others | Core transfer |
| Signal matrices | linear maps and shape compatibility | composed linear stages need matching middle dimensions | Core transfer |
| Open circuits | component interfaces | typed ML components compose only when output and input boundaries match | Core transfer |
| Logic of behavior | local checks and global claims | every batch or interval must satisfy the invariant before the global claim is trusted | Optional but useful |
On a first reading, focus on the rows marked core transfer. They connect most directly to the tiny ML pipeline. The optional rows are not less important; they are just farther from the first runnable model.
The chapter is successful if you can leave each sketch with one sentence:
This Rust model prevents this invalid composition.
Worked Example: Ordering Information Levels
The smallest first-principles version of this chapter is an ordered enum. Rust can derive an order for enum variants, and that gives the code a concrete way to ask whether one information level can safely flow into another:
#![allow(unused)]
fn main() {
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
enum Level {
Observation,
Feature,
Decision,
}
assert!(Level::Observation <= Level::Decision);
assert!(Level::Feature <= Level::Feature);
}
The real src/sketches.rs module uses the same idea with named domain types,
validation, and law checks.
Self-Check
Before reading the first sketch, explain why Observation <= Decision is a
modeling rule, not just a comparison between enum variants.
One Method Across Seven Sketches
Do not read the chapter as seven unrelated theory notes. Each sketch follows the same modeling method:
engineering problem
-> named Rust values
-> validated construction
-> composition operation or relation
-> law or boundary test
That method is the same one used in the tiny ML pipeline. The difference is the domain. Instead of tokens, logits, and parameters, this chapter uses resources, database rows, signal matrices, open circuits, and local behavior checks.
When a sketch feels abstract, ask one question:
What invalid composition should this model prevent?
The answer usually points to the real engineering value of the categorical shape.
The chapter’s applied models can be scanned by the structure they protect:
| Sketch | Rust handle | Valid structure | Rejected or checked boundary |
|---|---|---|---|
| Information order | InformationLevel | information flows upward through refinement levels | preorder laws check reflexivity and transitivity |
| Feature/layer planning | FeatureCount, LayerBudget | concrete and abstract capacity views agree | Galois law checks both directions of fit |
| Resources | ResourceBundle | compute and memory combine componentwise | monotonicity checks supply still respects demand |
| Databases | CompanyInstance | employee records refer to known departments | missing department references return Err(...) |
| Co-design | FeasibilityRelation | offers satisfy throughput and latency requirements | infeasible offers are not related |
| Signal flow | SignalMatrix | matrices compose when middle dimensions match | mismatched dimensions return Err(...) |
| Open circuits | OpenCircuit | serial composition connects matching port boundaries | boundary mismatch returns Err(...) |
| Behavior logic | SafetyCover | local interval checks combine into global truth | unknown local truth prevents a false global guarantee |
This table is the chapter’s law-and-boundary index. The point is not to memorize eight rows. The point is to see that each sketch earns its abstraction by protecting one concrete relationship.
Bridge Back To Tiny ML
The source text and the MIT course both use applied examples to make category theory portable across domains. This chapter uses the same move in a smaller way: each sketch should transfer back to one tiny ML design pressure.
Use the bridge below when the chapter starts to feel like a separate category theory tour.
| Sketch | Tiny ML pressure | Rust handle to inspect | Bad shortcut the sketch helps reject | Safe non-claim |
|---|---|---|---|---|
| Information order | raw observations, features, scores, and decisions should not be interchangeable | InformationLevel::can_flow_to | treating a score as if it were already a decision | the enum is a teaching model, not a calibrated decision theory |
| Feature/layer planning | concrete feature count and abstract model capacity are different views | FeatureCount, LayerBudget | treating abstraction and concretization as inverse functions | the Galois-style law is a tiny planning check, not a full architecture search method |
| Resources | training and inference choices depend on compute and memory together | ResourceBundle::tensor | collapsing all resource constraints into one raw number | the bundle is a two-resource sketch, not a deployment cost model |
| Database instance | training rows should not carry dangling references into feature extraction | CompanyInstance::new | letting malformed structured data reach the model | this validates one schema arrow, not a full data platform |
| Co-design feasibility | model choice is often a relation between requirements and candidate implementations | FeasibilityRelation::relates | forcing every design question into a single function | one feasible offer is not proof that every implementation satisfies the constraint |
| Signal matrices | linear stages compose only when dimensions line up | SignalMatrix::compose_after | multiplying stages before checking the middle dimension | this is matrix composition, not a full autodiff or neural-network framework |
| Open circuits | components need explicit input and output boundaries | OpenCircuit::then | wiring pieces by name while ignoring boundary shape | the circuit model is an interface analogy, not a full circuit algebra |
| Logic of behavior | local checks must support global claims | SafetyCover::global_truth | claiming global safety while one local interval failed | conjunction over intervals is a tiny behavior check, not full sheaf theory |
Read each row as a transfer sentence:
This sketch helps me reject this ML shortcut.
That sentence is the practical reason to keep the sketch in the book. The formal category-theory vocabulary matters only after the engineering shortcut is visible.
Transfer Triage Card
Use this card when a source idea feels too large to turn into code. The goal is not to shrink the source. The goal is to choose one local boundary that can be inspected.
| If the transfer feels like… | Do this first | Ready when you can write… |
|---|---|---|
| a broad theory claim | shrink to one law, relation, or boundary | Source claim -> local Rust handle |
| a vocabulary list | choose one constructor, method, example line, or test | Rust handle -> protected relationship |
| a passing example only | name the invalid case it rejects or the law it checks | protected relationship -> rejected shortcut |
| an ML analogy | name the ML object, constraint, or composition it maps to | tiny ML transfer -> safe non-claim |
| a research source | state what the source does not license this chapter to claim | non-claim -> evidence command |
The completed transfer card has seven fields:
source idea:
local Rust handle:
protected law, relation, or boundary:
invalid shortcut rejected:
tiny ML transfer:
larger claim not implemented:
local evidence command or test:
Example:
source idea: schemas and instances
local Rust handle: CompanyInstance::new
protected law, relation, or boundary: EmployeeRecord -> DepartmentId must resolve
invalid shortcut rejected: letting a missing department reach feature extraction
tiny ML transfer: validate structured training rows before training
larger claim not implemented: a general categorical database semantics
local evidence command or test: cargo test sketches::tests --lib
Second example:
source idea: open systems have boundaries
local Rust handle: OpenCircuit::then
protected law, relation, or boundary: previous outputs must match next inputs
invalid shortcut rejected: wiring components by label while ignoring shape
tiny ML transfer: Tokenizer -> Embedder is legal only when the output object
matches the next input object
larger claim not implemented: decorated cospans, hypergraph categories, and
operads for general circuit composition
local evidence command or test: cargo test sketches::tests --lib
This follows the source discipline used above. Seven Sketches gives a broad tour through concrete examples. MIT’s course frames category theory as a way to organize formal systems and transfer knowledge between them. Categorical deep-learning work distinguishes architecture constraints from implementations. The local job here is smaller: choose one implementable handle, name one protected relationship, and state the non-claim.
Sketch 1: Information Order
The problem this block solves is:
Some concepts are ordered by refinement. An observation can be refined into a feature, a feature into a score, and a score into a decision.
The block begins:
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
pub enum InformationLevel {
Observation,
Feature,
Score,
Decision,
}
Rust Syntax
This is an enum.
The variants are ordered because the enum derives:
PartialOrd, Ord
That means Rust can compare:
InformationLevel::Observation <= InformationLevel::Decision
The methods:
pub fn can_flow_to(self, target: Self) -> bool {
self <= target
}
pub fn join(self, other: Self) -> Self {
self.max(other)
}
reuse that ordering.
can_flow_to checks whether information can move upward.
join returns the more informative of two levels in this total order.
ML Or Software Concept
ML systems often move through levels of processed information:
raw observation
-> extracted feature
-> model score
-> final decision
The order prevents treating a low-level observation as if it were already a decision.
Category Theory Concept
This is a preorder-shaped example.
A preorder needs:
reflexivity: a <= a
transitivity: if a <= b and b <= c, then a <= c
The law check:
information_order_obeys_preorder_laws()
iterates over the finite set and verifies those rules.
Transfer Task: Ordered States
Model a workflow from your own codebase as an ordered enum, such as:
enum ReviewState {
Draft,
Reviewed,
Published,
}
Then write the Rust sentence you would want to be true:
Draft can flow to Published
Published cannot flow to Draft
state can always flow to itself
The transfer is complete when you can name the invalid flow your order prevents.
Sketch 1 Continued: Feature And Layer Galois Law
The problem this block solves is:
A concrete feature count and an abstract layer budget are different worlds, but they can be coordinated by a law.
The key types:
pub struct FeatureCount(usize);
pub struct LayerBudget(usize);
The conversion functions:
pub fn abstract_to_layer_budget(features: FeatureCount) -> CtResult<LayerBudget>
pub fn concretize_layer_budget(layers: LayerBudget) -> FeatureCount
Rust Syntax
Both FeatureCount and LayerBudget are newtypes around usize.
Their constructors reject zero because neither a zero feature count nor a zero layer budget is useful in this model.
abstract_to_layer_budget divides feature count by FEATURES_PER_LAYER and
rounds up:
features.value().div_ceil(FEATURES_PER_LAYER)
concretize_layer_budget multiplies layers by features per layer.
ML Or Software Concept
This models capacity planning.
Concrete features might be:
9 measured feature channels
The abstract layer budget might be:
3 layers
The law says the two planning views agree about what fits.
Category Theory Concept
The checked law is:
abstract(features) <= layers
if and only if
features <= concretize(layers)
That is the shape of a Galois connection.
The two directions are not inverses.
They are coordinated by an order law.
Transfer Task: Concrete And Abstract Capacity
Pick one concrete measure and one abstract budget from software work:
concrete: batch size, feature count, request rate
abstract: GPU count, layer budget, service tier
Write the two functions:
abstract(concrete) -> abstract budget
concretize(abstract budget) -> concrete capacity
The transfer is complete when you can state the law in both directions:
abstract(concrete) fits budget
if and only if
concrete fits concretize(budget)
Sketch 2: Resources
The problem this block solves is:
Independent resources need a way to combine, and resource supply must respect demand ordering.
The key block:
pub struct ResourceBundle {
compute: ResourceAmount,
memory: ResourceAmount,
}
Rust Syntax
ResourceAmount wraps a usize.
ResourceBundle stores two resource dimensions:
compute
memory
The monoidal operation is:
pub fn tensor(&self, other: &Self) -> Self
It adds compute to compute and memory to memory.
The preorder check is:
pub fn can_supply(&self, demand: &Self) -> bool
It returns true only when every resource component is large enough.
ML Or Software Concept
This is the shape of deployment capacity:
encoder resources + decoder resources = combined resources
A machine can satisfy a demand only when it has enough compute and memory.
Category Theory Concept
This is a monoidal preorder.
The preorder is can_supply.
The monoidal product is tensor.
The law check:
resource_tensor_is_monotone()
shows that adding the same fixed resource bundle to both sides preserves the order.
Transfer Task: Resource Bundle
Extend the idea to three resource dimensions:
compute
memory
disk
Name the constructor boundary and the law check:
ResourceBundle::new(compute, memory, disk)
resource_tensor_is_monotone
The transfer is complete when you can explain why componentwise addition should not turn an adequate supply into an inadequate supply.
Sketch 3: Database Instance
The problem this block solves is:
A database row with a foreign key should not point at a missing row.
The key types:
pub struct DepartmentId(usize);
pub struct EmployeeId(usize);
pub struct EmployeeRecord {
id: EmployeeId,
department: DepartmentId,
}
pub struct CompanyInstance {
departments: Vec<DepartmentId>,
employees: Vec<EmployeeRecord>,
}
Rust Syntax
DepartmentId and EmployeeId are distinct newtypes.
That prevents mixing department IDs and employee IDs.
EmployeeRecord contains:
employee id
department id
CompanyInstance::new collects departments and employees, then checks every
employee department exists:
if !departments.contains(&employee.department()) {
return Err(CtError::ShapeMismatch { ... });
}
ML Or Software Concept
Many ML systems depend on structured data.
If a row references missing data, downstream feature extraction or training can fail later in a confusing place.
This code rejects invalid relational structure at construction time.
Category Theory Concept
The schema can be read as:
Employee -> Department
An instance assigns sets of rows to schema objects.
The foreign key is a function from employees to departments.
CompanyInstance::new checks that the function is defined for every employee.
Transfer Task: Foreign Key Boundary
Translate the pattern to another schema arrow:
Task -> Project
Order -> Customer
Comment -> Post
Name two newtypes and one record:
ProjectId
TaskId
TaskRecord { id: TaskId, project: ProjectId }
The transfer is complete when you can write the constructor failure in plain English: “reject a task whose project id is missing from the project table.”
Sketch 4: Co-Design Feasibility
The problem this block solves is:
Some relationships are not functions. A requirement and an implementation offer are related only when constraints are satisfied.
The key types:
pub struct DesignRequirement {
minimum_throughput: Throughput,
maximum_latency: LatencyMs,
}
pub struct ImplementationOffer {
throughput: Throughput,
latency: LatencyMs,
}
pub struct FeasibilityRelation;
Rust Syntax
Throughput and LatencyMs are validated newtypes.
DesignRequirement stores the minimum acceptable throughput and maximum
acceptable latency.
ImplementationOffer stores what an implementation actually provides.
The relation is:
pub fn relates(requirement: DesignRequirement, offer: ImplementationOffer) -> bool {
offer.throughput >= requirement.minimum_throughput
&& offer.latency <= requirement.maximum_latency
}
ML Or Software Concept
This models feasibility:
Can this model/service/deployment satisfy this requirement?
For example:
required throughput: at least 100 requests/sec
required latency: at most 80 ms
offer: 120 requests/sec and 50 ms
The offer is feasible.
Category Theory Concept
This is relation-shaped rather than function-shaped.
It is the small Bool-valued version of profunctor-like reasoning:
Requirement x Offer -> Bool
Not every design problem should be forced into:
A -> B
Some should be modeled as constraints or relations.
This also gives a small handle for reading newer categorical deep-learning work. At architecture scale, a model can be discussed in terms of constraints it should satisfy and implementations that realize those constraints. This Rust sketch keeps only the first tiny version of that idea:
DesignRequirement x ImplementationOffer -> bool
The important distinction is:
| Question | Co-design sketch answer |
|---|---|
| What should be true? | throughput is high enough and latency is low enough |
| What implementation evidence do we have? | one ImplementationOffer value with checked throughput and latency |
| What does the relation decide? | whether that offer satisfies that requirement |
That is not a full theory of neural architectures. It is the learner-sized boundary that prevents a common overclaim: one implementation example is not the same thing as the whole constraint space.
Transfer Task: Feasibility Relation
Model a relation that is not a function:
Requirement x Offer -> Bool
For example, use:
minimum accuracy
maximum memory
maximum latency
Then write the architecture version:
ArchitectureConstraint x CandidateImplementation -> Bool
The transfer is complete when you can explain why many offers may satisfy one requirement, one offer may satisfy many requirements, and one passing offer is not proof that every future implementation satisfies the architecture constraint.
Sketch 5: Signal Matrices
The problem this block solves is:
Signal-flow diagrams need executable semantics. In this companion, matrices provide that meaning.
The key types:
pub struct SignalCoefficient(i32);
pub struct MatrixRows(usize);
pub struct MatrixCols(usize);
pub struct SignalMatrix {
rows: MatrixRows,
cols: MatrixCols,
coefficients: Vec<Vec<SignalCoefficient>>,
}
Rust Syntax
MatrixRows and MatrixCols reject zero.
SignalMatrix::new validates that the coefficient matrix has the promised
shape:
number of rows matches MatrixRows
number of columns in every row matches MatrixCols
The composition method:
pub fn compose_after(&self, previous: &Self) -> CtResult<Self>
requires compatible middle dimensions.
Then it performs matrix multiplication using:
add
multiply
sum over the middle dimension
ML Or Software Concept
This is the same shape as composing linear layers or signal-processing stages.
If one stage maps:
A -> B
and another maps:
B -> C
then the composite maps:
A -> C
The dimensions must line up.
Category Theory Concept
Signal-flow syntax gets matrix semantics.
The important principle is functorial semantics:
meaning(composed diagram)
=
composition of meanings
The code enforces the same middle-dimension law that ordinary morphism composition enforces.
Transfer Task: Shape-Safe Composition
Write the shape of two stages before writing any coefficients:
A -> B
B -> C
Then write one invalid pair:
A -> B
D -> C
The transfer is complete when you can name the exact middle dimension that must match for matrix composition to make sense.
Sketch 6: Open Circuits
The problem this block solves is:
A circuit is not only internal components. It also has a boundary where it can connect to other circuits.
The key types:
pub struct PortName(&'static str);
pub struct ResistanceOhms(usize);
pub struct CircuitComponent {
from: PortName,
to: PortName,
resistance: ResistanceOhms,
}
pub struct OpenCircuit {
inputs: Vec<PortName>,
outputs: Vec<PortName>,
components: Vec<CircuitComponent>,
}
Rust Syntax
PortName::new rejects empty names.
ResistanceOhms::new rejects zero resistance.
OpenCircuit::new rejects circuits with no inputs or no outputs.
Serial composition:
pub fn then(&self, next: &Self) -> CtResult<Self>
checks:
self output count == next input count
Parallel composition:
pub fn parallel(&self, other: &Self) -> CtResult<Self>
puts the two boundaries side by side.
ML Or Software Concept
This looks like component architecture:
input interface
internal implementation
output interface
Composition should fail when interfaces do not match.
That rule applies to services, data pipelines, neural layers, and circuit-like systems.
Category Theory Concept
This is the open-system idea.
The boundary is part of the object.
Composition is controlled by boundary compatibility.
The paper develops this with cospans, hypergraph categories, decorated cospans, and operads. The Rust code gives a small typed analogue.
Transfer Task: Interface Boundary
Model two components by boundary only:
Tokenizer: Text -> TokenSequence
Embedder: TokenSequence -> HiddenSequence
Then write one invalid serial composition:
Tokenizer: Text -> TokenSequence
Classifier: Logits -> Label
The transfer is complete when you can say which output boundary failed to match which input boundary.
Sketch 7: Logic Of Behavior
The problem this block solves is:
A system may be safe on local time intervals. The code needs a way to combine local safety checks into one global result.
The key types:
pub enum TruthValue {
False,
True,
}
pub struct TimeTick(usize);
pub struct TimeInterval {
start: TimeTick,
end: TimeTick,
}
pub struct LocalSafetyCheck {
interval: TimeInterval,
truth: TruthValue,
}
pub struct SafetyCover(Vec<LocalSafetyCheck>);
Rust Syntax
TruthValue implements Boolean-style operations:
and
implies
TimeInterval::new rejects intervals where start is after end.
SafetyCover::new rejects an empty list of checks.
global_truth folds all local truths with and:
self.0
.iter()
.fold(TruthValue::True, |truth, check| truth.and(check.truth()))
ML Or Software Concept
This models safety or behavior validation over time:
check interval 0..5
check interval 5..10
combine into global result
If every local check is true, global truth is true.
If any local check is false, global truth is false.
Category Theory Concept
This is a small analogue of local-to-global reasoning.
The sheaf-like idea is:
local facts can determine a global fact when they glue coherently
The code uses a simple conjunction model, not full sheaf theory.
The important lesson is that proof-like information becomes explicit data, not an informal comment.
Transfer Task: Local Checks To Global Claim
Pick one invariant over time:
loss is finite
latency stays below the budget
no batch has an empty token sequence
Split it into local intervals and assign a truth value to each interval.
The transfer is complete when you can explain why one false local check must make the global claim false.
Tests As Exercise Solutions
The problem this block solves is:
The laws should be runnable, not only described in prose.
The test module checks preorder laws, the feature/layer Galois law, resource tensor monotonicity, database foreign-key resolution, feasibility relation behavior, signal matrix composition, open circuit serial and parallel composition, and local-to-global truth.
It also includes negative boundary tests. A database instance with a missing department reference is rejected. Signal matrices with incompatible middle dimensions cannot compose. Open circuits with mismatched serial boundaries do not wire together. Those failures are part of the teaching point: the model is useful because it rejects incoherent structure near the boundary.
Rust Syntax
Every law is a normal Rust test marked with:
#[test]
Tests that may fail through constructors return:
CtResult<()>
so they can use ?.
ML Or Software Concept
The tests act as executable learning checks.
If a future change breaks a law, the project should fail quickly.
Category Theory Concept
The tests are small law checks.
They are not formal proofs, but they keep the implementation aligned with the claimed structure.
Run The Companion
Run:
cargo run --example 05_seven_sketches
The output gives one executable handle per sketch:
orders obey preorder laws: true
feature/layer Galois law: true
resource tensor monotone: true
employee EmployeeId(7) belongs to department Some(DepartmentId(1))
co-design offer feasible: true
signal-flow matrix semantics: [[SignalCoefficient(5)]]
serial circuit component count: 2
global behavior truth: True
Typed transformation:
InformationLevel <= InformationLevel checks preorder
FeatureCount <-> LayerBudget checks Galois law
ResourceBundle x ResourceBundle -> ResourceBundle
EmployeeRecord -> DepartmentId must resolve in CompanyInstance
DesignRequirement x ImplementationOffer -> bool
SignalMatrix x SignalMatrix -> SignalMatrix when dimensions match
OpenCircuit x OpenCircuit -> OpenCircuit when ports match
SafetyCover -> TruthValue
Example Output Transfer Checklist
Use the companion output as a boundary map. Each line should tell you which structure is being protected.
| Example output | Rust handle | Protected structure | Shortcut to reject |
|---|---|---|---|
orders obey preorder laws: true | InformationLevel | information can flow upward through an ordered refinement path | treating an enum order as arbitrary display order |
feature/layer Galois law: true | FeatureCount, LayerBudget | concrete and abstract capacity views agree by a two-way fit law | treating abstraction and concretization as inverse functions |
resource tensor monotone: true | ResourceBundle::tensor | adding resources componentwise preserves supply ordering | combining compute and memory as one raw number |
employee ... belongs to department ... | CompanyInstance::new | every employee department reference resolves | letting a missing foreign key reach later feature extraction |
co-design offer feasible: true | FeasibilityRelation::relates | requirements and offers form a relation, not a single function | forcing every design question into A -> B |
signal-flow matrix semantics: ... | SignalMatrix::compose_after | matrix meanings compose only when middle dimensions match | multiplying stages before checking shape |
serial circuit component count: 2 | OpenCircuit::then | serial composition requires matching boundaries | wiring components by name alone |
global behavior truth: True | SafetyCover::global_truth | local checks combine into a global claim | claiming global safety while one local interval is false |
The typed lines at the bottom of the output are not extra decoration. They name the form each sketch protects: order, Galois law, monoidal resource composition, database instance, feasibility relation, matrix composition, open system composition, and local-to-global truth.
For the full validation gate:
cargo test --all-targets --all-features
Core Mental Model
In Rust terms:
each sketch becomes concrete types, constructors, methods, and tests
In ML or software terms:
orders, resources, schemas, feasibility, signal flow, interfaces, and safety
are all engineering structures
In category-theory terms:
the useful part is compositionality: make structure visible, then make
composition obey laws
What To Remember
The seven sketches are not seven disconnected topics.
They repeat one engineering discipline:
name the objects
name the relationships
control construction
define composition
check the law
That is also the discipline used by the tiny ML pipeline in the rest of the course.
Where This Leaves Us
This chapter showed that the book’s main discipline is not limited to language modeling. Orders, resources, database instances, feasibility relations, signal matrices, open circuits, and local behavior checks can all be read in the same way: name the values, constrain construction, define composition, and test the law that makes the composition trustworthy.
The remaining practice material in Exercises asks you to use that reading method yourself. The exercises are not meant to test memorized definitions. They are meant to train the habit of translating one Rust block into its software role and its categorical shape.
Further Reading
Do not treat these sources as a separate theory shelf. Use them to improve one local Rust sentence at a time.
Start from this local evidence:
cargo run --example 05_seven_sketches
cargo test sketches::tests --lib
src/sketches.rs
examples/05_seven_sketches.rs
Then read the sources in this order:
| Source | What to transfer back into this chapter | Local evidence to inspect |
|---|---|---|
| Seven Sketches | Each sketch is one compositional structure: orders, resources, schemas, co-design, signal flow, open systems, or local-to-global behavior. | Paper Map To Rust, Source Scope Contract, Example Output Transfer Checklist |
| MIT Applied Category Theory OCW | The seven topic areas form a study sequence, not a bag of unrelated examples. | cargo run --example 05_seven_sketches output order |
| Category Theory for Programming | Programming examples should carry the vocabulary before the formal name is emphasized. | InformationLevel, ResourceBundle, SignalMatrix, OpenCircuit |
| Compositional Deep Learning | Categorical schemas and composition invariants can appear in neural-network settings under explicit assumptions. | CompanyInstance, FeasibilityRelation, database_instance_rejects_missing_department_reference |
| Categorical Deep Learning | Architecture-level claims should separate constraints from implementations that realize them. | DesignRequirement x ImplementationOffer -> bool |
| Rust Book: Enums | Finite domains are often clearer as enums than as unstructured numbers. | InformationLevel, TruthValue |
| Rust Book: Traits | Shared behavior should become an explicit contract before laws are tested around it. | SignalMatrix::compose_after, OpenCircuit::then, SafetyCover::global_truth |
After reading one source, answer four questions:
- Which Rust handle did it clarify?
- Which law, relation, or boundary did it support?
- Which larger claim did the source not license?
- Which command or test shows the local evidence?
For this chapter, the commands are:
cargo run --example 05_seven_sketches
cargo test sketches::tests --lib
For terminology recovery, use:
- References: paper links and supporting Rust/materials
- Glossary: terms used by the course
- Repository Source Snapshots: complete source files
If a source does not help you name a Rust handle, a law or boundary, a non-claim, and an evidence command, it has not transferred back into this chapter yet.
Practice After This Chapter
Use Exercise 10 to test one sketch law or inspect one negative test. The goal is to choose a structure, name the invalid state it rejects, and explain the result through Rust syntax, software or ML meaning, and category-theory shape.
Retrieval Practice
Recall
Recover the chapter map before choosing a transfer example.
Name the sketch that models ordered refinement from observation to decision.
Name the sketch that rejects an employee row whose department is missing.
Name the sketch that rejects matrix composition when the middle dimensions do not match.
Name the sketch that rejects serial component composition when output and input boundaries do not match.
Explain
Explain the protected boundary.
For InformationLevel, explain why Observation can flow to Decision but a
decision should not silently flow backward to an observation.
For CompanyInstance, explain why missing department references should be
rejected at construction time instead of during later feature extraction.
For SignalMatrix, explain why the middle dimension must match before matrix
composition can run.
For OpenCircuit, explain why a boundary mismatch is an interface error, not a
numeric error.
Apply
Use the companion output as evidence.
The example prints:
orders obey preorder laws: true
feature/layer Galois law: true
resource tensor monotone: true
signal-flow matrix semantics: [[SignalCoefficient(5)]]
serial circuit component count: 2
global behavior truth: True
Choose two printed lines. For each one, write:
Rust value or function:
software meaning:
category-theory shape:
invalid structure prevented or law checked:
Then pick one engineering concept from your own work, such as permissions, queues, schema references, resource budgets, or service interfaces. Describe the objects, relationships, and one law or boundary test you would want the code to check.
Debug
For each broken model, name the missing structure:
using raw usize for both EmployeeId and DepartmentId
letting a missing department reference enter the training data
composing signal matrices without checking the middle dimension
serially wiring components by name without checking boundary counts
claiming global safety when one local interval is false
A strong answer should use the chapter’s repeated method:
name the objects
name the relationships
control construction
define composition
check the law
The goal is not to memorize every sketch. The goal is to recognize what each model refuses to let pass as valid structure.
Exercises
The problem this chapter solves is:
Reading detailed explanations is not enough. You need to practice explaining the code through Rust syntax, ML concept, and category-theory concept.
The exercises are deliberately small. A strong answer is not a long essay; it is a precise explanation that connects a line of Rust to the value it protects, the ML step it supports, and the categorical shape it names. When an exercise asks you to edit code, make the smallest change, run the command, and then explain what changed.
For every exercise, use this answer shape:
Rust syntax:
...
ML concept:
...
Category theory concept:
...
The point is not to write long answers.
The point is to connect the same block of code across all three meanings.
The exercise method is:
read one small idea
run the matching command
break one boundary on purpose
explain the failure
restore the working version
This matters because the Rust compiler and the test suite are part of the lesson. The official Rust testing material treats tests as executable checks for expected behavior. This book uses the same habit for learning: a failed test, rejected constructor, or compiler error is not only a problem to remove. It is evidence about which boundary the code protects.
Source-Backed Practice Contract
This chapter uses sources to keep practice cumulative, testable, and transfer-oriented. Each source supports one local exercise rule and one kind of repository evidence.
| Source | What the source supports | Local rule in this chapter | Repository evidence |
|---|---|---|---|
| How People Learn II | Learners need practice that connects prior knowledge to new transfer situations. | Move from one small Rust boundary to a new but related boundary. | TokenId to Distribution, then TrainStep, then attention shapes |
| Test-Enhanced Learning | Retrieval practice can improve retention rather than only measure it. | Ask Recall, Explain, and Apply questions before the answer key. | ## Checkpoint Quiz, ## Retrieval Practice, exercises/ANSWER_KEY.md |
| Structuring the Transition From Example Study to Problem Solving | Learners benefit from moving from worked examples toward independent problem solving. | Use a worked example, then a partially completed example, then an open transfer exercise. | ## Worked Example, ## Partially Completed Example, ## Transfer Exercise |
| Rust Book: Writing Automated Tests | Tests check expected behavior that the type system alone cannot prove. | Treat tests, constructor errors, and compiler errors as learning evidence. | cargo test --all-targets --all-features, domain::tests, category::tests, ml::tests |
| Rust By Example: Tests | Small test commands and targeted test names make feedback inspectable. | Prefer one named command and one visible signal per exercise attempt. | cargo test structure::tests --lib, cargo test cross_entropy_is_lower_for_more_confident_target_probability --lib |
| CS231n Optimization and PyTorch gradcheck | Numerical gradient checks are useful local debugging signals with limits. | Use finite differences to compare one local update path, then state what the check does not prove. | Advanced Exercise 5, TransformerBlockTrainStep finite-difference tests |
The transfer pattern is:
worked example -> partial example -> independent attempt -> evidence signal
For this chapter, evidence means one of:
command output
constructor error
compiler error
named test result
answer-key mismatch
It is not evidence that every exercise works for every reader yet. Direct exercise-attempt reports are still needed before the exercise ladder can be called fully validated.
Before starting, make sure the basic Rust feedback loop works:
cargo test --all-targets --all-features
That command is part of the learning method. It proves that the examples in the book are not only explanatory text; they are tied to code that the compiler can check.
After attempting an exercise, compare your reasoning with the public answer key
in exercises/ANSWER_KEY.md. Use it to check the shape of the explanation, not
to memorize wording.
Exercise Ladder
Use the exercises in this order:
| Stage | File or chapter | What you practice |
|---|---|---|
| Beginner | exercises/beginner/README.md | Change inputs, observe output, name one invariant |
| Core | this chapter | Explain each concept through Rust, ML, and category theory |
| Intermediate | exercises/intermediate/README.md | Add one morphism and explain one composition failure |
| Advanced | exercises/advanced/README.md | Extend a chapter, diagram, law, or sketch test |
Do not skip the small exercises. How People Learn II emphasizes that learners
need to retrieve and use knowledge in new situations. In this course, transfer
means taking the same explanation method from TokenId to Distribution, then
from Distribution to training, and then from training to the seven applied
sketches.
Core Chapter Practice Map
Use this map when you finish a chapter and want the matching practice task.
| Chapter | Practice target | Best exercise |
|---|---|---|
| Welcome | Explain the three-lens reading contract | Beginner Exercise 3 |
| Course Map | Connect terminal output to pipeline stages | Exercise 2 and Exercise 8 |
| Domain Objects | Explain wrappers, invariants, and typed objects | Exercise 1 and Exercise 7 |
| Morphism and Composition | Explain legal and illegal composition | Exercise 4 and Exercise 17 |
| The Tiny ML Pipeline | Trace adjacent pairs, prediction, and loss | Exercise 3, Exercise 9, Exercise 13, and Exercise 17 |
| Training as an Endomorphism | Explain repeated Parameters -> Parameters updates | Exercise 5 and Exercise 17 |
| Functors, Naturality, Monoids, and Chain Rule | Explain mapping, laws, traces, and local gradients | Exercise 6, Exercise 14, and Exercise 17 |
| Seven Sketches Through Rust | Identify the law or boundary a structure protects | Exercise 10 |
| Challenges | Turn one compiler-fix or paper-to-code task into evidence | Challenge completion report |
| Transformer Roadmap | Trace attention shapes, classify category shapes, and explain finite-difference checks for structured training state | Exercise 12, Exercise 16, Exercise 17, and Advanced Exercise 5 |
The map is not a separate syllabus. It is a repair tool. If a chapter feels clear while reading but vague one hour later, use the matching exercise to make the idea active again.
Chapter Mastery Gates
Use these gates before moving from a chapter into later material. A gate is not a grade. It is a quick test of whether the idea is active enough to reuse.
| Chapter | Run evidence | Explain evidence | Transfer evidence |
|---|---|---|---|
| Welcome | cargo run --example 01_token_sequence | state the three-lens reading contract without looking back | explain one output line through Rust, ML, and category theory |
| Course Map | cargo run --bin category_ml | name the file or module behind three printed sections | choose the next chapter and matching exercise from the output |
| Domain Objects | cargo run --example 01_domain_objects | explain one constructor invariant and the bad state it rejects | replace one raw value in an explanation with its domain type |
| Morphism and Composition | cargo run --example 02_morphism_composition | name every middle object in TokenId -> Vector -> Logits -> Distribution | explain one illegal skipped stage and the missing object |
| The Tiny ML Pipeline | cargo test ml::tests --lib | separate logits, probabilities, target token, and loss | compute which prediction should have lower cross-entropy |
| Training as an Endomorphism | cargo run --example 03_training_endomorphism | explain why one update has shape Parameters -> Parameters | predict what breaks if an update returns only a loose changed field |
| Functors, Naturality, Monoids, and Chain Rule | cargo run --example 04_structure_and_calculus | explain one law by tracing both sides of the example | classify a new trace, option, vector, or derivative example |
| Seven Sketches Through Rust | cargo run --example 05_seven_sketches | identify the relation, order, schema, circuit, or cover being protected | model one analogous boundary in a small software system |
| Transformer Roadmap | cargo run --example 06_attention_scores and cargo run --example 07_transformer_training_state | classify attention boundaries by input count and output object | reject one illegal shortcut such as HiddenSequence x MultiHeadOutput -> HiddenSequence |
If a gate fails, do not reread the whole chapter first. Start with the matching exercise, inspect the failure signal, and compare your reasoning with the answer-key rubric. The smallest useful repair is usually one missing object, one missing command, or one missing distinction.
Checkpoint Quiz
Use this after the mastery gates. Answer from memory first, then check the answer key. The goal is not vocabulary recall alone. The goal is to notice whether you can connect a Rust boundary, an ML role, and a category-theory shape without the chapter open.
Questions
Write one or two sentences for each question.
- A value has type
TokenId. What mistake becomes harder than if the same value crossed the boundary asusize? - The path
TokenId -> Vector -> Logits -> Distributionfails if the middleLogitsstage is skipped. What Rust evidence and ML evidence explain the failure? - A model gives the target token probability
0.9in one case and0.1in another. Which case should have lower cross-entropy, and why? - A training update changes weights but returns only the changed readout matrix. Which composition shape has been broken?
VecFunctor::fmapmaps every element andOptionFunctor::fmapmaps only when a value is present. What does that preserve?- A naturality square has two paths from
Vec<A>toOption<B>. What should be true if the square commutes? AttentionScores x AttentionMask -> AttentionScoresreturns the score object. Why is this still not a unary endomorphism?- Why must the attention mask act before row-wise softmax?
HiddenSequence x MultiHeadOutput -> HiddenSequencelooks tempting after concatenating heads. Which missing boundary makes it illegal?- A finite-difference test agrees with the inferred gradient for one parameter. What has it checked, and what has it not checked?
Coverage Map
| Question | Chapter or section | Main objective |
|---|---|---|
| 1 | Domain Objects | explain why a wrapper protects a domain role |
| 2 | Morphism and Composition | identify a missing middle object |
| 3 | Tiny ML Pipeline | connect target probability to loss |
| 4 | Training as an Endomorphism | preserve state-update composition |
| 5 | Structure and Laws | explain structure-preserving mapping |
| 6 | Structure and Laws | trace both paths through a naturality square |
| 7 | Transformer Roadmap | count inputs before naming an endomorphism |
| 8 | Transformer Roadmap | separate masked scores from weights |
| 9 | Transformer Roadmap | identify a missing projection boundary |
| 10 | Exercises and Transformer Roadmap | state the scope of a local gradient check |
Score the quiz by evidence, not points. A strong answer names the object or boundary, explains the ML or software role, and rejects one invalid shortcut. If an answer only repeats a term, return to the matching exercise.
Failure Signals
A good exercise often fails before it works. Use the failure signal as part of the answer.
| Signal | Usually means | What to explain |
|---|---|---|
| Compiler type error | two stages do not connect | the missing middle object |
Constructor returns Err(...) | a value violates an invariant | the bad state rejected at the boundary |
| Test assertion fails | the behavior no longer matches the law | which example stopped preserving the intended structure |
| Command output changes | the data path changed | which typed value moved differently through the pipeline |
When an exercise asks you to break something, do it in a small local edit and then restore the working version. The final repository should still pass the validation commands.
Exercise Evidence Map
Use this table before checking the answer key. It tells you what kind of evidence should exist when an exercise is complete.
| Exercise | Progress evidence | Failure or output to inspect |
|---|---|---|
| Exercise 1 | written three-lens explanation | raw representation, invariant, and pipeline stage are all named |
| Exercise 2 | cargo run --bin category_ml | terminal output includes the new adjacent transition |
| Exercise 3 | handwritten adjacent pairs | three overlapping TokenId pairs are present |
| Exercise 4 | temporary broken composition | compiler reports a missing trait bound or middle object |
| Exercise 5 | cargo run --example 03_training_endomorphism | loss output changes as StepCount changes |
| Exercise 6 | rewritten output distribution | probabilities stay attached to transformed outcomes |
| Exercise 7 | constructor boundary explanation | Err(...) is connected to the invalid value |
| Exercise 8 | five-sentence file summary | one command is named as the proof that the file still works |
| Exercise 9 | source-role comparison | one external resource is connected to one local source file, one owned boundary, and one unsupported claim |
| Exercise 10 | cargo run --example 05_seven_sketches or a negative test | one law still holds, or one invalid structure is rejected |
| Exercise 11 | block explanation | a beginner-facing Rust explanation and a shape name are both present |
| Exercise 12 | cargo run --example 06_attention_scores | first output line and category shape for each attention boundary are recorded |
| Exercise 13 | cargo test cross_entropy_is_lower_for_more_confident_target_probability --lib | lower loss is assigned to the higher target probability |
| Exercise 14 | cargo test structure::tests --lib | naturality paths and monoid laws are both named |
| Exercise 15 | mixed boundary diagnosis | each failure is classified as an invariant, composition, endomorphism, shape, or local-to-global boundary |
| Exercise 16 | cargo run --example 07_transformer_training_state | three different updates preserve TransformerTrainingState -> TransformerTrainingState |
| Exercise 17 | diagram reconstruction sheet | objects, arrows, paths, Rust handles, and safe non-claims are all labeled |
This is not extra bureaucracy. Rustlings-style practice works because the learner gets a concrete feedback signal. This course uses the same idea: command output, a constructor error, a compiler error, or a named test should tell you whether the concept is becoming executable.
Worked Example: Mixed Boundary Diagnosis
Before solving Exercise 15, study one complete diagnosis. The case is:
CrossEntropy receives Logits instead of Product<Distribution, TokenId>.
A weak answer says:
The types are wrong.
That is true, but it is not precise enough. A useful diagnosis names the Rust boundary, the ML mistake, and the category-theory shape.
Boundary type:
composition boundary plus product-input boundary
Rust syntax:
CrossEntropy implements Morphism<Product<Distribution, TokenId>, Loss>. The
input must therefore be a product containing a validated Distribution and the
target TokenId. Logits alone have the wrong type.
ML concept:
Logits are unnormalized vocabulary scores. Cross-entropy needs the probability
assigned to the correct target token. The missing work is Softmax followed by
pairing the resulting Distribution with the target TokenId.
Category theory concept:
The legal route is Logits -> Distribution and then
Distribution x TokenId -> Loss. Skipping the product object hides the supervised
part of the loss calculation.
Smallest useful fix:
Run Softmax first, then call CrossEntropy on
Product::new(distribution, target_token).
Use this as the standard for Exercise 15. Do not stop at “wrong type.” Explain which object was missing, which morphism should have produced it, and which shortcut the boundary rejected.
Exercise Attempt Record
When an exercise feels unclear, record the attempt in this shape before opening an issue or comparing with the answer key:
Exercise:
Chapter:
Command run:
First failure signal:
Line or concept that caused confusion:
What I expected:
What happened instead:
Answer-key mismatch:
Suggested rewrite:
This report is useful because it ties reader feedback to a concrete exercise, command, failure signal, and chapter location. It also keeps feedback public and impersonal: do not include private data, local secrets, or personal background details that are not needed to improve the exercise.
Use the answer key after the attempt record. If the answer key explains the concept but not the failure you saw, that is evidence that the exercise needs a better hint, pass condition, or worked example.
Open an exercise clarity report after you have one concrete attempt record. The link fills the route, not the evidence; the evidence signal should come from what you personally read, ran, or attempted.
Worked Example
First study a complete answer. The exercise is:
Explain why TokenId is not a raw usize.
A strong answer:
Rust syntax:
TokenId is a tuple struct around usize. The field is private, so callers use
TokenId::new and index() instead of reaching into the raw value directly.
ML concept:
The number represents a vocabulary position, not an arbitrary count or shape.
Category theory concept:
TokenId is one object in the small category of typed pipeline values. Morphisms
such as Embedding can start from it.
Notice the order: name the syntax, connect it to the ML role, then name only the categorical shape the code supports.
Worked Example: Gradient Checking
This worked example supports Advanced Exercise 5 in
exercises/advanced/README.md.
The exercise asks why a finite-difference test compares:
inferred gradient from one training update
central finite difference of average loss
The reason is that these are two independent ways to ask the same local question:
If I nudge this parameter, how does the loss move?
CS231n presents this as the difference between numerical gradients and analytic gradients: the numerical version is slower and approximate, but useful for checking whether the analytic implementation is correct. Dive into Deep Learning explains the matching training shape from the other direction: backpropagation walks the computation in reverse order, stores intermediate values, and computes gradients for parameters. This project makes that idea small enough to inspect in Rust.
PyTorch’s gradcheck documentation gives the same engineering warning in
framework form: the check compares small finite differences against analytical
gradients and accepts agreement only within tolerance. It also calls out
practical caveats such as precision, non-differentiable points, and overlapping
memory. Translate that into this Rust lab as:
finite-difference match = useful local debugging signal
finite-difference match != proof of every gradient path
The code-level test has two paths.
The first path performs one training step:
before parameter
-> TransformerBlockTrainStep
-> after parameter
From that update, the test infers the gradient:
inferred_gradient = (before_value - after_value) / learning_rate
That matches gradient descent:
parameter <- parameter - learning_rate * gradient
The second path does not trust the training step. It clones the same state, changes one parameter in two directions, and measures the loss:
loss_plus = loss(parameter + epsilon)
loss_minus = loss(parameter - epsilon)
Then it estimates the local slope:
finite_difference = (loss_plus - loss_minus) / (2 * epsilon)
A strong answer for one parameter family looks like this:
Rust syntax:
The test selects one feed-forward bias entry, clones the training state twice,
adds epsilon to the entry in one clone, subtracts epsilon in the other clone,
and calls transformer_block_average_loss on both states.
ML concept:
The bias is a trainable parameter. The central finite difference estimates how
the average loss changes around the current bias value. The one-step update
infers the gradient that backpropagation used. If both slopes match, the
implemented update has the right local sign and scale for that parameter.
Category theory concept:
The training step is an endomorphism on TransformerTrainingState. The check
asks whether this state update agrees locally with the loss morphism that it is
supposed to reduce.
What failure would this test catch?
It would catch a missing bias gradient, a reversed update sign, a dropped path
through the feed-forward block, or a mismatch between averaged loss and summed
gradients.
The important habit is not the formula by itself. The habit is triangulation:
implementation path
numerical measurement
conceptual explanation
When all three agree, the code becomes easier to trust and easier to teach. If the two numbers disagree, do not immediately change the test tolerance. Ask which boundary failed first:
wrong sign?
missing path?
wrong averaging scale?
non-smooth point?
parameter aliasing or shared storage?
That is why the exercise asks for a specific parameter family. A focused finite-difference check is a microscope, not a certificate for the whole training system.
Partially Completed Example
Complete the missing lines for Distribution:
Rust syntax:
Distribution wraps ________ and construction can return ________.
ML concept:
It represents probabilities over possible next tokens, so the values must be
non-negative and sum to ________.
Category theory concept:
It is an object produced by ________ and consumed with a target token by
________.
Expected completion:
Vec<f32>
CtResult<Self>
one
Softmax
CrossEntropy
Your Turn
Now solve the same kind of exercise without the filled answer. Pick Loss,
TrainingSet, or LearningRate and explain it through the same three lenses.
Transfer Exercise
Design a wrapper type for the attention roadmap or a future Transformer chapter,
such as SequenceLength, HeadCount, or AttentionScores. State the raw
representation, the invariant, and one function that should consume or produce
it.
Expected failure to consider:
What should the constructor reject?
If the answer is “nothing,” the type may be only a semantic wrapper. If the answer is “zero heads,” “empty sequence,” or “probability outside the allowed range,” the type needs a validating constructor.
Exercise 1: Explain One Domain Type
Use Domain Objects.
Pick one type:
VectorLogitsDistributionLossTrainingSetParameters
Write:
The problem this solves:
Rust syntax:
ML concept:
Category theory concept:
Pass condition:
- You name the raw representation.
- You name the invariant or semantic distinction.
- You name the pipeline stage where the type appears.
- You distinguish a semantic wrapper from a validated object when that distinction matters.
Primitive-to-domain audit option:
Use the chapter’s Primitive-To-Domain Responsibility Ledger. Fill this card:
raw value:
domain object:
constructor or boundary:
invariant owned here:
downstream code allowed to trust:
unsafe shortcut rejected:
source-backed limit:
validation command:
Pass condition:
- Your audit names the constructor or boundary that owns the conversion.
- It distinguishes semantic role labeling from invariant validation.
- It names what downstream code is allowed to trust after construction.
- It rejects one raw-primitive shortcut without overclaiming what the type proves.
First-principles hint:
#![allow(unused)]
fn main() {
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
struct LocalTokenId(usize);
impl LocalTokenId {
fn new(index: usize) -> Self {
Self(index)
}
fn index(self) -> usize {
self.0
}
}
assert_eq!(LocalTokenId::new(7).index(), 7);
}
That snippet is intentionally smaller than the real crate. It shows the raw
idea: a named wrapper can make one usize mean “token id” instead of “any
number.”
Exercise 2: Add A Token
Use the src/demo.rs snapshot in Course Map.
Add one new vocabulary item and extend the token sequence.
Run:
cargo run --bin category_ml
Pass condition:
- the demo still runs
- the dataset windowing output includes your new transition
- you can explain why a longer
TokenSequencecreates more training examples
Debugging hint:
If the output does not include the new transition, check whether you changed both the vocabulary and the token sequence. A vocabulary entry alone does not create a training pair. The pair appears only when two token ids are adjacent in the sequence.
Exercise 3: Trace DatasetWindowing
Use The Tiny ML Pipeline.
For this input:
[TokenId(4), TokenId(8), TokenId(15), TokenId(16)]
write the training examples produced by windows(2).
Then explain:
Rust syntax:
what does `.windows(2)` do?
ML concept:
why does next-token training need adjacent pairs?
Category theory concept:
why is each example a product object?
Check yourself before reading onward:
(TokenId(4), TokenId(8))
(TokenId(8), TokenId(15))
(TokenId(15), TokenId(16))
The syntax creates overlapping adjacent windows. The ML idea is next-token supervision: each input token is paired with the token that follows it. The category-theory shape is a product object because each training example carries two typed values together.
Exercise 4: Break A Composition
Use the examples/02_morphism_composition.rs snapshot in
Morphism and Composition.
Try to compose Embedding directly with Softmax.
Expected failure shape:
the trait bound ... is not satisfied
Then restore the working version.
Explain:
Rust syntax:
which type did the compiler reject?
Composition diagnostic:
first source:
first target:
second source:
second target:
which middle object should connect the stages?
ML concept:
which prediction stage was skipped?
Category theory concept:
which middle object failed to match?
Source-target-middle repair audit option:
Use the chapter’s Source-Target-Middle Repair Ledger. Fill this card:
composition attempt:
first arrow:
second arrow:
claimed middle object:
actual first target:
actual second source:
repair:
unsafe shortcut rejected:
validation command or output:
Pass condition:
- You name
Embedding : TokenId -> VectorandSoftmax : Logits -> Distribution. - You identify
VectorversusLogitsas the failed middle-object match. - You restore
LinearToLogits : Vector -> Logitsinstead of weakeningSoftmax. - You explain why the skipped ML stage is vocabulary scoring.
- Your repair audit names the attempted composition, the actual first target, the actual second source, the missing repair arrow, the unsafe shortcut, and one validation command or output line.
Debugging hint:
Do not fix this by changing the type signatures. Restore the missing stage instead. The intended path is:
TokenId -> Vector -> Logits -> Distribution
Exercise 5: Change The Training Repetition Count
Use the examples/03_training_endomorphism.rs snapshot in
Training as an Endomorphism.
Change:
StepCount::new(80)
to:
StepCount::new(1)
StepCount::new(10)
StepCount::new(200)
Run:
cargo run --example 03_training_endomorphism
Explain the result:
Rust syntax:
where is the count used?
Training diagnostic:
what object is updated?
what object measures quality?
what repeats?
what controls update size?
ML concept:
what happens when training repeats more times?
Category theory concept:
why can the update be repeated?
Framework-to-Rust audit option:
Use the Framework-To-Rust Responsibility Ledger in
Training as an Endomorphism. Pick one framework
cue:
optimizer.zero_grad()
loss.backward()
optimizer.step()
optimizer state_dict
Fill this card:
framework cue:
responsibility:
local Rust handle:
returned object:
category boundary:
safe non-claim:
Expected observation:
One step should preserve the shape of the parameters but may not reduce loss much. More steps usually make the tiny example improve until the hand-written training rule reaches its limit. The important category-theory point is not “more is always better”; it is that the same update has the shape:
Parameters -> Parameters
Pass condition:
- You distinguish
TrainStep : Parameters -> ParametersfromParameters x TrainingSet -> Loss. - You explain that loss is a measurement, not the updated model state.
- You identify
StepCountas repetition of the same update shape. - Your framework-to-Rust audit distinguishes preparation, gradient computation, parameter update, and optimizer-state scope.
- You name the returned object and avoid calling the tiny step a framework optimizer or autograd engine.
- You avoid claiming that more steps always means better behavior.
Exercise 6: Explain Distribution<T>::map
Use Functors, Naturality, Monoids, and Chain Rule.
Explain the conceptual Distribution<T>::map example.
Use this input distribution:
TokenId(2) -> 0.70
TokenId(3) -> 0.30
and this function:
TokenId -> String
where:
TokenId(2) -> "Rust"
TokenId(3) -> "."
Write the output distribution.
Then explain:
Rust syntax:
why does `self` plus `into_iter()` move the old outcomes?
ML concept:
why do the probabilities stay the same?
Category theory concept:
what does it mean to lift `T -> U` into `Distribution<T> -> Distribution<U>`?
Exercise 7: Explain One Validation Boundary
Pick one constructor:
Distribution::newLoss::newLearningRate::newTrainingSet::newSignalMatrix::newOpenCircuit::new
Write:
The problem this solves:
Rust syntax:
which condition returns `Err(...)`?
ML or software concept:
what bad runtime behavior does this prevent?
Category theory concept:
what intended object or relationship is being protected?
Exercise 8: Trace A Full Source File
Use Repository Source Snapshots.
Pick one complete source file and write a five-sentence summary:
- What problem does the file solve?
- What are the main Rust types or traits?
- What ML or software concept does it model?
- What category-theory concept does it teach?
- Which command proves the file still works?
Exercise 9: Connect One External Reference
Use References.
Pick one external resource and connect it to one source file in this course. First classify the source using the source-role table in the references chapter.
Answer:
External resource:
Source role:
Owned boundary:
Source file:
Rust syntax connection:
ML or software concept connection:
Category theory concept connection:
What this source can support:
What this source cannot support:
One difference between the full treatment and this tiny implementation:
Pass condition:
- You classify the source as official documentation, academic paper, open textbook or university material, implementation bridge, or learner-friction signal.
- You name the boundary the source owns.
- You connect it to one concrete source file, type, function, test, or example.
- You state one claim the source does not license this book to make.
Exercise 10: Test One Sketch Law
Use Seven Sketches Through Rust.
Pick one law from src/sketches.rs:
- preorder laws
- feature/layer Galois law
- resource monotonicity
- foreign-key resolution
- co-design feasibility relation
- signal-flow matrix composition
- local-to-global safety truth
Change one input in examples/05_seven_sketches.rs, then run:
cargo run --example 05_seven_sketches
Pass condition:
- you can explain which law still holds
- you can explain which constructor or method prevents invalid structure
- your explanation uses Rust syntax, ML or software concept, and category theory concept
Negative test option:
Instead of changing the runnable example, inspect one of the negative tests in
src/sketches.rs:
- missing database reference
- mismatched signal-matrix middle dimension
- open-circuit serial boundary mismatch
Explain what invalid structure the test rejects. This is often the fastest way to understand what a law or constructor is protecting.
PDF-to-Rust contract option:
Use the chapter’s PDF-To-Rust Reading Contract. Pick one source idea from
the Seven Sketches chapter and fill this row:
source idea from the PDF:
Rust handle:
protected law, relation, or boundary:
larger source claim not implemented by this code:
local evidence command or test:
If the source idea still feels too large, fill the chapter’s transfer triage card before writing the final answer:
source idea:
local Rust handle:
protected law, relation, or boundary:
invalid shortcut rejected:
tiny ML transfer:
larger claim not implemented:
local evidence command or test:
Pass condition:
- your Rust handle is one concrete type, method, constructor, example output
line, or test from
src/sketches.rs - your protected claim is smaller than the full source text
- your evidence can be checked with
cargo run --example 05_seven_sketchesorcargo test sketches::tests --lib - your transfer card names one invalid shortcut and one non-claim
Page-to-Rust decision-ladder option:
Use the chapter’s Page-To-Rust Decision Ladder. Pick one paragraph shape from
the source text:
definition or named object
relation, order, or feasibility statement
composition rule
theorem, law, or proof step
worked example or application story
richer machinery beyond the local handle
Then fill:
source paragraph shape:
first Rust move:
invalid state or shortcut to reject:
local evidence command or test:
safe non-claim:
Pass condition:
- your first Rust move is concrete: newtype, enum, struct, constructor, method, fixture, output line, or named test
- your evidence names a command or test that exists in this repository
- your safe non-claim prevents turning one local handle into a claim about the whole source text
- you do not start by inventing a broad trait or framework when a small typed boundary would expose the issue
Bridge-back-to-tiny-ML option:
Use the chapter’s Bridge Back To Tiny ML table. Pick one row and fill:
sketch:
tiny ML pressure:
Rust handle:
bad shortcut rejected:
safe non-claim:
evidence command or test:
one-sentence transfer:
The one-sentence transfer must use this shape:
This sketch helps me reject this ML shortcut: ...
Pass condition:
- your row matches one actual bridge row in the Seven Sketches chapter
- your bad shortcut is something a tiny ML system could plausibly get wrong
- your safe non-claim prevents overclaiming the larger category-theory source
- your evidence points to
cargo run --example 05_seven_sketchesorcargo test sketches::tests --lib
Co-design option:
Use DesignRequirement, ImplementationOffer, and FeasibilityRelation.
Write the relation as:
DesignRequirement x ImplementationOffer -> bool
Then translate it to:
ArchitectureConstraint x CandidateImplementation -> Bool
Pass condition:
- you explain why this is a relation rather than a function
- you give one passing offer and one failing offer
- you say why one passing implementation does not prove the whole constraint space
Exercise 11: Write A New Block Explanation
Choose any block from the source snapshots that the chapter did not explain in enough detail for you.
Write a block explanation using this structure:
The problem this block solves:
The whole block:
Rust syntax:
ML or software concept:
Category theory concept:
Core mental model:
Pass condition:
- A beginner can understand the Rust syntax.
- An ML learner can understand why the block exists.
- A category-theory learner can name the shape.
Exercise 12: Trace Attention Shape Flow
Use Transformer Roadmap, src/attention.rs, and
examples/06_attention_scores.rs.
Run:
cargo run --example 06_attention_scores
First copy the four-line Q/K/V diagnostic printed before the attention weights:
Q/K/V source diagnostic:
query rows own score rows; key/value rows own score columns
self-attention shares the hidden source before projection; projected roles stay distinct
mask polarity here: true = allowed, false = blocked
Then write one sentence for each line:
query rows:
key/value rows:
self-attention source:
mask polarity:
Write down the first time the output mentions each shape:
AttentionScores:
AttentionWeights:
AttentionOutput:
MultiHeadOutput:
ProjectedAttentionOutput:
HiddenSequence after residual:
HiddenSequence after normalization:
HiddenSequence after feed-forward:
Then explain:
Rust syntax:
which named type or boundary protects each shape?
ML concept:
what changes between scores, weights, value mixing, projection, residual,
normalization, and feed-forward?
Category theory concept:
where does the path use a product input, and where does it return to the same
HiddenSequence object?
Then classify these boundaries:
QuerySequence x KeySequence -> AttentionScores:
AttentionScores x AttentionMask -> AttentionScores:
AttentionScores -> AttentionWeights:
AttentionWeights x ValueSequence -> AttentionOutput:
LayerNormalization : HiddenSequence -> HiddenSequence:
TransformerTrainingState -> TransformerTrainingState:
HiddenSequence x MultiHeadOutput -> HiddenSequence:
Then repeat the quick roadmap classification drill without looking at the answer table. For each boundary, count the inputs first, then name the safest category shape:
HiddenSequence -> QuerySequence:
AttentionScores x AttentionMask -> AttentionScores:
LayerNormalization : HiddenSequence -> HiddenSequence:
HiddenSequence x ProjectedAttentionOutput -> HiddenSequence:
TransformerTrainingState -> TransformerTrainingState:
Then trace three boundaries through the roadmap decision flow:
AttentionScores x AttentionMask -> AttentionScores:
MaskedMultiHeadTransformerBlock[M] : HiddenSequence -> HiddenSequence:
HiddenSequence x MultiHeadOutput -> HiddenSequence:
For each one, answer:
does it type-check?
how many inputs are visible?
was one context fixed first?
safe local name:
Finally, explain this trap in one sentence:
A product input that returns its left-hand object is not automatically an
endomorphism.
Then use the same-output classification rule from the roadmap. These three
lines all end with HiddenSequence; explain why they do not have the same
category shape:
LayerNormalization : HiddenSequence -> HiddenSequence:
HiddenSequence x ProjectedAttentionOutput -> HiddenSequence:
HiddenSequence x MultiHeadOutput -> HiddenSequence:
Then answer the terminal-output audit. For each printed line, write what the line proves and what category overclaim it does not prove:
projected attention shape: 2 positions x model dimension 2
residual shape: 2 positions x model dimension 2
masked multi-head block shape: 2 positions x model dimension 2
training state step: 0 -> 1
Use this rule:
printed shape line -> target evidence
typed transformation line -> source and target evidence
category name -> only after both are known
Then answer the source-ownership diagnostic:
Self-attention:
which sequence owns the query side?
which sequence owns the key side?
which sequence owns the value side?
Cross-attention:
which sequence owns the query side?
which sequence owns the key side?
which sequence owns the value side?
Shape check:
which length counts score rows?
which length counts score columns?
Then fill the shape ledger:
target length:
source length:
attention mask:
attention output:
For each row, write:
framework cue -> Rust roadmap meaning -> category-shape consequence
Then answer the mask-role ledger:
What does an attention-mask cell select?
Why is the mask not a shorter token sequence?
Why does the mask not directly produce AttentionWeights?
Which block-level boundary keeps the mask visible instead of hidden?
In a fixed-mask view, what context was selected first?
What does true mean in this repository's AttentionMask?
Why can a framework mask with the same shape still need boolean inversion?
Write the three-step rule:
mask cells ...
softmax ...
weights ...
Then answer the linear-scope diagnostic:
Which listed boundaries are the linear Q/K/V projections?
Which boundary turns scores into nonlinear normalized weights?
Which product-input boundaries must not be collapsed into one unary map?
Which state endomorphism belongs to training rather than forward attention?
Then answer the source-scope diagnostic:
Which source supports decomposing attention into recurring components?
Which source supports comparing the linear Q/K/V part with advanced category theory?
What does neither source license you to claim about the whole Rust roadmap block?
What is the local Rust contract for every component in this book?
Then answer the architecture-constraint diagnostic:
What is one architecture constraint in the roadmap?
Which Rust type, constructor, example, or test is implementation evidence for it?
Why is that implementation evidence not the same as proving the whole future
Transformer architecture satisfies every intended constraint?
Then answer the stackability diagnostic:
Which listed boundaries can stack directly as HiddenSequence -> HiddenSequence?
Why is MaskedMultiHeadTransformerBlock not an endomorphism while the mask is
still an open input?
What are the two precise ways to repeat a masked block?
When is a fixed-mask view allowed to be named HiddenSequence -> HiddenSequence?
When is LayerNormalization allowed to be named HiddenSequence -> HiddenSequence?
When is PositionalEncoding allowed to be named HiddenSequence -> HiddenSequence?
When is MultiHeadTransformerBlock allowed to be named HiddenSequence -> HiddenSequence?
If the layer's scale and shift are being learned, which larger boundary owns
that change?
Then answer the context-fixing drill:
Open masked block:
what is the whole input object?
what is the safe category shape?
can it stack unaided as HiddenSequence -> HiddenSequence?
Fixed-mask view:
what was selected first?
what is the induced boundary?
what promise must remain true while stacking?
Changing mask per call:
what must the caller supply or carry?
why is this not the same as a fixed-mask view?
Residual addition:
which two inputs remain visible?
why is this not a unary endomorphism?
if you name the whole product as the source object, why is
(HiddenSequence x ProjectedAttentionOutput) -> HiddenSequence
still not an endomorphism?
Rust closure bridge:
what value would a closure capture to create a fixed-mask view?
which argument would remain when the closure is called?
why does the closure analogy still not change the open block boundary?
Then answer the add-norm order drill:
Which order does the current Rust block implement around the attention sublayer?
Which order does it implement around the feed-forward sublayer?
Which two local boundaries show the order?
Why can post-norm and pre-norm blocks both have shape
HiddenSequence -> HiddenSequence
while still being different morphisms?
If a future pre-norm variant is added, what must be named separately?
Before naming each boundary, write the answer to the first diagnostic question:
How many inputs does this boundary require?
Then write the source-target audit card for at least three boundaries:
boundary:
whole source object:
target object:
context status:
safe conclusion:
Use at least one product-input boundary and one fixed-context boundary.
Pass condition:
- You name at least four concrete Rust types from
src/attention.rs. - You distinguish raw attention scores from normalized attention weights.
- You explain why residual addition must return to
HiddenSequence. - You connect one terminal output line to one typed boundary.
- You explain why self-attention shares a source before projection without collapsing query, key, and value into one role.
- You map target length, source length, attention mask, and attention output from framework notation to the Rust roadmap shape ledger.
- You explain that mask cells select legal score cells before softmax, not token rows after probability mass has been assigned.
- You state that this repository’s
AttentionMaskusestruefor an allowed source position, while some framework masks usetruefor a blocked or padding position. - You explain the four-line Q/K/V diagnostic before using later attention weights or shape lines as evidence.
- You keep claims about linear Q/K/V projections separate from softmax, masking, residual addition, normalization, and training state.
- You classify the quick roadmap drill by counting inputs before naming endomorphisms.
- You use the roadmap decision flow before the same-output and source-target audit cards.
- You explain that anatomy-of-attention research supports decomposition, while parametric-endofunctor research supports a narrower linear self-attention comparison.
- You separate architecture constraints from implementation boundaries.
- You do not claim the tiny Rust roadmap implements either full formalism.
- You distinguish an open masked-block product input from a fixed-mask induced endomorphism.
- You name what context was fixed before using the fixed-mask
HiddenSequence -> HiddenSequenceview. - You state that a shape-preserving layer is an endomorphism only for a fixed
module instance, while parameter changes belong to
TransformerTrainingState -> TransformerTrainingState. - You state that positional encodings and Transformer blocks follow the same
fixed-value rule: the table or block value must already be selected before
the forward call is named
HiddenSequence -> HiddenSequence. - You state that residual-normalization order is part of the morphism, so post-norm and pre-norm blocks can share source and target while remaining different implementations.
- You classify at least one product-input morphism, one endomorphism, and one illegal boundary.
- You do not call a product-input boundary an endomorphism only because its output matches the left input object.
- You write the whole source object and target object before deciding whether a row is an endomorphism.
- You explain that naming the whole product as the source object gives a unary morphism out of the product, not an endomorphism unless the same product object is also returned.
- You explain why two lines that return
HiddenSequencecan still have different category shapes.
Exercise 13: Compute Cross-Entropy From Target Probability
Use The Tiny ML Pipeline and src/ml.rs.
The CrossEntropy morphism uses:
loss = -ln(probability assigned to the target token)
For target token TokenId(0), compare these two distributions:
confident = [0.90, 0.10]
surprised = [0.10, 0.90]
Compute:
confident loss:
surprised loss:
which one is lower:
Then run:
cargo test cross_entropy_is_lower_for_more_confident_target_probability --lib
Explain:
Rust syntax:
which code reads the target probability, and which constructor validates the
loss?
ML concept:
why does the same target token produce different losses under the two
distributions?
Category theory concept:
why is CrossEntropy a morphism from Distribution x TokenId to Loss?
Target-probability responsibility audit option:
Use the chapter’s Target-Probability Responsibility Ledger. Fill this card:
pipeline cue:
Rust handle:
ML responsibility:
category boundary:
unsafe shortcut rejected:
source-backed limit:
validation command:
Pass condition:
- You compute approximate losses for
0.90and0.10. - You explain why the target index is
0in both cases. - You connect the test name to the learning claim.
- Your target-probability audit identifies
target.index(), rejects the largest-probability shortcut, separatesLogits -> DistributionfromDistribution x TokenId -> Loss, and states that normalized probability is not calibrated confidence or full framework equivalence.
Exercise 14: Trace Naturality And Monoid Laws
Use Functors, Naturality, Monoids, and Chain Rule
and src/structure.rs.
Run:
cargo test structure::tests --lib
For the naturality square, write the two paths:
top then right:
left then bottom:
why they should match:
For the monoid law check, write the three laws:
left identity:
right identity:
associativity:
Output-to-law audit option:
Use the Output-To-Law Audit section in
Functors, Naturality, Monoids, and Chain Rule.
Pick one output line from:
cargo run --example 04_structure_and_calculus
Fill:
output line:
Rust handle:
law or boundary:
source support:
safe non-claim:
validation command:
Then explain:
Rust syntax:
which functions or methods implement each path or law?
ML or software concept:
why do consistent wrapper conversion and trace grouping matter in a pipeline?
Category theory concept:
what does commutativity mean for the square, and what does associativity mean
for the trace monoid?
Pass condition:
- You name
naturality_square_holds_for_first_option. - You name
monoid_laws_hold_for_pipeline_trace. - You explain why both naturality paths return the same
Optionvalue. - You explain why changing parentheses in trace combination should not change the final trace.
- Your output-to-law audit connects one printed line to one Rust handle, one law-shaped claim, one source-backed limit, and one validation command.
Exercise 15: Mixed Boundary Diagnosis
Use this exercise after finishing the core chapters. The goal is interleaved transfer: diagnose which kind of boundary is being protected without being told which chapter the failure came from.
For each case, classify the boundary:
invariant boundary
composition boundary
endomorphism boundary
shape boundary
local-to-global boundary
Then answer with the usual three lenses.
Cases
1. A raw usize is used where the code expects TokenId.
2. Embedding is followed directly by Softmax.
3. CrossEntropy receives Logits instead of Product<Distribution, TokenId>.
4. A training step returns Loss instead of Parameters.
5. SignalMatrix::compose_after sees mismatched middle dimensions.
6. SafetyCover reports a global claim even though one interval is false.
7. A residual connection tries to add rows with different model dimensions.
For each case, write:
Boundary type:
Rust syntax:
ML or software concept:
Category theory concept:
Smallest useful fix:
Pass condition:
- You classify all seven cases.
- You name at least five concrete Rust types or functions.
- You explain the smallest useful fix without weakening the type boundary.
- You identify which cases are about invalid values, which are about invalid composition, and which are about invalid global claims.
Debugging hint:
Do not answer every case with “the compiler rejects it.” Some failures are
constructor errors, some are returned CtError::ShapeMismatch, some are
conceptual category-shape failures, and some are law-check failures. The skill
is choosing the right explanation for the right boundary.
Exercise 16: Trace Transformer Training State
Use Transformer Roadmap, src/attention.rs, and
examples/07_transformer_training_state.rs.
Run:
cargo run --example 07_transformer_training_state
Write down the output lines for:
initial state:
forward shape:
readout update:
feed-forward update:
composed block update:
Then classify each update:
TransformerReadoutTrainStep:
TransformerFeedForwardTrainStep:
TransformerBlockTrainStep:
For each update, answer with the three lenses:
Rust syntax:
which named type performs the update, and what state does it return?
ML concept:
which parameters or sublayer does this update train?
Category theory concept:
why is the outside shape an endomorphism?
Finally, explain why this shortcut would be weaker:
readout update returns readout weights
feed-forward update returns feed-forward weights
block update returns a bag of changed matrices
Pass condition:
- You name
TransformerTrainingState,TinyTransformerParameters, and all three training-step types. - You explain that the state owns parameters, learning rate, and step count.
- You distinguish readout-only, local feed-forward, and composed block updates.
- You connect each printed step increment to
TransformerTrainingState -> TransformerTrainingState. - You explain why returning loose weights would make the next update rebuild context by hand.
Exercise 17: Reconstruct A Diagram By Hand
Use this exercise whenever a chapter diagram feels dense. The goal is not to make a prettier copy. The goal is to prove that you can recover the objects, arrows, paths, and safe claim without relying on the book’s layout.
Choose one diagram from:
Course Map:
Text -> TokenSequence -> TrainingSet -> Loss, with Parameters -> Parameters
Domain Objects:
raw representation -> domain object -> trusted downstream boundary
Morphism and Composition:
TokenId -> Vector -> Logits -> Distribution
Tiny ML Pipeline:
Distribution x TokenId -> Loss
Training as an Endomorphism:
Parameters -> Parameters
Structure and Laws:
Vec<A> -> Option<B> naturality square
Transformer Roadmap:
AttentionWeights x ValueSequence -> AttentionOutput
Then fill this reconstruction sheet:
chapter:
diagram chosen:
objects:
arrows:
two paths or state transition:
Rust handle:
command or test:
what would break if one arrow was skipped:
safe non-claim:
For the structures chapter, use:
cargo run --example 04_structure_and_calculus
cargo test structure::tests --lib
For the roadmap attention path, use:
cargo run --example 06_attention_scores
Pass condition:
- You redraw the diagram without copying the original layout.
- You label every object and arrow.
- You say whether the diagram is a pipeline, a constructor boundary, a product input, a law square, or a state update.
- You name at least one Rust type, function, example, or test connected to the diagram.
- You explain one thing the diagram does not prove.
- You explain what would break if a key arrow, product input, or state object was removed.
Retrieval Practice
Close the source file before answering these prompts.
Recall
Name three kinds of feedback this course uses:
compiler error
constructor error
test failure
Explain
Explain why a failed composition is useful evidence, not only an obstacle.
Apply
Pick one exercise you solved and rewrite it for a different type or module. Keep the same answer shape:
Rust syntax:
ML or software concept:
Category theory concept:
Where This Leaves Us
If you can complete these exercises, you can read the project without treating category theory, Rust, and ML as three disconnected subjects. You can start from a line of code, name the syntax, identify the software or ML role, and then describe the categorical shape only as far as the code justifies it.
Challenges
The problem this chapter solves is:
After a chapter explains an idea, how can a reader practice the same boundary through compiler feedback and a tiny paper-to-code translation?
The book explains the ideas. The challenge tracks make them public practice.
There are two tracks:
- Typed AI Rustlings: learn AI by fixing compiler errors.
- Paper-To-Rust: stop summarizing papers. Compile one idea.
Chapter Outcomes
By the end of this chapter, you should be able to:
- explain how a
TokenIdorDistributionchallenge uses a compiler or test signal to expose one AI boundary, - translate one paper claim into a Rust boundary, invariant, and evidence command without summarizing the whole paper,
- distinguish challenge completion evidence from accepted textbook reader feedback.
Source-Backed Challenge Contract
The challenge track uses external sources as design constraints, not as decoration.
| Source | What the source supports | Local rule in this chapter | Repository evidence |
|---|---|---|---|
| Rustlings Usage | Short exercises often ask learners to fix compile errors or pass tests. | A Typed AI Rustlings exercise should make one type boundary fail visibly before the fix. | challenges/typed-ai-rustlings/exercises/token_id_not_usize.rs |
| Rustlings Community Exercises | Focused exercise packs can target one topic. | This project targets AI boundary mistakes rather than general Rust syntax. | challenges/typed-ai-rustlings/metadata.toml |
| Adam: A Method for Stochastic Optimization | Adam is based on adaptive first-moment and second-moment estimates. | The Adam challenge must carry optimizer memory with the update boundary. | src/challenges/papers/adam.rs |
| PyTorch Adam | Production Adam exposes optimizer state, moment estimates, state_dict, and step(). | The tiny Rust challenge is a teaching boundary, not a replacement for a framework optimizer. | examples/challenge_adam.rs and tests/paper_to_rust_adam.rs |
| Rust Book: Writing Automated Tests | Tests can verify expected behavior and catch regressions. | Every public challenge needs an evidence command or test signal. | cargo test --test challenge_typed_ai and cargo test --test paper_to_rust_adam |
The transfer pattern is:
source claim -> local challenge -> compiler, output, or test evidence
For the full source map, use References.
Rustlings is the model for the first track: a small exercise should force one
syntax or type boundary to become visible before the learner moves on. The
official Rustlings usage guide
describes a command-line loop where learners repair exercises that often fail
to compile or have tests that need to pass. The
community-exercise guide
explains how focused exercise packs can target one topic. This project keeps
that spirit but changes the domain: the failure should teach an AI boundary
such as usize versus TokenId or Logits versus Distribution.
Typed AI Rustlings
The first seed exercises live under challenges/typed-ai-rustlings/.
The exercise files are intentionally broken. They are not part of the normal
Cargo build. The reference solutions are compiled by cargo test --test challenge_typed_ai.
Start with:
token_id_not_usize
logits_are_not_probabilities
The point is not to memorize types. The point is to feel the compiler reject a bad AI boundary:
usize is not TokenId
Logits are not Distribution
Paper-To-Rust
The first Paper-To-Rust challenge compiles one idea from Adam:
optimizer memory is part of optimizer state
The source paper is Adam: A Method for Stochastic Optimization. The challenge does not claim to reproduce the whole optimizer paper. It uses one narrow teaching claim from the paper: Adam carries adaptive estimates of lower-order moments, so the optimizer update must move state as well as parameters.
Run:
cargo run --example challenge_adam
cargo test --test paper_to_rust_adam
The typed shape is:
AdamModelState -> AdamModelState
That shape matters because a real Adam update carries first moment, second moment, and step count forward with the parameters.
Worked Paper-To-Rust Ledger: Adam
Use the Adam challenge as the model for future paper-to-code exercises.
The source-backed claim is narrow:
Adam uses adaptive estimates of first and second gradient moments.
Therefore the optimizer boundary must carry optimizer memory forward.
That claim is supported by the Adam paper’s description of adaptive lower-moment estimates and by the official PyTorch Adam documentation, which presents first moment, second moment, step count, and optimizer state as part of the optimizer update and state dictionary.
The challenge does not ask you to build a production optimizer. It asks you to compile one boundary:
source claim -> Rust boundary -> invariant -> test signal
| Source idea | Rust boundary | Invariant | Evidence |
|---|---|---|---|
| Adam keeps a first-moment estimate of gradients. | AdamFirstMoment inside AdamOptimizerState | same dimension as parameters | adam_step_preserves_complete_optimizer_state |
| Adam keeps a second-moment estimate of squared gradients. | AdamSecondMoment inside AdamOptimizerState | same dimension as parameters | adam_first_step_matches_bias_corrected_update |
| Bias correction depends on the update count. | AdamStepCount | step count advances with every update | step count: 1 in cargo run --example challenge_adam |
| The optimizer update consumes a gradient and returns complete state. | AdamTrainStep : AdamModelState -> AdamModelState | gradient dimension matches model state dimension | adam_rejects_gradient_dimension_mismatch |
Read the tiny update in Rust as four movements:
previous first moment + gradient -> next first moment
previous second moment + squared gradient -> next second moment
next moments + step count -> bias-corrected moments
parameters + corrected moments -> updated parameters
The category-theory shape is deliberately modest. The challenge uses
AdamTrainStep as an endomorphism on AdamModelState because the public
input and output object are the same complete state. It does not claim that
this small file proves convergence, covers every Adam variant, or replaces the
paper’s analysis.
Challenge Evidence And Textbook Feedback
Challenge completions are useful because they show whether the practice loop works. They are not automatically accepted reader reports for the textbook.
A completion becomes useful feedback when it includes:
challenge tried
command, output line, compiler error, or test name
AI boundary that became clearer
first unclear point, or none
smallest useful fix for the next reader
If the first unclear point belongs to a chapter, example, table, exercise, or command, open the closest reader report from the public review guide too. The book improves fastest from a precise blocked learning step, not from broad approval that a challenge was interesting.
Contribute A Challenge
For Typed AI Rustlings, contribute one small compiler-fix exercise.
For Paper-To-Rust, choose one paper and compile one idea. Do not submit a whole-paper summary.
Use this shape:
paper claim -> Rust type -> invariant -> test -> executable example
Before adding a challenge, name the source that owns the claim and the Rust file that owns the executable boundary. The public solution should be small enough that a reader can inspect it completely.
Where This Leaves Us
The challenge track turns the book’s recurring method into public practice:
name the boundary
make the failure visible
repair the smallest Rust shape
record the evidence
The next reference tool is the Glossary. Use it when a challenge exposes a term you can run but cannot yet explain. For example:
TokenId
Distribution
endomorphism
optimizer state
evidence signal
If the term is still unclear after the glossary, report the exact command, output line, compiler error, or table row that exposed the confusion.
Further Reading
Use these sources only after you have run at least one challenge command:
| Source | Use it to clarify | Bring it back to this evidence |
|---|---|---|
| Rustlings Usage | why a small exercise can be built around a compiler error or failing test | one Typed AI Rustlings exercise file |
| Rustlings Community Exercises | how a focused exercise pack can target one topic | challenges/typed-ai-rustlings/metadata.toml |
| Adam | why Adam carries first-moment and second-moment estimates | AdamFirstMoment, AdamSecondMoment, and AdamStepCount |
| PyTorch Adam | how a production optimizer exposes state and step() | AdamModelState -> AdamModelState |
| Rust Book: Writing Automated Tests | why tests are part of the learning artifact | cargo test --test challenge_typed_ai and cargo test --test paper_to_rust_adam |
The safe reading rule is:
read one source -> improve one challenge boundary -> run one command
Do not use a source link as proof that the challenge is correct. Use the source to refine the local Rust claim, then use a command or test as local evidence.
Practice After This Chapter
Run one command from each track:
cargo test --test challenge_typed_ai
cargo run --example challenge_adam
Then fill this challenge evidence card:
challenge tried:
command:
visible evidence signal:
AI boundary that became clearer:
first unclear point, or none:
smallest useful fix for the next reader:
For a Typed AI Rustlings challenge, the evidence signal should be a compiler error, type mismatch, test name, or solution test.
For a Paper-To-Rust challenge, the evidence signal should be a source claim, Rust type, invariant, and passing test or output line.
Retrieval Practice
Recall
Name the two challenge tracks without looking back.
Name one boundary that Typed AI Rustlings should make visible.
Name the complete Adam challenge shape.
Explain
Explain why Logits should not be accepted where a Distribution is required.
Explain why Adam’s first moment, second moment, and step count belong to the optimizer state instead of being loose helper values.
Explain why challenge completion evidence is not automatically accepted textbook reader evidence.
Apply
Choose one paper, tutorial, or framework documentation page and write only this much:
source claim:
Rust boundary:
invariant:
test or output evidence:
larger claim not implemented:
The answer is strong only if the Rust boundary is small enough for another reader to inspect completely.
Debug
For each weak challenge design, name the missing piece:
1. The challenge links to a paper but names no Rust type.
2. The exercise fails, but the failure does not teach an AI boundary.
3. The Adam challenge updates parameters but drops moment state.
4. The completion report says "I liked it" but gives no output line.
5. The challenge claims to implement a whole paper from one small test.
A useful answer should say whether the problem is a missing source claim, missing Rust boundary, missing invariant, missing evidence signal, or overclaim.
Glossary
The problem this chapter solves is:
Abstract terms are easier to remember when each term is tied to a Rust type, an ML role, and a category-theory shape.
Use this glossary as a lookup table while reading the source snapshots.
Do not read it as a separate dictionary. Each entry is deliberately anchored to the codebase. If a definition sounds abstract, jump from the term to the Rust syntax and then back to the chapter where the type or trait appears.
Reader orientation: The glossary uses compact entries, but the entries still follow the book’s main discipline: first the Rust handle, then the ML or software role, then the categorical shape.
How To Use This Glossary
Use each entry as a bridge, not as a final definition.
term -> Rust handle -> ML or software role -> category-theory shape
If a term has no Rust handle in this repository, it is not a core term for this book yet. The goal is not to collect impressive vocabulary. The goal is to make the vocabulary already used by the chapters easier to retrieve and transfer.
When a term appears in a chapter, ask:
What value, function, trait, constructor, method, test, or command makes this
term concrete?
That question keeps the glossary grounded.
Source-Backed Recovery Rules
Use this section when a term feels impressive but not usable yet. The glossary is strongest when a definition can be recovered through four anchors:
term -> source anchor -> Rust evidence -> learner evidence signal
The outside source gives the term a trustworthy boundary. The repository evidence shows the smaller claim this book actually makes. The learner evidence signal tells you what to run, inspect, or explain before moving on.
| If this term family is unclear | Source anchor | Local evidence | Learner evidence signal |
|---|---|---|---|
| domain object, invariant, smart constructor | Rust structs, Rust enums, Rust API Guidelines, recoverable Result | TokenId, TokenSequence::new, Distribution::new, Loss::new, and LearningRate::new in src/domain.rs | cargo run --example 01_domain_objects; cargo test domain::tests; explain which invalid state a constructor rejects |
| morphism, identity, composition | Rust traits, Rust generics, Seven Sketches, Category Theory for Programming | Morphism<Input, Output>, Identity<T>, and Compose<F, G, Middle> in src/category.rs | cargo run --example 02_morphism_composition; cargo test category::tests; name the middle object that makes composition legal |
| logits, distribution, cross entropy, loss | Dive into Deep Learning: Softmax Regression, CS231n Linear Classification, PyTorch CrossEntropyLoss, On Calibration of Modern Neural Networks | Logits -> Distribution -> Product<Distribution, TokenId> -> Loss in src/ml.rs | cargo run --bin category_ml; cargo test ml::tests; point to the line where target token and prediction meet, then explain why normalized probability is not automatically calibrated confidence |
| training step, parameters, endomorphism | Backprop as Functor, D2L Backpropagation, PyTorch optimizers | TrainStep : Parameters -> Parameters in src/training.rs and TransformerTrainingState -> TransformerTrainingState in src/attention.rs | cargo run --example 03_training_endomorphism; cargo run --example 07_transformer_training_state; separate measurement from update |
| functor, naturality, monoid, chain rule | Categories for the Working Mathematician, Seven Sketches, Category Theory for Programming, D2L Computational Graphs | VecFunctor, OptionFunctor, first_or_none_naturality_square, PipelineTrace, and MulOp::backward in src/structure.rs and src/calculus.rs | cargo run --example 04_structure_and_calculus; cargo test structure::tests --lib; cargo test calculus::tests --lib; explain which small law or local derivative the output checks, and which formal claim the local tests do not prove |
| query, key, value, mask, attention weights | Attention Is All You Need, PyTorch MultiheadAttention, PyTorch Transformer, PyTorch scaled dot product attention, Hugging Face Transformer course | QuerySequence, KeySequence, ValueSequence, AttentionMask, AttentionScores, and AttentionWeights in src/attention.rs | cargo run --example 06_attention_scores; cargo test attention::tests; explain why the mask is applied before softmax and why this book’s true -> allowed polarity must not be confused with APIs where true -> blocked |
| fixed module instance, parameter context, training-state update | D2L Parameter Management, PyTorch optimizers, Rust Book closures | LayerNormalization, PositionWiseFeedForward, and TransformerTrainingState in src/attention.rs | cargo run --example 07_transformer_training_state; explain why a forward sublayer is an endomorphism only for a fixed module value, while parameter changes belong to training state |
| finite difference, gradient check, local update evidence | CS231n numerical gradients, CS231n Neural Networks Part 3, PyTorch gradcheck | finite-difference tests for transformer readout, feed-forward, layer norm, attention projection, and block updates in src/attention.rs | run cargo test attention::tests::transformer_block_train_step_matches_finite_difference_for_readout_weight; state that this is local evidence, not a proof of all training |
| challenge completion, evidence signal, Paper-To-Rust, optimizer state | Rustlings Usage, Rustlings Community Exercises, Adam, PyTorch Adam, Rust Book tests | challenges/typed-ai-rustlings/, src/challenges/papers/adam.rs, examples/challenge_adam.rs, tests/challenge_typed_ai.rs, and tests/paper_to_rust_adam.rs | cargo test --test challenge_typed_ai; cargo run --example challenge_adam; explain the source claim, Rust boundary, invariant, and visible compiler, output, or test signal |
| retrieval, transfer, and misconception repair | How People Learn II, Test-Enhanced Learning, worked-example transition | the worked examples, partial examples, common misreadings, and exercise evidence map in this book | recover one term by writing the Rust handle, the protected ML role, and the exact command or test that checks it |
These source anchors do not make the glossary a substitute for the chapters. They protect the smaller local claim:
If a term matters here, the reader should be able to point to code, run a
command, inspect a failure signal, or explain a checked boundary.
If you cannot name the Rust handle or evidence signal, treat the term as unrecovered and return to the chapter or source file where it first appears.
Core Term Alignment
Some ideas have a public phrase, a Rust type, and a category-theory reading. Use this table to keep them separate.
| Public phrase | Rust handle | Use this wording when precision matters |
|---|---|---|
| training pairs | TrainingExample values inside TrainingSet | “adjacent input-target pairs” for the examples, TrainingSet for the validated Rust object |
| model state | Parameters | “parameters” when naming the Rust object, “model state” when explaining the ML role |
| probabilities | Distribution | “probabilities” for intuition, Distribution when the constructor invariant matters |
| query sequence | QuerySequence | “queries” for intuition, QuerySequence when the attention role matters |
| key sequence | KeySequence | “keys” for intuition, KeySequence when score construction needs the matching head dimension |
| value sequence | ValueSequence | “values” for intuition, ValueSequence when attention weights need source rows to mix |
| target sequence length | QuerySequence row count | “target length” for intuition, L when contrasting target positions with source positions |
| source sequence length | KeySequence and ValueSequence row count | “source length” for intuition, S when the positions being read may differ from target positions |
| attention score rows | AttentionScores | “scores” for intuition, AttentionScores when the row shape must be validated |
| attention mask | AttentionMask | “allowed positions” for intuition, AttentionMask when illegal score positions must be removed before softmax |
| mask polarity | AttentionMask | “true means allowed in this Rust type” when comparing with framework APIs whose boolean masks may use the opposite convention |
| attention weights | AttentionWeights | “weights” for intuition, AttentionWeights when each query row must sum to one |
| attention output | AttentionOutput | “mixed values” for intuition, AttentionOutput when one output row per query matters |
| head count | HeadCount | “number of heads” for intuition, HeadCount when zero heads must be rejected |
| head outputs | AttentionHeadOutputs | “outputs from several heads” for intuition, AttentionHeadOutputs when all heads must share sequence length and width |
| multi-head output | MultiHeadOutput | “concatenated heads” for intuition, MultiHeadOutput when the combined model dimension matters |
| attention output projection | AttentionOutputProjection | “projection after head concatenation” for intuition, AttentionOutputProjection when matrix shape must be validated |
| projected attention output | ProjectedAttentionOutput | “projected attention sequence” for intuition, ProjectedAttentionOutput when the post-projection width matters |
| hidden sequence | HiddenSequence | “sequence of hidden vectors” for intuition, HiddenSequence when residual shape must be protected |
| hidden-to-query projection | HiddenToQuery | “make query vectors from hidden rows” for intuition, HiddenToQuery when projection shape must be validated |
| hidden-to-key projection | HiddenToKey | “make key vectors from hidden rows” for intuition, HiddenToKey when projection shape must be validated |
| hidden-to-value projection | HiddenToValue | “make value vectors from hidden rows” for intuition, HiddenToValue when projection shape must be validated |
| residual connection | ResidualConnection | “add the sublayer output back” for intuition, ResidualConnection when sequence length and width must match |
| layer normalization | LayerNormalization | “normalize each hidden vector” for intuition, LayerNormalization when feature-wise normalization must preserve shape |
| layer norm parameters | LayerNormParameters | “scale, shift, epsilon” for intuition, LayerNormParameters when parameter dimensions must be validated |
| position-wise feed-forward | PositionWiseFeedForward | “same non-linear map at each sequence position” for intuition, PositionWiseFeedForward when two-layer shape checks must preserve hidden width |
| positional encoding | PositionalEncoding | “add position rows” for intuition, PositionalEncoding when sequence length and model width must be checked |
| self-attention | SelfAttentionHead, MultiHeadTransformerBlock | “same hidden sequence supplies query, key, and value roles” for intuition, self-attention when source ownership matters |
| cross-attention | QuerySequence, KeySequence, ValueSequence | “target sequence reads a separate source sequence” for intuition; the repository names the boundary but does not implement a full cross-attention block yet |
| single-head block | SingleHeadTransformerBlock | “one block-shaped sketch” for intuition, SingleHeadTransformerBlock when the whole boundary should preserve hidden sequence shape |
| self-attention head | SelfAttentionHead | “one query/key/value projection triple” for intuition, SelfAttentionHead when one head’s role dimensions must be validated |
| multi-head block | MultiHeadTransformerBlock | “several heads as one block” for intuition, MultiHeadTransformerBlock when head count and output-projection shape must be validated |
| masked multi-head block | MaskedMultiHeadTransformerBlock | “block with allowed attention positions” for intuition, MaskedMultiHeadTransformerBlock when the mask joins hidden state at the block boundary |
| fixed mask context | AttentionMask selected before a block call | “same mask reused for this run” for intuition, fixed context when an open masked block is viewed as HiddenSequence -> HiddenSequence |
| fixed module instance | LayerNormalization, PositionWiseFeedForward, or a block value with stored parameters | “this specific layer value” for intuition, fixed module instance when a forward call is named HiddenSequence -> HiddenSequence |
| parameter-changing update | TransformerTrainingState | “learning changed the stored parameters” for intuition, training-state endomorphism when scale, shift, weights, biases, learning rate, or step count must stay together |
| sequence logits | SequenceLogits | “vocabulary scores at each sequence position” for intuition, SequenceLogits when sequence length and vocabulary width must be explicit |
| Transformer readout | TransformerReadout | “sequence language-model head” for intuition, TransformerReadout when hidden width and vocabulary width must be validated |
| tiny Transformer parameters | TinyTransformerParameters | “position plus block plus readout” for intuition, TinyTransformerParameters when named model roles should move together |
| Transformer training state | TransformerTrainingState | “parameters plus optimizer metadata” for intuition, TransformerTrainingState when step count and learning rate matter |
| Transformer readout training example | TransformerReadoutTrainingExample | “one fixed hidden sequence with target tokens” for intuition, TransformerReadoutTrainingExample when hidden, mask, and target lengths must match |
| Transformer readout train step | TransformerReadoutTrainStep | “readout-only update” for intuition, TransformerReadoutTrainStep when the state endomorphism matters |
| Transformer feed-forward training example | TransformerFeedForwardTrainingExample | “one hidden-sequence input and target” for intuition, TransformerFeedForwardTrainingExample when feed-forward local training shape must match |
| Transformer feed-forward train step | TransformerFeedForwardTrainStep | “local feed-forward update” for intuition, TransformerFeedForwardTrainStep when the state endomorphism matters |
| Transformer block training example | TransformerBlockTrainingExample | “one sequence-to-token supervised example” for intuition, TransformerBlockTrainingExample when hidden, mask, and target lengths must match |
| Transformer block train step | TransformerBlockTrainStep | “composed readout-plus-feed-forward update” for intuition, TransformerBlockTrainStep when a sequence loss updates more than one parameter group |
| evidence signal | command output, compiler error, test name, constructor result, or table row | “visible evidence” when reporting what happened, evidence signal when the report must point to something inspectable |
| challenge completion evidence | challenge issue fields and challenge commands | “I completed this practice loop” for challenge progress, not “accepted textbook reader feedback” unless it names the first unclear point and smallest useful fix |
| source claim | a narrow statement from a source link | “the outside claim being translated” before Rust code, source claim when a challenge must name what the paper or documentation actually supports |
| Rust boundary | a named type, function, trait, constructor, test, or command | “what the repository actually implements” when separating the local exercise from the larger source |
| optimizer state | AdamOptimizerState, AdamModelState, or TransformerTrainingState | “memory carried between updates” when explaining Adam-style moment estimates, step count, parameters, and learning metadata |
| Typed AI Rustlings | challenges/typed-ai-rustlings/ and tests/challenge_typed_ai.rs | “compiler-fix AI exercise” before abstraction, Typed AI Rustlings when one type mistake is meant to fail visibly |
| Paper-To-Rust | src/challenges/papers/adam.rs, examples/challenge_adam.rs, and tests/paper_to_rust_adam.rs | “compile one paper idea” before abstraction, Paper-To-Rust when a source claim becomes a Rust boundary, invariant, and test signal |
| larger claim not implemented | limitation notes in a chapter or challenge | “what this tiny example does not prove” when keeping a source-backed claim modest |
| typed transformation | Morphism<Input, Output> | “typed transformation” before abstraction, “morphism” once the Rust trait is in view |
| product-input morphism | Product<A, B> at the input boundary | “needs two named inputs” before abstraction, product-input morphism when the arrow shape is A x B -> C |
| update step | TrainStep | “training step” for ML behavior, “endomorphism” for the Parameters -> Parameters shape |
This alignment prevents two common confusions. First, not every prose phrase is a Rust type. Second, not every Rust type is a new mathematical concept. The book uses plain phrases for intuition, Rust names for exact code, and category-theory words only when the shape is visible.
Common Misreadings Index
Use this as a small contrast drill. Each row starts with a sentence that sounds plausible, then puts the corrected boundary next to it. The point is not to memorize the table. The point is to notice which Rust object, ML role, or category-theory shape the misreading erased.
| Plausible misreading | Corrected boundary | Rust evidence | What to say instead |
|---|---|---|---|
TokenId is just a usize. | TokenId is a domain object for vocabulary positions. | TokenId is a named type consumed by token and embedding stages. | The raw number is local machinery; the boundary value says “vocabulary item.” |
Logits are probabilities. | Logits -> Distribution is a required stage. | Softmax consumes Logits and produces Distribution. | Scores become probabilities only after row or vocabulary normalization. |
| A normalized softmax probability is calibrated confidence. | Calibration is an empirical reliability claim, not just a Distribution constructor invariant. | Distribution::new validates a local probability vector; calibration needs population-level evidence outside this tiny example. | Say “normalized model probability” unless you have checked empirical calibration. |
| Loss only needs the prediction. | Distribution x TokenId -> Loss is a product-input boundary. | CrossEntropy consumes prediction and target together. | The target token tells the loss which probability to judge. |
| A training step can return changed weights only. | Parameters -> Parameters or TransformerTrainingState -> TransformerTrainingState preserves the next update shape. | TrainStep and Transformer train steps return complete state objects. | The updated object must be ready for the next step without reconstruction. |
fmap means any function call. | fmap changes inside values while preserving wrapper shape. | VecFunctor::fmap returns Vec<B> and OptionFunctor::fmap returns Option<B>. | The operation maps the contents and keeps the outer structure. |
| Returning the left object makes a boundary an endomorphism. | Count inputs first: A x B -> A is still product-input. If the product is named as one source object, (A x B) -> A is unary from the product but still not an endomorphism. | HiddenSequence x ProjectedAttentionOutput -> HiddenSequence needs two inputs. | A unary endomorphism has shape A -> A; an endomorphism on the product would have shape (A x B) -> (A x B). |
| Self-attention makes Q, K, and V the same role. | Self-attention shares source ownership before projection. | HiddenToQuery, HiddenToKey, and HiddenToValue produce separate role objects. | The same hidden sequence may feed all three projections, but the roles remain distinct. |
| Masking after softmax is equivalent. | AttentionScores x AttentionMask -> AttentionScores -> AttentionWeights. | The mask is applied before AttentionSoftmax. | Illegal positions should not receive probability mass. |
A masked block is automatically an endomorphism because it returns HiddenSequence. | MaskedMultiHeadTransformerBlock : HiddenSequence x AttentionMask -> HiddenSequence while the mask is open. | The block consumes AttentionMask at the boundary. | Keep the mask visible, or explicitly say a fixed mask induces a HiddenSequence -> HiddenSequence view for that run. |
| A layer endomorphism means the parameters are not part of the story. | LayerNormalization : HiddenSequence -> HiddenSequence is a forward call for one fixed layer value; parameter learning is TransformerTrainingState -> TransformerTrainingState. | LayerNormalization stores scale and shift; train steps return full TransformerTrainingState. | Fixed module context makes a forward endomorphism; changing parameters moves the boundary to training state. |
MultiHeadOutput can be added directly to HiddenSequence. | MultiHeadOutput -> ProjectedAttentionOutput must happen first. | ResidualConnection expects projected model-width rows. | Concatenated heads must return to model width before residual addition. |
| One finite-difference match proves training is correct. | A finite-difference check is local evidence for one selected parameter path. | Tests compare one inferred update gradient with one numerical slope. | The check supports the local implementation; it does not prove every parameter, dataset, or optimizer. |
| Challenge completion means the textbook section is clear. | Challenge completion is practice evidence; textbook feedback needs the first unclear point or an explicit “none.” | Challenge completion issues ask for evidence, lesson learned, first unclear point, and smallest useful fix. | Say “the challenge ran” for completion; say “this section became clearer because…” for reader feedback. |
| Paper-To-Rust means reimplement the whole paper. | Paper-To-Rust compiles one source claim into one Rust boundary, invariant, and test signal. | The Adam challenge uses AdamModelState -> AdamModelState for optimizer memory. | Keep the source claim narrow, then name the larger claim not implemented. |
When one of these misreadings appears in your own answer, repair it with three questions:
Which object did I erase?
Which ML or software role did that object protect?
Which category-theory shape did I name too early or too loosely?
Category-Theory Terms
Object
Rust syntax:
TokenId
Vector
Logits
Distribution
Loss
Parameters
ML concept:
An object is one kind of value in the pipeline, such as a token, vector, probability distribution, loss, or model state.
Category theory concept:
An object is something a morphism can start from or end at.
First-principles reading:
An object is the kind of thing an arrow is allowed to receive or return. In this
book, TokenId and Vector are different objects because the pipeline should
not confuse a vocabulary index with a dense numeric representation.
Morphism
Rust syntax:
pub trait Morphism<Input, Output>
ML concept:
A morphism is one transformation stage, such as embedding lookup or softmax.
Category theory concept:
A morphism is a typed arrow:
Input -> Output
First-principles reading:
In this book, “morphism” usually means “a named transformation with an input type, an output type, and a possible typed error.” The abstract name is useful only because the Rust code makes the boundary inspectable.
Identity Morphism
Rust syntax:
Identity<T>
ML concept:
Identity is a stage that leaves a value unchanged. It is useful for testing the idea of neutral transformations.
Category theory concept:
Every object has an identity arrow:
id_A : A -> A
Composition
Rust syntax:
Compose<F, G, Middle>
ML concept:
Composition connects stages:
Embedding then LinearToLogits then Softmax
Category theory concept:
If:
f : A -> B
g : B -> C
then:
g after f : A -> C
First-principles reading:
Composition is the reason the middle type matters. If the first stage produces
Vector, the next stage must accept Vector. A compiler error at this point is
useful evidence: the pipeline is missing or misordering a stage.
Product Object
Rust syntax:
Product<A, B>
ML concept:
A product stores paired values, such as:
input token x target token
prediction distribution x target token
Category theory concept:
The product object is written:
A x B
Its projections correspond to first() and second().
Product-Input Morphism
Rust syntax:
Product<A, B> -> C
ScaledDotProductScores : QuerySequence x KeySequence -> AttentionScores
WeightedValueMixing : AttentionWeights x ValueSequence -> AttentionOutput
ML or software concept:
Some transformations need two meaningful inputs at the boundary. Attention scoring needs target-side queries and source-side keys. Value mixing needs attention weights and the source values being mixed.
Category theory concept:
A product-input morphism has a product object as its input:
A x B -> C
First-principles reading:
Do not erase the product just because the output has a familiar type. The product names the fact that two inputs must agree before the transformation is legal.
Law
Rust syntax:
assert_eq!(...)
information_order_obeys_preorder_laws()
pipeline_trace_obeys_monoid_laws()
ML or software concept:
A law is expected behavior that should keep working after implementation details change.
Category theory concept:
A law states the structure a model must preserve, such as identity, associativity, reflexivity, transitivity, or composition preservation.
First-principles reading:
A law is not decoration. In this repository, a law should have a nearby test or check. Otherwise the reader has no executable reason to trust the word.
Endomorphism
Rust syntax:
Endomorphism<T>
TrainStep : Parameters -> Parameters
ML concept:
A training step updates parameters and returns parameters again.
Category theory concept:
An endomorphism is an arrow from an object back to itself:
A -> A
Functor
Rust syntax:
Functor<A, B>
VecFunctor
OptionFunctor
ML concept:
Apply a transformation inside a wrapper such as a batch or optional value.
Category theory concept:
A functor maps objects and arrows while preserving structure.
First-principles reading:
For this book, the simplest functor intuition is map: apply a function inside
a context without destroying the context. VecFunctor preserves the list
shape. OptionFunctor preserves the difference between Some and None.
Functor Map
Rust syntax:
fn map<U>(self, f: impl Fn(T) -> U) -> Distribution<U>
ML concept:
For a probabilistic output, map transforms every possible outcome while
leaving the attached probabilities unchanged.
Category theory concept:
map lifts a deterministic function:
T -> U
into a context-aware transformation:
Distribution<T> -> Distribution<U>
Natural Transformation
Rust syntax:
VecToFirstOption : Vec<A> -> Option<A>
ML concept:
Convert one container shape into another consistently, such as many candidates to maybe one selected candidate.
Category theory concept:
A natural transformation converts one functor shape into another and commutes with mapping.
Monoid
Rust syntax:
PipelineTrace
Monoid::empty()
Monoid::combine()
ML concept:
Traces, logs, batches, and metric accumulators often need an empty value and a combine operation.
Category theory concept:
A monoid has an identity element and an associative binary operation.
First-principles reading:
A monoid is the structure behind “start empty, then combine many pieces.” That is why traces, logs, resource bundles, and accumulated updates are good software examples.
Preorder
Rust syntax:
InformationLevel::can_flow_to
ML or software concept:
Information can flow from observation to feature to score to decision.
Category theory concept:
A preorder is reflexive and transitive.
First-principles reading:
In code, a preorder often appears as a “can flow to,” “can supply,” or “is no more than” relation. The important part is not sorting. The important part is that repeated comparisons remain coherent.
Galois Connection
Rust syntax:
abstract_to_layer_budget
concretize_layer_budget
ML or software concept:
Concrete feature counts and abstract layer budgets can be coordinated.
Category theory concept:
Two order-preserving views are connected by a law:
abstract(x) <= y iff x <= concretize(y)
Monoidal Preorder
Rust syntax:
ResourceBundle::tensor
ResourceBundle::can_supply
ML or software concept:
Independent compute and memory resources can be combined.
Category theory concept:
A preorder with a product-like composition operation that preserves order.
Profunctor
Rust syntax:
FeasibilityRelation::relates(requirement, offer)
ML or software concept:
A requirement and implementation offer are related if constraints are satisfied.
Category theory concept:
A profunctor generalizes a relationship between categories. This course uses a small Bool-valued relation as the practical handle.
Functorial Semantics
Rust syntax:
SignalMatrix::compose_after
ML or software concept:
Composed signal-flow stages should have the same meaning as composing their matrix interpretations.
Category theory concept:
Interpretation preserves composition.
Open System
Rust syntax:
OpenCircuit
OpenCircuit::then
OpenCircuit::parallel
ML or software concept:
A component has an external interface plus internal implementation details.
Category theory concept:
An open system composes through typed boundaries.
Commutative Diagram
Rust syntax:
composed_and_direct_prediction_match()
naturality_square_commutes()
ML or software concept:
Two different implementation paths should produce the same result.
Category theory concept:
A commutative diagram says that following one route through a diagram has the same meaning as following another route with the same start and end.
First-principles reading:
In this book, do not imagine a diagram first. Imagine two Rust expressions that should agree. The diagram is the picture of that agreement.
Sheaf-Style Locality
Rust syntax:
SafetyCover::global_truth
ML or software concept:
Local safety checks over time intervals combine into a global safety result.
Category theory concept:
Local facts can determine a global fact when they glue coherently.
Boundary
Rust syntax:
Distribution::new
TrainingSet::new
SignalMatrix::compose_after
OpenCircuit::then
ML or software concept:
A boundary is where invalid structure should be rejected before it spreads through the pipeline.
Category theory concept:
A boundary protects the intended object, morphism, relation, or composition from accepting values outside its domain.
First-principles reading:
Many exercises ask what a type or method prevents. That is a boundary question. Good boundaries make wrong connections hard to express.
Rust Terms
Newtype
Rust syntax:
pub struct TokenId(usize);
ML concept:
The same raw number type can represent different concepts. Newtypes prevent accidental mixing.
Category theory concept:
A newtype names a specific object instead of treating all raw representations as the same object.
First-principles reading:
A newtype is the smallest move from “just data” to “data with a role.” The runtime representation can stay cheap, but the type checker now knows that a token id, vocabulary size, and model dimension are not the same concept.
Smart Constructor
Rust syntax:
pub fn new(value: Raw) -> CtResult<Self>
ML concept:
Invalid training inputs, probabilities, dimensions, or hyperparameters should be rejected early.
Category theory concept:
A smart constructor maps raw data into a validated subobject, using Result
when the mapping can fail.
Invariant
Rust syntax:
Distribution must be non-empty, finite, non-negative, and sum to one.
ML concept:
The model can trust a value only if the type protects the rule that makes it meaningful.
Category theory concept:
An invariant describes the subset or structure the object is meant to inhabit.
Typed Error
Rust syntax:
CtError
CtResult<T>
ML concept:
Bad data should fail with a meaningful cause, not with a vague panic later.
Category theory concept:
Result turns a partial construction or morphism into a total error-aware
mapping.
Negative Test
Rust syntax:
assert!(matches!(..., Err(...)))
ML or software concept:
A negative test proves that invalid data or an invalid connection is rejected.
Category theory concept:
It checks that a proposed object, relation, or composition is not admitted when the required structure is missing.
First-principles reading:
Positive tests show what works. Negative tests show what the boundary protects. Both are needed when a chapter claims that types make structure explicit.
Machine-Learning Terms
Token
Rust syntax:
TokenId
ML concept:
A token is a discrete symbol from a vocabulary.
Category theory concept:
The vocabulary is a finite discrete set of possible token objects.
Training Example
Rust syntax:
pub type TrainingExample = Product<TokenId, TokenId>;
ML concept:
A training example pairs an input token with the target token that follows it.
Category theory concept:
It is a product object:
TokenId x TokenId
First-principles reading:
The product matters because the loss function needs both parts: the prediction derived from the first token and the target represented by the second token.
Training Set
Rust syntax:
TrainingSet
DatasetWindowing : TokenSequence -> TrainingSet
ML concept:
A training set is a non-empty collection of adjacent next-token examples.
Category theory concept:
It is an object produced by a data-preparation morphism and consumed by the training update.
Embedding
Rust syntax:
Embedding : TokenId -> Vector
ML concept:
An embedding maps a discrete token to a dense numerical representation.
Category theory concept:
It is a morphism from a finite token object into a vector-space-like object.
Logits
Rust syntax:
Logits(Vec<f32>)
ML concept:
Logits are raw scores before softmax.
Category theory concept:
They live in a vector-space-like object:
R^vocab_size
Softmax
Rust syntax:
Softmax : Logits -> Distribution
ML concept:
Softmax turns raw scores into probabilities.
Category theory concept:
It maps from a score vector into the probability simplex.
Distribution
Rust syntax:
Distribution
Distribution::new
ML concept:
A distribution is a probability vector over possible next tokens. Its values must be finite, non-negative, non-empty, and sum to one.
Category theory concept:
It is the object produced by softmax and consumed with a target token to produce loss.
First-principles reading:
A raw vector can contain any numbers. A Distribution is a vector that has
earned the right to be read as probabilities.
Cross Entropy
Rust syntax:
CrossEntropy : Product<Distribution, TokenId> -> Loss
ML concept:
Cross entropy measures how much probability the model assigned to the correct target.
Category theory concept:
It is a morphism from prediction-target product into non-negative scalar loss.
Loss
Rust syntax:
Loss
Loss::new
ML concept:
Loss is a scalar penalty. Lower loss means the model assigned more probability to the correct target in this tiny pipeline.
Category theory concept:
Loss is the output object of the evaluation morphism:
Distribution x TokenId -> Loss
Parameters
Rust syntax:
Parameters
ML concept:
The trainable state of the model: embedding table, output head, and bias.
Category theory concept:
The object transformed by the training endomorphism.
First-principles reading:
The word “state” can be vague. In this book, the model state is concrete:
embedding table, output head, and bias. Training means returning a new value of
the same Parameters type.
Gradient
Rust syntax:
LocalGradient
grad_embedding
grad_lm_head
grad_bias
ML concept:
A gradient tells how parameters should change to reduce loss.
Category theory concept:
Gradient flow is local derivative information composed backward through a composed computation.
Learning Rate
Rust syntax:
LearningRate
ML concept:
The scalar step size in gradient descent.
Category theory concept:
It chooses a specific update morphism from a family of parameter endomorphisms.
End-To-End Pipeline
Rust syntax:
TokenSequence -> TrainingSet
TokenId -> Vector -> Logits -> Distribution
Distribution x TokenId -> Loss
Parameters -> Parameters
ML concept:
The full tiny system turns text into training examples, predicts a next-token distribution, evaluates loss, and updates parameters.
Category theory concept:
The full pipeline is a collection of composable typed transformations, with training represented as a repeatable endomorphism on model state.
Chain Rule
Rust syntax:
MulOp::backward
ML concept:
The chain rule lets local derivatives combine into gradients for a larger computation.
Category theory concept:
It is composition of local derivative maps.
Target And Source Sequence Length
Rust syntax:
QuerySequence
KeySequence
ValueSequence
AttentionScores
AttentionMask
ML concept:
The target sequence length is the number of query positions that ask for information. The source sequence length is the number of key-value positions that can be read. In self-attention they are often the same sequence. In cross-attention they can come from different sequences.
Category theory concept:
The attention boundary keeps two roles visible:
Target positions x Source positions -> attention weights
First-principles reading:
This is why the book uses role-specific names instead of one generic matrix
name. A mask of shape L x S answers a concrete question: for each target
position, which source positions may be read?
Attention Scores
Rust syntax:
QuerySequence
KeySequence
ScaledDotProductScores : QuerySequence x KeySequence -> AttentionScores
AttentionScores
ML concept:
Attention scores are query-by-key compatibility values before softmax. The scaled dot-product boundary computes one score for each query and key pair.
Category theory concept:
ScaledDotProductScores is a morphism from a product of role-specific sequence
objects into a score table. AttentionScores is an object whose rows can be
transformed into probability-like attention weights.
First-principles reading:
The shape matters. Query and key sequences may have different lengths, but they must share the same head dimension before dot products make sense. A score table must have at least one query row, at least one key column, and the same number of key columns in every row.
Hidden-To-Role Projections
Rust syntax:
HiddenToQuery : HiddenSequence -> QuerySequence
HiddenToKey : HiddenSequence -> KeySequence
HiddenToValue : HiddenSequence -> ValueSequence
ML concept:
Self-attention begins by projecting hidden states into query, key, and value roles. The rows may all be numbers, but the roles are not interchangeable.
Category theory concept:
These are parallel morphisms from one source object:
HiddenSequence -> QuerySequence
HiddenSequence -> KeySequence
HiddenSequence -> ValueSequence
First-principles reading:
The projection constructors validate matrix shape and finite values. The application step checks that the hidden sequence width matches the projection input width before producing role-specific sequence objects.
Self-Attention
Rust syntax:
SelfAttentionHead
MultiHeadTransformerBlock : HiddenSequence -> HiddenSequence
ML concept:
Self-attention means the query, key, and value roles all come from the same hidden sequence. The roles are still distinct after projection, but their source ownership is shared.
Category theory concept:
The internal attention path still contains product-input boundaries:
QuerySequence x KeySequence -> AttentionScores
AttentionWeights x ValueSequence -> AttentionOutput
The surrounding block can have endomorphism shape only after the internal composition returns to the same public object:
HiddenSequence -> HiddenSequence
First-principles reading:
Self-attention is not permission to call every internal step an endomorphism. It is the case where one source hidden sequence is projected into the query, key, and value roles before scoring and mixing.
Cross-Attention
Rust syntax:
QuerySequence
KeySequence
ValueSequence
ML concept:
Cross-attention means the target-side query sequence reads from a separate source-side key-value sequence. The current repository names this boundary for precision, but it does not yet implement a full cross-attention block.
Category theory concept:
The source split makes the product input impossible to hide:
TargetHiddenSequence -> QuerySequence
SourceHiddenSequence -> KeySequence
SourceHiddenSequence -> ValueSequence
QuerySequence x KeySequence -> AttentionScores
AttentionWeights x ValueSequence -> AttentionOutput
First-principles reading:
When the target sequence and source sequence are not the same object, the
attention map has target rows and source columns. That is the shape reason to
keep L and S separate in explanations, masks, and tests.
Attention Mask
Rust syntax:
AttentionMask
MaskedAttentionScores : AttentionScores x AttentionMask -> AttentionScores
ML concept:
An attention mask marks which key positions each query is allowed to attend to. Disallowed score positions become a large negative value before softmax, so their probability becomes negligible.
Read the mask as a permission table, not as a shorter token sequence. A mask cell answers:
may this query row read this source column?
It selects legal score cells before probability normalization. It does not
directly produce AttentionWeights; softmax still turns the remaining score
row into weights.
Category theory concept:
MaskedAttentionScores is a typed morphism from a product object back to the
score object:
AttentionScores x AttentionMask -> AttentionScores
First-principles reading:
Every mask row must allow at least one key. Otherwise softmax would be asked to choose among no legal positions.
Recovery rule:
mask cells select legal score cells
softmax turns remaining score rows into weights
weights read value rows
Attention Weights
Rust syntax:
AttentionWeights
AttentionSoftmax : AttentionScores -> AttentionWeights
ML concept:
Attention weights are row-wise probabilities over key positions. Each query position receives its own distribution over the positions it can attend to.
Category theory concept:
AttentionSoftmax is a typed morphism from raw score rows to validated
probability rows.
First-principles reading:
This is one Transformer-roadmap boundary made executable in the crate. It validates the probability-like score-to-weight step after query-key scoring and masking have produced legal score rows.
Value Mixing
Rust syntax:
ValueSequence
WeightedValueMixing : AttentionWeights x ValueSequence -> AttentionOutput
AttentionOutput
ML concept:
Value mixing uses each query row of attention weights to compute a weighted sum of value vectors. The result has one output vector per query position.
Category theory concept:
WeightedValueMixing is a morphism from a product object to an output object:
AttentionWeights x ValueSequence -> AttentionOutput
First-principles reading:
The key length of the weights must match the number of value rows. If a query has weights over three source positions, the value sequence must provide three source vectors to mix.
Multi-Head Concatenation
Rust syntax:
HeadCount
AttentionHeadOutputs
ConcatenateHeads : AttentionHeadOutputs -> MultiHeadOutput
MultiHeadOutput
ML concept:
Several attention heads can produce one output sequence each. Concatenation combines the feature vectors at each sequence position so later layers can read all head outputs together.
Category theory concept:
ConcatenateHeads is a recombination morphism:
AttentionHeadOutputs -> MultiHeadOutput
First-principles reading:
The constructor checks that every head has the same sequence length and head dimension before concatenation. The resulting model dimension is the head count multiplied by the head dimension. This is the typed boundary where separate head outputs become one combined object.
Attention Output Projection
Rust syntax:
AttentionOutputProjection
AttentionOutputProjection : MultiHeadOutput -> ProjectedAttentionOutput
ProjectedAttentionOutput
ML concept:
After head outputs are concatenated, a learned linear projection mixes features across heads and returns the sequence to the width expected by the surrounding model block.
Category theory concept:
AttentionOutputProjection is a morphism:
MultiHeadOutput -> ProjectedAttentionOutput
First-principles reading:
The projection validates its matrix and bias before use. It also checks that
the MultiHeadOutput width matches the projection input width. This keeps the
post-concatenation linear map from becoming an untyped matrix multiply hidden
inside the example.
Residual Connection
Rust syntax:
HiddenSequence
ResidualConnection : HiddenSequence x ProjectedAttentionOutput -> HiddenSequence
ML concept:
A residual connection adds a sublayer output back to the hidden sequence it came from. The addition is only meaningful when every sequence position has the same hidden width on both sides.
Category theory concept:
ResidualConnection is a product-to-object morphism:
HiddenSequence x ProjectedAttentionOutput -> HiddenSequence
The larger Transformer block can still have endomorphism shape:
HiddenSequence -> HiddenSequence
First-principles reading:
Residual addition is not just vector arithmetic. It is a shape contract. The sequence length and model dimension must match before addition can preserve the hidden sequence object.
Layer Normalization
Rust syntax:
LayerNormParameters
LayerNormalization : HiddenSequence -> HiddenSequence
ML concept:
Layer normalization normalizes each hidden vector across its feature dimension. It keeps the sequence length and model dimension unchanged.
Category theory concept:
LayerNormalization is an endomorphism:
HiddenSequence -> HiddenSequence
First-principles reading:
The operation changes values, not the object type. The parameter object protects the scale, shift, and epsilon invariants before a hidden sequence can be normalized.
Position-Wise Feed-Forward
Rust syntax:
PositionWiseFeedForward : HiddenSequence -> HiddenSequence
ML concept:
A position-wise feed-forward network applies the same two-layer non-linear map to every hidden vector in the sequence. It can expand the feature dimension internally, apply an activation, then project back to the original model dimension.
Category theory concept:
PositionWiseFeedForward is an endomorphism:
HiddenSequence -> HiddenSequence
First-principles reading:
The internal feed-forward width is allowed to differ from the model dimension, but the public output must return to the same hidden sequence shape. The type protects that shape before later blocks try to compose with it.
Positional Encoding
Rust syntax:
PositionalEncoding : HiddenSequence -> HiddenSequence
ML concept:
Position information lets a sequence model distinguish the first token from the second token even when their content vectors are otherwise similar.
Category theory concept:
PositionalEncoding is an endomorphism:
HiddenSequence -> HiddenSequence
First-principles reading:
The encoding table must have enough rows for the hidden sequence and the same model width. Adding position changes the values at each row, not the public shape of the hidden sequence.
Single-Head Transformer Block
Rust syntax:
SingleHeadTransformerBlock : HiddenSequence -> HiddenSequence
ML concept:
The single-head block sketch composes hidden-to-role projections, attention, output projection, residual addition, normalization, and a feed-forward sublayer. It is intentionally small: one head and no production training machinery.
Category theory concept:
SingleHeadTransformerBlock is an endomorphism:
HiddenSequence -> HiddenSequence
First-principles reading:
The block is useful because it hides internal steps without hiding shape contracts. The caller sees one sequence-preserving transformation; the constructor still checks the dimensions that make the internal composition legal.
Multi-Head Transformer Block
Rust syntax:
SelfAttentionHead
MultiHeadTransformerBlock : HiddenSequence -> HiddenSequence
ML concept:
A multi-head block applies several self-attention heads in parallel, concatenates their outputs, projects back to the model dimension, then applies the same residual, normalization, and feed-forward shape-preserving pattern.
Category theory concept:
MultiHeadTransformerBlock is an endomorphism:
HiddenSequence -> HiddenSequence
First-principles reading:
The block checks that every head accepts the same hidden width, every value head has the same output width, and the output projection expects exactly:
head_count * value_head_dimension
Those checks keep multi-head attention as explicit structure rather than an unlabeled matrix pile.
Masked Multi-Head Transformer Block
Rust syntax:
MaskedMultiHeadTransformerBlock : HiddenSequence x AttentionMask -> HiddenSequence
ML concept:
A masked block runs the same multi-head path while preventing disallowed query-key positions from receiving attention probability.
Category theory concept:
MaskedMultiHeadTransformerBlock consumes a product object:
HiddenSequence x AttentionMask -> HiddenSequence
First-principles reading:
The mask is not a side channel. It is an explicit input to the block. The mask shape must match the query-by-key score table produced inside each head.
Fixed Mask View
Rust syntax:
AttentionMask
MaskedMultiHeadTransformerBlock : HiddenSequence x AttentionMask -> HiddenSequence
ML concept:
A fixed mask view means a particular mask has already been chosen for this run. For example, one training example may reuse the same allowed-position pattern every time the block is applied to its hidden sequence.
Category theory concept:
The open boundary is product-input:
HiddenSequence x AttentionMask -> HiddenSequence
After choosing one concrete mask as context, that specific run can induce a unary map:
HiddenSequence -> HiddenSequence
First-principles reading:
Do not erase the mask to get a cleaner category name. Either keep the open
product-input boundary visible, or say exactly which AttentionMask was fixed
before calling the result a HiddenSequence -> HiddenSequence view.
Sequence Logits
Rust syntax:
SequenceLogits
ML concept:
Sequence logits are unnormalized vocabulary scores for each position in a hidden sequence.
Category theory concept:
They are the output object of a sequence-level readout morphism:
HiddenSequence -> SequenceLogits
First-principles reading:
The object keeps sequence length and vocabulary size explicit. That prevents a sequence readout from becoming an unlabeled table of floats.
Transformer Readout
Rust syntax:
TransformerReadout : HiddenSequence -> SequenceLogits
ML concept:
A readout maps each final hidden vector to vocabulary scores. It is the
sequence-level version of the earlier Vector -> Logits language-model head.
Category theory concept:
TransformerReadout is a morphism from hidden sequence object to sequence
logit object:
HiddenSequence -> SequenceLogits
First-principles reading:
The readout validates the input model dimension and vocabulary width before projecting rows. The model should fail at the boundary, not inside an indexing loop.
Tiny Transformer Parameters
Rust syntax:
TinyTransformerParameters : HiddenSequence x AttentionMask -> SequenceLogits
ML concept:
The parameter object owns the position table, masked block, and sequence readout needed for the tiny Transformer forward path.
Category theory concept:
It is a product-to-object morphism:
HiddenSequence x AttentionMask -> SequenceLogits
First-principles reading:
The object groups named roles. The point is not to claim a production Transformer; the point is to stop passing unrelated matrices as loose arguments.
Transformer Training State
Rust syntax:
TransformerTrainingState
ML concept:
A training state owns parameters, a learning rate, and a step count. The current code can evaluate through the structured state and record that a new parameter object belongs to the next step.
Category theory concept:
The forward path has shape:
HiddenSequence x AttentionMask -> SequenceLogits
The optimizer updates the state with endomorphism shape:
TransformerTrainingState -> TransformerTrainingState
First-principles reading:
This is honest scaffolding. It models the state boundary the current tiny optimizer updates, without pretending that the teaching implementation is a production Transformer trainer.
Transformer Readout Training Example
Rust syntax:
TransformerReadoutTrainingExample
ML concept:
A readout training example pairs one hidden sequence and attention mask with a target token at every sequence position.
Category theory concept:
It is a validated product-like learning object that feeds a training endomorphism:
TransformerTrainingState -> TransformerTrainingState
First-principles reading:
The hidden sequence length, mask shape, and target-token count must agree before a training step can compute a meaningful loss.
Transformer Readout Train Step
Rust syntax:
TransformerReadoutTrainStep : TransformerTrainingState -> TransformerTrainingState
ML concept:
This step updates only the sequence readout. It keeps the position table and attention block fixed, computes softmax cross-entropy gradients at each sequence position, updates the readout weights and bias, and increments the step count.
Category theory concept:
The update is an endomorphism:
TransformerTrainingState -> TransformerTrainingState
First-principles reading:
This is a real update with a narrow scope. It teaches how the structured state can change without claiming that gradients already flow through every Transformer block parameter.
Transformer Feed-Forward Training Example
Rust syntax:
TransformerFeedForwardTrainingExample
ML concept:
A local feed-forward training example pairs a hidden-sequence input with a hidden-sequence target. It trains the feed-forward sublayer as a small supervised map before the book attempts full block gradients.
Category theory concept:
It is a validated training object for an endomorphism:
TransformerTrainingState -> TransformerTrainingState
First-principles reading:
The input and target must have the same sequence length and model dimension. Otherwise the squared-error training signal would compare incompatible hidden objects.
Transformer Feed-Forward Train Step
Rust syntax:
TransformerFeedForwardTrainStep : TransformerTrainingState -> TransformerTrainingState
ML concept:
This step updates the position-wise feed-forward sublayer. It computes a local squared-error gradient through the second linear layer, the ReLU gate, and the first linear layer. It leaves attention and readout parameters fixed.
Category theory concept:
The update is another endomorphism:
TransformerTrainingState -> TransformerTrainingState
First-principles reading:
This is one layer deeper than readout-only training, but it is still not full Transformer backpropagation. It is a deliberately scoped way to show that a structured state can update an internal block component without erasing the roles of the other components.
Transformer Block Training Example
Rust syntax:
TransformerBlockTrainingExample
ML concept:
A block training example pairs an input hidden sequence and attention mask with target tokens. The loss starts at sequence logits, not at a hand-written hidden target.
Category theory concept:
It is a supervised object for a state endomorphism:
TransformerTrainingState -> TransformerTrainingState
First-principles reading:
The hidden sequence length, mask shape, and target-token count must agree because the training signal is position-wise. Every row in the hidden sequence produces one vocabulary score row and expects one target token.
Transformer Block Train Step
Rust syntax:
TransformerBlockTrainStep : TransformerTrainingState -> TransformerTrainingState
ML concept:
This step updates the sequence readout, position-wise feed-forward sublayer, and attention output projection from the same token-level loss. It computes the softmax cross-entropy gradient at the readout, backpropagates through the final layer-normalization boundary and residual addition, updates the feed-forward layers through the ReLU gate, then carries the signal through the attention normalization and residual boundary to the attention output projection.
Category theory concept:
The update is a composed endomorphism:
TransformerTrainingState -> TransformerTrainingState
First-principles reading:
This is the first update in the repository where a token prediction loss reaches inside the Transformer block and updates the attention output projection, query/key/value projections, and both layer-normalization scale/shift parameter sets. It still keeps position encodings fixed. That boundary is deliberate: the implemented step is real, but it is not pretending to be a production training algorithm.
Where This Leaves Us
The glossary is not a substitute for the chapters. It is the index of the book’s repeated translation habit. When a term feels unfamiliar, connect it back to one of three things: the Rust syntax that names it, the ML or software role that motivates it, and the categorical shape that explains how it composes.
References
The problem this chapter solves is:
The course uses small Rust examples. These references point to the larger Rust, ML, category-theory, Transformer, and learning-science treatments behind those examples.
Use references as a source map, not as decoration. When you read a chapter, these links show where the larger Rust, ML, category-theory, Transformer, and learning-design ideas come from. A useful reference should help answer at least one of the book’s three recurring questions:
Rust syntax:
which source file in this course uses the idea?
ML concept:
which model, training, or learning behavior does the source explain?
Category theory concept:
which object, morphism, composition, product, endomorphism, functor, monoid, or law does it deepen?
How To Read Source Roles
Not every reference has the same job. Use this order when deciding what a source can support:
| Source role | Examples in this chapter | Use it for | Do not use it for |
|---|---|---|---|
| repository code and tests | src/domain.rs, src/ml.rs, src/attention.rs, examples/ | the executable claim this book actually makes | replacing the larger math or framework source |
| official documentation | Rust Book, PyTorch docs, TensorFlow/Keras docs, Hugging Face docs | language behavior, API shape, framework boundary checks | proving a category-theory law by itself |
| academic papers | Seven Sketches, Attention Is All You Need, Layer Normalization, Backprop as Functor | original claims, formal scope, research vocabulary | claiming the tiny Rust code implements the whole paper |
| open textbooks and university material | Dive into Deep Learning, CS231n, MIT Applied Category Theory | pedagogy, intuition, course sequence, worked explanations | overriding repository code or official API docs |
| implementation bridges | The Annotated Transformer, The Illustrated Transformer, developer discussions | connecting notation to code and finding likely reader confusions | serving as final authority for definitions or laws |
| learner-friction signals | review reports, public questions, workshop notes | deciding what to explain more slowly | proving technical correctness |
When two sources disagree in vocabulary, prefer the source that owns the boundary. Rust documentation owns Rust syntax. Framework documentation owns framework API shape. Academic papers own the formal claim they introduce. This book’s code owns only the smaller executable teaching claim that the chapter states and tests.
Chapter Reference Map
Rust
- Category Theory for Tiny ML in Rust GitHub repository is the public source for this book, including Rust modules, examples, exercises, and issue templates.
- Category Theory for Tiny ML in Rust public workshop is the first public workshop for discussing the draft and the tiny ML pipeline.
- The Rust Programming Language: Packages, Crates, and Modules explains how Rust packages are organized into library and binary crates. Use it with
src/lib.rs,src/bin/category_ml.rs, and theexamples/files. - The Rust Programming Language: Defining and Instantiating Structs supports the domain-object chapter’s use of named Rust structs.
- The Rust Programming Language: Defining an Enum supports enum-based modeling in
src/sketches.rsand future Transformer state modeling. - The Rust Programming Language: Generic Data Types supports the generic shapes in
Product<A, B>,Compose<F, G, Middle>, and the functor examples. - The Rust Programming Language: Defining Shared Behavior with Traits explains the trait contract behind
Morphism<Input, Output>,Functor<A, B>, andMonoid. - The Rust Programming Language: Recoverable Errors with
Resultexplains the error pattern behindCtResult<T>and constructors such asDistribution::new. - The Rust Programming Language: Writing Automated Tests supports the exercise design where tests act as executable feedback.
- The Rust Programming Language: Closures supports the fixed-context roadmap analogy: a callable value can capture a mask from its environment before the remaining call receives
HiddenSequence. - Rust By Example is useful when a chapter needs a smaller runnable Rust example before the real crate code.
- Rust By Example: New Type Idiom supports the idea that a wrapper type can make the compiler reject values with the wrong semantic role, even when the underlying representation is the same.
- Rust By Example: Tests gives a compact view of unit and integration test organization for readers turning exercises into checks.
- Rustlings Usage supports compiler-feedback practice where learners fix small Rust exercises from the command line.
- Rustlings Community Exercises supports the challenge-track idea that a public project can define focused exercises around one domain-specific topic.
- The rustdoc book: How to write documentation explains the documentation comments used above public types and methods.
- Rust API Guidelines Checklist is a practical review checklist for naming, documentation, type conversions, and error design.
Category Theory
- Seven Sketches in Compositionality: An Invitation to Applied Category Theory is the larger applied-category-theory text behind the companion chapter. Use it with
src/sketches.rs. - Seven Sketches in Compositionality PDF is the direct paper file for offline reading and page-by-page study.
- MIT OpenCourseWare: Applied Category Theory is a university course built around applied category theory and the Seven Sketches text. Use it when a chapter needs more examples before a formal definition.
- Categories for the Working Mathematician is the classic formal reference for category, functor, natural transformation, duality, adjunctions, limits, monoids, and related structures. Use it as precision support, not as prerequisite reading.
- Stanford Encyclopedia of Philosophy: Category Theory gives a concise academic account of objects, morphisms, identities, composition, associativity, and examples. Use it to keep the local
Morphism<Input, Output>andCompose<F, G, Middle>language aligned with the formal category definition. - Category Theory for Programming is a programming-oriented academic reference for connecting category-theory ideas to datatype and functional-programming structure.
- Category Theory for Programmers PDF source repository is a programmer-friendly bridge for readers who want a longer informal route from programming to category theory.
Machine Learning
- Dive into Deep Learning: Softmax Regression explains multiclass classification, logits, softmax, and cross entropy. Use it with
src/ml.rs. - Dive into Deep Learning: Softmax Regression Implementation from Scratch shows the implementation path behind this course’s smaller Rust version.
- Accurate Computation of the Log-Sum-Exp and Softmax Functions by Blanchard, Higham, and Higham supports the shifted softmax implementation that subtracts the maximum logit before exponentiation to improve floating-point behavior.
- On Calibration of Modern Neural Networks by Guo, Pleiss, Sun, and Weinberger supports the distinction between softmax probabilities and calibrated confidence. Use it as a modesty boundary:
Distributionmeans normalized model probabilities in this tiny example, not a guarantee that confidence matches empirical correctness. - Dive into Deep Learning: Gradient Descent gives the optimization background for
TrainStep. - Dive into Deep Learning: Forward Propagation, Backward Propagation, and Computational Graphs supports the chain-rule and training chapters.
- Automatic differentiation in machine learning: a survey by Baydin, Pearlmutter, Radul, and Siskind separates automatic differentiation, backpropagation, symbolic differentiation, and numerical finite differences. Use it when the book needs to distinguish a tiny hand-written gradient path from a general AD system.
- PyTorch torch.optim is official framework documentation for optimizer objects, gradient clearing, backward passes, and optimizer steps. Use it to contrast production training loops with the book’s tiny
TrainStep(dataset, learning_rate) : Parameters -> Parametersboundary. - Adam: A Method for Stochastic Optimization by Kingma and Ba introduces Adam as an adaptive stochastic optimizer based on first-moment and second-moment estimates. Use it for the Paper-To-Rust challenge claim that optimizer state must move with parameters.
- PyTorch Adam is official framework documentation for Adam’s public optimizer API, moment estimates, bias correction,
state_dict, andstep()boundary. Use it as a production API sanity check for the smallerAdamModelState -> AdamModelStatechallenge. - Dive into Deep Learning: Numerical Stability and Initialization is useful when explaining broader gradient-scale and initialization stability issues beyond the tiny first softmax example.
- PyTorch Autograd mechanics is official framework documentation for dynamic graph recording, saved tensors, and backward traversal with the chain rule. Use it to contrast production automatic differentiation with the book’s tiny
MulOp::backwardboundary. - Stanford CS231n: Optimization explains finite differences, numerical gradients, analytic gradients, and gradient checks. Use it with the finite-difference exercise and the
TransformerBlockTrainSteptests. - Stanford CS231n: Neural Networks Part 3 explains gradient-checking cautions, learning-rate checks, and small-data sanity checks. Use it when exercises ask readers to interpret a failed training or gradient-check signal.
- PyTorch gradcheck is official framework documentation for checking small finite differences against analytical gradients with tolerance, precision, and differentiability caveats. Use it to keep the finite-difference exercise honest about what a local gradient check can and cannot prove.
- PyTorch CrossEntropyLoss is official framework documentation for the common production interface where the input is unnormalized logits and the target is a class index or class probability. Use it as an API-shape sanity check for the book’s smaller
Logits -> Distribution -> Product<Distribution, TokenId> -> Losspath, not as the implementation target. - Stanford CS231n: Linear Classification explains linear classifiers, scores, losses, and the softmax classifier from a widely used university course.
- Deep Learning by Goodfellow, Bengio, and Courville is a standard textbook reference for the broader ML vocabulary behind the tiny examples.
- The Matrix Calculus You Need For Deep Learning gives a compact bridge from scalar calculus to the matrix shapes behind neural-network training. Use it as an advanced support reference for the chain-rule and gradient-check sections, not as a prerequisite.
Category Theory And Learning Systems
- Backprop as Functor: A compositional perspective on supervised learning connects supervised learning, parameter updates, gradient descent, and compositional structure. Use it carefully: the book’s
TrainStepis a tiny executable analogy, not a full implementation of the paper. - Compositional Deep Learning is a research reference for neural-network composition and categorical schemas. Use it as advanced context, not as prerequisite reading.
- Category Theory in Machine Learning surveys category-theory applications across gradient-based learning, probability, and equivariant learning. Use it to decide whether a new chapter claim belongs to a recognized research theme or should stay a local teaching analogy.
- Learners’ Languages develops the learner/update perspective around backpropagation, simple lenses, polynomial functors, and dynamical systems. Use it as advanced support for keeping
TransformerTrainingState -> TransformerTrainingStatemodestly framed as a state-update teaching shape. - Generalized Gradient Descent is a Hypergraph Functor treats generalized gradient descent as a functor from compositional optimization problems to open dynamical systems. Use it as advanced context for composite objectives and distributed updates, not as a prerequisite for the tiny training loop.
- Learning Functors using Gradient Descent studies category-shaped learning problems where functorial structure and composition invariants are learned with gradient descent. Use it as an advanced bridge from Seven Sketches-style schemas to learning systems.
- Categorical Deep Learning is an ICML 2024 position paper about using category theory to connect architecture constraints with implementations. Use it as advanced context for roadmap warnings that a typed implementation boundary and a mathematical architecture constraint are related but not identical.
Transformers
- Attention Is All You Need on arXiv is the original Transformer paper.
- Attention Is All You Need on the NeurIPS proceedings site is the archival conference listing.
- Dive into Deep Learning: Attention Mechanisms and Transformers is a practical bridge from softmax and vector operations to attention and Transformer blocks. Use it with
src/attention.rsfor the query-key scoring, mask, score-to-weight, value-mixing, head-concatenation, output-projection, residual, normalization, and feed-forward boundaries. - Dive into Deep Learning: Queries, Keys, and Values supports the role distinction between queries, keys, and values before the code names
QuerySequence,KeySequence, andValueSequence. - Dive into Deep Learning: Attention Scoring Functions supports the scaled dot-product, masked-softmax, and value-mixing path used by
ScaledDotProductScores,MaskedAttentionScores,WeightedValueMixing, andMaskedMultiHeadTransformerBlock. - Dive into Deep Learning: Multi-Head Attention supports the roadmap distinction between separate attention heads, concatenated head outputs, the output projection, and the
MultiHeadTransformerBlockshape. - PyTorch
MultiheadAttentionis a framework documentation reference for query, key, and value as separate forward inputs, separate source and target sequence shapes, total embedding dimension split across attention heads, and the convention that boolean attention and key-padding masks mark blocked or ignored positions. Use it as an API-shape sanity check for the book’s typed role split, multi-head shape arithmetic, and mask-polarity warnings. - PyTorch
Transformeris an official framework reference for encoder/decoder mask arguments where boolean masks mark positions that are not allowed to participate in attention. Use it to keep the roadmap honest that mask polarity is API-specific. - TensorFlow Keras
MultiHeadAttentionis a second official framework reference for the same target/query versus source/key-value distinction: query lengthT, value/key lengthS, attention masks over(B, T, S), and an allow-mask convention where1means attention is allowed. Use it to keep the roadmap’s product-input boundary and mask-polarity rule from looking like a PyTorch-only convention. - PyTorch
scaled_dot_product_attentionis a framework documentation reference for the implementation order: score, apply mask or bias, row-wise softmax, dropout if used, then value mixing. It is also a useful polarity warning: its booleanattn_maskusesTruefor participation, while some higher-level PyTorch masks useTruefor blocking or padding. Use it as an implementation sanity check, not as the book’s primary API target. - PyTorch Transformer building blocks tutorial is official tutorial material on composing low-level Transformer pieces such as nested tensors,
scaled_dot_product_attention,torch.compile, andFlexAttention. Use it when the roadmap needs production context for variable sequence lengths, padding, masks, fully masked rows, and the distinction between pedagogical boundaries and optimized framework blocks. - PyTorch
TransformerEncoderLayeris an official framework reference for the original Transformer encoder layer shape and thenorm_firstswitch. Use it to keep the roadmap’s teaching boundary honest: the book can model foundational components while still being explicit that production libraries expose broader and faster variants. - PyTorch Developer Mailing List: Understanding Multi-Head Attention for ML Framework Developers is a developer-facing implementation bridge for Q/K/V source ownership,
q_lenversuskv_len, target/source sequence naming, masks, and the data-flow shape behind PyTorch attention APIs. - Dive into Deep Learning: Self-Attention and Positional Encoding supports the need for position information before sequence attention and the
PositionalEncodingboundary. - Dive into Deep Learning: Transformer Architecture supports the residual-connection, layer-normalization, position-wise feed-forward, block, decoder masking, readout, and training-loop shape requirements used by
ResidualConnection,LayerNormalization,PositionWiseFeedForward,SingleHeadTransformerBlock,MultiHeadTransformerBlock,MaskedMultiHeadTransformerBlock,TransformerReadout, andTransformerTrainingState. - Dive into Deep Learning: Parameter Management supports the idea that model parameters should be managed as explicit named components rather than scattered unnamed arrays. Use it with
TinyTransformerParametersandTransformerTrainingState. - Dive into Deep Learning: Softmax Regression Implementation from Scratch supports the readout-only gradient step used by
TransformerReadoutTrainStep. - Dive into Deep Learning: Backpropagation and Computational Graphs supports the forward-cache and reverse-computation order used by
TransformerBlockTrainStep. - Dive into Deep Learning: Gradient Descent supports the learning-rate update shape used by
TransformerReadoutTrainStep,TransformerFeedForwardTrainStep, andTransformerBlockTrainStep. - CS231n: Neural Networks Part 3 supports the roadmap’s gradient-evidence ledger: centered finite differences, relative-error reasoning, and the warning that gradient checks are local implementation checks.
- PyTorch
gradcheckis official framework documentation for comparing finite differences with analytical gradients under tolerance, precision, differentiability, and memory-layout caveats. Use it to keep the roadmap’s finite-difference tests scoped as local evidence. - Hugging Face Course: How do Transformers work? is a practitioner-facing course reference for architecture families, attention layers, masks, and the distinction between architecture, checkpoint, and model. Use it when the roadmap needs to explain why this repository builds tiny architecture pieces rather than loading pretrained checkpoints.
- Hugging Face Transformers: Model outputs is official framework documentation for returned hidden states, attentions, and output structures. Use it as an API-shape sanity check for the roadmap’s
HiddenSequence,AttentionWeights, andSequenceLogitsboundaries. - PyTorch Design Philosophy is an official engineering note about PyTorch’s design trade-offs. Use it only as production-context background when the roadmap contrasts inspectable tiny Rust examples with full framework ergonomics.
- PyTorch Numerical Accuracy is an official engineering note about numerical behavior, precision, and reproducibility limits. Use it as a boundary reminder when the book moves from tiny deterministic examples toward production-scale floating-point systems.
- Hugging Face Transformers: Performance and Scalability is official engineering documentation for training and inference constraints in large Transformer systems. Use it as deployment-context background, not as a prerequisite for the tiny first-principles path.
- Layer Normalization by Ba, Kiros, and Hinton supports the layer-normalization boundary and the per-example mean-and-variance normalization used by the roadmap code.
- On Layer Normalization in the Transformer Architecture supports the roadmap warning that Post-LN and Pre-LN Transformer variants can share a public
HiddenSequence -> HiddenSequenceshape while differing in internal order and training behavior. - On the Anatomy of Attention is an advanced research reference for using category-theoretic diagrams to decompose attention mechanisms, compare variants, and identify recurring attention components. Use it as support for the roadmap’s component-by-component boundary map, not as a claim that the tiny Rust code implements the paper’s full formalism.
- Self-Attention as a Parametric Endofunctor is an advanced research reference for categorical structure in the linear query, key, and value portions of self-attention. Use it as precision support when discussing linear attention structure, iterated layers, positional encodings, and the limit of the book’s claims around softmax and layer normalization.
- The Annotated Transformer is useful when the roadmap needs an implementation-oriented bridge from paper notation to code.
- The Illustrated Transformer is useful when the roadmap needs visual explanation of attention, encoder/decoder structure, and token-to-vector flow.
Learning Design
- How People Learn II: Learners, Contexts, and Cultures supports the book’s learning design: prior knowledge activation, worked examples, practice, retrieval, and attention to learner context.
- Improving Students’ Learning With Effective Learning Techniques by Dunlosky, Rawson, Marsh, Nathan, and Willingham is useful when deciding whether a chapter asks readers to practice durable techniques instead of only rereading.
- Test-Enhanced Learning: Taking Memory Tests Improves Long-Term Retention by Roediger and Karpicke supports retrieval-practice prompts that ask readers to recall, explain, and apply without looking back first.
- Structuring the Transition From Example Study to Problem Solving in Cognitive Skill Acquisition by Renkl and Atkinson supports the book’s progression from worked examples to partially completed examples and then transfer exercises.
- Self-Explanations: How Students Study and Use Examples in Learning to Solve Problems by Chi, Bassok, Lewis, Reimann, and Glaser supports self-check prompts that ask readers to explain why a worked example has the shape it has.
- Counteracting detrimental effects of misconceptions on learning and metacomprehension accuracy supports short contrast prompts that place a plausible misconception next to the corrected boundary before asking for transfer.
Use this section when checking why the chapters use worked examples, retrieval prompts, contrastive mistakes, and transfer exercises. The goal is not to cite learning science on every page. The goal is to make each chapter easier to enter, practice, remember, and transfer.
Transformer Roadmap
The problem this chapter solves is:
The repository name points toward Transformers, but the current code is a foundation course. This chapter explains exactly how the current objects and morphisms point toward a future attention-based model.
The current code is not a full Transformer.
It teaches the typed pieces you need first:
tokens
vectors
logits
probabilities
loss
training updates
composition
This distinction matters. A roadmap should not pretend the current crate is a production Transformer or a full sequence model. It should show how the current typed skeleton can grow without losing the discipline that made the small examples understandable.
Reader orientation: Read this chapter as an engineering migration plan, not as a promise that the current code already contains every Transformer component.
The source path for this roadmap is:
current Rust pipeline
-> original Transformer architecture
-> implementation-oriented attention tutorials
-> future typed Rust milestones
The original Transformer paper introduced an architecture based on attention instead of recurrence or convolution for sequence transduction. Dive into Deep Learning gives a practical learning path through queries, keys, values, multi-head attention, self-attention, positional encoding, and the full Transformer architecture. Implementation tutorials such as The Annotated Transformer and visual explainers such as The Illustrated Transformer are useful bridges from paper notation to code and diagrams.
Framework documentation such as PyTorch’s attention and Transformer layer APIs is useful for one narrower purpose: checking the public shapes that production tools expose. This chapter does not copy those APIs. It uses them as a sanity check while keeping the teaching path smaller and typed.
The Hugging Face course also gives a useful distinction for this roadmap:
architecture, checkpoint, and model are not the same idea. This repository is
working on architecture pieces: named states, typed boundaries, and update
rules. It is not loading a pretrained checkpoint, and it is not wrapping a
large framework model output. When this chapter uses words such as
HiddenSequence, AttentionWeights, or SequenceLogits, read them as tiny
Rust-owned teaching objects that make the same roles inspectable.
There is also advanced category-theory work that studies attention more directly. One recent source introduces a category-theoretic diagrammatic formalism for decomposing attention mechanisms into anatomical components and comparing attention variants. Another treats the linear query, key, and value maps through a parametric categorical lens and studies how layered self-attention structure can be composed. These sources are useful precision support, but they are not a license to call every part of a Transformer the same categorical object. The parametric-endofunctor paper itself separates its linear focus from nonlinear pieces such as softmax and layer normalization. This roadmap follows the same caution: name the linear maps, product-input boundaries, shape-preserving endomorphisms, and state updates separately.
A broader categorical deep-learning source makes the same warning at the architecture level: a model can be described by constraints it should satisfy and by the implementation that realizes those constraints. This roadmap uses that distinction as a practical rule. Do not treat a compiled Rust boundary as proof that the whole architecture satisfies a mathematical constraint. Do not treat an architecture diagram as a substitute for a concrete type, constructor, example, and test.
This chapter keeps those sources in view, but it does not import their full complexity all at once. The rule is: add one typed concept only when the tiny Rust version can explain its boundary.
Chapter Outcomes
By the end of this chapter, you should be able to:
- trace the attention example from query/key scoring through masking, softmax, value mixing, projection, residual addition, normalization, and feed-forward refinement,
- classify Transformer boundaries by counting inputs before naming morphisms, product-input morphisms, endomorphisms, or illegal attempted compositions,
- separate architecture constraints from implementation evidence in the tiny Rust roadmap.
What You Already Know
If you understand the current prediction path, you already know the skeleton a Transformer will extend. Tokens become vectors, vectors move through typed transformations, and probabilities feed a loss. The future work is to replace the one-token middle with sequence-aware structure.
Transformer Role Ownership Map
Before reading the implementation status table, separate the roles. A Transformer explanation becomes hard when query, key, value, score, weight, mask, and hidden-state roles all look like raw vectors or matrices. This roadmap assigns each role to a named Rust type or boundary.
| Transformer role | Rust owner | Boundary shape | Confusion prevented |
|---|---|---|---|
| hidden state sequence | HiddenSequence | model-width rows over sequence positions | treating one token vector as a full sequence |
| query role | QuerySequence and HiddenToQuery | HiddenSequence -> QuerySequence | passing values where queries are expected |
| key role | KeySequence and HiddenToKey | HiddenSequence -> KeySequence | comparing against value vectors instead of keys |
| value role | ValueSequence and HiddenToValue | HiddenSequence -> ValueSequence | mixing scores directly instead of value vectors |
| raw attention scores | AttentionScores | QuerySequence x KeySequence -> AttentionScores | treating unnormalized scores as probabilities |
| mask | AttentionMask | AttentionScores x AttentionMask -> AttentionScores | allowing illegal positions into softmax |
| normalized attention weights | AttentionWeights | AttentionScores -> AttentionWeights | forgetting that each query row is a distribution over source positions |
| value mixing | AttentionOutput | AttentionWeights x ValueSequence -> AttentionOutput | multiplying weights without saying what information is read |
| multiple heads | AttentionHeadOutputs and MultiHeadOutput | head outputs -> concatenated model-width rows | losing head count and head dimension |
| output projection | ProjectedAttentionOutput | MultiHeadOutput -> ProjectedAttentionOutput | leaving concatenated heads at the wrong width |
| residual boundary | ResidualConnection | HiddenSequence x ProjectedAttentionOutput -> HiddenSequence | adding tensors that cannot return to the block input shape |
| layer normalization | LayerNormalization | HiddenSequence -> HiddenSequence | changing values while accidentally changing the public object |
| feed-forward sublayer | PositionWiseFeedForward | HiddenSequence -> HiddenSequence | forgetting that the sublayer is position-wise and shape-preserving |
| block mask boundary | MaskedMultiHeadTransformerBlock | HiddenSequence x AttentionMask -> HiddenSequence | hiding the mask inside loose optional state |
| sequence readout | TransformerReadout and SequenceLogits | HiddenSequence -> SequenceLogits | confusing hidden states with vocabulary scores |
| training state | TransformerTrainingState | state plus learning rate plus step count | passing loose parameters without optimizer context |
This table is the chapter’s first debugging tool. If a later attention formula feels vague, point to the row that owns the role. The typed roadmap should make the question concrete:
Which object owns this role?
Which boundary produces it?
Which invalid connection should fail?
Category Naming Contract
Before this chapter calls an attention boundary an endomorphism, count its inputs. The original Transformer architecture, the query-key-value teaching path, and framework attention APIs all expose the same warning: attention is not one vague arrow from a sequence to itself. Some stages need a query side and a source side. Some stages need a mask. Some stages need the previous hidden stream and a sublayer output.
Use this contract while reading the roadmap:
| If the boundary has shape | Name it as | Example | Do not call it |
|---|---|---|---|
A -> B | ordinary morphism | AttentionScores -> AttentionWeights | an endomorphism |
A -> A | endomorphism | LayerNormalization : HiddenSequence -> HiddenSequence | a product boundary |
A x B -> C | product-input morphism | AttentionWeights x ValueSequence -> AttentionOutput | a unary transform |
A x B -> A | product-input morphism returning A | HiddenSequence x ProjectedAttentionOutput -> HiddenSequence | an endomorphism unless the whole input object and output object are identical |
| missing projection or wrong role | illegal attempted composition | HiddenSequence x MultiHeadOutput -> HiddenSequence | a clever shortcut |
One more context rule matters for learned layers. When this roadmap writes
LayerNormalization : HiddenSequence -> HiddenSequence or
PositionWiseFeedForward : HiddenSequence -> HiddenSequence, it means:
for this fixed layer instance, with its current parameters already stored
inside the Rust object
If the parameters themselves are allowed to vary, name the larger boundary
instead. For example, a parameter-learning story belongs to
TransformerTrainingState -> TransformerTrainingState, not to a hidden
sequence endomorphism that silently changes weights.
This rule keeps the category-theory vocabulary proportional to the code. The linear query, key, value, positional, and layered pieces can be compared with advanced categorical work on self-attention. Masking, softmax, residual addition, normalization, feed-forward refinement, and training updates still need their own typed boundaries in this teaching project.
Fixed-Value Endomorphism Ledger
Use this ledger whenever a roadmap boundary looks like
HiddenSequence -> HiddenSequence. The shape is not enough by itself; the
stored context must also be stable for the forward call.
| Boundary | Fixed value that makes the unary view valid | If that value changes |
|---|---|---|
PositionalEncoding : HiddenSequence -> HiddenSequence | one position table with fixed row count and model dimension | name the table update or rebuild path separately |
LayerNormalization : HiddenSequence -> HiddenSequence | one layer-normalization value with fixed scale, shift, and epsilon | move to TransformerTrainingState -> TransformerTrainingState |
PositionWiseFeedForward : HiddenSequence -> HiddenSequence | one feed-forward value with fixed weights, biases, and activation rule | move to TransformerTrainingState -> TransformerTrainingState |
MultiHeadTransformerBlock : HiddenSequence -> HiddenSequence | one block value with fixed heads, projections, residual path, normalization, and feed-forward layers | move to TransformerTrainingState -> TransformerTrainingState |
fixed-mask view of MaskedMultiHeadTransformerBlock | one named AttentionMask selected before the hidden-sequence call | return to HiddenSequence x AttentionMask -> HiddenSequence, or name a larger state carrying the changing mask |
This table is backed by the same source roles as the precision rules below: parameter-management references explain why model components own parameters, optimizer references explain why changing parameters belongs to the training loop, and Rust closure references explain the local analogy for fixing a mask before the remaining call.
Add-Norm Order Ledger
Residual addition and layer normalization are not only two names that happen in the same neighborhood. Their order is part of the block boundary.
The original Transformer uses residual addition around each sublayer followed
by layer normalization. Dive into Deep Learning teaches the same AddNorm
shape as residual addition followed by layer normalization. PyTorch exposes the
order as a configurable boundary: TransformerEncoderLayer has norm_first,
where layer normalization can happen before attention and feed-forward
operations instead of after them. Research on layer-normalization placement
also treats the difference between Post-LN and Pre-LN as an optimization
question, not a cosmetic rewrite.
This repository currently teaches the post-add normalization path:
attention sublayer output
-> ResidualConnection
-> attention_norm
feed-forward sublayer output
-> ResidualConnection
-> feed_forward_norm
Use this ledger when reading or extending the block:
| Order question | Source signal | Current Rust reading | Safe category statement |
|---|---|---|---|
| original post-norm shape | original Transformer and D2L AddNorm place normalization after residual addition | ResidualConnection runs before attention_norm and feed_forward_norm | fixed block is still HiddenSequence -> HiddenSequence |
| configurable framework shape | PyTorch norm_first can move normalization before attention and feed-forward operations | no pre-norm block is implemented here yet | a future pre-norm block needs a named constructor or type |
| optimization meaning | layer-normalization placement affects gradient behavior in Transformer training | current tests validate the local post-add path only | same source and target object does not imply same morphism |
| teaching boundary | order is visible in MultiHeadTransformerBlock::apply and MaskedMultiHeadTransformerBlock::apply_with_cache | residual output is normalized before feed-forward runs | do not erase order when explaining composition |
The important category-theory lesson is modest:
post-norm block : HiddenSequence -> HiddenSequence
pre-norm block : HiddenSequence -> HiddenSequence
Those two arrows can have the same source and target while being different morphisms. Shape compatibility permits composition. It does not say the two implementations are interchangeable.
Source-Backed Precision Rules
Use this table as a citation-to-claim guard while reading the rest of the roadmap. Each source supports a local teaching rule. None of them should be used as a shortcut around the typed Rust boundary.
| Source signal | Local rule in this roadmap | Rust evidence to inspect |
|---|---|---|
| Attention Is All You Need introduces the Transformer around attention instead of recurrence or convolution | treat attention as the architecture target, not as proof that the current crate is a full Transformer | examples/06_attention_scores.rs is a shape lab, not a production model |
Dive into Deep Learning: Scaled Dot Product Attention writes attention with n queries and m key-value pairs | keep query-side length and source-side length visible before naming the morphism | QuerySequence x KeySequence -> AttentionScores and AttentionWeights x ValueSequence -> AttentionOutput |
PyTorch MultiheadAttention exposes separate query, key, and value inputs with target length L and source length S | do not collapse self-attention and cross-attention into one vague HiddenSequence -> HiddenSequence arrow | TargetHiddenSequence -> QuerySequence and SourceHiddenSequence -> KeySequence, ValueSequence are the future cross-attention shape |
PyTorch scaled dot product attention says the attention mask must broadcast to the attention-weight shape, a boolean True means the element participates in attention, and a float mask is added to attention scores | a mask modifies the score table before probability normalization; it is not a token sequence and not attention weights | AttentionScores x AttentionMask -> AttentionScores runs before AttentionScores -> AttentionWeights |
PyTorch Transformer and PyTorch MultiheadAttention expose mask arguments where boolean True can mean “not allowed” or “ignore this key” | mask shape and mask polarity are separate ideas; translate polarity before comparing APIs with this Rust roadmap | AttentionMask::new(vec![vec![true, false, true], ...]) uses true for “this source position is allowed” |
TensorFlow Keras MultiHeadAttention uses query shape (B, T, dim), value/key shape (B, S, dim), mask shape (B, T, S), and a boolean attention mask where 1 means attention is allowed | treat target/query length, source/key-value length, and allow-mask polarity as framework-neutral shape evidence | AttentionMask answers which source positions each target position may read |
| PyTorch Transformer building blocks separates dense tensors, nested tensors, masks, scaled dot-product attention, and cross-attention concerns | production masking and variable-length behavior are framework boundary choices; the tiny Rust mask is deliberately stricter | AttentionMask::new rejects a row with no legal keys |
PyTorch TransformerEncoderLayer exposes norm_first and the original encoder-layer reference shape | residual-normalization order is a named architecture choice, not a detail to hide behind HiddenSequence -> HiddenSequence | MultiHeadTransformerBlock::apply uses post-add normalization today; a future pre-norm variant needs a named boundary |
| On Layer Normalization in the Transformer Architecture distinguishes Post-LN and Pre-LN Transformer variants and studies their training behavior | same source and target object can still mean different morphisms when the internal order changes | local tests validate the current post-add path, not every normalization-order variant |
| Dive into Deep Learning: Parameter Management treats parameters as named model components that can be accessed and updated | a forward sublayer may be an endomorphism only for a fixed layer instance; parameter-changing claims belong to the training-state boundary | LayerNormalization stores scale and shift parameters; TransformerTrainingState owns mutable training context |
| CS231n Neural Networks Part 3 and PyTorch gradcheck compare numerical finite differences with analytical gradients under tolerance and precision caveats | a finite-difference match is local evidence for one selected parameter path, not proof of every gradient, dataset, optimizer, or future training loop | transformer_block_train_step_matches_finite_difference_for_readout_weight, feed-forward, layer-normalization, output-projection, and attention-projection tests |
| Rust Book: Closures explains closures as callable values that can capture values from their surrounding environment | use closure capture as the Rust analogy for fixing a mask context before applying a unary view | `move |
| On the Anatomy of Attention studies attention by decomposing variants into components | decompose attention first, then compare variants | the roadmap names scores, masks, weights, values, heads, projection, residuals, normalization, and feed-forward separately |
| Self-Attention as a Parametric Endofunctor focuses on linear self-attention structure and explicitly separates nonlinear pieces | use “endofunctor” language only after naming the linear scope; do not carry it through softmax, masking, residuals, normalization, or training state without a new argument | HiddenSequence -> QuerySequence is a linear role-producing morphism; AttentionScores x AttentionMask -> AttentionScores is still product-input context |
Attention Mental Model Repair Table
Use this table before the first attention example. Each row repairs a tempting mental model with one source-backed rule and one local Rust checkpoint.
| Tempting mental model | Safer model | Rust checkpoint |
|---|---|---|
| query turns into key, then key turns into value | query, key, and value are roles; in self-attention they are parallel projections from the same hidden source, and in cross-attention the query side and key-value side may come from different sequence objects | HiddenSequence -> QuerySequence, HiddenSequence -> KeySequence, and HiddenSequence -> ValueSequence are siblings, not a pipeline |
| raw scores are already attention probabilities | scores become probabilities only after mask handling and row-wise softmax; value mixing happens after weights exist | AttentionScores x AttentionMask -> AttentionScores -> AttentionWeights comes before AttentionWeights x ValueSequence -> AttentionOutput |
| same output shape means endomorphism | count the whole source object first; A x B -> A returns A, but it is still a product-input boundary while B is open | HiddenSequence x ProjectedAttentionOutput -> HiddenSequence is not the same shape as HiddenSequence -> HiddenSequence |
| fixing a mask means the mask disappeared | a fixed-context view is a new named view after one mask has been chosen; the open boundary remains product-input | MaskedMultiHeadTransformerBlock[M] : HiddenSequence -> HiddenSequence is valid only after naming the fixed AttentionMask M |
The repair pattern is:
bad shortcut -> source-backed role or shape rule -> local Rust boundary
Do this before using a category-theory label. The label should describe the boundary the reader can inspect, not the shortcut the reader is trying to remember.
If a future chapter cites a stronger categorical result, it should add the same three pieces:
source claim
local typed boundary
validation command or test
That keeps the roadmap useful for both readers: the ML reader can see which shape is being implemented, and the category-theory reader can see which formal claim is being used and where it stops.
Worked Example Priority
The roadmap now has many typed attention boundaries. A reader does not need all of them expanded at the same depth on a first pass. Use this priority table to decide which sections deserve worked examples before more implementation is added.
| Priority | Boundary | Why this comes first | Evidence to ask from a reader |
|---|---|---|---|
| 1 | AttentionScores x AttentionMask -> AttentionScores -> AttentionWeights | readers often confuse raw scores, masked scores, and probabilities | Can the reader explain which positions were removed before softmax? |
| 2 | HiddenSequence -> QuerySequence, KeySequence, ValueSequence | query, key, and value are numerically similar but semantically different roles | Can the reader say which role asks, which role is compared, and which role is mixed? |
| 3 | AttentionWeights x ValueSequence -> AttentionOutput | attention becomes useful only when weights read values | Can the reader trace one output row as a weighted sum of value rows? |
| 4 | AttentionHeadOutputs -> MultiHeadOutput -> ProjectedAttentionOutput | multi-head attention adds shape arithmetic that can hide mistakes | Can the reader compute head_count * head_dimension and name the projection input width? |
| 5 | HiddenSequence x ProjectedAttentionOutput -> HiddenSequence | residual addition explains why many sublayers return to the same object | Can the reader explain why mismatched sequence length or model dimension must fail? |
| 6 | HiddenSequence -> HiddenSequence for normalization and feed-forward | these are shape-preserving sublayers, not new sequence objects | Can the reader name what changes and what stays invariant? |
| 7 | TransformerTrainingState -> TransformerTrainingState | training is important, but it should come after forward shape ownership is clear | Can the reader separate readout-only, local feed-forward, and composed block updates? |
This table is not a ranking of importance. It is a ranking of teaching risk. The first three rows protect the core attention story:
roles -> scores -> masked weights -> mixed values
If a reader cannot trace that path, the later block and training sections will feel like a list of names. If the reader can trace it, residuals, normalization, feed-forward layers, and training state have a stable place to attach.
Worked Example: Mask Before Softmax
The original Transformer formula and implementation-oriented attention references put softmax after query-key scoring. Practical implementations add the attention mask to the score table before softmax. The reason is simple: only legal positions should compete for probability mass.
The runnable attention example starts with two query positions and three key positions:
let queries = QuerySequence::new(vec![vec![1.0, 0.0], vec![0.0, 1.0]])?;
let keys = KeySequence::new(vec![vec![1.0, 0.0], vec![0.0, 1.0], vec![1.0, 1.0]])?;
let values = ValueSequence::new(vec![vec![1.0, 10.0], vec![2.0, 20.0], vec![3.0, 30.0]])?;
let mask = AttentionMask::new(vec![vec![true, false, true], vec![true, true, true]])?;
For the first query row, scaled dot product produces:
raw scores:
[0.7071, 0.0000, 0.7071]
mask:
[true, false, true]
masked scores:
[0.7071, very negative, 0.7071]
row-wise softmax:
[0.5, 0.0, 0.5]
The middle key position has a real raw score, but the mask says this query position is not allowed to read it. The mask must therefore act before softmax. After softmax, the illegal position would already have received probability mass.
The value-mixing step then reads only the allowed value rows:
0.5 * [1.0, 10.0]
+ 0.0 * [2.0, 20.0]
+ 0.5 * [3.0, 30.0]
= [2.0, 20.0]
That is why the typed path is:
AttentionScores x AttentionMask -> AttentionScores
AttentionScores -> AttentionWeights
AttentionWeights x ValueSequence -> AttentionOutput
The mask boundary is not a cosmetic option. It protects the meaning of the
probability row. AttentionWeights should answer:
among the positions this query may read, how much should each one contribute?
In this repository, masked-out scores become a very negative finite value instead of a non-finite value so that the pedagogical constructors can keep the “all scores are finite” invariant. The teaching meaning is the same as the standard attention implementation pattern: make disallowed positions effectively impossible before row-wise softmax.
Mask Polarity Ledger
Mask shape answers:
which query row and source column is this mask cell about?
Mask polarity answers:
does true mean allowed, or does true mean blocked?
Those are different questions. This repository chooses the smaller teaching polarity:
true -> this query may read this source position
false -> this query may not read this source position
That choice matches the boolean mask meaning used by PyTorch’s
scaled_dot_product_attention, where True means the element participates in
attention. It also matches the Keras MultiHeadAttention attention-mask rule
where 1 marks a query-key pair that may attend. It does not match every
PyTorch attention API. In MultiheadAttention padding masks, and in the
boolean masks described by torch.nn.Transformer, True can mean the
position is blocked or ignored.
So translate a framework mask in two steps:
| Question | Rust roadmap answer | Framework caution |
|---|---|---|
| What is the shape? | one cell per query-source score position | L x S, (B, T, S), and padding masks point at different axes |
| What is the polarity? | true means allowed | some APIs use true to mean blocked or padding |
| When is it applied? | before softmax, while values are still scores | after-softmax masking would change the meaning of the probability row |
The safe translation rule is:
first match the mask cells to score cells,
then translate boolean polarity,
then apply the mask before softmax
Do not carry a raw boolean mask from a framework into this Rust roadmap without stating its polarity. Two masks can have the same shape and opposite meaning.
Production Masking Caveat
The tiny AttentionMask in src/attention.rs is stricter than a production
framework boundary. For example, this constructor call is rejected:
AttentionMask::new(vec![vec![false, false]])
The reason is pedagogical. In this book, every attention-weight row should mean:
among at least one legal source position, how much should each one contribute?
If a row allows no source positions, there is no probability support for that row. The constructor therefore returns:
Err(CtError::EmptyInput("attention mask row allows no keys"))
That boundary is intentionally less general than production Transformer
libraries. PyTorch’s Transformer building-blocks tutorial discusses nested
tensors, variable sequence lengths, padding masks, and the production problem
of fully masked rows. It notes that softmax over an empty set is undefined and
that newer scaled_dot_product_attention behavior returns zero output for
fully masked rows.
The contrast is useful:
| Concern | Production framework boundary | Tiny teaching boundary |
|---|---|---|
| variable sequence lengths | ragged batches, padding, nested tensors, and mask ergonomics | each example uses one explicit rectangular mask |
| fully masked query row | framework must decide a stable output convention | constructor rejects the row before softmax |
| performance | fused kernels, compilation, and memory-aware representations | small values the reader can inspect by hand |
This is not a disagreement with framework behavior. It is a scope decision. It
preserves the invariant that AttentionWeights is a row-wise
distribution over at least one source position. A future production-oriented
chapter can relax that boundary only if it also names the new output convention
for rows with no legal source positions.
What Exists Now
The current model has this prediction path:
TokenId -> Vector -> Logits -> Distribution
The implementation status is:
| Concept | Current status | Reason |
|---|---|---|
| Token ids | implemented | TokenId and TokenSequence already exist |
| Vectors | implemented | Vector is the current hidden representation |
| Logits and probabilities | implemented | LinearToLogits and Softmax are executable |
| Loss | implemented | CrossEntropy evaluates prediction against target |
| Parameter update | implemented | TrainStep updates Parameters |
| Query-key score boundary | implemented as a tiny roadmap sketch | QuerySequence x KeySequence -> AttentionScores is executable |
| Attention mask boundary | implemented as a tiny roadmap sketch | AttentionScores x AttentionMask -> AttentionScores is executable |
| Attention score-to-weight boundary | implemented as a tiny roadmap sketch | AttentionScores -> AttentionWeights is executable |
| Value-mixing boundary | implemented as a tiny roadmap sketch | AttentionWeights x ValueSequence -> AttentionOutput is executable |
| Multi-head concatenation boundary | implemented as a tiny roadmap sketch | AttentionHeadOutputs -> MultiHeadOutput is executable |
| Attention output projection boundary | implemented as a tiny roadmap sketch | MultiHeadOutput -> ProjectedAttentionOutput is executable |
| Sequence hidden states | implemented as a tiny roadmap sketch | HiddenSequence is executable for residual addition |
| Residual addition boundary | implemented as a tiny roadmap sketch | HiddenSequence x ProjectedAttentionOutput -> HiddenSequence is executable |
| Layer normalization boundary | implemented as a tiny roadmap sketch | HiddenSequence -> HiddenSequence is executable through LayerNormalization |
| Position-wise feed-forward boundary | implemented as a tiny roadmap sketch | HiddenSequence -> HiddenSequence is executable through PositionWiseFeedForward |
| Hidden-to-query/key/value projections | implemented as a tiny roadmap sketch | HiddenSequence -> QuerySequence, HiddenSequence -> KeySequence, and HiddenSequence -> ValueSequence are executable |
| Single-head block boundary | implemented as a tiny roadmap sketch | SingleHeadTransformerBlock : HiddenSequence -> HiddenSequence composes the current boundaries |
| Multi-head block boundary | implemented as a tiny roadmap sketch | MultiHeadTransformerBlock : HiddenSequence -> HiddenSequence composes several SelfAttentionHead values |
| Positional encoding | implemented as a tiny roadmap sketch | PositionalEncoding : HiddenSequence -> HiddenSequence adds position rows while preserving shape |
| Masked block variants | implemented as a tiny roadmap sketch | MaskedMultiHeadTransformerBlock : HiddenSequence x AttentionMask -> HiddenSequence accepts a block-level mask |
| Sequence logits and readout | implemented as a tiny roadmap sketch | TransformerReadout : HiddenSequence -> SequenceLogits produces vocabulary scores at each sequence position |
| Structured Transformer parameter object | implemented as a tiny roadmap sketch | TinyTransformerParameters : HiddenSequence x AttentionMask -> SequenceLogits owns position, masked block, and readout pieces |
| Structured Transformer training state | implemented as a tiny roadmap sketch | TransformerTrainingState owns parameters, learning rate, and step count |
| Readout-only training step | implemented as a tiny roadmap sketch | TransformerReadoutTrainStep : TransformerTrainingState -> TransformerTrainingState updates only the sequence readout |
| Local feed-forward training step | implemented as a tiny roadmap sketch | TransformerFeedForwardTrainStep : TransformerTrainingState -> TransformerTrainingState updates only the position-wise feed-forward sublayer against hidden targets |
| Composed block training step | implemented as a tiny roadmap sketch | TransformerBlockTrainStep : TransformerTrainingState -> TransformerTrainingState updates readout, feed-forward, attention-output-projection, query/key/value, and layer-normalization parameters from sequence targets through residual, normalization, and attention paths |
This table is a guardrail. When extending the project, do not present planned items as implemented content. Add the type, example, test, chapter prose, and reference link together.
Rust Syntax
The path is implemented with:
Embedding
LinearToLogits
Softmax
Compose
The main domain objects are:
TokenId
Vector
Logits
Distribution
Parameters
The training update is:
TrainStep : Parameters -> Parameters
ML Concept
This is a tiny next-token model.
It predicts from one token at a time.
The main training example is still that small. The roadmap module now sketches attention blocks and structured Transformer state, but it does not yet train a production Transformer.
Still, it already teaches the core path:
discrete token
-> dense representation
-> vocabulary scores
-> next-token probabilities
Category Theory Concept
The current system teaches composition:
TokenId -> Vector -> Logits -> Distribution
and endomorphism:
Parameters -> Parameters
Those two shapes remain central in Transformers.
Step 1: Sequences As First-Class Objects
The future problem:
Attention does not operate on one token alone. It operates on a sequence of hidden states.
The current code already has TokenSequence, but that is a sequence of token
ids. Attention needs a sequence of hidden vectors, usually with position and
mask information attached. That is a different object with different
invariants.
Worked Example: Validating Sequence Length
The first-principles Rust move is the same one used throughout the book: do not let a meaningful value travel as a raw primitive once it crosses a conceptual boundary. The roadmap module now starts with a small validating type:
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct SequenceLength(usize);
impl SequenceLength {
pub fn new(value: usize) -> CtResult<Self> {
if value == 0 {
return Err(CtError::EmptyInput("sequence length"));
}
Ok(Self(value))
}
pub fn value(&self) -> usize {
self.0
}
}
Self-Check
Before reading the roadmap steps, explain why a future SequenceLength should
not be passed around as a bare usize.
Rust Syntax
A future extension should introduce types such as:
pub struct Position(usize);
pub struct SequenceLength(usize);
pub struct HiddenSequence(Vec<Vector>);
pub struct AttentionMask(/* validated mask representation */);
The important rule is the same as this course:
do not pass raw vectors across architectural boundaries
ML Concept
Attention needs a representation like:
[hidden_0, hidden_1, hidden_2, ...]
plus position and mask information.
Category Theory Concept
The object changes from:
Vector
to:
Sequence(Vector)
The next morphisms operate on structured sequences.
Design contract:
TokenSequence -> HiddenSequence
should not be represented as:
Vec<usize> -> Vec<Vec<f32>>
The second shape hides every domain distinction the course has worked to make visible.
Step 2: Query, Key, And Value Projections
The current problem:
Attention compares tokens by projecting hidden states into query, key, and value spaces.
The important design move is not only three matrices. It is three roles. A query vector, key vector, and value vector may share the same numeric representation, but they should not share the same Rust type once they cross a module boundary.
Rust Syntax
The current projection morphisms have shapes:
HiddenSequence -> QuerySequence
HiddenSequence -> KeySequence
HiddenSequence -> ValueSequence
Each output type should be distinct.
The current roadmap code models both the role objects and the hidden-state projection morphisms:
HiddenToQuery
HiddenToKey
HiddenToValue
QuerySequence
KeySequence
ValueSequence
Queries, keys, and values are all vectors underneath, but they have different roles.
Worked Example: Same Hidden Row, Three Roles
The query-key-value split is not about three mysterious kinds of vector. It is about three uses of a hidden state.
Start with two hidden rows:
hidden_0 = [1.0, 2.0]
hidden_1 = [3.0, 4.0]
A tiny set of projections can send the same hidden rows into three role-specific objects:
let hidden = HiddenSequence::new(vec![vec![1.0, 2.0], vec![3.0, 4.0]])?;
let to_query = HiddenToQuery::new(
vec![vec![1.0, 0.0], vec![0.0, 1.0]],
vec![0.0, 0.0],
)?;
let to_key = HiddenToKey::new(
vec![vec![0.0, 1.0], vec![1.0, 0.0]],
vec![0.0, 0.0],
)?;
let to_value = HiddenToValue::new(
vec![vec![10.0, 0.0], vec![0.0, 10.0]],
vec![0.0, 0.0],
)?;
For hidden_0, those projections produce:
query_0 = [1.0, 2.0]
key_0 = [2.0, 1.0]
value_0 = [10.0, 20.0]
The numbers are deliberately simple. The important lesson is the role separation:
| Role | Question it answers | Used for |
|---|---|---|
| query | what is this position looking for? | compared with keys |
| key | what can this source position be matched by? | compared with queries |
| value | what information can this source position contribute? | mixed after weights exist |
If all three values were passed around as Vec<Vec<f32>>, the compiler could
not help a reader notice a role mistake. ValueSequence could accidentally be
fed into query-key scoring. KeySequence could accidentally be mixed as if it
were content. The typed split makes that confusion harder to express.
This also explains why the attention path has two phases:
QuerySequence x KeySequence -> AttentionScores
AttentionWeights x ValueSequence -> AttentionOutput
Queries and keys decide where to look. Values provide what gets read.
Self-Attention And Cross-Attention Boundary
The current roadmap example is self-attention: query, key, and value roles all
come from the same HiddenSequence before they are projected into separate
role objects.
That is only one attention case.
Official framework documentation exposes a more general boundary. PyTorch’s
multi-head attention API accepts query, key, and value as separate
inputs. Its shape language distinguishes target sequence length L for
queries from source sequence length S for keys and values. Dive into Deep
Learning makes the same teaching distinction when it writes attention over
n queries and m key-value pairs.
TensorFlow/Keras exposes the same split with different letters: query has
target length T, value and key have source length S, and the attention mask
has shape (B, T, S). That cross-framework agreement is useful because it
keeps the rule from sounding like a PyTorch naming quirk. A target/query row
asks a question. A source/key-value column is something that can be read.
This matters for the book because it prevents a subtle category mistake. The attention scoring boundary is not automatically:
HiddenSequence -> HiddenSequence
The more honest shape is:
Target positions x Source positions -> attention weights
or, in the current Rust vocabulary:
QuerySequence x KeySequence -> AttentionScores
AttentionWeights x ValueSequence -> AttentionOutput
Self-attention is the special case where the target positions and source positions come from the same hidden sequence:
HiddenSequence -> QuerySequence
HiddenSequence -> KeySequence
HiddenSequence -> ValueSequence
These are parallel projections, not a pipeline where queries turn into keys and keys turn into values. The shared source is what makes the case “self-attention”; the role split is still real after projection.
Cross-attention is the case where the query side and the key-value side come from different sequence objects:
TargetHiddenSequence -> QuerySequence
SourceHiddenSequence -> KeySequence
SourceHiddenSequence -> ValueSequence
The tiny repository does not implement a full cross-attention module yet. But the naming rule should already be clear:
same source for Q, K, V -> self-attention case
separate query and key-value sources -> cross-attention case
Use this Q/K/V source diagnostic before reading a framework call:
| Question | Self-attention answer | Cross-attention answer |
|---|---|---|
| Which sequence owns the query side? | the same hidden sequence | the target hidden sequence |
| Which sequence owns the key side? | the same hidden sequence | the source hidden sequence |
| Which sequence owns the value side? | the same hidden sequence | the source hidden sequence |
| Which length counts score rows? | target/query length | target/query length |
| Which length counts score columns? | source/key-value length, equal to target length in the simple self-attention case | source/key-value length, possibly different from target length |
This table prevents a common framework-reading mistake. Passing the same hidden sequence into Q, K, and V means the source object is shared. It does not mean the projected query, key, and value roles have become the same role.
When you run the attention example, the first lines now anchor that diagnostic before any probabilities appear:
Q/K/V source diagnostic:
query rows own score rows; key/value rows own score columns
self-attention shares the hidden source before projection; projected roles stay distinct
mask polarity here: true = allowed, false = blocked
Use those four lines before interpreting attention shape: 2 query positions x 3 key positions. The terminal output gives the learner one inspectable signal
for the source-backed rule above: score rows come from the query side, score
columns come from the key-value side, and the local mask polarity must be
translated before comparing the Rust example with a framework API.
PyTorch and TensorFlow/Keras use different names but expose the same shape split:
| Framework cue | Query side | Key-value side | Mask cue |
|---|---|---|---|
| PyTorch | target length L | source length S | attention weights and masks use L x S |
| TensorFlow/Keras | target length T | source length S | mask shape is (B, T, S) |
| Rust roadmap | QuerySequence | KeySequence and ValueSequence | AttentionMask says which source positions each query may read |
Use the same ledger when reading the Rust types:
| Ledger item | Meaning in framework docs | Meaning in this roadmap | Category-shape consequence |
|---|---|---|---|
| target length | PyTorch L, Keras T | number of QuerySequence rows | score rows belong to the query-side object |
| source length | PyTorch/Keras S | number of KeySequence and ValueSequence rows | score columns belong to the key-value source object |
| attention mask | PyTorch L x S, Keras (B, T, S) | one permission table from query rows to source rows | the mask is context over a product boundary |
| attention output | target-side output rows | one AttentionOutput row for each query row | value mixing returns information to the query side |
The shape ledger gives a quick sanity check:
score table rows == query positions
score table columns == key-value positions
mask cells == query-position/source-position permissions
output rows == query positions after reading values
If those four statements are not true, the explanation has probably collapsed source ownership, role ownership, or mask context too early.
Mask Role Ledger: Permissions, Not Tokens
Framework APIs make the mask shape look like another tensor argument, but the teaching question is more specific:
Which query rows may read which source columns before softmax?
That is why the roadmap names the mask separately from the token sequence, score table, and attention weights. The mask is permission context over the score table.
| Mask misreading | Correct local boundary | What to inspect |
|---|---|---|
| the mask is a shorter token sequence | AttentionScores x AttentionMask -> AttentionScores | the score table keeps query rows and source columns |
| the mask directly produces probabilities | AttentionScores -> AttentionWeights still happens after masking | query 0 attends with [0.5, 0.0, 0.5] |
| the mask is hidden global state | MaskedMultiHeadTransformerBlock : HiddenSequence x AttentionMask -> HiddenSequence | the block boundary keeps the mask visible |
| a fixed mask means no mask exists | MaskedMultiHeadTransformerBlock[M] : HiddenSequence -> HiddenSequence | the chosen mask M is fixed context for that view |
The local rule is:
mask cells select legal score cells;
softmax turns remaining score rows into weights;
weights read value rows.
Do not say “the mask removes tokens” unless you also say which score cells were removed from probability competition. The source sequence still owns the value rows. The mask only says which of those rows each query is allowed to read.
The category-theory reading follows the input count. Self-attention can be wrapped inside a shape-preserving block after projection, masking, value mixing, output projection, residual addition, and normalization return to the hidden stream. The core scoring and mixing steps are still product-input morphisms. Cross-attention makes that product input impossible to ignore, because the target side and source side may have different sequence lengths.
When a framework call reports an attention mask of shape L x S, read it as a
typed reminder:
for each target position, which source positions may be read?
That is why this roadmap names QuerySequence, KeySequence, ValueSequence,
AttentionScores, AttentionMask, AttentionWeights, and AttentionOutput
separately. The names keep target-side questions, source-side comparison, and
source-side information from collapsing into one raw tensor.
ML Concept
Queries ask:
what am I looking for?
Keys answer:
what do I contain?
Values provide:
what information should be mixed?
Category Theory Concept
These are parallel morphisms out of the same object:
HiddenSequence -> QuerySequence
HiddenSequence -> KeySequence
HiddenSequence -> ValueSequence
The current attention example combines query and key roles to produce scores, then uses value roles to produce output vectors.
Design contract:
HiddenSequence -> QuerySequence
HiddenSequence -> KeySequence
HiddenSequence -> ValueSequence
should be three explicit morphisms. A single untyped vector list would make it too easy to pass values into the wrong part of the attention computation.
Step 3: Scaled Dot-Product Attention
The future problem:
Convert query-key similarity into a probability distribution over positions, then use it to mix values.
Rust Syntax
A typed shape could be:
QuerySequence x KeySequence -> AttentionScores
AttentionScores x AttentionMask -> AttentionScores
AttentionScores -> AttentionWeights
AttentionWeights x ValueSequence -> AttentionOutput
AttentionHeadOutputs -> MultiHeadOutput
MultiHeadOutput -> ProjectedAttentionOutput
HiddenSequence x ProjectedAttentionOutput -> HiddenSequence
HiddenSequence -> HiddenSequence
Read the current roadmap code through this shape trace:
flowchart LR
H["HiddenSequence"] --> Q["QuerySequence"]
H --> K["KeySequence"]
H --> V["ValueSequence"]
Q --> S["AttentionScores"]
K --> S
S --> M["Masked Scores"]
Mask["AttentionMask"] --> M
M --> W["AttentionWeights"]
W --> O["AttentionOutput"]
V --> O
O --> MH["MultiHeadOutput"]
MH --> P["ProjectedAttentionOutput"]
H --> R["Residual HiddenSequence"]
P --> R
R --> N["Normalized HiddenSequence"]
N --> FF["FeedForward HiddenSequence"]
The same attention core as a compact rendered math view:
[ \begin{array}{rcl} \mathrm{QuerySequence} \times \mathrm{KeySequence} & \to & \mathrm{AttentionScores} \ \mathrm{AttentionScores} \times \mathrm{AttentionMask} & \to & \mathrm{MaskedScores} \ \mathrm{MaskedScores} & \to & \mathrm{AttentionWeights} \ \mathrm{AttentionWeights} \times \mathrm{ValueSequence} & \to & \mathrm{AttentionOutput} \ \mathrm{AttentionOutput} & \to & \mathrm{ProjectedAttentionOutput} \ \mathrm{HiddenSequence} \times \mathrm{ProjectedAttentionOutput} & \to & \mathrm{HiddenSequence} \end{array} ]
How to read this diagram:
- every product input means two roles must stay visible,
- masking happens before row-wise softmax produces weights,
- value mixing is separate from score calculation,
- the residual step is the first row here that explicitly returns to
HiddenSequence.
What to notice:
Rust reading:
each box is a named type or a named typed boundary in src/attention.rs
ML reading:
scores choose positions, weights mix values, projection and residual return to
the hidden-state width
Category-theory reading:
the middle of attention is a composition with product inputs, and the enclosing
block keeps returning to HiddenSequence
AttentionWeights should be validated like Distribution, but over sequence
positions instead of vocabulary tokens.
The current roadmap code implements the query-key score boundary, the mask boundary, the score-to-weight boundary, the value-mixing boundary, the multi-head concatenation boundary, the output projection boundary, and the residual addition and normalization boundaries:
Source snapshot: src/attention.rs
//! Tiny typed attention boundary for the Transformer roadmap.
//!
//! This module does not implement a full Transformer. It makes the first small
//! attention-specific shapes explicit:
//!
//! - projected queries and keys become query-by-key scores,
//! - masks turn illegal score positions into negligible softmax inputs,
//! - query-by-key scores become row-wise attention probabilities,
//! - attention probabilities mix value vectors into output vectors,
//! - multiple head outputs concatenate into a multi-head output,
//! - the concatenated heads project back into a hidden sequence width,
//! - residual addition preserves the hidden sequence boundary,
//! - layer normalization preserves the hidden sequence boundary,
//! - a position-wise feed-forward map preserves the hidden sequence boundary,
//! - positional encoding adds position information while preserving shape,
//! - a single-head block sketch composes those boundaries end to end,
//! - a masked multi-head block accepts attention masks at the block boundary,
//! - a structured parameter object owns position, block, and readout pieces,
//! - a training-state object owns parameters, learning rate, and step count,
//! - a composed block train step updates readout, feed-forward,
//! normalization, and attention projection parameters from sequence targets.
use crate::category::{Morphism, StepCount};
use crate::domain::{
Distribution, LearningRate, Logits, Loss, ModelDimension, Product, TokenSequence, Vector,
VocabSize,
};
use crate::error::{CtError, CtResult};
use crate::ml::Softmax;
const MASKED_SCORE: f32 = -1_000_000.0;
/// Number of positions in a sequence.
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct SequenceLength(usize);
impl SequenceLength {
pub fn new(value: usize) -> CtResult<Self> {
if value == 0 {
return Err(CtError::EmptyInput("sequence length"));
}
Ok(Self(value))
}
pub fn value(&self) -> usize {
self.0
}
}
/// Number of parallel attention heads.
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct HeadCount(usize);
impl HeadCount {
pub fn new(value: usize) -> CtResult<Self> {
if value == 0 {
return Err(CtError::EmptyInput("head count"));
}
Ok(Self(value))
}
pub fn value(&self) -> usize {
self.0
}
}
/// Width of one attention head.
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct HeadDimension(usize);
impl HeadDimension {
pub fn new(value: usize) -> CtResult<Self> {
if value == 0 {
return Err(CtError::EmptyInput("head dimension"));
}
Ok(Self(value))
}
pub fn value(&self) -> usize {
self.0
}
}
/// Positive stabilizer used in layer normalization.
#[derive(Debug, Clone, Copy, PartialEq)]
pub struct NormalizationEpsilon(f32);
impl NormalizationEpsilon {
pub fn new(value: f32) -> CtResult<Self> {
if !value.is_finite() || value <= 0.0 {
return Err(CtError::ShapeMismatch {
op: "normalization epsilon",
expected: "positive finite epsilon".to_string(),
got: format!("epsilon {value}"),
});
}
Ok(Self(value))
}
pub fn value(&self) -> f32 {
self.0
}
}
/// Projected query vectors for one attention head.
#[derive(Debug, Clone, PartialEq)]
pub struct QuerySequence {
sequence_len: SequenceLength,
head_dimension: HeadDimension,
rows: Vec<Vector>,
}
impl QuerySequence {
pub fn new(rows: Vec<Vec<f32>>) -> CtResult<Self> {
let matrix = AttentionVectorRows::new("query sequence", rows)?;
Ok(Self {
sequence_len: matrix.sequence_len,
head_dimension: matrix.head_dimension,
rows: matrix.rows,
})
}
pub fn sequence_len(&self) -> SequenceLength {
self.sequence_len
}
pub fn head_dimension(&self) -> HeadDimension {
self.head_dimension
}
pub fn rows(&self) -> &[Vector] {
&self.rows
}
}
/// Projected key vectors for one attention head.
#[derive(Debug, Clone, PartialEq)]
pub struct KeySequence {
sequence_len: SequenceLength,
head_dimension: HeadDimension,
rows: Vec<Vector>,
}
impl KeySequence {
pub fn new(rows: Vec<Vec<f32>>) -> CtResult<Self> {
let matrix = AttentionVectorRows::new("key sequence", rows)?;
Ok(Self {
sequence_len: matrix.sequence_len,
head_dimension: matrix.head_dimension,
rows: matrix.rows,
})
}
pub fn sequence_len(&self) -> SequenceLength {
self.sequence_len
}
pub fn head_dimension(&self) -> HeadDimension {
self.head_dimension
}
pub fn rows(&self) -> &[Vector] {
&self.rows
}
}
/// Projected value vectors for one attention head.
#[derive(Debug, Clone, PartialEq)]
pub struct ValueSequence {
sequence_len: SequenceLength,
head_dimension: HeadDimension,
rows: Vec<Vector>,
}
impl ValueSequence {
pub fn new(rows: Vec<Vec<f32>>) -> CtResult<Self> {
let matrix = AttentionVectorRows::new("value sequence", rows)?;
Ok(Self {
sequence_len: matrix.sequence_len,
head_dimension: matrix.head_dimension,
rows: matrix.rows,
})
}
pub fn sequence_len(&self) -> SequenceLength {
self.sequence_len
}
pub fn head_dimension(&self) -> HeadDimension {
self.head_dimension
}
pub fn rows(&self) -> &[Vector] {
&self.rows
}
}
/// Hidden vectors over sequence positions.
#[derive(Debug, Clone, PartialEq)]
pub struct HiddenSequence {
sequence_len: SequenceLength,
model_dimension: ModelDimension,
rows: Vec<Vector>,
}
impl HiddenSequence {
pub fn new(rows: Vec<Vec<f32>>) -> CtResult<Self> {
let matrix = AttentionVectorRows::new("hidden sequence", rows)?;
Ok(Self {
sequence_len: matrix.sequence_len,
model_dimension: ModelDimension::new(matrix.head_dimension.value())?,
rows: matrix.rows,
})
}
pub fn sequence_len(&self) -> SequenceLength {
self.sequence_len
}
pub fn model_dimension(&self) -> ModelDimension {
self.model_dimension
}
pub fn rows(&self) -> &[Vector] {
&self.rows
}
}
/// A finite table of position vectors added to hidden states.
#[derive(Debug, Clone, PartialEq)]
pub struct PositionalEncoding {
max_sequence_len: SequenceLength,
model_dimension: ModelDimension,
rows: Vec<Vector>,
}
impl PositionalEncoding {
pub fn new(rows: Vec<Vec<f32>>) -> CtResult<Self> {
let matrix = AttentionVectorRows::new("positional encoding", rows)?;
Ok(Self {
max_sequence_len: matrix.sequence_len,
model_dimension: ModelDimension::new(matrix.head_dimension.value())?,
rows: matrix.rows,
})
}
pub fn max_sequence_len(&self) -> SequenceLength {
self.max_sequence_len
}
pub fn model_dimension(&self) -> ModelDimension {
self.model_dimension
}
}
#[derive(Debug, Clone, PartialEq)]
struct AttentionVectorRows {
sequence_len: SequenceLength,
head_dimension: HeadDimension,
rows: Vec<Vector>,
}
impl AttentionVectorRows {
fn new(kind: &'static str, rows: Vec<Vec<f32>>) -> CtResult<Self> {
if rows.is_empty() {
return Err(CtError::EmptyInput(kind));
}
let head_dimension = rows[0].len();
if head_dimension == 0 {
return Err(CtError::EmptyInput("attention vector row"));
}
for row in &rows {
if row.len() != head_dimension {
return Err(CtError::ShapeMismatch {
op: kind,
expected: format!("all rows have {head_dimension} columns"),
got: format!("row with {} columns", row.len()),
});
}
if row.iter().any(|value| !value.is_finite()) {
return Err(CtError::ShapeMismatch {
op: kind,
expected: "all vector values are finite".to_string(),
got: "non-finite vector value".to_string(),
});
}
}
Ok(Self {
sequence_len: SequenceLength::new(rows.len())?,
head_dimension: HeadDimension::new(head_dimension)?,
rows: rows.into_iter().map(Vector::new).collect(),
})
}
}
/// Query-by-key scores before row-wise softmax.
#[derive(Debug, Clone, PartialEq)]
pub struct AttentionScores {
query_len: SequenceLength,
key_len: SequenceLength,
rows: Vec<Logits>,
}
impl AttentionScores {
pub fn new(rows: Vec<Vec<f32>>) -> CtResult<Self> {
if rows.is_empty() {
return Err(CtError::EmptyInput("attention scores"));
}
let key_len = rows[0].len();
if key_len == 0 {
return Err(CtError::EmptyInput("attention score row"));
}
for row in &rows {
if row.len() != key_len {
return Err(CtError::ShapeMismatch {
op: "attention scores",
expected: format!("all rows have {key_len} columns"),
got: format!("row with {} columns", row.len()),
});
}
if row.iter().any(|value| !value.is_finite()) {
return Err(CtError::ShapeMismatch {
op: "attention scores",
expected: "all score values are finite".to_string(),
got: "non-finite score value".to_string(),
});
}
}
Ok(Self {
query_len: SequenceLength::new(rows.len())?,
key_len: SequenceLength::new(key_len)?,
rows: rows.into_iter().map(Logits::new).collect(),
})
}
pub fn query_len(&self) -> SequenceLength {
self.query_len
}
pub fn key_len(&self) -> SequenceLength {
self.key_len
}
pub fn rows(&self) -> &[Logits] {
&self.rows
}
}
/// Allowed query-by-key positions before attention softmax.
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct AttentionMask {
query_len: SequenceLength,
key_len: SequenceLength,
rows: Vec<Vec<bool>>,
}
impl AttentionMask {
pub fn new(rows: Vec<Vec<bool>>) -> CtResult<Self> {
if rows.is_empty() {
return Err(CtError::EmptyInput("attention mask"));
}
let key_len = rows[0].len();
if key_len == 0 {
return Err(CtError::EmptyInput("attention mask row"));
}
for row in &rows {
if row.len() != key_len {
return Err(CtError::ShapeMismatch {
op: "attention mask",
expected: format!("all rows have {key_len} columns"),
got: format!("row with {} columns", row.len()),
});
}
if !row.iter().any(|allowed| *allowed) {
return Err(CtError::EmptyInput("attention mask row allows no keys"));
}
}
Ok(Self {
query_len: SequenceLength::new(rows.len())?,
key_len: SequenceLength::new(key_len)?,
rows,
})
}
pub fn query_len(&self) -> SequenceLength {
self.query_len
}
pub fn key_len(&self) -> SequenceLength {
self.key_len
}
pub fn rows(&self) -> &[Vec<bool>] {
&self.rows
}
}
/// Computes scaled query-key dot-product scores.
#[derive(Debug, Clone)]
pub struct ScaledDotProductScores;
impl Morphism<Product<QuerySequence, KeySequence>, AttentionScores> for ScaledDotProductScores {
fn name(&self) -> &'static str {
"scaled_dot_product_scores"
}
fn apply(&self, input: Product<QuerySequence, KeySequence>) -> CtResult<AttentionScores> {
let (queries, keys) = input.into_parts();
let query_dimension = queries.head_dimension();
let key_dimension = keys.head_dimension();
if query_dimension != key_dimension {
return Err(CtError::ShapeMismatch {
op: "scaled dot-product attention scores",
expected: format!("query head dimension {}", query_dimension.value()),
got: format!("key head dimension {}", key_dimension.value()),
});
}
let scale = (query_dimension.value() as f32).sqrt();
let rows = queries
.rows()
.iter()
.map(|query| {
keys.rows()
.iter()
.map(|key| dot_product(query.as_slice(), key.as_slice()) / scale)
.collect::<Vec<_>>()
})
.collect::<Vec<_>>();
AttentionScores::new(rows)
}
}
fn dot_product(left: &[f32], right: &[f32]) -> f32 {
left.iter()
.zip(right.iter())
.map(|(left, right)| left * right)
.sum()
}
/// Applies a boolean attention mask to score rows before softmax.
#[derive(Debug, Clone)]
pub struct MaskedAttentionScores;
impl Morphism<Product<AttentionScores, AttentionMask>, AttentionScores> for MaskedAttentionScores {
fn name(&self) -> &'static str {
"masked_attention_scores"
}
fn apply(&self, input: Product<AttentionScores, AttentionMask>) -> CtResult<AttentionScores> {
let (scores, mask) = input.into_parts();
if scores.query_len() != mask.query_len() || scores.key_len() != mask.key_len() {
return Err(CtError::ShapeMismatch {
op: "masked attention scores",
expected: format!(
"{} query rows x {} key columns",
scores.query_len().value(),
scores.key_len().value()
),
got: format!(
"{} query rows x {} key columns",
mask.query_len().value(),
mask.key_len().value()
),
});
}
let rows = scores
.rows()
.iter()
.zip(mask.rows())
.map(|(score_row, mask_row)| {
score_row
.as_slice()
.iter()
.zip(mask_row)
.map(|(score, allowed)| if *allowed { *score } else { MASKED_SCORE })
.collect::<Vec<_>>()
})
.collect::<Vec<_>>();
AttentionScores::new(rows)
}
}
/// Row-wise attention probabilities over key positions.
#[derive(Debug, Clone, PartialEq)]
pub struct AttentionWeights {
query_len: SequenceLength,
key_len: SequenceLength,
rows: Vec<Distribution>,
}
impl AttentionWeights {
pub fn new(rows: Vec<Distribution>) -> CtResult<Self> {
if rows.is_empty() {
return Err(CtError::EmptyInput("attention weights"));
}
let key_len = rows[0].as_slice().len();
if key_len == 0 {
return Err(CtError::EmptyInput("attention weight row"));
}
for row in &rows {
if row.as_slice().len() != key_len {
return Err(CtError::ShapeMismatch {
op: "attention weights",
expected: format!("all rows have {key_len} columns"),
got: format!("row with {} columns", row.as_slice().len()),
});
}
}
Ok(Self {
query_len: SequenceLength::new(rows.len())?,
key_len: SequenceLength::new(key_len)?,
rows,
})
}
pub fn query_len(&self) -> SequenceLength {
self.query_len
}
pub fn key_len(&self) -> SequenceLength {
self.key_len
}
pub fn rows(&self) -> &[Distribution] {
&self.rows
}
}
/// Weighted value vectors, one output row per query position.
#[derive(Debug, Clone, PartialEq)]
pub struct AttentionOutput {
sequence_len: SequenceLength,
head_dimension: HeadDimension,
rows: Vec<Vector>,
}
impl AttentionOutput {
pub fn new(rows: Vec<Vec<f32>>) -> CtResult<Self> {
let matrix = AttentionVectorRows::new("attention output", rows)?;
Ok(Self {
sequence_len: matrix.sequence_len,
head_dimension: matrix.head_dimension,
rows: matrix.rows,
})
}
pub fn sequence_len(&self) -> SequenceLength {
self.sequence_len
}
pub fn head_dimension(&self) -> HeadDimension {
self.head_dimension
}
pub fn rows(&self) -> &[Vector] {
&self.rows
}
}
/// Validated collection of single-head attention outputs.
#[derive(Debug, Clone, PartialEq)]
pub struct AttentionHeadOutputs {
head_count: HeadCount,
sequence_len: SequenceLength,
head_dimension: HeadDimension,
heads: Vec<AttentionOutput>,
}
impl AttentionHeadOutputs {
pub fn new(heads: Vec<AttentionOutput>) -> CtResult<Self> {
if heads.is_empty() {
return Err(CtError::EmptyInput("attention head outputs"));
}
let sequence_len = heads[0].sequence_len();
let head_dimension = heads[0].head_dimension();
for head in &heads {
if head.sequence_len() != sequence_len {
return Err(CtError::ShapeMismatch {
op: "attention head outputs",
expected: format!("all heads have {} sequence rows", sequence_len.value()),
got: format!("head with {} sequence rows", head.sequence_len().value()),
});
}
if head.head_dimension() != head_dimension {
return Err(CtError::ShapeMismatch {
op: "attention head outputs",
expected: format!("all heads have dimension {}", head_dimension.value()),
got: format!("head dimension {}", head.head_dimension().value()),
});
}
}
Ok(Self {
head_count: HeadCount::new(heads.len())?,
sequence_len,
head_dimension,
heads,
})
}
pub fn head_count(&self) -> HeadCount {
self.head_count
}
pub fn sequence_len(&self) -> SequenceLength {
self.sequence_len
}
pub fn head_dimension(&self) -> HeadDimension {
self.head_dimension
}
pub fn heads(&self) -> &[AttentionOutput] {
&self.heads
}
}
/// Concatenated output of several attention heads.
#[derive(Debug, Clone, PartialEq)]
pub struct MultiHeadOutput {
sequence_len: SequenceLength,
head_count: HeadCount,
head_dimension: HeadDimension,
model_dimension: ModelDimension,
rows: Vec<Vector>,
}
impl MultiHeadOutput {
fn new(
rows: Vec<Vec<f32>>,
head_count: HeadCount,
head_dimension: HeadDimension,
) -> CtResult<Self> {
let matrix = AttentionVectorRows::new("multi-head output", rows)?;
let expected_dimension = head_count.value() * head_dimension.value();
if matrix.head_dimension.value() != expected_dimension {
return Err(CtError::ShapeMismatch {
op: "multi-head output",
expected: format!("row dimension {expected_dimension}"),
got: format!("row dimension {}", matrix.head_dimension.value()),
});
}
Ok(Self {
sequence_len: matrix.sequence_len,
head_count,
head_dimension,
model_dimension: ModelDimension::new(expected_dimension)?,
rows: matrix.rows,
})
}
pub fn sequence_len(&self) -> SequenceLength {
self.sequence_len
}
pub fn head_count(&self) -> HeadCount {
self.head_count
}
pub fn head_dimension(&self) -> HeadDimension {
self.head_dimension
}
pub fn model_dimension(&self) -> ModelDimension {
self.model_dimension
}
pub fn rows(&self) -> &[Vector] {
&self.rows
}
}
/// Output sequence after the multi-head output projection.
#[derive(Debug, Clone, PartialEq)]
pub struct ProjectedAttentionOutput {
sequence_len: SequenceLength,
model_dimension: ModelDimension,
rows: Vec<Vector>,
}
impl ProjectedAttentionOutput {
pub fn new(rows: Vec<Vec<f32>>) -> CtResult<Self> {
let matrix = AttentionVectorRows::new("projected attention output", rows)?;
Ok(Self {
sequence_len: matrix.sequence_len,
model_dimension: ModelDimension::new(matrix.head_dimension.value())?,
rows: matrix.rows,
})
}
pub fn sequence_len(&self) -> SequenceLength {
self.sequence_len
}
pub fn model_dimension(&self) -> ModelDimension {
self.model_dimension
}
pub fn rows(&self) -> &[Vector] {
&self.rows
}
}
/// Vocabulary logits for every position in a hidden sequence.
#[derive(Debug, Clone, PartialEq)]
pub struct SequenceLogits {
sequence_len: SequenceLength,
vocab_size: VocabSize,
rows: Vec<Logits>,
}
impl SequenceLogits {
pub fn new(rows: Vec<Vec<f32>>) -> CtResult<Self> {
if rows.is_empty() {
return Err(CtError::EmptyInput("sequence logits"));
}
let vocab_size = rows[0].len();
if vocab_size == 0 {
return Err(CtError::EmptyInput("sequence logits row"));
}
for row in &rows {
if row.len() != vocab_size {
return Err(CtError::ShapeMismatch {
op: "sequence logits",
expected: format!("all rows have {vocab_size} columns"),
got: format!("row with {} columns", row.len()),
});
}
if row.iter().any(|value| !value.is_finite()) {
return Err(CtError::ShapeMismatch {
op: "sequence logits",
expected: "all logit values are finite".to_string(),
got: "non-finite logit value".to_string(),
});
}
}
Ok(Self {
sequence_len: SequenceLength::new(rows.len())?,
vocab_size: VocabSize::new(vocab_size)?,
rows: rows.into_iter().map(Logits::new).collect(),
})
}
pub fn sequence_len(&self) -> SequenceLength {
self.sequence_len
}
pub fn vocab_size(&self) -> VocabSize {
self.vocab_size
}
pub fn rows(&self) -> &[Logits] {
&self.rows
}
}
/// Learned language-model readout applied to each hidden position.
#[derive(Debug, Clone, PartialEq)]
pub struct TransformerReadout {
input_dimension: ModelDimension,
vocab_size: VocabSize,
weight: Vec<Vec<f32>>,
bias: Vec<f32>,
}
impl TransformerReadout {
pub fn new(weight: Vec<Vec<f32>>, bias: Vec<f32>) -> CtResult<Self> {
let (input_dimension, output_dimension) =
validate_linear_parts("transformer readout", &weight, &bias)?;
Ok(Self {
input_dimension,
vocab_size: VocabSize::new(output_dimension.value())?,
weight,
bias,
})
}
pub fn input_dimension(&self) -> ModelDimension {
self.input_dimension
}
pub fn vocab_size(&self) -> VocabSize {
self.vocab_size
}
pub fn weight(&self) -> &[Vec<f32>] {
&self.weight
}
pub fn bias(&self) -> &[f32] {
&self.bias
}
}
/// Learned output projection after head concatenation.
#[derive(Debug, Clone, PartialEq)]
pub struct AttentionOutputProjection {
input_dimension: ModelDimension,
output_dimension: ModelDimension,
weight: Vec<Vec<f32>>,
bias: Vec<f32>,
}
impl AttentionOutputProjection {
pub fn new(weight: Vec<Vec<f32>>, bias: Vec<f32>) -> CtResult<Self> {
if weight.is_empty() {
return Err(CtError::EmptyInput("attention output projection weight"));
}
if bias.is_empty() {
return Err(CtError::EmptyInput("attention output projection bias"));
}
let output_dimension = bias.len();
if bias.iter().any(|value| !value.is_finite()) {
return Err(CtError::ShapeMismatch {
op: "attention output projection",
expected: "finite bias values".to_string(),
got: "non-finite bias value".to_string(),
});
}
for row in &weight {
if row.len() != output_dimension {
return Err(CtError::ShapeMismatch {
op: "attention output projection",
expected: format!("weight rows have {output_dimension} columns"),
got: format!("weight row with {} columns", row.len()),
});
}
if row.iter().any(|value| !value.is_finite()) {
return Err(CtError::ShapeMismatch {
op: "attention output projection",
expected: "finite weight values".to_string(),
got: "non-finite weight value".to_string(),
});
}
}
Ok(Self {
input_dimension: ModelDimension::new(weight.len())?,
output_dimension: ModelDimension::new(output_dimension)?,
weight,
bias,
})
}
pub fn input_dimension(&self) -> ModelDimension {
self.input_dimension
}
pub fn output_dimension(&self) -> ModelDimension {
self.output_dimension
}
pub fn weight(&self) -> &[Vec<f32>] {
&self.weight
}
pub fn bias(&self) -> &[f32] {
&self.bias
}
}
/// Scale, shift, and epsilon parameters for layer normalization.
#[derive(Debug, Clone, PartialEq)]
pub struct LayerNormParameters {
model_dimension: ModelDimension,
scale: Vec<f32>,
shift: Vec<f32>,
epsilon: NormalizationEpsilon,
}
impl LayerNormParameters {
pub fn new(scale: Vec<f32>, shift: Vec<f32>, epsilon: NormalizationEpsilon) -> CtResult<Self> {
if scale.is_empty() {
return Err(CtError::EmptyInput("layer norm scale"));
}
if shift.is_empty() {
return Err(CtError::EmptyInput("layer norm shift"));
}
if scale.len() != shift.len() {
return Err(CtError::ShapeMismatch {
op: "layer norm parameters",
expected: format!("scale and shift length {}", scale.len()),
got: format!("shift length {}", shift.len()),
});
}
if scale.iter().any(|value| !value.is_finite()) {
return Err(CtError::ShapeMismatch {
op: "layer norm parameters",
expected: "finite scale values".to_string(),
got: "non-finite scale value".to_string(),
});
}
if shift.iter().any(|value| !value.is_finite()) {
return Err(CtError::ShapeMismatch {
op: "layer norm parameters",
expected: "finite shift values".to_string(),
got: "non-finite shift value".to_string(),
});
}
Ok(Self {
model_dimension: ModelDimension::new(scale.len())?,
scale,
shift,
epsilon,
})
}
pub fn identity(model_dimension: ModelDimension) -> Self {
Self {
model_dimension,
scale: vec![1.0; model_dimension.value()],
shift: vec![0.0; model_dimension.value()],
epsilon: NormalizationEpsilon(1e-5),
}
}
pub fn model_dimension(&self) -> ModelDimension {
self.model_dimension
}
pub fn scale(&self) -> &[f32] {
&self.scale
}
pub fn shift(&self) -> &[f32] {
&self.shift
}
pub fn epsilon(&self) -> NormalizationEpsilon {
self.epsilon
}
}
/// Layer normalization over each hidden vector independently.
#[derive(Debug, Clone, PartialEq)]
pub struct LayerNormalization {
parameters: LayerNormParameters,
}
impl LayerNormalization {
pub fn new(parameters: LayerNormParameters) -> Self {
Self { parameters }
}
pub fn model_dimension(&self) -> ModelDimension {
self.parameters.model_dimension()
}
pub fn parameters(&self) -> &LayerNormParameters {
&self.parameters
}
}
#[derive(Debug, Clone, PartialEq)]
struct FeedForwardRowCache {
input: Vec<f32>,
pre_activation: Vec<f32>,
activation: Vec<f32>,
output: Vec<f32>,
}
#[derive(Debug, Clone, PartialEq)]
struct AttentionHeadTrainingCache {
queries: QuerySequence,
keys: KeySequence,
values: ValueSequence,
weights: AttentionWeights,
output: AttentionOutput,
}
#[derive(Debug, Clone, PartialEq)]
struct MaskedBlockTrainingCache {
output: HiddenSequence,
with_feed_forward: HiddenSequence,
with_attention: HiddenSequence,
multi_head_output: MultiHeadOutput,
attention_heads: Vec<AttentionHeadTrainingCache>,
feed_forward_rows: Vec<FeedForwardRowCache>,
}
/// Position-wise two-layer feed-forward sublayer.
#[derive(Debug, Clone, PartialEq)]
pub struct PositionWiseFeedForward {
input_dimension: ModelDimension,
hidden_dimension: ModelDimension,
output_dimension: ModelDimension,
first_weight: Vec<Vec<f32>>,
first_bias: Vec<f32>,
second_weight: Vec<Vec<f32>>,
second_bias: Vec<f32>,
}
impl PositionWiseFeedForward {
pub fn new(
first_weight: Vec<Vec<f32>>,
first_bias: Vec<f32>,
second_weight: Vec<Vec<f32>>,
second_bias: Vec<f32>,
) -> CtResult<Self> {
let (input_dimension, hidden_dimension) = validate_linear_parts(
"position-wise feed-forward first layer",
&first_weight,
&first_bias,
)?;
let (second_input_dimension, output_dimension) = validate_linear_parts(
"position-wise feed-forward second layer",
&second_weight,
&second_bias,
)?;
if second_input_dimension != hidden_dimension {
return Err(CtError::ShapeMismatch {
op: "position-wise feed-forward",
expected: format!("second input dimension {}", hidden_dimension.value()),
got: format!("second input dimension {}", second_input_dimension.value()),
});
}
if output_dimension != input_dimension {
return Err(CtError::ShapeMismatch {
op: "position-wise feed-forward",
expected: format!("output dimension {}", input_dimension.value()),
got: format!("output dimension {}", output_dimension.value()),
});
}
Ok(Self {
input_dimension,
hidden_dimension,
output_dimension,
first_weight,
first_bias,
second_weight,
second_bias,
})
}
pub fn input_dimension(&self) -> ModelDimension {
self.input_dimension
}
pub fn hidden_dimension(&self) -> ModelDimension {
self.hidden_dimension
}
pub fn output_dimension(&self) -> ModelDimension {
self.output_dimension
}
pub fn first_weight(&self) -> &[Vec<f32>] {
&self.first_weight
}
pub fn first_bias(&self) -> &[f32] {
&self.first_bias
}
pub fn second_weight(&self) -> &[Vec<f32>] {
&self.second_weight
}
pub fn second_bias(&self) -> &[f32] {
&self.second_bias
}
}
#[derive(Debug, Clone, PartialEq)]
struct HiddenProjection {
op: &'static str,
input_dimension: ModelDimension,
head_dimension: HeadDimension,
weight: Vec<Vec<f32>>,
bias: Vec<f32>,
}
impl HiddenProjection {
fn new(op: &'static str, weight: Vec<Vec<f32>>, bias: Vec<f32>) -> CtResult<Self> {
let (input_dimension, output_dimension) = validate_linear_parts(op, &weight, &bias)?;
Ok(Self {
op,
input_dimension,
head_dimension: HeadDimension::new(output_dimension.value())?,
weight,
bias,
})
}
fn project(&self, input: &HiddenSequence) -> CtResult<Vec<Vec<f32>>> {
if input.model_dimension() != self.input_dimension {
return Err(CtError::ShapeMismatch {
op: self.op,
expected: format!("input dimension {}", self.input_dimension.value()),
got: format!("input dimension {}", input.model_dimension().value()),
});
}
Ok(input
.rows()
.iter()
.map(|row| project_row(row.as_slice(), &self.weight, &self.bias))
.collect::<Vec<_>>())
}
fn input_dimension(&self) -> ModelDimension {
self.input_dimension
}
fn head_dimension(&self) -> HeadDimension {
self.head_dimension
}
fn weight(&self) -> &[Vec<f32>] {
&self.weight
}
fn bias(&self) -> &[f32] {
&self.bias
}
}
/// Learned projection from hidden states to query vectors.
#[derive(Debug, Clone, PartialEq)]
pub struct HiddenToQuery {
projection: HiddenProjection,
}
impl HiddenToQuery {
pub fn new(weight: Vec<Vec<f32>>, bias: Vec<f32>) -> CtResult<Self> {
Ok(Self {
projection: HiddenProjection::new("hidden-to-query projection", weight, bias)?,
})
}
pub fn input_dimension(&self) -> ModelDimension {
self.projection.input_dimension()
}
pub fn head_dimension(&self) -> HeadDimension {
self.projection.head_dimension()
}
pub fn weight(&self) -> &[Vec<f32>] {
self.projection.weight()
}
pub fn bias(&self) -> &[f32] {
self.projection.bias()
}
}
/// Learned projection from hidden states to key vectors.
#[derive(Debug, Clone, PartialEq)]
pub struct HiddenToKey {
projection: HiddenProjection,
}
impl HiddenToKey {
pub fn new(weight: Vec<Vec<f32>>, bias: Vec<f32>) -> CtResult<Self> {
Ok(Self {
projection: HiddenProjection::new("hidden-to-key projection", weight, bias)?,
})
}
pub fn input_dimension(&self) -> ModelDimension {
self.projection.input_dimension()
}
pub fn head_dimension(&self) -> HeadDimension {
self.projection.head_dimension()
}
pub fn weight(&self) -> &[Vec<f32>] {
self.projection.weight()
}
pub fn bias(&self) -> &[f32] {
self.projection.bias()
}
}
/// Learned projection from hidden states to value vectors.
#[derive(Debug, Clone, PartialEq)]
pub struct HiddenToValue {
projection: HiddenProjection,
}
impl HiddenToValue {
pub fn new(weight: Vec<Vec<f32>>, bias: Vec<f32>) -> CtResult<Self> {
Ok(Self {
projection: HiddenProjection::new("hidden-to-value projection", weight, bias)?,
})
}
pub fn input_dimension(&self) -> ModelDimension {
self.projection.input_dimension()
}
pub fn head_dimension(&self) -> HeadDimension {
self.projection.head_dimension()
}
pub fn weight(&self) -> &[Vec<f32>] {
self.projection.weight()
}
pub fn bias(&self) -> &[f32] {
self.projection.bias()
}
}
/// A tiny single-head Transformer block sketch.
#[derive(Debug, Clone, PartialEq)]
pub struct SingleHeadTransformerBlock {
model_dimension: ModelDimension,
query_projection: HiddenToQuery,
key_projection: HiddenToKey,
value_projection: HiddenToValue,
output_projection: AttentionOutputProjection,
attention_norm: LayerNormalization,
feed_forward: PositionWiseFeedForward,
feed_forward_norm: LayerNormalization,
}
impl SingleHeadTransformerBlock {
pub fn new(
query_projection: HiddenToQuery,
key_projection: HiddenToKey,
value_projection: HiddenToValue,
output_projection: AttentionOutputProjection,
attention_norm: LayerNormalization,
feed_forward: PositionWiseFeedForward,
feed_forward_norm: LayerNormalization,
) -> CtResult<Self> {
let model_dimension = query_projection.input_dimension();
validate_projection_input(
"single-head block key projection",
model_dimension,
key_projection.input_dimension(),
)?;
validate_projection_input(
"single-head block value projection",
model_dimension,
value_projection.input_dimension(),
)?;
if query_projection.head_dimension() != key_projection.head_dimension() {
return Err(CtError::ShapeMismatch {
op: "single-head block",
expected: format!(
"query/key head dimension {}",
query_projection.head_dimension().value()
),
got: format!(
"key head dimension {}",
key_projection.head_dimension().value()
),
});
}
if output_projection.input_dimension().value() != value_projection.head_dimension().value()
{
return Err(CtError::ShapeMismatch {
op: "single-head block",
expected: format!(
"output projection input dimension {}",
value_projection.head_dimension().value()
),
got: format!(
"output projection input dimension {}",
output_projection.input_dimension().value()
),
});
}
validate_projection_input(
"single-head block output projection",
model_dimension,
output_projection.output_dimension(),
)?;
validate_projection_input(
"single-head block attention normalization",
model_dimension,
attention_norm.model_dimension(),
)?;
validate_projection_input(
"single-head block feed-forward",
model_dimension,
feed_forward.input_dimension(),
)?;
validate_projection_input(
"single-head block feed-forward normalization",
model_dimension,
feed_forward_norm.model_dimension(),
)?;
Ok(Self {
model_dimension,
query_projection,
key_projection,
value_projection,
output_projection,
attention_norm,
feed_forward,
feed_forward_norm,
})
}
pub fn model_dimension(&self) -> ModelDimension {
self.model_dimension
}
}
/// Learned query, key, and value projections for one self-attention head.
#[derive(Debug, Clone, PartialEq)]
pub struct SelfAttentionHead {
query_projection: HiddenToQuery,
key_projection: HiddenToKey,
value_projection: HiddenToValue,
}
impl SelfAttentionHead {
pub fn new(
query_projection: HiddenToQuery,
key_projection: HiddenToKey,
value_projection: HiddenToValue,
) -> CtResult<Self> {
let input_dimension = query_projection.input_dimension();
validate_projection_input(
"self-attention head key projection",
input_dimension,
key_projection.input_dimension(),
)?;
validate_projection_input(
"self-attention head value projection",
input_dimension,
value_projection.input_dimension(),
)?;
if query_projection.head_dimension() != key_projection.head_dimension() {
return Err(CtError::ShapeMismatch {
op: "self-attention head",
expected: format!(
"query/key head dimension {}",
query_projection.head_dimension().value()
),
got: format!(
"key head dimension {}",
key_projection.head_dimension().value()
),
});
}
Ok(Self {
query_projection,
key_projection,
value_projection,
})
}
pub fn input_dimension(&self) -> ModelDimension {
self.query_projection.input_dimension()
}
pub fn query_key_dimension(&self) -> HeadDimension {
self.query_projection.head_dimension()
}
pub fn value_dimension(&self) -> HeadDimension {
self.value_projection.head_dimension()
}
pub fn query_projection(&self) -> &HiddenToQuery {
&self.query_projection
}
pub fn key_projection(&self) -> &HiddenToKey {
&self.key_projection
}
pub fn value_projection(&self) -> &HiddenToValue {
&self.value_projection
}
}
/// A tiny multi-head Transformer block sketch.
#[derive(Debug, Clone, PartialEq)]
pub struct MultiHeadTransformerBlock {
model_dimension: ModelDimension,
head_count: HeadCount,
value_dimension: HeadDimension,
heads: Vec<SelfAttentionHead>,
output_projection: AttentionOutputProjection,
attention_norm: LayerNormalization,
feed_forward: PositionWiseFeedForward,
feed_forward_norm: LayerNormalization,
}
impl MultiHeadTransformerBlock {
pub fn new(
heads: Vec<SelfAttentionHead>,
output_projection: AttentionOutputProjection,
attention_norm: LayerNormalization,
feed_forward: PositionWiseFeedForward,
feed_forward_norm: LayerNormalization,
) -> CtResult<Self> {
if heads.is_empty() {
return Err(CtError::EmptyInput("multi-head block heads"));
}
let model_dimension = heads[0].input_dimension();
let value_dimension = heads[0].value_dimension();
for head in &heads {
validate_projection_input(
"multi-head block head projection",
model_dimension,
head.input_dimension(),
)?;
if head.value_dimension() != value_dimension {
return Err(CtError::ShapeMismatch {
op: "multi-head block",
expected: format!("value head dimension {}", value_dimension.value()),
got: format!("value head dimension {}", head.value_dimension().value()),
});
}
}
let head_count = HeadCount::new(heads.len())?;
let concatenated_dimension =
ModelDimension::new(head_count.value() * value_dimension.value())?;
validate_projection_input(
"multi-head block output projection input",
concatenated_dimension,
output_projection.input_dimension(),
)?;
validate_projection_input(
"multi-head block output projection",
model_dimension,
output_projection.output_dimension(),
)?;
validate_projection_input(
"multi-head block attention normalization",
model_dimension,
attention_norm.model_dimension(),
)?;
validate_projection_input(
"multi-head block feed-forward",
model_dimension,
feed_forward.input_dimension(),
)?;
validate_projection_input(
"multi-head block feed-forward normalization",
model_dimension,
feed_forward_norm.model_dimension(),
)?;
Ok(Self {
model_dimension,
head_count,
value_dimension,
heads,
output_projection,
attention_norm,
feed_forward,
feed_forward_norm,
})
}
pub fn model_dimension(&self) -> ModelDimension {
self.model_dimension
}
pub fn head_count(&self) -> HeadCount {
self.head_count
}
pub fn value_dimension(&self) -> HeadDimension {
self.value_dimension
}
pub fn heads(&self) -> &[SelfAttentionHead] {
&self.heads
}
fn with_heads(self, heads: Vec<SelfAttentionHead>) -> CtResult<Self> {
Self::new(
heads,
self.output_projection,
self.attention_norm,
self.feed_forward,
self.feed_forward_norm,
)
}
fn with_feed_forward(self, feed_forward: PositionWiseFeedForward) -> CtResult<Self> {
Self::new(
self.heads,
self.output_projection,
self.attention_norm,
feed_forward,
self.feed_forward_norm,
)
}
fn with_output_projection(
self,
output_projection: AttentionOutputProjection,
) -> CtResult<Self> {
Self::new(
self.heads,
output_projection,
self.attention_norm,
self.feed_forward,
self.feed_forward_norm,
)
}
fn with_layer_norms(
self,
attention_norm: LayerNormalization,
feed_forward_norm: LayerNormalization,
) -> CtResult<Self> {
Self::new(
self.heads,
self.output_projection,
attention_norm,
self.feed_forward,
feed_forward_norm,
)
}
}
/// A tiny masked multi-head Transformer block sketch.
#[derive(Debug, Clone, PartialEq)]
pub struct MaskedMultiHeadTransformerBlock {
block: MultiHeadTransformerBlock,
}
impl MaskedMultiHeadTransformerBlock {
pub fn new(
heads: Vec<SelfAttentionHead>,
output_projection: AttentionOutputProjection,
attention_norm: LayerNormalization,
feed_forward: PositionWiseFeedForward,
feed_forward_norm: LayerNormalization,
) -> CtResult<Self> {
Ok(Self {
block: MultiHeadTransformerBlock::new(
heads,
output_projection,
attention_norm,
feed_forward,
feed_forward_norm,
)?,
})
}
pub fn model_dimension(&self) -> ModelDimension {
self.block.model_dimension()
}
pub fn head_count(&self) -> HeadCount {
self.block.head_count()
}
pub fn value_dimension(&self) -> HeadDimension {
self.block.value_dimension()
}
pub fn heads(&self) -> &[SelfAttentionHead] {
self.block.heads()
}
pub fn feed_forward(&self) -> &PositionWiseFeedForward {
&self.block.feed_forward
}
fn with_feed_forward(self, feed_forward: PositionWiseFeedForward) -> CtResult<Self> {
Ok(Self {
block: self.block.with_feed_forward(feed_forward)?,
})
}
fn with_heads(self, heads: Vec<SelfAttentionHead>) -> CtResult<Self> {
Ok(Self {
block: self.block.with_heads(heads)?,
})
}
fn with_output_projection(
self,
output_projection: AttentionOutputProjection,
) -> CtResult<Self> {
Ok(Self {
block: self.block.with_output_projection(output_projection)?,
})
}
fn with_layer_norms(
self,
attention_norm: LayerNormalization,
feed_forward_norm: LayerNormalization,
) -> CtResult<Self> {
Ok(Self {
block: self
.block
.with_layer_norms(attention_norm, feed_forward_norm)?,
})
}
fn apply_with_training_cache(
&self,
hidden: HiddenSequence,
mask: AttentionMask,
) -> CtResult<MaskedBlockTrainingCache> {
if hidden.model_dimension() != self.block.model_dimension {
return Err(CtError::ShapeMismatch {
op: "masked multi-head block",
expected: format!("model dimension {}", self.block.model_dimension.value()),
got: format!("model dimension {}", hidden.model_dimension().value()),
});
}
let head_caches = self
.block
.heads
.iter()
.map(|head| apply_self_attention_head_with_mask_cache(&hidden, head, Some(&mask)))
.collect::<CtResult<Vec<_>>>()?;
let attention_outputs = head_caches
.iter()
.map(|cache| cache.output.clone())
.collect::<Vec<_>>();
let head_outputs = AttentionHeadOutputs::new(attention_outputs)?;
let multi_head_output = ConcatenateHeads.apply(head_outputs)?;
let projected_attention = self
.block
.output_projection
.apply(multi_head_output.clone())?;
let with_attention = ResidualConnection.apply(Product::new(hidden, projected_attention))?;
let normalized_attention = self.block.attention_norm.apply(with_attention.clone())?;
let (feed_forward_output, feed_forward_rows) =
feed_forward_with_cache(&self.block.feed_forward, &normalized_attention)?;
let with_feed_forward =
ResidualConnection.apply(Product::new(normalized_attention, feed_forward_output))?;
let output = self
.block
.feed_forward_norm
.apply(with_feed_forward.clone())?;
Ok(MaskedBlockTrainingCache {
output,
with_feed_forward,
with_attention,
multi_head_output,
attention_heads: head_caches,
feed_forward_rows,
})
}
}
/// Tiny structured Transformer parameter object for the roadmap.
#[derive(Debug, Clone, PartialEq)]
pub struct TinyTransformerParameters {
positional_encoding: PositionalEncoding,
block: MaskedMultiHeadTransformerBlock,
readout: TransformerReadout,
}
impl TinyTransformerParameters {
pub fn new(
positional_encoding: PositionalEncoding,
block: MaskedMultiHeadTransformerBlock,
readout: TransformerReadout,
) -> CtResult<Self> {
let model_dimension = positional_encoding.model_dimension();
validate_projection_input(
"tiny transformer parameters block",
model_dimension,
block.model_dimension(),
)?;
validate_projection_input(
"tiny transformer parameters readout",
model_dimension,
readout.input_dimension(),
)?;
Ok(Self {
positional_encoding,
block,
readout,
})
}
pub fn model_dimension(&self) -> ModelDimension {
self.positional_encoding.model_dimension()
}
pub fn max_sequence_len(&self) -> SequenceLength {
self.positional_encoding.max_sequence_len()
}
pub fn vocab_size(&self) -> VocabSize {
self.readout.vocab_size()
}
pub fn encode(&self, hidden: HiddenSequence, mask: AttentionMask) -> CtResult<HiddenSequence> {
let positioned = self.positional_encoding.apply(hidden)?;
self.block.apply(Product::new(positioned, mask))
}
pub fn readout(&self) -> &TransformerReadout {
&self.readout
}
pub fn feed_forward(&self) -> &PositionWiseFeedForward {
self.block.feed_forward()
}
pub fn output_projection(&self) -> &AttentionOutputProjection {
&self.block.block.output_projection
}
pub fn attention_heads(&self) -> &[SelfAttentionHead] {
self.block.heads()
}
pub fn attention_norm(&self) -> &LayerNormalization {
&self.block.block.attention_norm
}
pub fn feed_forward_norm(&self) -> &LayerNormalization {
&self.block.block.feed_forward_norm
}
fn with_readout(self, readout: TransformerReadout) -> CtResult<Self> {
Self::new(self.positional_encoding, self.block, readout)
}
fn with_feed_forward(self, feed_forward: PositionWiseFeedForward) -> CtResult<Self> {
Self::new(
self.positional_encoding,
self.block.with_feed_forward(feed_forward)?,
self.readout,
)
}
fn with_attention_heads(self, heads: Vec<SelfAttentionHead>) -> CtResult<Self> {
Self::new(
self.positional_encoding,
self.block.with_heads(heads)?,
self.readout,
)
}
fn with_output_projection(
self,
output_projection: AttentionOutputProjection,
) -> CtResult<Self> {
Self::new(
self.positional_encoding,
self.block.with_output_projection(output_projection)?,
self.readout,
)
}
fn with_layer_norms(
self,
attention_norm: LayerNormalization,
feed_forward_norm: LayerNormalization,
) -> CtResult<Self> {
Self::new(
self.positional_encoding,
self.block
.with_layer_norms(attention_norm, feed_forward_norm)?,
self.readout,
)
}
}
/// Structured state owned by a future Transformer training loop.
#[derive(Debug, Clone, PartialEq)]
pub struct TransformerTrainingState {
parameters: TinyTransformerParameters,
learning_rate: LearningRate,
step_count: StepCount,
}
impl TransformerTrainingState {
pub fn new(parameters: TinyTransformerParameters, learning_rate: LearningRate) -> Self {
Self::from_parts(parameters, learning_rate, StepCount::new(0))
}
pub fn from_parts(
parameters: TinyTransformerParameters,
learning_rate: LearningRate,
step_count: StepCount,
) -> Self {
Self {
parameters,
learning_rate,
step_count,
}
}
pub fn parameters(&self) -> &TinyTransformerParameters {
&self.parameters
}
pub fn learning_rate(&self) -> LearningRate {
self.learning_rate
}
pub fn step_count(&self) -> StepCount {
self.step_count
}
pub fn record_updated_parameters(self, parameters: TinyTransformerParameters) -> Self {
Self {
parameters,
learning_rate: self.learning_rate,
step_count: StepCount::new(self.step_count.value() + 1),
}
}
}
/// One supervised sequence example for a readout-only Transformer update.
#[derive(Debug, Clone, PartialEq)]
pub struct TransformerReadoutTrainingExample {
hidden: HiddenSequence,
mask: AttentionMask,
targets: TokenSequence,
}
impl TransformerReadoutTrainingExample {
pub fn new(
hidden: HiddenSequence,
mask: AttentionMask,
targets: TokenSequence,
) -> CtResult<Self> {
let sequence_len = hidden.sequence_len();
if targets.as_slice().len() != sequence_len.value() {
return Err(CtError::ShapeMismatch {
op: "transformer readout training targets",
expected: format!("{} target tokens", sequence_len.value()),
got: format!("{} target tokens", targets.as_slice().len()),
});
}
if mask.query_len() != sequence_len || mask.key_len() != sequence_len {
return Err(CtError::ShapeMismatch {
op: "transformer readout training mask",
expected: format!(
"{} query rows x {} key columns",
sequence_len.value(),
sequence_len.value()
),
got: format!(
"{} query rows x {} key columns",
mask.query_len().value(),
mask.key_len().value()
),
});
}
Ok(Self {
hidden,
mask,
targets,
})
}
pub fn hidden(&self) -> &HiddenSequence {
&self.hidden
}
pub fn mask(&self) -> &AttentionMask {
&self.mask
}
pub fn targets(&self) -> &TokenSequence {
&self.targets
}
}
/// Non-empty set of supervised readout examples.
#[derive(Debug, Clone, PartialEq)]
pub struct TransformerReadoutTrainingSet(Vec<TransformerReadoutTrainingExample>);
impl TransformerReadoutTrainingSet {
pub fn new(
examples: impl IntoIterator<Item = TransformerReadoutTrainingExample>,
) -> CtResult<Self> {
let examples = examples.into_iter().collect::<Vec<_>>();
if examples.is_empty() {
return Err(CtError::EmptyInput("transformer readout training set"));
}
Ok(Self(examples))
}
pub fn examples(&self) -> &[TransformerReadoutTrainingExample] {
&self.0
}
}
/// One full-batch update of the sequence readout parameters.
#[derive(Debug, Clone, PartialEq)]
pub struct TransformerReadoutTrainStep {
dataset: TransformerReadoutTrainingSet,
}
impl TransformerReadoutTrainStep {
pub fn new(dataset: TransformerReadoutTrainingSet) -> Self {
Self { dataset }
}
}
impl Morphism<TransformerTrainingState, TransformerTrainingState> for TransformerReadoutTrainStep {
fn name(&self) -> &'static str {
"transformer_readout_train_step"
}
fn apply(&self, state: TransformerTrainingState) -> CtResult<TransformerTrainingState> {
let input_dimension = state.parameters.readout().input_dimension().value();
let vocab_size = state.parameters.vocab_size().value();
let mut grad_weight = vec![vec![0.0; vocab_size]; input_dimension];
let mut grad_bias = vec![0.0; vocab_size];
let mut position_count = 0usize;
for example in self.dataset.examples() {
let encoded = state
.parameters
.encode(example.hidden().clone(), example.mask().clone())?;
let logits = state.parameters.readout().apply(encoded.clone())?;
for ((hidden_row, logit_row), target) in encoded
.rows()
.iter()
.zip(logits.rows())
.zip(example.targets().as_slice())
{
let target_index = target.index();
if target_index >= vocab_size {
return Err(CtError::OutOfRange {
kind: "sequence target",
index: target_index,
limit: vocab_size,
});
}
let probabilities = Softmax.apply(logit_row.clone())?;
let mut dlogits = probabilities.as_slice().to_vec();
dlogits[target_index] -= 1.0;
for (vocab_id, dlogit) in dlogits.iter().copied().enumerate() {
grad_bias[vocab_id] += dlogit;
for (feature, hidden_value) in hidden_row.as_slice().iter().copied().enumerate()
{
grad_weight[feature][vocab_id] += hidden_value * dlogit;
}
}
position_count += 1;
}
}
let scale = state.learning_rate().value() / position_count as f32;
let mut updated_weight = state.parameters.readout().weight().to_vec();
let mut updated_bias = state.parameters.readout().bias().to_vec();
for (row, grad_row) in updated_weight.iter_mut().zip(&grad_weight) {
for (weight, grad) in row.iter_mut().zip(grad_row) {
*weight -= scale * grad;
}
}
for (bias, grad) in updated_bias.iter_mut().zip(&grad_bias) {
*bias -= scale * grad;
}
let updated_readout = TransformerReadout::new(updated_weight, updated_bias)?;
let updated_parameters = state.parameters.clone().with_readout(updated_readout)?;
Ok(state.record_updated_parameters(updated_parameters))
}
}
/// One supervised sequence example for local feed-forward sublayer training.
#[derive(Debug, Clone, PartialEq)]
pub struct TransformerFeedForwardTrainingExample {
input: HiddenSequence,
target: HiddenSequence,
}
impl TransformerFeedForwardTrainingExample {
pub fn new(input: HiddenSequence, target: HiddenSequence) -> CtResult<Self> {
if input.sequence_len() != target.sequence_len() {
return Err(CtError::ShapeMismatch {
op: "transformer feed-forward training sequence",
expected: format!("{} target rows", input.sequence_len().value()),
got: format!("{} target rows", target.sequence_len().value()),
});
}
if input.model_dimension() != target.model_dimension() {
return Err(CtError::ShapeMismatch {
op: "transformer feed-forward training dimension",
expected: format!("target dimension {}", input.model_dimension().value()),
got: format!("target dimension {}", target.model_dimension().value()),
});
}
Ok(Self { input, target })
}
pub fn input(&self) -> &HiddenSequence {
&self.input
}
pub fn target(&self) -> &HiddenSequence {
&self.target
}
}
/// Non-empty set of local feed-forward training examples.
#[derive(Debug, Clone, PartialEq)]
pub struct TransformerFeedForwardTrainingSet(Vec<TransformerFeedForwardTrainingExample>);
impl TransformerFeedForwardTrainingSet {
pub fn new(
examples: impl IntoIterator<Item = TransformerFeedForwardTrainingExample>,
) -> CtResult<Self> {
let examples = examples.into_iter().collect::<Vec<_>>();
if examples.is_empty() {
return Err(CtError::EmptyInput("transformer feed-forward training set"));
}
Ok(Self(examples))
}
pub fn examples(&self) -> &[TransformerFeedForwardTrainingExample] {
&self.0
}
}
/// One full-batch update of the position-wise feed-forward sublayer.
#[derive(Debug, Clone, PartialEq)]
pub struct TransformerFeedForwardTrainStep {
dataset: TransformerFeedForwardTrainingSet,
}
impl TransformerFeedForwardTrainStep {
pub fn new(dataset: TransformerFeedForwardTrainingSet) -> Self {
Self { dataset }
}
}
impl Morphism<TransformerTrainingState, TransformerTrainingState>
for TransformerFeedForwardTrainStep
{
fn name(&self) -> &'static str {
"transformer_feed_forward_train_step"
}
fn apply(&self, state: TransformerTrainingState) -> CtResult<TransformerTrainingState> {
let feed_forward = state.parameters.feed_forward();
let mut gradients = FeedForwardGradients::new(feed_forward);
let mut row_count = 0usize;
for example in self.dataset.examples() {
if example.input().model_dimension() != feed_forward.input_dimension() {
return Err(CtError::ShapeMismatch {
op: "transformer feed-forward train step",
expected: format!("input dimension {}", feed_forward.input_dimension().value()),
got: format!(
"input dimension {}",
example.input().model_dimension().value()
),
});
}
let (_output, cache_rows) = feed_forward_with_cache(feed_forward, example.input())?;
for (cache, target_row) in cache_rows.iter().zip(example.target().rows()) {
let d_output = cache
.output
.iter()
.zip(target_row.as_slice())
.map(|(output_value, target_value)| output_value - target_value)
.collect::<Vec<_>>();
gradients.accumulate(feed_forward, cache, &d_output);
row_count += 1;
}
}
let updated_feed_forward =
gradients.apply_to(feed_forward, state.learning_rate(), row_count)?;
let updated_parameters = state
.parameters
.clone()
.with_feed_forward(updated_feed_forward)?;
Ok(state.record_updated_parameters(updated_parameters))
}
}
/// One supervised sequence example for a composed block-level update.
#[derive(Debug, Clone, PartialEq)]
pub struct TransformerBlockTrainingExample {
hidden: HiddenSequence,
mask: AttentionMask,
targets: TokenSequence,
}
impl TransformerBlockTrainingExample {
pub fn new(
hidden: HiddenSequence,
mask: AttentionMask,
targets: TokenSequence,
) -> CtResult<Self> {
let sequence_len = hidden.sequence_len();
if targets.as_slice().len() != sequence_len.value() {
return Err(CtError::ShapeMismatch {
op: "transformer block training targets",
expected: format!("{} target tokens", sequence_len.value()),
got: format!("{} target tokens", targets.as_slice().len()),
});
}
if mask.query_len() != sequence_len || mask.key_len() != sequence_len {
return Err(CtError::ShapeMismatch {
op: "transformer block training mask",
expected: format!(
"{} query rows x {} key columns",
sequence_len.value(),
sequence_len.value()
),
got: format!(
"{} query rows x {} key columns",
mask.query_len().value(),
mask.key_len().value()
),
});
}
Ok(Self {
hidden,
mask,
targets,
})
}
pub fn hidden(&self) -> &HiddenSequence {
&self.hidden
}
pub fn mask(&self) -> &AttentionMask {
&self.mask
}
pub fn targets(&self) -> &TokenSequence {
&self.targets
}
}
/// Non-empty set of supervised block-level training examples.
#[derive(Debug, Clone, PartialEq)]
pub struct TransformerBlockTrainingSet(Vec<TransformerBlockTrainingExample>);
impl TransformerBlockTrainingSet {
pub fn new(
examples: impl IntoIterator<Item = TransformerBlockTrainingExample>,
) -> CtResult<Self> {
let examples = examples.into_iter().collect::<Vec<_>>();
if examples.is_empty() {
return Err(CtError::EmptyInput("transformer block training set"));
}
Ok(Self(examples))
}
pub fn examples(&self) -> &[TransformerBlockTrainingExample] {
&self.0
}
}
/// One full-batch update through the readout, block sublayers, and attention heads.
#[derive(Debug, Clone, PartialEq)]
pub struct TransformerBlockTrainStep {
dataset: TransformerBlockTrainingSet,
}
impl TransformerBlockTrainStep {
pub fn new(dataset: TransformerBlockTrainingSet) -> Self {
Self { dataset }
}
}
impl Morphism<TransformerTrainingState, TransformerTrainingState> for TransformerBlockTrainStep {
fn name(&self) -> &'static str {
"transformer_block_train_step"
}
fn apply(&self, state: TransformerTrainingState) -> CtResult<TransformerTrainingState> {
let readout = state.parameters.readout();
let feed_forward = state.parameters.feed_forward();
let output_projection = state.parameters.output_projection();
let attention_norm = state.parameters.attention_norm();
let feed_forward_norm = state.parameters.feed_forward_norm();
let mut readout_gradients = ReadoutGradients::new(readout);
let mut feed_forward_gradients = FeedForwardGradients::new(feed_forward);
let mut output_projection_gradients =
AttentionOutputProjectionGradients::new(output_projection);
let mut attention_norm_gradients = LayerNormGradients::new(attention_norm.parameters());
let mut feed_forward_norm_gradients =
LayerNormGradients::new(feed_forward_norm.parameters());
let mut attention_head_gradients = state
.parameters
.attention_heads()
.iter()
.map(AttentionHeadGradients::new)
.collect::<Vec<_>>();
let mut position_count = 0usize;
for example in self.dataset.examples() {
let positioned = state
.parameters
.positional_encoding
.apply(example.hidden().clone())?;
let block_cache = state
.parameters
.block
.apply_with_training_cache(positioned.clone(), example.mask().clone())?;
let logits = readout.apply(block_cache.output.clone())?;
let vocab_size = logits.vocab_size().value();
let mut d_multi_head_rows = vec![
vec![0.0; output_projection.input_dimension().value()];
block_cache.multi_head_output.sequence_len().value()
];
for (
position,
(
(((encoded_row, logit_row), with_feed_forward_row), with_attention_row),
((feed_forward_cache, multi_head_row), target),
),
) in block_cache
.output
.rows()
.iter()
.zip(logits.rows())
.zip(block_cache.with_feed_forward.rows())
.zip(block_cache.with_attention.rows())
.zip(
block_cache
.feed_forward_rows
.iter()
.zip(block_cache.multi_head_output.rows())
.zip(example.targets().as_slice()),
)
.enumerate()
{
let dlogits =
softmax_cross_entropy_logits_gradient(logit_row, target.index(), vocab_size)?;
let d_encoded =
readout_gradients.accumulate(readout, encoded_row.as_slice(), &dlogits);
let d_with_feed_forward = feed_forward_norm_gradients.accumulate(
&d_encoded,
with_feed_forward_row.as_slice(),
feed_forward_norm.parameters(),
);
let d_feed_forward_input = feed_forward_gradients.accumulate(
feed_forward,
feed_forward_cache,
&d_with_feed_forward,
);
let d_normalized_attention = add_rows(&d_with_feed_forward, &d_feed_forward_input);
let d_with_attention = attention_norm_gradients.accumulate(
&d_normalized_attention,
with_attention_row.as_slice(),
attention_norm.parameters(),
);
d_multi_head_rows[position] = output_projection_gradients.accumulate(
output_projection,
multi_head_row.as_slice(),
&d_with_attention,
);
position_count += 1;
}
let value_dimension = block_cache.multi_head_output.head_dimension().value();
for (head_index, ((head_gradient, head), head_cache)) in attention_head_gradients
.iter_mut()
.zip(state.parameters.attention_heads())
.zip(&block_cache.attention_heads)
.enumerate()
{
let start = head_index * value_dimension;
let end = start + value_dimension;
let d_head_output_rows = d_multi_head_rows
.iter()
.map(|row| row[start..end].to_vec())
.collect::<Vec<_>>();
head_gradient.accumulate(
head,
&positioned,
head_cache,
example.mask(),
&d_head_output_rows,
)?;
}
}
let updated_readout =
readout_gradients.apply_to(readout, state.learning_rate(), position_count)?;
let updated_feed_forward =
feed_forward_gradients.apply_to(feed_forward, state.learning_rate(), position_count)?;
let updated_output_projection = output_projection_gradients.apply_to(
output_projection,
state.learning_rate(),
position_count,
)?;
let updated_attention_norm = LayerNormalization::new(attention_norm_gradients.apply_to(
attention_norm.parameters(),
state.learning_rate(),
position_count,
)?);
let updated_feed_forward_norm =
LayerNormalization::new(feed_forward_norm_gradients.apply_to(
feed_forward_norm.parameters(),
state.learning_rate(),
position_count,
)?);
let updated_heads = state
.parameters
.attention_heads()
.iter()
.zip(attention_head_gradients)
.map(|(head, gradients)| {
gradients.apply_to(head, state.learning_rate(), position_count)
})
.collect::<CtResult<Vec<_>>>()?;
let updated_parameters = state
.parameters
.clone()
.with_readout(updated_readout)?
.with_feed_forward(updated_feed_forward)?
.with_output_projection(updated_output_projection)?
.with_layer_norms(updated_attention_norm, updated_feed_forward_norm)?
.with_attention_heads(updated_heads)?;
Ok(state.record_updated_parameters(updated_parameters))
}
}
/// Average sequence cross-entropy for the structured Transformer readout.
pub fn transformer_readout_average_loss(
state: &TransformerTrainingState,
dataset: &TransformerReadoutTrainingSet,
) -> CtResult<Loss> {
let mut total = 0.0;
let mut position_count = 0usize;
for example in dataset.examples() {
let logits = state.apply(Product::new(
example.hidden().clone(),
example.mask().clone(),
))?;
let vocab_size = logits.vocab_size().value();
for (logit_row, target) in logits.rows().iter().zip(example.targets().as_slice()) {
let target_index = target.index();
if target_index >= vocab_size {
return Err(CtError::OutOfRange {
kind: "sequence target",
index: target_index,
limit: vocab_size,
});
}
let probabilities = Softmax.apply(logit_row.clone())?;
let probability = probabilities.as_slice()[target_index].max(1e-9);
total += -probability.ln();
position_count += 1;
}
}
Loss::new(total / position_count as f32)
}
/// Average squared error for local feed-forward sublayer training.
pub fn transformer_feed_forward_average_loss(
state: &TransformerTrainingState,
dataset: &TransformerFeedForwardTrainingSet,
) -> CtResult<Loss> {
let feed_forward = state.parameters.feed_forward();
let mut total = 0.0;
let mut value_count = 0usize;
for example in dataset.examples() {
let output = feed_forward.apply(example.input().clone())?;
for (output_row, target_row) in output.rows().iter().zip(example.target().rows()) {
for (output_value, target_value) in
output_row.as_slice().iter().zip(target_row.as_slice())
{
let error = output_value - target_value;
total += 0.5 * error * error;
value_count += 1;
}
}
}
Loss::new(total / value_count as f32)
}
/// Average sequence cross-entropy for the composed block-level training set.
pub fn transformer_block_average_loss(
state: &TransformerTrainingState,
dataset: &TransformerBlockTrainingSet,
) -> CtResult<Loss> {
let mut total = 0.0;
let mut position_count = 0usize;
for example in dataset.examples() {
let logits = state.apply(Product::new(
example.hidden().clone(),
example.mask().clone(),
))?;
let vocab_size = logits.vocab_size().value();
for (logit_row, target) in logits.rows().iter().zip(example.targets().as_slice()) {
let target_index = target.index();
if target_index >= vocab_size {
return Err(CtError::OutOfRange {
kind: "sequence target",
index: target_index,
limit: vocab_size,
});
}
let probabilities = Softmax.apply(logit_row.clone())?;
let probability = probabilities.as_slice()[target_index].max(1e-9);
total += -probability.ln();
position_count += 1;
}
}
Loss::new(total / position_count as f32)
}
/// Applies softmax independently to each query row.
#[derive(Debug, Clone)]
pub struct AttentionSoftmax;
impl Morphism<AttentionScores, AttentionWeights> for AttentionSoftmax {
fn name(&self) -> &'static str {
"attention_softmax"
}
fn apply(&self, scores: AttentionScores) -> CtResult<AttentionWeights> {
let rows = scores
.rows
.into_iter()
.map(|row| Softmax.apply(row))
.collect::<CtResult<Vec<_>>>()?;
AttentionWeights::new(rows)
}
}
/// Mixes value vectors with row-wise attention weights.
#[derive(Debug, Clone)]
pub struct WeightedValueMixing;
impl Morphism<Product<AttentionWeights, ValueSequence>, AttentionOutput> for WeightedValueMixing {
fn name(&self) -> &'static str {
"weighted_value_mixing"
}
fn apply(&self, input: Product<AttentionWeights, ValueSequence>) -> CtResult<AttentionOutput> {
let (weights, values) = input.into_parts();
if weights.key_len() != values.sequence_len() {
return Err(CtError::ShapeMismatch {
op: "weighted value mixing",
expected: format!("{} value rows", weights.key_len().value()),
got: format!("{} value rows", values.sequence_len().value()),
});
}
let value_width = values.head_dimension().value();
let rows = weights
.rows()
.iter()
.map(|weight_row| weighted_sum(weight_row.as_slice(), values.rows(), value_width))
.collect::<Vec<_>>();
AttentionOutput::new(rows)
}
}
fn weighted_sum(weights: &[f32], values: &[Vector], value_width: usize) -> Vec<f32> {
let mut output = vec![0.0; value_width];
for (weight, value) in weights.iter().zip(values.iter()) {
for (output_value, value_component) in output.iter_mut().zip(value.as_slice()) {
*output_value += weight * value_component;
}
}
output
}
/// Concatenates single-head outputs along the feature dimension.
#[derive(Debug, Clone)]
pub struct ConcatenateHeads;
impl Morphism<AttentionHeadOutputs, MultiHeadOutput> for ConcatenateHeads {
fn name(&self) -> &'static str {
"concatenate_heads"
}
fn apply(&self, heads: AttentionHeadOutputs) -> CtResult<MultiHeadOutput> {
let rows = (0..heads.sequence_len().value())
.map(|position| {
heads
.heads()
.iter()
.flat_map(|head| head.rows()[position].as_slice().iter().copied())
.collect::<Vec<_>>()
})
.collect::<Vec<_>>();
MultiHeadOutput::new(rows, heads.head_count(), heads.head_dimension())
}
}
impl Morphism<MultiHeadOutput, ProjectedAttentionOutput> for AttentionOutputProjection {
fn name(&self) -> &'static str {
"attention_output_projection"
}
fn apply(&self, input: MultiHeadOutput) -> CtResult<ProjectedAttentionOutput> {
if input.model_dimension() != self.input_dimension {
return Err(CtError::ShapeMismatch {
op: "attention output projection",
expected: format!("input dimension {}", self.input_dimension.value()),
got: format!("input dimension {}", input.model_dimension().value()),
});
}
let rows = input
.rows()
.iter()
.map(|row| project_row(row.as_slice(), &self.weight, &self.bias))
.collect::<Vec<_>>();
ProjectedAttentionOutput::new(rows)
}
}
/// Adds a same-shaped sublayer output back to the hidden sequence.
#[derive(Debug, Clone)]
pub struct ResidualConnection;
impl Morphism<Product<HiddenSequence, ProjectedAttentionOutput>, HiddenSequence>
for ResidualConnection
{
fn name(&self) -> &'static str {
"residual_connection"
}
fn apply(
&self,
input: Product<HiddenSequence, ProjectedAttentionOutput>,
) -> CtResult<HiddenSequence> {
let (hidden, sublayer_output) = input.into_parts();
if hidden.sequence_len() != sublayer_output.sequence_len() {
return Err(CtError::ShapeMismatch {
op: "residual connection",
expected: format!("{} sequence rows", hidden.sequence_len().value()),
got: format!("{} sequence rows", sublayer_output.sequence_len().value()),
});
}
if hidden.model_dimension() != sublayer_output.model_dimension() {
return Err(CtError::ShapeMismatch {
op: "residual connection",
expected: format!("model dimension {}", hidden.model_dimension().value()),
got: format!(
"model dimension {}",
sublayer_output.model_dimension().value()
),
});
}
let rows = hidden
.rows()
.iter()
.zip(sublayer_output.rows())
.map(|(left, right)| add_rows(left.as_slice(), right.as_slice()))
.collect::<Vec<_>>();
HiddenSequence::new(rows)
}
}
impl Morphism<HiddenSequence, HiddenSequence> for LayerNormalization {
fn name(&self) -> &'static str {
"layer_normalization"
}
fn apply(&self, input: HiddenSequence) -> CtResult<HiddenSequence> {
if input.model_dimension() != self.parameters.model_dimension() {
return Err(CtError::ShapeMismatch {
op: "layer normalization",
expected: format!(
"model dimension {}",
self.parameters.model_dimension().value()
),
got: format!("model dimension {}", input.model_dimension().value()),
});
}
let rows = input
.rows()
.iter()
.map(|row| normalize_row(row.as_slice(), &self.parameters))
.collect::<Vec<_>>();
HiddenSequence::new(rows)
}
}
impl Morphism<HiddenSequence, HiddenSequence> for PositionWiseFeedForward {
fn name(&self) -> &'static str {
"position_wise_feed_forward"
}
fn apply(&self, input: HiddenSequence) -> CtResult<HiddenSequence> {
let (output, _cache) = feed_forward_with_cache(self, &input)?;
Ok(output)
}
}
impl Morphism<HiddenSequence, HiddenSequence> for PositionalEncoding {
fn name(&self) -> &'static str {
"positional_encoding"
}
fn apply(&self, input: HiddenSequence) -> CtResult<HiddenSequence> {
if input.sequence_len().value() > self.max_sequence_len.value() {
return Err(CtError::ShapeMismatch {
op: "positional encoding",
expected: format!("at most {} sequence rows", self.max_sequence_len.value()),
got: format!("{} sequence rows", input.sequence_len().value()),
});
}
if input.model_dimension() != self.model_dimension {
return Err(CtError::ShapeMismatch {
op: "positional encoding",
expected: format!("model dimension {}", self.model_dimension.value()),
got: format!("model dimension {}", input.model_dimension().value()),
});
}
let rows = input
.rows()
.iter()
.zip(&self.rows)
.map(|(hidden_row, position_row)| {
add_rows(hidden_row.as_slice(), position_row.as_slice())
})
.collect::<Vec<_>>();
HiddenSequence::new(rows)
}
}
impl Morphism<HiddenSequence, SequenceLogits> for TransformerReadout {
fn name(&self) -> &'static str {
"transformer_readout"
}
fn apply(&self, input: HiddenSequence) -> CtResult<SequenceLogits> {
if input.model_dimension() != self.input_dimension {
return Err(CtError::ShapeMismatch {
op: "transformer readout",
expected: format!("input dimension {}", self.input_dimension.value()),
got: format!("input dimension {}", input.model_dimension().value()),
});
}
let rows = input
.rows()
.iter()
.map(|row| project_row(row.as_slice(), &self.weight, &self.bias))
.collect::<Vec<_>>();
SequenceLogits::new(rows)
}
}
impl Morphism<HiddenSequence, QuerySequence> for HiddenToQuery {
fn name(&self) -> &'static str {
"hidden_to_query"
}
fn apply(&self, input: HiddenSequence) -> CtResult<QuerySequence> {
QuerySequence::new(self.projection.project(&input)?)
}
}
impl Morphism<HiddenSequence, KeySequence> for HiddenToKey {
fn name(&self) -> &'static str {
"hidden_to_key"
}
fn apply(&self, input: HiddenSequence) -> CtResult<KeySequence> {
KeySequence::new(self.projection.project(&input)?)
}
}
impl Morphism<HiddenSequence, ValueSequence> for HiddenToValue {
fn name(&self) -> &'static str {
"hidden_to_value"
}
fn apply(&self, input: HiddenSequence) -> CtResult<ValueSequence> {
ValueSequence::new(self.projection.project(&input)?)
}
}
fn apply_self_attention_head(
input: &HiddenSequence,
head: &SelfAttentionHead,
) -> CtResult<AttentionOutput> {
apply_self_attention_head_with_mask(input, head, None)
}
fn apply_self_attention_head_with_mask(
input: &HiddenSequence,
head: &SelfAttentionHead,
mask: Option<&AttentionMask>,
) -> CtResult<AttentionOutput> {
Ok(apply_self_attention_head_with_mask_cache(input, head, mask)?.output)
}
fn apply_self_attention_head_with_mask_cache(
input: &HiddenSequence,
head: &SelfAttentionHead,
mask: Option<&AttentionMask>,
) -> CtResult<AttentionHeadTrainingCache> {
let queries = head.query_projection.apply(input.clone())?;
let keys = head.key_projection.apply(input.clone())?;
let values = head.value_projection.apply(input.clone())?;
let scores = ScaledDotProductScores.apply(Product::new(queries.clone(), keys.clone()))?;
let scores = if let Some(mask) = mask {
MaskedAttentionScores.apply(Product::new(scores, mask.clone()))?
} else {
scores
};
let weights = AttentionSoftmax.apply(scores)?;
let output = WeightedValueMixing.apply(Product::new(weights.clone(), values.clone()))?;
Ok(AttentionHeadTrainingCache {
queries,
keys,
values,
weights,
output,
})
}
impl Morphism<Product<HiddenSequence, HiddenSequence>, HiddenSequence> for ResidualConnection {
fn name(&self) -> &'static str {
"hidden_residual_connection"
}
fn apply(&self, input: Product<HiddenSequence, HiddenSequence>) -> CtResult<HiddenSequence> {
let (left, right) = input.into_parts();
if left.sequence_len() != right.sequence_len() {
return Err(CtError::ShapeMismatch {
op: "hidden residual connection",
expected: format!("{} sequence rows", left.sequence_len().value()),
got: format!("{} sequence rows", right.sequence_len().value()),
});
}
if left.model_dimension() != right.model_dimension() {
return Err(CtError::ShapeMismatch {
op: "hidden residual connection",
expected: format!("model dimension {}", left.model_dimension().value()),
got: format!("model dimension {}", right.model_dimension().value()),
});
}
let rows = left
.rows()
.iter()
.zip(right.rows())
.map(|(left, right)| add_rows(left.as_slice(), right.as_slice()))
.collect::<Vec<_>>();
HiddenSequence::new(rows)
}
}
impl Morphism<HiddenSequence, HiddenSequence> for SingleHeadTransformerBlock {
fn name(&self) -> &'static str {
"single_head_transformer_block"
}
fn apply(&self, input: HiddenSequence) -> CtResult<HiddenSequence> {
if input.model_dimension() != self.model_dimension {
return Err(CtError::ShapeMismatch {
op: "single-head block",
expected: format!("model dimension {}", self.model_dimension.value()),
got: format!("model dimension {}", input.model_dimension().value()),
});
}
let head = SelfAttentionHead::new(
self.query_projection.clone(),
self.key_projection.clone(),
self.value_projection.clone(),
)?;
let attention_output = apply_self_attention_head(&input, &head)?;
let head_outputs = AttentionHeadOutputs::new(vec![attention_output])?;
let multi_head_output = ConcatenateHeads.apply(head_outputs)?;
let projected_attention = self.output_projection.apply(multi_head_output)?;
let with_attention = ResidualConnection.apply(Product::new(input, projected_attention))?;
let normalized_attention = self.attention_norm.apply(with_attention)?;
let feed_forward_output = self.feed_forward.apply(normalized_attention.clone())?;
let with_feed_forward =
ResidualConnection.apply(Product::new(normalized_attention, feed_forward_output))?;
self.feed_forward_norm.apply(with_feed_forward)
}
}
impl Morphism<HiddenSequence, HiddenSequence> for MultiHeadTransformerBlock {
fn name(&self) -> &'static str {
"multi_head_transformer_block"
}
fn apply(&self, input: HiddenSequence) -> CtResult<HiddenSequence> {
if input.model_dimension() != self.model_dimension {
return Err(CtError::ShapeMismatch {
op: "multi-head block",
expected: format!("model dimension {}", self.model_dimension.value()),
got: format!("model dimension {}", input.model_dimension().value()),
});
}
let attention_outputs = self
.heads
.iter()
.map(|head| apply_self_attention_head(&input, head))
.collect::<CtResult<Vec<_>>>()?;
let head_outputs = AttentionHeadOutputs::new(attention_outputs)?;
let multi_head_output = ConcatenateHeads.apply(head_outputs)?;
let projected_attention = self.output_projection.apply(multi_head_output)?;
let with_attention = ResidualConnection.apply(Product::new(input, projected_attention))?;
let normalized_attention = self.attention_norm.apply(with_attention)?;
let feed_forward_output = self.feed_forward.apply(normalized_attention.clone())?;
let with_feed_forward =
ResidualConnection.apply(Product::new(normalized_attention, feed_forward_output))?;
self.feed_forward_norm.apply(with_feed_forward)
}
}
impl Morphism<Product<HiddenSequence, AttentionMask>, HiddenSequence>
for MaskedMultiHeadTransformerBlock
{
fn name(&self) -> &'static str {
"masked_multi_head_transformer_block"
}
fn apply(&self, input: Product<HiddenSequence, AttentionMask>) -> CtResult<HiddenSequence> {
let (hidden, mask) = input.into_parts();
Ok(self.apply_with_training_cache(hidden, mask)?.output)
}
}
impl Morphism<Product<HiddenSequence, AttentionMask>, SequenceLogits>
for TinyTransformerParameters
{
fn name(&self) -> &'static str {
"tiny_transformer_parameters"
}
fn apply(&self, input: Product<HiddenSequence, AttentionMask>) -> CtResult<SequenceLogits> {
let (hidden, mask) = input.into_parts();
let positioned = self.positional_encoding.apply(hidden)?;
let encoded = self.block.apply(Product::new(positioned, mask))?;
self.readout.apply(encoded)
}
}
impl Morphism<Product<HiddenSequence, AttentionMask>, SequenceLogits> for TransformerTrainingState {
fn name(&self) -> &'static str {
"transformer_training_state_forward"
}
fn apply(&self, input: Product<HiddenSequence, AttentionMask>) -> CtResult<SequenceLogits> {
self.parameters.apply(input)
}
}
#[derive(Debug, Clone, PartialEq)]
struct ReadoutGradients {
weight: Vec<Vec<f32>>,
bias: Vec<f32>,
}
impl ReadoutGradients {
fn new(readout: &TransformerReadout) -> Self {
Self {
weight: vec![
vec![0.0; readout.vocab_size().value()];
readout.input_dimension().value()
],
bias: vec![0.0; readout.vocab_size().value()],
}
}
fn accumulate(
&mut self,
readout: &TransformerReadout,
hidden_row: &[f32],
dlogits: &[f32],
) -> Vec<f32> {
let mut d_hidden = vec![0.0; readout.input_dimension().value()];
for (vocab_id, dlogit) in dlogits.iter().copied().enumerate() {
self.bias[vocab_id] += dlogit;
for (feature, hidden_value) in hidden_row.iter().copied().enumerate() {
self.weight[feature][vocab_id] += hidden_value * dlogit;
d_hidden[feature] += readout.weight()[feature][vocab_id] * dlogit;
}
}
d_hidden
}
fn apply_to(
self,
readout: &TransformerReadout,
learning_rate: LearningRate,
position_count: usize,
) -> CtResult<TransformerReadout> {
if position_count == 0 {
return Err(CtError::EmptyInput("readout gradient positions"));
}
let scale = learning_rate.value() / position_count as f32;
let mut updated_weight = readout.weight().to_vec();
let mut updated_bias = readout.bias().to_vec();
for (row, grad_row) in updated_weight.iter_mut().zip(&self.weight) {
for (weight, grad) in row.iter_mut().zip(grad_row) {
*weight -= scale * grad;
}
}
for (bias, grad) in updated_bias.iter_mut().zip(&self.bias) {
*bias -= scale * grad;
}
TransformerReadout::new(updated_weight, updated_bias)
}
}
#[derive(Debug, Clone, PartialEq)]
struct FeedForwardGradients {
first_weight: Vec<Vec<f32>>,
first_bias: Vec<f32>,
second_weight: Vec<Vec<f32>>,
second_bias: Vec<f32>,
}
impl FeedForwardGradients {
fn new(feed_forward: &PositionWiseFeedForward) -> Self {
Self {
first_weight: vec![
vec![0.0; feed_forward.hidden_dimension().value()];
feed_forward.input_dimension().value()
],
first_bias: vec![0.0; feed_forward.hidden_dimension().value()],
second_weight: vec![
vec![0.0; feed_forward.output_dimension().value()];
feed_forward.hidden_dimension().value()
],
second_bias: vec![0.0; feed_forward.output_dimension().value()],
}
}
fn accumulate(
&mut self,
feed_forward: &PositionWiseFeedForward,
cache: &FeedForwardRowCache,
d_output: &[f32],
) -> Vec<f32> {
let mut d_activation = vec![0.0; feed_forward.hidden_dimension().value()];
for (output_id, d_output_value) in d_output.iter().copied().enumerate() {
self.second_bias[output_id] += d_output_value;
for (hidden_id, activation_value) in cache.activation.iter().copied().enumerate() {
self.second_weight[hidden_id][output_id] += activation_value * d_output_value;
d_activation[hidden_id] +=
feed_forward.second_weight()[hidden_id][output_id] * d_output_value;
}
}
let mut d_input = vec![0.0; feed_forward.input_dimension().value()];
for (hidden_id, pre_activation_value) in cache.pre_activation.iter().copied().enumerate() {
let d_pre_activation = if pre_activation_value > 0.0 {
d_activation[hidden_id]
} else {
0.0
};
self.first_bias[hidden_id] += d_pre_activation;
for (input_id, input_value) in cache.input.iter().copied().enumerate() {
self.first_weight[input_id][hidden_id] += input_value * d_pre_activation;
d_input[input_id] +=
feed_forward.first_weight()[input_id][hidden_id] * d_pre_activation;
}
}
d_input
}
fn apply_to(
self,
feed_forward: &PositionWiseFeedForward,
learning_rate: LearningRate,
row_count: usize,
) -> CtResult<PositionWiseFeedForward> {
if row_count == 0 {
return Err(CtError::EmptyInput("feed-forward gradient rows"));
}
let scale = learning_rate.value() / row_count as f32;
let mut updated_first_weight = feed_forward.first_weight().to_vec();
let mut updated_first_bias = feed_forward.first_bias().to_vec();
let mut updated_second_weight = feed_forward.second_weight().to_vec();
let mut updated_second_bias = feed_forward.second_bias().to_vec();
for (row, grad_row) in updated_first_weight.iter_mut().zip(&self.first_weight) {
for (weight, grad) in row.iter_mut().zip(grad_row) {
*weight -= scale * grad;
}
}
for (bias, grad) in updated_first_bias.iter_mut().zip(&self.first_bias) {
*bias -= scale * grad;
}
for (row, grad_row) in updated_second_weight.iter_mut().zip(&self.second_weight) {
for (weight, grad) in row.iter_mut().zip(grad_row) {
*weight -= scale * grad;
}
}
for (bias, grad) in updated_second_bias.iter_mut().zip(&self.second_bias) {
*bias -= scale * grad;
}
PositionWiseFeedForward::new(
updated_first_weight,
updated_first_bias,
updated_second_weight,
updated_second_bias,
)
}
}
#[derive(Debug, Clone, PartialEq)]
struct AttentionOutputProjectionGradients {
weight: Vec<Vec<f32>>,
bias: Vec<f32>,
}
impl AttentionOutputProjectionGradients {
fn new(output_projection: &AttentionOutputProjection) -> Self {
Self {
weight: vec![
vec![0.0; output_projection.output_dimension().value()];
output_projection.input_dimension().value()
],
bias: vec![0.0; output_projection.output_dimension().value()],
}
}
fn accumulate(
&mut self,
output_projection: &AttentionOutputProjection,
multi_head_row: &[f32],
d_projected_attention: &[f32],
) -> Vec<f32> {
let mut d_multi_head = vec![0.0; output_projection.input_dimension().value()];
for (output_id, d_value) in d_projected_attention.iter().copied().enumerate() {
self.bias[output_id] += d_value;
for (input_id, input_value) in multi_head_row.iter().copied().enumerate() {
self.weight[input_id][output_id] += input_value * d_value;
d_multi_head[input_id] += output_projection.weight()[input_id][output_id] * d_value;
}
}
d_multi_head
}
fn apply_to(
self,
output_projection: &AttentionOutputProjection,
learning_rate: LearningRate,
row_count: usize,
) -> CtResult<AttentionOutputProjection> {
if row_count == 0 {
return Err(CtError::EmptyInput(
"attention output projection gradient rows",
));
}
let scale = learning_rate.value() / row_count as f32;
let mut updated_weight = output_projection.weight().to_vec();
let mut updated_bias = output_projection.bias().to_vec();
for (row, grad_row) in updated_weight.iter_mut().zip(&self.weight) {
for (weight, grad) in row.iter_mut().zip(grad_row) {
*weight -= scale * grad;
}
}
for (bias, grad) in updated_bias.iter_mut().zip(&self.bias) {
*bias -= scale * grad;
}
AttentionOutputProjection::new(updated_weight, updated_bias)
}
}
#[derive(Debug, Clone, PartialEq)]
struct HiddenProjectionGradients {
weight: Vec<Vec<f32>>,
bias: Vec<f32>,
}
impl HiddenProjectionGradients {
fn new(projection: &HiddenProjection) -> Self {
Self {
weight: vec![
vec![0.0; projection.head_dimension().value()];
projection.input_dimension().value()
],
bias: vec![0.0; projection.head_dimension().value()],
}
}
fn accumulate(
&mut self,
projection: &HiddenProjection,
hidden_row: &[f32],
d_output: &[f32],
) -> CtResult<()> {
if hidden_row.len() != projection.input_dimension().value() {
return Err(CtError::ShapeMismatch {
op: projection.op,
expected: format!("input dimension {}", projection.input_dimension().value()),
got: format!("input dimension {}", hidden_row.len()),
});
}
if d_output.len() != projection.head_dimension().value() {
return Err(CtError::ShapeMismatch {
op: projection.op,
expected: format!("output dimension {}", projection.head_dimension().value()),
got: format!("output dimension {}", d_output.len()),
});
}
for (output_id, d_output_value) in d_output.iter().copied().enumerate() {
self.bias[output_id] += d_output_value;
for (input_id, input_value) in hidden_row.iter().copied().enumerate() {
self.weight[input_id][output_id] += input_value * d_output_value;
}
}
Ok(())
}
fn updated_parts(
self,
projection: &HiddenProjection,
learning_rate: LearningRate,
row_count: usize,
) -> CtResult<(Vec<Vec<f32>>, Vec<f32>)> {
if row_count == 0 {
return Err(CtError::EmptyInput("hidden projection gradient rows"));
}
let scale = learning_rate.value() / row_count as f32;
let mut updated_weight = projection.weight().to_vec();
let mut updated_bias = projection.bias().to_vec();
for (row, grad_row) in updated_weight.iter_mut().zip(&self.weight) {
for (weight, grad) in row.iter_mut().zip(grad_row) {
*weight -= scale * grad;
}
}
for (bias, grad) in updated_bias.iter_mut().zip(&self.bias) {
*bias -= scale * grad;
}
Ok((updated_weight, updated_bias))
}
}
#[derive(Debug, Clone, PartialEq)]
struct AttentionHeadGradients {
query: HiddenProjectionGradients,
key: HiddenProjectionGradients,
value: HiddenProjectionGradients,
}
impl AttentionHeadGradients {
fn new(head: &SelfAttentionHead) -> Self {
Self {
query: HiddenProjectionGradients::new(&head.query_projection.projection),
key: HiddenProjectionGradients::new(&head.key_projection.projection),
value: HiddenProjectionGradients::new(&head.value_projection.projection),
}
}
fn accumulate(
&mut self,
head: &SelfAttentionHead,
input: &HiddenSequence,
cache: &AttentionHeadTrainingCache,
mask: &AttentionMask,
d_output_rows: &[Vec<f32>],
) -> CtResult<()> {
let sequence_len = input.sequence_len().value();
let value_dimension = head.value_dimension().value();
let query_key_dimension = head.query_key_dimension().value();
if d_output_rows.len() != sequence_len {
return Err(CtError::ShapeMismatch {
op: "attention head gradients",
expected: format!("{sequence_len} output rows"),
got: format!("{} output rows", d_output_rows.len()),
});
}
if mask.query_len().value() != sequence_len || mask.key_len().value() != sequence_len {
return Err(CtError::ShapeMismatch {
op: "attention head gradient mask",
expected: format!("{sequence_len} query rows x {sequence_len} key columns"),
got: format!(
"{} query rows x {} key columns",
mask.query_len().value(),
mask.key_len().value()
),
});
}
let mut d_weights = vec![vec![0.0; sequence_len]; sequence_len];
let mut d_values = vec![vec![0.0; value_dimension]; sequence_len];
for (query_id, d_output) in d_output_rows.iter().enumerate() {
if d_output.len() != value_dimension {
return Err(CtError::ShapeMismatch {
op: "attention head output gradient",
expected: format!("value dimension {value_dimension}"),
got: format!("value dimension {}", d_output.len()),
});
}
for key_id in 0..sequence_len {
let value_row = cache.values.rows()[key_id].as_slice();
let weight = cache.weights.rows()[query_id].as_slice()[key_id];
for value_id in 0..value_dimension {
d_weights[query_id][key_id] += d_output[value_id] * value_row[value_id];
d_values[key_id][value_id] += weight * d_output[value_id];
}
}
}
let mut d_scores = vec![vec![0.0; sequence_len]; sequence_len];
for query_id in 0..sequence_len {
let weight_row = cache.weights.rows()[query_id].as_slice();
let row_dot = d_weights[query_id]
.iter()
.zip(weight_row)
.map(|(grad, weight)| grad * weight)
.sum::<f32>();
for key_id in 0..sequence_len {
if mask.rows()[query_id][key_id] {
d_scores[query_id][key_id] =
weight_row[key_id] * (d_weights[query_id][key_id] - row_dot);
}
}
}
let score_scale = (query_key_dimension as f32).sqrt();
let mut d_queries = vec![vec![0.0; query_key_dimension]; sequence_len];
let mut d_keys = vec![vec![0.0; query_key_dimension]; sequence_len];
for query_id in 0..sequence_len {
let query_row = cache.queries.rows()[query_id].as_slice();
for key_id in 0..sequence_len {
let score_gradient = d_scores[query_id][key_id] / score_scale;
let key_row = cache.keys.rows()[key_id].as_slice();
for feature in 0..query_key_dimension {
d_queries[query_id][feature] += score_gradient * key_row[feature];
d_keys[key_id][feature] += score_gradient * query_row[feature];
}
}
}
for position in 0..sequence_len {
let hidden_row = input.rows()[position].as_slice();
self.query.accumulate(
&head.query_projection.projection,
hidden_row,
&d_queries[position],
)?;
self.key.accumulate(
&head.key_projection.projection,
hidden_row,
&d_keys[position],
)?;
self.value.accumulate(
&head.value_projection.projection,
hidden_row,
&d_values[position],
)?;
}
Ok(())
}
fn apply_to(
self,
head: &SelfAttentionHead,
learning_rate: LearningRate,
row_count: usize,
) -> CtResult<SelfAttentionHead> {
let (query_weight, query_bias) = self.query.updated_parts(
&head.query_projection.projection,
learning_rate,
row_count,
)?;
let (key_weight, key_bias) =
self.key
.updated_parts(&head.key_projection.projection, learning_rate, row_count)?;
let (value_weight, value_bias) = self.value.updated_parts(
&head.value_projection.projection,
learning_rate,
row_count,
)?;
SelfAttentionHead::new(
HiddenToQuery::new(query_weight, query_bias)?,
HiddenToKey::new(key_weight, key_bias)?,
HiddenToValue::new(value_weight, value_bias)?,
)
}
}
#[derive(Debug, Clone, PartialEq)]
struct LayerNormGradients {
scale: Vec<f32>,
shift: Vec<f32>,
}
impl LayerNormGradients {
fn new(parameters: &LayerNormParameters) -> Self {
Self {
scale: vec![0.0; parameters.model_dimension().value()],
shift: vec![0.0; parameters.model_dimension().value()],
}
}
fn accumulate(
&mut self,
d_output: &[f32],
input: &[f32],
parameters: &LayerNormParameters,
) -> Vec<f32> {
let stats = layer_norm_stats(input, parameters);
for (feature, (grad, normalized_value)) in
d_output.iter().zip(&stats.normalized).enumerate()
{
self.scale[feature] += grad * normalized_value;
self.shift[feature] += grad;
}
layer_norm_backward_from_stats(d_output, &stats, parameters)
}
fn apply_to(
self,
parameters: &LayerNormParameters,
learning_rate: LearningRate,
row_count: usize,
) -> CtResult<LayerNormParameters> {
if row_count == 0 {
return Err(CtError::EmptyInput("layer norm gradient rows"));
}
let scale = learning_rate.value() / row_count as f32;
let mut updated_scale = parameters.scale().to_vec();
let mut updated_shift = parameters.shift().to_vec();
for (value, grad) in updated_scale.iter_mut().zip(&self.scale) {
*value -= scale * grad;
}
for (value, grad) in updated_shift.iter_mut().zip(&self.shift) {
*value -= scale * grad;
}
LayerNormParameters::new(updated_scale, updated_shift, parameters.epsilon())
}
}
fn feed_forward_with_cache(
feed_forward: &PositionWiseFeedForward,
input: &HiddenSequence,
) -> CtResult<(HiddenSequence, Vec<FeedForwardRowCache>)> {
if input.model_dimension() != feed_forward.input_dimension() {
return Err(CtError::ShapeMismatch {
op: "position-wise feed-forward",
expected: format!("input dimension {}", feed_forward.input_dimension().value()),
got: format!("input dimension {}", input.model_dimension().value()),
});
}
let mut output_rows = Vec::with_capacity(input.rows().len());
let mut cache_rows = Vec::with_capacity(input.rows().len());
for row in input.rows() {
let input_row = row.as_slice().to_vec();
let pre_activation = project_row(
&input_row,
feed_forward.first_weight(),
feed_forward.first_bias(),
);
let activation = pre_activation
.iter()
.map(|value| value.max(0.0))
.collect::<Vec<_>>();
let output = project_row(
&activation,
feed_forward.second_weight(),
feed_forward.second_bias(),
);
output_rows.push(output.clone());
cache_rows.push(FeedForwardRowCache {
input: input_row,
pre_activation,
activation,
output,
});
}
Ok((HiddenSequence::new(output_rows)?, cache_rows))
}
fn softmax_cross_entropy_logits_gradient(
logits: &Logits,
target_index: usize,
vocab_size: usize,
) -> CtResult<Vec<f32>> {
if target_index >= vocab_size {
return Err(CtError::OutOfRange {
kind: "sequence target",
index: target_index,
limit: vocab_size,
});
}
let probabilities = Softmax.apply(logits.clone())?;
let mut dlogits = probabilities.as_slice().to_vec();
dlogits[target_index] -= 1.0;
Ok(dlogits)
}
#[derive(Debug, Clone, PartialEq)]
struct LayerNormStats {
dimension: f32,
inverse_std: f32,
normalized: Vec<f32>,
}
fn layer_norm_stats(input: &[f32], parameters: &LayerNormParameters) -> LayerNormStats {
let dimension = input.len() as f32;
let mean = input.iter().sum::<f32>() / dimension;
let variance = input
.iter()
.map(|value| {
let centered = value - mean;
centered * centered
})
.sum::<f32>()
/ dimension;
let inverse_std = 1.0 / (variance + parameters.epsilon().value()).sqrt();
let normalized = input
.iter()
.map(|value| (value - mean) * inverse_std)
.collect::<Vec<_>>();
LayerNormStats {
dimension,
inverse_std,
normalized,
}
}
fn layer_norm_backward_from_stats(
d_output: &[f32],
stats: &LayerNormStats,
parameters: &LayerNormParameters,
) -> Vec<f32> {
let d_normalized = d_output
.iter()
.zip(parameters.scale())
.map(|(grad, scale)| grad * scale)
.collect::<Vec<_>>();
let sum_d_normalized = d_normalized.iter().sum::<f32>();
let sum_d_normalized_times_normalized = d_normalized
.iter()
.zip(&stats.normalized)
.map(|(grad, normalized_value)| grad * normalized_value)
.sum::<f32>();
d_normalized
.iter()
.zip(&stats.normalized)
.map(|(grad, normalized_value)| {
(stats.dimension * grad
- sum_d_normalized
- normalized_value * sum_d_normalized_times_normalized)
* stats.inverse_std
/ stats.dimension
})
.collect()
}
fn validate_projection_input(
op: &'static str,
expected: ModelDimension,
got: ModelDimension,
) -> CtResult<()> {
if expected != got {
return Err(CtError::ShapeMismatch {
op,
expected: format!("model dimension {}", expected.value()),
got: format!("model dimension {}", got.value()),
});
}
Ok(())
}
fn add_rows(left: &[f32], right: &[f32]) -> Vec<f32> {
left.iter()
.zip(right.iter())
.map(|(left, right)| left + right)
.collect()
}
fn validate_linear_parts(
op: &'static str,
weight: &[Vec<f32>],
bias: &[f32],
) -> CtResult<(ModelDimension, ModelDimension)> {
if weight.is_empty() {
return Err(CtError::EmptyInput("linear weight"));
}
if bias.is_empty() {
return Err(CtError::EmptyInput("linear bias"));
}
let output_dimension = bias.len();
if bias.iter().any(|value| !value.is_finite()) {
return Err(CtError::ShapeMismatch {
op,
expected: "finite bias values".to_string(),
got: "non-finite bias value".to_string(),
});
}
for row in weight {
if row.len() != output_dimension {
return Err(CtError::ShapeMismatch {
op,
expected: format!("weight rows have {output_dimension} columns"),
got: format!("weight row with {} columns", row.len()),
});
}
if row.iter().any(|value| !value.is_finite()) {
return Err(CtError::ShapeMismatch {
op,
expected: "finite weight values".to_string(),
got: "non-finite weight value".to_string(),
});
}
}
Ok((
ModelDimension::new(weight.len())?,
ModelDimension::new(output_dimension)?,
))
}
fn normalize_row(input: &[f32], parameters: &LayerNormParameters) -> Vec<f32> {
let mean = input.iter().sum::<f32>() / input.len() as f32;
let variance = input
.iter()
.map(|value| {
let centered = value - mean;
centered * centered
})
.sum::<f32>()
/ input.len() as f32;
let denominator = (variance + parameters.epsilon.value()).sqrt();
input
.iter()
.zip(parameters.scale.iter().zip(¶meters.shift))
.map(|(value, (scale, shift))| ((value - mean) / denominator) * scale + shift)
.collect()
}
fn project_row(input: &[f32], weight: &[Vec<f32>], bias: &[f32]) -> Vec<f32> {
let mut output = bias.to_vec();
for (feature, input_value) in input.iter().enumerate() {
for (output_value, weight_value) in output.iter_mut().zip(&weight[feature]) {
*output_value += input_value * weight_value;
}
}
output
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn scaled_dot_product_scores_build_query_by_key_rows() -> CtResult<()> {
let queries = QuerySequence::new(vec![vec![1.0, 0.0], vec![0.0, 1.0]])?;
let keys = KeySequence::new(vec![vec![1.0, 0.0], vec![0.0, 1.0], vec![1.0, 1.0]])?;
let scores = ScaledDotProductScores.apply(Product::new(queries, keys))?;
assert_eq!(scores.query_len().value(), 2);
assert_eq!(scores.key_len().value(), 3);
assert!(crate::domain::approx_eq(
scores.rows()[0].as_slice()[0],
std::f32::consts::FRAC_1_SQRT_2,
1e-4
));
assert!(crate::domain::approx_eq(
scores.rows()[0].as_slice()[1],
0.0,
1e-4
));
assert!(crate::domain::approx_eq(
scores.rows()[1].as_slice()[2],
std::f32::consts::FRAC_1_SQRT_2,
1e-4
));
Ok(())
}
#[test]
fn weighted_value_mixing_builds_one_output_per_query() -> CtResult<()> {
let queries = QuerySequence::new(vec![vec![1.0, 0.0], vec![0.0, 1.0]])?;
let keys = KeySequence::new(vec![vec![1.0, 0.0], vec![0.0, 1.0], vec![1.0, 1.0]])?;
let values = ValueSequence::new(vec![vec![1.0, 10.0], vec![2.0, 20.0], vec![3.0, 30.0]])?;
let scores = ScaledDotProductScores.apply(Product::new(queries, keys))?;
let weights = AttentionSoftmax.apply(scores)?;
let output = WeightedValueMixing.apply(Product::new(weights, values))?;
assert_eq!(output.sequence_len().value(), 2);
assert_eq!(output.head_dimension().value(), 2);
assert!(crate::domain::approx_eq(
output.rows()[0].as_slice()[0],
2.0,
1e-4
));
assert!(crate::domain::approx_eq(
output.rows()[0].as_slice()[1],
20.0,
1e-4
));
Ok(())
}
#[test]
fn concatenate_heads_preserves_sequence_and_concatenates_features() -> CtResult<()> {
let head_a = AttentionOutput::new(vec![vec![1.0, 10.0], vec![2.0, 20.0]])?;
let head_b = AttentionOutput::new(vec![vec![3.0, 30.0], vec![4.0, 40.0]])?;
let heads = AttentionHeadOutputs::new(vec![head_a, head_b])?;
let output = ConcatenateHeads.apply(heads)?;
assert_eq!(output.sequence_len().value(), 2);
assert_eq!(output.head_count().value(), 2);
assert_eq!(output.head_dimension().value(), 2);
assert_eq!(output.model_dimension().value(), 4);
assert_eq!(output.rows()[0].as_slice(), &[1.0, 10.0, 3.0, 30.0]);
assert_eq!(output.rows()[1].as_slice(), &[2.0, 20.0, 4.0, 40.0]);
Ok(())
}
#[test]
fn attention_output_projection_maps_multi_head_rows() -> CtResult<()> {
let head_a = AttentionOutput::new(vec![vec![1.0, 10.0], vec![2.0, 20.0]])?;
let head_b = AttentionOutput::new(vec![vec![3.0, 30.0], vec![4.0, 40.0]])?;
let multi_head =
ConcatenateHeads.apply(AttentionHeadOutputs::new(vec![head_a, head_b])?)?;
let projection = AttentionOutputProjection::new(
vec![
vec![1.0, 0.0],
vec![0.0, 0.1],
vec![0.5, 0.0],
vec![0.0, 0.01],
],
vec![0.0, 1.0],
)?;
let output = projection.apply(multi_head)?;
assert_eq!(output.sequence_len().value(), 2);
assert_eq!(output.model_dimension().value(), 2);
assert!(crate::domain::approx_eq(
output.rows()[0].as_slice()[0],
2.5,
1e-4
));
assert!(crate::domain::approx_eq(
output.rows()[0].as_slice()[1],
2.3,
1e-4
));
assert!(crate::domain::approx_eq(
output.rows()[1].as_slice()[0],
4.0,
1e-4
));
assert!(crate::domain::approx_eq(
output.rows()[1].as_slice()[1],
3.4,
1e-4
));
Ok(())
}
#[test]
fn residual_connection_adds_matching_sequences() -> CtResult<()> {
let hidden = HiddenSequence::new(vec![vec![1.0, 2.0], vec![3.0, 4.0]])?;
let sublayer_output = ProjectedAttentionOutput::new(vec![vec![0.5, 1.5], vec![2.5, 3.5]])?;
let output = ResidualConnection.apply(Product::new(hidden, sublayer_output))?;
assert_eq!(output.sequence_len().value(), 2);
assert_eq!(output.model_dimension().value(), 2);
assert_eq!(output.rows()[0].as_slice(), &[1.5, 3.5]);
assert_eq!(output.rows()[1].as_slice(), &[5.5, 7.5]);
Ok(())
}
#[test]
fn layer_normalization_preserves_shape_and_normalizes_each_row() -> CtResult<()> {
let hidden = HiddenSequence::new(vec![vec![1.0, 3.0], vec![2.0, 4.0]])?;
let norm = LayerNormalization::new(LayerNormParameters::identity(ModelDimension::new(2)?));
let output = norm.apply(hidden)?;
assert_eq!(output.sequence_len().value(), 2);
assert_eq!(output.model_dimension().value(), 2);
assert!(crate::domain::approx_eq(
output.rows()[0].as_slice()[0],
-0.999995,
1e-4
));
assert!(crate::domain::approx_eq(
output.rows()[0].as_slice()[1],
0.999995,
1e-4
));
assert!(crate::domain::approx_eq(
output.rows()[1].as_slice()[0],
-0.999995,
1e-4
));
assert!(crate::domain::approx_eq(
output.rows()[1].as_slice()[1],
0.999995,
1e-4
));
Ok(())
}
#[test]
fn layer_normalization_applies_scale_and_shift() -> CtResult<()> {
let hidden = HiddenSequence::new(vec![vec![1.0, 3.0]])?;
let params = LayerNormParameters::new(
vec![2.0, 0.5],
vec![1.0, -1.0],
NormalizationEpsilon::new(1e-5)?,
)?;
let norm = LayerNormalization::new(params);
let output = norm.apply(hidden)?;
assert!(crate::domain::approx_eq(
output.rows()[0].as_slice()[0],
-0.99999,
1e-4
));
assert!(crate::domain::approx_eq(
output.rows()[0].as_slice()[1],
-0.5000025,
1e-4
));
Ok(())
}
#[test]
fn position_wise_feed_forward_maps_each_row_and_preserves_shape() -> CtResult<()> {
let hidden = HiddenSequence::new(vec![vec![1.0, 2.0], vec![3.0, 4.0]])?;
let feed_forward = PositionWiseFeedForward::new(
vec![vec![1.0, -1.0, 0.5], vec![0.0, 1.0, 0.5]],
vec![0.0, 0.0, 0.0],
vec![vec![1.0, 0.0], vec![0.0, 1.0], vec![0.5, 0.5]],
vec![0.0, 0.0],
)?;
let output = feed_forward.apply(hidden)?;
assert_eq!(output.sequence_len().value(), 2);
assert_eq!(output.model_dimension().value(), 2);
assert!(crate::domain::approx_eq(
output.rows()[0].as_slice()[0],
1.75,
1e-4
));
assert!(crate::domain::approx_eq(
output.rows()[0].as_slice()[1],
1.75,
1e-4
));
assert!(crate::domain::approx_eq(
output.rows()[1].as_slice()[0],
4.75,
1e-4
));
assert!(crate::domain::approx_eq(
output.rows()[1].as_slice()[1],
2.75,
1e-4
));
Ok(())
}
#[test]
fn attention_mask_removes_disallowed_positions_before_softmax() -> CtResult<()> {
let scores = AttentionScores::new(vec![vec![2.0, 1.0, 2.0]])?;
let mask = AttentionMask::new(vec![vec![true, false, true]])?;
let masked_scores = MaskedAttentionScores.apply(Product::new(scores, mask))?;
let weights = AttentionSoftmax.apply(masked_scores)?;
assert!(crate::domain::approx_eq(
weights.rows()[0].as_slice()[0],
0.5,
1e-4
));
assert!(crate::domain::approx_eq(
weights.rows()[0].as_slice()[1],
0.0,
1e-4
));
assert!(crate::domain::approx_eq(
weights.rows()[0].as_slice()[2],
0.5,
1e-4
));
Ok(())
}
#[test]
fn attention_softmax_normalizes_each_query_row() -> CtResult<()> {
let scores = AttentionScores::new(vec![vec![2.0, 1.0], vec![0.0, 3.0]])?;
let weights = AttentionSoftmax.apply(scores)?;
assert_eq!(weights.query_len().value(), 2);
assert_eq!(weights.key_len().value(), 2);
for row in weights.rows() {
let sum: f32 = row.as_slice().iter().sum();
assert!(crate::domain::approx_eq(sum, 1.0, 1e-4));
}
Ok(())
}
#[test]
fn attention_scores_reject_non_finite_values() {
assert!(matches!(
AttentionScores::new(vec![vec![1.0, f32::NAN]]),
Err(CtError::ShapeMismatch {
op: "attention scores",
..
})
));
}
#[test]
fn attention_scores_reject_ragged_rows() {
assert!(matches!(
AttentionScores::new(vec![vec![1.0, 2.0], vec![3.0]]),
Err(CtError::ShapeMismatch {
op: "attention scores",
..
})
));
}
#[test]
fn attention_mask_rejects_rows_with_no_allowed_keys() {
assert!(matches!(
AttentionMask::new(vec![vec![false, false]]),
Err(CtError::EmptyInput("attention mask row allows no keys"))
));
}
#[test]
fn masked_attention_scores_reject_shape_mismatch() -> CtResult<()> {
let scores = AttentionScores::new(vec![vec![1.0, 2.0]])?;
let mask = AttentionMask::new(vec![vec![true, true], vec![true, true]])?;
assert!(matches!(
MaskedAttentionScores.apply(Product::new(scores, mask)),
Err(CtError::ShapeMismatch {
op: "masked attention scores",
..
})
));
Ok(())
}
#[test]
fn query_sequence_rejects_ragged_rows() {
assert!(matches!(
QuerySequence::new(vec![vec![1.0, 2.0], vec![3.0]]),
Err(CtError::ShapeMismatch {
op: "query sequence",
..
})
));
}
#[test]
fn value_sequence_rejects_empty_rows() {
assert!(matches!(
ValueSequence::new(vec![Vec::new()]),
Err(CtError::EmptyInput("attention vector row"))
));
}
#[test]
fn key_sequence_rejects_non_finite_values() {
assert!(matches!(
KeySequence::new(vec![vec![1.0, f32::NAN]]),
Err(CtError::ShapeMismatch {
op: "key sequence",
..
})
));
}
#[test]
fn scaled_dot_product_rejects_mismatched_head_dimensions() -> CtResult<()> {
let queries = QuerySequence::new(vec![vec![1.0, 0.0]])?;
let keys = KeySequence::new(vec![vec![1.0, 0.0, 1.0]])?;
assert!(matches!(
ScaledDotProductScores.apply(Product::new(queries, keys)),
Err(CtError::ShapeMismatch {
op: "scaled dot-product attention scores",
..
})
));
Ok(())
}
#[test]
fn weighted_value_mixing_rejects_value_length_mismatch() -> CtResult<()> {
let weights = AttentionWeights::new(vec![Distribution::new(vec![0.5, 0.5])?])?;
let values = ValueSequence::new(vec![vec![1.0, 10.0]])?;
assert!(matches!(
WeightedValueMixing.apply(Product::new(weights, values)),
Err(CtError::ShapeMismatch {
op: "weighted value mixing",
..
})
));
Ok(())
}
#[test]
fn sequence_and_head_dimensions_reject_zero() {
assert!(matches!(
SequenceLength::new(0),
Err(CtError::EmptyInput("sequence length"))
));
assert!(matches!(
HeadDimension::new(0),
Err(CtError::EmptyInput("head dimension"))
));
assert!(matches!(
HeadCount::new(0),
Err(CtError::EmptyInput("head count"))
));
}
#[test]
fn attention_head_outputs_reject_sequence_mismatch() -> CtResult<()> {
let head_a = AttentionOutput::new(vec![vec![1.0, 10.0]])?;
let head_b = AttentionOutput::new(vec![vec![2.0, 20.0], vec![3.0, 30.0]])?;
assert!(matches!(
AttentionHeadOutputs::new(vec![head_a, head_b]),
Err(CtError::ShapeMismatch {
op: "attention head outputs",
..
})
));
Ok(())
}
#[test]
fn attention_head_outputs_reject_head_dimension_mismatch() -> CtResult<()> {
let head_a = AttentionOutput::new(vec![vec![1.0, 10.0]])?;
let head_b = AttentionOutput::new(vec![vec![2.0, 20.0, 200.0]])?;
assert!(matches!(
AttentionHeadOutputs::new(vec![head_a, head_b]),
Err(CtError::ShapeMismatch {
op: "attention head outputs",
..
})
));
Ok(())
}
#[test]
fn attention_output_projection_rejects_input_dimension_mismatch() -> CtResult<()> {
let head_a = AttentionOutput::new(vec![vec![1.0, 10.0]])?;
let head_b = AttentionOutput::new(vec![vec![2.0, 20.0]])?;
let multi_head =
ConcatenateHeads.apply(AttentionHeadOutputs::new(vec![head_a, head_b])?)?;
let projection =
AttentionOutputProjection::new(vec![vec![1.0], vec![1.0], vec![1.0]], vec![0.0])?;
assert!(matches!(
projection.apply(multi_head),
Err(CtError::ShapeMismatch {
op: "attention output projection",
..
})
));
Ok(())
}
#[test]
fn attention_output_projection_rejects_bad_weight_shapes() {
assert!(matches!(
AttentionOutputProjection::new(vec![vec![1.0, 0.0], vec![1.0]], vec![0.0, 0.0]),
Err(CtError::ShapeMismatch {
op: "attention output projection",
..
})
));
}
#[test]
fn attention_output_projection_rejects_non_finite_values() {
assert!(matches!(
AttentionOutputProjection::new(vec![vec![1.0]], vec![f32::NAN]),
Err(CtError::ShapeMismatch {
op: "attention output projection",
..
})
));
assert!(matches!(
AttentionOutputProjection::new(vec![vec![f32::INFINITY]], vec![0.0]),
Err(CtError::ShapeMismatch {
op: "attention output projection",
..
})
));
}
#[test]
fn residual_connection_rejects_sequence_mismatch() -> CtResult<()> {
let hidden = HiddenSequence::new(vec![vec![1.0, 2.0]])?;
let sublayer_output = ProjectedAttentionOutput::new(vec![vec![0.5, 1.5], vec![2.5, 3.5]])?;
assert!(matches!(
ResidualConnection.apply(Product::new(hidden, sublayer_output)),
Err(CtError::ShapeMismatch {
op: "residual connection",
..
})
));
Ok(())
}
#[test]
fn residual_connection_rejects_model_dimension_mismatch() -> CtResult<()> {
let hidden = HiddenSequence::new(vec![vec![1.0, 2.0]])?;
let sublayer_output = ProjectedAttentionOutput::new(vec![vec![0.5, 1.5, 2.5]])?;
assert!(matches!(
ResidualConnection.apply(Product::new(hidden, sublayer_output)),
Err(CtError::ShapeMismatch {
op: "residual connection",
..
})
));
Ok(())
}
#[test]
fn layer_normalization_rejects_model_dimension_mismatch() -> CtResult<()> {
let hidden = HiddenSequence::new(vec![vec![1.0, 2.0, 3.0]])?;
let norm = LayerNormalization::new(LayerNormParameters::identity(ModelDimension::new(2)?));
assert!(matches!(
norm.apply(hidden),
Err(CtError::ShapeMismatch {
op: "layer normalization",
..
})
));
Ok(())
}
#[test]
fn layer_norm_parameters_reject_bad_shapes_and_values() {
assert!(matches!(
LayerNormParameters::new(vec![1.0, 1.0], vec![0.0], NormalizationEpsilon(1e-5)),
Err(CtError::ShapeMismatch {
op: "layer norm parameters",
..
})
));
assert!(matches!(
LayerNormParameters::new(vec![f32::NAN], vec![0.0], NormalizationEpsilon(1e-5)),
Err(CtError::ShapeMismatch {
op: "layer norm parameters",
..
})
));
assert!(matches!(
LayerNormParameters::new(vec![1.0], vec![f32::INFINITY], NormalizationEpsilon(1e-5)),
Err(CtError::ShapeMismatch {
op: "layer norm parameters",
..
})
));
assert!(matches!(
NormalizationEpsilon::new(0.0),
Err(CtError::ShapeMismatch {
op: "normalization epsilon",
..
})
));
}
#[test]
fn position_wise_feed_forward_rejects_input_dimension_mismatch() -> CtResult<()> {
let hidden = HiddenSequence::new(vec![vec![1.0, 2.0, 3.0]])?;
let feed_forward = PositionWiseFeedForward::new(
vec![vec![1.0, 0.0], vec![0.0, 1.0]],
vec![0.0, 0.0],
vec![vec![1.0, 0.0], vec![0.0, 1.0]],
vec![0.0, 0.0],
)?;
assert!(matches!(
feed_forward.apply(hidden),
Err(CtError::ShapeMismatch {
op: "position-wise feed-forward",
..
})
));
Ok(())
}
#[test]
fn position_wise_feed_forward_rejects_incompatible_layer_shapes() {
assert!(matches!(
PositionWiseFeedForward::new(
vec![vec![1.0, 0.0]],
vec![0.0, 0.0],
vec![vec![1.0], vec![1.0], vec![1.0]],
vec![0.0],
),
Err(CtError::ShapeMismatch {
op: "position-wise feed-forward",
..
})
));
assert!(matches!(
PositionWiseFeedForward::new(
vec![vec![1.0, 0.0]],
vec![0.0, 0.0],
vec![vec![1.0, 0.0], vec![1.0, 0.0]],
vec![0.0, 0.0],
),
Err(CtError::ShapeMismatch {
op: "position-wise feed-forward",
..
})
));
}
#[test]
fn position_wise_feed_forward_rejects_non_finite_values() {
assert!(matches!(
PositionWiseFeedForward::new(
vec![vec![f32::NAN]],
vec![0.0],
vec![vec![1.0]],
vec![0.0],
),
Err(CtError::ShapeMismatch {
op: "position-wise feed-forward first layer",
..
})
));
assert!(matches!(
PositionWiseFeedForward::new(
vec![vec![1.0]],
vec![0.0],
vec![vec![1.0]],
vec![f32::INFINITY],
),
Err(CtError::ShapeMismatch {
op: "position-wise feed-forward second layer",
..
})
));
}
#[test]
fn positional_encoding_adds_position_rows_and_preserves_shape() -> CtResult<()> {
let hidden = HiddenSequence::new(vec![vec![1.0, 2.0], vec![3.0, 4.0]])?;
let positions = PositionalEncoding::new(vec![vec![0.1, 0.2], vec![0.3, 0.4]])?;
let output = positions.apply(hidden)?;
assert_eq!(output.sequence_len().value(), 2);
assert_eq!(output.model_dimension().value(), 2);
assert!(crate::domain::approx_eq(
output.rows()[0].as_slice()[0],
1.1,
1e-4
));
assert!(crate::domain::approx_eq(
output.rows()[1].as_slice()[1],
4.4,
1e-4
));
Ok(())
}
#[test]
fn positional_encoding_rejects_model_dimension_mismatch() -> CtResult<()> {
let hidden = HiddenSequence::new(vec![vec![1.0, 2.0]])?;
let positions = PositionalEncoding::new(vec![vec![0.1, 0.2, 0.3]])?;
assert!(matches!(
positions.apply(hidden),
Err(CtError::ShapeMismatch {
op: "positional encoding",
..
})
));
Ok(())
}
#[test]
fn positional_encoding_rejects_sequence_too_long() -> CtResult<()> {
let hidden = HiddenSequence::new(vec![vec![1.0], vec![2.0]])?;
let positions = PositionalEncoding::new(vec![vec![0.1]])?;
assert!(matches!(
positions.apply(hidden),
Err(CtError::ShapeMismatch {
op: "positional encoding",
..
})
));
Ok(())
}
#[test]
fn hidden_to_query_projects_hidden_rows() -> CtResult<()> {
let hidden = HiddenSequence::new(vec![vec![1.0, 2.0], vec![3.0, 4.0]])?;
let projection = HiddenToQuery::new(vec![vec![1.0, 0.0], vec![0.0, 1.0]], vec![0.5, -0.5])?;
let queries = projection.apply(hidden)?;
assert_eq!(queries.sequence_len().value(), 2);
assert_eq!(queries.head_dimension().value(), 2);
assert!(crate::domain::approx_eq(
queries.rows()[0].as_slice()[0],
1.5,
1e-4
));
assert!(crate::domain::approx_eq(
queries.rows()[1].as_slice()[1],
3.5,
1e-4
));
Ok(())
}
#[test]
fn hidden_projection_rejects_input_dimension_mismatch() -> CtResult<()> {
let hidden = HiddenSequence::new(vec![vec![1.0, 2.0, 3.0]])?;
let projection = HiddenToValue::new(vec![vec![1.0, 0.0], vec![0.0, 1.0]], vec![0.0, 0.0])?;
assert!(matches!(
projection.apply(hidden),
Err(CtError::ShapeMismatch {
op: "hidden-to-value projection",
..
})
));
Ok(())
}
#[test]
fn residual_connection_adds_hidden_sequences() -> CtResult<()> {
let left = HiddenSequence::new(vec![vec![1.0, 2.0]])?;
let right = HiddenSequence::new(vec![vec![3.0, 4.0]])?;
let output = ResidualConnection.apply(Product::new(left, right))?;
assert_eq!(output.rows()[0].as_slice(), &[4.0, 6.0]);
Ok(())
}
#[test]
fn single_head_transformer_block_preserves_hidden_sequence_shape() -> CtResult<()> {
let block = tiny_single_head_block()?;
let hidden = HiddenSequence::new(vec![vec![1.0, 0.0], vec![0.0, 1.0]])?;
let output = block.apply(hidden)?;
assert_eq!(output.sequence_len().value(), 2);
assert_eq!(output.model_dimension().value(), 2);
assert!(
output
.rows()
.iter()
.flat_map(|row| row.as_slice())
.all(|value| value.is_finite())
);
Ok(())
}
#[test]
fn single_head_transformer_block_rejects_constructor_dimension_mismatch() -> CtResult<()> {
let query = HiddenToQuery::new(vec![vec![1.0, 0.0], vec![0.0, 1.0]], vec![0.0, 0.0])?;
let key = HiddenToKey::new(
vec![vec![1.0, 0.0], vec![0.0, 1.0], vec![0.0, 0.0]],
vec![0.0, 0.0],
)?;
let value = HiddenToValue::new(vec![vec![1.0, 0.0], vec![0.0, 1.0]], vec![0.0, 0.0])?;
let output_projection =
AttentionOutputProjection::new(vec![vec![1.0, 0.0], vec![0.0, 1.0]], vec![0.0, 0.0])?;
let model_dimension = ModelDimension::new(2)?;
let attention_norm =
LayerNormalization::new(LayerNormParameters::identity(model_dimension));
let feed_forward = PositionWiseFeedForward::new(
vec![vec![1.0, 0.0], vec![0.0, 1.0]],
vec![0.0, 0.0],
vec![vec![1.0, 0.0], vec![0.0, 1.0]],
vec![0.0, 0.0],
)?;
let feed_forward_norm =
LayerNormalization::new(LayerNormParameters::identity(model_dimension));
assert!(matches!(
SingleHeadTransformerBlock::new(
query,
key,
value,
output_projection,
attention_norm,
feed_forward,
feed_forward_norm,
),
Err(CtError::ShapeMismatch {
op: "single-head block key projection",
..
})
));
Ok(())
}
#[test]
fn single_head_transformer_block_rejects_apply_dimension_mismatch() -> CtResult<()> {
let block = tiny_single_head_block()?;
let hidden = HiddenSequence::new(vec![vec![1.0, 0.0, 0.0]])?;
assert!(matches!(
block.apply(hidden),
Err(CtError::ShapeMismatch {
op: "single-head block",
..
})
));
Ok(())
}
#[test]
fn self_attention_head_rejects_query_key_head_mismatch() -> CtResult<()> {
let query = HiddenToQuery::new(vec![vec![1.0], vec![0.0]], vec![0.0])?;
let key = HiddenToKey::new(vec![vec![1.0, 0.0], vec![0.0, 1.0]], vec![0.0, 0.0])?;
let value = HiddenToValue::new(vec![vec![1.0], vec![0.0]], vec![0.0])?;
assert!(matches!(
SelfAttentionHead::new(query, key, value),
Err(CtError::ShapeMismatch {
op: "self-attention head",
..
})
));
Ok(())
}
#[test]
fn multi_head_transformer_block_preserves_hidden_sequence_shape() -> CtResult<()> {
let block = tiny_multi_head_block()?;
let hidden = HiddenSequence::new(vec![vec![1.0, 0.0], vec![0.0, 1.0]])?;
let output = block.apply(hidden)?;
assert_eq!(block.head_count().value(), 2);
assert_eq!(block.value_dimension().value(), 1);
assert_eq!(output.sequence_len().value(), 2);
assert_eq!(output.model_dimension().value(), 2);
assert!(
output
.rows()
.iter()
.flat_map(|row| row.as_slice())
.all(|value| value.is_finite())
);
Ok(())
}
#[test]
fn multi_head_transformer_block_rejects_value_dimension_mismatch() -> CtResult<()> {
let head_a = tiny_self_attention_head_first_feature()?;
let head_b = SelfAttentionHead::new(
HiddenToQuery::new(vec![vec![0.0], vec![1.0]], vec![0.0])?,
HiddenToKey::new(vec![vec![0.0], vec![1.0]], vec![0.0])?,
HiddenToValue::new(vec![vec![0.0, 1.0], vec![1.0, 0.0]], vec![0.0, 0.0])?,
)?;
let model_dimension = ModelDimension::new(2)?;
assert!(matches!(
MultiHeadTransformerBlock::new(
vec![head_a, head_b],
AttentionOutputProjection::new(
vec![vec![1.0, 0.0], vec![0.0, 1.0]],
vec![0.0, 0.0],
)?,
LayerNormalization::new(LayerNormParameters::identity(model_dimension)),
identity_feed_forward()?,
LayerNormalization::new(LayerNormParameters::identity(model_dimension)),
),
Err(CtError::ShapeMismatch {
op: "multi-head block",
..
})
));
Ok(())
}
#[test]
fn multi_head_transformer_block_rejects_output_projection_input_mismatch() -> CtResult<()> {
let model_dimension = ModelDimension::new(2)?;
assert!(matches!(
MultiHeadTransformerBlock::new(
vec![
tiny_self_attention_head_first_feature()?,
tiny_self_attention_head_second_feature()?,
],
AttentionOutputProjection::new(
vec![vec![1.0, 0.0], vec![0.0, 1.0], vec![0.0, 0.0]],
vec![0.0, 0.0],
)?,
LayerNormalization::new(LayerNormParameters::identity(model_dimension)),
identity_feed_forward()?,
LayerNormalization::new(LayerNormParameters::identity(model_dimension)),
),
Err(CtError::ShapeMismatch {
op: "multi-head block output projection input",
..
})
));
Ok(())
}
#[test]
fn multi_head_transformer_block_rejects_apply_dimension_mismatch() -> CtResult<()> {
let block = tiny_multi_head_block()?;
let hidden = HiddenSequence::new(vec![vec![1.0, 0.0, 0.0]])?;
assert!(matches!(
block.apply(hidden),
Err(CtError::ShapeMismatch {
op: "multi-head block",
..
})
));
Ok(())
}
#[test]
fn masked_multi_head_transformer_block_preserves_hidden_sequence_shape() -> CtResult<()> {
let block = tiny_masked_multi_head_block()?;
let hidden = HiddenSequence::new(vec![vec![1.0, 0.0], vec![0.0, 1.0]])?;
let mask = AttentionMask::new(vec![vec![true, false], vec![true, true]])?;
let output = block.apply(Product::new(hidden, mask))?;
assert_eq!(block.head_count().value(), 2);
assert_eq!(output.sequence_len().value(), 2);
assert_eq!(output.model_dimension().value(), 2);
assert!(
output
.rows()
.iter()
.flat_map(|row| row.as_slice())
.all(|value| value.is_finite())
);
Ok(())
}
#[test]
fn masked_multi_head_transformer_block_rejects_mask_shape_mismatch() -> CtResult<()> {
let block = tiny_masked_multi_head_block()?;
let hidden = HiddenSequence::new(vec![vec![1.0, 0.0], vec![0.0, 1.0]])?;
let mask = AttentionMask::new(vec![vec![true, true, true], vec![true, true, true]])?;
assert!(matches!(
block.apply(Product::new(hidden, mask)),
Err(CtError::ShapeMismatch {
op: "masked attention scores",
..
})
));
Ok(())
}
#[test]
fn transformer_readout_maps_each_hidden_position_to_logits() -> CtResult<()> {
let readout = TransformerReadout::new(
vec![vec![1.0, 0.0, 0.5], vec![0.0, 1.0, -0.5]],
vec![0.0, 0.1, -0.1],
)?;
let hidden = HiddenSequence::new(vec![vec![2.0, 3.0], vec![4.0, 5.0]])?;
let logits = readout.apply(hidden)?;
assert_eq!(logits.sequence_len().value(), 2);
assert_eq!(logits.vocab_size().value(), 3);
assert_eq!(logits.rows()[0].as_slice(), &[2.0, 3.1, -0.6]);
assert_eq!(logits.rows()[1].as_slice(), &[4.0, 5.1, -0.6]);
Ok(())
}
#[test]
fn tiny_transformer_parameters_forward_maps_hidden_and_mask_to_sequence_logits() -> CtResult<()>
{
let parameters = tiny_transformer_parameters()?;
let hidden = HiddenSequence::new(vec![vec![1.0, 0.0], vec![0.0, 1.0]])?;
let mask = AttentionMask::new(vec![vec![true, false], vec![true, true]])?;
let logits = parameters.apply(Product::new(hidden, mask))?;
assert_eq!(parameters.model_dimension().value(), 2);
assert_eq!(parameters.max_sequence_len().value(), 2);
assert_eq!(logits.sequence_len().value(), 2);
assert_eq!(logits.vocab_size().value(), 3);
assert!(
logits
.rows()
.iter()
.flat_map(|row| row.as_slice())
.all(|value| value.is_finite())
);
Ok(())
}
#[test]
fn tiny_transformer_parameters_rejects_readout_dimension_mismatch() -> CtResult<()> {
let positional_encoding = PositionalEncoding::new(vec![vec![0.1, 0.0], vec![0.0, 0.1]])?;
let block = tiny_masked_multi_head_block()?;
let readout = TransformerReadout::new(vec![vec![1.0], vec![0.0], vec![0.5]], vec![0.0])?;
assert!(matches!(
TinyTransformerParameters::new(positional_encoding, block, readout),
Err(CtError::ShapeMismatch {
op: "tiny transformer parameters readout",
..
})
));
Ok(())
}
#[test]
fn transformer_training_state_records_updated_parameters_and_step_count() -> CtResult<()> {
let initial_parameters = tiny_transformer_parameters()?;
let updated_parameters = tiny_transformer_parameters()?;
let state = TransformerTrainingState::new(initial_parameters, LearningRate::new(0.25)?);
let next_state = state.record_updated_parameters(updated_parameters.clone());
assert_eq!(next_state.parameters(), &updated_parameters);
assert_eq!(next_state.learning_rate().value(), 0.25);
assert_eq!(next_state.step_count().value(), 1);
Ok(())
}
#[test]
fn transformer_training_state_forward_uses_structured_parameters() -> CtResult<()> {
let state =
TransformerTrainingState::new(tiny_transformer_parameters()?, LearningRate::new(0.1)?);
let hidden = HiddenSequence::new(vec![vec![1.0, 0.0], vec![0.0, 1.0]])?;
let mask = AttentionMask::new(vec![vec![true, false], vec![true, true]])?;
let logits = state.apply(Product::new(hidden, mask))?;
assert_eq!(logits.sequence_len().value(), 2);
assert_eq!(logits.vocab_size().value(), 3);
assert_eq!(state.step_count().value(), 0);
Ok(())
}
#[test]
fn transformer_readout_training_example_rejects_target_length_mismatch() -> CtResult<()> {
let hidden = HiddenSequence::new(vec![vec![1.0, 0.0], vec![0.0, 1.0]])?;
let mask = AttentionMask::new(vec![vec![true, false], vec![true, true]])?;
let targets = TokenSequence::from_indices([0])?;
assert!(matches!(
TransformerReadoutTrainingExample::new(hidden, mask, targets),
Err(CtError::ShapeMismatch {
op: "transformer readout training targets",
..
})
));
Ok(())
}
#[test]
fn transformer_readout_training_example_rejects_mask_shape_mismatch() -> CtResult<()> {
let hidden = HiddenSequence::new(vec![vec![1.0, 0.0], vec![0.0, 1.0]])?;
let mask = AttentionMask::new(vec![vec![true, true, true], vec![true, true, true]])?;
let targets = TokenSequence::from_indices([0, 1])?;
assert!(matches!(
TransformerReadoutTrainingExample::new(hidden, mask, targets),
Err(CtError::ShapeMismatch {
op: "transformer readout training mask",
..
})
));
Ok(())
}
#[test]
fn transformer_readout_train_step_reduces_sequence_loss() -> CtResult<()> {
let dataset = tiny_transformer_training_set()?;
let state =
TransformerTrainingState::new(tiny_transformer_parameters()?, LearningRate::new(0.5)?);
let before = transformer_readout_average_loss(&state, &dataset)?;
let train_step = TransformerReadoutTrainStep::new(dataset.clone());
let trained =
crate::category::apply_endomorphism_n_times(&train_step, state, StepCount::new(40))?;
let after = transformer_readout_average_loss(&trained, &dataset)?;
assert!(after.value() < before.value());
assert_eq!(trained.step_count().value(), 40);
Ok(())
}
#[test]
fn transformer_readout_train_step_rejects_target_outside_vocabulary() -> CtResult<()> {
let example = TransformerReadoutTrainingExample::new(
HiddenSequence::new(vec![vec![1.0, 0.0], vec![0.0, 1.0]])?,
AttentionMask::new(vec![vec![true, false], vec![true, true]])?,
TokenSequence::from_indices([0, 9])?,
)?;
let dataset = TransformerReadoutTrainingSet::new([example])?;
let state =
TransformerTrainingState::new(tiny_transformer_parameters()?, LearningRate::new(0.1)?);
let train_step = TransformerReadoutTrainStep::new(dataset);
assert!(matches!(
train_step.apply(state),
Err(CtError::OutOfRange {
kind: "sequence target",
index: 9,
limit: 3,
})
));
Ok(())
}
#[test]
fn transformer_feed_forward_training_example_rejects_target_shape_mismatch() -> CtResult<()> {
let input = HiddenSequence::new(vec![vec![1.0, 0.0], vec![0.0, 1.0]])?;
let target = HiddenSequence::new(vec![vec![1.0, 0.0, 0.0], vec![0.0, 1.0, 0.0]])?;
assert!(matches!(
TransformerFeedForwardTrainingExample::new(input, target),
Err(CtError::ShapeMismatch {
op: "transformer feed-forward training dimension",
..
})
));
Ok(())
}
#[test]
fn transformer_feed_forward_training_set_rejects_empty_input() {
assert!(matches!(
TransformerFeedForwardTrainingSet::new([]),
Err(CtError::EmptyInput("transformer feed-forward training set"))
));
}
#[test]
fn transformer_feed_forward_train_step_reduces_local_hidden_loss() -> CtResult<()> {
let dataset = tiny_feed_forward_training_set()?;
let state =
TransformerTrainingState::new(tiny_transformer_parameters()?, LearningRate::new(0.2)?);
let before = transformer_feed_forward_average_loss(&state, &dataset)?;
let train_step = TransformerFeedForwardTrainStep::new(dataset.clone());
let trained =
crate::category::apply_endomorphism_n_times(&train_step, state, StepCount::new(60))?;
let after = transformer_feed_forward_average_loss(&trained, &dataset)?;
assert!(after.value() < before.value());
assert_eq!(trained.step_count().value(), 60);
Ok(())
}
#[test]
fn transformer_block_training_example_rejects_mask_shape_mismatch() -> CtResult<()> {
let hidden = HiddenSequence::new(vec![vec![1.0, 0.0], vec![0.0, 1.0]])?;
let mask = AttentionMask::new(vec![vec![true, true, true], vec![true, true, true]])?;
let targets = TokenSequence::from_indices([0, 1])?;
assert!(matches!(
TransformerBlockTrainingExample::new(hidden, mask, targets),
Err(CtError::ShapeMismatch {
op: "transformer block training mask",
..
})
));
Ok(())
}
#[test]
fn transformer_block_training_set_rejects_empty_input() {
assert!(matches!(
TransformerBlockTrainingSet::new([]),
Err(CtError::EmptyInput("transformer block training set"))
));
}
#[test]
fn transformer_block_train_step_rejects_target_outside_vocabulary() -> CtResult<()> {
let example = TransformerBlockTrainingExample::new(
HiddenSequence::new(vec![vec![1.0, 0.0], vec![0.0, 1.0]])?,
AttentionMask::new(vec![vec![true, false], vec![true, true]])?,
TokenSequence::from_indices([0, 9])?,
)?;
let dataset = TransformerBlockTrainingSet::new([example])?;
let state =
TransformerTrainingState::new(tiny_transformer_parameters()?, LearningRate::new(0.1)?);
let train_step = TransformerBlockTrainStep::new(dataset);
assert!(matches!(
train_step.apply(state),
Err(CtError::OutOfRange {
kind: "sequence target",
index: 9,
limit: 3,
})
));
Ok(())
}
#[test]
fn transformer_block_train_step_reduces_sequence_loss() -> CtResult<()> {
let dataset = tiny_transformer_block_training_set()?;
let state =
TransformerTrainingState::new(tiny_transformer_parameters()?, LearningRate::new(0.2)?);
let before = transformer_block_average_loss(&state, &dataset)?;
let train_step = TransformerBlockTrainStep::new(dataset.clone());
let trained =
crate::category::apply_endomorphism_n_times(&train_step, state, StepCount::new(40))?;
let after = transformer_block_average_loss(&trained, &dataset)?;
assert!(after.value() < before.value());
assert_eq!(trained.step_count().value(), 40);
Ok(())
}
#[test]
fn transformer_block_train_step_updates_attention_output_projection() -> CtResult<()> {
let dataset = tiny_transformer_block_training_set()?;
let state =
TransformerTrainingState::new(tiny_transformer_parameters()?, LearningRate::new(0.2)?);
let before = state.parameters().output_projection().weight().to_vec();
let train_step = TransformerBlockTrainStep::new(dataset);
let trained = train_step.apply(state)?;
let after = trained.parameters().output_projection().weight().to_vec();
assert_ne!(before, after);
assert_eq!(trained.step_count().value(), 1);
Ok(())
}
#[test]
fn transformer_block_train_step_updates_layer_norm_parameters() -> CtResult<()> {
let dataset = tiny_transformer_block_training_set()?;
let state =
TransformerTrainingState::new(tiny_transformer_parameters()?, LearningRate::new(0.2)?);
let before_attention_scale = state
.parameters()
.attention_norm()
.parameters()
.scale()
.to_vec();
let before_feed_forward_shift = state
.parameters()
.feed_forward_norm()
.parameters()
.shift()
.to_vec();
let train_step = TransformerBlockTrainStep::new(dataset);
let trained = train_step.apply(state)?;
let after_attention_scale = trained
.parameters()
.attention_norm()
.parameters()
.scale()
.to_vec();
let after_feed_forward_shift = trained
.parameters()
.feed_forward_norm()
.parameters()
.shift()
.to_vec();
assert_ne!(before_attention_scale, after_attention_scale);
assert_ne!(before_feed_forward_shift, after_feed_forward_shift);
assert_eq!(trained.step_count().value(), 1);
Ok(())
}
#[test]
fn transformer_block_train_step_updates_query_key_value_projections() -> CtResult<()> {
let dataset = tiny_transformer_block_training_set()?;
let state =
TransformerTrainingState::new(tiny_transformer_parameters()?, LearningRate::new(0.2)?);
let before_query = state
.parameters()
.attention_heads()
.iter()
.map(|head| head.query_projection().weight().to_vec())
.collect::<Vec<_>>();
let before_key = state
.parameters()
.attention_heads()
.iter()
.map(|head| head.key_projection().weight().to_vec())
.collect::<Vec<_>>();
let before_value = state
.parameters()
.attention_heads()
.iter()
.map(|head| head.value_projection().weight().to_vec())
.collect::<Vec<_>>();
let train_step = TransformerBlockTrainStep::new(dataset);
let trained = train_step.apply(state)?;
let after_query = trained
.parameters()
.attention_heads()
.iter()
.map(|head| head.query_projection().weight().to_vec())
.collect::<Vec<_>>();
let after_key = trained
.parameters()
.attention_heads()
.iter()
.map(|head| head.key_projection().weight().to_vec())
.collect::<Vec<_>>();
let after_value = trained
.parameters()
.attention_heads()
.iter()
.map(|head| head.value_projection().weight().to_vec())
.collect::<Vec<_>>();
assert_ne!(before_query, after_query);
assert_ne!(before_key, after_key);
assert_ne!(before_value, after_value);
assert_eq!(trained.step_count().value(), 1);
Ok(())
}
#[test]
fn transformer_block_train_step_matches_finite_difference_for_attention_projection()
-> CtResult<()> {
let dataset = tiny_transformer_block_training_set()?;
let state =
TransformerTrainingState::new(tiny_transformer_parameters()?, LearningRate::new(0.2)?);
let train_step = TransformerBlockTrainStep::new(dataset.clone());
let trained = train_step.apply(state.clone())?;
let selection = largest_changed_attention_projection(&state, &trained)?;
let before_value = attention_projection_weight(&state, selection)?;
let after_value = attention_projection_weight(&trained, selection)?;
let inferred_gradient = (before_value - after_value) / state.learning_rate().value();
let epsilon = 1e-3;
let loss_plus = transformer_block_average_loss(
&state_with_attention_projection_weight(&state, selection, before_value + epsilon)?,
&dataset,
)?
.value();
let loss_minus = transformer_block_average_loss(
&state_with_attention_projection_weight(&state, selection, before_value - epsilon)?,
&dataset,
)?
.value();
let finite_difference = (loss_plus - loss_minus) / (2.0 * epsilon);
assert!(
(inferred_gradient - finite_difference).abs() < 1e-2,
"inferred gradient {inferred_gradient} should match finite difference {finite_difference}"
);
Ok(())
}
#[test]
fn transformer_block_train_step_matches_finite_difference_for_readout_weight() -> CtResult<()> {
let dataset = tiny_transformer_block_training_set()?;
let state =
TransformerTrainingState::new(tiny_transformer_parameters()?, LearningRate::new(0.2)?);
let train_step = TransformerBlockTrainStep::new(dataset.clone());
let trained = train_step.apply(state.clone())?;
let selection = largest_changed_readout_weight(&state, &trained)?;
let before_value = readout_weight(&state, selection);
let after_value = readout_weight(&trained, selection);
assert_block_gradient_matches_finite_difference(
&state,
&dataset,
before_value,
after_value,
|value| state_with_readout_weight(&state, selection, value),
)
}
#[test]
fn transformer_block_train_step_matches_finite_difference_for_feed_forward_weight()
-> CtResult<()> {
let dataset = tiny_transformer_block_training_set()?;
let state =
TransformerTrainingState::new(tiny_transformer_parameters()?, LearningRate::new(0.2)?);
let train_step = TransformerBlockTrainStep::new(dataset.clone());
let trained = train_step.apply(state.clone())?;
let selection = largest_changed_feed_forward_weight(&state, &trained)?;
let before_value = feed_forward_weight(&state, selection);
let after_value = feed_forward_weight(&trained, selection);
assert_block_gradient_matches_finite_difference(
&state,
&dataset,
before_value,
after_value,
|value| state_with_feed_forward_weight(&state, selection, value),
)
}
#[test]
fn transformer_block_train_step_matches_finite_difference_for_layer_norm_parameter()
-> CtResult<()> {
let dataset = tiny_transformer_block_training_set()?;
let state =
TransformerTrainingState::new(tiny_transformer_parameters()?, LearningRate::new(0.2)?);
let train_step = TransformerBlockTrainStep::new(dataset.clone());
let trained = train_step.apply(state.clone())?;
let selection = largest_changed_layer_norm_parameter(&state, &trained)?;
let before_value = layer_norm_parameter_value(&state, selection);
let after_value = layer_norm_parameter_value(&trained, selection);
assert_block_gradient_matches_finite_difference(
&state,
&dataset,
before_value,
after_value,
|value| state_with_layer_norm_parameter(&state, selection, value),
)
}
#[test]
fn transformer_block_train_step_matches_finite_difference_for_readout_bias() -> CtResult<()> {
let dataset = tiny_transformer_block_training_set()?;
let state =
TransformerTrainingState::new(tiny_transformer_parameters()?, LearningRate::new(0.2)?);
let train_step = TransformerBlockTrainStep::new(dataset.clone());
let trained = train_step.apply(state.clone())?;
let selection = largest_changed_vector_index(
state.parameters().readout().bias(),
trained.parameters().readout().bias(),
"changed readout bias",
)?;
let before_value = state.parameters().readout().bias()[selection];
let after_value = trained.parameters().readout().bias()[selection];
assert_block_gradient_matches_finite_difference(
&state,
&dataset,
before_value,
after_value,
|value| state_with_readout_bias(&state, selection, value),
)
}
#[test]
fn transformer_block_train_step_matches_finite_difference_for_feed_forward_bias() -> CtResult<()>
{
let dataset = tiny_transformer_block_training_set()?;
let state =
TransformerTrainingState::new(tiny_transformer_parameters()?, LearningRate::new(0.2)?);
let train_step = TransformerBlockTrainStep::new(dataset.clone());
let trained = train_step.apply(state.clone())?;
let selection = largest_changed_feed_forward_bias(&state, &trained)?;
let before_value = feed_forward_bias(&state, selection);
let after_value = feed_forward_bias(&trained, selection);
assert_block_gradient_matches_finite_difference(
&state,
&dataset,
before_value,
after_value,
|value| state_with_feed_forward_bias(&state, selection, value),
)
}
#[test]
fn transformer_block_train_step_matches_finite_difference_for_output_projection_bias()
-> CtResult<()> {
let dataset = tiny_transformer_block_training_set()?;
let state =
TransformerTrainingState::new(tiny_transformer_parameters()?, LearningRate::new(0.2)?);
let train_step = TransformerBlockTrainStep::new(dataset.clone());
let trained = train_step.apply(state.clone())?;
let selection = largest_changed_vector_index(
state.parameters().output_projection().bias(),
trained.parameters().output_projection().bias(),
"changed attention output projection bias",
)?;
let before_value = state.parameters().output_projection().bias()[selection];
let after_value = trained.parameters().output_projection().bias()[selection];
assert_block_gradient_matches_finite_difference(
&state,
&dataset,
before_value,
after_value,
|value| state_with_output_projection_bias(&state, selection, value),
)
}
#[test]
fn transformer_block_train_step_matches_finite_difference_for_attention_projection_bias()
-> CtResult<()> {
let dataset = tiny_transformer_block_training_set()?;
let state =
TransformerTrainingState::new(tiny_transformer_parameters()?, LearningRate::new(0.2)?);
let train_step = TransformerBlockTrainStep::new(dataset.clone());
let trained = train_step.apply(state.clone())?;
let selection = largest_changed_attention_projection_bias(&state, &trained)?;
let before_value = attention_projection_bias(&state, selection)?;
let after_value = attention_projection_bias(&trained, selection)?;
assert_block_gradient_matches_finite_difference(
&state,
&dataset,
before_value,
after_value,
|value| state_with_attention_projection_bias(&state, selection, value),
)
}
fn assert_block_gradient_matches_finite_difference(
state: &TransformerTrainingState,
dataset: &TransformerBlockTrainingSet,
before_value: f32,
after_value: f32,
mut state_with_value: impl FnMut(f32) -> CtResult<TransformerTrainingState>,
) -> CtResult<()> {
let inferred_gradient = (before_value - after_value) / state.learning_rate().value();
let epsilon = 1e-3;
let loss_plus =
transformer_block_average_loss(&state_with_value(before_value + epsilon)?, dataset)?
.value();
let loss_minus =
transformer_block_average_loss(&state_with_value(before_value - epsilon)?, dataset)?
.value();
let finite_difference = (loss_plus - loss_minus) / (2.0 * epsilon);
assert!(
(inferred_gradient - finite_difference).abs() < 1e-2,
"inferred gradient {inferred_gradient} should match finite difference {finite_difference}"
);
Ok(())
}
fn largest_changed_vector_index(
before: &[f32],
after: &[f32],
label: &'static str,
) -> CtResult<usize> {
let mut selected = None;
let mut largest_delta = 0.0;
for (index, (before_value, after_value)) in before.iter().zip(after).enumerate() {
let delta = (before_value - after_value).abs();
if delta > largest_delta {
largest_delta = delta;
selected = Some(index);
}
}
selected.ok_or(CtError::EmptyInput(label))
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
struct MatrixSelection {
input_index: usize,
output_index: usize,
}
fn largest_changed_matrix_weight(
before: &[Vec<f32>],
after: &[Vec<f32>],
label: &'static str,
) -> CtResult<MatrixSelection> {
let mut selected = None;
let mut largest_delta = 0.0;
for (input_index, (before_row, after_row)) in before.iter().zip(after).enumerate() {
for (output_index, (before_value, after_value)) in
before_row.iter().zip(after_row).enumerate()
{
let delta = (before_value - after_value).abs();
if delta > largest_delta {
largest_delta = delta;
selected = Some(MatrixSelection {
input_index,
output_index,
});
}
}
}
selected.ok_or(CtError::EmptyInput(label))
}
fn largest_changed_readout_weight(
before: &TransformerTrainingState,
after: &TransformerTrainingState,
) -> CtResult<MatrixSelection> {
largest_changed_matrix_weight(
before.parameters().readout().weight(),
after.parameters().readout().weight(),
"changed readout weight",
)
}
fn readout_weight(state: &TransformerTrainingState, selection: MatrixSelection) -> f32 {
state.parameters().readout().weight()[selection.input_index][selection.output_index]
}
fn state_with_readout_weight(
state: &TransformerTrainingState,
selection: MatrixSelection,
value: f32,
) -> CtResult<TransformerTrainingState> {
let readout = state.parameters().readout();
let mut weight = readout.weight().to_vec();
weight[selection.input_index][selection.output_index] = value;
let readout = TransformerReadout::new(weight, readout.bias().to_vec())?;
let parameters = state.parameters().clone().with_readout(readout)?;
Ok(TransformerTrainingState::from_parts(
parameters,
state.learning_rate(),
state.step_count(),
))
}
fn state_with_readout_bias(
state: &TransformerTrainingState,
selection: usize,
value: f32,
) -> CtResult<TransformerTrainingState> {
let readout = state.parameters().readout();
let mut bias = readout.bias().to_vec();
bias[selection] = value;
let readout = TransformerReadout::new(readout.weight().to_vec(), bias)?;
let parameters = state.parameters().clone().with_readout(readout)?;
Ok(TransformerTrainingState::from_parts(
parameters,
state.learning_rate(),
state.step_count(),
))
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
enum FeedForwardWeightKind {
First,
Second,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
struct FeedForwardWeightSelection {
kind: FeedForwardWeightKind,
matrix: MatrixSelection,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
struct FeedForwardBiasSelection {
kind: FeedForwardWeightKind,
index: usize,
}
fn largest_changed_feed_forward_weight(
before: &TransformerTrainingState,
after: &TransformerTrainingState,
) -> CtResult<FeedForwardWeightSelection> {
let first = largest_changed_matrix_weight(
before.parameters().feed_forward().first_weight(),
after.parameters().feed_forward().first_weight(),
"changed first feed-forward weight",
);
let second = largest_changed_matrix_weight(
before.parameters().feed_forward().second_weight(),
after.parameters().feed_forward().second_weight(),
"changed second feed-forward weight",
);
match (first, second) {
(Ok(first), Ok(second)) => {
let first_delta = feed_forward_weight_delta(
before,
after,
FeedForwardWeightSelection {
kind: FeedForwardWeightKind::First,
matrix: first,
},
);
let second_delta = feed_forward_weight_delta(
before,
after,
FeedForwardWeightSelection {
kind: FeedForwardWeightKind::Second,
matrix: second,
},
);
if first_delta >= second_delta {
Ok(FeedForwardWeightSelection {
kind: FeedForwardWeightKind::First,
matrix: first,
})
} else {
Ok(FeedForwardWeightSelection {
kind: FeedForwardWeightKind::Second,
matrix: second,
})
}
}
(Ok(first), Err(_)) => Ok(FeedForwardWeightSelection {
kind: FeedForwardWeightKind::First,
matrix: first,
}),
(Err(_), Ok(second)) => Ok(FeedForwardWeightSelection {
kind: FeedForwardWeightKind::Second,
matrix: second,
}),
(Err(_), Err(_)) => Err(CtError::EmptyInput("changed feed-forward weight")),
}
}
fn feed_forward_weight_delta(
before: &TransformerTrainingState,
after: &TransformerTrainingState,
selection: FeedForwardWeightSelection,
) -> f32 {
(feed_forward_weight(before, selection) - feed_forward_weight(after, selection)).abs()
}
fn feed_forward_weight(
state: &TransformerTrainingState,
selection: FeedForwardWeightSelection,
) -> f32 {
let feed_forward = state.parameters().feed_forward();
match selection.kind {
FeedForwardWeightKind::First => feed_forward.first_weight()
[selection.matrix.input_index][selection.matrix.output_index],
FeedForwardWeightKind::Second => feed_forward.second_weight()
[selection.matrix.input_index][selection.matrix.output_index],
}
}
fn state_with_feed_forward_weight(
state: &TransformerTrainingState,
selection: FeedForwardWeightSelection,
value: f32,
) -> CtResult<TransformerTrainingState> {
let feed_forward = state.parameters().feed_forward();
let mut first_weight = feed_forward.first_weight().to_vec();
let mut second_weight = feed_forward.second_weight().to_vec();
match selection.kind {
FeedForwardWeightKind::First => {
first_weight[selection.matrix.input_index][selection.matrix.output_index] = value;
}
FeedForwardWeightKind::Second => {
second_weight[selection.matrix.input_index][selection.matrix.output_index] = value;
}
}
let feed_forward = PositionWiseFeedForward::new(
first_weight,
feed_forward.first_bias().to_vec(),
second_weight,
feed_forward.second_bias().to_vec(),
)?;
let parameters = state.parameters().clone().with_feed_forward(feed_forward)?;
Ok(TransformerTrainingState::from_parts(
parameters,
state.learning_rate(),
state.step_count(),
))
}
fn largest_changed_feed_forward_bias(
before: &TransformerTrainingState,
after: &TransformerTrainingState,
) -> CtResult<FeedForwardBiasSelection> {
let first = largest_changed_vector_index(
before.parameters().feed_forward().first_bias(),
after.parameters().feed_forward().first_bias(),
"changed first feed-forward bias",
);
let second = largest_changed_vector_index(
before.parameters().feed_forward().second_bias(),
after.parameters().feed_forward().second_bias(),
"changed second feed-forward bias",
);
match (first, second) {
(Ok(first), Ok(second)) => {
let first_delta = feed_forward_bias_delta(
before,
after,
FeedForwardBiasSelection {
kind: FeedForwardWeightKind::First,
index: first,
},
);
let second_delta = feed_forward_bias_delta(
before,
after,
FeedForwardBiasSelection {
kind: FeedForwardWeightKind::Second,
index: second,
},
);
if first_delta >= second_delta {
Ok(FeedForwardBiasSelection {
kind: FeedForwardWeightKind::First,
index: first,
})
} else {
Ok(FeedForwardBiasSelection {
kind: FeedForwardWeightKind::Second,
index: second,
})
}
}
(Ok(first), Err(_)) => Ok(FeedForwardBiasSelection {
kind: FeedForwardWeightKind::First,
index: first,
}),
(Err(_), Ok(second)) => Ok(FeedForwardBiasSelection {
kind: FeedForwardWeightKind::Second,
index: second,
}),
(Err(_), Err(_)) => Err(CtError::EmptyInput("changed feed-forward bias")),
}
}
fn feed_forward_bias_delta(
before: &TransformerTrainingState,
after: &TransformerTrainingState,
selection: FeedForwardBiasSelection,
) -> f32 {
(feed_forward_bias(before, selection) - feed_forward_bias(after, selection)).abs()
}
fn feed_forward_bias(
state: &TransformerTrainingState,
selection: FeedForwardBiasSelection,
) -> f32 {
let feed_forward = state.parameters().feed_forward();
match selection.kind {
FeedForwardWeightKind::First => feed_forward.first_bias()[selection.index],
FeedForwardWeightKind::Second => feed_forward.second_bias()[selection.index],
}
}
fn state_with_feed_forward_bias(
state: &TransformerTrainingState,
selection: FeedForwardBiasSelection,
value: f32,
) -> CtResult<TransformerTrainingState> {
let feed_forward = state.parameters().feed_forward();
let mut first_bias = feed_forward.first_bias().to_vec();
let mut second_bias = feed_forward.second_bias().to_vec();
match selection.kind {
FeedForwardWeightKind::First => {
first_bias[selection.index] = value;
}
FeedForwardWeightKind::Second => {
second_bias[selection.index] = value;
}
}
let feed_forward = PositionWiseFeedForward::new(
feed_forward.first_weight().to_vec(),
first_bias,
feed_forward.second_weight().to_vec(),
second_bias,
)?;
let parameters = state.parameters().clone().with_feed_forward(feed_forward)?;
Ok(TransformerTrainingState::from_parts(
parameters,
state.learning_rate(),
state.step_count(),
))
}
fn state_with_output_projection_bias(
state: &TransformerTrainingState,
selection: usize,
value: f32,
) -> CtResult<TransformerTrainingState> {
let output_projection = state.parameters().output_projection();
let mut bias = output_projection.bias().to_vec();
bias[selection] = value;
let output_projection =
AttentionOutputProjection::new(output_projection.weight().to_vec(), bias)?;
let parameters = state
.parameters()
.clone()
.with_output_projection(output_projection)?;
Ok(TransformerTrainingState::from_parts(
parameters,
state.learning_rate(),
state.step_count(),
))
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
enum LayerNormParameterKind {
AttentionScale,
AttentionShift,
FeedForwardScale,
FeedForwardShift,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
struct LayerNormParameterSelection {
kind: LayerNormParameterKind,
feature_index: usize,
}
fn largest_changed_layer_norm_parameter(
before: &TransformerTrainingState,
after: &TransformerTrainingState,
) -> CtResult<LayerNormParameterSelection> {
let mut selected = None;
let mut largest_delta = 0.0;
for kind in [
LayerNormParameterKind::AttentionScale,
LayerNormParameterKind::AttentionShift,
LayerNormParameterKind::FeedForwardScale,
LayerNormParameterKind::FeedForwardShift,
] {
let before_values = layer_norm_parameter_values(before, kind);
let after_values = layer_norm_parameter_values(after, kind);
for (feature_index, (before_value, after_value)) in
before_values.iter().zip(after_values).enumerate()
{
let delta = (before_value - after_value).abs();
if delta > largest_delta {
largest_delta = delta;
selected = Some(LayerNormParameterSelection {
kind,
feature_index,
});
}
}
}
selected.ok_or(CtError::EmptyInput("changed layer norm parameter"))
}
fn layer_norm_parameter_values(
state: &TransformerTrainingState,
kind: LayerNormParameterKind,
) -> &[f32] {
match kind {
LayerNormParameterKind::AttentionScale => {
state.parameters().attention_norm().parameters().scale()
}
LayerNormParameterKind::AttentionShift => {
state.parameters().attention_norm().parameters().shift()
}
LayerNormParameterKind::FeedForwardScale => {
state.parameters().feed_forward_norm().parameters().scale()
}
LayerNormParameterKind::FeedForwardShift => {
state.parameters().feed_forward_norm().parameters().shift()
}
}
}
fn layer_norm_parameter_value(
state: &TransformerTrainingState,
selection: LayerNormParameterSelection,
) -> f32 {
layer_norm_parameter_values(state, selection.kind)[selection.feature_index]
}
fn state_with_layer_norm_parameter(
state: &TransformerTrainingState,
selection: LayerNormParameterSelection,
value: f32,
) -> CtResult<TransformerTrainingState> {
let attention_parameters = state.parameters().attention_norm().parameters();
let feed_forward_parameters = state.parameters().feed_forward_norm().parameters();
let mut attention_scale = attention_parameters.scale().to_vec();
let mut attention_shift = attention_parameters.shift().to_vec();
let mut feed_forward_scale = feed_forward_parameters.scale().to_vec();
let mut feed_forward_shift = feed_forward_parameters.shift().to_vec();
match selection.kind {
LayerNormParameterKind::AttentionScale => {
attention_scale[selection.feature_index] = value;
}
LayerNormParameterKind::AttentionShift => {
attention_shift[selection.feature_index] = value;
}
LayerNormParameterKind::FeedForwardScale => {
feed_forward_scale[selection.feature_index] = value;
}
LayerNormParameterKind::FeedForwardShift => {
feed_forward_shift[selection.feature_index] = value;
}
}
let attention_norm = LayerNormalization::new(LayerNormParameters::new(
attention_scale,
attention_shift,
attention_parameters.epsilon(),
)?);
let feed_forward_norm = LayerNormalization::new(LayerNormParameters::new(
feed_forward_scale,
feed_forward_shift,
feed_forward_parameters.epsilon(),
)?);
let parameters = state
.parameters()
.clone()
.with_layer_norms(attention_norm, feed_forward_norm)?;
Ok(TransformerTrainingState::from_parts(
parameters,
state.learning_rate(),
state.step_count(),
))
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
enum AttentionProjectionKind {
Query,
Key,
Value,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
struct AttentionProjectionSelection {
head_index: usize,
kind: AttentionProjectionKind,
input_index: usize,
output_index: usize,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
struct AttentionProjectionBiasSelection {
head_index: usize,
kind: AttentionProjectionKind,
output_index: usize,
}
fn largest_changed_attention_projection(
before: &TransformerTrainingState,
after: &TransformerTrainingState,
) -> CtResult<AttentionProjectionSelection> {
let head_count = before.parameters().attention_heads().len();
let mut selected = None;
let mut largest_delta = 0.0;
for head_index in 0..head_count {
for kind in [
AttentionProjectionKind::Query,
AttentionProjectionKind::Key,
AttentionProjectionKind::Value,
] {
let before_weight = attention_projection_weight_matrix(before, head_index, kind)?;
let after_weight = attention_projection_weight_matrix(after, head_index, kind)?;
for (input_index, (before_row, after_row)) in
before_weight.iter().zip(after_weight).enumerate()
{
for (output_index, (before_value, after_value)) in
before_row.iter().zip(after_row).enumerate()
{
let delta = (before_value - after_value).abs();
if delta > largest_delta {
largest_delta = delta;
selected = Some(AttentionProjectionSelection {
head_index,
kind,
input_index,
output_index,
});
}
}
}
}
}
selected.ok_or(CtError::EmptyInput("changed attention projection"))
}
fn attention_projection_weight(
state: &TransformerTrainingState,
selection: AttentionProjectionSelection,
) -> CtResult<f32> {
let weight =
attention_projection_weight_matrix(state, selection.head_index, selection.kind)?;
Ok(weight[selection.input_index][selection.output_index])
}
fn attention_projection_weight_matrix(
state: &TransformerTrainingState,
head_index: usize,
kind: AttentionProjectionKind,
) -> CtResult<&[Vec<f32>]> {
let head =
state
.parameters()
.attention_heads()
.get(head_index)
.ok_or(CtError::OutOfRange {
kind: "attention head",
index: head_index,
limit: state.parameters().attention_heads().len(),
})?;
Ok(match kind {
AttentionProjectionKind::Query => head.query_projection().weight(),
AttentionProjectionKind::Key => head.key_projection().weight(),
AttentionProjectionKind::Value => head.value_projection().weight(),
})
}
fn largest_changed_attention_projection_bias(
before: &TransformerTrainingState,
after: &TransformerTrainingState,
) -> CtResult<AttentionProjectionBiasSelection> {
let head_count = before.parameters().attention_heads().len();
let mut selected = None;
let mut largest_delta = 0.0;
for head_index in 0..head_count {
for kind in [
AttentionProjectionKind::Query,
AttentionProjectionKind::Key,
AttentionProjectionKind::Value,
] {
let before_bias = attention_projection_bias_values(before, head_index, kind)?;
let after_bias = attention_projection_bias_values(after, head_index, kind)?;
for (output_index, (before_value, after_value)) in
before_bias.iter().zip(after_bias).enumerate()
{
let delta = (before_value - after_value).abs();
if delta > largest_delta {
largest_delta = delta;
selected = Some(AttentionProjectionBiasSelection {
head_index,
kind,
output_index,
});
}
}
}
}
selected.ok_or(CtError::EmptyInput("changed attention projection bias"))
}
fn attention_projection_bias(
state: &TransformerTrainingState,
selection: AttentionProjectionBiasSelection,
) -> CtResult<f32> {
let bias = attention_projection_bias_values(state, selection.head_index, selection.kind)?;
Ok(bias[selection.output_index])
}
fn attention_projection_bias_values(
state: &TransformerTrainingState,
head_index: usize,
kind: AttentionProjectionKind,
) -> CtResult<&[f32]> {
let head =
state
.parameters()
.attention_heads()
.get(head_index)
.ok_or(CtError::OutOfRange {
kind: "attention head",
index: head_index,
limit: state.parameters().attention_heads().len(),
})?;
Ok(match kind {
AttentionProjectionKind::Query => head.query_projection().bias(),
AttentionProjectionKind::Key => head.key_projection().bias(),
AttentionProjectionKind::Value => head.value_projection().bias(),
})
}
fn state_with_attention_projection_weight(
state: &TransformerTrainingState,
selection: AttentionProjectionSelection,
value: f32,
) -> CtResult<TransformerTrainingState> {
let mut heads = state.parameters().attention_heads().to_vec();
let head = heads.get(selection.head_index).ok_or(CtError::OutOfRange {
kind: "attention head",
index: selection.head_index,
limit: heads.len(),
})?;
let mut query_weight = head.query_projection().weight().to_vec();
let mut key_weight = head.key_projection().weight().to_vec();
let mut value_weight = head.value_projection().weight().to_vec();
match selection.kind {
AttentionProjectionKind::Query => {
query_weight[selection.input_index][selection.output_index] = value;
}
AttentionProjectionKind::Key => {
key_weight[selection.input_index][selection.output_index] = value;
}
AttentionProjectionKind::Value => {
value_weight[selection.input_index][selection.output_index] = value;
}
}
heads[selection.head_index] = SelfAttentionHead::new(
HiddenToQuery::new(query_weight, head.query_projection().bias().to_vec())?,
HiddenToKey::new(key_weight, head.key_projection().bias().to_vec())?,
HiddenToValue::new(value_weight, head.value_projection().bias().to_vec())?,
)?;
let parameters = state.parameters().clone().with_attention_heads(heads)?;
Ok(TransformerTrainingState::from_parts(
parameters,
state.learning_rate(),
state.step_count(),
))
}
fn state_with_attention_projection_bias(
state: &TransformerTrainingState,
selection: AttentionProjectionBiasSelection,
value: f32,
) -> CtResult<TransformerTrainingState> {
let mut heads = state.parameters().attention_heads().to_vec();
let head = heads.get(selection.head_index).ok_or(CtError::OutOfRange {
kind: "attention head",
index: selection.head_index,
limit: heads.len(),
})?;
let mut query_bias = head.query_projection().bias().to_vec();
let mut key_bias = head.key_projection().bias().to_vec();
let mut value_bias = head.value_projection().bias().to_vec();
match selection.kind {
AttentionProjectionKind::Query => {
query_bias[selection.output_index] = value;
}
AttentionProjectionKind::Key => {
key_bias[selection.output_index] = value;
}
AttentionProjectionKind::Value => {
value_bias[selection.output_index] = value;
}
}
heads[selection.head_index] = SelfAttentionHead::new(
HiddenToQuery::new(head.query_projection().weight().to_vec(), query_bias)?,
HiddenToKey::new(head.key_projection().weight().to_vec(), key_bias)?,
HiddenToValue::new(head.value_projection().weight().to_vec(), value_bias)?,
)?;
let parameters = state.parameters().clone().with_attention_heads(heads)?;
Ok(TransformerTrainingState::from_parts(
parameters,
state.learning_rate(),
state.step_count(),
))
}
fn tiny_single_head_block() -> CtResult<SingleHeadTransformerBlock> {
let query = HiddenToQuery::new(vec![vec![1.0, 0.0], vec![0.0, 1.0]], vec![0.0, 0.0])?;
let key = HiddenToKey::new(vec![vec![1.0, 0.0], vec![0.0, 1.0]], vec![0.0, 0.0])?;
let value = HiddenToValue::new(vec![vec![1.0, 0.0], vec![0.0, 1.0]], vec![0.0, 0.0])?;
let output_projection =
AttentionOutputProjection::new(vec![vec![1.0, 0.0], vec![0.0, 1.0]], vec![0.0, 0.0])?;
let model_dimension = ModelDimension::new(2)?;
let attention_norm =
LayerNormalization::new(LayerNormParameters::identity(model_dimension));
let feed_forward = PositionWiseFeedForward::new(
vec![vec![1.0, 0.0], vec![0.0, 1.0]],
vec![0.0, 0.0],
vec![vec![1.0, 0.0], vec![0.0, 1.0]],
vec![0.0, 0.0],
)?;
let feed_forward_norm =
LayerNormalization::new(LayerNormParameters::identity(model_dimension));
SingleHeadTransformerBlock::new(
query,
key,
value,
output_projection,
attention_norm,
feed_forward,
feed_forward_norm,
)
}
fn tiny_multi_head_block() -> CtResult<MultiHeadTransformerBlock> {
let model_dimension = ModelDimension::new(2)?;
MultiHeadTransformerBlock::new(
vec![
tiny_self_attention_head_first_feature()?,
tiny_self_attention_head_second_feature()?,
],
AttentionOutputProjection::new(vec![vec![1.0, 0.0], vec![0.0, 1.0]], vec![0.0, 0.0])?,
LayerNormalization::new(LayerNormParameters::identity(model_dimension)),
identity_feed_forward()?,
LayerNormalization::new(LayerNormParameters::identity(model_dimension)),
)
}
fn tiny_masked_multi_head_block() -> CtResult<MaskedMultiHeadTransformerBlock> {
let model_dimension = ModelDimension::new(2)?;
MaskedMultiHeadTransformerBlock::new(
vec![
tiny_self_attention_head_first_feature()?,
tiny_self_attention_head_second_feature()?,
],
AttentionOutputProjection::new(vec![vec![1.0, 0.0], vec![0.0, 1.0]], vec![0.0, 0.0])?,
LayerNormalization::new(LayerNormParameters::identity(model_dimension)),
identity_feed_forward()?,
LayerNormalization::new(LayerNormParameters::identity(model_dimension)),
)
}
fn tiny_transformer_parameters() -> CtResult<TinyTransformerParameters> {
TinyTransformerParameters::new(
PositionalEncoding::new(vec![vec![0.1, 0.0], vec![0.0, 0.1]])?,
tiny_masked_multi_head_block()?,
TransformerReadout::new(
vec![vec![1.0, 0.0, 0.5], vec![0.0, 1.0, -0.5]],
vec![0.0, 0.0, 0.0],
)?,
)
}
fn tiny_transformer_training_set() -> CtResult<TransformerReadoutTrainingSet> {
let example = TransformerReadoutTrainingExample::new(
HiddenSequence::new(vec![vec![1.0, 0.0], vec![0.0, 1.0]])?,
AttentionMask::new(vec![vec![true, false], vec![true, true]])?,
TokenSequence::from_indices([0, 1])?,
)?;
TransformerReadoutTrainingSet::new([example])
}
fn tiny_feed_forward_training_set() -> CtResult<TransformerFeedForwardTrainingSet> {
let example = TransformerFeedForwardTrainingExample::new(
HiddenSequence::new(vec![vec![1.0, 0.0], vec![0.0, 1.0]])?,
HiddenSequence::new(vec![vec![2.0, 0.0], vec![0.0, 2.0]])?,
)?;
TransformerFeedForwardTrainingSet::new([example])
}
fn tiny_transformer_block_training_set() -> CtResult<TransformerBlockTrainingSet> {
let example = TransformerBlockTrainingExample::new(
HiddenSequence::new(vec![vec![1.0, 0.0], vec![0.0, 1.0]])?,
AttentionMask::new(vec![vec![true, false], vec![true, true]])?,
TokenSequence::from_indices([0, 1])?,
)?;
TransformerBlockTrainingSet::new([example])
}
fn tiny_self_attention_head_first_feature() -> CtResult<SelfAttentionHead> {
SelfAttentionHead::new(
HiddenToQuery::new(vec![vec![1.0], vec![0.0]], vec![0.0])?,
HiddenToKey::new(vec![vec![1.0], vec![0.0]], vec![0.0])?,
HiddenToValue::new(vec![vec![1.0], vec![0.0]], vec![0.0])?,
)
}
fn tiny_self_attention_head_second_feature() -> CtResult<SelfAttentionHead> {
SelfAttentionHead::new(
HiddenToQuery::new(vec![vec![0.0], vec![1.0]], vec![0.0])?,
HiddenToKey::new(vec![vec![0.0], vec![1.0]], vec![0.0])?,
HiddenToValue::new(vec![vec![0.0], vec![1.0]], vec![0.0])?,
)
}
fn identity_feed_forward() -> CtResult<PositionWiseFeedForward> {
PositionWiseFeedForward::new(
vec![vec![1.0, 0.0], vec![0.0, 1.0]],
vec![0.0, 0.0],
vec![vec![1.0, 0.0], vec![0.0, 1.0]],
vec![0.0, 0.0],
)
}
}
The full runnable companion is:
Source snapshot: examples/06_attention_scores.rs
use category_theory_transformer_rs::{
AttentionHeadOutputs, AttentionMask, AttentionOutput, AttentionOutputProjection,
AttentionSoftmax, ConcatenateHeads, CtResult, HiddenSequence, HiddenToKey, HiddenToQuery,
HiddenToValue, KeySequence, LayerNormParameters, LayerNormalization, LearningRate,
MaskedAttentionScores, MaskedMultiHeadTransformerBlock, Morphism, MultiHeadTransformerBlock,
PositionWiseFeedForward, PositionalEncoding, Product, QuerySequence, ResidualConnection,
ScaledDotProductScores, SelfAttentionHead, SingleHeadTransformerBlock,
TinyTransformerParameters, TokenSequence, TransformerBlockTrainStep,
TransformerBlockTrainingExample, TransformerBlockTrainingSet, TransformerFeedForwardTrainStep,
TransformerFeedForwardTrainingExample, TransformerFeedForwardTrainingSet, TransformerReadout,
TransformerReadoutTrainStep, TransformerReadoutTrainingExample, TransformerReadoutTrainingSet,
TransformerTrainingState, ValueSequence, WeightedValueMixing, transformer_block_average_loss,
transformer_feed_forward_average_loss, transformer_readout_average_loss,
};
fn main() -> CtResult<()> {
let queries = QuerySequence::new(vec![vec![1.0, 0.0], vec![0.0, 1.0]])?;
let keys = KeySequence::new(vec![vec![1.0, 0.0], vec![0.0, 1.0], vec![1.0, 1.0]])?;
let values = ValueSequence::new(vec![vec![1.0, 10.0], vec![2.0, 20.0], vec![3.0, 30.0]])?;
let mask = AttentionMask::new(vec![vec![true, false, true], vec![true, true, true]])?;
println!("Q/K/V source diagnostic:");
println!("query rows own score rows; key/value rows own score columns");
println!(
"self-attention shares the hidden source before projection; projected roles stay distinct"
);
println!("mask polarity here: true = allowed, false = blocked\n");
let scores = ScaledDotProductScores.apply(Product::new(queries, keys))?;
let masked_scores = MaskedAttentionScores.apply(Product::new(scores, mask))?;
let weights = AttentionSoftmax.apply(masked_scores)?;
println!(
"attention shape: {} query positions x {} key positions",
weights.query_len().value(),
weights.key_len().value()
);
for (query_position, row) in weights.rows().iter().enumerate() {
println!("query {query_position} attends with {:?}", row.as_slice());
}
let output = WeightedValueMixing.apply(Product::new(weights, values))?;
for (query_position, row) in output.rows().iter().enumerate() {
println!("query {query_position} output vector {:?}", row.as_slice());
}
let second_head = AttentionOutput::new(vec![vec![10.0, 1.0], vec![20.0, 2.0]])?;
let head_outputs = AttentionHeadOutputs::new(vec![output, second_head])?;
let multi_head = ConcatenateHeads.apply(head_outputs)?;
println!(
"multi-head shape: {} heads x {} features -> model dimension {}",
multi_head.head_count().value(),
multi_head.head_dimension().value(),
multi_head.model_dimension().value()
);
for (query_position, row) in multi_head.rows().iter().enumerate() {
println!("query {query_position} multi-head row {:?}", row.as_slice());
}
let output_projection = AttentionOutputProjection::new(
vec![
vec![1.0, 0.0],
vec![0.0, 0.1],
vec![0.5, 0.0],
vec![0.0, 1.0],
],
vec![0.0, 0.0],
)?;
let projected = output_projection.apply(multi_head)?;
println!(
"projected attention shape: {} positions x model dimension {}",
projected.sequence_len().value(),
projected.model_dimension().value()
);
for (query_position, row) in projected.rows().iter().enumerate() {
println!(
"query {query_position} projected attention row {:?}",
row.as_slice()
);
}
let hidden_input = HiddenSequence::new(vec![vec![0.5, 0.5], vec![1.0, 1.0]])?;
let residual = ResidualConnection.apply(Product::new(hidden_input, projected))?;
println!(
"residual shape: {} positions x model dimension {}",
residual.sequence_len().value(),
residual.model_dimension().value()
);
for (query_position, row) in residual.rows().iter().enumerate() {
println!("query {query_position} residual row {:?}", row.as_slice());
}
let layer_norm =
LayerNormalization::new(LayerNormParameters::identity(residual.model_dimension()));
let normalized = layer_norm.apply(residual)?;
println!(
"normalized shape: {} positions x model dimension {}",
normalized.sequence_len().value(),
normalized.model_dimension().value()
);
for (query_position, row) in normalized.rows().iter().enumerate() {
println!("query {query_position} normalized row {:?}", row.as_slice());
}
let feed_forward = PositionWiseFeedForward::new(
vec![vec![1.0, -1.0, 0.5], vec![0.0, 1.0, 0.5]],
vec![0.0, 0.0, 0.0],
vec![vec![1.0, 0.0], vec![0.0, 1.0], vec![0.5, 0.5]],
vec![0.0, 0.0],
)?;
let fed_forward = feed_forward.apply(normalized)?;
println!(
"feed-forward shape: {} positions x model dimension {}",
fed_forward.sequence_len().value(),
fed_forward.model_dimension().value()
);
for (query_position, row) in fed_forward.rows().iter().enumerate() {
println!(
"query {query_position} feed-forward row {:?}",
row.as_slice()
);
}
let positional_encoding = PositionalEncoding::new(vec![vec![0.1, 0.0], vec![0.0, 0.1]])?;
let positioned_hidden =
positional_encoding.apply(HiddenSequence::new(vec![vec![1.0, 0.0], vec![0.0, 1.0]])?)?;
println!(
"positioned hidden shape: {} positions x model dimension {}",
positioned_hidden.sequence_len().value(),
positioned_hidden.model_dimension().value()
);
let block = SingleHeadTransformerBlock::new(
HiddenToQuery::new(vec![vec![1.0, 0.0], vec![0.0, 1.0]], vec![0.0, 0.0])?,
HiddenToKey::new(vec![vec![1.0, 0.0], vec![0.0, 1.0]], vec![0.0, 0.0])?,
HiddenToValue::new(vec![vec![1.0, 0.0], vec![0.0, 1.0]], vec![0.0, 0.0])?,
AttentionOutputProjection::new(vec![vec![1.0, 0.0], vec![0.0, 1.0]], vec![0.0, 0.0])?,
LayerNormalization::new(LayerNormParameters::identity(fed_forward.model_dimension())),
PositionWiseFeedForward::new(
vec![vec![1.0, 0.0], vec![0.0, 1.0]],
vec![0.0, 0.0],
vec![vec![1.0, 0.0], vec![0.0, 1.0]],
vec![0.0, 0.0],
)?,
LayerNormalization::new(LayerNormParameters::identity(fed_forward.model_dimension())),
)?;
let block_output = block.apply(positioned_hidden.clone())?;
println!(
"single-head block shape: {} positions x model dimension {}",
block_output.sequence_len().value(),
block_output.model_dimension().value()
);
let multi_head_block = MultiHeadTransformerBlock::new(
vec![
SelfAttentionHead::new(
HiddenToQuery::new(vec![vec![1.0], vec![0.0]], vec![0.0])?,
HiddenToKey::new(vec![vec![1.0], vec![0.0]], vec![0.0])?,
HiddenToValue::new(vec![vec![1.0], vec![0.0]], vec![0.0])?,
)?,
SelfAttentionHead::new(
HiddenToQuery::new(vec![vec![0.0], vec![1.0]], vec![0.0])?,
HiddenToKey::new(vec![vec![0.0], vec![1.0]], vec![0.0])?,
HiddenToValue::new(vec![vec![0.0], vec![1.0]], vec![0.0])?,
)?,
],
AttentionOutputProjection::new(vec![vec![1.0, 0.0], vec![0.0, 1.0]], vec![0.0, 0.0])?,
LayerNormalization::new(LayerNormParameters::identity(
block_output.model_dimension(),
)),
PositionWiseFeedForward::new(
vec![vec![1.0, 0.0], vec![0.0, 1.0]],
vec![0.0, 0.0],
vec![vec![1.0, 0.0], vec![0.0, 1.0]],
vec![0.0, 0.0],
)?,
LayerNormalization::new(LayerNormParameters::identity(
block_output.model_dimension(),
)),
)?;
let multi_head_output = multi_head_block.apply(positioned_hidden)?;
println!(
"multi-head block shape: {} positions x {} heads x value dimension {} -> model dimension {}",
multi_head_output.sequence_len().value(),
multi_head_block.head_count().value(),
multi_head_block.value_dimension().value(),
multi_head_output.model_dimension().value()
);
let masked_multi_head_block = MaskedMultiHeadTransformerBlock::new(
vec![
SelfAttentionHead::new(
HiddenToQuery::new(vec![vec![1.0], vec![0.0]], vec![0.0])?,
HiddenToKey::new(vec![vec![1.0], vec![0.0]], vec![0.0])?,
HiddenToValue::new(vec![vec![1.0], vec![0.0]], vec![0.0])?,
)?,
SelfAttentionHead::new(
HiddenToQuery::new(vec![vec![0.0], vec![1.0]], vec![0.0])?,
HiddenToKey::new(vec![vec![0.0], vec![1.0]], vec![0.0])?,
HiddenToValue::new(vec![vec![0.0], vec![1.0]], vec![0.0])?,
)?,
],
AttentionOutputProjection::new(vec![vec![1.0, 0.0], vec![0.0, 1.0]], vec![0.0, 0.0])?,
LayerNormalization::new(LayerNormParameters::identity(
multi_head_output.model_dimension(),
)),
PositionWiseFeedForward::new(
vec![vec![1.0, 0.0], vec![0.0, 1.0]],
vec![0.0, 0.0],
vec![vec![1.0, 0.0], vec![0.0, 1.0]],
vec![0.0, 0.0],
)?,
LayerNormalization::new(LayerNormParameters::identity(
multi_head_output.model_dimension(),
)),
)?;
let masked_block_output = masked_multi_head_block.apply(Product::new(
HiddenSequence::new(vec![vec![1.0, 0.0], vec![0.0, 1.0]])?,
AttentionMask::new(vec![vec![true, false], vec![true, true]])?,
))?;
println!(
"masked multi-head block shape: {} positions x model dimension {}",
masked_block_output.sequence_len().value(),
masked_block_output.model_dimension().value()
);
let transformer_parameters = TinyTransformerParameters::new(
PositionalEncoding::new(vec![vec![0.1, 0.0], vec![0.0, 0.1]])?,
masked_multi_head_block,
TransformerReadout::new(
vec![vec![1.0, 0.0, 0.5], vec![0.0, 1.0, -0.5]],
vec![0.0, 0.0, 0.0],
)?,
)?;
let transformer_state =
TransformerTrainingState::new(transformer_parameters, LearningRate::new(0.1)?);
let training_set =
TransformerReadoutTrainingSet::new([TransformerReadoutTrainingExample::new(
HiddenSequence::new(vec![vec![1.0, 0.0], vec![0.0, 1.0]])?,
AttentionMask::new(vec![vec![true, false], vec![true, true]])?,
TokenSequence::from_indices([0, 1])?,
)?])?;
let sequence_logits = transformer_state.apply(Product::new(
HiddenSequence::new(vec![vec![1.0, 0.0], vec![0.0, 1.0]])?,
AttentionMask::new(vec![vec![true, false], vec![true, true]])?,
))?;
let loss_before = transformer_readout_average_loss(&transformer_state, &training_set)?;
let train_step = TransformerReadoutTrainStep::new(training_set.clone());
let next_state = train_step.apply(transformer_state.clone())?;
let loss_after = transformer_readout_average_loss(&next_state, &training_set)?;
println!(
"structured transformer logits shape: {} positions x vocabulary size {}",
sequence_logits.sequence_len().value(),
sequence_logits.vocab_size().value()
);
println!(
"training state step: {} -> {}",
transformer_state.step_count().value(),
next_state.step_count().value()
);
println!(
"readout loss after one update: {:.6} -> {:.6}",
loss_before.value(),
loss_after.value()
);
let feed_forward_training_set =
TransformerFeedForwardTrainingSet::new([TransformerFeedForwardTrainingExample::new(
HiddenSequence::new(vec![vec![1.0, 0.0], vec![0.0, 1.0]])?,
HiddenSequence::new(vec![vec![2.0, 0.0], vec![0.0, 2.0]])?,
)?])?;
let feed_forward_loss_before =
transformer_feed_forward_average_loss(&next_state, &feed_forward_training_set)?;
let feed_forward_train_step =
TransformerFeedForwardTrainStep::new(feed_forward_training_set.clone());
let feed_forward_state = feed_forward_train_step.apply(next_state.clone())?;
let feed_forward_loss_after =
transformer_feed_forward_average_loss(&feed_forward_state, &feed_forward_training_set)?;
println!(
"feed-forward loss after one local update: {:.6} -> {:.6}",
feed_forward_loss_before.value(),
feed_forward_loss_after.value()
);
let block_training_set =
TransformerBlockTrainingSet::new([TransformerBlockTrainingExample::new(
HiddenSequence::new(vec![vec![1.0, 0.0], vec![0.0, 1.0]])?,
AttentionMask::new(vec![vec![true, false], vec![true, true]])?,
TokenSequence::from_indices([0, 1])?,
)?])?;
let block_loss_before =
transformer_block_average_loss(&feed_forward_state, &block_training_set)?;
let block_train_step = TransformerBlockTrainStep::new(block_training_set.clone());
let block_trained_state = block_train_step.apply(feed_forward_state)?;
let block_loss_after =
transformer_block_average_loss(&block_trained_state, &block_training_set)?;
println!(
"block loss after one composed update: {:.6} -> {:.6}",
block_loss_before.value(),
block_loss_after.value()
);
println!();
println!("Typed transformation:");
println!("HiddenSequence -> QuerySequence");
println!("HiddenSequence -> KeySequence");
println!("HiddenSequence -> ValueSequence");
println!("QuerySequence x KeySequence -> AttentionScores");
println!("AttentionScores x AttentionMask -> AttentionScores");
println!("AttentionScores -> AttentionWeights");
println!("AttentionWeights x ValueSequence -> AttentionOutput");
println!("AttentionHeadOutputs -> MultiHeadOutput");
println!("MultiHeadOutput -> ProjectedAttentionOutput");
println!("HiddenSequence x ProjectedAttentionOutput -> HiddenSequence");
println!("LayerNormalization : HiddenSequence -> HiddenSequence");
println!("PositionWiseFeedForward : HiddenSequence -> HiddenSequence");
println!("PositionalEncoding : HiddenSequence -> HiddenSequence");
println!("SingleHeadTransformerBlock : HiddenSequence -> HiddenSequence");
println!("MultiHeadTransformerBlock : HiddenSequence -> HiddenSequence");
println!("MaskedMultiHeadTransformerBlock : HiddenSequence x AttentionMask -> HiddenSequence");
println!("TinyTransformerParameters : HiddenSequence x AttentionMask -> SequenceLogits");
println!("TransformerTrainingState owns parameters, learning rate, and step count");
println!("TransformerReadoutTrainStep : TransformerTrainingState -> TransformerTrainingState");
println!(
"TransformerFeedForwardTrainStep : TransformerTrainingState -> TransformerTrainingState"
);
println!("TransformerBlockTrainStep : TransformerTrainingState -> TransformerTrainingState");
Ok(())
}
The smaller state-only companion is:
Source snapshot: examples/07_transformer_training_state.rs
use category_theory_transformer_rs::{
AttentionMask, AttentionOutputProjection, CtResult, HiddenSequence, HiddenToKey, HiddenToQuery,
HiddenToValue, LayerNormParameters, LayerNormalization, LearningRate,
MaskedMultiHeadTransformerBlock, ModelDimension, Morphism, PositionWiseFeedForward,
PositionalEncoding, Product, SelfAttentionHead, TinyTransformerParameters, TokenSequence,
TransformerBlockTrainStep, TransformerBlockTrainingExample, TransformerBlockTrainingSet,
TransformerFeedForwardTrainStep, TransformerFeedForwardTrainingExample,
TransformerFeedForwardTrainingSet, TransformerReadout, TransformerReadoutTrainStep,
TransformerReadoutTrainingExample, TransformerReadoutTrainingSet, TransformerTrainingState,
transformer_block_average_loss, transformer_feed_forward_average_loss,
transformer_readout_average_loss,
};
fn main() -> CtResult<()> {
let hidden = HiddenSequence::new(vec![vec![1.0, 0.0], vec![0.0, 1.0]])?;
let mask = AttentionMask::new(vec![vec![true, false], vec![true, true]])?;
let targets = TokenSequence::from_indices([0, 1])?;
let initial_state = tiny_training_state()?;
let logits = initial_state.apply(Product::new(hidden.clone(), mask.clone()))?;
println!(
"initial state: step={}, learning_rate={:.3}, model_dimension={}, vocab_size={}",
initial_state.step_count().value(),
initial_state.learning_rate().value(),
initial_state.parameters().model_dimension().value(),
initial_state.parameters().vocab_size().value()
);
println!(
"forward shape: {} positions x vocabulary size {}",
logits.sequence_len().value(),
logits.vocab_size().value()
);
let readout_set =
TransformerReadoutTrainingSet::new([TransformerReadoutTrainingExample::new(
hidden.clone(),
mask.clone(),
targets.clone(),
)?])?;
let readout_loss_before = transformer_readout_average_loss(&initial_state, &readout_set)?;
let readout_state =
TransformerReadoutTrainStep::new(readout_set.clone()).apply(initial_state)?;
let readout_loss_after = transformer_readout_average_loss(&readout_state, &readout_set)?;
print_update(
"readout update",
0,
&readout_state,
readout_loss_before.value(),
readout_loss_after.value(),
);
let feed_forward_set =
TransformerFeedForwardTrainingSet::new([TransformerFeedForwardTrainingExample::new(
hidden.clone(),
HiddenSequence::new(vec![vec![2.0, 0.0], vec![0.0, 2.0]])?,
)?])?;
let feed_forward_loss_before =
transformer_feed_forward_average_loss(&readout_state, &feed_forward_set)?;
let feed_forward_state =
TransformerFeedForwardTrainStep::new(feed_forward_set.clone()).apply(readout_state)?;
let feed_forward_loss_after =
transformer_feed_forward_average_loss(&feed_forward_state, &feed_forward_set)?;
print_update(
"feed-forward update",
1,
&feed_forward_state,
feed_forward_loss_before.value(),
feed_forward_loss_after.value(),
);
let block_set = TransformerBlockTrainingSet::new([TransformerBlockTrainingExample::new(
hidden, mask, targets,
)?])?;
let block_loss_before = transformer_block_average_loss(&feed_forward_state, &block_set)?;
let block_state =
TransformerBlockTrainStep::new(block_set.clone()).apply(feed_forward_state)?;
let block_loss_after = transformer_block_average_loss(&block_state, &block_set)?;
print_update(
"composed block update",
2,
&block_state,
block_loss_before.value(),
block_loss_after.value(),
);
println!();
println!("Typed transformation:");
println!("TinyTransformerParameters : HiddenSequence x AttentionMask -> SequenceLogits");
println!("TransformerTrainingState owns parameters, learning rate, and step count");
println!("TransformerReadoutTrainStep : TransformerTrainingState -> TransformerTrainingState");
println!(
"TransformerFeedForwardTrainStep : TransformerTrainingState -> TransformerTrainingState"
);
println!("TransformerBlockTrainStep : TransformerTrainingState -> TransformerTrainingState");
println!("Every update returns a full training state, not loose changed weights.");
Ok(())
}
fn print_update(
label: &str,
previous_step: usize,
state: &TransformerTrainingState,
loss_before: f32,
loss_after: f32,
) {
println!(
"{label}: step {} -> {}, loss {:.6} -> {:.6}",
previous_step,
state.step_count().value(),
loss_before,
loss_after
);
}
fn tiny_training_state() -> CtResult<TransformerTrainingState> {
let model_dimension = ModelDimension::new(2)?;
let block = MaskedMultiHeadTransformerBlock::new(
vec![
SelfAttentionHead::new(
HiddenToQuery::new(vec![vec![1.0], vec![0.0]], vec![0.0])?,
HiddenToKey::new(vec![vec![1.0], vec![0.0]], vec![0.0])?,
HiddenToValue::new(vec![vec![1.0], vec![0.0]], vec![0.0])?,
)?,
SelfAttentionHead::new(
HiddenToQuery::new(vec![vec![0.0], vec![1.0]], vec![0.0])?,
HiddenToKey::new(vec![vec![0.0], vec![1.0]], vec![0.0])?,
HiddenToValue::new(vec![vec![0.0], vec![1.0]], vec![0.0])?,
)?,
],
AttentionOutputProjection::new(vec![vec![1.0, 0.0], vec![0.0, 1.0]], vec![0.0, 0.0])?,
LayerNormalization::new(LayerNormParameters::identity(model_dimension)),
PositionWiseFeedForward::new(
vec![vec![1.0, 0.0], vec![0.0, 1.0]],
vec![0.0, 0.0],
vec![vec![1.0, 0.0], vec![0.0, 1.0]],
vec![0.0, 0.0],
)?,
LayerNormalization::new(LayerNormParameters::identity(model_dimension)),
)?;
let parameters = TinyTransformerParameters::new(
PositionalEncoding::new(vec![vec![0.1, 0.0], vec![0.0, 0.1]])?,
block,
TransformerReadout::new(
vec![vec![1.0, 0.0, 0.5], vec![0.0, 1.0, -0.5]],
vec![0.0, 0.0, 0.0],
)?,
)?;
Ok(TransformerTrainingState::new(
parameters,
LearningRate::new(0.1)?,
))
}
Run it with:
cargo run --example 06_attention_scores
If you only want to inspect the training-state update shape, run the smaller companion example:
cargo run --example 07_transformer_training_state
You should see two query positions and three key positions. Query and key vectors first produce score rows. The mask removes one illegal score position. Then each row is normalized independently, and the weights mix value vectors:
QuerySequence x KeySequence -> AttentionScores
AttentionScores x AttentionMask -> AttentionScores
AttentionScores -> AttentionWeights
AttentionWeights x ValueSequence -> AttentionOutput
AttentionHeadOutputs -> MultiHeadOutput
MultiHeadOutput -> ProjectedAttentionOutput
HiddenSequence x ProjectedAttentionOutput -> HiddenSequence
HiddenSequence -> HiddenSequence
That is the bridge from the current softmax chapter. The current Distribution
answers:
which next token is likely?
Attention weights answer:
which source positions should this position read from?
Both are probability-like objects. They differ in what their support means.
ML Concept
Attention computes:
scores = QK^T / sqrt(d)
weights = softmax(scores)
output = weights V
This is softmax again, but applied to token-to-token interaction scores.
Category Theory Concept
The attention block is a composition of typed maps with a product input:
(Q, K, V) -> scores -> weights -> mixed values
Design contract:
Attention should have a positive test showing that valid shapes compose and a negative test showing that mismatched head dimensions, mask shapes, or value lengths are rejected at construction or composition time.
Step 4: Multi-Head Concatenation
The current problem:
One attention head sees one interaction pattern. Multiple heads let the model carry several patterns in parallel. Their outputs must be recombined without losing shape information.
Rust Syntax
The recombination boundary is:
AttentionHeadOutputs -> MultiHeadOutput
HeadCount rejects zero. AttentionHeadOutputs rejects an empty collection,
sequence-length mismatches, and head-dimension mismatches. MultiHeadOutput
records:
sequence length
head count
head dimension
model dimension
This boundary is not the whole block by itself. It is the place where separate head outputs become one wider object before the output projection.
Worked Example: Concatenate Heads, Then Project
Multi-head attention adds one shape calculation that readers should be able to do without a framework:
model_dimension = head_count * head_dimension
In the runnable attention example, the first head has two output features per query position:
head_0 query_0 = [2.0, 20.0]
head_0 query_1 = [2.2033, 22.0334]
The example then adds a second head with the same sequence length and head dimension:
head_1 query_0 = [10.0, 1.0]
head_1 query_1 = [20.0, 2.0]
Concatenation does not average the heads. It places their feature rows side by side:
query_0 multi-head row = [2.0, 20.0, 10.0, 1.0]
query_1 multi-head row = [2.2033, 22.0334, 20.0, 2.0]
The recorded shape is:
2 heads x 2 features = model dimension 4
The next boundary is the learned output projection:
MultiHeadOutput -> ProjectedAttentionOutput
For the first query row, the tiny projection in the example uses:
input row:
[2.0, 20.0, 10.0, 1.0]
projection rows:
[1.0, 0.0]
[0.0, 0.1]
[0.5, 0.0]
[0.0, 1.0]
The projected row is:
first output feature = 2.0 * 1.0 + 20.0 * 0.0 + 10.0 * 0.5 + 1.0 * 0.0 = 7.0
second output feature = 2.0 * 0.0 + 20.0 * 0.1 + 10.0 * 0.0 + 1.0 * 1.0 = 3.0
projected query_0 = [7.0, 3.0]
This is the reason the repository keeps two separate boundaries:
AttentionHeadOutputs -> MultiHeadOutput
MultiHeadOutput -> ProjectedAttentionOutput
The first boundary proves that head outputs can be concatenated. The second boundary proves that the concatenated width matches the projection input width. If either relationship fails, the model should fail at the boundary, before a later residual connection receives the wrong shape.
ML Concept
Each head performs attention separately.
The outputs are concatenated, then an output projection maps the combined row back into the hidden-state width. This repository now implements that projection as a separate typed boundary.
Category Theory Concept
This is parallel composition followed by recombination:
head_1 x head_2 x ... x head_n -> MultiHeadOutput
Design contract:
HeadCount, HeadDimension, and ModelDimension are not bare usize values.
The arithmetic relationship between them is part of the architecture: if there
are two heads of width two, the concatenated model dimension is four.
Step 5: Output Projection
The current problem:
Concatenated heads are wider than a single head. A later Transformer block expects a coherent hidden width again.
Rust Syntax
The current code models the projection as:
MultiHeadOutput -> ProjectedAttentionOutput
AttentionOutputProjection validates:
non-empty weight rows
non-empty bias
finite weight and bias values
weight rows matching output dimension
input dimension matching MultiHeadOutput model dimension
That shape follows the multi-head attention reference path: heads are concatenated, then another learned linear projection produces the output sequence.
ML Concept
The projection is a learned linear map after concatenation. It lets the model mix features across heads and return to the width expected by the surrounding block.
Category Theory Concept
This is another typed morphism:
MultiHeadOutput -> ProjectedAttentionOutput
Design contract:
The projection should fail before multiplication if the concatenated head width does not match the projection’s input dimension. That is a boundary invariant, not an indexing accident.
Step 6: Residual Addition
The current problem:
Transformer sublayers need to add their output back to the hidden sequence they received.
Rust Syntax
The current code models residual addition as:
HiddenSequence x ProjectedAttentionOutput -> HiddenSequence
ResidualConnection rejects sequence-length mismatches and model-dimension
mismatches before adding rows. This follows the Transformer requirement that a
sublayer output must have the same dimension as its input for residual addition
to be feasible.
Worked Example: Residual Addition Needs The Same Shape
The previous worked example ended with a projected attention row:
projected query_0 = [7.0, 3.0]
Residual addition can only happen because that projected row has the same model dimension as the hidden row it will be added to:
hidden query_0 = [0.5, 0.5]
projected query_0 = [7.0, 3.0]
residual query_0 = [7.5, 3.5]
The second row follows the same rule:
hidden query_1 = [1.0, 1.0]
projected query_1 = [12.2033, 4.2033]
residual query_1 = [13.2033, 5.2033]
The shape is preserved:
HiddenSequence:
2 positions x model dimension 2
ProjectedAttentionOutput:
2 positions x model dimension 2
Residual HiddenSequence:
2 positions x model dimension 2
This is why the output projection matters. Multi-head concatenation produced a four-feature row. Residual addition needs a two-feature row because the hidden sequence has model dimension two in this tiny example. The projection is the bridge that makes the residual boundary legal.
The invalid shortcut would be:
HiddenSequence x MultiHeadOutput -> HiddenSequence
For the example above, that would try to add a two-feature hidden row to a four-feature concatenated row. The repository avoids that by making the legal path explicit:
MultiHeadOutput -> ProjectedAttentionOutput
HiddenSequence x ProjectedAttentionOutput -> HiddenSequence
ML Concept
Residual addition preserves the hidden sequence shape while allowing a sublayer to contribute a learned change:
hidden + sublayer_output
The repository currently implements the addition boundary, the layer normalization boundary, the position-wise feed-forward boundary, and compact single-head and multi-head block boundaries.
Category Theory Concept
The residual boundary consumes a product object and returns the same hidden sequence object:
HiddenSequence x ProjectedAttentionOutput -> HiddenSequence
For a fixed block value, the larger unmasked block shape is still an endomorphism:
HiddenSequence -> HiddenSequence
Design contract:
Residual addition should fail before addition if either the sequence length or model dimension differs. Without that check, a later Transformer block would be silently mixing incompatible objects.
Step 7: Layer Normalization
The current problem:
After residual addition, each sequence position should be normalized across its feature dimension while preserving the hidden sequence shape.
Rust Syntax
The current code models layer normalization as:
HiddenSequence -> HiddenSequence
LayerNormParameters validates non-empty scale and shift vectors, matching
parameter lengths, finite parameter values, and positive finite epsilon.
LayerNormalization rejects hidden sequences whose model dimension does not
match its parameter dimension.
ML Concept
Layer normalization recenters and rescales each hidden vector across its feature dimension. It is batch-size independent and preserves the sequence shape:
same positions
same model dimension
new normalized values
The Layer Normalization paper is useful here because it frames the operation
around statistics inside a single training case rather than statistics
collected across a batch. In this roadmap’s tiny Rust object, that means each
row can be normalized while the public HiddenSequence boundary remains the
same.
Category Theory Concept
The normalization boundary is an endomorphism:
HiddenSequence -> HiddenSequence
Read that as a forward call for one fixed LayerNormalization value. Its
scale, shift, and epsilon are already stored in the layer. If those values are
being learned, the changing object is the larger training state, not the
hidden sequence alone.
Design contract:
Normalization should not change the object type. If a later block expects a hidden sequence, the normalized result should still be a hidden sequence.
Step 8: Position-Wise Feed-Forward
The current problem:
After attention and normalization, each sequence position needs a learned non-linear transformation that preserves the hidden sequence shape.
Rust Syntax
The current code models this as:
HiddenSequence -> HiddenSequence
PositionWiseFeedForward validates two linear layers:
model dimension -> feed-forward hidden dimension -> model dimension
It also checks finite weights, finite biases, and compatible intermediate dimensions before any row is projected.
ML Concept
A position-wise feed-forward network applies the same two-layer non-linear map to each hidden vector independently:
hidden row -> expanded row -> activated row -> hidden row
It changes feature values, not the sequence length or public model dimension.
That is why the public boundary stays:
HiddenSequence -> HiddenSequence
Category Theory Concept
The feed-forward sublayer is another endomorphism:
HiddenSequence -> HiddenSequence
Read that the same way: for this fixed PositionWiseFeedForward value, the
call receives a hidden sequence and returns a hidden sequence. The layer’s
weights and biases are context already stored inside the Rust object. Training
those weights belongs to a state update, not to this forward boundary.
It is not the whole Transformer block. It is the next shape-preserving sublayer that a later block can compose.
Design contract:
The second linear layer must return to the original model dimension. Otherwise the next residual or block boundary would receive the wrong object.
Worked Example: Values Change, Shape Stays HiddenSequence
The residual example produced this hidden sequence:
residual query_0 = [7.5, 3.5]
residual query_1 = [13.2033, 5.2033]
Layer normalization changes the row values while keeping the public object the same:
normalized query_0 = [0.9999988, -0.9999988]
normalized query_1 = [0.99999976, -0.99999976]
The sequence still has two positions and model dimension two:
Residual HiddenSequence
2 positions x model dimension 2
LayerNormalization
HiddenSequence -> HiddenSequence
Normalized HiddenSequence
2 positions x model dimension 2
The feed-forward sublayer then applies the same two-layer map to each position. In this tiny example, the ReLU step clips the negative feature before the second linear layer returns to the public model dimension:
feed-forward query_0 = [0.9999988, 0.0]
feed-forward query_1 = [0.99999976, 0.0]
Again, the object has not changed:
HiddenSequence
2 positions x model dimension 2
PositionWiseFeedForward
HiddenSequence -> HiddenSequence
HiddenSequence
2 positions x model dimension 2
The values changed twice. The sequence length and model dimension did not.
That distinction matters because normalization and feed-forward computation are not new sequence objects in this roadmap. They are shape-preserving maps over hidden rows:
Residual HiddenSequence -> LayerNormalization -> HiddenSequence
HiddenSequence -> PositionWiseFeedForward -> HiddenSequence
The invalid mental shortcut is:
normalization creates a special normalized object
feed-forward creates a special feed-forward object
The useful engineering view is stricter:
both stages return HiddenSequence so the next block boundary can compose
Step 9: Positional Encoding
The current problem:
Self-attention sees a set of hidden rows. A sequence model also needs to know where each row sits in the sequence.
Rust Syntax
The current code models position as another shape-preserving morphism:
PositionalEncoding : HiddenSequence -> HiddenSequence
The encoding table validates non-empty finite rows and a fixed model dimension. Applying it rejects hidden sequences that are too long for the table or have the wrong model width.
ML Concept
Position information is added to hidden vectors before attention so identical tokens in different positions can become distinguishable to later transformations.
Category Theory Concept
For one fixed positional-encoding table, the public shape is still an endomorphism:
HiddenSequence -> HiddenSequence
Read that as a forward call with the encoding table already selected. If a
future chapter learns, swaps, or rebuilds the position table, that changing
context must be named separately instead of being hidden inside the
HiddenSequence -> HiddenSequence arrow.
Design contract:
Adding position should change values, not the hidden sequence object. If the position table has the wrong width or not enough rows, composition should fail before attention starts.
Step 10: Single-Head And Multi-Head Blocks
The current problem:
A block should compose attention, residual addition, normalization, and feed-forward computation while keeping the public shape simple.
Rust Syntax
The current single-head sketch has shape:
HiddenSequence -> HiddenSequence
It uses:
SingleHeadTransformerBlock
MultiHeadTransformerBlock
The single-head sketch proves the compact block boundary. The multi-head sketch
extends that boundary by collecting several SelfAttentionHead values,
concatenating their outputs, and applying the output projection.
ML Concept
Transformer blocks combine:
attention
residual connection
normalization
feed-forward network
The block output has the same shape as the input.
This is where the current training chapter becomes useful again. A block with
shape HiddenSequence -> HiddenSequence can be stacked for the same reason a
training step with shape Parameters -> Parameters can be repeated: output and
input live in the same object.
Category Theory Concept
For one fixed single-head or multi-head block value, this is another endomorphism:
HiddenSequence -> HiddenSequence
Stacking layers is repeated endomorphism application.
If the block’s heads, projections, normalization parameters, or feed-forward weights are changing, the boundary is no longer only this forward call. The changing object is the larger training state:
TransformerTrainingState -> TransformerTrainingState
Design contract:
Internal complexity does not leak into every caller. The single-head and multi-head sketches contain several sublayers, but callers see one typed boundary.
For the multi-head sketch, the output-projection input dimension must equal:
head_count * value_head_dimension
That check is the difference between a typed block and a loose pile of matrix multiplications.
Step 11: Masked Blocks
The current problem:
Some sequence positions should not attend to other positions. The block boundary needs a mask, not only the lower-level score operation.
Rust Syntax
The current code models the masked block as:
MaskedMultiHeadTransformerBlock : HiddenSequence x AttentionMask -> HiddenSequence
The mask is part of the input product. The block applies it after query-key scoring and before row-wise softmax for each head.
ML Concept
A mask controls which source positions each query position may use. The same shape-preserving block can now represent selective attention.
Category Theory Concept
This is a product-to-object morphism:
HiddenSequence x AttentionMask -> HiddenSequence
Design contract:
The mask must have the same query and key dimensions as the score table inside the block. If the shape does not match, the block fails before softmax.
Worked Example: Fixed Mask Versus Open Mask
The masked block is where stacking language can become imprecise. The unmasked multi-head block has the simple public shape:
MultiHeadTransformerBlock : HiddenSequence -> HiddenSequence
That boundary can compose directly with another boundary of the same shape. The masked block is different:
MaskedMultiHeadTransformerBlock : HiddenSequence x AttentionMask -> HiddenSequence
While the mask is still an open input, the boundary is not a unary endomorphism. It needs the hidden sequence and the mask. There are two precise ways to use it repeatedly:
| Use case | Boundary to name | Why this is precise |
|---|---|---|
| the caller supplies a mask for each block call | HiddenSequence x AttentionMask -> HiddenSequence | the mask remains visible as required context |
| one example fixes a specific mask before applying the block | HiddenSequence -> HiddenSequence under fixed mask context | the unary map is induced by a named fixed mask |
The second row is useful in a lesson or a single training example, but only if the fixed context is named. The mask did not disappear. It became part of the chosen environment for that run.
The invalid shortcut is:
MaskedMultiHeadTransformerBlock returns HiddenSequence, so it is automatically
an endomorphism.
The output type is not enough. Count the whole input object. The open masked boundary has product input. A fixed-mask view can be treated as a unary shape-preserving map only after the mask has been selected and kept stable for that call path.
Step 12: Structured State For Training And Evaluation
The current problem:
Once the model has attention parameters, evaluation and future training need one structured object instead of loose matrices passed through the code.
Rust Syntax
The earlier training chapter used:
Parameters
The roadmap code now adds the structured attention-side version:
TransformerReadout : HiddenSequence -> SequenceLogits
TinyTransformerParameters : HiddenSequence x AttentionMask -> SequenceLogits
TransformerTrainingState
TransformerReadoutTrainStep : TransformerTrainingState -> TransformerTrainingState
TransformerFeedForwardTrainStep : TransformerTrainingState -> TransformerTrainingState
TinyTransformerParameters owns:
positional encoding
masked multi-head block
sequence readout
TransformerTrainingState owns that parameter object plus LearningRate and
StepCount. Its record_updated_parameters method records a new parameter
object and increments the step count.
The roadmap code also adds three supervised updates:
TransformerReadoutTrainStep : TransformerTrainingState -> TransformerTrainingState
TransformerFeedForwardTrainStep : TransformerTrainingState -> TransformerTrainingState
TransformerBlockTrainStep : TransformerTrainingState -> TransformerTrainingState
The readout step updates the sequence readout with a softmax cross-entropy gradient. The feed-forward step updates the position-wise feed-forward sublayer against hidden-sequence targets. The block step composes those ideas: it starts from sequence targets, computes readout gradients, carries the hidden gradient through the final layer-normalization and residual boundary, then through the attention-normalization and residual boundary, and updates the feed-forward sublayer, attention output projection, query/key/value projections, and both layer-normalization scale/shift vectors from the same supervised example. These are real gradient steps, but deliberately tiny ones.
ML Concept
A Transformer training loop still has the same outer structure:
predict
compute loss
backpropagate
update parameters
The internal model is richer, so the parameter object must be richer. A useful training state keeps three questions separate:
what parameters define the model?
what optimizer settings control the update?
which update step are we on?
The current code answers those questions structurally, then adds a small full-batch gradient through the current trainable block components. The readout update answers the first smaller question:
if the hidden sequence is fixed, can the vocabulary readout learn?
Yes. The update computes probabilities at each sequence position, subtracts one from the target class probability term, accumulates weight and bias gradients for the readout, applies the learning rate, and increments the step count.
The local feed-forward update asks a different small question:
if the attention output is treated as fixed, can the feed-forward sublayer
learn a hidden-sequence target?
It computes a squared-error gradient through the two feed-forward linear layers and the ReLU between them. That teaches a real block-internal update without pretending to use token targets.
The composed block update asks the next question:
if the model predicts target tokens, can the readout loss also update the
feed-forward sublayer through the final residual and normalization path?
Yes. The update follows the actual forward cache used by the block, computes the softmax cross-entropy gradient at the readout, applies the standard layer normalization backward pass for the final normalization boundary, splits the residual path, passes through the feed-forward sublayer and attention normalization boundary, then updates the readout, feed-forward parameters, attention output projection, and both layer-normalization scale/shift vectors together. It also backpropagates through value mixing, attention softmax, and scaled query-key scores to update the query, key, and value projections.
Worked Example: Three Updates, One State Shape
The attention example prints one structured state transition first:
training state step: 0 -> 1
The smaller training-state example isolates the same contract:
initial state: step=0, learning_rate=0.100, model_dimension=2, vocab_size=3
readout update: step 0 -> 1
feed-forward update: step 1 -> 2
composed block update: step 2 -> 3
That line is small, but it carries the whole contract:
TransformerTrainingState
owns TinyTransformerParameters
owns LearningRate
owns StepCount
An update is not allowed to return loose matrices. It must return another
TransformerTrainingState, because the next update needs the same state shape.
The readout-only step asks the smallest supervised question:
fixed hidden sequence
-> vocabulary logits
-> sequence loss
-> updated readout parameters
loss: 0.499085 -> 0.456495
The state shape is unchanged:
TransformerReadoutTrainStep
TransformerTrainingState -> TransformerTrainingState
The local feed-forward step asks a different question:
fixed hidden-sequence input
-> feed-forward output
-> squared-error hidden target
-> updated feed-forward parameters
loss: 0.250000 -> 0.160633
The same outer shape holds:
TransformerFeedForwardTrainStep
TransformerTrainingState -> TransformerTrainingState
The composed block step asks the broader question:
hidden sequence and mask
-> attention block
-> readout logits
-> sequence loss
-> updated block and readout parameters
loss: 0.456495 -> 0.409737
Again, the outside of the system is stable:
TransformerBlockTrainStep
TransformerTrainingState -> TransformerTrainingState
The internal gradient path grows from readout-only, to local feed-forward, to a composed block update. The public training shape does not grow:
state_0 -> state_1 -> state_2 -> state_3
That is the engineering version of the earlier endomorphism idea. A training step may touch different fields, but it should return the same kind of state so the loop can keep running.
The invalid shortcut would be:
readout update returns readout weights
feed-forward update returns feed-forward weights
block update returns a bag of changed matrices
That makes the next training step guess how to rebuild the model. The roadmap uses one structured state object instead, so every update must preserve the state boundary.
Gradient Evidence Ledger
The block training step has finite-difference tests. They are important, but they are not magic certificates. Read them as local evidence checks.
The test shape is:
one selected parameter
-> perturb it by +epsilon and -epsilon
-> measure two nearby losses
-> compute a central finite difference
one training step
-> compare before and after parameter values
-> infer the gradient used by the update
Those two paths should agree for the selected parameter:
central finite difference of loss ~= inferred update gradient
CS231n uses gradient checking this way: compare a numerical gradient with an
analytic gradient, preferably with a centered finite-difference formula and
careful error interpretation. PyTorch’s gradcheck documentation gives the
framework version of the same idea: finite differences are compared with
analytical gradients, and the result depends on tolerance, precision,
differentiability, and memory-layout assumptions.
The Rust roadmap keeps the claim smaller:
| Test family | Parameter path checked | What it can catch | What it cannot prove |
|---|---|---|---|
transformer_block_train_step_matches_finite_difference_for_readout_weight | sequence readout weight | wrong sign, missing target-class gradient, wrong averaging scale | correctness of attention gradients |
transformer_block_train_step_matches_finite_difference_for_feed_forward_weight | feed-forward weight | dropped ReLU or hidden-layer path | correctness of every feed-forward configuration |
transformer_block_train_step_matches_finite_difference_for_layer_norm_parameter | normalization scale or shift | wrong layer-normalization backward path | correctness of all normalization behavior |
transformer_block_train_step_matches_finite_difference_for_attention_projection | query, key, value, or output projection weight | dropped attention projection path | correctness of every attention variant |
| bias finite-difference tests | readout, feed-forward, output projection, or attention projection bias | missing bias gradient | correctness of all trainable fields |
The category-theory reading is also modest. These tests compare two local morphisms around one selected coordinate:
loss measurement around current state
parameter update inside TransformerTrainingState -> TransformerTrainingState
They support the implementation of the current state endomorphism. They do not prove that every possible dataset, learning rate, optimizer, mask, sequence length, or future Transformer block is correct.
Use this decision rule when reading a gradient-check result:
match -> local evidence for this parameter path
mismatch -> inspect sign, scaling, dropped path, nonsmooth point, or tolerance
Do not respond to a mismatch by only loosening the tolerance. First ask which typed boundary or gradient path failed.
Category Theory Concept
The forward path is now a typed morphism:
HiddenSequence x AttentionMask -> SequenceLogits
The training updates all have the same endomorphism shape:
TransformerTrainingState -> TransformerTrainingState
The block step is more global than the readout-only and local feed-forward steps because the loss starts at vocabulary logits and reaches an internal sublayer, the attention output projection, query/key/value projections, and both layer-normalization parameter sets. The update still preserves the same outer endomorphism shape even as the internal gradient path becomes richer.
Design contract:
The parameter object separates substructures:
position information
attention block
language-model readout
That separation is pedagogical and architectural. A reader can point at one field and say which mathematical role it plays. A future optimizer can update the same object without erasing the roles.
Core Mental Model
The current course teaches the typed skeleton:
TokenId -> Vector -> Logits -> Distribution
Distribution x TokenId -> Loss
Parameters -> Parameters
A Transformer extension grows the middle:
TokenSequence
-> HiddenSequence
-> HiddenSequence with position
-> QuerySequence x KeySequence x ValueSequence
-> AttentionOutput
-> HiddenSequence
-> SingleHeadTransformerBlock
-> MultiHeadTransformerBlock
-> SequenceLogits
-> sequence-level probabilities
The practical rule stays the same:
Make every intermediate object explicit, then compose only arrows whose types actually match.
Where This Leaves Us
The roadmap keeps the book honest. The current implementation is a tiny next-token system, not a production Transformer. Its value is that it gives the future system a typed foundation: tokens become vectors, vectors become logits, logits become probabilities, probabilities become loss, and training updates parameters through a repeatable endomorphism.
A future Transformer should extend that foundation with stronger optimizer checks, more realistic datasets, and clearer diagrams. Each new concept should enter the codebase the same way the current concepts did: as a named type, a validated boundary, a typed morphism, a compiled example, and a law or regression test where the concept has a law worth checking.
Roadmap Reference Path
Use References as a staged path. This roadmap comes first because it says what the current code has and does not have. The original Transformer paper then gives the architectural target. Dive into Deep Learning gives the practical sequence from attention scoring to full Transformer blocks. Implementation and visual tutorials help translate paper notation into code structure and diagrams.
After that, return to this repository and add one typed boundary at a time. A future contribution should not start by copying a full architecture into one large module. It should start by making one Transformer concept explicit enough to construct, compose, test, and explain.
The central question for every future contribution is:
What invalid Transformer state should this type make harder to express?
Terminal Output Checkpoint Map
The companion example prints many lines because it is acting as a shape lab. Before reading the typed transformation list, group the terminal output into checkpoints.
| Printed checkpoint | What changed | What stayed true | Boundary to protect |
|---|---|---|---|
attention shape: 2 query positions x 3 key positions | query-key scoring produced one row per query and one column per key | score rows are still tied to query positions | QuerySequence x KeySequence -> AttentionScores |
query 0 attends with [0.5, 0.0, 0.5] | masked scores became row-wise weights | the masked key position has zero contribution | AttentionScores x AttentionMask -> AttentionScores -> AttentionWeights |
query 0 output vector [2.0, 20.0] | weights mixed value rows into one output row | one output row still belongs to one query position | AttentionWeights x ValueSequence -> AttentionOutput |
multi-head shape: 2 heads x 2 features -> model dimension 4 | separate head outputs were concatenated | sequence length stayed two positions | AttentionHeadOutputs -> MultiHeadOutput |
projected attention shape: 2 positions x model dimension 2 | concatenated width returned to model width | each row can now rejoin the residual stream | MultiHeadOutput -> ProjectedAttentionOutput |
residual shape: 2 positions x model dimension 2 | projected attention was added to the input hidden rows | public object is still HiddenSequence | HiddenSequence x ProjectedAttentionOutput -> HiddenSequence |
normalized shape and feed-forward shape | values changed inside each row | sequence length and model dimension stayed stable | HiddenSequence -> HiddenSequence |
structured transformer logits shape | hidden rows became vocabulary scores per position | logits are not probabilities yet | HiddenSequence x AttentionMask -> SequenceLogits |
training state step: 0 -> 1 | parameters and step count advanced | the training object stayed whole | TransformerTrainingState -> TransformerTrainingState |
readout loss after one update | readout parameters learned from token targets | the update still returns full training state | readout endomorphism |
feed-forward loss after one local update | the feed-forward sublayer learned a hidden target | the update still returns full training state | local feed-forward endomorphism |
block loss after one composed update | readout, feed-forward, attention projections, and normalization parameters moved together | the outer state shape stayed stable | composed block endomorphism |
Use this map to avoid a common Transformer-reading mistake: treating every printed vector as “attention.” The output actually moves through three different ideas:
where to read
-> what information to read
-> how to return to the hidden-state stream
Then the training lines ask a separate question:
which parameters moved, and did the update preserve the state object?
Example Output Transfer Checklist
After running the companion example, read the printed transformation list as a boundary report. For each line, ask four questions:
What Rust object is being produced?
What ML role does it play?
What shape must remain true?
What shortcut would break the next composition?
| Example output line | Boundary to own | Shortcut to reject |
|---|---|---|
HiddenSequence -> QuerySequence | Hidden rows become question-like vectors for scoring. | Reusing raw hidden rows as queries without a named projection. |
HiddenSequence -> KeySequence | The same hidden rows become comparison vectors. | Treating keys and queries as the same role because they share a source. |
HiddenSequence -> ValueSequence | The same hidden rows become information vectors to be mixed. | Mixing keys or queries as if they were values. |
QuerySequence x KeySequence -> AttentionScores | Scores are unnormalized similarity numbers. | Reading scores as probabilities or final attention output. |
AttentionScores x AttentionMask -> AttentionScores | The mask removes illegal positions before normalization. | Applying softmax first, then hiding positions after probability mass has already moved. |
AttentionScores -> AttentionWeights | Row-wise softmax turns scores into weights. | Combining values before a normalized weight object exists. |
AttentionWeights x ValueSequence -> AttentionOutput | Values are mixed only after weights exist. | Asking scores to carry both similarity and information. |
AttentionHeadOutputs -> MultiHeadOutput | Head outputs concatenate, so width is head_count * head_dimension. | Pretending the concatenated width is already the model dimension. |
MultiHeadOutput -> ProjectedAttentionOutput | The output projection returns concatenated heads to model width. | Adding unprojected multi-head output directly to the residual stream. |
HiddenSequence x ProjectedAttentionOutput -> HiddenSequence | Residual addition preserves sequence length and model dimension. | Adding two objects whose row widths do not match. |
LayerNormalization : HiddenSequence -> HiddenSequence | Values change while the public hidden-sequence shape stays stable. | Treating normalization as a projection to a new domain object. |
PositionWiseFeedForward : HiddenSequence -> HiddenSequence | Each row may expand internally, but the sublayer returns model width. | Letting the hidden expansion leak past the sublayer boundary. |
TransformerTrainingState -> TransformerTrainingState | Training updates parameters while preserving learning rate and step count. | Returning only changed weights and forcing the next step to reconstruct state. |
This checklist is the transfer bridge from paper notation and framework documentation to this repository’s Rust style. The original Transformer uses query, key, value, masking, softmax, value mixing, multi-head concatenation, output projection, residual paths, normalization, and feed-forward sublayers. Diagrammatic attention research also treats attention as something that can be decomposed into recurring components before variants are compared. Framework APIs compress much of that into one function call. This chapter uncompresses the path so the reader can point at each intermediate Rust type and say what invalid connection it prevents.
Category Shape Diagnostic
The printed Transformer path uses several category-theory shapes that look similar if you only read the arrows. Before naming a boundary, ask two questions:
How many inputs does this boundary require?
Does it return the same public object, or a different object?
Those two questions prevent a common mistake: calling every shape-preserving
line an endomorphism. A true endomorphism in this book has the form
A -> A. A product-input boundary such as A x B -> A may return the same
object as its left input, but it still needs extra information.
There is a second kind of extra information: learned parameters stored inside
a layer object. When this diagnostic names
LayerNormalization : HiddenSequence -> HiddenSequence,
PositionWiseFeedForward : HiddenSequence -> HiddenSequence, or
MultiHeadTransformerBlock : HiddenSequence -> HiddenSequence, read it as a
forward call for one fixed layer or block instance. The scale, shift, weights,
and biases are already inside that Rust value. If those parameters are being
changed, the boundary has moved to training state:
TransformerTrainingState -> TransformerTrainingState
That distinction keeps the chapter honest. It allows a fixed module call to be shape-preserving without pretending parameter learning has disappeared.
They also prevent a second mistake: importing an advanced categorical name too early. Research on self-attention as a parametric endofunctor is useful for the linear portions of self-attention, especially query, key, value, positional, and layered structure. It does not make the whole pedagogical block a single endofunctor in this book. Softmax, masking, residual addition, normalization, feed-forward refinement, and training state each still need their own typed boundary.
Research on the anatomy of attention supports the opposite teaching move: decompose attention first, then compare variants. In this book, the decomposition is not a full diagrammatic formalism. It is a Rust teaching contract: every component must have a named type, a boundary shape, and a failure it prevents.
Categorical deep-learning research also separates architecture constraints from implementations. That distinction is useful here because the Rust code is an implementation witness for one small boundary at a time, not a proof that a future full Transformer satisfies every intended constraint. A good chapter claim should say which side it is on:
architecture constraint:
what should remain true?
implementation boundary:
which Rust type, constructor, example, or test currently enforces it?
Do not call the whole block an endofunctor when the explanation only checked one internal linear path. In this chapter, use the smaller safe name first: ordinary morphism, product-input morphism, shape-preserving endomorphism, state endomorphism, or illegal attempted composition.
The decision flow is:
flowchart TD
B["Boundary shape"] --> T{"Does it type-check?"}
T -->|"no"| I["Illegal attempted composition: name the missing conversion"]
T -->|"yes"| C{"How many inputs are visible?"}
C -->|"one input"| O{"Same whole source and target object?"}
O -->|"yes"| E["Endomorphism: A -> A"]
O -->|"no"| M["Ordinary morphism: A -> B"]
C -->|"product input"| F{"Was one context fixed first?"}
F -->|"yes"| U["Induced unary view: name the fixed context"]
F -->|"no"| P["Product-input morphism: keep A x B visible"]
The same naming rule as a compact rendered math view:
[ \begin{array}{rcl} A \to B &:& \text{ordinary morphism} \ A \to A &:& \text{endomorphism} \ A \times B \to C &:& \text{product-input morphism} \ A \times B \to A &:& \text{not automatically an endomorphism} \ A \xrightarrow{f_b} A &:& \text{fixed-context induced endomorphism, after } b \text{ is fixed} \end{array} ]
How to read this diagram:
- count the visible inputs before naming the category shape,
- compare the whole source object with the whole target object,
- fix context explicitly before using a unary view,
- reject same-output shortcuts that ignore product inputs.
Read the diagram from top to bottom before naming an attention boundary. It is only a local naming aid, but it prevents three common shortcuts:
| Shortcut | Safer move |
|---|---|
| output shape matches, so the boundary is an endomorphism | compare the whole source object with the target object |
| the product can be read as one source, so the boundary is an endomorphism | check whether the target is the same product object |
| context was fixed in prose, so the original boundary had one input | name the open boundary first, then name the fixed context |
Two-Minute Classification Drill
Before reading the longer table, classify these boundaries yourself. Cover the right column, count the inputs, then decide whether the output returns to the same whole object.
| Boundary | Question to ask first | Safe classification |
|---|---|---|
HiddenSequence -> QuerySequence | one input, different output object? | ordinary morphism |
AttentionScores x AttentionMask -> AttentionScores | two inputs, returns the left object? | product-input morphism returning the score object |
LayerNormalization : HiddenSequence -> HiddenSequence | one input, same output object? | shape-preserving endomorphism |
HiddenSequence x ProjectedAttentionOutput -> HiddenSequence | two inputs, returns the left object? | product-input morphism returning hidden state |
TransformerTrainingState -> TransformerTrainingState | one input, same whole training object? | state endomorphism |
The trap is the second and fourth rows. Returning the left-hand object is not
the same as being an endomorphism. A boundary that still needs a mask, value
sequence, projected sublayer output, dataset, or learning rate is not a pure
A -> A story until that context is explicitly fixed.
Source-Target Audit Card
Use this card when a row still feels ambiguous. Do not start from the output type. Name the whole source object, then name the target object.
| Boundary | Whole source object | Target object | Context status | Safe conclusion |
|---|---|---|---|---|
HiddenSequence -> QuerySequence | HiddenSequence | QuerySequence | no extra context in the boundary | ordinary morphism |
AttentionScores x AttentionMask -> AttentionScores | AttentionScores x AttentionMask | AttentionScores | mask is open context | product-input morphism, not an endomorphism on scores |
MaskedMultiHeadTransformerBlock[M] : HiddenSequence -> HiddenSequence | HiddenSequence | HiddenSequence | one mask M was fixed first | induced endomorphism for that named mask |
HiddenSequence x ProjectedAttentionOutput -> HiddenSequence | HiddenSequence x ProjectedAttentionOutput | HiddenSequence | residual input is open context | product-input morphism returning hidden state |
TransformerTrainingState -> TransformerTrainingState | TransformerTrainingState | TransformerTrainingState | update context is inside the state object | state endomorphism |
The second and fourth rows are unary only if you choose to regard the product as one source object, but they are still not endomorphisms. Their targets are not the same product object. The fixed-mask row is different because the mask has been selected before the remaining call.
Linear Scope Diagnostic
Use this when an external source gives a categorical reading of self-attention. First ask which part of the attention path the source actually classified.
| Attention part | Boundary in this roadmap | Safe reading here |
|---|---|---|
| query projection | HiddenSequence -> QuerySequence | linear role-producing morphism |
| key projection | HiddenSequence -> KeySequence | linear role-producing morphism |
| value projection | HiddenSequence -> ValueSequence | linear role-producing morphism |
| score construction | QuerySequence x KeySequence -> AttentionScores | product-input boundary |
| mask application | AttentionScores x AttentionMask -> AttentionScores | product-input boundary, not a pure score endomorphism |
| score normalization | AttentionScores -> AttentionWeights | nonlinear normalization boundary |
| value mixing | AttentionWeights x ValueSequence -> AttentionOutput | product-input boundary |
| residual addition | HiddenSequence x ProjectedAttentionOutput -> HiddenSequence | product-input boundary returning hidden state |
| layer normalization for a fixed layer instance | HiddenSequence -> HiddenSequence | shape-preserving but nonlinear endomorphism |
| parameter-changing training update | TransformerTrainingState -> TransformerTrainingState | state endomorphism over the whole training object |
Safe rule:
If a claim was checked for linear Q/K/V maps, do not carry it through softmax,
masking, residual addition, layer normalization, feed-forward refinement, or
training state without naming the next boundary.
This keeps the chapter usable for two readers at once. The category-theory reader sees where a stronger formal story might attach. The ML engineer sees which implemented boundary still needs its own shape, invariant, and test.
Worked Classification: Same Output, Different Shape
The most tempting mistake is to look only at the output type. Three boundaries
below all end with HiddenSequence, but they do not have the same category
shape.
| Boundary | Count the inputs | Classification | Why |
|---|---|---|---|
LayerNormalization : HiddenSequence -> HiddenSequence | one input | endomorphism | the whole input object and output object are both HiddenSequence |
HiddenSequence x ProjectedAttentionOutput -> HiddenSequence | two inputs | product-input morphism returning HiddenSequence | residual addition needs both the old stream and the projected sublayer output |
HiddenSequence x MultiHeadOutput -> HiddenSequence | two inputs, wrong second object | illegal attempted boundary | residual addition needs projected model-width output, not raw concatenated heads |
The first boundary can safely be named an endomorphism in this book. The second
cannot, even though it returns HiddenSequence, because the full input object
is not HiddenSequence; it is a product of two objects. The third should not
receive a category-theory name yet. It is missing the output projection that
makes the residual path well typed.
This gives a short decision tree:
Does the boundary type-check?
no -> name the missing conversion first
yes -> count the inputs
one input -> compare input object and output object
two inputs -> keep the product-input boundary visible
Use this naming rule before reading the table:
1. Count the inputs.
2. If there is one input, compare the input object and output object.
3. If there is a product input, keep the product in the name.
4. If a required projection or conversion is missing, call it illegal before
giving it a category-theory name.
That gives five safe cases:
| Shape | Safe name | Example |
|---|---|---|
A -> B | ordinary morphism | AttentionScores -> AttentionWeights |
A -> A | endomorphism | LayerNormalization : HiddenSequence -> HiddenSequence |
A x B -> C | product-input morphism | AttentionWeights x ValueSequence -> AttentionOutput |
A x B -> A | product-input morphism returning A | HiddenSequence x ProjectedAttentionOutput -> HiddenSequence |
A x B -> A with the wrong B | illegal attempted boundary | HiddenSequence x MultiHeadOutput -> HiddenSequence |
The fourth row is the trap. Returning the left object is not enough to make a boundary an endomorphism. The whole input must be one object, and the output must be that same object.
There is a second safe reading that is useful but different. You may choose to treat the product itself as one source object:
(A x B) -> A
That makes the arrow unary from the product object, but it still is not an
endomorphism. The source object is A x B; the target object is A. An
endomorphism on the product would have shape:
(A x B) -> (A x B)
This is why the roadmap keeps the phrase “product-input morphism returning
A” instead of shortening it to “endomorphism on A.”
Terminal Output Audit: Shape Line Is Not Boundary Shape
The runnable example prints several lines with the same public dimensions:
cargo run --example 06_attention_scores
Those lines are useful evidence, but they are not category names by themselves. A printed shape tells you something about the target object. The typed transformation line tells you the whole source object and the target object.
| Printed output line | What the line proves | What it does not prove | Boundary to name |
|---|---|---|---|
projected attention shape: 2 positions x model dimension 2 | raw head output has been projected back to model width | the residual connection has already happened | MultiHeadOutput -> ProjectedAttentionOutput |
residual shape: 2 positions x model dimension 2 | the result has returned to hidden-sequence shape | residual addition was unary | HiddenSequence x ProjectedAttentionOutput -> HiddenSequence |
masked multi-head block shape: 2 positions x model dimension 2 | the block output can feed the next hidden-sequence layer | the open masked block is a pure endomorphism | MaskedMultiHeadTransformerBlock : HiddenSequence x AttentionMask -> HiddenSequence |
training state step: 0 -> 1 | the update returns a state that can be updated again | training is a loose Loss -> Parameters shortcut | TransformerTrainingState -> TransformerTrainingState |
Use this three-step audit whenever the terminal output seems to settle the category name too quickly:
printed shape line -> evidence about the target object
typed transformation line -> evidence about the source and target objects
category name -> only after both source and target are known
That is why residual shape and masked multi-head block shape can both show
model-width hidden rows while still having different safe category readings.
Same printed dimensions are not the same boundary.
Stackability With Context
Stacking means the output of one boundary can feed the next boundary without inventing missing inputs. A direct endomorphism can stack by itself. A product-input boundary can stack only if the extra context is carried along or fixed explicitly.
For learned sublayers, “direct” means the layer instance is fixed for the
forward call. A different LayerNormalization or PositionWiseFeedForward
value is a different morphism. Changing those parameters is training-state
work, so the safe outer name is TransformerTrainingState -> TransformerTrainingState.
| Boundary | Can it stack directly as HiddenSequence -> HiddenSequence? | Safe reading |
|---|---|---|
LayerNormalization : HiddenSequence -> HiddenSequence | yes | direct shape-preserving endomorphism |
MultiHeadTransformerBlock : HiddenSequence -> HiddenSequence | yes | direct block endomorphism |
MaskedMultiHeadTransformerBlock : HiddenSequence x AttentionMask -> HiddenSequence | no, not while the mask is open | product-input morphism that still needs mask context |
fixed-mask view of MaskedMultiHeadTransformerBlock | yes, for that named mask context | induced endomorphism after context is fixed |
TransformerTrainingState -> TransformerTrainingState | yes | state endomorphism over the whole training object |
This is the same discipline as the rest of the chapter. Do not erase context to make a category name fit. If a mask, dataset, learning rate, or parameter object is part of the boundary, either keep it in the type shape or say exactly where it was fixed.
Context Fixing Drill
The open masked block and a fixed-mask view are related, but they are not the same boundary:
open boundary:
MaskedMultiHeadTransformerBlock : HiddenSequence x AttentionMask -> HiddenSequence
fixed-context boundary:
choose one mask M
MaskedMultiHeadTransformerBlock[M] : HiddenSequence -> HiddenSequence
The fixed-context boundary is a new named view after selecting M.
It is not a claim that the original block had only one input. The source of
the context must stay visible in the prose, the exercise, or the type that
carries it.
| Case | What is fixed? | Safe category shape | Can it stack as HiddenSequence -> HiddenSequence? | Overclaim to avoid |
|---|---|---|---|---|
| open masked block | nothing | product-input morphism | no | “it returns HiddenSequence, so it is an endomorphism” |
| fixed-mask view | one named AttentionMask | induced endomorphism for that mask | yes, while the same mask context remains fixed | “the mask disappeared” |
| changing mask per call | the mask is supplied again each call | repeated product-input calls, or a larger state carrying the mask | only if the caller threads or fixes the context | “this is the same as a fixed-mask view” |
| residual addition | no input is fixed; both hidden stream and projected output are supplied | product-input morphism returning hidden state | no | “a binary operation is unary because the result is hidden state” |
The residual row is a negative contrast. It is not a context-fixing example. The hidden stream is still an input, and the projected sublayer output is still an input. Nothing has been selected in advance. The boundary therefore remains:
HiddenSequence x ProjectedAttentionOutput -> HiddenSequence
It returns HiddenSequence, but it does not become a unary
HiddenSequence -> HiddenSequence boundary unless one input has actually been
fixed. If the product object itself is named as the source, the arrow is unary
from that product object:
(HiddenSequence x ProjectedAttentionOutput) -> HiddenSequence
That is still not an endomorphism, because the target is not the same product object. An endomorphism on the named product would have to return the whole product again:
(HiddenSequence x ProjectedAttentionOutput)
-> (HiddenSequence x ProjectedAttentionOutput)
This is only a local teaching use of fixing context. It is not a proof that the whole attention block lives in a closed category, and it is not permission to hide arbitrary inputs. The practical rule stays simple:
name the open boundary first
name exactly what was fixed
then name the induced unary view
Rust already has a familiar mechanism for this idea: a closure can capture a value from the surrounding environment. The official Rust Book uses closures to teach how a callable value can remember environment. In this roadmap, a fixed-mask view can be read the same way:
let fixed_mask = mask.clone();
let fixed_mask_view = move |hidden: HiddenSequence| {
masked_block.apply(Product::new(hidden, fixed_mask.clone()))
};
That closure-shaped explanation is only an analogy for this local boundary. It does not change the original open type:
MaskedMultiHeadTransformerBlock : HiddenSequence x AttentionMask -> HiddenSequence
It says how one chosen AttentionMask can be captured before the remaining
call receives HiddenSequence. If a reader cannot point to the captured
fixed_mask, the text has hidden context instead of fixing it.
| Boundary | Category shape to name | Why this is the right name | Common misread |
|---|---|---|---|
QuerySequence x KeySequence -> AttentionScores | product-input morphism | scoring needs query rows and key rows | treating scores as a unary query transform |
AttentionScores x AttentionMask -> AttentionScores | product-input morphism returning the score object | the mask is extra evidence used before softmax | calling it a pure endomorphism on scores |
AttentionScores -> AttentionWeights | ordinary morphism | raw scores become normalized rows | treating weights as the same object as scores |
AttentionWeights x ValueSequence -> AttentionOutput | product-input morphism | weights decide which value rows to read | treating keys and values as interchangeable |
MultiHeadOutput -> ProjectedAttentionOutput | ordinary morphism | concatenated head width returns to model width | adding multi-head output directly to residual state |
HiddenSequence x ProjectedAttentionOutput -> HiddenSequence | product-input morphism returning hidden state | residual addition needs both the old stream and the sublayer output | calling the binary residual operation a unary endomorphism |
LayerNormalization : HiddenSequence -> HiddenSequence | shape-preserving endomorphism | values change while the public hidden object stays the same | treating normalization as a new sequence domain |
PositionWiseFeedForward : HiddenSequence -> HiddenSequence | shape-preserving endomorphism | internal width may expand, but the public object returns unchanged | leaking the internal expansion into the next block |
TransformerTrainingState -> TransformerTrainingState | state endomorphism | one update returns a complete object that can be updated again | returning only changed weights or only loss |
HiddenSequence x MultiHeadOutput -> HiddenSequence | not a legal composed boundary | residual addition needs projected model-width output | skipping the output projection |
This diagnostic is the category-theory version of shape checking. The ML question is:
what information must this stage receive?
The Rust question is:
which type should own that information before the next call?
The category-theory question is:
is this a unary morphism, a product-input morphism, an endomorphism, or not
composable yet?
If a boundary needs two objects, write both. If it returns to the same public object, say whether that return is unary or product-input. Precision here is what keeps the roadmap from turning attention into a single vague arrow.
Reader Evidence Handoff
If this diagnostic becomes unclear, the most useful report is not “attention is confusing.” The useful report names the exact rule that failed.
Use this shape:
Command: cargo run --example 06_attention_scores
Page: Transformer Roadmap -> Category Shape Diagnostic
Evidence signal: one boundary row or printed output line
Last clear idea: the last boundary name that still made sense
First unclear rule: input count, fixed context, legal composition, source role,
target role, or linear-scope limit
Smallest useful fix: one sentence, table row, diagram, or exercise check
Good evidence signals are small:
AttentionScores x AttentionMask -> AttentionScores
MaskedMultiHeadTransformerBlock[M] : HiddenSequence -> HiddenSequence
HiddenSequence x MultiHeadOutput -> HiddenSequence
query 0 attends with [0.5, 0.0, 0.5]
A report like that gives the next rewrite a concrete target: which boundary, which rule, and which reader expectation failed.
Open the chapter clarity feedback form with those fields filled from your own run or reading.
Retrieval Practice
Run the attention example before answering:
cargo run --example 06_attention_scores
Recall
Recover the named objects and boundaries before explaining them.
- Which three role objects are produced from
HiddenSequencebefore attention scores are computed? - Which printed line is the first point where attention scores become row-wise normalized weights?
- Which boundaries in the example preserve the public shape
HiddenSequence -> HiddenSequence? - Which three training steps share the outer shape
TransformerTrainingState -> TransformerTrainingState? - Which printed line tells you that multi-head width must be projected before it can rejoin the residual stream?
Explain
Use the type boundary to explain the reason for the design.
- Why must the mask act before row-wise softmax?
- Why does multi-head attention need an output projection before residual addition?
- Why is returning a full
TransformerTrainingStatesafer than returning only changed readout or feed-forward weights?
The next questions check the scope of the evidence. A local gradient check and a shape-preserving sublayer are useful only when their boundaries are named.
- Why is a finite-difference check useful for one selected parameter without proving that every future training loop is correct?
- Why does
PositionWiseFeedForward : HiddenSequence -> HiddenSequencepermit an internal hidden expansion but not an expanded public output?
Apply
Change the numbers and check whether the same typed rule still holds.
- A block has three heads and each head produces four features per position. What input width must the output projection accept?
- A feed-forward sublayer expands a model-dimension-two row to six hidden features, then returns five features. Which public boundary has been broken?
- A training step updates the feed-forward weights but drops the learning rate. Why can the next training step no longer compose safely?
- A two-token sequence has raw scores
[1.0, 9.0]for the first row, but the second position is masked out. Which object should record the forbidden position before softmax? - A learner sees the line
AttentionWeights x ValueSequence -> AttentionOutputand wants to replaceValueSequencewithKeySequence. Which ML role has been lost? - A learner sees
HiddenSequence x ProjectedAttentionOutput -> HiddenSequenceand calls it an endomorphism because the output isHiddenSequence. Which step of the naming rule corrects that mistake?
Debug
For each invalid shortcut, name the missing boundary:
HiddenSequence x MultiHeadOutput -> HiddenSequence
AttentionScores -> AttentionOutput
readout update -> changed readout weights
softmax scores -> masked weights
feed-forward hidden expansion -> next block input
finite-difference agreement -> full optimizer proof
Good answers should point back to a concrete type or transformation in this chapter, not only to a phrase such as “shape mismatch” or “training update.”
Repository Source Snapshots
This appendix collects the learner-facing source files used throughout the
course. The Rust snapshots are complete files, so they are marked ignore as
standalone snippets. The real source files are still validated by the repository
checks.
The problem this appendix solves is:
The explanatory chapters should never drift away from the real code.
Use this appendix as the exact-code lookup layer.
When you inspect any block here, use the same reading pattern as the chapters:
Rust syntax:
what does the code literally declare or execute?
ML or software concept:
why does this block exist in the pipeline or system model?
Category theory concept:
what object, morphism, product, composition, endomorphism, or law does it
represent?
The appendix is intentionally less narrative than the chapters. It keeps the full source available in one place so you can verify every explanation against the code itself.
How To Navigate The Snapshots
Use the snapshots in two directions.
When reading chapter-first, start with the explanation, then open the matching snapshot to verify the exact code. When reading code-first, start with the file that interests you, then return to the chapter that teaches its role.
| If you want to inspect | Start with | Then read |
|---|---|---|
| public crate surface | src/lib.rs | Course Map |
| typed values and invariants | src/domain.rs | Domain Objects |
| arrows and composition | src/category.rs | Morphism and Composition |
| prediction pipeline | src/ml.rs | The Tiny ML Pipeline |
| parameter updates | src/training.rs | Training as an Endomorphism |
| reusable structure | src/structure.rs | Functors, Naturality, Monoids, and Chain Rule |
| local derivative flow | src/calculus.rs | Functors, Naturality, Monoids, and Chain Rule |
| typed masked attention, value-mixing, head-concatenation, output-projection, residual, normalization, feed-forward, positional, hidden-projection, single-head block, multi-head block, masked-block, readout, parameter-object, training-state, readout-training, local feed-forward training, and composed block-training boundaries with query/key/value gradients | src/attention.rs | Transformer Roadmap |
| applied category-theory sketches | src/sketches.rs | Seven Sketches Through Rust |
| public challenge reference behavior | src/challenges/ and examples/challenge_adam.rs | Challenges |
| runnable end-to-end walkthrough | src/demo.rs | Course Map |
| command-line entrypoint | src/bin/category_ml.rs | Course Map |
The useful question is:
Which explanation in the book would become false if this source file changed?
That question is why the source snapshots exist. They make drift visible.
Reading Order By Goal
For a five-minute run, inspect examples/01_token_sequence.rs, then compare it
with the beginning of Course Map.
For the core book path, read src/domain.rs, src/category.rs, src/ml.rs,
and src/training.rs in that order.
For the structure path, read src/structure.rs, src/calculus.rs,
src/attention.rs, and src/sketches.rs after the tiny ML pipeline is clear.
For contribution work, read the chapter first, then the source file, then the tests in the same module. The tests often explain the contract more clearly than the implementation alone.
Rust Library Surface
src/lib.rs
//! A small, modular Rust tutorial for category-theory ideas in tiny ML.
//!
//! The crate is intentionally split into small learning modules:
//! - [`domain`] defines the nouns: tokens, vectors, probabilities, losses, and parameters.
//! - [`category`] defines the arrows: morphisms, identity, composition, and endomorphisms.
//! - [`ml`] implements concrete ML morphisms.
//! - [`training`] turns one optimizer step into an endomorphism on parameters.
//! - [`structure`] covers functors, natural transformations, and monoids.
//! - [`calculus`] shows the chain rule as a local backward pass.
//! - [`attention`] sketches typed attention boundaries and Transformer state for the roadmap.
//! - [`sketches`] gives Rust models for the seven applied-category-theory sketches.
//! - [`challenges`] backs the public Typed AI Rustlings and Paper-To-Rust tracks.
//! - [`demo`] connects the pieces into the terminal walkthrough.
pub mod attention;
pub mod calculus;
pub mod category;
pub mod challenges;
pub mod demo;
pub mod domain;
pub mod error;
pub mod ml;
pub mod sketches;
pub mod structure;
pub mod training;
pub use attention::{
AttentionHeadOutputs, AttentionMask, AttentionOutput, AttentionOutputProjection,
AttentionScores, AttentionSoftmax, AttentionWeights, ConcatenateHeads, HeadCount,
HeadDimension, HiddenSequence, HiddenToKey, HiddenToQuery, HiddenToValue, KeySequence,
LayerNormParameters, LayerNormalization, MaskedAttentionScores,
MaskedMultiHeadTransformerBlock, MultiHeadOutput, MultiHeadTransformerBlock,
NormalizationEpsilon, PositionWiseFeedForward, PositionalEncoding, ProjectedAttentionOutput,
QuerySequence, ResidualConnection, ScaledDotProductScores, SelfAttentionHead, SequenceLength,
SequenceLogits, SingleHeadTransformerBlock, TinyTransformerParameters,
TransformerBlockTrainStep, TransformerBlockTrainingExample, TransformerBlockTrainingSet,
TransformerFeedForwardTrainStep, TransformerFeedForwardTrainingExample,
TransformerFeedForwardTrainingSet, TransformerReadout, TransformerReadoutTrainStep,
TransformerReadoutTrainingExample, TransformerReadoutTrainingSet, TransformerTrainingState,
ValueSequence, WeightedValueMixing, transformer_block_average_loss,
transformer_feed_forward_average_loss, transformer_readout_average_loss,
};
pub use calculus::{LocalGradient, MulOp, Scalar};
pub use category::{
Compose, Endomorphism, Identity, Morphism, StepCount, apply_endomorphism_n_times,
};
pub use challenges::papers::adam::{
AdamConfig, AdamDecayRate, AdamEpsilon, AdamFirstMoment, AdamGradientVector, AdamModelState,
AdamOptimizerState, AdamParameterVector, AdamSecondMoment, AdamStepCount, AdamTrainStep,
AdamVectorDimension,
};
pub use challenges::typed_ai::{
loss_from_logits, require_target_in_distribution, token_index, uniform_distribution,
};
pub use demo::run_demo;
pub use domain::{
Distribution, LearningRate, Logits, Loss, ModelDimension, Parameters, Product, TokenId,
TokenSequence, TrainingExample, TrainingSet, Vector, VocabSize,
};
pub use error::{CtError, CtResult};
pub use ml::{
CrossEntropy, DatasetWindowing, DirectPredict, Embedding, LinearToLogits, Softmax,
average_loss, composed_prediction_matches_direct_prediction,
};
pub use sketches::{
CircuitComponent, CompanyInstance, DepartmentId, DesignRequirement, EmployeeId, EmployeeRecord,
FeasibilityRelation, FeatureCount, ImplementationOffer, InformationLevel, LatencyMs,
LayerBudget, LocalSafetyCheck, MatrixCols, MatrixRows, OpenCircuit, PortName, ResistanceOhms,
ResourceAmount, ResourceBundle, SafetyCover, SignalCoefficient, SignalMatrix, Throughput,
TimeInterval, TimeTick, TruthValue, abstract_to_layer_budget, concretize_layer_budget,
feature_layer_galois_law_holds, information_order_obeys_preorder_laws,
resource_tensor_is_monotone,
};
pub use structure::{
Functor, Monoid, NaturalTransformation, OptionFunctor, PipelineTrace, TraceStep, VecFunctor,
VecToFirstOption, monoid_laws_hold_for_pipeline_trace,
naturality_square_holds_for_first_option,
};
pub use training::TrainStep;
src/error.rs
use std::fmt::{self, Display};
/// Shared error type for the tutorial crate.
#[derive(Debug, Clone, PartialEq)]
pub enum CtError {
EmptyInput(&'static str),
OutOfRange {
kind: &'static str,
index: usize,
limit: usize,
},
ShapeMismatch {
op: &'static str,
expected: String,
got: String,
},
InvalidProbability(&'static str),
InvalidLoss(f32),
InvalidLearningRate(f32),
InvalidScalar {
kind: &'static str,
value: f32,
},
InvalidQuantity {
kind: &'static str,
value: i64,
},
InvalidInterval {
start: usize,
end: usize,
},
}
impl Display for CtError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
CtError::EmptyInput(context) => write!(f, "empty input in {context}"),
CtError::OutOfRange { kind, index, limit } => {
write!(f, "{kind} index {index} out of range; limit is {limit}")
}
CtError::ShapeMismatch { op, expected, got } => {
write!(f, "shape mismatch in {op}: expected {expected}, got {got}")
}
CtError::InvalidProbability(context) => {
write!(f, "invalid probability distribution in {context}")
}
CtError::InvalidLoss(value) => write!(f, "invalid loss value {value}"),
CtError::InvalidLearningRate(value) => write!(f, "invalid learning rate {value}"),
CtError::InvalidScalar { kind, value } => {
write!(f, "invalid {kind} scalar {value}")
}
CtError::InvalidQuantity { kind, value } => {
write!(f, "invalid {kind} quantity {value}")
}
CtError::InvalidInterval { start, end } => {
write!(f, "invalid interval: start {start} is after end {end}")
}
}
}
}
impl std::error::Error for CtError {}
pub type CtResult<T> = Result<T, CtError>;
src/domain.rs
use crate::error::{CtError, CtResult};
/// A vocabulary index. It is intentionally not a raw `usize` in public APIs.
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub struct TokenId(usize);
impl TokenId {
pub fn new(index: usize) -> Self {
Self(index)
}
pub fn index(&self) -> usize {
self.0
}
}
impl From<usize> for TokenId {
fn from(value: usize) -> Self {
Self::new(value)
}
}
/// A sequence of tokens before it has been converted into training pairs.
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct TokenSequence(Vec<TokenId>);
impl TokenSequence {
pub fn new(tokens: impl IntoIterator<Item = TokenId>) -> CtResult<Self> {
let tokens = tokens.into_iter().collect::<Vec<_>>();
if tokens.is_empty() {
return Err(CtError::EmptyInput("token sequence"));
}
Ok(Self(tokens))
}
pub fn from_indices(indices: impl IntoIterator<Item = usize>) -> CtResult<Self> {
Self::new(indices.into_iter().map(TokenId::new))
}
pub fn as_slice(&self) -> &[TokenId] {
&self.0
}
}
/// A dense feature vector.
#[derive(Debug, Clone, PartialEq)]
pub struct Vector(Vec<f32>);
impl Vector {
pub fn new(values: Vec<f32>) -> Self {
Self(values)
}
pub fn as_slice(&self) -> &[f32] {
&self.0
}
}
/// Unnormalized model scores.
#[derive(Debug, Clone, PartialEq)]
pub struct Logits(Vec<f32>);
impl Logits {
pub fn new(values: Vec<f32>) -> Self {
Self(values)
}
pub fn as_slice(&self) -> &[f32] {
&self.0
}
}
/// A validated probability distribution.
#[derive(Debug, Clone, PartialEq)]
pub struct Distribution(Vec<f32>);
impl Distribution {
pub fn new(probabilities: Vec<f32>) -> CtResult<Self> {
if probabilities.is_empty() {
return Err(CtError::EmptyInput("distribution"));
}
let sum: f32 = probabilities.iter().sum();
let all_valid = probabilities
.iter()
.all(|probability| probability.is_finite() && *probability >= 0.0);
if !all_valid || !approx_eq(sum, 1.0, 1e-4) {
return Err(CtError::InvalidProbability("distribution constructor"));
}
Ok(Self(probabilities))
}
pub fn as_slice(&self) -> &[f32] {
&self.0
}
}
/// A non-negative scalar objective value.
#[derive(Debug, Clone, Copy, PartialEq)]
pub struct Loss(f32);
impl Loss {
pub fn new(value: f32) -> CtResult<Self> {
if !value.is_finite() || value < 0.0 {
return Err(CtError::InvalidLoss(value));
}
Ok(Self(value))
}
pub fn value(&self) -> f32 {
self.0
}
}
/// Number of vocabulary entries.
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct VocabSize(usize);
impl VocabSize {
pub fn new(value: usize) -> CtResult<Self> {
if value == 0 {
return Err(CtError::EmptyInput("vocabulary"));
}
Ok(Self(value))
}
pub fn value(&self) -> usize {
self.0
}
}
/// Width of each embedding vector.
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct ModelDimension(usize);
impl ModelDimension {
pub fn new(value: usize) -> CtResult<Self> {
if value == 0 {
return Err(CtError::EmptyInput("model dimension"));
}
Ok(Self(value))
}
pub fn value(&self) -> usize {
self.0
}
}
/// Positive optimizer step size.
#[derive(Debug, Clone, Copy, PartialEq)]
pub struct LearningRate(f32);
impl LearningRate {
pub fn new(value: f32) -> CtResult<Self> {
if !value.is_finite() || value <= 0.0 {
return Err(CtError::InvalidLearningRate(value));
}
Ok(Self(value))
}
pub fn value(&self) -> f32 {
self.0
}
}
/// Categorical product object: `A x B`.
#[derive(Debug, Clone, PartialEq)]
pub struct Product<A, B> {
first: A,
second: B,
}
impl<A, B> Product<A, B> {
pub fn new(first: A, second: B) -> Self {
Self { first, second }
}
pub fn first(&self) -> &A {
&self.first
}
pub fn second(&self) -> &B {
&self.second
}
pub fn into_parts(self) -> (A, B) {
(self.first, self.second)
}
}
pub type TrainingExample = Product<TokenId, TokenId>;
/// Non-empty next-token training pairs.
#[derive(Debug, Clone, PartialEq)]
pub struct TrainingSet(Vec<TrainingExample>);
impl TrainingSet {
pub fn new(examples: impl IntoIterator<Item = TrainingExample>) -> CtResult<Self> {
let examples = examples.into_iter().collect::<Vec<_>>();
if examples.is_empty() {
return Err(CtError::EmptyInput("training set"));
}
Ok(Self(examples))
}
pub fn examples(&self) -> &[TrainingExample] {
&self.0
}
pub fn len(&self) -> usize {
self.0.len()
}
pub fn is_empty(&self) -> bool {
self.0.is_empty()
}
}
/// Tiny model parameters for an embedding plus language-model head.
#[derive(Debug, Clone, PartialEq)]
pub struct Parameters {
pub(crate) embedding: Vec<Vec<f32>>,
pub(crate) lm_head: Vec<Vec<f32>>,
pub(crate) bias: Vec<f32>,
}
impl Parameters {
pub fn init(vocab_size: VocabSize, d_model: ModelDimension) -> Self {
let vocab_size = vocab_size.value();
let d_model = d_model.value();
Self {
embedding: init_matrix(vocab_size, d_model, 0.2, 1),
lm_head: init_matrix(d_model, vocab_size, 0.2, 2),
bias: vec![0.0; vocab_size],
}
}
pub fn vocab_size(&self) -> usize {
self.bias.len()
}
pub fn d_model(&self) -> usize {
self.embedding.first().map_or(0, Vec::len)
}
pub fn embedding_table(&self) -> &[Vec<f32>] {
&self.embedding
}
pub fn lm_head(&self) -> &[Vec<f32>] {
&self.lm_head
}
pub fn bias(&self) -> &[f32] {
&self.bias
}
}
pub(crate) fn init_matrix(rows: usize, cols: usize, scale: f32, seed: usize) -> Vec<Vec<f32>> {
let mut out = vec![vec![0.0; cols]; rows];
for (row_index, row) in out.iter_mut().enumerate() {
for (col_index, value) in row.iter_mut().enumerate() {
let raw = ((row_index * cols + col_index) * 37 + seed * 101) % 1000;
let unit = raw as f32 / 1000.0;
*value = (unit - 0.5) * scale;
}
}
out
}
pub(crate) fn approx_eq(a: f32, b: f32, eps: f32) -> bool {
(a - b).abs() <= eps
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn distribution_rejects_non_normalized_values() {
let result = Distribution::new(vec![0.4, 0.4]);
assert!(matches!(result, Err(CtError::InvalidProbability(_))));
}
#[test]
fn token_sequence_rejects_empty_input() {
let result = TokenSequence::new(vec![]);
assert!(matches!(result, Err(CtError::EmptyInput("token sequence"))));
}
}
src/category.rs
use std::marker::PhantomData;
use crate::error::CtResult;
/// A typed category-theory arrow: `Input -> Output`.
pub trait Morphism<Input, Output> {
fn name(&self) -> &'static str;
fn apply(&self, input: Input) -> CtResult<Output>;
}
/// Identity morphism: `id_A : A -> A`.
#[derive(Debug, Clone, Copy)]
pub struct Identity<T> {
_marker: PhantomData<T>,
}
impl<T> Identity<T> {
pub fn new() -> Self {
Self {
_marker: PhantomData,
}
}
}
impl<T> Default for Identity<T> {
fn default() -> Self {
Self::new()
}
}
impl<T> Morphism<T, T> for Identity<T> {
fn name(&self) -> &'static str {
"identity"
}
fn apply(&self, input: T) -> CtResult<T> {
Ok(input)
}
}
/// Composition of two morphisms: if `f : A -> B` and `g : B -> C`, this is
/// `g after f : A -> C`.
#[derive(Debug, Clone)]
pub struct Compose<F, G, Middle> {
first: F,
second: G,
_middle: PhantomData<Middle>,
}
impl<F, G, Middle> Compose<F, G, Middle> {
pub fn new(first: F, second: G) -> Self {
Self {
first,
second,
_middle: PhantomData,
}
}
}
impl<Input, Middle, Output, F, G> Morphism<Input, Output> for Compose<F, G, Middle>
where
F: Morphism<Input, Middle>,
G: Morphism<Middle, Output>,
{
fn name(&self) -> &'static str {
"composition"
}
fn apply(&self, input: Input) -> CtResult<Output> {
let middle = self.first.apply(input)?;
self.second.apply(middle)
}
}
/// Endomorphism: a morphism from a type back to itself.
pub trait Endomorphism<T>: Morphism<T, T> {}
impl<T, M> Endomorphism<T> for M where M: Morphism<T, T> {}
/// How many times to repeat an endomorphism.
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct StepCount(usize);
impl StepCount {
pub fn new(value: usize) -> Self {
Self(value)
}
pub fn value(&self) -> usize {
self.0
}
}
/// Repeatedly apply an endomorphism: `A0 -> A1 -> ... -> An`.
pub fn apply_endomorphism_n_times<T, E>(endo: &E, mut value: T, count: StepCount) -> CtResult<T>
where
E: Endomorphism<T>,
{
for _ in 0..count.value() {
value = endo.apply(value)?;
}
Ok(value)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::error::CtError;
#[derive(Debug, Clone, Copy)]
struct AddOne;
impl Morphism<i32, i32> for AddOne {
fn name(&self) -> &'static str {
"add_one"
}
fn apply(&self, input: i32) -> CtResult<i32> {
Ok(input + 1)
}
}
#[derive(Debug, Clone, Copy)]
struct Double;
impl Morphism<i32, i32> for Double {
fn name(&self) -> &'static str {
"double"
}
fn apply(&self, input: i32) -> CtResult<i32> {
Ok(input * 2)
}
}
#[derive(Debug, Clone, Copy)]
struct Fail;
impl Morphism<i32, i32> for Fail {
fn name(&self) -> &'static str {
"fail"
}
fn apply(&self, _input: i32) -> CtResult<i32> {
Err(CtError::InvalidQuantity {
kind: "test morphism",
value: -1,
})
}
}
#[test]
fn identity_returns_the_same_value() -> CtResult<()> {
let value = String::from("same");
assert_eq!(Identity::<String>::new().apply(value.clone())?, value);
Ok(())
}
#[test]
fn identity_composes_without_changing_behavior() -> CtResult<()> {
let left_identity = Compose::<_, _, i32>::new(Identity::<i32>::new(), AddOne);
let right_identity = Compose::<_, _, i32>::new(AddOne, Identity::<i32>::new());
assert_eq!(left_identity.apply(41)?, AddOne.apply(41)?);
assert_eq!(right_identity.apply(41)?, AddOne.apply(41)?);
Ok(())
}
#[test]
fn composition_applies_first_then_second() -> CtResult<()> {
let add_then_double = Compose::<_, _, i32>::new(AddOne, Double);
assert_eq!(add_then_double.apply(4)?, 10);
Ok(())
}
#[test]
fn composition_returns_the_first_error() {
let composed = Compose::<_, _, i32>::new(Fail, AddOne);
assert!(matches!(
composed.apply(4),
Err(CtError::InvalidQuantity {
kind: "test morphism",
value: -1,
})
));
}
}
src/ml.rs
use crate::category::{Compose, Morphism};
use crate::domain::{
Distribution, Logits, Loss, Parameters, Product, TokenId, TokenSequence, TrainingSet, Vector,
approx_eq,
};
use crate::error::{CtError, CtResult};
/// Turns adjacent tokens into next-token training examples.
#[derive(Debug, Clone)]
pub struct DatasetWindowing;
impl Morphism<TokenSequence, TrainingSet> for DatasetWindowing {
fn name(&self) -> &'static str {
"dataset_windowing"
}
fn apply(&self, tokens: TokenSequence) -> CtResult<TrainingSet> {
if tokens.as_slice().len() < 2 {
return Err(CtError::EmptyInput(
"dataset windowing requires at least 2 tokens",
));
}
TrainingSet::new(
tokens
.as_slice()
.windows(2)
.map(|pair| Product::new(pair[0], pair[1])),
)
}
}
/// Morphism from token id to embedding vector.
#[derive(Debug, Clone)]
pub struct Embedding {
table: Vec<Vec<f32>>,
}
impl Embedding {
pub fn from_parameters(params: &Parameters) -> Self {
Self {
table: params.embedding_table().to_vec(),
}
}
}
impl Morphism<TokenId, Vector> for Embedding {
fn name(&self) -> &'static str {
"embedding"
}
fn apply(&self, token: TokenId) -> CtResult<Vector> {
let Some(row) = self.table.get(token.index()) else {
return Err(CtError::OutOfRange {
kind: "token",
index: token.index(),
limit: self.table.len(),
});
};
Ok(Vector::new(row.clone()))
}
}
/// Linear projection from hidden vector to vocabulary logits.
#[derive(Debug, Clone)]
pub struct LinearToLogits {
weight: Vec<Vec<f32>>,
bias: Vec<f32>,
}
impl LinearToLogits {
pub fn from_parameters(params: &Parameters) -> Self {
Self {
weight: params.lm_head().to_vec(),
bias: params.bias().to_vec(),
}
}
pub(crate) fn from_parts(weight: Vec<Vec<f32>>, bias: Vec<f32>) -> Self {
Self { weight, bias }
}
}
impl Morphism<Vector, Logits> for LinearToLogits {
fn name(&self) -> &'static str {
"linear_to_logits"
}
fn apply(&self, input: Vector) -> CtResult<Logits> {
let d_model = input.as_slice().len();
let vocab_size = self.bias.len();
if self.weight.len() != d_model {
return Err(CtError::ShapeMismatch {
op: "linear layer",
expected: format!("weight rows == input dim {d_model}"),
got: format!("weight rows {}", self.weight.len()),
});
}
let mut out = self.bias.clone();
for (feature, input_value) in input.as_slice().iter().enumerate() {
if self.weight[feature].len() != vocab_size {
return Err(CtError::ShapeMismatch {
op: "linear layer",
expected: format!("weight cols == vocab size {vocab_size}"),
got: format!("weight cols {}", self.weight[feature].len()),
});
}
for (vocab_id, output_value) in out.iter_mut().enumerate() {
*output_value += input_value * self.weight[feature][vocab_id];
}
}
Ok(Logits::new(out))
}
}
/// Converts logits to a probability distribution.
#[derive(Debug, Clone)]
pub struct Softmax;
impl Morphism<Logits, Distribution> for Softmax {
fn name(&self) -> &'static str {
"softmax"
}
fn apply(&self, logits: Logits) -> CtResult<Distribution> {
if logits.as_slice().is_empty() {
return Err(CtError::EmptyInput("softmax"));
}
let max_value = logits
.as_slice()
.iter()
.copied()
.fold(f32::NEG_INFINITY, f32::max);
let mut exps = Vec::with_capacity(logits.as_slice().len());
let mut sum = 0.0;
for value in logits.as_slice() {
let exp = (*value - max_value).exp();
exps.push(exp);
sum += exp;
}
if sum <= 0.0 || !sum.is_finite() {
return Err(CtError::InvalidProbability("softmax"));
}
Distribution::new(exps.into_iter().map(|value| value / sum).collect())
}
}
/// Negative log likelihood for `(distribution, target_token)`.
#[derive(Debug, Clone)]
pub struct CrossEntropy;
impl Morphism<Product<Distribution, TokenId>, Loss> for CrossEntropy {
fn name(&self) -> &'static str {
"cross_entropy"
}
fn apply(&self, input: Product<Distribution, TokenId>) -> CtResult<Loss> {
let (distribution, target) = input.into_parts();
let Some(probability) = distribution.as_slice().get(target.index()).copied() else {
return Err(CtError::OutOfRange {
kind: "target",
index: target.index(),
limit: distribution.as_slice().len(),
});
};
Loss::new(-probability.max(1e-9).ln())
}
}
/// Direct path used for a commutative-diagram check.
#[derive(Debug, Clone)]
pub struct DirectPredict {
params: Parameters,
}
impl DirectPredict {
pub fn new(params: Parameters) -> Self {
Self { params }
}
}
impl Morphism<TokenId, Distribution> for DirectPredict {
fn name(&self) -> &'static str {
"direct_predict"
}
fn apply(&self, token: TokenId) -> CtResult<Distribution> {
let embedding = Embedding::from_parameters(&self.params).apply(token)?;
let logits = LinearToLogits::from_parameters(&self.params).apply(embedding)?;
Softmax.apply(logits)
}
}
/// Average cross-entropy over a training set.
pub fn average_loss(params: &Parameters, dataset: &TrainingSet) -> CtResult<Loss> {
let embedding = Embedding::from_parameters(params);
let linear = LinearToLogits::from_parameters(params);
let predict = Compose::<_, _, Vector>::new(embedding, linear);
let predict = Compose::<_, _, Logits>::new(predict, Softmax);
let loss_fn = CrossEntropy;
let mut total = 0.0;
for example in dataset.examples() {
let distribution = predict.apply(*example.first())?;
let loss = loss_fn.apply(Product::new(distribution, *example.second()))?;
total += loss.value();
}
Loss::new(total / dataset.len() as f32)
}
/// Verifies that the composed path and direct path produce the same result.
pub fn composed_prediction_matches_direct_prediction(params: &Parameters) -> CtResult<bool> {
let token = TokenId::new(1);
let composed = Compose::<_, _, Vector>::new(
Embedding::from_parameters(params),
LinearToLogits::from_parameters(params),
);
let composed = Compose::<_, _, Logits>::new(composed, Softmax);
let direct = DirectPredict::new(params.clone());
let left_path = composed.apply(token)?;
let right_path = direct.apply(token)?;
Ok(left_path
.as_slice()
.iter()
.zip(right_path.as_slice().iter())
.all(|(a, b)| approx_eq(*a, *b, 1e-6)))
}
#[cfg(test)]
mod tests {
use super::*;
use crate::domain::{ModelDimension, VocabSize};
#[test]
fn dataset_windowing_builds_adjacent_pairs() -> CtResult<()> {
let tokens = TokenSequence::from_indices([1, 2, 3])?;
let dataset = DatasetWindowing.apply(tokens)?;
assert_eq!(dataset.len(), 2);
assert_eq!(dataset.examples()[0].first().index(), 1);
assert_eq!(dataset.examples()[0].second().index(), 2);
Ok(())
}
#[test]
fn composed_and_direct_prediction_match() -> CtResult<()> {
let params = Parameters::init(VocabSize::new(5)?, ModelDimension::new(4)?);
assert!(composed_prediction_matches_direct_prediction(¶ms)?);
Ok(())
}
#[test]
fn softmax_normalizes_logits_into_distribution() -> CtResult<()> {
let distribution = Softmax.apply(Logits::new(vec![1.0, 2.0, 3.0]))?;
let probabilities = distribution.as_slice();
let sum: f32 = probabilities.iter().sum();
assert!(approx_eq(sum, 1.0, 1e-6));
assert!(probabilities[2] > probabilities[1]);
assert!(probabilities[1] > probabilities[0]);
Ok(())
}
#[test]
fn cross_entropy_is_lower_for_more_confident_target_probability() -> CtResult<()> {
let confident = Distribution::new(vec![0.9, 0.1])?;
let surprised = Distribution::new(vec![0.1, 0.9])?;
let confident_loss = CrossEntropy.apply(Product::new(confident, TokenId::new(0)))?;
let surprised_loss = CrossEntropy.apply(Product::new(surprised, TokenId::new(0)))?;
assert!(confident_loss.value() < surprised_loss.value());
Ok(())
}
}
src/training.rs
use crate::category::Morphism;
use crate::domain::{LearningRate, Parameters, TrainingSet, Vector};
use crate::error::{CtError, CtResult};
use crate::ml::{LinearToLogits, Softmax};
/// One full-batch optimizer update.
///
/// Categorically, this is an endomorphism:
///
/// `Parameters -> Parameters`
#[derive(Debug, Clone)]
pub struct TrainStep {
dataset: TrainingSet,
learning_rate: LearningRate,
}
impl TrainStep {
pub fn new(dataset: TrainingSet, learning_rate: LearningRate) -> Self {
Self {
dataset,
learning_rate,
}
}
}
impl Morphism<Parameters, Parameters> for TrainStep {
fn name(&self) -> &'static str {
"train_step_endomorphism"
}
fn apply(&self, params: Parameters) -> CtResult<Parameters> {
let vocab_size = params.vocab_size();
let d_model = params.d_model();
if vocab_size == 0 || d_model == 0 {
return Err(CtError::EmptyInput("parameters"));
}
let mut grad_embedding = vec![vec![0.0; d_model]; params.embedding.len()];
let mut grad_lm_head = vec![vec![0.0; vocab_size]; d_model];
let mut grad_bias = vec![0.0; vocab_size];
for example in self.dataset.examples() {
let input_id = example.first().index();
let target_id = example.second().index();
if input_id >= params.embedding.len() {
return Err(CtError::OutOfRange {
kind: "input token",
index: input_id,
limit: params.embedding.len(),
});
}
if target_id >= vocab_size {
return Err(CtError::OutOfRange {
kind: "target token",
index: target_id,
limit: vocab_size,
});
}
let x = ¶ms.embedding[input_id];
let logits = LinearToLogits::from_parts(params.lm_head.clone(), params.bias.clone())
.apply(Vector::new(x.clone()))?;
let probs = Softmax.apply(logits)?;
let mut dlogits = probs.as_slice().to_vec();
dlogits[target_id] -= 1.0;
for (vocab_id, dlogit) in dlogits.iter().copied().enumerate() {
grad_bias[vocab_id] += dlogit;
for (feature, x_feature) in x.iter().copied().enumerate() {
grad_lm_head[feature][vocab_id] += x_feature * dlogit;
}
}
for (feature, grad_feature) in grad_embedding[input_id].iter_mut().enumerate() {
let dx = params.lm_head[feature]
.iter()
.zip(dlogits.iter())
.map(|(weight, dlogit)| weight * dlogit)
.sum::<f32>();
*grad_feature += dx;
}
}
let batch_scale = 1.0 / self.dataset.len() as f32;
let learning_rate = self.learning_rate.value();
let mut updated = params.clone();
for (row, grad_row) in updated.embedding.iter_mut().zip(grad_embedding.iter()) {
for (value, grad) in row.iter_mut().zip(grad_row.iter()) {
*value -= learning_rate * grad * batch_scale;
}
}
for (row, grad_row) in updated.lm_head.iter_mut().zip(grad_lm_head.iter()) {
for (value, grad) in row.iter_mut().zip(grad_row.iter()) {
*value -= learning_rate * grad * batch_scale;
}
}
for (bias, grad) in updated.bias.iter_mut().zip(grad_bias.iter()) {
*bias -= learning_rate * grad * batch_scale;
}
Ok(updated)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::category::{StepCount, apply_endomorphism_n_times};
use crate::domain::{ModelDimension, Product, TokenId, TokenSequence, TrainingSet, VocabSize};
use crate::ml::{DatasetWindowing, average_loss};
#[test]
fn repeated_training_step_reduces_loss() -> CtResult<()> {
let tokens = TokenSequence::from_indices([1, 2, 3, 4, 1, 2, 3, 4])?;
let dataset = DatasetWindowing.apply(tokens)?;
let params = Parameters::init(VocabSize::new(5)?, ModelDimension::new(4)?);
let before = average_loss(¶ms, &dataset)?;
let train_step = TrainStep::new(dataset.clone(), LearningRate::new(1.0)?);
let trained = apply_endomorphism_n_times(&train_step, params, StepCount::new(80))?;
let after = average_loss(&trained, &dataset)?;
assert!(after.value() < before.value());
Ok(())
}
#[test]
fn one_training_step_preserves_parameter_shape() -> CtResult<()> {
let tokens = TokenSequence::from_indices([1, 2, 3, 4])?;
let dataset = DatasetWindowing.apply(tokens)?;
let params = Parameters::init(VocabSize::new(5)?, ModelDimension::new(4)?);
let train_step = TrainStep::new(dataset, LearningRate::new(0.1)?);
let trained = train_step.apply(params.clone())?;
assert_eq!(trained.vocab_size(), params.vocab_size());
assert_eq!(trained.d_model(), params.d_model());
Ok(())
}
#[test]
fn training_rejects_target_outside_vocabulary() -> CtResult<()> {
let dataset = TrainingSet::new([Product::new(TokenId::new(0), TokenId::new(9))])?;
let params = Parameters::init(VocabSize::new(2)?, ModelDimension::new(2)?);
let train_step = TrainStep::new(dataset, LearningRate::new(0.1)?);
assert!(matches!(
train_step.apply(params),
Err(CtError::OutOfRange {
kind: "target token",
index: 9,
limit: 2,
})
));
Ok(())
}
}
src/structure.rs
/// A minimal functor interface for this tutorial.
pub trait Functor<A, B> {
type WrappedA;
type WrappedB;
fn fmap<F>(wrapped: Self::WrappedA, f: F) -> Self::WrappedB
where
F: Fn(A) -> B;
}
pub struct VecFunctor;
impl<A, B> Functor<A, B> for VecFunctor {
type WrappedA = Vec<A>;
type WrappedB = Vec<B>;
fn fmap<F>(wrapped: Vec<A>, f: F) -> Vec<B>
where
F: Fn(A) -> B,
{
wrapped.into_iter().map(f).collect()
}
}
pub struct OptionFunctor;
impl<A, B> Functor<A, B> for OptionFunctor {
type WrappedA = Option<A>;
type WrappedB = Option<B>;
fn fmap<F>(wrapped: Option<A>, f: F) -> Option<B>
where
F: Fn(A) -> B,
{
wrapped.map(f)
}
}
/// A structure-preserving conversion between wrappers.
pub trait NaturalTransformation<A> {
type From;
type To;
fn transform(from: Self::From) -> Self::To;
}
/// Natural transformation from `Vec<A>` to `Option<A>` by taking the first item.
pub struct VecToFirstOption;
impl<A> NaturalTransformation<A> for VecToFirstOption {
type From = Vec<A>;
type To = Option<A>;
fn transform(from: Vec<A>) -> Option<A> {
from.into_iter().next()
}
}
pub fn naturality_square_holds_for_first_option() -> bool {
let xs = vec![1, 2, 3];
let f = |x| x * 10;
let path_top_then_right = VecToFirstOption::transform(VecFunctor::fmap(xs.clone(), f));
let path_left_then_bottom = OptionFunctor::fmap(VecToFirstOption::transform(xs), f);
path_top_then_right == path_left_then_bottom
}
pub trait Monoid: Sized {
fn empty() -> Self;
fn combine(&self, other: &Self) -> Self;
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct TraceStep(&'static str);
impl TraceStep {
pub fn new(name: &'static str) -> Self {
Self(name)
}
pub fn name(&self) -> &'static str {
self.0
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct PipelineTrace(Vec<TraceStep>);
impl PipelineTrace {
pub fn from_steps(steps: impl IntoIterator<Item = TraceStep>) -> Self {
Self(steps.into_iter().collect())
}
pub fn names(&self) -> Vec<&'static str> {
self.0.iter().map(TraceStep::name).collect()
}
}
impl Monoid for PipelineTrace {
fn empty() -> Self {
PipelineTrace(vec![])
}
fn combine(&self, other: &Self) -> Self {
let mut combined = self.0.clone();
combined.extend_from_slice(&other.0);
PipelineTrace(combined)
}
}
pub fn monoid_laws_hold_for_pipeline_trace() -> bool {
let a = PipelineTrace::from_steps(vec![TraceStep::new("embedding")]);
let b = PipelineTrace::from_steps(vec![TraceStep::new("linear")]);
let c = PipelineTrace::from_steps(vec![TraceStep::new("softmax")]);
let identity = PipelineTrace::empty();
let left_identity = identity.combine(&a) == a;
let right_identity = a.combine(&identity) == a;
let associativity = a.combine(&b).combine(&c) == a.combine(&b.combine(&c));
left_identity && right_identity && associativity
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn vec_functor_preserves_identity_for_values() {
let values = vec![1, 2, 3];
assert_eq!(VecFunctor::fmap(values.clone(), |value| value), values);
}
#[test]
fn vec_functor_preserves_composition_for_values() {
let values = vec![1, 2, 3];
let add_one = |value| value + 1;
let double = |value| value * 2;
let map_then_map = VecFunctor::fmap(VecFunctor::fmap(values.clone(), add_one), double);
let map_composed = VecFunctor::fmap(values, |value| double(add_one(value)));
assert_eq!(map_then_map, map_composed);
}
#[test]
fn option_functor_preserves_absence() {
let missing = OptionFunctor::fmap(None::<i32>, |value| value * 10);
assert_eq!(missing, None);
}
#[test]
fn naturality_square_commutes() {
assert!(naturality_square_holds_for_first_option());
}
#[test]
fn pipeline_trace_obeys_monoid_laws() {
assert!(monoid_laws_hold_for_pipeline_trace());
}
}
src/calculus.rs
use crate::error::{CtError, CtResult};
#[derive(Debug, Clone, Copy, PartialEq)]
pub struct Scalar(f32);
impl Scalar {
pub fn new(value: f32) -> CtResult<Self> {
if !value.is_finite() {
return Err(CtError::InvalidLoss(value));
}
Ok(Self(value))
}
pub fn value(&self) -> f32 {
self.0
}
}
#[derive(Debug, Clone, Copy, PartialEq)]
pub struct LocalGradient(f32);
impl LocalGradient {
pub fn new(value: f32) -> CtResult<Self> {
if !value.is_finite() {
return Err(CtError::InvalidLoss(value));
}
Ok(Self(value))
}
pub fn value(&self) -> f32 {
self.0
}
}
#[derive(Debug, Clone, Copy)]
pub struct MulOp;
impl MulOp {
pub fn forward(&self, x: Scalar, y: Scalar) -> CtResult<Scalar> {
Scalar::new(x.value() * y.value())
}
/// Given upstream gradient dL/dz, return `(dL/dx, dL/dy)` for `z = x * y`.
pub fn backward(
&self,
x: Scalar,
y: Scalar,
upstream: LocalGradient,
) -> CtResult<(LocalGradient, LocalGradient)> {
let dz_dx = y.value();
let dz_dy = x.value();
Ok((
LocalGradient::new(upstream.value() * dz_dx)?,
LocalGradient::new(upstream.value() * dz_dy)?,
))
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn multiply_backward_returns_local_chain_rule_gradients() -> CtResult<()> {
let mul = MulOp;
let x = Scalar::new(2.0)?;
let y = Scalar::new(3.0)?;
let upstream = LocalGradient::new(1.0)?;
let (dl_dx, dl_dy) = mul.backward(x, y, upstream)?;
assert_eq!(dl_dx.value(), 3.0);
assert_eq!(dl_dy.value(), 2.0);
Ok(())
}
#[test]
fn multiply_backward_scales_with_upstream_gradient() -> CtResult<()> {
let mul = MulOp;
let x = Scalar::new(2.0)?;
let y = Scalar::new(3.0)?;
let upstream = LocalGradient::new(4.0)?;
let (dl_dx, dl_dy) = mul.backward(x, y, upstream)?;
assert_eq!(dl_dx.value(), 12.0);
assert_eq!(dl_dy.value(), 8.0);
Ok(())
}
}
src/attention.rs
//! Tiny typed attention boundary for the Transformer roadmap.
//!
//! This module does not implement a full Transformer. It makes the first small
//! attention-specific shapes explicit:
//!
//! - projected queries and keys become query-by-key scores,
//! - masks turn illegal score positions into negligible softmax inputs,
//! - query-by-key scores become row-wise attention probabilities,
//! - attention probabilities mix value vectors into output vectors,
//! - multiple head outputs concatenate into a multi-head output,
//! - the concatenated heads project back into a hidden sequence width,
//! - residual addition preserves the hidden sequence boundary,
//! - layer normalization preserves the hidden sequence boundary,
//! - a position-wise feed-forward map preserves the hidden sequence boundary,
//! - positional encoding adds position information while preserving shape,
//! - a single-head block sketch composes those boundaries end to end,
//! - a masked multi-head block accepts attention masks at the block boundary,
//! - a structured parameter object owns position, block, and readout pieces,
//! - a training-state object owns parameters, learning rate, and step count,
//! - a composed block train step updates readout, feed-forward,
//! normalization, and attention projection parameters from sequence targets.
use crate::category::{Morphism, StepCount};
use crate::domain::{
Distribution, LearningRate, Logits, Loss, ModelDimension, Product, TokenSequence, Vector,
VocabSize,
};
use crate::error::{CtError, CtResult};
use crate::ml::Softmax;
const MASKED_SCORE: f32 = -1_000_000.0;
/// Number of positions in a sequence.
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct SequenceLength(usize);
impl SequenceLength {
pub fn new(value: usize) -> CtResult<Self> {
if value == 0 {
return Err(CtError::EmptyInput("sequence length"));
}
Ok(Self(value))
}
pub fn value(&self) -> usize {
self.0
}
}
/// Number of parallel attention heads.
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct HeadCount(usize);
impl HeadCount {
pub fn new(value: usize) -> CtResult<Self> {
if value == 0 {
return Err(CtError::EmptyInput("head count"));
}
Ok(Self(value))
}
pub fn value(&self) -> usize {
self.0
}
}
/// Width of one attention head.
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct HeadDimension(usize);
impl HeadDimension {
pub fn new(value: usize) -> CtResult<Self> {
if value == 0 {
return Err(CtError::EmptyInput("head dimension"));
}
Ok(Self(value))
}
pub fn value(&self) -> usize {
self.0
}
}
/// Positive stabilizer used in layer normalization.
#[derive(Debug, Clone, Copy, PartialEq)]
pub struct NormalizationEpsilon(f32);
impl NormalizationEpsilon {
pub fn new(value: f32) -> CtResult<Self> {
if !value.is_finite() || value <= 0.0 {
return Err(CtError::ShapeMismatch {
op: "normalization epsilon",
expected: "positive finite epsilon".to_string(),
got: format!("epsilon {value}"),
});
}
Ok(Self(value))
}
pub fn value(&self) -> f32 {
self.0
}
}
/// Projected query vectors for one attention head.
#[derive(Debug, Clone, PartialEq)]
pub struct QuerySequence {
sequence_len: SequenceLength,
head_dimension: HeadDimension,
rows: Vec<Vector>,
}
impl QuerySequence {
pub fn new(rows: Vec<Vec<f32>>) -> CtResult<Self> {
let matrix = AttentionVectorRows::new("query sequence", rows)?;
Ok(Self {
sequence_len: matrix.sequence_len,
head_dimension: matrix.head_dimension,
rows: matrix.rows,
})
}
pub fn sequence_len(&self) -> SequenceLength {
self.sequence_len
}
pub fn head_dimension(&self) -> HeadDimension {
self.head_dimension
}
pub fn rows(&self) -> &[Vector] {
&self.rows
}
}
/// Projected key vectors for one attention head.
#[derive(Debug, Clone, PartialEq)]
pub struct KeySequence {
sequence_len: SequenceLength,
head_dimension: HeadDimension,
rows: Vec<Vector>,
}
impl KeySequence {
pub fn new(rows: Vec<Vec<f32>>) -> CtResult<Self> {
let matrix = AttentionVectorRows::new("key sequence", rows)?;
Ok(Self {
sequence_len: matrix.sequence_len,
head_dimension: matrix.head_dimension,
rows: matrix.rows,
})
}
pub fn sequence_len(&self) -> SequenceLength {
self.sequence_len
}
pub fn head_dimension(&self) -> HeadDimension {
self.head_dimension
}
pub fn rows(&self) -> &[Vector] {
&self.rows
}
}
/// Projected value vectors for one attention head.
#[derive(Debug, Clone, PartialEq)]
pub struct ValueSequence {
sequence_len: SequenceLength,
head_dimension: HeadDimension,
rows: Vec<Vector>,
}
impl ValueSequence {
pub fn new(rows: Vec<Vec<f32>>) -> CtResult<Self> {
let matrix = AttentionVectorRows::new("value sequence", rows)?;
Ok(Self {
sequence_len: matrix.sequence_len,
head_dimension: matrix.head_dimension,
rows: matrix.rows,
})
}
pub fn sequence_len(&self) -> SequenceLength {
self.sequence_len
}
pub fn head_dimension(&self) -> HeadDimension {
self.head_dimension
}
pub fn rows(&self) -> &[Vector] {
&self.rows
}
}
/// Hidden vectors over sequence positions.
#[derive(Debug, Clone, PartialEq)]
pub struct HiddenSequence {
sequence_len: SequenceLength,
model_dimension: ModelDimension,
rows: Vec<Vector>,
}
impl HiddenSequence {
pub fn new(rows: Vec<Vec<f32>>) -> CtResult<Self> {
let matrix = AttentionVectorRows::new("hidden sequence", rows)?;
Ok(Self {
sequence_len: matrix.sequence_len,
model_dimension: ModelDimension::new(matrix.head_dimension.value())?,
rows: matrix.rows,
})
}
pub fn sequence_len(&self) -> SequenceLength {
self.sequence_len
}
pub fn model_dimension(&self) -> ModelDimension {
self.model_dimension
}
pub fn rows(&self) -> &[Vector] {
&self.rows
}
}
/// A finite table of position vectors added to hidden states.
#[derive(Debug, Clone, PartialEq)]
pub struct PositionalEncoding {
max_sequence_len: SequenceLength,
model_dimension: ModelDimension,
rows: Vec<Vector>,
}
impl PositionalEncoding {
pub fn new(rows: Vec<Vec<f32>>) -> CtResult<Self> {
let matrix = AttentionVectorRows::new("positional encoding", rows)?;
Ok(Self {
max_sequence_len: matrix.sequence_len,
model_dimension: ModelDimension::new(matrix.head_dimension.value())?,
rows: matrix.rows,
})
}
pub fn max_sequence_len(&self) -> SequenceLength {
self.max_sequence_len
}
pub fn model_dimension(&self) -> ModelDimension {
self.model_dimension
}
}
#[derive(Debug, Clone, PartialEq)]
struct AttentionVectorRows {
sequence_len: SequenceLength,
head_dimension: HeadDimension,
rows: Vec<Vector>,
}
impl AttentionVectorRows {
fn new(kind: &'static str, rows: Vec<Vec<f32>>) -> CtResult<Self> {
if rows.is_empty() {
return Err(CtError::EmptyInput(kind));
}
let head_dimension = rows[0].len();
if head_dimension == 0 {
return Err(CtError::EmptyInput("attention vector row"));
}
for row in &rows {
if row.len() != head_dimension {
return Err(CtError::ShapeMismatch {
op: kind,
expected: format!("all rows have {head_dimension} columns"),
got: format!("row with {} columns", row.len()),
});
}
if row.iter().any(|value| !value.is_finite()) {
return Err(CtError::ShapeMismatch {
op: kind,
expected: "all vector values are finite".to_string(),
got: "non-finite vector value".to_string(),
});
}
}
Ok(Self {
sequence_len: SequenceLength::new(rows.len())?,
head_dimension: HeadDimension::new(head_dimension)?,
rows: rows.into_iter().map(Vector::new).collect(),
})
}
}
/// Query-by-key scores before row-wise softmax.
#[derive(Debug, Clone, PartialEq)]
pub struct AttentionScores {
query_len: SequenceLength,
key_len: SequenceLength,
rows: Vec<Logits>,
}
impl AttentionScores {
pub fn new(rows: Vec<Vec<f32>>) -> CtResult<Self> {
if rows.is_empty() {
return Err(CtError::EmptyInput("attention scores"));
}
let key_len = rows[0].len();
if key_len == 0 {
return Err(CtError::EmptyInput("attention score row"));
}
for row in &rows {
if row.len() != key_len {
return Err(CtError::ShapeMismatch {
op: "attention scores",
expected: format!("all rows have {key_len} columns"),
got: format!("row with {} columns", row.len()),
});
}
if row.iter().any(|value| !value.is_finite()) {
return Err(CtError::ShapeMismatch {
op: "attention scores",
expected: "all score values are finite".to_string(),
got: "non-finite score value".to_string(),
});
}
}
Ok(Self {
query_len: SequenceLength::new(rows.len())?,
key_len: SequenceLength::new(key_len)?,
rows: rows.into_iter().map(Logits::new).collect(),
})
}
pub fn query_len(&self) -> SequenceLength {
self.query_len
}
pub fn key_len(&self) -> SequenceLength {
self.key_len
}
pub fn rows(&self) -> &[Logits] {
&self.rows
}
}
/// Allowed query-by-key positions before attention softmax.
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct AttentionMask {
query_len: SequenceLength,
key_len: SequenceLength,
rows: Vec<Vec<bool>>,
}
impl AttentionMask {
pub fn new(rows: Vec<Vec<bool>>) -> CtResult<Self> {
if rows.is_empty() {
return Err(CtError::EmptyInput("attention mask"));
}
let key_len = rows[0].len();
if key_len == 0 {
return Err(CtError::EmptyInput("attention mask row"));
}
for row in &rows {
if row.len() != key_len {
return Err(CtError::ShapeMismatch {
op: "attention mask",
expected: format!("all rows have {key_len} columns"),
got: format!("row with {} columns", row.len()),
});
}
if !row.iter().any(|allowed| *allowed) {
return Err(CtError::EmptyInput("attention mask row allows no keys"));
}
}
Ok(Self {
query_len: SequenceLength::new(rows.len())?,
key_len: SequenceLength::new(key_len)?,
rows,
})
}
pub fn query_len(&self) -> SequenceLength {
self.query_len
}
pub fn key_len(&self) -> SequenceLength {
self.key_len
}
pub fn rows(&self) -> &[Vec<bool>] {
&self.rows
}
}
/// Computes scaled query-key dot-product scores.
#[derive(Debug, Clone)]
pub struct ScaledDotProductScores;
impl Morphism<Product<QuerySequence, KeySequence>, AttentionScores> for ScaledDotProductScores {
fn name(&self) -> &'static str {
"scaled_dot_product_scores"
}
fn apply(&self, input: Product<QuerySequence, KeySequence>) -> CtResult<AttentionScores> {
let (queries, keys) = input.into_parts();
let query_dimension = queries.head_dimension();
let key_dimension = keys.head_dimension();
if query_dimension != key_dimension {
return Err(CtError::ShapeMismatch {
op: "scaled dot-product attention scores",
expected: format!("query head dimension {}", query_dimension.value()),
got: format!("key head dimension {}", key_dimension.value()),
});
}
let scale = (query_dimension.value() as f32).sqrt();
let rows = queries
.rows()
.iter()
.map(|query| {
keys.rows()
.iter()
.map(|key| dot_product(query.as_slice(), key.as_slice()) / scale)
.collect::<Vec<_>>()
})
.collect::<Vec<_>>();
AttentionScores::new(rows)
}
}
fn dot_product(left: &[f32], right: &[f32]) -> f32 {
left.iter()
.zip(right.iter())
.map(|(left, right)| left * right)
.sum()
}
/// Applies a boolean attention mask to score rows before softmax.
#[derive(Debug, Clone)]
pub struct MaskedAttentionScores;
impl Morphism<Product<AttentionScores, AttentionMask>, AttentionScores> for MaskedAttentionScores {
fn name(&self) -> &'static str {
"masked_attention_scores"
}
fn apply(&self, input: Product<AttentionScores, AttentionMask>) -> CtResult<AttentionScores> {
let (scores, mask) = input.into_parts();
if scores.query_len() != mask.query_len() || scores.key_len() != mask.key_len() {
return Err(CtError::ShapeMismatch {
op: "masked attention scores",
expected: format!(
"{} query rows x {} key columns",
scores.query_len().value(),
scores.key_len().value()
),
got: format!(
"{} query rows x {} key columns",
mask.query_len().value(),
mask.key_len().value()
),
});
}
let rows = scores
.rows()
.iter()
.zip(mask.rows())
.map(|(score_row, mask_row)| {
score_row
.as_slice()
.iter()
.zip(mask_row)
.map(|(score, allowed)| if *allowed { *score } else { MASKED_SCORE })
.collect::<Vec<_>>()
})
.collect::<Vec<_>>();
AttentionScores::new(rows)
}
}
/// Row-wise attention probabilities over key positions.
#[derive(Debug, Clone, PartialEq)]
pub struct AttentionWeights {
query_len: SequenceLength,
key_len: SequenceLength,
rows: Vec<Distribution>,
}
impl AttentionWeights {
pub fn new(rows: Vec<Distribution>) -> CtResult<Self> {
if rows.is_empty() {
return Err(CtError::EmptyInput("attention weights"));
}
let key_len = rows[0].as_slice().len();
if key_len == 0 {
return Err(CtError::EmptyInput("attention weight row"));
}
for row in &rows {
if row.as_slice().len() != key_len {
return Err(CtError::ShapeMismatch {
op: "attention weights",
expected: format!("all rows have {key_len} columns"),
got: format!("row with {} columns", row.as_slice().len()),
});
}
}
Ok(Self {
query_len: SequenceLength::new(rows.len())?,
key_len: SequenceLength::new(key_len)?,
rows,
})
}
pub fn query_len(&self) -> SequenceLength {
self.query_len
}
pub fn key_len(&self) -> SequenceLength {
self.key_len
}
pub fn rows(&self) -> &[Distribution] {
&self.rows
}
}
/// Weighted value vectors, one output row per query position.
#[derive(Debug, Clone, PartialEq)]
pub struct AttentionOutput {
sequence_len: SequenceLength,
head_dimension: HeadDimension,
rows: Vec<Vector>,
}
impl AttentionOutput {
pub fn new(rows: Vec<Vec<f32>>) -> CtResult<Self> {
let matrix = AttentionVectorRows::new("attention output", rows)?;
Ok(Self {
sequence_len: matrix.sequence_len,
head_dimension: matrix.head_dimension,
rows: matrix.rows,
})
}
pub fn sequence_len(&self) -> SequenceLength {
self.sequence_len
}
pub fn head_dimension(&self) -> HeadDimension {
self.head_dimension
}
pub fn rows(&self) -> &[Vector] {
&self.rows
}
}
/// Validated collection of single-head attention outputs.
#[derive(Debug, Clone, PartialEq)]
pub struct AttentionHeadOutputs {
head_count: HeadCount,
sequence_len: SequenceLength,
head_dimension: HeadDimension,
heads: Vec<AttentionOutput>,
}
impl AttentionHeadOutputs {
pub fn new(heads: Vec<AttentionOutput>) -> CtResult<Self> {
if heads.is_empty() {
return Err(CtError::EmptyInput("attention head outputs"));
}
let sequence_len = heads[0].sequence_len();
let head_dimension = heads[0].head_dimension();
for head in &heads {
if head.sequence_len() != sequence_len {
return Err(CtError::ShapeMismatch {
op: "attention head outputs",
expected: format!("all heads have {} sequence rows", sequence_len.value()),
got: format!("head with {} sequence rows", head.sequence_len().value()),
});
}
if head.head_dimension() != head_dimension {
return Err(CtError::ShapeMismatch {
op: "attention head outputs",
expected: format!("all heads have dimension {}", head_dimension.value()),
got: format!("head dimension {}", head.head_dimension().value()),
});
}
}
Ok(Self {
head_count: HeadCount::new(heads.len())?,
sequence_len,
head_dimension,
heads,
})
}
pub fn head_count(&self) -> HeadCount {
self.head_count
}
pub fn sequence_len(&self) -> SequenceLength {
self.sequence_len
}
pub fn head_dimension(&self) -> HeadDimension {
self.head_dimension
}
pub fn heads(&self) -> &[AttentionOutput] {
&self.heads
}
}
/// Concatenated output of several attention heads.
#[derive(Debug, Clone, PartialEq)]
pub struct MultiHeadOutput {
sequence_len: SequenceLength,
head_count: HeadCount,
head_dimension: HeadDimension,
model_dimension: ModelDimension,
rows: Vec<Vector>,
}
impl MultiHeadOutput {
fn new(
rows: Vec<Vec<f32>>,
head_count: HeadCount,
head_dimension: HeadDimension,
) -> CtResult<Self> {
let matrix = AttentionVectorRows::new("multi-head output", rows)?;
let expected_dimension = head_count.value() * head_dimension.value();
if matrix.head_dimension.value() != expected_dimension {
return Err(CtError::ShapeMismatch {
op: "multi-head output",
expected: format!("row dimension {expected_dimension}"),
got: format!("row dimension {}", matrix.head_dimension.value()),
});
}
Ok(Self {
sequence_len: matrix.sequence_len,
head_count,
head_dimension,
model_dimension: ModelDimension::new(expected_dimension)?,
rows: matrix.rows,
})
}
pub fn sequence_len(&self) -> SequenceLength {
self.sequence_len
}
pub fn head_count(&self) -> HeadCount {
self.head_count
}
pub fn head_dimension(&self) -> HeadDimension {
self.head_dimension
}
pub fn model_dimension(&self) -> ModelDimension {
self.model_dimension
}
pub fn rows(&self) -> &[Vector] {
&self.rows
}
}
/// Output sequence after the multi-head output projection.
#[derive(Debug, Clone, PartialEq)]
pub struct ProjectedAttentionOutput {
sequence_len: SequenceLength,
model_dimension: ModelDimension,
rows: Vec<Vector>,
}
impl ProjectedAttentionOutput {
pub fn new(rows: Vec<Vec<f32>>) -> CtResult<Self> {
let matrix = AttentionVectorRows::new("projected attention output", rows)?;
Ok(Self {
sequence_len: matrix.sequence_len,
model_dimension: ModelDimension::new(matrix.head_dimension.value())?,
rows: matrix.rows,
})
}
pub fn sequence_len(&self) -> SequenceLength {
self.sequence_len
}
pub fn model_dimension(&self) -> ModelDimension {
self.model_dimension
}
pub fn rows(&self) -> &[Vector] {
&self.rows
}
}
/// Vocabulary logits for every position in a hidden sequence.
#[derive(Debug, Clone, PartialEq)]
pub struct SequenceLogits {
sequence_len: SequenceLength,
vocab_size: VocabSize,
rows: Vec<Logits>,
}
impl SequenceLogits {
pub fn new(rows: Vec<Vec<f32>>) -> CtResult<Self> {
if rows.is_empty() {
return Err(CtError::EmptyInput("sequence logits"));
}
let vocab_size = rows[0].len();
if vocab_size == 0 {
return Err(CtError::EmptyInput("sequence logits row"));
}
for row in &rows {
if row.len() != vocab_size {
return Err(CtError::ShapeMismatch {
op: "sequence logits",
expected: format!("all rows have {vocab_size} columns"),
got: format!("row with {} columns", row.len()),
});
}
if row.iter().any(|value| !value.is_finite()) {
return Err(CtError::ShapeMismatch {
op: "sequence logits",
expected: "all logit values are finite".to_string(),
got: "non-finite logit value".to_string(),
});
}
}
Ok(Self {
sequence_len: SequenceLength::new(rows.len())?,
vocab_size: VocabSize::new(vocab_size)?,
rows: rows.into_iter().map(Logits::new).collect(),
})
}
pub fn sequence_len(&self) -> SequenceLength {
self.sequence_len
}
pub fn vocab_size(&self) -> VocabSize {
self.vocab_size
}
pub fn rows(&self) -> &[Logits] {
&self.rows
}
}
/// Learned language-model readout applied to each hidden position.
#[derive(Debug, Clone, PartialEq)]
pub struct TransformerReadout {
input_dimension: ModelDimension,
vocab_size: VocabSize,
weight: Vec<Vec<f32>>,
bias: Vec<f32>,
}
impl TransformerReadout {
pub fn new(weight: Vec<Vec<f32>>, bias: Vec<f32>) -> CtResult<Self> {
let (input_dimension, output_dimension) =
validate_linear_parts("transformer readout", &weight, &bias)?;
Ok(Self {
input_dimension,
vocab_size: VocabSize::new(output_dimension.value())?,
weight,
bias,
})
}
pub fn input_dimension(&self) -> ModelDimension {
self.input_dimension
}
pub fn vocab_size(&self) -> VocabSize {
self.vocab_size
}
pub fn weight(&self) -> &[Vec<f32>] {
&self.weight
}
pub fn bias(&self) -> &[f32] {
&self.bias
}
}
/// Learned output projection after head concatenation.
#[derive(Debug, Clone, PartialEq)]
pub struct AttentionOutputProjection {
input_dimension: ModelDimension,
output_dimension: ModelDimension,
weight: Vec<Vec<f32>>,
bias: Vec<f32>,
}
impl AttentionOutputProjection {
pub fn new(weight: Vec<Vec<f32>>, bias: Vec<f32>) -> CtResult<Self> {
if weight.is_empty() {
return Err(CtError::EmptyInput("attention output projection weight"));
}
if bias.is_empty() {
return Err(CtError::EmptyInput("attention output projection bias"));
}
let output_dimension = bias.len();
if bias.iter().any(|value| !value.is_finite()) {
return Err(CtError::ShapeMismatch {
op: "attention output projection",
expected: "finite bias values".to_string(),
got: "non-finite bias value".to_string(),
});
}
for row in &weight {
if row.len() != output_dimension {
return Err(CtError::ShapeMismatch {
op: "attention output projection",
expected: format!("weight rows have {output_dimension} columns"),
got: format!("weight row with {} columns", row.len()),
});
}
if row.iter().any(|value| !value.is_finite()) {
return Err(CtError::ShapeMismatch {
op: "attention output projection",
expected: "finite weight values".to_string(),
got: "non-finite weight value".to_string(),
});
}
}
Ok(Self {
input_dimension: ModelDimension::new(weight.len())?,
output_dimension: ModelDimension::new(output_dimension)?,
weight,
bias,
})
}
pub fn input_dimension(&self) -> ModelDimension {
self.input_dimension
}
pub fn output_dimension(&self) -> ModelDimension {
self.output_dimension
}
pub fn weight(&self) -> &[Vec<f32>] {
&self.weight
}
pub fn bias(&self) -> &[f32] {
&self.bias
}
}
/// Scale, shift, and epsilon parameters for layer normalization.
#[derive(Debug, Clone, PartialEq)]
pub struct LayerNormParameters {
model_dimension: ModelDimension,
scale: Vec<f32>,
shift: Vec<f32>,
epsilon: NormalizationEpsilon,
}
impl LayerNormParameters {
pub fn new(scale: Vec<f32>, shift: Vec<f32>, epsilon: NormalizationEpsilon) -> CtResult<Self> {
if scale.is_empty() {
return Err(CtError::EmptyInput("layer norm scale"));
}
if shift.is_empty() {
return Err(CtError::EmptyInput("layer norm shift"));
}
if scale.len() != shift.len() {
return Err(CtError::ShapeMismatch {
op: "layer norm parameters",
expected: format!("scale and shift length {}", scale.len()),
got: format!("shift length {}", shift.len()),
});
}
if scale.iter().any(|value| !value.is_finite()) {
return Err(CtError::ShapeMismatch {
op: "layer norm parameters",
expected: "finite scale values".to_string(),
got: "non-finite scale value".to_string(),
});
}
if shift.iter().any(|value| !value.is_finite()) {
return Err(CtError::ShapeMismatch {
op: "layer norm parameters",
expected: "finite shift values".to_string(),
got: "non-finite shift value".to_string(),
});
}
Ok(Self {
model_dimension: ModelDimension::new(scale.len())?,
scale,
shift,
epsilon,
})
}
pub fn identity(model_dimension: ModelDimension) -> Self {
Self {
model_dimension,
scale: vec![1.0; model_dimension.value()],
shift: vec![0.0; model_dimension.value()],
epsilon: NormalizationEpsilon(1e-5),
}
}
pub fn model_dimension(&self) -> ModelDimension {
self.model_dimension
}
pub fn scale(&self) -> &[f32] {
&self.scale
}
pub fn shift(&self) -> &[f32] {
&self.shift
}
pub fn epsilon(&self) -> NormalizationEpsilon {
self.epsilon
}
}
/// Layer normalization over each hidden vector independently.
#[derive(Debug, Clone, PartialEq)]
pub struct LayerNormalization {
parameters: LayerNormParameters,
}
impl LayerNormalization {
pub fn new(parameters: LayerNormParameters) -> Self {
Self { parameters }
}
pub fn model_dimension(&self) -> ModelDimension {
self.parameters.model_dimension()
}
pub fn parameters(&self) -> &LayerNormParameters {
&self.parameters
}
}
#[derive(Debug, Clone, PartialEq)]
struct FeedForwardRowCache {
input: Vec<f32>,
pre_activation: Vec<f32>,
activation: Vec<f32>,
output: Vec<f32>,
}
#[derive(Debug, Clone, PartialEq)]
struct AttentionHeadTrainingCache {
queries: QuerySequence,
keys: KeySequence,
values: ValueSequence,
weights: AttentionWeights,
output: AttentionOutput,
}
#[derive(Debug, Clone, PartialEq)]
struct MaskedBlockTrainingCache {
output: HiddenSequence,
with_feed_forward: HiddenSequence,
with_attention: HiddenSequence,
multi_head_output: MultiHeadOutput,
attention_heads: Vec<AttentionHeadTrainingCache>,
feed_forward_rows: Vec<FeedForwardRowCache>,
}
/// Position-wise two-layer feed-forward sublayer.
#[derive(Debug, Clone, PartialEq)]
pub struct PositionWiseFeedForward {
input_dimension: ModelDimension,
hidden_dimension: ModelDimension,
output_dimension: ModelDimension,
first_weight: Vec<Vec<f32>>,
first_bias: Vec<f32>,
second_weight: Vec<Vec<f32>>,
second_bias: Vec<f32>,
}
impl PositionWiseFeedForward {
pub fn new(
first_weight: Vec<Vec<f32>>,
first_bias: Vec<f32>,
second_weight: Vec<Vec<f32>>,
second_bias: Vec<f32>,
) -> CtResult<Self> {
let (input_dimension, hidden_dimension) = validate_linear_parts(
"position-wise feed-forward first layer",
&first_weight,
&first_bias,
)?;
let (second_input_dimension, output_dimension) = validate_linear_parts(
"position-wise feed-forward second layer",
&second_weight,
&second_bias,
)?;
if second_input_dimension != hidden_dimension {
return Err(CtError::ShapeMismatch {
op: "position-wise feed-forward",
expected: format!("second input dimension {}", hidden_dimension.value()),
got: format!("second input dimension {}", second_input_dimension.value()),
});
}
if output_dimension != input_dimension {
return Err(CtError::ShapeMismatch {
op: "position-wise feed-forward",
expected: format!("output dimension {}", input_dimension.value()),
got: format!("output dimension {}", output_dimension.value()),
});
}
Ok(Self {
input_dimension,
hidden_dimension,
output_dimension,
first_weight,
first_bias,
second_weight,
second_bias,
})
}
pub fn input_dimension(&self) -> ModelDimension {
self.input_dimension
}
pub fn hidden_dimension(&self) -> ModelDimension {
self.hidden_dimension
}
pub fn output_dimension(&self) -> ModelDimension {
self.output_dimension
}
pub fn first_weight(&self) -> &[Vec<f32>] {
&self.first_weight
}
pub fn first_bias(&self) -> &[f32] {
&self.first_bias
}
pub fn second_weight(&self) -> &[Vec<f32>] {
&self.second_weight
}
pub fn second_bias(&self) -> &[f32] {
&self.second_bias
}
}
#[derive(Debug, Clone, PartialEq)]
struct HiddenProjection {
op: &'static str,
input_dimension: ModelDimension,
head_dimension: HeadDimension,
weight: Vec<Vec<f32>>,
bias: Vec<f32>,
}
impl HiddenProjection {
fn new(op: &'static str, weight: Vec<Vec<f32>>, bias: Vec<f32>) -> CtResult<Self> {
let (input_dimension, output_dimension) = validate_linear_parts(op, &weight, &bias)?;
Ok(Self {
op,
input_dimension,
head_dimension: HeadDimension::new(output_dimension.value())?,
weight,
bias,
})
}
fn project(&self, input: &HiddenSequence) -> CtResult<Vec<Vec<f32>>> {
if input.model_dimension() != self.input_dimension {
return Err(CtError::ShapeMismatch {
op: self.op,
expected: format!("input dimension {}", self.input_dimension.value()),
got: format!("input dimension {}", input.model_dimension().value()),
});
}
Ok(input
.rows()
.iter()
.map(|row| project_row(row.as_slice(), &self.weight, &self.bias))
.collect::<Vec<_>>())
}
fn input_dimension(&self) -> ModelDimension {
self.input_dimension
}
fn head_dimension(&self) -> HeadDimension {
self.head_dimension
}
fn weight(&self) -> &[Vec<f32>] {
&self.weight
}
fn bias(&self) -> &[f32] {
&self.bias
}
}
/// Learned projection from hidden states to query vectors.
#[derive(Debug, Clone, PartialEq)]
pub struct HiddenToQuery {
projection: HiddenProjection,
}
impl HiddenToQuery {
pub fn new(weight: Vec<Vec<f32>>, bias: Vec<f32>) -> CtResult<Self> {
Ok(Self {
projection: HiddenProjection::new("hidden-to-query projection", weight, bias)?,
})
}
pub fn input_dimension(&self) -> ModelDimension {
self.projection.input_dimension()
}
pub fn head_dimension(&self) -> HeadDimension {
self.projection.head_dimension()
}
pub fn weight(&self) -> &[Vec<f32>] {
self.projection.weight()
}
pub fn bias(&self) -> &[f32] {
self.projection.bias()
}
}
/// Learned projection from hidden states to key vectors.
#[derive(Debug, Clone, PartialEq)]
pub struct HiddenToKey {
projection: HiddenProjection,
}
impl HiddenToKey {
pub fn new(weight: Vec<Vec<f32>>, bias: Vec<f32>) -> CtResult<Self> {
Ok(Self {
projection: HiddenProjection::new("hidden-to-key projection", weight, bias)?,
})
}
pub fn input_dimension(&self) -> ModelDimension {
self.projection.input_dimension()
}
pub fn head_dimension(&self) -> HeadDimension {
self.projection.head_dimension()
}
pub fn weight(&self) -> &[Vec<f32>] {
self.projection.weight()
}
pub fn bias(&self) -> &[f32] {
self.projection.bias()
}
}
/// Learned projection from hidden states to value vectors.
#[derive(Debug, Clone, PartialEq)]
pub struct HiddenToValue {
projection: HiddenProjection,
}
impl HiddenToValue {
pub fn new(weight: Vec<Vec<f32>>, bias: Vec<f32>) -> CtResult<Self> {
Ok(Self {
projection: HiddenProjection::new("hidden-to-value projection", weight, bias)?,
})
}
pub fn input_dimension(&self) -> ModelDimension {
self.projection.input_dimension()
}
pub fn head_dimension(&self) -> HeadDimension {
self.projection.head_dimension()
}
pub fn weight(&self) -> &[Vec<f32>] {
self.projection.weight()
}
pub fn bias(&self) -> &[f32] {
self.projection.bias()
}
}
/// A tiny single-head Transformer block sketch.
#[derive(Debug, Clone, PartialEq)]
pub struct SingleHeadTransformerBlock {
model_dimension: ModelDimension,
query_projection: HiddenToQuery,
key_projection: HiddenToKey,
value_projection: HiddenToValue,
output_projection: AttentionOutputProjection,
attention_norm: LayerNormalization,
feed_forward: PositionWiseFeedForward,
feed_forward_norm: LayerNormalization,
}
impl SingleHeadTransformerBlock {
pub fn new(
query_projection: HiddenToQuery,
key_projection: HiddenToKey,
value_projection: HiddenToValue,
output_projection: AttentionOutputProjection,
attention_norm: LayerNormalization,
feed_forward: PositionWiseFeedForward,
feed_forward_norm: LayerNormalization,
) -> CtResult<Self> {
let model_dimension = query_projection.input_dimension();
validate_projection_input(
"single-head block key projection",
model_dimension,
key_projection.input_dimension(),
)?;
validate_projection_input(
"single-head block value projection",
model_dimension,
value_projection.input_dimension(),
)?;
if query_projection.head_dimension() != key_projection.head_dimension() {
return Err(CtError::ShapeMismatch {
op: "single-head block",
expected: format!(
"query/key head dimension {}",
query_projection.head_dimension().value()
),
got: format!(
"key head dimension {}",
key_projection.head_dimension().value()
),
});
}
if output_projection.input_dimension().value() != value_projection.head_dimension().value()
{
return Err(CtError::ShapeMismatch {
op: "single-head block",
expected: format!(
"output projection input dimension {}",
value_projection.head_dimension().value()
),
got: format!(
"output projection input dimension {}",
output_projection.input_dimension().value()
),
});
}
validate_projection_input(
"single-head block output projection",
model_dimension,
output_projection.output_dimension(),
)?;
validate_projection_input(
"single-head block attention normalization",
model_dimension,
attention_norm.model_dimension(),
)?;
validate_projection_input(
"single-head block feed-forward",
model_dimension,
feed_forward.input_dimension(),
)?;
validate_projection_input(
"single-head block feed-forward normalization",
model_dimension,
feed_forward_norm.model_dimension(),
)?;
Ok(Self {
model_dimension,
query_projection,
key_projection,
value_projection,
output_projection,
attention_norm,
feed_forward,
feed_forward_norm,
})
}
pub fn model_dimension(&self) -> ModelDimension {
self.model_dimension
}
}
/// Learned query, key, and value projections for one self-attention head.
#[derive(Debug, Clone, PartialEq)]
pub struct SelfAttentionHead {
query_projection: HiddenToQuery,
key_projection: HiddenToKey,
value_projection: HiddenToValue,
}
impl SelfAttentionHead {
pub fn new(
query_projection: HiddenToQuery,
key_projection: HiddenToKey,
value_projection: HiddenToValue,
) -> CtResult<Self> {
let input_dimension = query_projection.input_dimension();
validate_projection_input(
"self-attention head key projection",
input_dimension,
key_projection.input_dimension(),
)?;
validate_projection_input(
"self-attention head value projection",
input_dimension,
value_projection.input_dimension(),
)?;
if query_projection.head_dimension() != key_projection.head_dimension() {
return Err(CtError::ShapeMismatch {
op: "self-attention head",
expected: format!(
"query/key head dimension {}",
query_projection.head_dimension().value()
),
got: format!(
"key head dimension {}",
key_projection.head_dimension().value()
),
});
}
Ok(Self {
query_projection,
key_projection,
value_projection,
})
}
pub fn input_dimension(&self) -> ModelDimension {
self.query_projection.input_dimension()
}
pub fn query_key_dimension(&self) -> HeadDimension {
self.query_projection.head_dimension()
}
pub fn value_dimension(&self) -> HeadDimension {
self.value_projection.head_dimension()
}
pub fn query_projection(&self) -> &HiddenToQuery {
&self.query_projection
}
pub fn key_projection(&self) -> &HiddenToKey {
&self.key_projection
}
pub fn value_projection(&self) -> &HiddenToValue {
&self.value_projection
}
}
/// A tiny multi-head Transformer block sketch.
#[derive(Debug, Clone, PartialEq)]
pub struct MultiHeadTransformerBlock {
model_dimension: ModelDimension,
head_count: HeadCount,
value_dimension: HeadDimension,
heads: Vec<SelfAttentionHead>,
output_projection: AttentionOutputProjection,
attention_norm: LayerNormalization,
feed_forward: PositionWiseFeedForward,
feed_forward_norm: LayerNormalization,
}
impl MultiHeadTransformerBlock {
pub fn new(
heads: Vec<SelfAttentionHead>,
output_projection: AttentionOutputProjection,
attention_norm: LayerNormalization,
feed_forward: PositionWiseFeedForward,
feed_forward_norm: LayerNormalization,
) -> CtResult<Self> {
if heads.is_empty() {
return Err(CtError::EmptyInput("multi-head block heads"));
}
let model_dimension = heads[0].input_dimension();
let value_dimension = heads[0].value_dimension();
for head in &heads {
validate_projection_input(
"multi-head block head projection",
model_dimension,
head.input_dimension(),
)?;
if head.value_dimension() != value_dimension {
return Err(CtError::ShapeMismatch {
op: "multi-head block",
expected: format!("value head dimension {}", value_dimension.value()),
got: format!("value head dimension {}", head.value_dimension().value()),
});
}
}
let head_count = HeadCount::new(heads.len())?;
let concatenated_dimension =
ModelDimension::new(head_count.value() * value_dimension.value())?;
validate_projection_input(
"multi-head block output projection input",
concatenated_dimension,
output_projection.input_dimension(),
)?;
validate_projection_input(
"multi-head block output projection",
model_dimension,
output_projection.output_dimension(),
)?;
validate_projection_input(
"multi-head block attention normalization",
model_dimension,
attention_norm.model_dimension(),
)?;
validate_projection_input(
"multi-head block feed-forward",
model_dimension,
feed_forward.input_dimension(),
)?;
validate_projection_input(
"multi-head block feed-forward normalization",
model_dimension,
feed_forward_norm.model_dimension(),
)?;
Ok(Self {
model_dimension,
head_count,
value_dimension,
heads,
output_projection,
attention_norm,
feed_forward,
feed_forward_norm,
})
}
pub fn model_dimension(&self) -> ModelDimension {
self.model_dimension
}
pub fn head_count(&self) -> HeadCount {
self.head_count
}
pub fn value_dimension(&self) -> HeadDimension {
self.value_dimension
}
pub fn heads(&self) -> &[SelfAttentionHead] {
&self.heads
}
fn with_heads(self, heads: Vec<SelfAttentionHead>) -> CtResult<Self> {
Self::new(
heads,
self.output_projection,
self.attention_norm,
self.feed_forward,
self.feed_forward_norm,
)
}
fn with_feed_forward(self, feed_forward: PositionWiseFeedForward) -> CtResult<Self> {
Self::new(
self.heads,
self.output_projection,
self.attention_norm,
feed_forward,
self.feed_forward_norm,
)
}
fn with_output_projection(
self,
output_projection: AttentionOutputProjection,
) -> CtResult<Self> {
Self::new(
self.heads,
output_projection,
self.attention_norm,
self.feed_forward,
self.feed_forward_norm,
)
}
fn with_layer_norms(
self,
attention_norm: LayerNormalization,
feed_forward_norm: LayerNormalization,
) -> CtResult<Self> {
Self::new(
self.heads,
self.output_projection,
attention_norm,
self.feed_forward,
feed_forward_norm,
)
}
}
/// A tiny masked multi-head Transformer block sketch.
#[derive(Debug, Clone, PartialEq)]
pub struct MaskedMultiHeadTransformerBlock {
block: MultiHeadTransformerBlock,
}
impl MaskedMultiHeadTransformerBlock {
pub fn new(
heads: Vec<SelfAttentionHead>,
output_projection: AttentionOutputProjection,
attention_norm: LayerNormalization,
feed_forward: PositionWiseFeedForward,
feed_forward_norm: LayerNormalization,
) -> CtResult<Self> {
Ok(Self {
block: MultiHeadTransformerBlock::new(
heads,
output_projection,
attention_norm,
feed_forward,
feed_forward_norm,
)?,
})
}
pub fn model_dimension(&self) -> ModelDimension {
self.block.model_dimension()
}
pub fn head_count(&self) -> HeadCount {
self.block.head_count()
}
pub fn value_dimension(&self) -> HeadDimension {
self.block.value_dimension()
}
pub fn heads(&self) -> &[SelfAttentionHead] {
self.block.heads()
}
pub fn feed_forward(&self) -> &PositionWiseFeedForward {
&self.block.feed_forward
}
fn with_feed_forward(self, feed_forward: PositionWiseFeedForward) -> CtResult<Self> {
Ok(Self {
block: self.block.with_feed_forward(feed_forward)?,
})
}
fn with_heads(self, heads: Vec<SelfAttentionHead>) -> CtResult<Self> {
Ok(Self {
block: self.block.with_heads(heads)?,
})
}
fn with_output_projection(
self,
output_projection: AttentionOutputProjection,
) -> CtResult<Self> {
Ok(Self {
block: self.block.with_output_projection(output_projection)?,
})
}
fn with_layer_norms(
self,
attention_norm: LayerNormalization,
feed_forward_norm: LayerNormalization,
) -> CtResult<Self> {
Ok(Self {
block: self
.block
.with_layer_norms(attention_norm, feed_forward_norm)?,
})
}
fn apply_with_training_cache(
&self,
hidden: HiddenSequence,
mask: AttentionMask,
) -> CtResult<MaskedBlockTrainingCache> {
if hidden.model_dimension() != self.block.model_dimension {
return Err(CtError::ShapeMismatch {
op: "masked multi-head block",
expected: format!("model dimension {}", self.block.model_dimension.value()),
got: format!("model dimension {}", hidden.model_dimension().value()),
});
}
let head_caches = self
.block
.heads
.iter()
.map(|head| apply_self_attention_head_with_mask_cache(&hidden, head, Some(&mask)))
.collect::<CtResult<Vec<_>>>()?;
let attention_outputs = head_caches
.iter()
.map(|cache| cache.output.clone())
.collect::<Vec<_>>();
let head_outputs = AttentionHeadOutputs::new(attention_outputs)?;
let multi_head_output = ConcatenateHeads.apply(head_outputs)?;
let projected_attention = self
.block
.output_projection
.apply(multi_head_output.clone())?;
let with_attention = ResidualConnection.apply(Product::new(hidden, projected_attention))?;
let normalized_attention = self.block.attention_norm.apply(with_attention.clone())?;
let (feed_forward_output, feed_forward_rows) =
feed_forward_with_cache(&self.block.feed_forward, &normalized_attention)?;
let with_feed_forward =
ResidualConnection.apply(Product::new(normalized_attention, feed_forward_output))?;
let output = self
.block
.feed_forward_norm
.apply(with_feed_forward.clone())?;
Ok(MaskedBlockTrainingCache {
output,
with_feed_forward,
with_attention,
multi_head_output,
attention_heads: head_caches,
feed_forward_rows,
})
}
}
/// Tiny structured Transformer parameter object for the roadmap.
#[derive(Debug, Clone, PartialEq)]
pub struct TinyTransformerParameters {
positional_encoding: PositionalEncoding,
block: MaskedMultiHeadTransformerBlock,
readout: TransformerReadout,
}
impl TinyTransformerParameters {
pub fn new(
positional_encoding: PositionalEncoding,
block: MaskedMultiHeadTransformerBlock,
readout: TransformerReadout,
) -> CtResult<Self> {
let model_dimension = positional_encoding.model_dimension();
validate_projection_input(
"tiny transformer parameters block",
model_dimension,
block.model_dimension(),
)?;
validate_projection_input(
"tiny transformer parameters readout",
model_dimension,
readout.input_dimension(),
)?;
Ok(Self {
positional_encoding,
block,
readout,
})
}
pub fn model_dimension(&self) -> ModelDimension {
self.positional_encoding.model_dimension()
}
pub fn max_sequence_len(&self) -> SequenceLength {
self.positional_encoding.max_sequence_len()
}
pub fn vocab_size(&self) -> VocabSize {
self.readout.vocab_size()
}
pub fn encode(&self, hidden: HiddenSequence, mask: AttentionMask) -> CtResult<HiddenSequence> {
let positioned = self.positional_encoding.apply(hidden)?;
self.block.apply(Product::new(positioned, mask))
}
pub fn readout(&self) -> &TransformerReadout {
&self.readout
}
pub fn feed_forward(&self) -> &PositionWiseFeedForward {
self.block.feed_forward()
}
pub fn output_projection(&self) -> &AttentionOutputProjection {
&self.block.block.output_projection
}
pub fn attention_heads(&self) -> &[SelfAttentionHead] {
self.block.heads()
}
pub fn attention_norm(&self) -> &LayerNormalization {
&self.block.block.attention_norm
}
pub fn feed_forward_norm(&self) -> &LayerNormalization {
&self.block.block.feed_forward_norm
}
fn with_readout(self, readout: TransformerReadout) -> CtResult<Self> {
Self::new(self.positional_encoding, self.block, readout)
}
fn with_feed_forward(self, feed_forward: PositionWiseFeedForward) -> CtResult<Self> {
Self::new(
self.positional_encoding,
self.block.with_feed_forward(feed_forward)?,
self.readout,
)
}
fn with_attention_heads(self, heads: Vec<SelfAttentionHead>) -> CtResult<Self> {
Self::new(
self.positional_encoding,
self.block.with_heads(heads)?,
self.readout,
)
}
fn with_output_projection(
self,
output_projection: AttentionOutputProjection,
) -> CtResult<Self> {
Self::new(
self.positional_encoding,
self.block.with_output_projection(output_projection)?,
self.readout,
)
}
fn with_layer_norms(
self,
attention_norm: LayerNormalization,
feed_forward_norm: LayerNormalization,
) -> CtResult<Self> {
Self::new(
self.positional_encoding,
self.block
.with_layer_norms(attention_norm, feed_forward_norm)?,
self.readout,
)
}
}
/// Structured state owned by a future Transformer training loop.
#[derive(Debug, Clone, PartialEq)]
pub struct TransformerTrainingState {
parameters: TinyTransformerParameters,
learning_rate: LearningRate,
step_count: StepCount,
}
impl TransformerTrainingState {
pub fn new(parameters: TinyTransformerParameters, learning_rate: LearningRate) -> Self {
Self::from_parts(parameters, learning_rate, StepCount::new(0))
}
pub fn from_parts(
parameters: TinyTransformerParameters,
learning_rate: LearningRate,
step_count: StepCount,
) -> Self {
Self {
parameters,
learning_rate,
step_count,
}
}
pub fn parameters(&self) -> &TinyTransformerParameters {
&self.parameters
}
pub fn learning_rate(&self) -> LearningRate {
self.learning_rate
}
pub fn step_count(&self) -> StepCount {
self.step_count
}
pub fn record_updated_parameters(self, parameters: TinyTransformerParameters) -> Self {
Self {
parameters,
learning_rate: self.learning_rate,
step_count: StepCount::new(self.step_count.value() + 1),
}
}
}
/// One supervised sequence example for a readout-only Transformer update.
#[derive(Debug, Clone, PartialEq)]
pub struct TransformerReadoutTrainingExample {
hidden: HiddenSequence,
mask: AttentionMask,
targets: TokenSequence,
}
impl TransformerReadoutTrainingExample {
pub fn new(
hidden: HiddenSequence,
mask: AttentionMask,
targets: TokenSequence,
) -> CtResult<Self> {
let sequence_len = hidden.sequence_len();
if targets.as_slice().len() != sequence_len.value() {
return Err(CtError::ShapeMismatch {
op: "transformer readout training targets",
expected: format!("{} target tokens", sequence_len.value()),
got: format!("{} target tokens", targets.as_slice().len()),
});
}
if mask.query_len() != sequence_len || mask.key_len() != sequence_len {
return Err(CtError::ShapeMismatch {
op: "transformer readout training mask",
expected: format!(
"{} query rows x {} key columns",
sequence_len.value(),
sequence_len.value()
),
got: format!(
"{} query rows x {} key columns",
mask.query_len().value(),
mask.key_len().value()
),
});
}
Ok(Self {
hidden,
mask,
targets,
})
}
pub fn hidden(&self) -> &HiddenSequence {
&self.hidden
}
pub fn mask(&self) -> &AttentionMask {
&self.mask
}
pub fn targets(&self) -> &TokenSequence {
&self.targets
}
}
/// Non-empty set of supervised readout examples.
#[derive(Debug, Clone, PartialEq)]
pub struct TransformerReadoutTrainingSet(Vec<TransformerReadoutTrainingExample>);
impl TransformerReadoutTrainingSet {
pub fn new(
examples: impl IntoIterator<Item = TransformerReadoutTrainingExample>,
) -> CtResult<Self> {
let examples = examples.into_iter().collect::<Vec<_>>();
if examples.is_empty() {
return Err(CtError::EmptyInput("transformer readout training set"));
}
Ok(Self(examples))
}
pub fn examples(&self) -> &[TransformerReadoutTrainingExample] {
&self.0
}
}
/// One full-batch update of the sequence readout parameters.
#[derive(Debug, Clone, PartialEq)]
pub struct TransformerReadoutTrainStep {
dataset: TransformerReadoutTrainingSet,
}
impl TransformerReadoutTrainStep {
pub fn new(dataset: TransformerReadoutTrainingSet) -> Self {
Self { dataset }
}
}
impl Morphism<TransformerTrainingState, TransformerTrainingState> for TransformerReadoutTrainStep {
fn name(&self) -> &'static str {
"transformer_readout_train_step"
}
fn apply(&self, state: TransformerTrainingState) -> CtResult<TransformerTrainingState> {
let input_dimension = state.parameters.readout().input_dimension().value();
let vocab_size = state.parameters.vocab_size().value();
let mut grad_weight = vec![vec![0.0; vocab_size]; input_dimension];
let mut grad_bias = vec![0.0; vocab_size];
let mut position_count = 0usize;
for example in self.dataset.examples() {
let encoded = state
.parameters
.encode(example.hidden().clone(), example.mask().clone())?;
let logits = state.parameters.readout().apply(encoded.clone())?;
for ((hidden_row, logit_row), target) in encoded
.rows()
.iter()
.zip(logits.rows())
.zip(example.targets().as_slice())
{
let target_index = target.index();
if target_index >= vocab_size {
return Err(CtError::OutOfRange {
kind: "sequence target",
index: target_index,
limit: vocab_size,
});
}
let probabilities = Softmax.apply(logit_row.clone())?;
let mut dlogits = probabilities.as_slice().to_vec();
dlogits[target_index] -= 1.0;
for (vocab_id, dlogit) in dlogits.iter().copied().enumerate() {
grad_bias[vocab_id] += dlogit;
for (feature, hidden_value) in hidden_row.as_slice().iter().copied().enumerate()
{
grad_weight[feature][vocab_id] += hidden_value * dlogit;
}
}
position_count += 1;
}
}
let scale = state.learning_rate().value() / position_count as f32;
let mut updated_weight = state.parameters.readout().weight().to_vec();
let mut updated_bias = state.parameters.readout().bias().to_vec();
for (row, grad_row) in updated_weight.iter_mut().zip(&grad_weight) {
for (weight, grad) in row.iter_mut().zip(grad_row) {
*weight -= scale * grad;
}
}
for (bias, grad) in updated_bias.iter_mut().zip(&grad_bias) {
*bias -= scale * grad;
}
let updated_readout = TransformerReadout::new(updated_weight, updated_bias)?;
let updated_parameters = state.parameters.clone().with_readout(updated_readout)?;
Ok(state.record_updated_parameters(updated_parameters))
}
}
/// One supervised sequence example for local feed-forward sublayer training.
#[derive(Debug, Clone, PartialEq)]
pub struct TransformerFeedForwardTrainingExample {
input: HiddenSequence,
target: HiddenSequence,
}
impl TransformerFeedForwardTrainingExample {
pub fn new(input: HiddenSequence, target: HiddenSequence) -> CtResult<Self> {
if input.sequence_len() != target.sequence_len() {
return Err(CtError::ShapeMismatch {
op: "transformer feed-forward training sequence",
expected: format!("{} target rows", input.sequence_len().value()),
got: format!("{} target rows", target.sequence_len().value()),
});
}
if input.model_dimension() != target.model_dimension() {
return Err(CtError::ShapeMismatch {
op: "transformer feed-forward training dimension",
expected: format!("target dimension {}", input.model_dimension().value()),
got: format!("target dimension {}", target.model_dimension().value()),
});
}
Ok(Self { input, target })
}
pub fn input(&self) -> &HiddenSequence {
&self.input
}
pub fn target(&self) -> &HiddenSequence {
&self.target
}
}
/// Non-empty set of local feed-forward training examples.
#[derive(Debug, Clone, PartialEq)]
pub struct TransformerFeedForwardTrainingSet(Vec<TransformerFeedForwardTrainingExample>);
impl TransformerFeedForwardTrainingSet {
pub fn new(
examples: impl IntoIterator<Item = TransformerFeedForwardTrainingExample>,
) -> CtResult<Self> {
let examples = examples.into_iter().collect::<Vec<_>>();
if examples.is_empty() {
return Err(CtError::EmptyInput("transformer feed-forward training set"));
}
Ok(Self(examples))
}
pub fn examples(&self) -> &[TransformerFeedForwardTrainingExample] {
&self.0
}
}
/// One full-batch update of the position-wise feed-forward sublayer.
#[derive(Debug, Clone, PartialEq)]
pub struct TransformerFeedForwardTrainStep {
dataset: TransformerFeedForwardTrainingSet,
}
impl TransformerFeedForwardTrainStep {
pub fn new(dataset: TransformerFeedForwardTrainingSet) -> Self {
Self { dataset }
}
}
impl Morphism<TransformerTrainingState, TransformerTrainingState>
for TransformerFeedForwardTrainStep
{
fn name(&self) -> &'static str {
"transformer_feed_forward_train_step"
}
fn apply(&self, state: TransformerTrainingState) -> CtResult<TransformerTrainingState> {
let feed_forward = state.parameters.feed_forward();
let mut gradients = FeedForwardGradients::new(feed_forward);
let mut row_count = 0usize;
for example in self.dataset.examples() {
if example.input().model_dimension() != feed_forward.input_dimension() {
return Err(CtError::ShapeMismatch {
op: "transformer feed-forward train step",
expected: format!("input dimension {}", feed_forward.input_dimension().value()),
got: format!(
"input dimension {}",
example.input().model_dimension().value()
),
});
}
let (_output, cache_rows) = feed_forward_with_cache(feed_forward, example.input())?;
for (cache, target_row) in cache_rows.iter().zip(example.target().rows()) {
let d_output = cache
.output
.iter()
.zip(target_row.as_slice())
.map(|(output_value, target_value)| output_value - target_value)
.collect::<Vec<_>>();
gradients.accumulate(feed_forward, cache, &d_output);
row_count += 1;
}
}
let updated_feed_forward =
gradients.apply_to(feed_forward, state.learning_rate(), row_count)?;
let updated_parameters = state
.parameters
.clone()
.with_feed_forward(updated_feed_forward)?;
Ok(state.record_updated_parameters(updated_parameters))
}
}
/// One supervised sequence example for a composed block-level update.
#[derive(Debug, Clone, PartialEq)]
pub struct TransformerBlockTrainingExample {
hidden: HiddenSequence,
mask: AttentionMask,
targets: TokenSequence,
}
impl TransformerBlockTrainingExample {
pub fn new(
hidden: HiddenSequence,
mask: AttentionMask,
targets: TokenSequence,
) -> CtResult<Self> {
let sequence_len = hidden.sequence_len();
if targets.as_slice().len() != sequence_len.value() {
return Err(CtError::ShapeMismatch {
op: "transformer block training targets",
expected: format!("{} target tokens", sequence_len.value()),
got: format!("{} target tokens", targets.as_slice().len()),
});
}
if mask.query_len() != sequence_len || mask.key_len() != sequence_len {
return Err(CtError::ShapeMismatch {
op: "transformer block training mask",
expected: format!(
"{} query rows x {} key columns",
sequence_len.value(),
sequence_len.value()
),
got: format!(
"{} query rows x {} key columns",
mask.query_len().value(),
mask.key_len().value()
),
});
}
Ok(Self {
hidden,
mask,
targets,
})
}
pub fn hidden(&self) -> &HiddenSequence {
&self.hidden
}
pub fn mask(&self) -> &AttentionMask {
&self.mask
}
pub fn targets(&self) -> &TokenSequence {
&self.targets
}
}
/// Non-empty set of supervised block-level training examples.
#[derive(Debug, Clone, PartialEq)]
pub struct TransformerBlockTrainingSet(Vec<TransformerBlockTrainingExample>);
impl TransformerBlockTrainingSet {
pub fn new(
examples: impl IntoIterator<Item = TransformerBlockTrainingExample>,
) -> CtResult<Self> {
let examples = examples.into_iter().collect::<Vec<_>>();
if examples.is_empty() {
return Err(CtError::EmptyInput("transformer block training set"));
}
Ok(Self(examples))
}
pub fn examples(&self) -> &[TransformerBlockTrainingExample] {
&self.0
}
}
/// One full-batch update through the readout, block sublayers, and attention heads.
#[derive(Debug, Clone, PartialEq)]
pub struct TransformerBlockTrainStep {
dataset: TransformerBlockTrainingSet,
}
impl TransformerBlockTrainStep {
pub fn new(dataset: TransformerBlockTrainingSet) -> Self {
Self { dataset }
}
}
impl Morphism<TransformerTrainingState, TransformerTrainingState> for TransformerBlockTrainStep {
fn name(&self) -> &'static str {
"transformer_block_train_step"
}
fn apply(&self, state: TransformerTrainingState) -> CtResult<TransformerTrainingState> {
let readout = state.parameters.readout();
let feed_forward = state.parameters.feed_forward();
let output_projection = state.parameters.output_projection();
let attention_norm = state.parameters.attention_norm();
let feed_forward_norm = state.parameters.feed_forward_norm();
let mut readout_gradients = ReadoutGradients::new(readout);
let mut feed_forward_gradients = FeedForwardGradients::new(feed_forward);
let mut output_projection_gradients =
AttentionOutputProjectionGradients::new(output_projection);
let mut attention_norm_gradients = LayerNormGradients::new(attention_norm.parameters());
let mut feed_forward_norm_gradients =
LayerNormGradients::new(feed_forward_norm.parameters());
let mut attention_head_gradients = state
.parameters
.attention_heads()
.iter()
.map(AttentionHeadGradients::new)
.collect::<Vec<_>>();
let mut position_count = 0usize;
for example in self.dataset.examples() {
let positioned = state
.parameters
.positional_encoding
.apply(example.hidden().clone())?;
let block_cache = state
.parameters
.block
.apply_with_training_cache(positioned.clone(), example.mask().clone())?;
let logits = readout.apply(block_cache.output.clone())?;
let vocab_size = logits.vocab_size().value();
let mut d_multi_head_rows = vec![
vec![0.0; output_projection.input_dimension().value()];
block_cache.multi_head_output.sequence_len().value()
];
for (
position,
(
(((encoded_row, logit_row), with_feed_forward_row), with_attention_row),
((feed_forward_cache, multi_head_row), target),
),
) in block_cache
.output
.rows()
.iter()
.zip(logits.rows())
.zip(block_cache.with_feed_forward.rows())
.zip(block_cache.with_attention.rows())
.zip(
block_cache
.feed_forward_rows
.iter()
.zip(block_cache.multi_head_output.rows())
.zip(example.targets().as_slice()),
)
.enumerate()
{
let dlogits =
softmax_cross_entropy_logits_gradient(logit_row, target.index(), vocab_size)?;
let d_encoded =
readout_gradients.accumulate(readout, encoded_row.as_slice(), &dlogits);
let d_with_feed_forward = feed_forward_norm_gradients.accumulate(
&d_encoded,
with_feed_forward_row.as_slice(),
feed_forward_norm.parameters(),
);
let d_feed_forward_input = feed_forward_gradients.accumulate(
feed_forward,
feed_forward_cache,
&d_with_feed_forward,
);
let d_normalized_attention = add_rows(&d_with_feed_forward, &d_feed_forward_input);
let d_with_attention = attention_norm_gradients.accumulate(
&d_normalized_attention,
with_attention_row.as_slice(),
attention_norm.parameters(),
);
d_multi_head_rows[position] = output_projection_gradients.accumulate(
output_projection,
multi_head_row.as_slice(),
&d_with_attention,
);
position_count += 1;
}
let value_dimension = block_cache.multi_head_output.head_dimension().value();
for (head_index, ((head_gradient, head), head_cache)) in attention_head_gradients
.iter_mut()
.zip(state.parameters.attention_heads())
.zip(&block_cache.attention_heads)
.enumerate()
{
let start = head_index * value_dimension;
let end = start + value_dimension;
let d_head_output_rows = d_multi_head_rows
.iter()
.map(|row| row[start..end].to_vec())
.collect::<Vec<_>>();
head_gradient.accumulate(
head,
&positioned,
head_cache,
example.mask(),
&d_head_output_rows,
)?;
}
}
let updated_readout =
readout_gradients.apply_to(readout, state.learning_rate(), position_count)?;
let updated_feed_forward =
feed_forward_gradients.apply_to(feed_forward, state.learning_rate(), position_count)?;
let updated_output_projection = output_projection_gradients.apply_to(
output_projection,
state.learning_rate(),
position_count,
)?;
let updated_attention_norm = LayerNormalization::new(attention_norm_gradients.apply_to(
attention_norm.parameters(),
state.learning_rate(),
position_count,
)?);
let updated_feed_forward_norm =
LayerNormalization::new(feed_forward_norm_gradients.apply_to(
feed_forward_norm.parameters(),
state.learning_rate(),
position_count,
)?);
let updated_heads = state
.parameters
.attention_heads()
.iter()
.zip(attention_head_gradients)
.map(|(head, gradients)| {
gradients.apply_to(head, state.learning_rate(), position_count)
})
.collect::<CtResult<Vec<_>>>()?;
let updated_parameters = state
.parameters
.clone()
.with_readout(updated_readout)?
.with_feed_forward(updated_feed_forward)?
.with_output_projection(updated_output_projection)?
.with_layer_norms(updated_attention_norm, updated_feed_forward_norm)?
.with_attention_heads(updated_heads)?;
Ok(state.record_updated_parameters(updated_parameters))
}
}
/// Average sequence cross-entropy for the structured Transformer readout.
pub fn transformer_readout_average_loss(
state: &TransformerTrainingState,
dataset: &TransformerReadoutTrainingSet,
) -> CtResult<Loss> {
let mut total = 0.0;
let mut position_count = 0usize;
for example in dataset.examples() {
let logits = state.apply(Product::new(
example.hidden().clone(),
example.mask().clone(),
))?;
let vocab_size = logits.vocab_size().value();
for (logit_row, target) in logits.rows().iter().zip(example.targets().as_slice()) {
let target_index = target.index();
if target_index >= vocab_size {
return Err(CtError::OutOfRange {
kind: "sequence target",
index: target_index,
limit: vocab_size,
});
}
let probabilities = Softmax.apply(logit_row.clone())?;
let probability = probabilities.as_slice()[target_index].max(1e-9);
total += -probability.ln();
position_count += 1;
}
}
Loss::new(total / position_count as f32)
}
/// Average squared error for local feed-forward sublayer training.
pub fn transformer_feed_forward_average_loss(
state: &TransformerTrainingState,
dataset: &TransformerFeedForwardTrainingSet,
) -> CtResult<Loss> {
let feed_forward = state.parameters.feed_forward();
let mut total = 0.0;
let mut value_count = 0usize;
for example in dataset.examples() {
let output = feed_forward.apply(example.input().clone())?;
for (output_row, target_row) in output.rows().iter().zip(example.target().rows()) {
for (output_value, target_value) in
output_row.as_slice().iter().zip(target_row.as_slice())
{
let error = output_value - target_value;
total += 0.5 * error * error;
value_count += 1;
}
}
}
Loss::new(total / value_count as f32)
}
/// Average sequence cross-entropy for the composed block-level training set.
pub fn transformer_block_average_loss(
state: &TransformerTrainingState,
dataset: &TransformerBlockTrainingSet,
) -> CtResult<Loss> {
let mut total = 0.0;
let mut position_count = 0usize;
for example in dataset.examples() {
let logits = state.apply(Product::new(
example.hidden().clone(),
example.mask().clone(),
))?;
let vocab_size = logits.vocab_size().value();
for (logit_row, target) in logits.rows().iter().zip(example.targets().as_slice()) {
let target_index = target.index();
if target_index >= vocab_size {
return Err(CtError::OutOfRange {
kind: "sequence target",
index: target_index,
limit: vocab_size,
});
}
let probabilities = Softmax.apply(logit_row.clone())?;
let probability = probabilities.as_slice()[target_index].max(1e-9);
total += -probability.ln();
position_count += 1;
}
}
Loss::new(total / position_count as f32)
}
/// Applies softmax independently to each query row.
#[derive(Debug, Clone)]
pub struct AttentionSoftmax;
impl Morphism<AttentionScores, AttentionWeights> for AttentionSoftmax {
fn name(&self) -> &'static str {
"attention_softmax"
}
fn apply(&self, scores: AttentionScores) -> CtResult<AttentionWeights> {
let rows = scores
.rows
.into_iter()
.map(|row| Softmax.apply(row))
.collect::<CtResult<Vec<_>>>()?;
AttentionWeights::new(rows)
}
}
/// Mixes value vectors with row-wise attention weights.
#[derive(Debug, Clone)]
pub struct WeightedValueMixing;
impl Morphism<Product<AttentionWeights, ValueSequence>, AttentionOutput> for WeightedValueMixing {
fn name(&self) -> &'static str {
"weighted_value_mixing"
}
fn apply(&self, input: Product<AttentionWeights, ValueSequence>) -> CtResult<AttentionOutput> {
let (weights, values) = input.into_parts();
if weights.key_len() != values.sequence_len() {
return Err(CtError::ShapeMismatch {
op: "weighted value mixing",
expected: format!("{} value rows", weights.key_len().value()),
got: format!("{} value rows", values.sequence_len().value()),
});
}
let value_width = values.head_dimension().value();
let rows = weights
.rows()
.iter()
.map(|weight_row| weighted_sum(weight_row.as_slice(), values.rows(), value_width))
.collect::<Vec<_>>();
AttentionOutput::new(rows)
}
}
fn weighted_sum(weights: &[f32], values: &[Vector], value_width: usize) -> Vec<f32> {
let mut output = vec![0.0; value_width];
for (weight, value) in weights.iter().zip(values.iter()) {
for (output_value, value_component) in output.iter_mut().zip(value.as_slice()) {
*output_value += weight * value_component;
}
}
output
}
/// Concatenates single-head outputs along the feature dimension.
#[derive(Debug, Clone)]
pub struct ConcatenateHeads;
impl Morphism<AttentionHeadOutputs, MultiHeadOutput> for ConcatenateHeads {
fn name(&self) -> &'static str {
"concatenate_heads"
}
fn apply(&self, heads: AttentionHeadOutputs) -> CtResult<MultiHeadOutput> {
let rows = (0..heads.sequence_len().value())
.map(|position| {
heads
.heads()
.iter()
.flat_map(|head| head.rows()[position].as_slice().iter().copied())
.collect::<Vec<_>>()
})
.collect::<Vec<_>>();
MultiHeadOutput::new(rows, heads.head_count(), heads.head_dimension())
}
}
impl Morphism<MultiHeadOutput, ProjectedAttentionOutput> for AttentionOutputProjection {
fn name(&self) -> &'static str {
"attention_output_projection"
}
fn apply(&self, input: MultiHeadOutput) -> CtResult<ProjectedAttentionOutput> {
if input.model_dimension() != self.input_dimension {
return Err(CtError::ShapeMismatch {
op: "attention output projection",
expected: format!("input dimension {}", self.input_dimension.value()),
got: format!("input dimension {}", input.model_dimension().value()),
});
}
let rows = input
.rows()
.iter()
.map(|row| project_row(row.as_slice(), &self.weight, &self.bias))
.collect::<Vec<_>>();
ProjectedAttentionOutput::new(rows)
}
}
/// Adds a same-shaped sublayer output back to the hidden sequence.
#[derive(Debug, Clone)]
pub struct ResidualConnection;
impl Morphism<Product<HiddenSequence, ProjectedAttentionOutput>, HiddenSequence>
for ResidualConnection
{
fn name(&self) -> &'static str {
"residual_connection"
}
fn apply(
&self,
input: Product<HiddenSequence, ProjectedAttentionOutput>,
) -> CtResult<HiddenSequence> {
let (hidden, sublayer_output) = input.into_parts();
if hidden.sequence_len() != sublayer_output.sequence_len() {
return Err(CtError::ShapeMismatch {
op: "residual connection",
expected: format!("{} sequence rows", hidden.sequence_len().value()),
got: format!("{} sequence rows", sublayer_output.sequence_len().value()),
});
}
if hidden.model_dimension() != sublayer_output.model_dimension() {
return Err(CtError::ShapeMismatch {
op: "residual connection",
expected: format!("model dimension {}", hidden.model_dimension().value()),
got: format!(
"model dimension {}",
sublayer_output.model_dimension().value()
),
});
}
let rows = hidden
.rows()
.iter()
.zip(sublayer_output.rows())
.map(|(left, right)| add_rows(left.as_slice(), right.as_slice()))
.collect::<Vec<_>>();
HiddenSequence::new(rows)
}
}
impl Morphism<HiddenSequence, HiddenSequence> for LayerNormalization {
fn name(&self) -> &'static str {
"layer_normalization"
}
fn apply(&self, input: HiddenSequence) -> CtResult<HiddenSequence> {
if input.model_dimension() != self.parameters.model_dimension() {
return Err(CtError::ShapeMismatch {
op: "layer normalization",
expected: format!(
"model dimension {}",
self.parameters.model_dimension().value()
),
got: format!("model dimension {}", input.model_dimension().value()),
});
}
let rows = input
.rows()
.iter()
.map(|row| normalize_row(row.as_slice(), &self.parameters))
.collect::<Vec<_>>();
HiddenSequence::new(rows)
}
}
impl Morphism<HiddenSequence, HiddenSequence> for PositionWiseFeedForward {
fn name(&self) -> &'static str {
"position_wise_feed_forward"
}
fn apply(&self, input: HiddenSequence) -> CtResult<HiddenSequence> {
let (output, _cache) = feed_forward_with_cache(self, &input)?;
Ok(output)
}
}
impl Morphism<HiddenSequence, HiddenSequence> for PositionalEncoding {
fn name(&self) -> &'static str {
"positional_encoding"
}
fn apply(&self, input: HiddenSequence) -> CtResult<HiddenSequence> {
if input.sequence_len().value() > self.max_sequence_len.value() {
return Err(CtError::ShapeMismatch {
op: "positional encoding",
expected: format!("at most {} sequence rows", self.max_sequence_len.value()),
got: format!("{} sequence rows", input.sequence_len().value()),
});
}
if input.model_dimension() != self.model_dimension {
return Err(CtError::ShapeMismatch {
op: "positional encoding",
expected: format!("model dimension {}", self.model_dimension.value()),
got: format!("model dimension {}", input.model_dimension().value()),
});
}
let rows = input
.rows()
.iter()
.zip(&self.rows)
.map(|(hidden_row, position_row)| {
add_rows(hidden_row.as_slice(), position_row.as_slice())
})
.collect::<Vec<_>>();
HiddenSequence::new(rows)
}
}
impl Morphism<HiddenSequence, SequenceLogits> for TransformerReadout {
fn name(&self) -> &'static str {
"transformer_readout"
}
fn apply(&self, input: HiddenSequence) -> CtResult<SequenceLogits> {
if input.model_dimension() != self.input_dimension {
return Err(CtError::ShapeMismatch {
op: "transformer readout",
expected: format!("input dimension {}", self.input_dimension.value()),
got: format!("input dimension {}", input.model_dimension().value()),
});
}
let rows = input
.rows()
.iter()
.map(|row| project_row(row.as_slice(), &self.weight, &self.bias))
.collect::<Vec<_>>();
SequenceLogits::new(rows)
}
}
impl Morphism<HiddenSequence, QuerySequence> for HiddenToQuery {
fn name(&self) -> &'static str {
"hidden_to_query"
}
fn apply(&self, input: HiddenSequence) -> CtResult<QuerySequence> {
QuerySequence::new(self.projection.project(&input)?)
}
}
impl Morphism<HiddenSequence, KeySequence> for HiddenToKey {
fn name(&self) -> &'static str {
"hidden_to_key"
}
fn apply(&self, input: HiddenSequence) -> CtResult<KeySequence> {
KeySequence::new(self.projection.project(&input)?)
}
}
impl Morphism<HiddenSequence, ValueSequence> for HiddenToValue {
fn name(&self) -> &'static str {
"hidden_to_value"
}
fn apply(&self, input: HiddenSequence) -> CtResult<ValueSequence> {
ValueSequence::new(self.projection.project(&input)?)
}
}
fn apply_self_attention_head(
input: &HiddenSequence,
head: &SelfAttentionHead,
) -> CtResult<AttentionOutput> {
apply_self_attention_head_with_mask(input, head, None)
}
fn apply_self_attention_head_with_mask(
input: &HiddenSequence,
head: &SelfAttentionHead,
mask: Option<&AttentionMask>,
) -> CtResult<AttentionOutput> {
Ok(apply_self_attention_head_with_mask_cache(input, head, mask)?.output)
}
fn apply_self_attention_head_with_mask_cache(
input: &HiddenSequence,
head: &SelfAttentionHead,
mask: Option<&AttentionMask>,
) -> CtResult<AttentionHeadTrainingCache> {
let queries = head.query_projection.apply(input.clone())?;
let keys = head.key_projection.apply(input.clone())?;
let values = head.value_projection.apply(input.clone())?;
let scores = ScaledDotProductScores.apply(Product::new(queries.clone(), keys.clone()))?;
let scores = if let Some(mask) = mask {
MaskedAttentionScores.apply(Product::new(scores, mask.clone()))?
} else {
scores
};
let weights = AttentionSoftmax.apply(scores)?;
let output = WeightedValueMixing.apply(Product::new(weights.clone(), values.clone()))?;
Ok(AttentionHeadTrainingCache {
queries,
keys,
values,
weights,
output,
})
}
impl Morphism<Product<HiddenSequence, HiddenSequence>, HiddenSequence> for ResidualConnection {
fn name(&self) -> &'static str {
"hidden_residual_connection"
}
fn apply(&self, input: Product<HiddenSequence, HiddenSequence>) -> CtResult<HiddenSequence> {
let (left, right) = input.into_parts();
if left.sequence_len() != right.sequence_len() {
return Err(CtError::ShapeMismatch {
op: "hidden residual connection",
expected: format!("{} sequence rows", left.sequence_len().value()),
got: format!("{} sequence rows", right.sequence_len().value()),
});
}
if left.model_dimension() != right.model_dimension() {
return Err(CtError::ShapeMismatch {
op: "hidden residual connection",
expected: format!("model dimension {}", left.model_dimension().value()),
got: format!("model dimension {}", right.model_dimension().value()),
});
}
let rows = left
.rows()
.iter()
.zip(right.rows())
.map(|(left, right)| add_rows(left.as_slice(), right.as_slice()))
.collect::<Vec<_>>();
HiddenSequence::new(rows)
}
}
impl Morphism<HiddenSequence, HiddenSequence> for SingleHeadTransformerBlock {
fn name(&self) -> &'static str {
"single_head_transformer_block"
}
fn apply(&self, input: HiddenSequence) -> CtResult<HiddenSequence> {
if input.model_dimension() != self.model_dimension {
return Err(CtError::ShapeMismatch {
op: "single-head block",
expected: format!("model dimension {}", self.model_dimension.value()),
got: format!("model dimension {}", input.model_dimension().value()),
});
}
let head = SelfAttentionHead::new(
self.query_projection.clone(),
self.key_projection.clone(),
self.value_projection.clone(),
)?;
let attention_output = apply_self_attention_head(&input, &head)?;
let head_outputs = AttentionHeadOutputs::new(vec![attention_output])?;
let multi_head_output = ConcatenateHeads.apply(head_outputs)?;
let projected_attention = self.output_projection.apply(multi_head_output)?;
let with_attention = ResidualConnection.apply(Product::new(input, projected_attention))?;
let normalized_attention = self.attention_norm.apply(with_attention)?;
let feed_forward_output = self.feed_forward.apply(normalized_attention.clone())?;
let with_feed_forward =
ResidualConnection.apply(Product::new(normalized_attention, feed_forward_output))?;
self.feed_forward_norm.apply(with_feed_forward)
}
}
impl Morphism<HiddenSequence, HiddenSequence> for MultiHeadTransformerBlock {
fn name(&self) -> &'static str {
"multi_head_transformer_block"
}
fn apply(&self, input: HiddenSequence) -> CtResult<HiddenSequence> {
if input.model_dimension() != self.model_dimension {
return Err(CtError::ShapeMismatch {
op: "multi-head block",
expected: format!("model dimension {}", self.model_dimension.value()),
got: format!("model dimension {}", input.model_dimension().value()),
});
}
let attention_outputs = self
.heads
.iter()
.map(|head| apply_self_attention_head(&input, head))
.collect::<CtResult<Vec<_>>>()?;
let head_outputs = AttentionHeadOutputs::new(attention_outputs)?;
let multi_head_output = ConcatenateHeads.apply(head_outputs)?;
let projected_attention = self.output_projection.apply(multi_head_output)?;
let with_attention = ResidualConnection.apply(Product::new(input, projected_attention))?;
let normalized_attention = self.attention_norm.apply(with_attention)?;
let feed_forward_output = self.feed_forward.apply(normalized_attention.clone())?;
let with_feed_forward =
ResidualConnection.apply(Product::new(normalized_attention, feed_forward_output))?;
self.feed_forward_norm.apply(with_feed_forward)
}
}
impl Morphism<Product<HiddenSequence, AttentionMask>, HiddenSequence>
for MaskedMultiHeadTransformerBlock
{
fn name(&self) -> &'static str {
"masked_multi_head_transformer_block"
}
fn apply(&self, input: Product<HiddenSequence, AttentionMask>) -> CtResult<HiddenSequence> {
let (hidden, mask) = input.into_parts();
Ok(self.apply_with_training_cache(hidden, mask)?.output)
}
}
impl Morphism<Product<HiddenSequence, AttentionMask>, SequenceLogits>
for TinyTransformerParameters
{
fn name(&self) -> &'static str {
"tiny_transformer_parameters"
}
fn apply(&self, input: Product<HiddenSequence, AttentionMask>) -> CtResult<SequenceLogits> {
let (hidden, mask) = input.into_parts();
let positioned = self.positional_encoding.apply(hidden)?;
let encoded = self.block.apply(Product::new(positioned, mask))?;
self.readout.apply(encoded)
}
}
impl Morphism<Product<HiddenSequence, AttentionMask>, SequenceLogits> for TransformerTrainingState {
fn name(&self) -> &'static str {
"transformer_training_state_forward"
}
fn apply(&self, input: Product<HiddenSequence, AttentionMask>) -> CtResult<SequenceLogits> {
self.parameters.apply(input)
}
}
#[derive(Debug, Clone, PartialEq)]
struct ReadoutGradients {
weight: Vec<Vec<f32>>,
bias: Vec<f32>,
}
impl ReadoutGradients {
fn new(readout: &TransformerReadout) -> Self {
Self {
weight: vec![
vec![0.0; readout.vocab_size().value()];
readout.input_dimension().value()
],
bias: vec![0.0; readout.vocab_size().value()],
}
}
fn accumulate(
&mut self,
readout: &TransformerReadout,
hidden_row: &[f32],
dlogits: &[f32],
) -> Vec<f32> {
let mut d_hidden = vec![0.0; readout.input_dimension().value()];
for (vocab_id, dlogit) in dlogits.iter().copied().enumerate() {
self.bias[vocab_id] += dlogit;
for (feature, hidden_value) in hidden_row.iter().copied().enumerate() {
self.weight[feature][vocab_id] += hidden_value * dlogit;
d_hidden[feature] += readout.weight()[feature][vocab_id] * dlogit;
}
}
d_hidden
}
fn apply_to(
self,
readout: &TransformerReadout,
learning_rate: LearningRate,
position_count: usize,
) -> CtResult<TransformerReadout> {
if position_count == 0 {
return Err(CtError::EmptyInput("readout gradient positions"));
}
let scale = learning_rate.value() / position_count as f32;
let mut updated_weight = readout.weight().to_vec();
let mut updated_bias = readout.bias().to_vec();
for (row, grad_row) in updated_weight.iter_mut().zip(&self.weight) {
for (weight, grad) in row.iter_mut().zip(grad_row) {
*weight -= scale * grad;
}
}
for (bias, grad) in updated_bias.iter_mut().zip(&self.bias) {
*bias -= scale * grad;
}
TransformerReadout::new(updated_weight, updated_bias)
}
}
#[derive(Debug, Clone, PartialEq)]
struct FeedForwardGradients {
first_weight: Vec<Vec<f32>>,
first_bias: Vec<f32>,
second_weight: Vec<Vec<f32>>,
second_bias: Vec<f32>,
}
impl FeedForwardGradients {
fn new(feed_forward: &PositionWiseFeedForward) -> Self {
Self {
first_weight: vec![
vec![0.0; feed_forward.hidden_dimension().value()];
feed_forward.input_dimension().value()
],
first_bias: vec![0.0; feed_forward.hidden_dimension().value()],
second_weight: vec![
vec![0.0; feed_forward.output_dimension().value()];
feed_forward.hidden_dimension().value()
],
second_bias: vec![0.0; feed_forward.output_dimension().value()],
}
}
fn accumulate(
&mut self,
feed_forward: &PositionWiseFeedForward,
cache: &FeedForwardRowCache,
d_output: &[f32],
) -> Vec<f32> {
let mut d_activation = vec![0.0; feed_forward.hidden_dimension().value()];
for (output_id, d_output_value) in d_output.iter().copied().enumerate() {
self.second_bias[output_id] += d_output_value;
for (hidden_id, activation_value) in cache.activation.iter().copied().enumerate() {
self.second_weight[hidden_id][output_id] += activation_value * d_output_value;
d_activation[hidden_id] +=
feed_forward.second_weight()[hidden_id][output_id] * d_output_value;
}
}
let mut d_input = vec![0.0; feed_forward.input_dimension().value()];
for (hidden_id, pre_activation_value) in cache.pre_activation.iter().copied().enumerate() {
let d_pre_activation = if pre_activation_value > 0.0 {
d_activation[hidden_id]
} else {
0.0
};
self.first_bias[hidden_id] += d_pre_activation;
for (input_id, input_value) in cache.input.iter().copied().enumerate() {
self.first_weight[input_id][hidden_id] += input_value * d_pre_activation;
d_input[input_id] +=
feed_forward.first_weight()[input_id][hidden_id] * d_pre_activation;
}
}
d_input
}
fn apply_to(
self,
feed_forward: &PositionWiseFeedForward,
learning_rate: LearningRate,
row_count: usize,
) -> CtResult<PositionWiseFeedForward> {
if row_count == 0 {
return Err(CtError::EmptyInput("feed-forward gradient rows"));
}
let scale = learning_rate.value() / row_count as f32;
let mut updated_first_weight = feed_forward.first_weight().to_vec();
let mut updated_first_bias = feed_forward.first_bias().to_vec();
let mut updated_second_weight = feed_forward.second_weight().to_vec();
let mut updated_second_bias = feed_forward.second_bias().to_vec();
for (row, grad_row) in updated_first_weight.iter_mut().zip(&self.first_weight) {
for (weight, grad) in row.iter_mut().zip(grad_row) {
*weight -= scale * grad;
}
}
for (bias, grad) in updated_first_bias.iter_mut().zip(&self.first_bias) {
*bias -= scale * grad;
}
for (row, grad_row) in updated_second_weight.iter_mut().zip(&self.second_weight) {
for (weight, grad) in row.iter_mut().zip(grad_row) {
*weight -= scale * grad;
}
}
for (bias, grad) in updated_second_bias.iter_mut().zip(&self.second_bias) {
*bias -= scale * grad;
}
PositionWiseFeedForward::new(
updated_first_weight,
updated_first_bias,
updated_second_weight,
updated_second_bias,
)
}
}
#[derive(Debug, Clone, PartialEq)]
struct AttentionOutputProjectionGradients {
weight: Vec<Vec<f32>>,
bias: Vec<f32>,
}
impl AttentionOutputProjectionGradients {
fn new(output_projection: &AttentionOutputProjection) -> Self {
Self {
weight: vec![
vec![0.0; output_projection.output_dimension().value()];
output_projection.input_dimension().value()
],
bias: vec![0.0; output_projection.output_dimension().value()],
}
}
fn accumulate(
&mut self,
output_projection: &AttentionOutputProjection,
multi_head_row: &[f32],
d_projected_attention: &[f32],
) -> Vec<f32> {
let mut d_multi_head = vec![0.0; output_projection.input_dimension().value()];
for (output_id, d_value) in d_projected_attention.iter().copied().enumerate() {
self.bias[output_id] += d_value;
for (input_id, input_value) in multi_head_row.iter().copied().enumerate() {
self.weight[input_id][output_id] += input_value * d_value;
d_multi_head[input_id] += output_projection.weight()[input_id][output_id] * d_value;
}
}
d_multi_head
}
fn apply_to(
self,
output_projection: &AttentionOutputProjection,
learning_rate: LearningRate,
row_count: usize,
) -> CtResult<AttentionOutputProjection> {
if row_count == 0 {
return Err(CtError::EmptyInput(
"attention output projection gradient rows",
));
}
let scale = learning_rate.value() / row_count as f32;
let mut updated_weight = output_projection.weight().to_vec();
let mut updated_bias = output_projection.bias().to_vec();
for (row, grad_row) in updated_weight.iter_mut().zip(&self.weight) {
for (weight, grad) in row.iter_mut().zip(grad_row) {
*weight -= scale * grad;
}
}
for (bias, grad) in updated_bias.iter_mut().zip(&self.bias) {
*bias -= scale * grad;
}
AttentionOutputProjection::new(updated_weight, updated_bias)
}
}
#[derive(Debug, Clone, PartialEq)]
struct HiddenProjectionGradients {
weight: Vec<Vec<f32>>,
bias: Vec<f32>,
}
impl HiddenProjectionGradients {
fn new(projection: &HiddenProjection) -> Self {
Self {
weight: vec![
vec![0.0; projection.head_dimension().value()];
projection.input_dimension().value()
],
bias: vec![0.0; projection.head_dimension().value()],
}
}
fn accumulate(
&mut self,
projection: &HiddenProjection,
hidden_row: &[f32],
d_output: &[f32],
) -> CtResult<()> {
if hidden_row.len() != projection.input_dimension().value() {
return Err(CtError::ShapeMismatch {
op: projection.op,
expected: format!("input dimension {}", projection.input_dimension().value()),
got: format!("input dimension {}", hidden_row.len()),
});
}
if d_output.len() != projection.head_dimension().value() {
return Err(CtError::ShapeMismatch {
op: projection.op,
expected: format!("output dimension {}", projection.head_dimension().value()),
got: format!("output dimension {}", d_output.len()),
});
}
for (output_id, d_output_value) in d_output.iter().copied().enumerate() {
self.bias[output_id] += d_output_value;
for (input_id, input_value) in hidden_row.iter().copied().enumerate() {
self.weight[input_id][output_id] += input_value * d_output_value;
}
}
Ok(())
}
fn updated_parts(
self,
projection: &HiddenProjection,
learning_rate: LearningRate,
row_count: usize,
) -> CtResult<(Vec<Vec<f32>>, Vec<f32>)> {
if row_count == 0 {
return Err(CtError::EmptyInput("hidden projection gradient rows"));
}
let scale = learning_rate.value() / row_count as f32;
let mut updated_weight = projection.weight().to_vec();
let mut updated_bias = projection.bias().to_vec();
for (row, grad_row) in updated_weight.iter_mut().zip(&self.weight) {
for (weight, grad) in row.iter_mut().zip(grad_row) {
*weight -= scale * grad;
}
}
for (bias, grad) in updated_bias.iter_mut().zip(&self.bias) {
*bias -= scale * grad;
}
Ok((updated_weight, updated_bias))
}
}
#[derive(Debug, Clone, PartialEq)]
struct AttentionHeadGradients {
query: HiddenProjectionGradients,
key: HiddenProjectionGradients,
value: HiddenProjectionGradients,
}
impl AttentionHeadGradients {
fn new(head: &SelfAttentionHead) -> Self {
Self {
query: HiddenProjectionGradients::new(&head.query_projection.projection),
key: HiddenProjectionGradients::new(&head.key_projection.projection),
value: HiddenProjectionGradients::new(&head.value_projection.projection),
}
}
fn accumulate(
&mut self,
head: &SelfAttentionHead,
input: &HiddenSequence,
cache: &AttentionHeadTrainingCache,
mask: &AttentionMask,
d_output_rows: &[Vec<f32>],
) -> CtResult<()> {
let sequence_len = input.sequence_len().value();
let value_dimension = head.value_dimension().value();
let query_key_dimension = head.query_key_dimension().value();
if d_output_rows.len() != sequence_len {
return Err(CtError::ShapeMismatch {
op: "attention head gradients",
expected: format!("{sequence_len} output rows"),
got: format!("{} output rows", d_output_rows.len()),
});
}
if mask.query_len().value() != sequence_len || mask.key_len().value() != sequence_len {
return Err(CtError::ShapeMismatch {
op: "attention head gradient mask",
expected: format!("{sequence_len} query rows x {sequence_len} key columns"),
got: format!(
"{} query rows x {} key columns",
mask.query_len().value(),
mask.key_len().value()
),
});
}
let mut d_weights = vec![vec![0.0; sequence_len]; sequence_len];
let mut d_values = vec![vec![0.0; value_dimension]; sequence_len];
for (query_id, d_output) in d_output_rows.iter().enumerate() {
if d_output.len() != value_dimension {
return Err(CtError::ShapeMismatch {
op: "attention head output gradient",
expected: format!("value dimension {value_dimension}"),
got: format!("value dimension {}", d_output.len()),
});
}
for key_id in 0..sequence_len {
let value_row = cache.values.rows()[key_id].as_slice();
let weight = cache.weights.rows()[query_id].as_slice()[key_id];
for value_id in 0..value_dimension {
d_weights[query_id][key_id] += d_output[value_id] * value_row[value_id];
d_values[key_id][value_id] += weight * d_output[value_id];
}
}
}
let mut d_scores = vec![vec![0.0; sequence_len]; sequence_len];
for query_id in 0..sequence_len {
let weight_row = cache.weights.rows()[query_id].as_slice();
let row_dot = d_weights[query_id]
.iter()
.zip(weight_row)
.map(|(grad, weight)| grad * weight)
.sum::<f32>();
for key_id in 0..sequence_len {
if mask.rows()[query_id][key_id] {
d_scores[query_id][key_id] =
weight_row[key_id] * (d_weights[query_id][key_id] - row_dot);
}
}
}
let score_scale = (query_key_dimension as f32).sqrt();
let mut d_queries = vec![vec![0.0; query_key_dimension]; sequence_len];
let mut d_keys = vec![vec![0.0; query_key_dimension]; sequence_len];
for query_id in 0..sequence_len {
let query_row = cache.queries.rows()[query_id].as_slice();
for key_id in 0..sequence_len {
let score_gradient = d_scores[query_id][key_id] / score_scale;
let key_row = cache.keys.rows()[key_id].as_slice();
for feature in 0..query_key_dimension {
d_queries[query_id][feature] += score_gradient * key_row[feature];
d_keys[key_id][feature] += score_gradient * query_row[feature];
}
}
}
for position in 0..sequence_len {
let hidden_row = input.rows()[position].as_slice();
self.query.accumulate(
&head.query_projection.projection,
hidden_row,
&d_queries[position],
)?;
self.key.accumulate(
&head.key_projection.projection,
hidden_row,
&d_keys[position],
)?;
self.value.accumulate(
&head.value_projection.projection,
hidden_row,
&d_values[position],
)?;
}
Ok(())
}
fn apply_to(
self,
head: &SelfAttentionHead,
learning_rate: LearningRate,
row_count: usize,
) -> CtResult<SelfAttentionHead> {
let (query_weight, query_bias) = self.query.updated_parts(
&head.query_projection.projection,
learning_rate,
row_count,
)?;
let (key_weight, key_bias) =
self.key
.updated_parts(&head.key_projection.projection, learning_rate, row_count)?;
let (value_weight, value_bias) = self.value.updated_parts(
&head.value_projection.projection,
learning_rate,
row_count,
)?;
SelfAttentionHead::new(
HiddenToQuery::new(query_weight, query_bias)?,
HiddenToKey::new(key_weight, key_bias)?,
HiddenToValue::new(value_weight, value_bias)?,
)
}
}
#[derive(Debug, Clone, PartialEq)]
struct LayerNormGradients {
scale: Vec<f32>,
shift: Vec<f32>,
}
impl LayerNormGradients {
fn new(parameters: &LayerNormParameters) -> Self {
Self {
scale: vec![0.0; parameters.model_dimension().value()],
shift: vec![0.0; parameters.model_dimension().value()],
}
}
fn accumulate(
&mut self,
d_output: &[f32],
input: &[f32],
parameters: &LayerNormParameters,
) -> Vec<f32> {
let stats = layer_norm_stats(input, parameters);
for (feature, (grad, normalized_value)) in
d_output.iter().zip(&stats.normalized).enumerate()
{
self.scale[feature] += grad * normalized_value;
self.shift[feature] += grad;
}
layer_norm_backward_from_stats(d_output, &stats, parameters)
}
fn apply_to(
self,
parameters: &LayerNormParameters,
learning_rate: LearningRate,
row_count: usize,
) -> CtResult<LayerNormParameters> {
if row_count == 0 {
return Err(CtError::EmptyInput("layer norm gradient rows"));
}
let scale = learning_rate.value() / row_count as f32;
let mut updated_scale = parameters.scale().to_vec();
let mut updated_shift = parameters.shift().to_vec();
for (value, grad) in updated_scale.iter_mut().zip(&self.scale) {
*value -= scale * grad;
}
for (value, grad) in updated_shift.iter_mut().zip(&self.shift) {
*value -= scale * grad;
}
LayerNormParameters::new(updated_scale, updated_shift, parameters.epsilon())
}
}
fn feed_forward_with_cache(
feed_forward: &PositionWiseFeedForward,
input: &HiddenSequence,
) -> CtResult<(HiddenSequence, Vec<FeedForwardRowCache>)> {
if input.model_dimension() != feed_forward.input_dimension() {
return Err(CtError::ShapeMismatch {
op: "position-wise feed-forward",
expected: format!("input dimension {}", feed_forward.input_dimension().value()),
got: format!("input dimension {}", input.model_dimension().value()),
});
}
let mut output_rows = Vec::with_capacity(input.rows().len());
let mut cache_rows = Vec::with_capacity(input.rows().len());
for row in input.rows() {
let input_row = row.as_slice().to_vec();
let pre_activation = project_row(
&input_row,
feed_forward.first_weight(),
feed_forward.first_bias(),
);
let activation = pre_activation
.iter()
.map(|value| value.max(0.0))
.collect::<Vec<_>>();
let output = project_row(
&activation,
feed_forward.second_weight(),
feed_forward.second_bias(),
);
output_rows.push(output.clone());
cache_rows.push(FeedForwardRowCache {
input: input_row,
pre_activation,
activation,
output,
});
}
Ok((HiddenSequence::new(output_rows)?, cache_rows))
}
fn softmax_cross_entropy_logits_gradient(
logits: &Logits,
target_index: usize,
vocab_size: usize,
) -> CtResult<Vec<f32>> {
if target_index >= vocab_size {
return Err(CtError::OutOfRange {
kind: "sequence target",
index: target_index,
limit: vocab_size,
});
}
let probabilities = Softmax.apply(logits.clone())?;
let mut dlogits = probabilities.as_slice().to_vec();
dlogits[target_index] -= 1.0;
Ok(dlogits)
}
#[derive(Debug, Clone, PartialEq)]
struct LayerNormStats {
dimension: f32,
inverse_std: f32,
normalized: Vec<f32>,
}
fn layer_norm_stats(input: &[f32], parameters: &LayerNormParameters) -> LayerNormStats {
let dimension = input.len() as f32;
let mean = input.iter().sum::<f32>() / dimension;
let variance = input
.iter()
.map(|value| {
let centered = value - mean;
centered * centered
})
.sum::<f32>()
/ dimension;
let inverse_std = 1.0 / (variance + parameters.epsilon().value()).sqrt();
let normalized = input
.iter()
.map(|value| (value - mean) * inverse_std)
.collect::<Vec<_>>();
LayerNormStats {
dimension,
inverse_std,
normalized,
}
}
fn layer_norm_backward_from_stats(
d_output: &[f32],
stats: &LayerNormStats,
parameters: &LayerNormParameters,
) -> Vec<f32> {
let d_normalized = d_output
.iter()
.zip(parameters.scale())
.map(|(grad, scale)| grad * scale)
.collect::<Vec<_>>();
let sum_d_normalized = d_normalized.iter().sum::<f32>();
let sum_d_normalized_times_normalized = d_normalized
.iter()
.zip(&stats.normalized)
.map(|(grad, normalized_value)| grad * normalized_value)
.sum::<f32>();
d_normalized
.iter()
.zip(&stats.normalized)
.map(|(grad, normalized_value)| {
(stats.dimension * grad
- sum_d_normalized
- normalized_value * sum_d_normalized_times_normalized)
* stats.inverse_std
/ stats.dimension
})
.collect()
}
fn validate_projection_input(
op: &'static str,
expected: ModelDimension,
got: ModelDimension,
) -> CtResult<()> {
if expected != got {
return Err(CtError::ShapeMismatch {
op,
expected: format!("model dimension {}", expected.value()),
got: format!("model dimension {}", got.value()),
});
}
Ok(())
}
fn add_rows(left: &[f32], right: &[f32]) -> Vec<f32> {
left.iter()
.zip(right.iter())
.map(|(left, right)| left + right)
.collect()
}
fn validate_linear_parts(
op: &'static str,
weight: &[Vec<f32>],
bias: &[f32],
) -> CtResult<(ModelDimension, ModelDimension)> {
if weight.is_empty() {
return Err(CtError::EmptyInput("linear weight"));
}
if bias.is_empty() {
return Err(CtError::EmptyInput("linear bias"));
}
let output_dimension = bias.len();
if bias.iter().any(|value| !value.is_finite()) {
return Err(CtError::ShapeMismatch {
op,
expected: "finite bias values".to_string(),
got: "non-finite bias value".to_string(),
});
}
for row in weight {
if row.len() != output_dimension {
return Err(CtError::ShapeMismatch {
op,
expected: format!("weight rows have {output_dimension} columns"),
got: format!("weight row with {} columns", row.len()),
});
}
if row.iter().any(|value| !value.is_finite()) {
return Err(CtError::ShapeMismatch {
op,
expected: "finite weight values".to_string(),
got: "non-finite weight value".to_string(),
});
}
}
Ok((
ModelDimension::new(weight.len())?,
ModelDimension::new(output_dimension)?,
))
}
fn normalize_row(input: &[f32], parameters: &LayerNormParameters) -> Vec<f32> {
let mean = input.iter().sum::<f32>() / input.len() as f32;
let variance = input
.iter()
.map(|value| {
let centered = value - mean;
centered * centered
})
.sum::<f32>()
/ input.len() as f32;
let denominator = (variance + parameters.epsilon.value()).sqrt();
input
.iter()
.zip(parameters.scale.iter().zip(¶meters.shift))
.map(|(value, (scale, shift))| ((value - mean) / denominator) * scale + shift)
.collect()
}
fn project_row(input: &[f32], weight: &[Vec<f32>], bias: &[f32]) -> Vec<f32> {
let mut output = bias.to_vec();
for (feature, input_value) in input.iter().enumerate() {
for (output_value, weight_value) in output.iter_mut().zip(&weight[feature]) {
*output_value += input_value * weight_value;
}
}
output
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn scaled_dot_product_scores_build_query_by_key_rows() -> CtResult<()> {
let queries = QuerySequence::new(vec![vec![1.0, 0.0], vec![0.0, 1.0]])?;
let keys = KeySequence::new(vec![vec![1.0, 0.0], vec![0.0, 1.0], vec![1.0, 1.0]])?;
let scores = ScaledDotProductScores.apply(Product::new(queries, keys))?;
assert_eq!(scores.query_len().value(), 2);
assert_eq!(scores.key_len().value(), 3);
assert!(crate::domain::approx_eq(
scores.rows()[0].as_slice()[0],
std::f32::consts::FRAC_1_SQRT_2,
1e-4
));
assert!(crate::domain::approx_eq(
scores.rows()[0].as_slice()[1],
0.0,
1e-4
));
assert!(crate::domain::approx_eq(
scores.rows()[1].as_slice()[2],
std::f32::consts::FRAC_1_SQRT_2,
1e-4
));
Ok(())
}
#[test]
fn weighted_value_mixing_builds_one_output_per_query() -> CtResult<()> {
let queries = QuerySequence::new(vec![vec![1.0, 0.0], vec![0.0, 1.0]])?;
let keys = KeySequence::new(vec![vec![1.0, 0.0], vec![0.0, 1.0], vec![1.0, 1.0]])?;
let values = ValueSequence::new(vec![vec![1.0, 10.0], vec![2.0, 20.0], vec![3.0, 30.0]])?;
let scores = ScaledDotProductScores.apply(Product::new(queries, keys))?;
let weights = AttentionSoftmax.apply(scores)?;
let output = WeightedValueMixing.apply(Product::new(weights, values))?;
assert_eq!(output.sequence_len().value(), 2);
assert_eq!(output.head_dimension().value(), 2);
assert!(crate::domain::approx_eq(
output.rows()[0].as_slice()[0],
2.0,
1e-4
));
assert!(crate::domain::approx_eq(
output.rows()[0].as_slice()[1],
20.0,
1e-4
));
Ok(())
}
#[test]
fn concatenate_heads_preserves_sequence_and_concatenates_features() -> CtResult<()> {
let head_a = AttentionOutput::new(vec![vec![1.0, 10.0], vec![2.0, 20.0]])?;
let head_b = AttentionOutput::new(vec![vec![3.0, 30.0], vec![4.0, 40.0]])?;
let heads = AttentionHeadOutputs::new(vec![head_a, head_b])?;
let output = ConcatenateHeads.apply(heads)?;
assert_eq!(output.sequence_len().value(), 2);
assert_eq!(output.head_count().value(), 2);
assert_eq!(output.head_dimension().value(), 2);
assert_eq!(output.model_dimension().value(), 4);
assert_eq!(output.rows()[0].as_slice(), &[1.0, 10.0, 3.0, 30.0]);
assert_eq!(output.rows()[1].as_slice(), &[2.0, 20.0, 4.0, 40.0]);
Ok(())
}
#[test]
fn attention_output_projection_maps_multi_head_rows() -> CtResult<()> {
let head_a = AttentionOutput::new(vec![vec![1.0, 10.0], vec![2.0, 20.0]])?;
let head_b = AttentionOutput::new(vec![vec![3.0, 30.0], vec![4.0, 40.0]])?;
let multi_head =
ConcatenateHeads.apply(AttentionHeadOutputs::new(vec![head_a, head_b])?)?;
let projection = AttentionOutputProjection::new(
vec![
vec![1.0, 0.0],
vec![0.0, 0.1],
vec![0.5, 0.0],
vec![0.0, 0.01],
],
vec![0.0, 1.0],
)?;
let output = projection.apply(multi_head)?;
assert_eq!(output.sequence_len().value(), 2);
assert_eq!(output.model_dimension().value(), 2);
assert!(crate::domain::approx_eq(
output.rows()[0].as_slice()[0],
2.5,
1e-4
));
assert!(crate::domain::approx_eq(
output.rows()[0].as_slice()[1],
2.3,
1e-4
));
assert!(crate::domain::approx_eq(
output.rows()[1].as_slice()[0],
4.0,
1e-4
));
assert!(crate::domain::approx_eq(
output.rows()[1].as_slice()[1],
3.4,
1e-4
));
Ok(())
}
#[test]
fn residual_connection_adds_matching_sequences() -> CtResult<()> {
let hidden = HiddenSequence::new(vec![vec![1.0, 2.0], vec![3.0, 4.0]])?;
let sublayer_output = ProjectedAttentionOutput::new(vec![vec![0.5, 1.5], vec![2.5, 3.5]])?;
let output = ResidualConnection.apply(Product::new(hidden, sublayer_output))?;
assert_eq!(output.sequence_len().value(), 2);
assert_eq!(output.model_dimension().value(), 2);
assert_eq!(output.rows()[0].as_slice(), &[1.5, 3.5]);
assert_eq!(output.rows()[1].as_slice(), &[5.5, 7.5]);
Ok(())
}
#[test]
fn layer_normalization_preserves_shape_and_normalizes_each_row() -> CtResult<()> {
let hidden = HiddenSequence::new(vec![vec![1.0, 3.0], vec![2.0, 4.0]])?;
let norm = LayerNormalization::new(LayerNormParameters::identity(ModelDimension::new(2)?));
let output = norm.apply(hidden)?;
assert_eq!(output.sequence_len().value(), 2);
assert_eq!(output.model_dimension().value(), 2);
assert!(crate::domain::approx_eq(
output.rows()[0].as_slice()[0],
-0.999995,
1e-4
));
assert!(crate::domain::approx_eq(
output.rows()[0].as_slice()[1],
0.999995,
1e-4
));
assert!(crate::domain::approx_eq(
output.rows()[1].as_slice()[0],
-0.999995,
1e-4
));
assert!(crate::domain::approx_eq(
output.rows()[1].as_slice()[1],
0.999995,
1e-4
));
Ok(())
}
#[test]
fn layer_normalization_applies_scale_and_shift() -> CtResult<()> {
let hidden = HiddenSequence::new(vec![vec![1.0, 3.0]])?;
let params = LayerNormParameters::new(
vec![2.0, 0.5],
vec![1.0, -1.0],
NormalizationEpsilon::new(1e-5)?,
)?;
let norm = LayerNormalization::new(params);
let output = norm.apply(hidden)?;
assert!(crate::domain::approx_eq(
output.rows()[0].as_slice()[0],
-0.99999,
1e-4
));
assert!(crate::domain::approx_eq(
output.rows()[0].as_slice()[1],
-0.5000025,
1e-4
));
Ok(())
}
#[test]
fn position_wise_feed_forward_maps_each_row_and_preserves_shape() -> CtResult<()> {
let hidden = HiddenSequence::new(vec![vec![1.0, 2.0], vec![3.0, 4.0]])?;
let feed_forward = PositionWiseFeedForward::new(
vec![vec![1.0, -1.0, 0.5], vec![0.0, 1.0, 0.5]],
vec![0.0, 0.0, 0.0],
vec![vec![1.0, 0.0], vec![0.0, 1.0], vec![0.5, 0.5]],
vec![0.0, 0.0],
)?;
let output = feed_forward.apply(hidden)?;
assert_eq!(output.sequence_len().value(), 2);
assert_eq!(output.model_dimension().value(), 2);
assert!(crate::domain::approx_eq(
output.rows()[0].as_slice()[0],
1.75,
1e-4
));
assert!(crate::domain::approx_eq(
output.rows()[0].as_slice()[1],
1.75,
1e-4
));
assert!(crate::domain::approx_eq(
output.rows()[1].as_slice()[0],
4.75,
1e-4
));
assert!(crate::domain::approx_eq(
output.rows()[1].as_slice()[1],
2.75,
1e-4
));
Ok(())
}
#[test]
fn attention_mask_removes_disallowed_positions_before_softmax() -> CtResult<()> {
let scores = AttentionScores::new(vec![vec![2.0, 1.0, 2.0]])?;
let mask = AttentionMask::new(vec![vec![true, false, true]])?;
let masked_scores = MaskedAttentionScores.apply(Product::new(scores, mask))?;
let weights = AttentionSoftmax.apply(masked_scores)?;
assert!(crate::domain::approx_eq(
weights.rows()[0].as_slice()[0],
0.5,
1e-4
));
assert!(crate::domain::approx_eq(
weights.rows()[0].as_slice()[1],
0.0,
1e-4
));
assert!(crate::domain::approx_eq(
weights.rows()[0].as_slice()[2],
0.5,
1e-4
));
Ok(())
}
#[test]
fn attention_softmax_normalizes_each_query_row() -> CtResult<()> {
let scores = AttentionScores::new(vec![vec![2.0, 1.0], vec![0.0, 3.0]])?;
let weights = AttentionSoftmax.apply(scores)?;
assert_eq!(weights.query_len().value(), 2);
assert_eq!(weights.key_len().value(), 2);
for row in weights.rows() {
let sum: f32 = row.as_slice().iter().sum();
assert!(crate::domain::approx_eq(sum, 1.0, 1e-4));
}
Ok(())
}
#[test]
fn attention_scores_reject_non_finite_values() {
assert!(matches!(
AttentionScores::new(vec![vec![1.0, f32::NAN]]),
Err(CtError::ShapeMismatch {
op: "attention scores",
..
})
));
}
#[test]
fn attention_scores_reject_ragged_rows() {
assert!(matches!(
AttentionScores::new(vec![vec![1.0, 2.0], vec![3.0]]),
Err(CtError::ShapeMismatch {
op: "attention scores",
..
})
));
}
#[test]
fn attention_mask_rejects_rows_with_no_allowed_keys() {
assert!(matches!(
AttentionMask::new(vec![vec![false, false]]),
Err(CtError::EmptyInput("attention mask row allows no keys"))
));
}
#[test]
fn masked_attention_scores_reject_shape_mismatch() -> CtResult<()> {
let scores = AttentionScores::new(vec![vec![1.0, 2.0]])?;
let mask = AttentionMask::new(vec![vec![true, true], vec![true, true]])?;
assert!(matches!(
MaskedAttentionScores.apply(Product::new(scores, mask)),
Err(CtError::ShapeMismatch {
op: "masked attention scores",
..
})
));
Ok(())
}
#[test]
fn query_sequence_rejects_ragged_rows() {
assert!(matches!(
QuerySequence::new(vec![vec![1.0, 2.0], vec![3.0]]),
Err(CtError::ShapeMismatch {
op: "query sequence",
..
})
));
}
#[test]
fn value_sequence_rejects_empty_rows() {
assert!(matches!(
ValueSequence::new(vec![Vec::new()]),
Err(CtError::EmptyInput("attention vector row"))
));
}
#[test]
fn key_sequence_rejects_non_finite_values() {
assert!(matches!(
KeySequence::new(vec![vec![1.0, f32::NAN]]),
Err(CtError::ShapeMismatch {
op: "key sequence",
..
})
));
}
#[test]
fn scaled_dot_product_rejects_mismatched_head_dimensions() -> CtResult<()> {
let queries = QuerySequence::new(vec![vec![1.0, 0.0]])?;
let keys = KeySequence::new(vec![vec![1.0, 0.0, 1.0]])?;
assert!(matches!(
ScaledDotProductScores.apply(Product::new(queries, keys)),
Err(CtError::ShapeMismatch {
op: "scaled dot-product attention scores",
..
})
));
Ok(())
}
#[test]
fn weighted_value_mixing_rejects_value_length_mismatch() -> CtResult<()> {
let weights = AttentionWeights::new(vec![Distribution::new(vec![0.5, 0.5])?])?;
let values = ValueSequence::new(vec![vec![1.0, 10.0]])?;
assert!(matches!(
WeightedValueMixing.apply(Product::new(weights, values)),
Err(CtError::ShapeMismatch {
op: "weighted value mixing",
..
})
));
Ok(())
}
#[test]
fn sequence_and_head_dimensions_reject_zero() {
assert!(matches!(
SequenceLength::new(0),
Err(CtError::EmptyInput("sequence length"))
));
assert!(matches!(
HeadDimension::new(0),
Err(CtError::EmptyInput("head dimension"))
));
assert!(matches!(
HeadCount::new(0),
Err(CtError::EmptyInput("head count"))
));
}
#[test]
fn attention_head_outputs_reject_sequence_mismatch() -> CtResult<()> {
let head_a = AttentionOutput::new(vec![vec![1.0, 10.0]])?;
let head_b = AttentionOutput::new(vec![vec![2.0, 20.0], vec![3.0, 30.0]])?;
assert!(matches!(
AttentionHeadOutputs::new(vec![head_a, head_b]),
Err(CtError::ShapeMismatch {
op: "attention head outputs",
..
})
));
Ok(())
}
#[test]
fn attention_head_outputs_reject_head_dimension_mismatch() -> CtResult<()> {
let head_a = AttentionOutput::new(vec![vec![1.0, 10.0]])?;
let head_b = AttentionOutput::new(vec![vec![2.0, 20.0, 200.0]])?;
assert!(matches!(
AttentionHeadOutputs::new(vec![head_a, head_b]),
Err(CtError::ShapeMismatch {
op: "attention head outputs",
..
})
));
Ok(())
}
#[test]
fn attention_output_projection_rejects_input_dimension_mismatch() -> CtResult<()> {
let head_a = AttentionOutput::new(vec![vec![1.0, 10.0]])?;
let head_b = AttentionOutput::new(vec![vec![2.0, 20.0]])?;
let multi_head =
ConcatenateHeads.apply(AttentionHeadOutputs::new(vec![head_a, head_b])?)?;
let projection =
AttentionOutputProjection::new(vec![vec![1.0], vec![1.0], vec![1.0]], vec![0.0])?;
assert!(matches!(
projection.apply(multi_head),
Err(CtError::ShapeMismatch {
op: "attention output projection",
..
})
));
Ok(())
}
#[test]
fn attention_output_projection_rejects_bad_weight_shapes() {
assert!(matches!(
AttentionOutputProjection::new(vec![vec![1.0, 0.0], vec![1.0]], vec![0.0, 0.0]),
Err(CtError::ShapeMismatch {
op: "attention output projection",
..
})
));
}
#[test]
fn attention_output_projection_rejects_non_finite_values() {
assert!(matches!(
AttentionOutputProjection::new(vec![vec![1.0]], vec![f32::NAN]),
Err(CtError::ShapeMismatch {
op: "attention output projection",
..
})
));
assert!(matches!(
AttentionOutputProjection::new(vec![vec![f32::INFINITY]], vec![0.0]),
Err(CtError::ShapeMismatch {
op: "attention output projection",
..
})
));
}
#[test]
fn residual_connection_rejects_sequence_mismatch() -> CtResult<()> {
let hidden = HiddenSequence::new(vec![vec![1.0, 2.0]])?;
let sublayer_output = ProjectedAttentionOutput::new(vec![vec![0.5, 1.5], vec![2.5, 3.5]])?;
assert!(matches!(
ResidualConnection.apply(Product::new(hidden, sublayer_output)),
Err(CtError::ShapeMismatch {
op: "residual connection",
..
})
));
Ok(())
}
#[test]
fn residual_connection_rejects_model_dimension_mismatch() -> CtResult<()> {
let hidden = HiddenSequence::new(vec![vec![1.0, 2.0]])?;
let sublayer_output = ProjectedAttentionOutput::new(vec![vec![0.5, 1.5, 2.5]])?;
assert!(matches!(
ResidualConnection.apply(Product::new(hidden, sublayer_output)),
Err(CtError::ShapeMismatch {
op: "residual connection",
..
})
));
Ok(())
}
#[test]
fn layer_normalization_rejects_model_dimension_mismatch() -> CtResult<()> {
let hidden = HiddenSequence::new(vec![vec![1.0, 2.0, 3.0]])?;
let norm = LayerNormalization::new(LayerNormParameters::identity(ModelDimension::new(2)?));
assert!(matches!(
norm.apply(hidden),
Err(CtError::ShapeMismatch {
op: "layer normalization",
..
})
));
Ok(())
}
#[test]
fn layer_norm_parameters_reject_bad_shapes_and_values() {
assert!(matches!(
LayerNormParameters::new(vec![1.0, 1.0], vec![0.0], NormalizationEpsilon(1e-5)),
Err(CtError::ShapeMismatch {
op: "layer norm parameters",
..
})
));
assert!(matches!(
LayerNormParameters::new(vec![f32::NAN], vec![0.0], NormalizationEpsilon(1e-5)),
Err(CtError::ShapeMismatch {
op: "layer norm parameters",
..
})
));
assert!(matches!(
LayerNormParameters::new(vec![1.0], vec![f32::INFINITY], NormalizationEpsilon(1e-5)),
Err(CtError::ShapeMismatch {
op: "layer norm parameters",
..
})
));
assert!(matches!(
NormalizationEpsilon::new(0.0),
Err(CtError::ShapeMismatch {
op: "normalization epsilon",
..
})
));
}
#[test]
fn position_wise_feed_forward_rejects_input_dimension_mismatch() -> CtResult<()> {
let hidden = HiddenSequence::new(vec![vec![1.0, 2.0, 3.0]])?;
let feed_forward = PositionWiseFeedForward::new(
vec![vec![1.0, 0.0], vec![0.0, 1.0]],
vec![0.0, 0.0],
vec![vec![1.0, 0.0], vec![0.0, 1.0]],
vec![0.0, 0.0],
)?;
assert!(matches!(
feed_forward.apply(hidden),
Err(CtError::ShapeMismatch {
op: "position-wise feed-forward",
..
})
));
Ok(())
}
#[test]
fn position_wise_feed_forward_rejects_incompatible_layer_shapes() {
assert!(matches!(
PositionWiseFeedForward::new(
vec![vec![1.0, 0.0]],
vec![0.0, 0.0],
vec![vec![1.0], vec![1.0], vec![1.0]],
vec![0.0],
),
Err(CtError::ShapeMismatch {
op: "position-wise feed-forward",
..
})
));
assert!(matches!(
PositionWiseFeedForward::new(
vec![vec![1.0, 0.0]],
vec![0.0, 0.0],
vec![vec![1.0, 0.0], vec![1.0, 0.0]],
vec![0.0, 0.0],
),
Err(CtError::ShapeMismatch {
op: "position-wise feed-forward",
..
})
));
}
#[test]
fn position_wise_feed_forward_rejects_non_finite_values() {
assert!(matches!(
PositionWiseFeedForward::new(
vec![vec![f32::NAN]],
vec![0.0],
vec![vec![1.0]],
vec![0.0],
),
Err(CtError::ShapeMismatch {
op: "position-wise feed-forward first layer",
..
})
));
assert!(matches!(
PositionWiseFeedForward::new(
vec![vec![1.0]],
vec![0.0],
vec![vec![1.0]],
vec![f32::INFINITY],
),
Err(CtError::ShapeMismatch {
op: "position-wise feed-forward second layer",
..
})
));
}
#[test]
fn positional_encoding_adds_position_rows_and_preserves_shape() -> CtResult<()> {
let hidden = HiddenSequence::new(vec![vec![1.0, 2.0], vec![3.0, 4.0]])?;
let positions = PositionalEncoding::new(vec![vec![0.1, 0.2], vec![0.3, 0.4]])?;
let output = positions.apply(hidden)?;
assert_eq!(output.sequence_len().value(), 2);
assert_eq!(output.model_dimension().value(), 2);
assert!(crate::domain::approx_eq(
output.rows()[0].as_slice()[0],
1.1,
1e-4
));
assert!(crate::domain::approx_eq(
output.rows()[1].as_slice()[1],
4.4,
1e-4
));
Ok(())
}
#[test]
fn positional_encoding_rejects_model_dimension_mismatch() -> CtResult<()> {
let hidden = HiddenSequence::new(vec![vec![1.0, 2.0]])?;
let positions = PositionalEncoding::new(vec![vec![0.1, 0.2, 0.3]])?;
assert!(matches!(
positions.apply(hidden),
Err(CtError::ShapeMismatch {
op: "positional encoding",
..
})
));
Ok(())
}
#[test]
fn positional_encoding_rejects_sequence_too_long() -> CtResult<()> {
let hidden = HiddenSequence::new(vec![vec![1.0], vec![2.0]])?;
let positions = PositionalEncoding::new(vec![vec![0.1]])?;
assert!(matches!(
positions.apply(hidden),
Err(CtError::ShapeMismatch {
op: "positional encoding",
..
})
));
Ok(())
}
#[test]
fn hidden_to_query_projects_hidden_rows() -> CtResult<()> {
let hidden = HiddenSequence::new(vec![vec![1.0, 2.0], vec![3.0, 4.0]])?;
let projection = HiddenToQuery::new(vec![vec![1.0, 0.0], vec![0.0, 1.0]], vec![0.5, -0.5])?;
let queries = projection.apply(hidden)?;
assert_eq!(queries.sequence_len().value(), 2);
assert_eq!(queries.head_dimension().value(), 2);
assert!(crate::domain::approx_eq(
queries.rows()[0].as_slice()[0],
1.5,
1e-4
));
assert!(crate::domain::approx_eq(
queries.rows()[1].as_slice()[1],
3.5,
1e-4
));
Ok(())
}
#[test]
fn hidden_projection_rejects_input_dimension_mismatch() -> CtResult<()> {
let hidden = HiddenSequence::new(vec![vec![1.0, 2.0, 3.0]])?;
let projection = HiddenToValue::new(vec![vec![1.0, 0.0], vec![0.0, 1.0]], vec![0.0, 0.0])?;
assert!(matches!(
projection.apply(hidden),
Err(CtError::ShapeMismatch {
op: "hidden-to-value projection",
..
})
));
Ok(())
}
#[test]
fn residual_connection_adds_hidden_sequences() -> CtResult<()> {
let left = HiddenSequence::new(vec![vec![1.0, 2.0]])?;
let right = HiddenSequence::new(vec![vec![3.0, 4.0]])?;
let output = ResidualConnection.apply(Product::new(left, right))?;
assert_eq!(output.rows()[0].as_slice(), &[4.0, 6.0]);
Ok(())
}
#[test]
fn single_head_transformer_block_preserves_hidden_sequence_shape() -> CtResult<()> {
let block = tiny_single_head_block()?;
let hidden = HiddenSequence::new(vec![vec![1.0, 0.0], vec![0.0, 1.0]])?;
let output = block.apply(hidden)?;
assert_eq!(output.sequence_len().value(), 2);
assert_eq!(output.model_dimension().value(), 2);
assert!(
output
.rows()
.iter()
.flat_map(|row| row.as_slice())
.all(|value| value.is_finite())
);
Ok(())
}
#[test]
fn single_head_transformer_block_rejects_constructor_dimension_mismatch() -> CtResult<()> {
let query = HiddenToQuery::new(vec![vec![1.0, 0.0], vec![0.0, 1.0]], vec![0.0, 0.0])?;
let key = HiddenToKey::new(
vec![vec![1.0, 0.0], vec![0.0, 1.0], vec![0.0, 0.0]],
vec![0.0, 0.0],
)?;
let value = HiddenToValue::new(vec![vec![1.0, 0.0], vec![0.0, 1.0]], vec![0.0, 0.0])?;
let output_projection =
AttentionOutputProjection::new(vec![vec![1.0, 0.0], vec![0.0, 1.0]], vec![0.0, 0.0])?;
let model_dimension = ModelDimension::new(2)?;
let attention_norm =
LayerNormalization::new(LayerNormParameters::identity(model_dimension));
let feed_forward = PositionWiseFeedForward::new(
vec![vec![1.0, 0.0], vec![0.0, 1.0]],
vec![0.0, 0.0],
vec![vec![1.0, 0.0], vec![0.0, 1.0]],
vec![0.0, 0.0],
)?;
let feed_forward_norm =
LayerNormalization::new(LayerNormParameters::identity(model_dimension));
assert!(matches!(
SingleHeadTransformerBlock::new(
query,
key,
value,
output_projection,
attention_norm,
feed_forward,
feed_forward_norm,
),
Err(CtError::ShapeMismatch {
op: "single-head block key projection",
..
})
));
Ok(())
}
#[test]
fn single_head_transformer_block_rejects_apply_dimension_mismatch() -> CtResult<()> {
let block = tiny_single_head_block()?;
let hidden = HiddenSequence::new(vec![vec![1.0, 0.0, 0.0]])?;
assert!(matches!(
block.apply(hidden),
Err(CtError::ShapeMismatch {
op: "single-head block",
..
})
));
Ok(())
}
#[test]
fn self_attention_head_rejects_query_key_head_mismatch() -> CtResult<()> {
let query = HiddenToQuery::new(vec![vec![1.0], vec![0.0]], vec![0.0])?;
let key = HiddenToKey::new(vec![vec![1.0, 0.0], vec![0.0, 1.0]], vec![0.0, 0.0])?;
let value = HiddenToValue::new(vec![vec![1.0], vec![0.0]], vec![0.0])?;
assert!(matches!(
SelfAttentionHead::new(query, key, value),
Err(CtError::ShapeMismatch {
op: "self-attention head",
..
})
));
Ok(())
}
#[test]
fn multi_head_transformer_block_preserves_hidden_sequence_shape() -> CtResult<()> {
let block = tiny_multi_head_block()?;
let hidden = HiddenSequence::new(vec![vec![1.0, 0.0], vec![0.0, 1.0]])?;
let output = block.apply(hidden)?;
assert_eq!(block.head_count().value(), 2);
assert_eq!(block.value_dimension().value(), 1);
assert_eq!(output.sequence_len().value(), 2);
assert_eq!(output.model_dimension().value(), 2);
assert!(
output
.rows()
.iter()
.flat_map(|row| row.as_slice())
.all(|value| value.is_finite())
);
Ok(())
}
#[test]
fn multi_head_transformer_block_rejects_value_dimension_mismatch() -> CtResult<()> {
let head_a = tiny_self_attention_head_first_feature()?;
let head_b = SelfAttentionHead::new(
HiddenToQuery::new(vec![vec![0.0], vec![1.0]], vec![0.0])?,
HiddenToKey::new(vec![vec![0.0], vec![1.0]], vec![0.0])?,
HiddenToValue::new(vec![vec![0.0, 1.0], vec![1.0, 0.0]], vec![0.0, 0.0])?,
)?;
let model_dimension = ModelDimension::new(2)?;
assert!(matches!(
MultiHeadTransformerBlock::new(
vec![head_a, head_b],
AttentionOutputProjection::new(
vec![vec![1.0, 0.0], vec![0.0, 1.0]],
vec![0.0, 0.0],
)?,
LayerNormalization::new(LayerNormParameters::identity(model_dimension)),
identity_feed_forward()?,
LayerNormalization::new(LayerNormParameters::identity(model_dimension)),
),
Err(CtError::ShapeMismatch {
op: "multi-head block",
..
})
));
Ok(())
}
#[test]
fn multi_head_transformer_block_rejects_output_projection_input_mismatch() -> CtResult<()> {
let model_dimension = ModelDimension::new(2)?;
assert!(matches!(
MultiHeadTransformerBlock::new(
vec![
tiny_self_attention_head_first_feature()?,
tiny_self_attention_head_second_feature()?,
],
AttentionOutputProjection::new(
vec![vec![1.0, 0.0], vec![0.0, 1.0], vec![0.0, 0.0]],
vec![0.0, 0.0],
)?,
LayerNormalization::new(LayerNormParameters::identity(model_dimension)),
identity_feed_forward()?,
LayerNormalization::new(LayerNormParameters::identity(model_dimension)),
),
Err(CtError::ShapeMismatch {
op: "multi-head block output projection input",
..
})
));
Ok(())
}
#[test]
fn multi_head_transformer_block_rejects_apply_dimension_mismatch() -> CtResult<()> {
let block = tiny_multi_head_block()?;
let hidden = HiddenSequence::new(vec![vec![1.0, 0.0, 0.0]])?;
assert!(matches!(
block.apply(hidden),
Err(CtError::ShapeMismatch {
op: "multi-head block",
..
})
));
Ok(())
}
#[test]
fn masked_multi_head_transformer_block_preserves_hidden_sequence_shape() -> CtResult<()> {
let block = tiny_masked_multi_head_block()?;
let hidden = HiddenSequence::new(vec![vec![1.0, 0.0], vec![0.0, 1.0]])?;
let mask = AttentionMask::new(vec![vec![true, false], vec![true, true]])?;
let output = block.apply(Product::new(hidden, mask))?;
assert_eq!(block.head_count().value(), 2);
assert_eq!(output.sequence_len().value(), 2);
assert_eq!(output.model_dimension().value(), 2);
assert!(
output
.rows()
.iter()
.flat_map(|row| row.as_slice())
.all(|value| value.is_finite())
);
Ok(())
}
#[test]
fn masked_multi_head_transformer_block_rejects_mask_shape_mismatch() -> CtResult<()> {
let block = tiny_masked_multi_head_block()?;
let hidden = HiddenSequence::new(vec![vec![1.0, 0.0], vec![0.0, 1.0]])?;
let mask = AttentionMask::new(vec![vec![true, true, true], vec![true, true, true]])?;
assert!(matches!(
block.apply(Product::new(hidden, mask)),
Err(CtError::ShapeMismatch {
op: "masked attention scores",
..
})
));
Ok(())
}
#[test]
fn transformer_readout_maps_each_hidden_position_to_logits() -> CtResult<()> {
let readout = TransformerReadout::new(
vec![vec![1.0, 0.0, 0.5], vec![0.0, 1.0, -0.5]],
vec![0.0, 0.1, -0.1],
)?;
let hidden = HiddenSequence::new(vec![vec![2.0, 3.0], vec![4.0, 5.0]])?;
let logits = readout.apply(hidden)?;
assert_eq!(logits.sequence_len().value(), 2);
assert_eq!(logits.vocab_size().value(), 3);
assert_eq!(logits.rows()[0].as_slice(), &[2.0, 3.1, -0.6]);
assert_eq!(logits.rows()[1].as_slice(), &[4.0, 5.1, -0.6]);
Ok(())
}
#[test]
fn tiny_transformer_parameters_forward_maps_hidden_and_mask_to_sequence_logits() -> CtResult<()>
{
let parameters = tiny_transformer_parameters()?;
let hidden = HiddenSequence::new(vec![vec![1.0, 0.0], vec![0.0, 1.0]])?;
let mask = AttentionMask::new(vec![vec![true, false], vec![true, true]])?;
let logits = parameters.apply(Product::new(hidden, mask))?;
assert_eq!(parameters.model_dimension().value(), 2);
assert_eq!(parameters.max_sequence_len().value(), 2);
assert_eq!(logits.sequence_len().value(), 2);
assert_eq!(logits.vocab_size().value(), 3);
assert!(
logits
.rows()
.iter()
.flat_map(|row| row.as_slice())
.all(|value| value.is_finite())
);
Ok(())
}
#[test]
fn tiny_transformer_parameters_rejects_readout_dimension_mismatch() -> CtResult<()> {
let positional_encoding = PositionalEncoding::new(vec![vec![0.1, 0.0], vec![0.0, 0.1]])?;
let block = tiny_masked_multi_head_block()?;
let readout = TransformerReadout::new(vec![vec![1.0], vec![0.0], vec![0.5]], vec![0.0])?;
assert!(matches!(
TinyTransformerParameters::new(positional_encoding, block, readout),
Err(CtError::ShapeMismatch {
op: "tiny transformer parameters readout",
..
})
));
Ok(())
}
#[test]
fn transformer_training_state_records_updated_parameters_and_step_count() -> CtResult<()> {
let initial_parameters = tiny_transformer_parameters()?;
let updated_parameters = tiny_transformer_parameters()?;
let state = TransformerTrainingState::new(initial_parameters, LearningRate::new(0.25)?);
let next_state = state.record_updated_parameters(updated_parameters.clone());
assert_eq!(next_state.parameters(), &updated_parameters);
assert_eq!(next_state.learning_rate().value(), 0.25);
assert_eq!(next_state.step_count().value(), 1);
Ok(())
}
#[test]
fn transformer_training_state_forward_uses_structured_parameters() -> CtResult<()> {
let state =
TransformerTrainingState::new(tiny_transformer_parameters()?, LearningRate::new(0.1)?);
let hidden = HiddenSequence::new(vec![vec![1.0, 0.0], vec![0.0, 1.0]])?;
let mask = AttentionMask::new(vec![vec![true, false], vec![true, true]])?;
let logits = state.apply(Product::new(hidden, mask))?;
assert_eq!(logits.sequence_len().value(), 2);
assert_eq!(logits.vocab_size().value(), 3);
assert_eq!(state.step_count().value(), 0);
Ok(())
}
#[test]
fn transformer_readout_training_example_rejects_target_length_mismatch() -> CtResult<()> {
let hidden = HiddenSequence::new(vec![vec![1.0, 0.0], vec![0.0, 1.0]])?;
let mask = AttentionMask::new(vec![vec![true, false], vec![true, true]])?;
let targets = TokenSequence::from_indices([0])?;
assert!(matches!(
TransformerReadoutTrainingExample::new(hidden, mask, targets),
Err(CtError::ShapeMismatch {
op: "transformer readout training targets",
..
})
));
Ok(())
}
#[test]
fn transformer_readout_training_example_rejects_mask_shape_mismatch() -> CtResult<()> {
let hidden = HiddenSequence::new(vec![vec![1.0, 0.0], vec![0.0, 1.0]])?;
let mask = AttentionMask::new(vec![vec![true, true, true], vec![true, true, true]])?;
let targets = TokenSequence::from_indices([0, 1])?;
assert!(matches!(
TransformerReadoutTrainingExample::new(hidden, mask, targets),
Err(CtError::ShapeMismatch {
op: "transformer readout training mask",
..
})
));
Ok(())
}
#[test]
fn transformer_readout_train_step_reduces_sequence_loss() -> CtResult<()> {
let dataset = tiny_transformer_training_set()?;
let state =
TransformerTrainingState::new(tiny_transformer_parameters()?, LearningRate::new(0.5)?);
let before = transformer_readout_average_loss(&state, &dataset)?;
let train_step = TransformerReadoutTrainStep::new(dataset.clone());
let trained =
crate::category::apply_endomorphism_n_times(&train_step, state, StepCount::new(40))?;
let after = transformer_readout_average_loss(&trained, &dataset)?;
assert!(after.value() < before.value());
assert_eq!(trained.step_count().value(), 40);
Ok(())
}
#[test]
fn transformer_readout_train_step_rejects_target_outside_vocabulary() -> CtResult<()> {
let example = TransformerReadoutTrainingExample::new(
HiddenSequence::new(vec![vec![1.0, 0.0], vec![0.0, 1.0]])?,
AttentionMask::new(vec![vec![true, false], vec![true, true]])?,
TokenSequence::from_indices([0, 9])?,
)?;
let dataset = TransformerReadoutTrainingSet::new([example])?;
let state =
TransformerTrainingState::new(tiny_transformer_parameters()?, LearningRate::new(0.1)?);
let train_step = TransformerReadoutTrainStep::new(dataset);
assert!(matches!(
train_step.apply(state),
Err(CtError::OutOfRange {
kind: "sequence target",
index: 9,
limit: 3,
})
));
Ok(())
}
#[test]
fn transformer_feed_forward_training_example_rejects_target_shape_mismatch() -> CtResult<()> {
let input = HiddenSequence::new(vec![vec![1.0, 0.0], vec![0.0, 1.0]])?;
let target = HiddenSequence::new(vec![vec![1.0, 0.0, 0.0], vec![0.0, 1.0, 0.0]])?;
assert!(matches!(
TransformerFeedForwardTrainingExample::new(input, target),
Err(CtError::ShapeMismatch {
op: "transformer feed-forward training dimension",
..
})
));
Ok(())
}
#[test]
fn transformer_feed_forward_training_set_rejects_empty_input() {
assert!(matches!(
TransformerFeedForwardTrainingSet::new([]),
Err(CtError::EmptyInput("transformer feed-forward training set"))
));
}
#[test]
fn transformer_feed_forward_train_step_reduces_local_hidden_loss() -> CtResult<()> {
let dataset = tiny_feed_forward_training_set()?;
let state =
TransformerTrainingState::new(tiny_transformer_parameters()?, LearningRate::new(0.2)?);
let before = transformer_feed_forward_average_loss(&state, &dataset)?;
let train_step = TransformerFeedForwardTrainStep::new(dataset.clone());
let trained =
crate::category::apply_endomorphism_n_times(&train_step, state, StepCount::new(60))?;
let after = transformer_feed_forward_average_loss(&trained, &dataset)?;
assert!(after.value() < before.value());
assert_eq!(trained.step_count().value(), 60);
Ok(())
}
#[test]
fn transformer_block_training_example_rejects_mask_shape_mismatch() -> CtResult<()> {
let hidden = HiddenSequence::new(vec![vec![1.0, 0.0], vec![0.0, 1.0]])?;
let mask = AttentionMask::new(vec![vec![true, true, true], vec![true, true, true]])?;
let targets = TokenSequence::from_indices([0, 1])?;
assert!(matches!(
TransformerBlockTrainingExample::new(hidden, mask, targets),
Err(CtError::ShapeMismatch {
op: "transformer block training mask",
..
})
));
Ok(())
}
#[test]
fn transformer_block_training_set_rejects_empty_input() {
assert!(matches!(
TransformerBlockTrainingSet::new([]),
Err(CtError::EmptyInput("transformer block training set"))
));
}
#[test]
fn transformer_block_train_step_rejects_target_outside_vocabulary() -> CtResult<()> {
let example = TransformerBlockTrainingExample::new(
HiddenSequence::new(vec![vec![1.0, 0.0], vec![0.0, 1.0]])?,
AttentionMask::new(vec![vec![true, false], vec![true, true]])?,
TokenSequence::from_indices([0, 9])?,
)?;
let dataset = TransformerBlockTrainingSet::new([example])?;
let state =
TransformerTrainingState::new(tiny_transformer_parameters()?, LearningRate::new(0.1)?);
let train_step = TransformerBlockTrainStep::new(dataset);
assert!(matches!(
train_step.apply(state),
Err(CtError::OutOfRange {
kind: "sequence target",
index: 9,
limit: 3,
})
));
Ok(())
}
#[test]
fn transformer_block_train_step_reduces_sequence_loss() -> CtResult<()> {
let dataset = tiny_transformer_block_training_set()?;
let state =
TransformerTrainingState::new(tiny_transformer_parameters()?, LearningRate::new(0.2)?);
let before = transformer_block_average_loss(&state, &dataset)?;
let train_step = TransformerBlockTrainStep::new(dataset.clone());
let trained =
crate::category::apply_endomorphism_n_times(&train_step, state, StepCount::new(40))?;
let after = transformer_block_average_loss(&trained, &dataset)?;
assert!(after.value() < before.value());
assert_eq!(trained.step_count().value(), 40);
Ok(())
}
#[test]
fn transformer_block_train_step_updates_attention_output_projection() -> CtResult<()> {
let dataset = tiny_transformer_block_training_set()?;
let state =
TransformerTrainingState::new(tiny_transformer_parameters()?, LearningRate::new(0.2)?);
let before = state.parameters().output_projection().weight().to_vec();
let train_step = TransformerBlockTrainStep::new(dataset);
let trained = train_step.apply(state)?;
let after = trained.parameters().output_projection().weight().to_vec();
assert_ne!(before, after);
assert_eq!(trained.step_count().value(), 1);
Ok(())
}
#[test]
fn transformer_block_train_step_updates_layer_norm_parameters() -> CtResult<()> {
let dataset = tiny_transformer_block_training_set()?;
let state =
TransformerTrainingState::new(tiny_transformer_parameters()?, LearningRate::new(0.2)?);
let before_attention_scale = state
.parameters()
.attention_norm()
.parameters()
.scale()
.to_vec();
let before_feed_forward_shift = state
.parameters()
.feed_forward_norm()
.parameters()
.shift()
.to_vec();
let train_step = TransformerBlockTrainStep::new(dataset);
let trained = train_step.apply(state)?;
let after_attention_scale = trained
.parameters()
.attention_norm()
.parameters()
.scale()
.to_vec();
let after_feed_forward_shift = trained
.parameters()
.feed_forward_norm()
.parameters()
.shift()
.to_vec();
assert_ne!(before_attention_scale, after_attention_scale);
assert_ne!(before_feed_forward_shift, after_feed_forward_shift);
assert_eq!(trained.step_count().value(), 1);
Ok(())
}
#[test]
fn transformer_block_train_step_updates_query_key_value_projections() -> CtResult<()> {
let dataset = tiny_transformer_block_training_set()?;
let state =
TransformerTrainingState::new(tiny_transformer_parameters()?, LearningRate::new(0.2)?);
let before_query = state
.parameters()
.attention_heads()
.iter()
.map(|head| head.query_projection().weight().to_vec())
.collect::<Vec<_>>();
let before_key = state
.parameters()
.attention_heads()
.iter()
.map(|head| head.key_projection().weight().to_vec())
.collect::<Vec<_>>();
let before_value = state
.parameters()
.attention_heads()
.iter()
.map(|head| head.value_projection().weight().to_vec())
.collect::<Vec<_>>();
let train_step = TransformerBlockTrainStep::new(dataset);
let trained = train_step.apply(state)?;
let after_query = trained
.parameters()
.attention_heads()
.iter()
.map(|head| head.query_projection().weight().to_vec())
.collect::<Vec<_>>();
let after_key = trained
.parameters()
.attention_heads()
.iter()
.map(|head| head.key_projection().weight().to_vec())
.collect::<Vec<_>>();
let after_value = trained
.parameters()
.attention_heads()
.iter()
.map(|head| head.value_projection().weight().to_vec())
.collect::<Vec<_>>();
assert_ne!(before_query, after_query);
assert_ne!(before_key, after_key);
assert_ne!(before_value, after_value);
assert_eq!(trained.step_count().value(), 1);
Ok(())
}
#[test]
fn transformer_block_train_step_matches_finite_difference_for_attention_projection()
-> CtResult<()> {
let dataset = tiny_transformer_block_training_set()?;
let state =
TransformerTrainingState::new(tiny_transformer_parameters()?, LearningRate::new(0.2)?);
let train_step = TransformerBlockTrainStep::new(dataset.clone());
let trained = train_step.apply(state.clone())?;
let selection = largest_changed_attention_projection(&state, &trained)?;
let before_value = attention_projection_weight(&state, selection)?;
let after_value = attention_projection_weight(&trained, selection)?;
let inferred_gradient = (before_value - after_value) / state.learning_rate().value();
let epsilon = 1e-3;
let loss_plus = transformer_block_average_loss(
&state_with_attention_projection_weight(&state, selection, before_value + epsilon)?,
&dataset,
)?
.value();
let loss_minus = transformer_block_average_loss(
&state_with_attention_projection_weight(&state, selection, before_value - epsilon)?,
&dataset,
)?
.value();
let finite_difference = (loss_plus - loss_minus) / (2.0 * epsilon);
assert!(
(inferred_gradient - finite_difference).abs() < 1e-2,
"inferred gradient {inferred_gradient} should match finite difference {finite_difference}"
);
Ok(())
}
#[test]
fn transformer_block_train_step_matches_finite_difference_for_readout_weight() -> CtResult<()> {
let dataset = tiny_transformer_block_training_set()?;
let state =
TransformerTrainingState::new(tiny_transformer_parameters()?, LearningRate::new(0.2)?);
let train_step = TransformerBlockTrainStep::new(dataset.clone());
let trained = train_step.apply(state.clone())?;
let selection = largest_changed_readout_weight(&state, &trained)?;
let before_value = readout_weight(&state, selection);
let after_value = readout_weight(&trained, selection);
assert_block_gradient_matches_finite_difference(
&state,
&dataset,
before_value,
after_value,
|value| state_with_readout_weight(&state, selection, value),
)
}
#[test]
fn transformer_block_train_step_matches_finite_difference_for_feed_forward_weight()
-> CtResult<()> {
let dataset = tiny_transformer_block_training_set()?;
let state =
TransformerTrainingState::new(tiny_transformer_parameters()?, LearningRate::new(0.2)?);
let train_step = TransformerBlockTrainStep::new(dataset.clone());
let trained = train_step.apply(state.clone())?;
let selection = largest_changed_feed_forward_weight(&state, &trained)?;
let before_value = feed_forward_weight(&state, selection);
let after_value = feed_forward_weight(&trained, selection);
assert_block_gradient_matches_finite_difference(
&state,
&dataset,
before_value,
after_value,
|value| state_with_feed_forward_weight(&state, selection, value),
)
}
#[test]
fn transformer_block_train_step_matches_finite_difference_for_layer_norm_parameter()
-> CtResult<()> {
let dataset = tiny_transformer_block_training_set()?;
let state =
TransformerTrainingState::new(tiny_transformer_parameters()?, LearningRate::new(0.2)?);
let train_step = TransformerBlockTrainStep::new(dataset.clone());
let trained = train_step.apply(state.clone())?;
let selection = largest_changed_layer_norm_parameter(&state, &trained)?;
let before_value = layer_norm_parameter_value(&state, selection);
let after_value = layer_norm_parameter_value(&trained, selection);
assert_block_gradient_matches_finite_difference(
&state,
&dataset,
before_value,
after_value,
|value| state_with_layer_norm_parameter(&state, selection, value),
)
}
#[test]
fn transformer_block_train_step_matches_finite_difference_for_readout_bias() -> CtResult<()> {
let dataset = tiny_transformer_block_training_set()?;
let state =
TransformerTrainingState::new(tiny_transformer_parameters()?, LearningRate::new(0.2)?);
let train_step = TransformerBlockTrainStep::new(dataset.clone());
let trained = train_step.apply(state.clone())?;
let selection = largest_changed_vector_index(
state.parameters().readout().bias(),
trained.parameters().readout().bias(),
"changed readout bias",
)?;
let before_value = state.parameters().readout().bias()[selection];
let after_value = trained.parameters().readout().bias()[selection];
assert_block_gradient_matches_finite_difference(
&state,
&dataset,
before_value,
after_value,
|value| state_with_readout_bias(&state, selection, value),
)
}
#[test]
fn transformer_block_train_step_matches_finite_difference_for_feed_forward_bias() -> CtResult<()>
{
let dataset = tiny_transformer_block_training_set()?;
let state =
TransformerTrainingState::new(tiny_transformer_parameters()?, LearningRate::new(0.2)?);
let train_step = TransformerBlockTrainStep::new(dataset.clone());
let trained = train_step.apply(state.clone())?;
let selection = largest_changed_feed_forward_bias(&state, &trained)?;
let before_value = feed_forward_bias(&state, selection);
let after_value = feed_forward_bias(&trained, selection);
assert_block_gradient_matches_finite_difference(
&state,
&dataset,
before_value,
after_value,
|value| state_with_feed_forward_bias(&state, selection, value),
)
}
#[test]
fn transformer_block_train_step_matches_finite_difference_for_output_projection_bias()
-> CtResult<()> {
let dataset = tiny_transformer_block_training_set()?;
let state =
TransformerTrainingState::new(tiny_transformer_parameters()?, LearningRate::new(0.2)?);
let train_step = TransformerBlockTrainStep::new(dataset.clone());
let trained = train_step.apply(state.clone())?;
let selection = largest_changed_vector_index(
state.parameters().output_projection().bias(),
trained.parameters().output_projection().bias(),
"changed attention output projection bias",
)?;
let before_value = state.parameters().output_projection().bias()[selection];
let after_value = trained.parameters().output_projection().bias()[selection];
assert_block_gradient_matches_finite_difference(
&state,
&dataset,
before_value,
after_value,
|value| state_with_output_projection_bias(&state, selection, value),
)
}
#[test]
fn transformer_block_train_step_matches_finite_difference_for_attention_projection_bias()
-> CtResult<()> {
let dataset = tiny_transformer_block_training_set()?;
let state =
TransformerTrainingState::new(tiny_transformer_parameters()?, LearningRate::new(0.2)?);
let train_step = TransformerBlockTrainStep::new(dataset.clone());
let trained = train_step.apply(state.clone())?;
let selection = largest_changed_attention_projection_bias(&state, &trained)?;
let before_value = attention_projection_bias(&state, selection)?;
let after_value = attention_projection_bias(&trained, selection)?;
assert_block_gradient_matches_finite_difference(
&state,
&dataset,
before_value,
after_value,
|value| state_with_attention_projection_bias(&state, selection, value),
)
}
fn assert_block_gradient_matches_finite_difference(
state: &TransformerTrainingState,
dataset: &TransformerBlockTrainingSet,
before_value: f32,
after_value: f32,
mut state_with_value: impl FnMut(f32) -> CtResult<TransformerTrainingState>,
) -> CtResult<()> {
let inferred_gradient = (before_value - after_value) / state.learning_rate().value();
let epsilon = 1e-3;
let loss_plus =
transformer_block_average_loss(&state_with_value(before_value + epsilon)?, dataset)?
.value();
let loss_minus =
transformer_block_average_loss(&state_with_value(before_value - epsilon)?, dataset)?
.value();
let finite_difference = (loss_plus - loss_minus) / (2.0 * epsilon);
assert!(
(inferred_gradient - finite_difference).abs() < 1e-2,
"inferred gradient {inferred_gradient} should match finite difference {finite_difference}"
);
Ok(())
}
fn largest_changed_vector_index(
before: &[f32],
after: &[f32],
label: &'static str,
) -> CtResult<usize> {
let mut selected = None;
let mut largest_delta = 0.0;
for (index, (before_value, after_value)) in before.iter().zip(after).enumerate() {
let delta = (before_value - after_value).abs();
if delta > largest_delta {
largest_delta = delta;
selected = Some(index);
}
}
selected.ok_or(CtError::EmptyInput(label))
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
struct MatrixSelection {
input_index: usize,
output_index: usize,
}
fn largest_changed_matrix_weight(
before: &[Vec<f32>],
after: &[Vec<f32>],
label: &'static str,
) -> CtResult<MatrixSelection> {
let mut selected = None;
let mut largest_delta = 0.0;
for (input_index, (before_row, after_row)) in before.iter().zip(after).enumerate() {
for (output_index, (before_value, after_value)) in
before_row.iter().zip(after_row).enumerate()
{
let delta = (before_value - after_value).abs();
if delta > largest_delta {
largest_delta = delta;
selected = Some(MatrixSelection {
input_index,
output_index,
});
}
}
}
selected.ok_or(CtError::EmptyInput(label))
}
fn largest_changed_readout_weight(
before: &TransformerTrainingState,
after: &TransformerTrainingState,
) -> CtResult<MatrixSelection> {
largest_changed_matrix_weight(
before.parameters().readout().weight(),
after.parameters().readout().weight(),
"changed readout weight",
)
}
fn readout_weight(state: &TransformerTrainingState, selection: MatrixSelection) -> f32 {
state.parameters().readout().weight()[selection.input_index][selection.output_index]
}
fn state_with_readout_weight(
state: &TransformerTrainingState,
selection: MatrixSelection,
value: f32,
) -> CtResult<TransformerTrainingState> {
let readout = state.parameters().readout();
let mut weight = readout.weight().to_vec();
weight[selection.input_index][selection.output_index] = value;
let readout = TransformerReadout::new(weight, readout.bias().to_vec())?;
let parameters = state.parameters().clone().with_readout(readout)?;
Ok(TransformerTrainingState::from_parts(
parameters,
state.learning_rate(),
state.step_count(),
))
}
fn state_with_readout_bias(
state: &TransformerTrainingState,
selection: usize,
value: f32,
) -> CtResult<TransformerTrainingState> {
let readout = state.parameters().readout();
let mut bias = readout.bias().to_vec();
bias[selection] = value;
let readout = TransformerReadout::new(readout.weight().to_vec(), bias)?;
let parameters = state.parameters().clone().with_readout(readout)?;
Ok(TransformerTrainingState::from_parts(
parameters,
state.learning_rate(),
state.step_count(),
))
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
enum FeedForwardWeightKind {
First,
Second,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
struct FeedForwardWeightSelection {
kind: FeedForwardWeightKind,
matrix: MatrixSelection,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
struct FeedForwardBiasSelection {
kind: FeedForwardWeightKind,
index: usize,
}
fn largest_changed_feed_forward_weight(
before: &TransformerTrainingState,
after: &TransformerTrainingState,
) -> CtResult<FeedForwardWeightSelection> {
let first = largest_changed_matrix_weight(
before.parameters().feed_forward().first_weight(),
after.parameters().feed_forward().first_weight(),
"changed first feed-forward weight",
);
let second = largest_changed_matrix_weight(
before.parameters().feed_forward().second_weight(),
after.parameters().feed_forward().second_weight(),
"changed second feed-forward weight",
);
match (first, second) {
(Ok(first), Ok(second)) => {
let first_delta = feed_forward_weight_delta(
before,
after,
FeedForwardWeightSelection {
kind: FeedForwardWeightKind::First,
matrix: first,
},
);
let second_delta = feed_forward_weight_delta(
before,
after,
FeedForwardWeightSelection {
kind: FeedForwardWeightKind::Second,
matrix: second,
},
);
if first_delta >= second_delta {
Ok(FeedForwardWeightSelection {
kind: FeedForwardWeightKind::First,
matrix: first,
})
} else {
Ok(FeedForwardWeightSelection {
kind: FeedForwardWeightKind::Second,
matrix: second,
})
}
}
(Ok(first), Err(_)) => Ok(FeedForwardWeightSelection {
kind: FeedForwardWeightKind::First,
matrix: first,
}),
(Err(_), Ok(second)) => Ok(FeedForwardWeightSelection {
kind: FeedForwardWeightKind::Second,
matrix: second,
}),
(Err(_), Err(_)) => Err(CtError::EmptyInput("changed feed-forward weight")),
}
}
fn feed_forward_weight_delta(
before: &TransformerTrainingState,
after: &TransformerTrainingState,
selection: FeedForwardWeightSelection,
) -> f32 {
(feed_forward_weight(before, selection) - feed_forward_weight(after, selection)).abs()
}
fn feed_forward_weight(
state: &TransformerTrainingState,
selection: FeedForwardWeightSelection,
) -> f32 {
let feed_forward = state.parameters().feed_forward();
match selection.kind {
FeedForwardWeightKind::First => feed_forward.first_weight()
[selection.matrix.input_index][selection.matrix.output_index],
FeedForwardWeightKind::Second => feed_forward.second_weight()
[selection.matrix.input_index][selection.matrix.output_index],
}
}
fn state_with_feed_forward_weight(
state: &TransformerTrainingState,
selection: FeedForwardWeightSelection,
value: f32,
) -> CtResult<TransformerTrainingState> {
let feed_forward = state.parameters().feed_forward();
let mut first_weight = feed_forward.first_weight().to_vec();
let mut second_weight = feed_forward.second_weight().to_vec();
match selection.kind {
FeedForwardWeightKind::First => {
first_weight[selection.matrix.input_index][selection.matrix.output_index] = value;
}
FeedForwardWeightKind::Second => {
second_weight[selection.matrix.input_index][selection.matrix.output_index] = value;
}
}
let feed_forward = PositionWiseFeedForward::new(
first_weight,
feed_forward.first_bias().to_vec(),
second_weight,
feed_forward.second_bias().to_vec(),
)?;
let parameters = state.parameters().clone().with_feed_forward(feed_forward)?;
Ok(TransformerTrainingState::from_parts(
parameters,
state.learning_rate(),
state.step_count(),
))
}
fn largest_changed_feed_forward_bias(
before: &TransformerTrainingState,
after: &TransformerTrainingState,
) -> CtResult<FeedForwardBiasSelection> {
let first = largest_changed_vector_index(
before.parameters().feed_forward().first_bias(),
after.parameters().feed_forward().first_bias(),
"changed first feed-forward bias",
);
let second = largest_changed_vector_index(
before.parameters().feed_forward().second_bias(),
after.parameters().feed_forward().second_bias(),
"changed second feed-forward bias",
);
match (first, second) {
(Ok(first), Ok(second)) => {
let first_delta = feed_forward_bias_delta(
before,
after,
FeedForwardBiasSelection {
kind: FeedForwardWeightKind::First,
index: first,
},
);
let second_delta = feed_forward_bias_delta(
before,
after,
FeedForwardBiasSelection {
kind: FeedForwardWeightKind::Second,
index: second,
},
);
if first_delta >= second_delta {
Ok(FeedForwardBiasSelection {
kind: FeedForwardWeightKind::First,
index: first,
})
} else {
Ok(FeedForwardBiasSelection {
kind: FeedForwardWeightKind::Second,
index: second,
})
}
}
(Ok(first), Err(_)) => Ok(FeedForwardBiasSelection {
kind: FeedForwardWeightKind::First,
index: first,
}),
(Err(_), Ok(second)) => Ok(FeedForwardBiasSelection {
kind: FeedForwardWeightKind::Second,
index: second,
}),
(Err(_), Err(_)) => Err(CtError::EmptyInput("changed feed-forward bias")),
}
}
fn feed_forward_bias_delta(
before: &TransformerTrainingState,
after: &TransformerTrainingState,
selection: FeedForwardBiasSelection,
) -> f32 {
(feed_forward_bias(before, selection) - feed_forward_bias(after, selection)).abs()
}
fn feed_forward_bias(
state: &TransformerTrainingState,
selection: FeedForwardBiasSelection,
) -> f32 {
let feed_forward = state.parameters().feed_forward();
match selection.kind {
FeedForwardWeightKind::First => feed_forward.first_bias()[selection.index],
FeedForwardWeightKind::Second => feed_forward.second_bias()[selection.index],
}
}
fn state_with_feed_forward_bias(
state: &TransformerTrainingState,
selection: FeedForwardBiasSelection,
value: f32,
) -> CtResult<TransformerTrainingState> {
let feed_forward = state.parameters().feed_forward();
let mut first_bias = feed_forward.first_bias().to_vec();
let mut second_bias = feed_forward.second_bias().to_vec();
match selection.kind {
FeedForwardWeightKind::First => {
first_bias[selection.index] = value;
}
FeedForwardWeightKind::Second => {
second_bias[selection.index] = value;
}
}
let feed_forward = PositionWiseFeedForward::new(
feed_forward.first_weight().to_vec(),
first_bias,
feed_forward.second_weight().to_vec(),
second_bias,
)?;
let parameters = state.parameters().clone().with_feed_forward(feed_forward)?;
Ok(TransformerTrainingState::from_parts(
parameters,
state.learning_rate(),
state.step_count(),
))
}
fn state_with_output_projection_bias(
state: &TransformerTrainingState,
selection: usize,
value: f32,
) -> CtResult<TransformerTrainingState> {
let output_projection = state.parameters().output_projection();
let mut bias = output_projection.bias().to_vec();
bias[selection] = value;
let output_projection =
AttentionOutputProjection::new(output_projection.weight().to_vec(), bias)?;
let parameters = state
.parameters()
.clone()
.with_output_projection(output_projection)?;
Ok(TransformerTrainingState::from_parts(
parameters,
state.learning_rate(),
state.step_count(),
))
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
enum LayerNormParameterKind {
AttentionScale,
AttentionShift,
FeedForwardScale,
FeedForwardShift,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
struct LayerNormParameterSelection {
kind: LayerNormParameterKind,
feature_index: usize,
}
fn largest_changed_layer_norm_parameter(
before: &TransformerTrainingState,
after: &TransformerTrainingState,
) -> CtResult<LayerNormParameterSelection> {
let mut selected = None;
let mut largest_delta = 0.0;
for kind in [
LayerNormParameterKind::AttentionScale,
LayerNormParameterKind::AttentionShift,
LayerNormParameterKind::FeedForwardScale,
LayerNormParameterKind::FeedForwardShift,
] {
let before_values = layer_norm_parameter_values(before, kind);
let after_values = layer_norm_parameter_values(after, kind);
for (feature_index, (before_value, after_value)) in
before_values.iter().zip(after_values).enumerate()
{
let delta = (before_value - after_value).abs();
if delta > largest_delta {
largest_delta = delta;
selected = Some(LayerNormParameterSelection {
kind,
feature_index,
});
}
}
}
selected.ok_or(CtError::EmptyInput("changed layer norm parameter"))
}
fn layer_norm_parameter_values(
state: &TransformerTrainingState,
kind: LayerNormParameterKind,
) -> &[f32] {
match kind {
LayerNormParameterKind::AttentionScale => {
state.parameters().attention_norm().parameters().scale()
}
LayerNormParameterKind::AttentionShift => {
state.parameters().attention_norm().parameters().shift()
}
LayerNormParameterKind::FeedForwardScale => {
state.parameters().feed_forward_norm().parameters().scale()
}
LayerNormParameterKind::FeedForwardShift => {
state.parameters().feed_forward_norm().parameters().shift()
}
}
}
fn layer_norm_parameter_value(
state: &TransformerTrainingState,
selection: LayerNormParameterSelection,
) -> f32 {
layer_norm_parameter_values(state, selection.kind)[selection.feature_index]
}
fn state_with_layer_norm_parameter(
state: &TransformerTrainingState,
selection: LayerNormParameterSelection,
value: f32,
) -> CtResult<TransformerTrainingState> {
let attention_parameters = state.parameters().attention_norm().parameters();
let feed_forward_parameters = state.parameters().feed_forward_norm().parameters();
let mut attention_scale = attention_parameters.scale().to_vec();
let mut attention_shift = attention_parameters.shift().to_vec();
let mut feed_forward_scale = feed_forward_parameters.scale().to_vec();
let mut feed_forward_shift = feed_forward_parameters.shift().to_vec();
match selection.kind {
LayerNormParameterKind::AttentionScale => {
attention_scale[selection.feature_index] = value;
}
LayerNormParameterKind::AttentionShift => {
attention_shift[selection.feature_index] = value;
}
LayerNormParameterKind::FeedForwardScale => {
feed_forward_scale[selection.feature_index] = value;
}
LayerNormParameterKind::FeedForwardShift => {
feed_forward_shift[selection.feature_index] = value;
}
}
let attention_norm = LayerNormalization::new(LayerNormParameters::new(
attention_scale,
attention_shift,
attention_parameters.epsilon(),
)?);
let feed_forward_norm = LayerNormalization::new(LayerNormParameters::new(
feed_forward_scale,
feed_forward_shift,
feed_forward_parameters.epsilon(),
)?);
let parameters = state
.parameters()
.clone()
.with_layer_norms(attention_norm, feed_forward_norm)?;
Ok(TransformerTrainingState::from_parts(
parameters,
state.learning_rate(),
state.step_count(),
))
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
enum AttentionProjectionKind {
Query,
Key,
Value,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
struct AttentionProjectionSelection {
head_index: usize,
kind: AttentionProjectionKind,
input_index: usize,
output_index: usize,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
struct AttentionProjectionBiasSelection {
head_index: usize,
kind: AttentionProjectionKind,
output_index: usize,
}
fn largest_changed_attention_projection(
before: &TransformerTrainingState,
after: &TransformerTrainingState,
) -> CtResult<AttentionProjectionSelection> {
let head_count = before.parameters().attention_heads().len();
let mut selected = None;
let mut largest_delta = 0.0;
for head_index in 0..head_count {
for kind in [
AttentionProjectionKind::Query,
AttentionProjectionKind::Key,
AttentionProjectionKind::Value,
] {
let before_weight = attention_projection_weight_matrix(before, head_index, kind)?;
let after_weight = attention_projection_weight_matrix(after, head_index, kind)?;
for (input_index, (before_row, after_row)) in
before_weight.iter().zip(after_weight).enumerate()
{
for (output_index, (before_value, after_value)) in
before_row.iter().zip(after_row).enumerate()
{
let delta = (before_value - after_value).abs();
if delta > largest_delta {
largest_delta = delta;
selected = Some(AttentionProjectionSelection {
head_index,
kind,
input_index,
output_index,
});
}
}
}
}
}
selected.ok_or(CtError::EmptyInput("changed attention projection"))
}
fn attention_projection_weight(
state: &TransformerTrainingState,
selection: AttentionProjectionSelection,
) -> CtResult<f32> {
let weight =
attention_projection_weight_matrix(state, selection.head_index, selection.kind)?;
Ok(weight[selection.input_index][selection.output_index])
}
fn attention_projection_weight_matrix(
state: &TransformerTrainingState,
head_index: usize,
kind: AttentionProjectionKind,
) -> CtResult<&[Vec<f32>]> {
let head =
state
.parameters()
.attention_heads()
.get(head_index)
.ok_or(CtError::OutOfRange {
kind: "attention head",
index: head_index,
limit: state.parameters().attention_heads().len(),
})?;
Ok(match kind {
AttentionProjectionKind::Query => head.query_projection().weight(),
AttentionProjectionKind::Key => head.key_projection().weight(),
AttentionProjectionKind::Value => head.value_projection().weight(),
})
}
fn largest_changed_attention_projection_bias(
before: &TransformerTrainingState,
after: &TransformerTrainingState,
) -> CtResult<AttentionProjectionBiasSelection> {
let head_count = before.parameters().attention_heads().len();
let mut selected = None;
let mut largest_delta = 0.0;
for head_index in 0..head_count {
for kind in [
AttentionProjectionKind::Query,
AttentionProjectionKind::Key,
AttentionProjectionKind::Value,
] {
let before_bias = attention_projection_bias_values(before, head_index, kind)?;
let after_bias = attention_projection_bias_values(after, head_index, kind)?;
for (output_index, (before_value, after_value)) in
before_bias.iter().zip(after_bias).enumerate()
{
let delta = (before_value - after_value).abs();
if delta > largest_delta {
largest_delta = delta;
selected = Some(AttentionProjectionBiasSelection {
head_index,
kind,
output_index,
});
}
}
}
}
selected.ok_or(CtError::EmptyInput("changed attention projection bias"))
}
fn attention_projection_bias(
state: &TransformerTrainingState,
selection: AttentionProjectionBiasSelection,
) -> CtResult<f32> {
let bias = attention_projection_bias_values(state, selection.head_index, selection.kind)?;
Ok(bias[selection.output_index])
}
fn attention_projection_bias_values(
state: &TransformerTrainingState,
head_index: usize,
kind: AttentionProjectionKind,
) -> CtResult<&[f32]> {
let head =
state
.parameters()
.attention_heads()
.get(head_index)
.ok_or(CtError::OutOfRange {
kind: "attention head",
index: head_index,
limit: state.parameters().attention_heads().len(),
})?;
Ok(match kind {
AttentionProjectionKind::Query => head.query_projection().bias(),
AttentionProjectionKind::Key => head.key_projection().bias(),
AttentionProjectionKind::Value => head.value_projection().bias(),
})
}
fn state_with_attention_projection_weight(
state: &TransformerTrainingState,
selection: AttentionProjectionSelection,
value: f32,
) -> CtResult<TransformerTrainingState> {
let mut heads = state.parameters().attention_heads().to_vec();
let head = heads.get(selection.head_index).ok_or(CtError::OutOfRange {
kind: "attention head",
index: selection.head_index,
limit: heads.len(),
})?;
let mut query_weight = head.query_projection().weight().to_vec();
let mut key_weight = head.key_projection().weight().to_vec();
let mut value_weight = head.value_projection().weight().to_vec();
match selection.kind {
AttentionProjectionKind::Query => {
query_weight[selection.input_index][selection.output_index] = value;
}
AttentionProjectionKind::Key => {
key_weight[selection.input_index][selection.output_index] = value;
}
AttentionProjectionKind::Value => {
value_weight[selection.input_index][selection.output_index] = value;
}
}
heads[selection.head_index] = SelfAttentionHead::new(
HiddenToQuery::new(query_weight, head.query_projection().bias().to_vec())?,
HiddenToKey::new(key_weight, head.key_projection().bias().to_vec())?,
HiddenToValue::new(value_weight, head.value_projection().bias().to_vec())?,
)?;
let parameters = state.parameters().clone().with_attention_heads(heads)?;
Ok(TransformerTrainingState::from_parts(
parameters,
state.learning_rate(),
state.step_count(),
))
}
fn state_with_attention_projection_bias(
state: &TransformerTrainingState,
selection: AttentionProjectionBiasSelection,
value: f32,
) -> CtResult<TransformerTrainingState> {
let mut heads = state.parameters().attention_heads().to_vec();
let head = heads.get(selection.head_index).ok_or(CtError::OutOfRange {
kind: "attention head",
index: selection.head_index,
limit: heads.len(),
})?;
let mut query_bias = head.query_projection().bias().to_vec();
let mut key_bias = head.key_projection().bias().to_vec();
let mut value_bias = head.value_projection().bias().to_vec();
match selection.kind {
AttentionProjectionKind::Query => {
query_bias[selection.output_index] = value;
}
AttentionProjectionKind::Key => {
key_bias[selection.output_index] = value;
}
AttentionProjectionKind::Value => {
value_bias[selection.output_index] = value;
}
}
heads[selection.head_index] = SelfAttentionHead::new(
HiddenToQuery::new(head.query_projection().weight().to_vec(), query_bias)?,
HiddenToKey::new(head.key_projection().weight().to_vec(), key_bias)?,
HiddenToValue::new(head.value_projection().weight().to_vec(), value_bias)?,
)?;
let parameters = state.parameters().clone().with_attention_heads(heads)?;
Ok(TransformerTrainingState::from_parts(
parameters,
state.learning_rate(),
state.step_count(),
))
}
fn tiny_single_head_block() -> CtResult<SingleHeadTransformerBlock> {
let query = HiddenToQuery::new(vec![vec![1.0, 0.0], vec![0.0, 1.0]], vec![0.0, 0.0])?;
let key = HiddenToKey::new(vec![vec![1.0, 0.0], vec![0.0, 1.0]], vec![0.0, 0.0])?;
let value = HiddenToValue::new(vec![vec![1.0, 0.0], vec![0.0, 1.0]], vec![0.0, 0.0])?;
let output_projection =
AttentionOutputProjection::new(vec![vec![1.0, 0.0], vec![0.0, 1.0]], vec![0.0, 0.0])?;
let model_dimension = ModelDimension::new(2)?;
let attention_norm =
LayerNormalization::new(LayerNormParameters::identity(model_dimension));
let feed_forward = PositionWiseFeedForward::new(
vec![vec![1.0, 0.0], vec![0.0, 1.0]],
vec![0.0, 0.0],
vec![vec![1.0, 0.0], vec![0.0, 1.0]],
vec![0.0, 0.0],
)?;
let feed_forward_norm =
LayerNormalization::new(LayerNormParameters::identity(model_dimension));
SingleHeadTransformerBlock::new(
query,
key,
value,
output_projection,
attention_norm,
feed_forward,
feed_forward_norm,
)
}
fn tiny_multi_head_block() -> CtResult<MultiHeadTransformerBlock> {
let model_dimension = ModelDimension::new(2)?;
MultiHeadTransformerBlock::new(
vec![
tiny_self_attention_head_first_feature()?,
tiny_self_attention_head_second_feature()?,
],
AttentionOutputProjection::new(vec![vec![1.0, 0.0], vec![0.0, 1.0]], vec![0.0, 0.0])?,
LayerNormalization::new(LayerNormParameters::identity(model_dimension)),
identity_feed_forward()?,
LayerNormalization::new(LayerNormParameters::identity(model_dimension)),
)
}
fn tiny_masked_multi_head_block() -> CtResult<MaskedMultiHeadTransformerBlock> {
let model_dimension = ModelDimension::new(2)?;
MaskedMultiHeadTransformerBlock::new(
vec![
tiny_self_attention_head_first_feature()?,
tiny_self_attention_head_second_feature()?,
],
AttentionOutputProjection::new(vec![vec![1.0, 0.0], vec![0.0, 1.0]], vec![0.0, 0.0])?,
LayerNormalization::new(LayerNormParameters::identity(model_dimension)),
identity_feed_forward()?,
LayerNormalization::new(LayerNormParameters::identity(model_dimension)),
)
}
fn tiny_transformer_parameters() -> CtResult<TinyTransformerParameters> {
TinyTransformerParameters::new(
PositionalEncoding::new(vec![vec![0.1, 0.0], vec![0.0, 0.1]])?,
tiny_masked_multi_head_block()?,
TransformerReadout::new(
vec![vec![1.0, 0.0, 0.5], vec![0.0, 1.0, -0.5]],
vec![0.0, 0.0, 0.0],
)?,
)
}
fn tiny_transformer_training_set() -> CtResult<TransformerReadoutTrainingSet> {
let example = TransformerReadoutTrainingExample::new(
HiddenSequence::new(vec![vec![1.0, 0.0], vec![0.0, 1.0]])?,
AttentionMask::new(vec![vec![true, false], vec![true, true]])?,
TokenSequence::from_indices([0, 1])?,
)?;
TransformerReadoutTrainingSet::new([example])
}
fn tiny_feed_forward_training_set() -> CtResult<TransformerFeedForwardTrainingSet> {
let example = TransformerFeedForwardTrainingExample::new(
HiddenSequence::new(vec![vec![1.0, 0.0], vec![0.0, 1.0]])?,
HiddenSequence::new(vec![vec![2.0, 0.0], vec![0.0, 2.0]])?,
)?;
TransformerFeedForwardTrainingSet::new([example])
}
fn tiny_transformer_block_training_set() -> CtResult<TransformerBlockTrainingSet> {
let example = TransformerBlockTrainingExample::new(
HiddenSequence::new(vec![vec![1.0, 0.0], vec![0.0, 1.0]])?,
AttentionMask::new(vec![vec![true, false], vec![true, true]])?,
TokenSequence::from_indices([0, 1])?,
)?;
TransformerBlockTrainingSet::new([example])
}
fn tiny_self_attention_head_first_feature() -> CtResult<SelfAttentionHead> {
SelfAttentionHead::new(
HiddenToQuery::new(vec![vec![1.0], vec![0.0]], vec![0.0])?,
HiddenToKey::new(vec![vec![1.0], vec![0.0]], vec![0.0])?,
HiddenToValue::new(vec![vec![1.0], vec![0.0]], vec![0.0])?,
)
}
fn tiny_self_attention_head_second_feature() -> CtResult<SelfAttentionHead> {
SelfAttentionHead::new(
HiddenToQuery::new(vec![vec![0.0], vec![1.0]], vec![0.0])?,
HiddenToKey::new(vec![vec![0.0], vec![1.0]], vec![0.0])?,
HiddenToValue::new(vec![vec![0.0], vec![1.0]], vec![0.0])?,
)
}
fn identity_feed_forward() -> CtResult<PositionWiseFeedForward> {
PositionWiseFeedForward::new(
vec![vec![1.0, 0.0], vec![0.0, 1.0]],
vec![0.0, 0.0],
vec![vec![1.0, 0.0], vec![0.0, 1.0]],
vec![0.0, 0.0],
)
}
}
src/sketches.rs
//! Rust models for the seven applied-category-theory sketches.
//!
//! This module is a study companion for *Seven Sketches in Compositionality*.
//! It does not try to encode all of category theory. Instead, each section
//! gives one small Rust model for the main structure of a sketch: orders,
//! resources, databases, co-design, signal flow, circuits, and logic of
//! behavior.
use crate::error::{CtError, CtResult};
/// A finite order used for the first sketch: observations can be refined into
/// features, scores, and decisions.
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
pub enum InformationLevel {
Observation,
Feature,
Score,
Decision,
}
impl InformationLevel {
/// Returns true when information at this level can flow into `target`.
pub fn can_flow_to(self, target: Self) -> bool {
self <= target
}
/// Least upper bound in this small total order.
pub fn join(self, other: Self) -> Self {
self.max(other)
}
}
const INFORMATION_LEVELS: [InformationLevel; 4] = [
InformationLevel::Observation,
InformationLevel::Feature,
InformationLevel::Score,
InformationLevel::Decision,
];
/// Checks reflexivity and transitivity for the finite information order.
pub fn information_order_obeys_preorder_laws() -> bool {
for level in INFORMATION_LEVELS {
if !level.can_flow_to(level) {
return false;
}
}
for first in INFORMATION_LEVELS {
for second in INFORMATION_LEVELS {
for third in INFORMATION_LEVELS {
let premise = first.can_flow_to(second) && second.can_flow_to(third);
if premise && !first.can_flow_to(third) {
return false;
}
}
}
}
true
}
/// Number of concrete features in a tiny model.
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
pub struct FeatureCount(usize);
impl FeatureCount {
pub fn new(value: usize) -> CtResult<Self> {
if value == 0 {
return Err(CtError::EmptyInput("feature count"));
}
Ok(Self(value))
}
pub fn value(&self) -> usize {
self.0
}
}
/// Number of abstract layers used to summarize feature capacity.
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
pub struct LayerBudget(usize);
impl LayerBudget {
pub fn new(value: usize) -> CtResult<Self> {
if value == 0 {
return Err(CtError::EmptyInput("layer budget"));
}
Ok(Self(value))
}
pub fn value(&self) -> usize {
self.0
}
}
const FEATURES_PER_LAYER: usize = 4;
/// Abstracts concrete features to the minimum layer budget that can hold them.
pub fn abstract_to_layer_budget(features: FeatureCount) -> CtResult<LayerBudget> {
LayerBudget::new(features.value().div_ceil(FEATURES_PER_LAYER))
}
/// Concretizes an abstract layer budget back to feature capacity.
pub fn concretize_layer_budget(layers: LayerBudget) -> FeatureCount {
FeatureCount(layers.value() * FEATURES_PER_LAYER)
}
/// Checks the Galois-connection law for the feature/layer abstraction.
pub fn feature_layer_galois_law_holds(
features: FeatureCount,
layers: LayerBudget,
) -> CtResult<bool> {
let abstracted_fits = abstract_to_layer_budget(features)?.value() <= layers.value();
let concrete_fits = features.value() <= concretize_layer_budget(layers).value();
Ok(abstracted_fits == concrete_fits)
}
/// A non-negative amount in a resource bundle.
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
pub struct ResourceAmount(usize);
impl ResourceAmount {
pub fn new(value: usize) -> Self {
Self(value)
}
pub fn value(&self) -> usize {
self.0
}
}
/// Compute and memory resources composed by component-wise addition.
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct ResourceBundle {
compute: ResourceAmount,
memory: ResourceAmount,
}
impl ResourceBundle {
pub fn new(compute: ResourceAmount, memory: ResourceAmount) -> Self {
Self { compute, memory }
}
pub fn compute(&self) -> ResourceAmount {
self.compute
}
pub fn memory(&self) -> ResourceAmount {
self.memory
}
/// Monoidal product for independent resources.
pub fn tensor(&self, other: &Self) -> Self {
Self {
compute: ResourceAmount::new(self.compute.value() + other.compute.value()),
memory: ResourceAmount::new(self.memory.value() + other.memory.value()),
}
}
/// Resource preorder: supply can cover demand when every component is large
/// enough.
pub fn can_supply(&self, demand: &Self) -> bool {
self.compute >= demand.compute && self.memory >= demand.memory
}
}
/// Demonstrates monotonicity of the monoidal resource operation.
pub fn resource_tensor_is_monotone() -> bool {
let small_supply = ResourceBundle::new(ResourceAmount::new(2), ResourceAmount::new(8));
let large_supply = ResourceBundle::new(ResourceAmount::new(4), ResourceAmount::new(16));
let fixed = ResourceBundle::new(ResourceAmount::new(1), ResourceAmount::new(2));
large_supply.can_supply(&small_supply)
&& large_supply
.tensor(&fixed)
.can_supply(&small_supply.tensor(&fixed))
}
/// Identifier for a department row.
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub struct DepartmentId(usize);
impl DepartmentId {
pub fn new(value: usize) -> Self {
Self(value)
}
pub fn value(&self) -> usize {
self.0
}
}
/// Identifier for an employee row.
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub struct EmployeeId(usize);
impl EmployeeId {
pub fn new(value: usize) -> Self {
Self(value)
}
pub fn value(&self) -> usize {
self.0
}
}
/// A row in the employee table. The department field is the schema arrow
/// `Employee -> Department`.
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct EmployeeRecord {
id: EmployeeId,
department: DepartmentId,
}
impl EmployeeRecord {
pub fn new(id: EmployeeId, department: DepartmentId) -> Self {
Self { id, department }
}
pub fn id(&self) -> EmployeeId {
self.id
}
pub fn department(&self) -> DepartmentId {
self.department
}
}
/// A tiny database instance: schema objects become sets of ids, and schema
/// arrows become functions between those sets.
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct CompanyInstance {
departments: Vec<DepartmentId>,
employees: Vec<EmployeeRecord>,
}
impl CompanyInstance {
pub fn new(
departments: impl IntoIterator<Item = DepartmentId>,
employees: impl IntoIterator<Item = EmployeeRecord>,
) -> CtResult<Self> {
let departments = departments.into_iter().collect::<Vec<_>>();
let employees = employees.into_iter().collect::<Vec<_>>();
for employee in &employees {
if !departments.contains(&employee.department()) {
return Err(CtError::ShapeMismatch {
op: "database instance",
expected: String::from("employee departments exist in Department"),
got: format!("missing department {}", employee.department().value()),
});
}
}
Ok(Self {
departments,
employees,
})
}
pub fn departments(&self) -> &[DepartmentId] {
&self.departments
}
pub fn employees(&self) -> &[EmployeeRecord] {
&self.employees
}
pub fn department_of(&self, employee_id: EmployeeId) -> Option<DepartmentId> {
self.employees
.iter()
.find(|employee| employee.id() == employee_id)
.map(EmployeeRecord::department)
}
pub fn foreign_keys_resolve(&self) -> bool {
self.employees
.iter()
.all(|employee| self.departments.contains(&employee.department()))
}
}
/// Requests per second needed by a design.
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
pub struct Throughput(usize);
impl Throughput {
pub fn new(value: usize) -> CtResult<Self> {
if value == 0 {
return Err(CtError::EmptyInput("throughput"));
}
Ok(Self(value))
}
pub fn value(&self) -> usize {
self.0
}
}
/// Millisecond latency boundary.
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
pub struct LatencyMs(usize);
impl LatencyMs {
pub fn new(value: usize) -> CtResult<Self> {
if value == 0 {
return Err(CtError::EmptyInput("latency"));
}
Ok(Self(value))
}
pub fn value(&self) -> usize {
self.0
}
}
/// Functional need in a co-design problem.
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct DesignRequirement {
minimum_throughput: Throughput,
maximum_latency: LatencyMs,
}
impl DesignRequirement {
pub fn new(minimum_throughput: Throughput, maximum_latency: LatencyMs) -> Self {
Self {
minimum_throughput,
maximum_latency,
}
}
}
/// Implementation offered by a candidate component.
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct ImplementationOffer {
throughput: Throughput,
latency: LatencyMs,
}
impl ImplementationOffer {
pub fn new(throughput: Throughput, latency: LatencyMs) -> Self {
Self {
throughput,
latency,
}
}
}
/// A Bool-valued feasibility relation between requirements and offers.
#[derive(Debug, Clone, Copy)]
pub struct FeasibilityRelation;
impl FeasibilityRelation {
pub fn relates(requirement: DesignRequirement, offer: ImplementationOffer) -> bool {
offer.throughput >= requirement.minimum_throughput
&& offer.latency <= requirement.maximum_latency
}
}
/// A scalar coefficient in a signal-flow matrix.
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct SignalCoefficient(i32);
impl SignalCoefficient {
pub fn new(value: i32) -> Self {
Self(value)
}
pub fn zero() -> Self {
Self(0)
}
pub fn value(&self) -> i32 {
self.0
}
fn add(self, other: Self) -> Self {
Self(self.value() + other.value())
}
fn multiply(self, other: Self) -> Self {
Self(self.value() * other.value())
}
}
/// Number of rows in a signal-flow matrix.
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct MatrixRows(usize);
impl MatrixRows {
pub fn new(value: usize) -> CtResult<Self> {
if value == 0 {
return Err(CtError::EmptyInput("matrix rows"));
}
Ok(Self(value))
}
pub fn value(&self) -> usize {
self.0
}
}
/// Number of columns in a signal-flow matrix.
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct MatrixCols(usize);
impl MatrixCols {
pub fn new(value: usize) -> CtResult<Self> {
if value == 0 {
return Err(CtError::EmptyInput("matrix columns"));
}
Ok(Self(value))
}
pub fn value(&self) -> usize {
self.0
}
}
/// Matrix semantics for a signal-flow graph.
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct SignalMatrix {
rows: MatrixRows,
cols: MatrixCols,
coefficients: Vec<Vec<SignalCoefficient>>,
}
impl SignalMatrix {
pub fn new(
rows: MatrixRows,
cols: MatrixCols,
coefficients: Vec<Vec<SignalCoefficient>>,
) -> CtResult<Self> {
if coefficients.len() != rows.value() {
return Err(CtError::ShapeMismatch {
op: "signal matrix",
expected: format!("{} rows", rows.value()),
got: format!("{} rows", coefficients.len()),
});
}
for row in &coefficients {
if row.len() != cols.value() {
return Err(CtError::ShapeMismatch {
op: "signal matrix",
expected: format!("{} columns", cols.value()),
got: format!("{} columns", row.len()),
});
}
}
Ok(Self {
rows,
cols,
coefficients,
})
}
pub fn rows(&self) -> MatrixRows {
self.rows
}
pub fn cols(&self) -> MatrixCols {
self.cols
}
pub fn coefficients(&self) -> &[Vec<SignalCoefficient>] {
&self.coefficients
}
/// Matrix composition. If `previous` is `A -> B` and `self` is `B -> C`,
/// the result is `A -> C`.
pub fn compose_after(&self, previous: &Self) -> CtResult<Self> {
if previous.rows.value() != self.cols.value() {
return Err(CtError::ShapeMismatch {
op: "signal matrix composition",
expected: format!("middle dimension {}", self.cols.value()),
got: format!("middle dimension {}", previous.rows.value()),
});
}
let mut coefficients =
vec![vec![SignalCoefficient::zero(); previous.cols.value()]; self.rows.value()];
for (output_row, output_coefficients) in coefficients.iter_mut().enumerate() {
for (input_col, coefficient) in output_coefficients.iter_mut().enumerate() {
let mut total = SignalCoefficient::zero();
for middle in 0..self.cols.value() {
total = total.add(
self.coefficients[output_row][middle]
.multiply(previous.coefficients[middle][input_col]),
);
}
*coefficient = total;
}
}
Self::new(self.rows, previous.cols, coefficients)
}
}
/// A named boundary port in an open circuit.
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub struct PortName(&'static str);
impl PortName {
pub fn new(value: &'static str) -> CtResult<Self> {
if value.trim().is_empty() {
return Err(CtError::EmptyInput("port name"));
}
Ok(Self(value))
}
pub fn as_str(&self) -> &'static str {
self.0
}
}
/// Positive resistance value.
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct ResistanceOhms(usize);
impl ResistanceOhms {
pub fn new(value: usize) -> CtResult<Self> {
if value == 0 {
return Err(CtError::InvalidQuantity {
kind: "resistance",
value: value as i64,
});
}
Ok(Self(value))
}
pub fn value(&self) -> usize {
self.0
}
}
/// A simple resistor component between two ports.
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct CircuitComponent {
from: PortName,
to: PortName,
resistance: ResistanceOhms,
}
impl CircuitComponent {
pub fn resistor(from: PortName, to: PortName, resistance: ResistanceOhms) -> Self {
Self {
from,
to,
resistance,
}
}
pub fn from(&self) -> PortName {
self.from
}
pub fn to(&self) -> PortName {
self.to
}
pub fn resistance(&self) -> ResistanceOhms {
self.resistance
}
}
/// An open circuit with input ports, output ports, and internal components.
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct OpenCircuit {
inputs: Vec<PortName>,
outputs: Vec<PortName>,
components: Vec<CircuitComponent>,
}
impl OpenCircuit {
pub fn new(
inputs: impl IntoIterator<Item = PortName>,
outputs: impl IntoIterator<Item = PortName>,
components: impl IntoIterator<Item = CircuitComponent>,
) -> CtResult<Self> {
let inputs = inputs.into_iter().collect::<Vec<_>>();
let outputs = outputs.into_iter().collect::<Vec<_>>();
let components = components.into_iter().collect::<Vec<_>>();
if inputs.is_empty() {
return Err(CtError::EmptyInput("circuit inputs"));
}
if outputs.is_empty() {
return Err(CtError::EmptyInput("circuit outputs"));
}
Ok(Self {
inputs,
outputs,
components,
})
}
pub fn input_count(&self) -> usize {
self.inputs.len()
}
pub fn output_count(&self) -> usize {
self.outputs.len()
}
pub fn component_count(&self) -> usize {
self.components.len()
}
/// Serial composition wires this circuit's outputs into the next circuit's
/// inputs.
pub fn then(&self, next: &Self) -> CtResult<Self> {
if self.output_count() != next.input_count() {
return Err(CtError::ShapeMismatch {
op: "open circuit serial composition",
expected: format!("{} next inputs", self.output_count()),
got: format!("{} next inputs", next.input_count()),
});
}
let mut components = self.components.clone();
components.extend_from_slice(&next.components);
Self::new(self.inputs.clone(), next.outputs.clone(), components)
}
/// Parallel composition keeps the two open interfaces side by side.
pub fn parallel(&self, other: &Self) -> CtResult<Self> {
let mut inputs = self.inputs.clone();
inputs.extend_from_slice(&other.inputs);
let mut outputs = self.outputs.clone();
outputs.extend_from_slice(&other.outputs);
let mut components = self.components.clone();
components.extend_from_slice(&other.components);
Self::new(inputs, outputs, components)
}
}
/// Truth values used for behavior classification.
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum TruthValue {
False,
True,
}
impl TruthValue {
pub fn and(self, other: Self) -> Self {
match (self, other) {
(TruthValue::True, TruthValue::True) => TruthValue::True,
_ => TruthValue::False,
}
}
pub fn implies(self, other: Self) -> Self {
match (self, other) {
(TruthValue::True, TruthValue::False) => TruthValue::False,
_ => TruthValue::True,
}
}
}
/// Discrete time point in a behavior trace.
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
pub struct TimeTick(usize);
impl TimeTick {
pub fn new(value: usize) -> Self {
Self(value)
}
pub fn value(&self) -> usize {
self.0
}
}
/// Closed interval of time ticks.
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct TimeInterval {
start: TimeTick,
end: TimeTick,
}
impl TimeInterval {
pub fn new(start: TimeTick, end: TimeTick) -> CtResult<Self> {
if start > end {
return Err(CtError::InvalidInterval {
start: start.value(),
end: end.value(),
});
}
Ok(Self { start, end })
}
pub fn start(&self) -> TimeTick {
self.start
}
pub fn end(&self) -> TimeTick {
self.end
}
}
/// Local safety result on one interval.
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct LocalSafetyCheck {
interval: TimeInterval,
truth: TruthValue,
}
impl LocalSafetyCheck {
pub fn new(interval: TimeInterval, truth: TruthValue) -> Self {
Self { interval, truth }
}
pub fn interval(&self) -> TimeInterval {
self.interval
}
pub fn truth(&self) -> TruthValue {
self.truth
}
}
/// A cover of local behavior checks. The global result is true only when all
/// local checks are true.
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct SafetyCover(Vec<LocalSafetyCheck>);
impl SafetyCover {
pub fn new(checks: impl IntoIterator<Item = LocalSafetyCheck>) -> CtResult<Self> {
let checks = checks.into_iter().collect::<Vec<_>>();
if checks.is_empty() {
return Err(CtError::EmptyInput("safety cover"));
}
Ok(Self(checks))
}
pub fn global_truth(&self) -> TruthValue {
self.0
.iter()
.fold(TruthValue::True, |truth, check| truth.and(check.truth()))
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn information_order_has_preorder_laws() {
assert!(information_order_obeys_preorder_laws());
assert_eq!(
InformationLevel::Feature.join(InformationLevel::Decision),
InformationLevel::Decision
);
}
#[test]
fn feature_layer_abstraction_obeys_galois_law() -> CtResult<()> {
let features = FeatureCount::new(9)?;
let layers = LayerBudget::new(3)?;
assert!(feature_layer_galois_law_holds(features, layers)?);
assert_eq!(abstract_to_layer_budget(features)?.value(), 3);
assert_eq!(concretize_layer_budget(layers).value(), 12);
Ok(())
}
#[test]
fn resource_tensor_is_order_preserving() {
assert!(resource_tensor_is_monotone());
}
#[test]
fn database_instance_resolves_schema_arrow() -> CtResult<()> {
let research = DepartmentId::new(1);
let platform = DepartmentId::new(2);
let ada = EmployeeId::new(7);
let instance =
CompanyInstance::new([research, platform], [EmployeeRecord::new(ada, research)])?;
assert!(instance.foreign_keys_resolve());
assert_eq!(instance.department_of(ada), Some(research));
Ok(())
}
#[test]
fn database_instance_rejects_missing_department_reference() {
let research = DepartmentId::new(1);
let missing_department = DepartmentId::new(99);
let ada = EmployeeId::new(7);
assert!(matches!(
CompanyInstance::new([research], [EmployeeRecord::new(ada, missing_department)]),
Err(CtError::ShapeMismatch {
op: "database instance",
..
})
));
}
#[test]
fn feasibility_relation_matches_requirement_to_offer() -> CtResult<()> {
let requirement = DesignRequirement::new(Throughput::new(100)?, LatencyMs::new(80)?);
let offer = ImplementationOffer::new(Throughput::new(120)?, LatencyMs::new(50)?);
assert!(FeasibilityRelation::relates(requirement, offer));
Ok(())
}
#[test]
fn signal_matrices_compose_like_flow_graph_semantics() -> CtResult<()> {
let duplicate = SignalMatrix::new(
MatrixRows::new(2)?,
MatrixCols::new(1)?,
vec![
vec![SignalCoefficient::new(1)],
vec![SignalCoefficient::new(1)],
],
)?;
let add_weighted = SignalMatrix::new(
MatrixRows::new(1)?,
MatrixCols::new(2)?,
vec![vec![SignalCoefficient::new(2), SignalCoefficient::new(3)]],
)?;
let composed = add_weighted.compose_after(&duplicate)?;
assert_eq!(composed.coefficients(), &[vec![SignalCoefficient::new(5)]]);
Ok(())
}
#[test]
fn signal_matrix_composition_rejects_mismatched_middle_dimension() -> CtResult<()> {
let previous = SignalMatrix::new(
MatrixRows::new(2)?,
MatrixCols::new(1)?,
vec![
vec![SignalCoefficient::new(1)],
vec![SignalCoefficient::new(1)],
],
)?;
let next = SignalMatrix::new(
MatrixRows::new(1)?,
MatrixCols::new(3)?,
vec![vec![
SignalCoefficient::new(1),
SignalCoefficient::new(1),
SignalCoefficient::new(1),
]],
)?;
assert!(matches!(
next.compose_after(&previous),
Err(CtError::ShapeMismatch {
op: "signal matrix composition",
..
})
));
Ok(())
}
#[test]
fn open_circuits_compose_in_series_and_parallel() -> CtResult<()> {
let input = PortName::new("input")?;
let middle = PortName::new("middle")?;
let output = PortName::new("output")?;
let first = OpenCircuit::new(
[input],
[middle],
[CircuitComponent::resistor(
input,
middle,
ResistanceOhms::new(10)?,
)],
)?;
let second = OpenCircuit::new(
[middle],
[output],
[CircuitComponent::resistor(
middle,
output,
ResistanceOhms::new(20)?,
)],
)?;
let serial = first.then(&second)?;
let parallel = first.parallel(&second)?;
assert_eq!(serial.input_count(), 1);
assert_eq!(serial.output_count(), 1);
assert_eq!(serial.component_count(), 2);
assert_eq!(parallel.input_count(), 2);
assert_eq!(parallel.output_count(), 2);
Ok(())
}
#[test]
fn open_circuit_serial_composition_rejects_boundary_mismatch() -> CtResult<()> {
let input = PortName::new("input")?;
let output = PortName::new("output")?;
let left = OpenCircuit::new([input], [output], [])?;
let right = OpenCircuit::new([input, output], [output], [])?;
assert!(matches!(
left.then(&right),
Err(CtError::ShapeMismatch {
op: "open circuit serial composition",
..
})
));
Ok(())
}
#[test]
fn local_behavior_truth_glues_to_global_truth() -> CtResult<()> {
let first = LocalSafetyCheck::new(
TimeInterval::new(TimeTick::new(0), TimeTick::new(5))?,
TruthValue::True,
);
let second = LocalSafetyCheck::new(
TimeInterval::new(TimeTick::new(5), TimeTick::new(10))?,
TruthValue::True,
);
let cover = SafetyCover::new([first, second])?;
assert_eq!(cover.global_truth(), TruthValue::True);
assert_eq!(
TruthValue::True.implies(TruthValue::False),
TruthValue::False
);
Ok(())
}
}
src/challenges/mod.rs
//! Public challenge implementations that back the learner-facing challenge files.
//!
//! The intentionally broken exercise templates live under `challenges/` and are
//! not part of the normal Cargo build. This module contains the validated
//! reference behavior that keeps the public challenge honest.
pub mod papers;
pub mod typed_ai;
src/challenges/typed_ai.rs
//! Reference helpers for the Typed AI Rustlings challenge.
use crate::category::Morphism;
use crate::domain::{Distribution, Logits, Loss, Product, TokenId, VocabSize};
use crate::error::{CtError, CtResult};
use crate::ml::{CrossEntropy, Softmax};
/// Exposes the integer index only after the value has crossed the `TokenId`
/// boundary.
pub fn token_index(token: TokenId) -> usize {
token.index()
}
/// Builds a uniform distribution for a non-empty vocabulary.
pub fn uniform_distribution(vocab_size: VocabSize) -> CtResult<Distribution> {
let probability = 1.0 / vocab_size.value() as f32;
Distribution::new(vec![probability; vocab_size.value()])
}
/// Computes target loss from raw logits by crossing the softmax boundary first.
pub fn loss_from_logits(logits: Logits, target: TokenId) -> CtResult<Loss> {
let distribution = Softmax.apply(logits)?;
CrossEntropy.apply(Product::new(distribution, target))
}
/// Validates that a target token can address a probability distribution.
pub fn require_target_in_distribution(
distribution: &Distribution,
target: TokenId,
) -> CtResult<()> {
if target.index() >= distribution.as_slice().len() {
return Err(CtError::OutOfRange {
kind: "target token",
index: target.index(),
limit: distribution.as_slice().len(),
});
}
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn token_index_requires_a_token_id_boundary() {
assert_eq!(token_index(TokenId::new(7)), 7);
}
#[test]
fn uniform_distribution_normalizes_vocab_size() -> CtResult<()> {
let distribution = uniform_distribution(VocabSize::new(4)?)?;
assert_eq!(distribution.as_slice(), &[0.25, 0.25, 0.25, 0.25]);
Ok(())
}
#[test]
fn loss_from_logits_crosses_softmax_before_cross_entropy() -> CtResult<()> {
let loss = loss_from_logits(Logits::new(vec![0.0, 2.0]), TokenId::new(1))?;
assert!(loss.value() < 0.2);
Ok(())
}
#[test]
fn target_validation_rejects_out_of_range_token() -> CtResult<()> {
let distribution = uniform_distribution(VocabSize::new(2)?)?;
assert!(matches!(
require_target_in_distribution(&distribution, TokenId::new(9)),
Err(CtError::OutOfRange {
kind: "target token",
index: 9,
limit: 2,
})
));
Ok(())
}
}
src/challenges/papers/mod.rs
//! Paper-To-Rust challenge seeds.
pub mod adam;
src/challenges/papers/adam.rs
//! A small Paper-To-Rust translation of Adam optimizer state.
//!
//! The challenge compiles one idea from the Adam paper: an optimizer step is not
//! just a parameter vector update. It carries first-moment, second-moment, and
//! step-count state forward.
use crate::category::Morphism;
use crate::domain::LearningRate;
use crate::error::{CtError, CtResult};
/// Number of coordinates in an Adam parameter or gradient vector.
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct AdamVectorDimension(usize);
impl AdamVectorDimension {
pub fn new(value: usize) -> CtResult<Self> {
if value == 0 {
return Err(CtError::EmptyInput("Adam vector dimension"));
}
Ok(Self(value))
}
pub fn value(&self) -> usize {
self.0
}
}
/// Model parameters owned by the Adam challenge.
#[derive(Debug, Clone, PartialEq)]
pub struct AdamParameterVector(Vec<f32>);
impl AdamParameterVector {
pub fn new(values: Vec<f32>) -> CtResult<Self> {
validate_finite_non_empty("Adam parameter vector", &values)?;
Ok(Self(values))
}
pub fn dimension(&self) -> AdamVectorDimension {
AdamVectorDimension(self.0.len())
}
pub fn as_slice(&self) -> &[f32] {
&self.0
}
}
/// Gradient vector for one Adam update.
#[derive(Debug, Clone, PartialEq)]
pub struct AdamGradientVector(Vec<f32>);
impl AdamGradientVector {
pub fn new(values: Vec<f32>) -> CtResult<Self> {
validate_finite_non_empty("Adam gradient vector", &values)?;
Ok(Self(values))
}
pub fn dimension(&self) -> AdamVectorDimension {
AdamVectorDimension(self.0.len())
}
pub fn as_slice(&self) -> &[f32] {
&self.0
}
}
/// Exponential moving average of gradients.
#[derive(Debug, Clone, PartialEq)]
pub struct AdamFirstMoment(Vec<f32>);
impl AdamFirstMoment {
pub fn zeros(dimension: AdamVectorDimension) -> Self {
Self(vec![0.0; dimension.value()])
}
pub fn as_slice(&self) -> &[f32] {
&self.0
}
}
/// Exponential moving average of squared gradients.
#[derive(Debug, Clone, PartialEq)]
pub struct AdamSecondMoment(Vec<f32>);
impl AdamSecondMoment {
pub fn zeros(dimension: AdamVectorDimension) -> Self {
Self(vec![0.0; dimension.value()])
}
pub fn as_slice(&self) -> &[f32] {
&self.0
}
}
/// A decay coefficient in the half-open range `[0, 1)`.
#[derive(Debug, Clone, Copy, PartialEq)]
pub struct AdamDecayRate(f32);
impl AdamDecayRate {
pub fn new(value: f32) -> CtResult<Self> {
if !value.is_finite() || !(0.0..1.0).contains(&value) {
return Err(CtError::InvalidScalar {
kind: "Adam decay rate",
value,
});
}
Ok(Self(value))
}
pub fn value(&self) -> f32 {
self.0
}
}
/// Small positive stabilizer in the Adam denominator.
#[derive(Debug, Clone, Copy, PartialEq)]
pub struct AdamEpsilon(f32);
impl AdamEpsilon {
pub fn new(value: f32) -> CtResult<Self> {
if !value.is_finite() || value <= 0.0 {
return Err(CtError::InvalidScalar {
kind: "Adam epsilon",
value,
});
}
Ok(Self(value))
}
pub fn value(&self) -> f32 {
self.0
}
}
/// Count of Adam updates already applied.
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct AdamStepCount(usize);
impl AdamStepCount {
pub fn zero() -> Self {
Self(0)
}
pub fn value(&self) -> usize {
self.0
}
fn next(self) -> Self {
Self(self.0 + 1)
}
}
/// Adam hyperparameters that affect one optimizer transition.
#[derive(Debug, Clone, Copy, PartialEq)]
pub struct AdamConfig {
learning_rate: LearningRate,
beta1: AdamDecayRate,
beta2: AdamDecayRate,
epsilon: AdamEpsilon,
}
impl AdamConfig {
pub fn new(
learning_rate: LearningRate,
beta1: AdamDecayRate,
beta2: AdamDecayRate,
epsilon: AdamEpsilon,
) -> Self {
Self {
learning_rate,
beta1,
beta2,
epsilon,
}
}
pub fn learning_rate(&self) -> LearningRate {
self.learning_rate
}
pub fn beta1(&self) -> AdamDecayRate {
self.beta1
}
pub fn beta2(&self) -> AdamDecayRate {
self.beta2
}
pub fn epsilon(&self) -> AdamEpsilon {
self.epsilon
}
}
/// Adam optimizer memory that must move with the parameters.
#[derive(Debug, Clone, PartialEq)]
pub struct AdamOptimizerState {
first_moment: AdamFirstMoment,
second_moment: AdamSecondMoment,
step_count: AdamStepCount,
}
impl AdamOptimizerState {
pub fn zeros(dimension: AdamVectorDimension) -> Self {
Self {
first_moment: AdamFirstMoment::zeros(dimension),
second_moment: AdamSecondMoment::zeros(dimension),
step_count: AdamStepCount::zero(),
}
}
pub fn first_moment(&self) -> &AdamFirstMoment {
&self.first_moment
}
pub fn second_moment(&self) -> &AdamSecondMoment {
&self.second_moment
}
pub fn step_count(&self) -> AdamStepCount {
self.step_count
}
}
/// Complete Adam state at an optimizer boundary.
#[derive(Debug, Clone, PartialEq)]
pub struct AdamModelState {
parameters: AdamParameterVector,
optimizer: AdamOptimizerState,
}
impl AdamModelState {
pub fn new(parameters: AdamParameterVector, optimizer: AdamOptimizerState) -> CtResult<Self> {
validate_matching_dimension(
"Adam model state",
parameters.dimension(),
optimizer.first_moment.as_slice().len(),
)?;
validate_matching_dimension(
"Adam model state",
parameters.dimension(),
optimizer.second_moment.as_slice().len(),
)?;
Ok(Self {
parameters,
optimizer,
})
}
pub fn from_parameters(parameters: AdamParameterVector) -> Self {
let optimizer = AdamOptimizerState::zeros(parameters.dimension());
Self {
parameters,
optimizer,
}
}
pub fn parameters(&self) -> &AdamParameterVector {
&self.parameters
}
pub fn optimizer(&self) -> &AdamOptimizerState {
&self.optimizer
}
}
/// One Adam optimizer step as an endomorphism on `AdamModelState`.
#[derive(Debug, Clone, PartialEq)]
pub struct AdamTrainStep {
gradient: AdamGradientVector,
config: AdamConfig,
}
impl AdamTrainStep {
pub fn new(gradient: AdamGradientVector, config: AdamConfig) -> Self {
Self { gradient, config }
}
}
impl Morphism<AdamModelState, AdamModelState> for AdamTrainStep {
fn name(&self) -> &'static str {
"adam_train_step"
}
fn apply(&self, state: AdamModelState) -> CtResult<AdamModelState> {
let dimension = state.parameters.dimension();
validate_matching_dimension("Adam gradient", dimension, self.gradient.as_slice().len())?;
validate_matching_dimension(
"Adam first moment",
dimension,
state.optimizer.first_moment.as_slice().len(),
)?;
validate_matching_dimension(
"Adam second moment",
dimension,
state.optimizer.second_moment.as_slice().len(),
)?;
let next_step = state.optimizer.step_count.next();
let beta1 = self.config.beta1.value();
let beta2 = self.config.beta2.value();
let learning_rate = self.config.learning_rate.value();
let epsilon = self.config.epsilon.value();
let bias_correction_1 = 1.0 - beta1.powf(next_step.value() as f32);
let bias_correction_2 = 1.0 - beta2.powf(next_step.value() as f32);
let mut next_parameters = Vec::with_capacity(dimension.value());
let mut next_first_moment = Vec::with_capacity(dimension.value());
let mut next_second_moment = Vec::with_capacity(dimension.value());
for (((parameter, gradient), first_moment), second_moment) in state
.parameters
.as_slice()
.iter()
.copied()
.zip(self.gradient.as_slice().iter().copied())
.zip(state.optimizer.first_moment.as_slice().iter().copied())
.zip(state.optimizer.second_moment.as_slice().iter().copied())
{
let updated_first_moment = beta1 * first_moment + (1.0 - beta1) * gradient;
let updated_second_moment = beta2 * second_moment + (1.0 - beta2) * gradient * gradient;
let corrected_first_moment = updated_first_moment / bias_correction_1;
let corrected_second_moment = updated_second_moment / bias_correction_2;
let updated_parameter = parameter
- learning_rate * corrected_first_moment
/ (corrected_second_moment.sqrt() + epsilon);
next_parameters.push(updated_parameter);
next_first_moment.push(updated_first_moment);
next_second_moment.push(updated_second_moment);
}
AdamModelState::new(
AdamParameterVector::new(next_parameters)?,
AdamOptimizerState {
first_moment: AdamFirstMoment(next_first_moment),
second_moment: AdamSecondMoment(next_second_moment),
step_count: next_step,
},
)
}
}
fn validate_finite_non_empty(kind: &'static str, values: &[f32]) -> CtResult<()> {
if values.is_empty() {
return Err(CtError::EmptyInput(kind));
}
if let Some(value) = values.iter().copied().find(|value| !value.is_finite()) {
return Err(CtError::InvalidScalar { kind, value });
}
Ok(())
}
fn validate_matching_dimension(
op: &'static str,
expected: AdamVectorDimension,
got: usize,
) -> CtResult<()> {
if expected.value() != got {
return Err(CtError::ShapeMismatch {
op,
expected: format!("dimension {}", expected.value()),
got: format!("dimension {got}"),
});
}
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
fn standard_config() -> CtResult<AdamConfig> {
Ok(AdamConfig::new(
LearningRate::new(0.1)?,
AdamDecayRate::new(0.9)?,
AdamDecayRate::new(0.999)?,
AdamEpsilon::new(1e-8)?,
))
}
#[test]
fn adam_first_step_matches_bias_corrected_update() -> CtResult<()> {
let state = AdamModelState::from_parameters(AdamParameterVector::new(vec![1.0, -1.0])?);
let step = AdamTrainStep::new(
AdamGradientVector::new(vec![0.5, -0.25])?,
standard_config()?,
);
let updated = step.apply(state)?;
assert_eq!(updated.optimizer().step_count().value(), 1);
assert!((updated.parameters().as_slice()[0] - 0.9).abs() < 1e-5);
assert!((updated.parameters().as_slice()[1] + 0.9).abs() < 1e-5);
assert!((updated.optimizer().first_moment().as_slice()[0] - 0.05).abs() < 1e-6);
assert!((updated.optimizer().second_moment().as_slice()[0] - 0.00025).abs() < 1e-7);
Ok(())
}
#[test]
fn adam_rejects_bad_decay_rate() {
assert!(matches!(
AdamDecayRate::new(1.0),
Err(CtError::InvalidScalar {
kind: "Adam decay rate",
value: 1.0,
})
));
}
#[test]
fn adam_rejects_gradient_dimension_mismatch() -> CtResult<()> {
let state = AdamModelState::from_parameters(AdamParameterVector::new(vec![1.0, -1.0])?);
let step = AdamTrainStep::new(AdamGradientVector::new(vec![0.5])?, standard_config()?);
assert!(matches!(
step.apply(state),
Err(CtError::ShapeMismatch {
op: "Adam gradient",
..
})
));
Ok(())
}
#[test]
fn adam_step_preserves_complete_optimizer_state() -> CtResult<()> {
let state = AdamModelState::from_parameters(AdamParameterVector::new(vec![1.0])?);
let step = AdamTrainStep::new(AdamGradientVector::new(vec![0.5])?, standard_config()?);
let once = step.apply(state)?;
let twice = step.apply(once)?;
assert_eq!(twice.optimizer().step_count().value(), 2);
assert_eq!(twice.parameters().dimension().value(), 1);
assert_eq!(twice.optimizer().first_moment().as_slice().len(), 1);
assert_eq!(twice.optimizer().second_moment().as_slice().len(), 1);
Ok(())
}
}
src/demo.rs
use crate::calculus::{LocalGradient, MulOp, Scalar};
use crate::category::{Compose, StepCount, apply_endomorphism_n_times};
use crate::domain::{
LearningRate, Logits, ModelDimension, Parameters, Product, TokenId, TokenSequence, Vector,
VocabSize,
};
use crate::error::CtResult;
use crate::ml::{
CrossEntropy, DatasetWindowing, Embedding, LinearToLogits, Softmax, average_loss,
composed_prediction_matches_direct_prediction,
};
use crate::structure::{
Functor, Monoid, OptionFunctor, PipelineTrace, TraceStep, VecFunctor,
monoid_laws_hold_for_pipeline_trace, naturality_square_holds_for_first_option,
};
use crate::training::TrainStep;
use crate::{Identity, Morphism};
/// Run the full terminal walkthrough used by `cargo run --bin category_ml`.
pub fn run_demo() -> CtResult<()> {
println!("Category theory concepts implemented in Rust 2024");
println!("=================================================\n");
let vocab = ["<pad>", "I", "love", "Rust", "."];
let raw_text = TokenSequence::from_indices([1, 2, 3, 4, 1, 2, 3, 4])?;
println!("1. Object examples");
println!(" TokenId(1) means {:?}\n", vocab[1]);
println!("2. Dataset morphism: TokenSequence -> TrainingSet");
let dataset = DatasetWindowing.apply(raw_text)?;
for example in dataset.examples() {
println!(
" {:?} -> {:?}",
vocab[example.first().index()],
vocab[example.second().index()]
);
}
println!();
println!("3. Identity morphism: id_Vector : Vector -> Vector");
let v = Vector::new(vec![1.0, 2.0, 3.0]);
let same_v = Identity::<Vector>::new().apply(v.clone())?;
println!(" input = {:?}", v);
println!(" output = {:?}\n", same_v);
println!("4. Composition: Softmax after Linear after Embedding");
let params = Parameters::init(VocabSize::new(vocab.len())?, ModelDimension::new(4)?);
let embedding = Embedding::from_parameters(¶ms);
let linear = LinearToLogits::from_parameters(¶ms);
let token_to_logits = Compose::<_, _, Vector>::new(embedding, linear);
let token_to_distribution = Compose::<_, _, Logits>::new(token_to_logits, Softmax);
let distribution = token_to_distribution.apply(TokenId::new(1))?;
println!(" P(next token | 'I') = {:?}\n", distribution.as_slice());
println!("5. Product object: Prediction x Target -> Loss");
let loss = CrossEntropy.apply(Product::new(distribution, TokenId::new(2)))?;
println!(" loss for target 'love' = {:.6}\n", loss.value());
println!("6. Endomorphism: TrainStep : Parameters -> Parameters");
let before = average_loss(¶ms, &dataset)?;
let train_step = TrainStep::new(dataset.clone(), LearningRate::new(1.0)?);
let trained_params =
apply_endomorphism_n_times(&train_step, params.clone(), StepCount::new(80))?;
let after = average_loss(&trained_params, &dataset)?;
println!(" average loss before training = {:.6}", before.value());
println!(" average loss after training = {:.6}\n", after.value());
println!("7. Functor: fmap over Vec and Option");
let xs = vec![1, 2, 3];
let ys = VecFunctor::fmap(xs, |x| x * x);
let maybe = OptionFunctor::fmap(Some(7), |x| x + 1);
println!(" VecFunctor fmap square: {:?}", ys);
println!(" OptionFunctor fmap +1: {:?}\n", maybe);
println!("8. Natural transformation: Vec<A> -> Option<A>");
println!(
" naturality square holds: {}\n",
naturality_square_holds_for_first_option()
);
println!("9. Monoid: pipeline traces compose associatively with identity");
let trace = PipelineTrace::from_steps(vec![TraceStep::new("embedding")])
.combine(&PipelineTrace::from_steps(vec![TraceStep::new("linear")]))
.combine(&PipelineTrace::from_steps(vec![TraceStep::new("softmax")]));
println!(" trace = {:?}", trace.names());
println!(
" monoid laws hold for this trace type: {}\n",
monoid_laws_hold_for_pipeline_trace()
);
println!("10. Commutative diagram check");
println!(
" composed prediction == direct prediction: {}\n",
composed_prediction_matches_direct_prediction(¶ms)?
);
println!("11. Chain rule / local derivative morphism");
let mul = MulOp;
let x = Scalar::new(2.0)?;
let y = Scalar::new(3.0)?;
let z = mul.forward(x, y)?;
let upstream = LocalGradient::new(1.0)?;
let (dl_dx, dl_dy) = mul.backward(x, y, upstream)?;
println!(" z = x * y = {}", z.value());
println!(
" if dL/dz = {}, then dL/dx = {}, dL/dy = {}\n",
upstream.value(),
dl_dx.value(),
dl_dy.value()
);
println!("Compressed categorical training view:");
println!(" Dataset x Parameters -> Prediction -> Loss -> Gradients -> Updated Parameters");
println!(" TrainStep is repeated as Parameters0 -> Parameters1 -> ... -> ParametersN");
Ok(())
}
src/bin/category_ml.rs
fn main() -> category_theory_transformer_rs::CtResult<()> {
category_theory_transformer_rs::run_demo()
}
Runnable Examples
examples/01_token_sequence.rs
use category_theory_transformer_rs::{CtResult, Product, TokenId, TokenSequence, TrainingExample};
fn main() -> CtResult<()> {
let raw_input = "rust makes ai structure visible";
let token_ids = tokenize_visible_structure(raw_input);
let sequence = TokenSequence::new(token_ids)?;
let training_pairs = training_pairs(sequence.as_slice())?;
println!("Raw input:");
println!("\"{raw_input}\"");
println!();
println!("TokenSequence:");
println!("{}", format_token_sequence(sequence.as_slice()));
println!();
println!("TrainingPairs:");
for pair in &training_pairs {
println!("{}", format_training_pair(pair));
}
println!();
println!("Typed transformation:");
println!("Text -> TokenSequence -> TrainingPairs");
println!();
println!("No framework magic.");
println!("Just explicit structure.");
Ok(())
}
fn tokenize_visible_structure(input: &str) -> Vec<TokenId> {
input
.split_whitespace()
.map(|word| match word {
"rust" => TokenId::new(12),
"makes" => TokenId::new(44),
"ai" => TokenId::new(7),
"structure" => TokenId::new(19),
"visible" => TokenId::new(91),
_ => TokenId::new(0),
})
.collect()
}
fn training_pairs(tokens: &[TokenId]) -> CtResult<Vec<TrainingExample>> {
TokenSequence::new(tokens.iter().copied())?;
Ok(tokens
.windows(2)
.map(|pair| Product::new(pair[0], pair[1]))
.collect())
}
fn format_token_sequence(tokens: &[TokenId]) -> String {
let formatted = tokens
.iter()
.map(format_token_id)
.collect::<Vec<_>>()
.join(", ");
format!("[{formatted}]")
}
fn format_training_pair(pair: &TrainingExample) -> String {
format!(
"({} -> {})",
format_token_id(pair.first()),
format_token_id(pair.second())
)
}
fn format_token_id(token: &TokenId) -> String {
format!("TokenId({})", token.index())
}
examples/01_domain_objects.rs
use category_theory_transformer_rs::{
CtResult, DatasetWindowing, Morphism, TokenId, TokenSequence,
};
fn main() -> CtResult<()> {
let tokens = TokenSequence::from_indices([1, 2, 3, 4])?;
let dataset = DatasetWindowing.apply(tokens.clone())?;
println!("TokenSequence:");
println!("{}", format_token_sequence(tokens.as_slice()));
println!();
println!("TrainingSet:");
for example in dataset.examples() {
println!(
"({} -> {})",
format_token_id(example.first()),
format_token_id(example.second())
);
}
println!();
println!("Typed boundaries:");
println!("usize -> TokenId");
println!("Vec<TokenId> -> TokenSequence");
println!("TokenSequence -> TrainingSet");
println!("TrainingExample = Product<TokenId, TokenId>");
Ok(())
}
fn format_token_sequence(tokens: &[TokenId]) -> String {
let formatted = tokens
.iter()
.map(format_token_id)
.collect::<Vec<_>>()
.join(", ");
format!("[{formatted}]")
}
fn format_token_id(token: &TokenId) -> String {
format!("TokenId({})", token.index())
}
examples/02_morphism_composition.rs
use category_theory_transformer_rs::{
Compose, CtResult, Distribution, Embedding, LinearToLogits, Logits, ModelDimension, Morphism,
Parameters, Softmax, TokenId, Vector, VocabSize,
};
fn main() -> CtResult<()> {
let params = Parameters::init(VocabSize::new(5)?, ModelDimension::new(4)?);
let token = TokenId::new(1);
let embedding = Embedding::from_parameters(¶ms);
let linear = LinearToLogits::from_parameters(¶ms);
let token_to_logits = Compose::<_, _, Vector>::new(embedding.clone(), linear.clone());
let token_to_distribution = Compose::<_, _, Logits>::new(token_to_logits, Softmax);
let vector = embedding.apply(token)?;
let logits = linear.apply(vector.clone())?;
let distribution = Softmax.apply(logits.clone())?;
let composed_distribution = token_to_distribution.apply(token)?;
println!("Input object:");
println!("TokenId({})", token.index());
println!();
println!("Stage outputs:");
println!("Embedding : TokenId -> Vector");
println!("{}", format_vector(&vector));
println!("LinearToLogits : Vector -> Logits");
println!("{}", format_logits(&logits));
println!("Softmax : Logits -> Distribution");
println!("{}", format_distribution(&distribution));
println!();
println!("Composed morphism:");
println!("TokenId -> Distribution");
println!(
"next-token probabilities: {}",
format_values(composed_distribution.as_slice())
);
println!();
println!("Middle objects kept visible:");
println!("Vector");
println!("Logits");
println!();
println!("Composition rule:");
println!("first target must equal second source");
println!("Embedding then LinearToLogits is legal because Vector == Vector");
println!("Embedding then Softmax is illegal because Vector != Logits");
Ok(())
}
fn format_vector(vector: &Vector) -> String {
format!(
"Vector(dim={}, values={})",
vector.as_slice().len(),
format_values(vector.as_slice())
)
}
fn format_logits(logits: &Logits) -> String {
format!(
"Logits(vocab={}, values={})",
logits.as_slice().len(),
format_values(logits.as_slice())
)
}
fn format_distribution(distribution: &Distribution) -> String {
format!(
"Distribution(vocab={}, sum={:.6}, values={})",
distribution.as_slice().len(),
distribution.as_slice().iter().sum::<f32>(),
format_values(distribution.as_slice())
)
}
fn format_values(values: &[f32]) -> String {
let formatted = values
.iter()
.map(|value| format!("{value:.6}"))
.collect::<Vec<_>>()
.join(", ");
format!("[{formatted}]")
}
examples/03_training_endomorphism.rs
use category_theory_transformer_rs::{
CtResult, DatasetWindowing, LearningRate, ModelDimension, Morphism, Parameters, StepCount,
TokenSequence, TrainStep, VocabSize, apply_endomorphism_n_times, average_loss,
};
fn main() -> CtResult<()> {
let tokens = TokenSequence::from_indices([1, 2, 3, 4, 1, 2, 3, 4])?;
let dataset = DatasetWindowing.apply(tokens)?;
let params = Parameters::init(VocabSize::new(5)?, ModelDimension::new(4)?);
let before = average_loss(¶ms, &dataset)?;
let train_step = TrainStep::new(dataset.clone(), LearningRate::new(1.0)?);
let trained = apply_endomorphism_n_times(&train_step, params, StepCount::new(80))?;
let after = average_loss(&trained, &dataset)?;
println!("loss before: {:.6}", before.value());
println!("loss after: {:.6}", after.value());
println!();
println!("Typed transformation:");
println!("TrainStep : Parameters -> Parameters");
println!("Repeated endomorphism:");
println!("Parameters0 -> Parameters1 -> ... -> Parameters80");
println!("Measurement:");
println!("Parameters x TrainingSet -> Loss");
Ok(())
}
examples/04_structure_and_calculus.rs
use category_theory_transformer_rs::{
CtResult, Functor, LocalGradient, Monoid, MulOp, OptionFunctor, PipelineTrace, Scalar,
TraceStep, VecFunctor, monoid_laws_hold_for_pipeline_trace,
naturality_square_holds_for_first_option,
};
fn main() -> CtResult<()> {
let squared = VecFunctor::fmap(vec![1, 2, 3], |x| x * x);
let shifted = OptionFunctor::fmap(Some(7), |x| x + 1);
println!("Vec fmap square: {squared:?}");
println!("Option fmap +1: {shifted:?}");
println!(
"naturality square holds: {}",
naturality_square_holds_for_first_option()
);
let trace = PipelineTrace::from_steps(vec![TraceStep::new("embedding")])
.combine(&PipelineTrace::from_steps(vec![TraceStep::new("linear")]))
.combine(&PipelineTrace::from_steps(vec![TraceStep::new("softmax")]));
println!("trace: {:?}", trace.names());
println!(
"monoid laws hold: {}",
monoid_laws_hold_for_pipeline_trace()
);
let mul = MulOp;
let x = Scalar::new(2.0)?;
let y = Scalar::new(3.0)?;
let upstream = LocalGradient::new(1.0)?;
let (dl_dx, dl_dy) = mul.backward(x, y, upstream)?;
println!("dL/dx: {}", dl_dx.value());
println!("dL/dy: {}", dl_dy.value());
println!();
println!("Typed transformation:");
println!("VecFunctor::fmap : Vec<A> x (A -> B) -> Vec<B>");
println!("OptionFunctor::fmap : Option<A> x (A -> B) -> Option<B>");
println!("Naturality square:");
println!("Vec<A> -> Vec<B> -> Option<B>");
println!("Vec<A> -> Option<A> -> Option<B>");
println!("Monoid:");
println!("PipelineTrace x PipelineTrace -> PipelineTrace");
println!("Chain rule:");
println!("Scalar x Scalar -> Scalar");
println!("dL/dz -> (dL/dx, dL/dy)");
Ok(())
}
examples/05_seven_sketches.rs
use category_theory_transformer_rs::{
CircuitComponent, CompanyInstance, CtResult, DepartmentId, DesignRequirement, EmployeeId,
EmployeeRecord, FeasibilityRelation, FeatureCount, ImplementationOffer, InformationLevel,
LatencyMs, LayerBudget, LocalSafetyCheck, MatrixCols, MatrixRows, OpenCircuit, PortName,
ResistanceOhms, ResourceAmount, ResourceBundle, SafetyCover, SignalCoefficient, SignalMatrix,
Throughput, TimeInterval, TimeTick, TruthValue, abstract_to_layer_budget,
concretize_layer_budget, feature_layer_galois_law_holds, information_order_obeys_preorder_laws,
resource_tensor_is_monotone,
};
fn main() -> CtResult<()> {
println!(
"orders obey preorder laws: {}",
information_order_obeys_preorder_laws()
);
println!(
"join(feature, decision): {:?}",
InformationLevel::Feature.join(InformationLevel::Decision)
);
let features = FeatureCount::new(9)?;
let layers = LayerBudget::new(3)?;
println!(
"feature/layer Galois law: {}",
feature_layer_galois_law_holds(features, layers)?
);
println!(
"abstract 9 features to {} layers; concretize 3 layers to {} features",
abstract_to_layer_budget(features)?.value(),
concretize_layer_budget(layers).value()
);
let encoder = ResourceBundle::new(ResourceAmount::new(2), ResourceAmount::new(8));
let decoder = ResourceBundle::new(ResourceAmount::new(3), ResourceAmount::new(10));
println!("combined resource bundle: {:?}", encoder.tensor(&decoder));
println!(
"resource tensor monotone: {}",
resource_tensor_is_monotone()
);
let research = DepartmentId::new(1);
let platform = DepartmentId::new(2);
let ada = EmployeeId::new(7);
let instance =
CompanyInstance::new([research, platform], [EmployeeRecord::new(ada, research)])?;
println!(
"employee {:?} belongs to department {:?}",
ada,
instance.department_of(ada)
);
let requirement = DesignRequirement::new(Throughput::new(100)?, LatencyMs::new(80)?);
let offer = ImplementationOffer::new(Throughput::new(120)?, LatencyMs::new(50)?);
println!(
"co-design offer feasible: {}",
FeasibilityRelation::relates(requirement, offer)
);
let duplicate = SignalMatrix::new(
MatrixRows::new(2)?,
MatrixCols::new(1)?,
vec![
vec![SignalCoefficient::new(1)],
vec![SignalCoefficient::new(1)],
],
)?;
let add_weighted = SignalMatrix::new(
MatrixRows::new(1)?,
MatrixCols::new(2)?,
vec![vec![SignalCoefficient::new(2), SignalCoefficient::new(3)]],
)?;
let composed = add_weighted.compose_after(&duplicate)?;
println!(
"signal-flow matrix semantics: {:?}",
composed.coefficients()
);
let input = PortName::new("input")?;
let middle = PortName::new("middle")?;
let output = PortName::new("output")?;
let first_circuit = OpenCircuit::new(
[input],
[middle],
[CircuitComponent::resistor(
input,
middle,
ResistanceOhms::new(10)?,
)],
)?;
let second_circuit = OpenCircuit::new(
[middle],
[output],
[CircuitComponent::resistor(
middle,
output,
ResistanceOhms::new(20)?,
)],
)?;
println!(
"serial circuit component count: {}",
first_circuit.then(&second_circuit)?.component_count()
);
let safety = SafetyCover::new([
LocalSafetyCheck::new(
TimeInterval::new(TimeTick::new(0), TimeTick::new(5))?,
TruthValue::True,
),
LocalSafetyCheck::new(
TimeInterval::new(TimeTick::new(5), TimeTick::new(10))?,
TruthValue::True,
),
])?;
println!("global behavior truth: {:?}", safety.global_truth());
println!();
println!("Typed transformation:");
println!("InformationLevel <= InformationLevel checks preorder");
println!("FeatureCount <-> LayerBudget checks Galois law");
println!("ResourceBundle x ResourceBundle -> ResourceBundle");
println!("EmployeeRecord -> DepartmentId must resolve in CompanyInstance");
println!("DesignRequirement x ImplementationOffer -> bool");
println!("SignalMatrix x SignalMatrix -> SignalMatrix when dimensions match");
println!("OpenCircuit x OpenCircuit -> OpenCircuit when ports match");
println!("SafetyCover -> TruthValue");
Ok(())
}
examples/06_attention_scores.rs
use category_theory_transformer_rs::{
AttentionHeadOutputs, AttentionMask, AttentionOutput, AttentionOutputProjection,
AttentionSoftmax, ConcatenateHeads, CtResult, HiddenSequence, HiddenToKey, HiddenToQuery,
HiddenToValue, KeySequence, LayerNormParameters, LayerNormalization, LearningRate,
MaskedAttentionScores, MaskedMultiHeadTransformerBlock, Morphism, MultiHeadTransformerBlock,
PositionWiseFeedForward, PositionalEncoding, Product, QuerySequence, ResidualConnection,
ScaledDotProductScores, SelfAttentionHead, SingleHeadTransformerBlock,
TinyTransformerParameters, TokenSequence, TransformerBlockTrainStep,
TransformerBlockTrainingExample, TransformerBlockTrainingSet, TransformerFeedForwardTrainStep,
TransformerFeedForwardTrainingExample, TransformerFeedForwardTrainingSet, TransformerReadout,
TransformerReadoutTrainStep, TransformerReadoutTrainingExample, TransformerReadoutTrainingSet,
TransformerTrainingState, ValueSequence, WeightedValueMixing, transformer_block_average_loss,
transformer_feed_forward_average_loss, transformer_readout_average_loss,
};
fn main() -> CtResult<()> {
let queries = QuerySequence::new(vec![vec![1.0, 0.0], vec![0.0, 1.0]])?;
let keys = KeySequence::new(vec![vec![1.0, 0.0], vec![0.0, 1.0], vec![1.0, 1.0]])?;
let values = ValueSequence::new(vec![vec![1.0, 10.0], vec![2.0, 20.0], vec![3.0, 30.0]])?;
let mask = AttentionMask::new(vec![vec![true, false, true], vec![true, true, true]])?;
println!("Q/K/V source diagnostic:");
println!("query rows own score rows; key/value rows own score columns");
println!(
"self-attention shares the hidden source before projection; projected roles stay distinct"
);
println!("mask polarity here: true = allowed, false = blocked\n");
let scores = ScaledDotProductScores.apply(Product::new(queries, keys))?;
let masked_scores = MaskedAttentionScores.apply(Product::new(scores, mask))?;
let weights = AttentionSoftmax.apply(masked_scores)?;
println!(
"attention shape: {} query positions x {} key positions",
weights.query_len().value(),
weights.key_len().value()
);
for (query_position, row) in weights.rows().iter().enumerate() {
println!("query {query_position} attends with {:?}", row.as_slice());
}
let output = WeightedValueMixing.apply(Product::new(weights, values))?;
for (query_position, row) in output.rows().iter().enumerate() {
println!("query {query_position} output vector {:?}", row.as_slice());
}
let second_head = AttentionOutput::new(vec![vec![10.0, 1.0], vec![20.0, 2.0]])?;
let head_outputs = AttentionHeadOutputs::new(vec![output, second_head])?;
let multi_head = ConcatenateHeads.apply(head_outputs)?;
println!(
"multi-head shape: {} heads x {} features -> model dimension {}",
multi_head.head_count().value(),
multi_head.head_dimension().value(),
multi_head.model_dimension().value()
);
for (query_position, row) in multi_head.rows().iter().enumerate() {
println!("query {query_position} multi-head row {:?}", row.as_slice());
}
let output_projection = AttentionOutputProjection::new(
vec![
vec![1.0, 0.0],
vec![0.0, 0.1],
vec![0.5, 0.0],
vec![0.0, 1.0],
],
vec![0.0, 0.0],
)?;
let projected = output_projection.apply(multi_head)?;
println!(
"projected attention shape: {} positions x model dimension {}",
projected.sequence_len().value(),
projected.model_dimension().value()
);
for (query_position, row) in projected.rows().iter().enumerate() {
println!(
"query {query_position} projected attention row {:?}",
row.as_slice()
);
}
let hidden_input = HiddenSequence::new(vec![vec![0.5, 0.5], vec![1.0, 1.0]])?;
let residual = ResidualConnection.apply(Product::new(hidden_input, projected))?;
println!(
"residual shape: {} positions x model dimension {}",
residual.sequence_len().value(),
residual.model_dimension().value()
);
for (query_position, row) in residual.rows().iter().enumerate() {
println!("query {query_position} residual row {:?}", row.as_slice());
}
let layer_norm =
LayerNormalization::new(LayerNormParameters::identity(residual.model_dimension()));
let normalized = layer_norm.apply(residual)?;
println!(
"normalized shape: {} positions x model dimension {}",
normalized.sequence_len().value(),
normalized.model_dimension().value()
);
for (query_position, row) in normalized.rows().iter().enumerate() {
println!("query {query_position} normalized row {:?}", row.as_slice());
}
let feed_forward = PositionWiseFeedForward::new(
vec![vec![1.0, -1.0, 0.5], vec![0.0, 1.0, 0.5]],
vec![0.0, 0.0, 0.0],
vec![vec![1.0, 0.0], vec![0.0, 1.0], vec![0.5, 0.5]],
vec![0.0, 0.0],
)?;
let fed_forward = feed_forward.apply(normalized)?;
println!(
"feed-forward shape: {} positions x model dimension {}",
fed_forward.sequence_len().value(),
fed_forward.model_dimension().value()
);
for (query_position, row) in fed_forward.rows().iter().enumerate() {
println!(
"query {query_position} feed-forward row {:?}",
row.as_slice()
);
}
let positional_encoding = PositionalEncoding::new(vec![vec![0.1, 0.0], vec![0.0, 0.1]])?;
let positioned_hidden =
positional_encoding.apply(HiddenSequence::new(vec![vec![1.0, 0.0], vec![0.0, 1.0]])?)?;
println!(
"positioned hidden shape: {} positions x model dimension {}",
positioned_hidden.sequence_len().value(),
positioned_hidden.model_dimension().value()
);
let block = SingleHeadTransformerBlock::new(
HiddenToQuery::new(vec![vec![1.0, 0.0], vec![0.0, 1.0]], vec![0.0, 0.0])?,
HiddenToKey::new(vec![vec![1.0, 0.0], vec![0.0, 1.0]], vec![0.0, 0.0])?,
HiddenToValue::new(vec![vec![1.0, 0.0], vec![0.0, 1.0]], vec![0.0, 0.0])?,
AttentionOutputProjection::new(vec![vec![1.0, 0.0], vec![0.0, 1.0]], vec![0.0, 0.0])?,
LayerNormalization::new(LayerNormParameters::identity(fed_forward.model_dimension())),
PositionWiseFeedForward::new(
vec![vec![1.0, 0.0], vec![0.0, 1.0]],
vec![0.0, 0.0],
vec![vec![1.0, 0.0], vec![0.0, 1.0]],
vec![0.0, 0.0],
)?,
LayerNormalization::new(LayerNormParameters::identity(fed_forward.model_dimension())),
)?;
let block_output = block.apply(positioned_hidden.clone())?;
println!(
"single-head block shape: {} positions x model dimension {}",
block_output.sequence_len().value(),
block_output.model_dimension().value()
);
let multi_head_block = MultiHeadTransformerBlock::new(
vec![
SelfAttentionHead::new(
HiddenToQuery::new(vec![vec![1.0], vec![0.0]], vec![0.0])?,
HiddenToKey::new(vec![vec![1.0], vec![0.0]], vec![0.0])?,
HiddenToValue::new(vec![vec![1.0], vec![0.0]], vec![0.0])?,
)?,
SelfAttentionHead::new(
HiddenToQuery::new(vec![vec![0.0], vec![1.0]], vec![0.0])?,
HiddenToKey::new(vec![vec![0.0], vec![1.0]], vec![0.0])?,
HiddenToValue::new(vec![vec![0.0], vec![1.0]], vec![0.0])?,
)?,
],
AttentionOutputProjection::new(vec![vec![1.0, 0.0], vec![0.0, 1.0]], vec![0.0, 0.0])?,
LayerNormalization::new(LayerNormParameters::identity(
block_output.model_dimension(),
)),
PositionWiseFeedForward::new(
vec![vec![1.0, 0.0], vec![0.0, 1.0]],
vec![0.0, 0.0],
vec![vec![1.0, 0.0], vec![0.0, 1.0]],
vec![0.0, 0.0],
)?,
LayerNormalization::new(LayerNormParameters::identity(
block_output.model_dimension(),
)),
)?;
let multi_head_output = multi_head_block.apply(positioned_hidden)?;
println!(
"multi-head block shape: {} positions x {} heads x value dimension {} -> model dimension {}",
multi_head_output.sequence_len().value(),
multi_head_block.head_count().value(),
multi_head_block.value_dimension().value(),
multi_head_output.model_dimension().value()
);
let masked_multi_head_block = MaskedMultiHeadTransformerBlock::new(
vec![
SelfAttentionHead::new(
HiddenToQuery::new(vec![vec![1.0], vec![0.0]], vec![0.0])?,
HiddenToKey::new(vec![vec![1.0], vec![0.0]], vec![0.0])?,
HiddenToValue::new(vec![vec![1.0], vec![0.0]], vec![0.0])?,
)?,
SelfAttentionHead::new(
HiddenToQuery::new(vec![vec![0.0], vec![1.0]], vec![0.0])?,
HiddenToKey::new(vec![vec![0.0], vec![1.0]], vec![0.0])?,
HiddenToValue::new(vec![vec![0.0], vec![1.0]], vec![0.0])?,
)?,
],
AttentionOutputProjection::new(vec![vec![1.0, 0.0], vec![0.0, 1.0]], vec![0.0, 0.0])?,
LayerNormalization::new(LayerNormParameters::identity(
multi_head_output.model_dimension(),
)),
PositionWiseFeedForward::new(
vec![vec![1.0, 0.0], vec![0.0, 1.0]],
vec![0.0, 0.0],
vec![vec![1.0, 0.0], vec![0.0, 1.0]],
vec![0.0, 0.0],
)?,
LayerNormalization::new(LayerNormParameters::identity(
multi_head_output.model_dimension(),
)),
)?;
let masked_block_output = masked_multi_head_block.apply(Product::new(
HiddenSequence::new(vec![vec![1.0, 0.0], vec![0.0, 1.0]])?,
AttentionMask::new(vec![vec![true, false], vec![true, true]])?,
))?;
println!(
"masked multi-head block shape: {} positions x model dimension {}",
masked_block_output.sequence_len().value(),
masked_block_output.model_dimension().value()
);
let transformer_parameters = TinyTransformerParameters::new(
PositionalEncoding::new(vec![vec![0.1, 0.0], vec![0.0, 0.1]])?,
masked_multi_head_block,
TransformerReadout::new(
vec![vec![1.0, 0.0, 0.5], vec![0.0, 1.0, -0.5]],
vec![0.0, 0.0, 0.0],
)?,
)?;
let transformer_state =
TransformerTrainingState::new(transformer_parameters, LearningRate::new(0.1)?);
let training_set =
TransformerReadoutTrainingSet::new([TransformerReadoutTrainingExample::new(
HiddenSequence::new(vec![vec![1.0, 0.0], vec![0.0, 1.0]])?,
AttentionMask::new(vec![vec![true, false], vec![true, true]])?,
TokenSequence::from_indices([0, 1])?,
)?])?;
let sequence_logits = transformer_state.apply(Product::new(
HiddenSequence::new(vec![vec![1.0, 0.0], vec![0.0, 1.0]])?,
AttentionMask::new(vec![vec![true, false], vec![true, true]])?,
))?;
let loss_before = transformer_readout_average_loss(&transformer_state, &training_set)?;
let train_step = TransformerReadoutTrainStep::new(training_set.clone());
let next_state = train_step.apply(transformer_state.clone())?;
let loss_after = transformer_readout_average_loss(&next_state, &training_set)?;
println!(
"structured transformer logits shape: {} positions x vocabulary size {}",
sequence_logits.sequence_len().value(),
sequence_logits.vocab_size().value()
);
println!(
"training state step: {} -> {}",
transformer_state.step_count().value(),
next_state.step_count().value()
);
println!(
"readout loss after one update: {:.6} -> {:.6}",
loss_before.value(),
loss_after.value()
);
let feed_forward_training_set =
TransformerFeedForwardTrainingSet::new([TransformerFeedForwardTrainingExample::new(
HiddenSequence::new(vec![vec![1.0, 0.0], vec![0.0, 1.0]])?,
HiddenSequence::new(vec![vec![2.0, 0.0], vec![0.0, 2.0]])?,
)?])?;
let feed_forward_loss_before =
transformer_feed_forward_average_loss(&next_state, &feed_forward_training_set)?;
let feed_forward_train_step =
TransformerFeedForwardTrainStep::new(feed_forward_training_set.clone());
let feed_forward_state = feed_forward_train_step.apply(next_state.clone())?;
let feed_forward_loss_after =
transformer_feed_forward_average_loss(&feed_forward_state, &feed_forward_training_set)?;
println!(
"feed-forward loss after one local update: {:.6} -> {:.6}",
feed_forward_loss_before.value(),
feed_forward_loss_after.value()
);
let block_training_set =
TransformerBlockTrainingSet::new([TransformerBlockTrainingExample::new(
HiddenSequence::new(vec![vec![1.0, 0.0], vec![0.0, 1.0]])?,
AttentionMask::new(vec![vec![true, false], vec![true, true]])?,
TokenSequence::from_indices([0, 1])?,
)?])?;
let block_loss_before =
transformer_block_average_loss(&feed_forward_state, &block_training_set)?;
let block_train_step = TransformerBlockTrainStep::new(block_training_set.clone());
let block_trained_state = block_train_step.apply(feed_forward_state)?;
let block_loss_after =
transformer_block_average_loss(&block_trained_state, &block_training_set)?;
println!(
"block loss after one composed update: {:.6} -> {:.6}",
block_loss_before.value(),
block_loss_after.value()
);
println!();
println!("Typed transformation:");
println!("HiddenSequence -> QuerySequence");
println!("HiddenSequence -> KeySequence");
println!("HiddenSequence -> ValueSequence");
println!("QuerySequence x KeySequence -> AttentionScores");
println!("AttentionScores x AttentionMask -> AttentionScores");
println!("AttentionScores -> AttentionWeights");
println!("AttentionWeights x ValueSequence -> AttentionOutput");
println!("AttentionHeadOutputs -> MultiHeadOutput");
println!("MultiHeadOutput -> ProjectedAttentionOutput");
println!("HiddenSequence x ProjectedAttentionOutput -> HiddenSequence");
println!("LayerNormalization : HiddenSequence -> HiddenSequence");
println!("PositionWiseFeedForward : HiddenSequence -> HiddenSequence");
println!("PositionalEncoding : HiddenSequence -> HiddenSequence");
println!("SingleHeadTransformerBlock : HiddenSequence -> HiddenSequence");
println!("MultiHeadTransformerBlock : HiddenSequence -> HiddenSequence");
println!("MaskedMultiHeadTransformerBlock : HiddenSequence x AttentionMask -> HiddenSequence");
println!("TinyTransformerParameters : HiddenSequence x AttentionMask -> SequenceLogits");
println!("TransformerTrainingState owns parameters, learning rate, and step count");
println!("TransformerReadoutTrainStep : TransformerTrainingState -> TransformerTrainingState");
println!(
"TransformerFeedForwardTrainStep : TransformerTrainingState -> TransformerTrainingState"
);
println!("TransformerBlockTrainStep : TransformerTrainingState -> TransformerTrainingState");
Ok(())
}
examples/07_transformer_training_state.rs
use category_theory_transformer_rs::{
AttentionMask, AttentionOutputProjection, CtResult, HiddenSequence, HiddenToKey, HiddenToQuery,
HiddenToValue, LayerNormParameters, LayerNormalization, LearningRate,
MaskedMultiHeadTransformerBlock, ModelDimension, Morphism, PositionWiseFeedForward,
PositionalEncoding, Product, SelfAttentionHead, TinyTransformerParameters, TokenSequence,
TransformerBlockTrainStep, TransformerBlockTrainingExample, TransformerBlockTrainingSet,
TransformerFeedForwardTrainStep, TransformerFeedForwardTrainingExample,
TransformerFeedForwardTrainingSet, TransformerReadout, TransformerReadoutTrainStep,
TransformerReadoutTrainingExample, TransformerReadoutTrainingSet, TransformerTrainingState,
transformer_block_average_loss, transformer_feed_forward_average_loss,
transformer_readout_average_loss,
};
fn main() -> CtResult<()> {
let hidden = HiddenSequence::new(vec![vec![1.0, 0.0], vec![0.0, 1.0]])?;
let mask = AttentionMask::new(vec![vec![true, false], vec![true, true]])?;
let targets = TokenSequence::from_indices([0, 1])?;
let initial_state = tiny_training_state()?;
let logits = initial_state.apply(Product::new(hidden.clone(), mask.clone()))?;
println!(
"initial state: step={}, learning_rate={:.3}, model_dimension={}, vocab_size={}",
initial_state.step_count().value(),
initial_state.learning_rate().value(),
initial_state.parameters().model_dimension().value(),
initial_state.parameters().vocab_size().value()
);
println!(
"forward shape: {} positions x vocabulary size {}",
logits.sequence_len().value(),
logits.vocab_size().value()
);
let readout_set =
TransformerReadoutTrainingSet::new([TransformerReadoutTrainingExample::new(
hidden.clone(),
mask.clone(),
targets.clone(),
)?])?;
let readout_loss_before = transformer_readout_average_loss(&initial_state, &readout_set)?;
let readout_state =
TransformerReadoutTrainStep::new(readout_set.clone()).apply(initial_state)?;
let readout_loss_after = transformer_readout_average_loss(&readout_state, &readout_set)?;
print_update(
"readout update",
0,
&readout_state,
readout_loss_before.value(),
readout_loss_after.value(),
);
let feed_forward_set =
TransformerFeedForwardTrainingSet::new([TransformerFeedForwardTrainingExample::new(
hidden.clone(),
HiddenSequence::new(vec![vec![2.0, 0.0], vec![0.0, 2.0]])?,
)?])?;
let feed_forward_loss_before =
transformer_feed_forward_average_loss(&readout_state, &feed_forward_set)?;
let feed_forward_state =
TransformerFeedForwardTrainStep::new(feed_forward_set.clone()).apply(readout_state)?;
let feed_forward_loss_after =
transformer_feed_forward_average_loss(&feed_forward_state, &feed_forward_set)?;
print_update(
"feed-forward update",
1,
&feed_forward_state,
feed_forward_loss_before.value(),
feed_forward_loss_after.value(),
);
let block_set = TransformerBlockTrainingSet::new([TransformerBlockTrainingExample::new(
hidden, mask, targets,
)?])?;
let block_loss_before = transformer_block_average_loss(&feed_forward_state, &block_set)?;
let block_state =
TransformerBlockTrainStep::new(block_set.clone()).apply(feed_forward_state)?;
let block_loss_after = transformer_block_average_loss(&block_state, &block_set)?;
print_update(
"composed block update",
2,
&block_state,
block_loss_before.value(),
block_loss_after.value(),
);
println!();
println!("Typed transformation:");
println!("TinyTransformerParameters : HiddenSequence x AttentionMask -> SequenceLogits");
println!("TransformerTrainingState owns parameters, learning rate, and step count");
println!("TransformerReadoutTrainStep : TransformerTrainingState -> TransformerTrainingState");
println!(
"TransformerFeedForwardTrainStep : TransformerTrainingState -> TransformerTrainingState"
);
println!("TransformerBlockTrainStep : TransformerTrainingState -> TransformerTrainingState");
println!("Every update returns a full training state, not loose changed weights.");
Ok(())
}
fn print_update(
label: &str,
previous_step: usize,
state: &TransformerTrainingState,
loss_before: f32,
loss_after: f32,
) {
println!(
"{label}: step {} -> {}, loss {:.6} -> {:.6}",
previous_step,
state.step_count().value(),
loss_before,
loss_after
);
}
fn tiny_training_state() -> CtResult<TransformerTrainingState> {
let model_dimension = ModelDimension::new(2)?;
let block = MaskedMultiHeadTransformerBlock::new(
vec![
SelfAttentionHead::new(
HiddenToQuery::new(vec![vec![1.0], vec![0.0]], vec![0.0])?,
HiddenToKey::new(vec![vec![1.0], vec![0.0]], vec![0.0])?,
HiddenToValue::new(vec![vec![1.0], vec![0.0]], vec![0.0])?,
)?,
SelfAttentionHead::new(
HiddenToQuery::new(vec![vec![0.0], vec![1.0]], vec![0.0])?,
HiddenToKey::new(vec![vec![0.0], vec![1.0]], vec![0.0])?,
HiddenToValue::new(vec![vec![0.0], vec![1.0]], vec![0.0])?,
)?,
],
AttentionOutputProjection::new(vec![vec![1.0, 0.0], vec![0.0, 1.0]], vec![0.0, 0.0])?,
LayerNormalization::new(LayerNormParameters::identity(model_dimension)),
PositionWiseFeedForward::new(
vec![vec![1.0, 0.0], vec![0.0, 1.0]],
vec![0.0, 0.0],
vec![vec![1.0, 0.0], vec![0.0, 1.0]],
vec![0.0, 0.0],
)?,
LayerNormalization::new(LayerNormParameters::identity(model_dimension)),
)?;
let parameters = TinyTransformerParameters::new(
PositionalEncoding::new(vec![vec![0.1, 0.0], vec![0.0, 0.1]])?,
block,
TransformerReadout::new(
vec![vec![1.0, 0.0, 0.5], vec![0.0, 1.0, -0.5]],
vec![0.0, 0.0, 0.0],
)?,
)?;
Ok(TransformerTrainingState::new(
parameters,
LearningRate::new(0.1)?,
))
}
examples/challenge_adam.rs
use category_theory_transformer_rs::{
AdamConfig, AdamDecayRate, AdamEpsilon, AdamGradientVector, AdamModelState,
AdamParameterVector, AdamTrainStep, CtResult, LearningRate, Morphism,
};
fn main() -> CtResult<()> {
let state = AdamModelState::from_parameters(AdamParameterVector::new(vec![1.0, -1.0])?);
let config = AdamConfig::new(
LearningRate::new(0.1)?,
AdamDecayRate::new(0.9)?,
AdamDecayRate::new(0.999)?,
AdamEpsilon::new(1e-8)?,
);
let step = AdamTrainStep::new(AdamGradientVector::new(vec![0.5, -0.25])?, config);
let updated = step.apply(state)?;
println!("Paper-To-Rust: Adam");
println!("paper idea: optimizer state is part of the update boundary");
println!("typed shape: AdamModelState -> AdamModelState");
println!(
"step count: {}, first moment: {:?}, second moment: {:?}",
updated.optimizer().step_count().value(),
updated.optimizer().first_moment().as_slice(),
updated.optimizer().second_moment().as_slice()
);
println!("updated parameters: {:?}", updated.parameters().as_slice());
println!("share line: Stop summarizing Adam. Compile optimizer state.");
Ok(())
}
Project Configuration
Cargo.toml
[package]
name = "category-theory-transformer-rs"
version = "0.1.0"
edition = "2024"
description = "A Rust book and lab for understanding tiny ML systems through category-theory structure."
repository = "https://github.com/hghalebi/category_theory_transformer_rs"
documentation = "https://hghalebi.github.io/category_theory_transformer_rs/"
license-file = "LICENSE.md"
[dependencies]
Companion Lesson Notes
These are the shorter markdown notes kept under lessons/.
lessons/README.md
# Category Theory for Tiny ML: Compact Lesson Path
This folder is the compact reading path through the codebase.
For the complete self-contained course, use the chapters in `book/src/`. They
include source snapshots, runnable examples, exercises, a glossary, references,
and the full source appendix.
These compact lessons are kept for quick review. Each lesson is intentionally
short and points to a real Rust file that `cargo` checks.
## The Learning Loop
For each lesson:
1. Read the mental model.
2. Open the named Rust module.
3. Run the named example.
4. Answer the checkpoint before moving on.
5. If the lesson names a boundary, find the test that proves the boundary.
## Lessons
1. [Map of the Course](00-map.md)
2. [Domain Objects](01-domain-objects.md)
3. [Morphism and Composition](02-morphisms-composition.md)
4. [The Tiny ML Pipeline](03-ml-pipeline.md)
5. [Training as an Endomorphism](04-training-endomorphism.md)
6. [Functors, Naturality, Monoids, and Chain Rule](05-structure-and-calculus.md)
7. [Seven Sketches Through Rust](06-seven-sketches.md)
## Validation
Run the full check:
```bash
cargo test --all-targets --all-features
```
Run one lesson example:
```bash
cargo run --example 03_training_endomorphism
```
lessons/00-map.md
# 00 - Map of the Course
## Goal
Learn one idea:
> Category theory is a language for typed transformations.
In this repo, the transformations are tiny ML operations:
- token to vector
- vector to logits
- logits to probabilities
- prediction plus target to loss
- parameters to better parameters
## The Code Map
- `src/domain.rs`: nouns, also called objects
- `src/category.rs`: arrows, identity, composition, endomorphisms
- `src/ml.rs`: ML arrows
- `src/training.rs`: one training step
- `src/structure.rs`: functor, natural transformation, monoid
- `src/calculus.rs`: local derivative example
- `src/demo.rs`: one guided terminal walkthrough
## First Run
```bash
cargo run --bin category_ml
```
You should see a tiny language-model pipeline and the loss decreasing after
training.
## Checkpoint
Before moving on, say this in your own words:
> A morphism is a typed function. Composition lets small typed functions become
> one larger typed pipeline.
lessons/01-domain-objects.md
# 01 - Domain Objects
## Mental Model
Objects are the nouns of the system.
In normal Rust code, it is tempting to pass `usize`, `Vec<f32>`, and `f32`
everywhere. That is easy at first, then confusing later. This repo wraps those
values in small domain types.
## Read This File
Open `src/domain.rs`.
Focus only on these types first:
- `TokenId`
- `TokenSequence`
- `Vector`
- `Logits`
- `Distribution`
- `Loss`
- `TrainingSet`
- `Parameters`
## Run the Example
```bash
cargo run --example 01_domain_objects
```
Expected shape:
```text
training pairs:
1 -> 2
2 -> 3
3 -> 4
```
## Why This Matters
`TokenId` and `Loss` are not interchangeable, even if both could be represented
by numbers. The Rust type system keeps those ideas separate.
## Checkpoint
What bug becomes harder when `TokenId` is not just a raw `usize`?
lessons/02-morphisms-composition.md
# 02 - Morphism and Composition
## Mental Model
A morphism is a typed arrow:
```text
Input -> Output
```
In Rust, this repo models that with `Morphism<Input, Output>`.
## Read This File
Open `src/category.rs`.
Read in this order:
1. `Morphism`
2. `Identity`
3. `Compose`
4. `Endomorphism`
## Run the Example
```bash
cargo run --example 02_morphism_composition
```
## What to Notice
The pipeline is built from small arrows:
```text
TokenId -> Vector -> Logits -> Distribution
```
The example prints the middle objects:
```text
Vector = hidden features
Logits = vocabulary scores
Distribution = normalized probabilities
```
`Compose` is the glue. If the middle types do not match, Rust rejects the
program before it runs.
## Checkpoint
Why is `TokenId -> Vector -> Logits` easier to debug than one giant
`predict(...)` function?
lessons/03-ml-pipeline.md
# 03 - The Tiny ML Pipeline
## Mental Model
The ML pipeline is a sequence of typed transformations:
```text
TokenSequence -> TrainingSet
TokenId -> Vector
Vector -> Logits
Logits -> Distribution
Distribution x TokenId -> Loss
```
## Read This File
Open `src/ml.rs`.
Read only these structs first:
- `DatasetWindowing`
- `Embedding`
- `LinearToLogits`
- `Softmax`
- `CrossEntropy`
## Run the Full Demo
```bash
cargo run --bin category_ml
```
Look at sections 2 through 5 in the output.
## The Key Detail
`Softmax` validates that its output is a real probability distribution.
`CrossEntropy` validates that the target token is inside the distribution.
Errors happen at the boundary where the bad data is first understood.
## Checkpoint
Where should an out-of-range target token be caught: inside `CrossEntropy`, or
later after loss calculation?
lessons/04-training-endomorphism.md
# 04 - Training as an Endomorphism
## Mental Model
An endomorphism maps a thing back to the same kind of thing:
```text
A -> A
```
Training does exactly that:
```text
Parameters -> Parameters
```
The model changes, but it is still a model.
## Read This File
Open `src/training.rs`.
Read in this order:
1. `TrainStep`
2. `impl Morphism<Parameters, Parameters> for TrainStep`
3. The unit test at the bottom
## Run the Example
```bash
cargo run --example 03_training_endomorphism
```
Expected pattern:
```text
loss before: ...
loss after: ...
```
The second number should be smaller.
## Why This Is Category-Theoretic
`TrainStep` can be repeated because its input type and output type are both
`Parameters`.
That makes this loop type-correct:
```text
Parameters0 -> Parameters1 -> Parameters2 -> ... -> ParametersN
```
## Checkpoint
Why would training be harder to compose if it returned raw vectors instead of
`Parameters`?
lessons/05-structure-and-calculus.md
# 05 - Functors, Naturality, Monoids, and Chain Rule
## Mental Model
This lesson gives names to patterns you already use:
- `Functor`: map inside a wrapper without changing the wrapper shape
- `NaturalTransformation`: convert one wrapper shape to another in a consistent way
- `Monoid`: an empty value plus an associative combine operation
- Chain rule: local gradients compose into larger gradients
## Read These Files
Open:
- `src/structure.rs`
- `src/calculus.rs`
## Run the Example
```bash
cargo run --example 04_structure_and_calculus
```
## What to Notice
The examples are tiny on purpose.
`VecFunctor` and `OptionFunctor` are not trying to replace real Rust APIs. They
show the shape of the idea:
```text
keep the container, transform the inside
```
`PipelineTrace` is a monoid because:
- there is an empty trace
- traces can be combined
- grouping does not change the final trace
`MulOp` shows the smallest useful backward pass:
```text
z = x * y
dL/dx = dL/dz * y
dL/dy = dL/dz * x
```
## Checkpoint
Why is "local rule plus composition" the core idea behind backpropagation?
lessons/06-seven-sketches.md
# 06 - Seven Sketches Through Rust
## Mental Model
Applied category theory is useful when it helps you model real structure:
orders, resources, database references, feasibility relations, signal flow,
open systems, and local-to-global checks.
This lesson uses one method:
```text
engineering problem
-> named Rust values
-> validated construction
-> relation or composition
-> law or boundary test
```
## Read This File
Open `src/sketches.rs`.
Read the tests first. They show the contract each small model is supposed to
protect.
Focus on these examples:
- `InformationLevel::can_flow_to`
- `ResourceBundle::tensor`
- `CompanyInstance::department_for`
- `SignalMatrix::compose_after`
- `OpenCircuit::then`
- `SafetyCover::global_truth`
## Run the Example
```bash
cargo run --example 05_seven_sketches
```
Expected shape:
```text
orders obey preorder laws: true
feature/layer Galois law: true
resource tensor monotone: true
co-design offer feasible: true
global behavior truth: True
```
The exact debug formatting is less important than the contracts. Each output
line is a small executable handle for a larger applied-category idea.
## What to Notice
The negative tests are as important as the positive examples.
They show that the code rejects invalid structure:
- a missing database reference,
- mismatched signal-matrix dimensions,
- an open-circuit boundary mismatch.
That is the same discipline as the tiny ML pipeline. A type or constructor is
valuable when it prevents a wrong connection from becoming ordinary runtime
state.
## Checkpoint
Pick one sketch and answer:
```text
What invalid composition or relationship does this model prevent?
```
A strong answer names the Rust type, the software problem, and the
category-theory shape.