7.9s
sorry I uh the lecture slides for this lecture was not up uh earlier today, but um it is now up. So, if you want to have that at your fingertips to look at, um you should be able to get it through uh the course website. Um but if not, you can it's definitely up there on the
10.4s
lecture was not up uh earlier today, but
13.5s
um it is now up. So, if you want to have
15.9s
that at your fingertips to look at, um
18.0s
you should be able to get it through
20.1s
uh the course website. Um but if not,
22.4s
you can it's definitely up there on the
24.3s
uh GitHub. Today, we're going to talk about um what I think of as kind of like more advanced uh architecture ideas. Um so, last lecture was really about the basic transformer and how, you know, we might tweak different parts of the basic transformer to get to a modern language transformer to get to a modern language model.
26.0s
about um what I think of as kind of like
28.2s
more advanced uh architecture ideas. Um
31.4s
so, last lecture was really about the
33.1s
basic transformer and how, you know, we
36.3s
might tweak different parts of the basic
37.8s
transformer to get to a modern language
40.2s
transformer to get to a modern language model.
41.5s
Um now, I want to talk about things that I think are much more advanced, much more complex uh developments on top of the usual transformer. And the two things I'll talk about uh today, the first one is um attentional alternatives, ways of going to much longer contexts using architectural longer contexts using architectural changes
43.2s
I think are much more advanced, much
45.0s
more complex uh developments on top of
48.0s
the usual transformer. And the two
50.0s
things I'll talk about uh today, the
51.9s
first one is um attentional
54.6s
alternatives, ways of going to much
57.0s
longer contexts using architectural
59.0s
longer contexts using architectural changes
60.1s
that for the most part generally allow for linear time dependence um on the length of the sequence rather than length of the sequence rather than quadratic. Um the second thing I want to talk about is the idea of mixture of experts, right? So, the first part is, you know, modifying the attention block. The
61.9s
for linear time dependence um on the
64.1s
length of the sequence rather than
65.2s
length of the sequence rather than quadratic.
66.4s
Um the second thing I want to talk about
68.1s
is the idea of mixture of experts,
70.0s
right? So, the first part is, you know,
72.2s
modifying the attention block. The
73.8s
second part is going to be modifying the uh MLP part, right? And mixture of experts are going to give us significantly better utilization um in terms of our hardware, in terms of having more parameters um relative to the compute that we have to spend on these models. So, those are the two
75.9s
uh MLP part, right? And mixture of
77.4s
experts are going to give us
79.0s
significantly better utilization um in
81.3s
terms of our hardware, in terms of
82.6s
having more parameters um relative to
85.2s
the compute that we have to spend on
87.0s
these models. So, those are the two
89.2s
parts and the two ideas um for today. So, I'm going to start with attention and I'm going to set the stage. Um it's clear now that people want longer context length, right? You want to pack a lot of things into the context so that your model has more knowledge, maybe it's an agent that operates on lots of
92.8s
So, I'm going to start with attention
95.0s
and I'm going to set the stage. Um it's
97.4s
clear now that people want longer
99.7s
context length, right? You want to pack
101.3s
a lot of things into the context so that
103.0s
your model has more knowledge, maybe
104.6s
it's an agent that operates on lots of
106.2s
things. If you look at, you know, different models over time on the x-axis over here and you know, the context window size, this is in log scale, you know, there's really a clear rush by a lot of the the top LM vendors to provide larger and larger context sizes to support these kinds of, you know, more
108.3s
different models over time on the x-axis
110.6s
over here and you know, the context
112.8s
window size, this is in log scale, you
115.1s
know, there's really a clear rush by a
117.0s
lot of the the top LM vendors to provide
119.4s
larger and larger context sizes to
122.0s
support these kinds of, you know, more
123.5s
complex workloads. Now, how do we control these costs? Well, you know, um all right, oh not not yet. Um the right plot also shows, you know, the ratio between, you know, how much compute cost is spent in the feed forward part of the network versus the attention part of the network as you increase the sequence
125.5s
Now, how do we control these costs?
128.0s
Well, you know, um
130.0s
all right, oh not not yet. Um the right
132.0s
plot also shows, you know, the ratio
133.8s
between, you know, how much compute cost
135.8s
is spent in the feed forward part of the
137.4s
network versus the attention part of the
139.4s
network as you increase the sequence
140.7s
lengths. Um and you see that, you know, the feed forward fairly large at the start, um but grows linearly. Attention, right, is an all-to-all connection between all the different positions, quadratic, so quickly outpaces feed forward um as the sequence length grows, right? So, um you know, it used to be that for big models at like smallish
142.8s
the feed forward fairly large at the
144.7s
start, um but grows linearly. Attention,
147.2s
right, is an all-to-all connection
148.6s
between all the different positions,
150.3s
quadratic, so quickly outpaces feed
152.5s
forward um as the sequence length grows,
154.8s
right? So, um you know, it used to be
157.0s
that for big models at like smallish
158.9s
sequence lengths, feed forward was the dominant cost. As you go to these longer sequence lengths, attention increasingly becomes um more and more of a problem, becomes um more and more of a problem, right? Now, what's kind of the basic toolkit by which we can control these? And we've already talked about these, or well, we
160.8s
dominant cost. As you go to these longer
162.9s
sequence lengths, attention increasingly
164.7s
becomes um more and more of a problem,
167.3s
becomes um more and more of a problem, right?
169.4s
Now, what's kind of the basic toolkit by
172.0s
which we can control these? And we've
173.5s
already talked about these, or well, we
174.9s
talked about one of these and we will talk about another one of these. Um you know, we can basically take things like local attention um and combine them in different kinds of hybrid ways to control the very high cost of global attention, right? Um if you're only doing global attention once every eight
176.7s
talk about another one of these. Um you
179.4s
know, we can basically take things like
182.0s
local attention um and combine them in
184.6s
different kinds of hybrid ways to
187.0s
control the very high cost of global
189.0s
attention, right? Um if you're only
191.1s
doing global attention once every eight
192.7s
layers and all the other ones are these like very local attentions, you've very much controlled the cost. The other thing you can do is systems engineering. Um I think what very underappreciated fact is, you know, constant factors really, really matter, right? Part of the theme of this course is you need to pay
194.8s
like very local attentions, you've very
197.0s
much controlled the cost.
198.8s
The other thing you can do is systems
200.2s
engineering. Um
201.7s
I think what very underappreciated fact
204.0s
is, you know, constant factors really,
206.4s
really matter, right? Part of the theme
207.9s
of this course is you need to pay
210.0s
attention to the details and it's very easy for people trained in like a more classical theory-oriented computer science tradition to be like, Oh, yes, it's the big O that matters, right? Linear or quadratic. Um but I think one of the biggest things that has happened to, you know, contexts and like attention cost is flash
212.4s
easy for people trained in like a more
214.6s
classical theory-oriented computer
216.8s
science tradition to be like, Oh, yes,
218.7s
it's the big O that matters, right?
220.3s
Linear or quadratic.
221.9s
Um but I think one of the biggest things
223.4s
that has happened to, you know, contexts
225.9s
and like attention cost is flash
228.2s
attention, right? Flash attention, we'll talk about this in a lot more detail once we get to the systems lecture and as you do the systems assignments, but flash attention is really just a very clever way of rearranging the um attention operation into a much more systems-friendly way to minimize uh memory sort of transfer overhead.
230.0s
talk about this in a lot more detail
231.7s
once we get to the systems lecture and
233.1s
as you do the systems assignments, but
235.3s
flash attention is really just a very
236.9s
clever way of rearranging the um
240.5s
attention operation into a much more
242.4s
systems-friendly way to minimize uh
244.8s
memory sort of transfer overhead.
247.6s
And by doing so, you can get truly, truly dramatic improvements in uh the performance of your attention. So, you know, base PyTorch is this blue, you know, you can just kind of look at some of the shorter sequence lengths, you're doing like 30 to 40 um op uh flops per sec, a teraflops per second. Um and then
249.7s
truly dramatic improvements in uh the
252.6s
performance of your attention. So, you
253.8s
know, base PyTorch is this blue, you
256.2s
know, you can just kind of look at some
257.5s
of the shorter sequence lengths, you're
259.0s
doing like 30 to 40 um op uh flops per
262.1s
sec, a teraflops per second. Um and then
264.8s
you do flash attention, which is really a constant factor systems improvement, and you're getting these big, big improvements like factors of two improvements. And at some sizes, you can't even fit this in memory anymore, but with flash attention, where you're not materializing these big attention matrices anymore, you can continue to run these, albeit, you know, somewhat
266.2s
a constant factor systems improvement,
268.3s
and you're getting these big, big
269.6s
improvements like factors of two
271.0s
improvements. And at some sizes, you
272.7s
can't even fit this in memory anymore,
274.6s
but with flash attention, where you're
275.8s
not materializing these big attention
277.6s
matrices anymore, you can continue to
280.0s
run these, albeit, you know, somewhat
281.4s
slowly, right? This doesn't fix any of the quadratic cost issues, but constant factors are very, very powerful. Now, that said, you know, if we're going to to five, 10 million tokens, these tricks might not be enough. We want much more radical, much larger gains in our ability to handle long contexts. So,
283.6s
the quadratic cost issues, but constant
286.2s
factors are very, very powerful.
288.6s
Now, that said, you know, if we're going
290.4s
to to five, 10 million tokens, these
294.1s
tricks might not be enough. We want much
296.3s
more radical, much larger gains in our
298.8s
ability to handle long contexts. So,
301.6s
this leads us to the question, is there ways for us to have linear time or linear dependence on the length of the linear dependence on the length of the sequence? And what would those kinds of things um look like? Okay, there we go. Um it turns out that, you know, a lot of
303.3s
ways for us to have linear time or
305.8s
linear dependence on the length of the
307.3s
linear dependence on the length of the sequence?
308.4s
And what would those kinds of things um
310.8s
look like?
312.8s
Okay, there we go.
314.8s
Um it turns out that, you know, a lot of
318.0s
people have tried a lot of different things over the years. Um and I think there are a few false starts, um but in the recent years, I would say the last two-ish years, there's been kind of an emergence of a a set of recipes that are fairly effective in linear time attention that are now kind of
319.3s
things over the years. Um and I think
321.8s
there are a few false starts, um but in
324.2s
the recent years, I would say the last
326.7s
two-ish years, there's been kind of an
328.8s
emergence of a a set of recipes that are
331.5s
fairly effective in linear time
333.2s
attention that are now kind of
334.4s
battle-tested at scale. Um and that's kind of why uh this is finally the first year when I'm, you know, talking about linear time attention because it's it's clear now that these things really work at scale and in production and so on. and in production and so on. Um to really um explain all of these different
336.4s
kind of why
337.8s
uh this is finally the first year when
339.4s
I'm, you know, talking about linear time
340.8s
attention because it's it's clear now
342.8s
that these things really work at scale
344.4s
and in production and so on.
346.4s
and in production and so on. Um
347.8s
to really
349.2s
um explain all of these different
350.6s
methods that work well, I think you need to really only understand one core idea to start with and then we're going to elaborate and build upon that idea. Um and that idea to start with is the Um and that idea to start with is the associativity um of multiplication. So, let's think
352.3s
to really only understand one core idea
354.3s
to start with and then we're going to
355.7s
elaborate and build upon that idea.
358.2s
Um and that idea to start with is the
360.4s
Um and that idea to start with is the associativity
361.9s
um of multiplication. So, let's think
363.8s
about how attention works. Um we can write it in this compact notation where we say, you know, the the attention operation is going to be have Qs, Ks, and Vs, right, with these, you know, matrix dimensions up here. Um and you know, I'm going to take my Qs and the Ks and I'm going to have a
366.1s
write it in this compact notation where
367.6s
we say, you know, the the attention
368.8s
operation is going to be have Qs, Ks,
371.2s
and Vs, right, with these, you know,
372.8s
matrix dimensions up here.
374.7s
Um and you know, I'm going to take my Qs
376.6s
and the Ks and I'm going to have a
377.6s
all-to-all interaction, right? All the N positions are going to interact with each other by matrix multiplication. And then I'll normalize them through a softmax, this this row over here, and then I will multiply it with my my my values V to get the attention outputs, right? This is a classic, you know, self-attention or whatever other
379.5s
positions are going to interact with
380.9s
each other by matrix multiplication. And
382.8s
then I'll normalize them through a
384.2s
softmax, this this row over here, and
387.2s
then I will multiply it with my my my
389.4s
values V to get the attention outputs,
392.2s
right? This is a classic, you know,
393.9s
self-attention or whatever other
395.3s
self-attention or whatever other attention uh operation. Now, you know, for the moment, let's forget that, you know, attention has a softmax and then I'm just going to drop the row, like forget it for the moment, the row, like forget it for the moment, right? Now, if we didn't have a row, maybe we
396.5s
uh operation.
398.3s
Now, you know, for the moment, let's
400.5s
forget that, you know, attention has a
402.0s
softmax and then I'm just going to drop
404.1s
the row, like forget it for the moment,
405.9s
the row, like forget it for the moment, right?
406.9s
Now, if we didn't have a row, maybe we
409.6s
can actually change the complexity of this operation. So, you know, if you look at attention, it's, you know, N squared DK, right? Because we're taking this QK transpose multiplication. But if row was the identity, I can just move the parentheses around a little bit and I can say, all right, instead of,
411.2s
this operation. So, you know, if you
413.0s
look at attention, it's, you know, N
414.3s
squared DK, right? Because we're taking
416.3s
this QK transpose multiplication.
418.8s
But if row was the identity, I can just
420.8s
move the parentheses around a little bit
422.9s
and I can say, all right, instead of,
424.5s
you know, QK transpose V, really I'm going to have what I'm going to have is Q K transpose V, right? And if we do this reordering, um well, we've like changed which part is quadratic, right? Like instead of having the N squareds, which were pretty terrible because these Ns are very big, they're millions,
426.6s
going to have what I'm going to have is
428.0s
Q K transpose V, right? And if we do
430.2s
this reordering, um well, we've like
432.8s
changed which part is quadratic, right?
435.2s
Like instead of having the N squareds,
437.2s
which were pretty terrible because these
439.1s
Ns are very big, they're millions,
440.4s
right? That's the context length. Um instead, I can have dependence on uh NDVD K, right? So, now we're multiplying um essentially the key and the value dimensions. These could be big in principle, right? I can have a big model, but these are usually on the order of like thousands, tens of
442.2s
Um instead, I can have dependence on uh
444.5s
NDVD K, right? So, now we're multiplying
447.8s
um essentially the key and the value
449.4s
dimensions. These could be big in
451.2s
principle, right? I can have a big
452.3s
model, but these are usually on the
454.3s
order of like thousands, tens of
456.0s
thousands, right? No one has like a a million coordinates in their uh in their hidden uh dimensions. And so, you know, this is a much more favorable uh sort of, you know, uh term to be dependent on compared to this N squared term when we were, you know, doing this original attention operation,
457.6s
million coordinates in their
460.4s
uh in their hidden uh dimensions.
463.2s
And so, you know, this is a much more
464.6s
favorable uh sort of, you know, uh term
467.8s
to be dependent on compared to this N
469.6s
squared term when we were, you know,
471.5s
doing this original attention operation,
473.5s
right? So, this is the first thing that you should know, which is like associativity of multiplication allows us to change kind of the the dependence on the sequence length. on the sequence length. Okay. Um now, the other thing that we can do now is, you know, if we consider this reordering, this like QK transpose V,
476.0s
you should know, which is like
476.8s
associativity of multiplication allows
479.5s
us to change kind of the the dependence
482.2s
on the sequence length.
484.4s
on the sequence length. Okay.
485.2s
Um now,
487.1s
the other thing that we can do now is,
489.8s
you know, if we consider this
490.9s
reordering, this like QK transpose V,
494.5s
reordering, this like QK transpose V, um something that's even nicer about this, and I think this has led to really a lot of the research that, you know, I'm going to mention in the next couple slides, is that, you know, this matrix multiply on the right in the self-attention case looks just like an
495.1s
something that's even nicer about this,
496.9s
and I think this has led to really a lot
498.8s
of the research that, you know, I'm
500.6s
going to mention in the next couple
502.0s
slides, is that, you know, this matrix
504.3s
multiply on the right in the
505.6s
self-attention case looks just like an
507.9s
RNN, right? Instead of writing it in this dense form up here, I can write it incrementally as I sweep left to right on my context, you know, I can multiply my Ks and Vs, right? And I can update this thing S, which is a state, and I can add it to my previous state. This is
510.2s
this dense form up here, I can write it
512.5s
incrementally as I sweep left to right
514.6s
on my context, you know, I can multiply
516.7s
my Ks and Vs, right? And I can update
519.2s
this thing S, which is a state, and I
521.8s
can add it to my previous state. This is
523.7s
kind of like incrementally multiplying and accumulating my KVs, and I can multiply it with my Q to get my output, right? And this, you know, dense operation up here and this RNN form, you know, are kind of equivalent. And so, you know, what have we done? Well, we've started from linear attention on the top
525.4s
and accumulating my KVs, and I can
527.7s
multiply it with my Q to get my output,
530.1s
right? And this, you know, dense
532.1s
operation up here and this RNN form, you
535.0s
know, are kind of equivalent. And so,
537.2s
you know, what have we done? Well, we've
538.6s
started from linear attention on the top
541.3s
left and then we've somehow gotten something that looks very much like an RNN. And what's nice about RNNs? Well, RNNs are quite nice because at inference time, right, I have this S, which is my state, you know, it's fixed in size and I'm just carrying that forward and forward, right? So, RNNs are very nice
543.4s
something that looks very much like an
545.3s
RNN. And what's nice about RNNs? Well,
547.8s
RNNs are quite nice because at inference
549.8s
time, right, I have this S, which is my
552.2s
state, you know, it's fixed in size and
554.4s
I'm just carrying that forward and
555.8s
forward, right? So, RNNs are very nice
557.5s
for inference reasons, um they are, you know, not nice for training reasons, but nice thing about linear attention is you can use them in either way, right? You can have it in this dense form, which will be great for training, it's parallel, or you can have it in this sort of serial form like an RNN, which
560.5s
know, not nice for training reasons, but
563.3s
nice thing about linear attention is you
564.6s
can use them in either way, right? You
566.0s
can have it in this dense form, which
567.3s
will be great for training, it's
568.6s
parallel, or you can have it in this
570.6s
sort of serial form like an RNN, which
573.0s
is great for inference, right? So, you can get the best of both worlds. Now, um at least in terms of like dependence of all these objects, this is great, this is lovely. Um unfortunately, this is linear, which is not very good. this is linear, which is not very good. Um you know, this is really sort of the
574.5s
can get the best of both worlds.
576.7s
Now, um at least in terms of like
578.3s
dependence of all these objects, this is
580.1s
great, this is lovely. Um unfortunately,
582.5s
this is linear, which is not very good.
585.1s
this is linear, which is not very good. Um
586.2s
you know, this is really sort of the
587.6s
starting point of our discussion. Um any questions so far? This is really the questions so far? This is really the basics. basics. Okay. Okay. Um One cool thing is um there have been people that have basically, uh um, used this linear attention as part of their, um, of their attention mechanism. Um, so
589.8s
questions so far? This is really the
590.8s
questions so far? This is really the basics.
592.3s
basics. Okay.
593.6s
Okay. Um
595.4s
One cool thing is um there have been
597.5s
people that have basically, uh um,
600.7s
used this linear attention as part of
602.8s
their, um,
604.8s
of their attention mechanism. Um, so
607.2s
Minimax M1, which is a fairly large-scale, fairly high-performance Chinese, um, open model, um, uses a 71 hybrid. So, seven, you know, linear attention layers plus one full, uh, softmax attention. Um, performance generally strong. You can compare it to a number of, you know, strong models like, you know, Open as 03 or DeepSeek
609.8s
large-scale, fairly high-performance
612.0s
Chinese, um, open model, um, uses a 71
616.2s
hybrid. So, seven, you know, linear
618.2s
attention layers plus one full, uh,
620.8s
softmax attention. Um, performance
623.4s
generally strong. You can compare it to
625.5s
a number of, you know, strong models
627.3s
like, you know, Open as 03 or DeepSeek
629.2s
R1, fairly competitive. And then, you know, since it's linear attention, most of the dependence on context length is linear, you know, it's not fully linear cuz you've got these softmaxes, but the dependence is much, much milder, right? dependence is much, much milder, right? Um, so, you know, this is one example where
631.8s
know, since it's linear attention, most
634.0s
of the dependence on context length is
635.6s
linear, you know, it's not fully linear
637.1s
cuz you've got these softmaxes, but the
639.0s
dependence is much, much milder, right?
641.4s
dependence is much, much milder, right? Um,
642.7s
so, you know, this is one example where
646.1s
people have shown that this even this basic linear attention, which is very, very simple to understand, um, can be, you know, used fairly effectively at scale. As long as you've got some sort of full, uh, quadratic connections, uh, involved. Um, no one has thus far really proven out, uh, fully linear time attention mechanisms at scale. Uh,
647.6s
basic linear attention, which is very,
649.2s
very simple to understand, um, can be,
652.2s
you know, used fairly effectively at
654.0s
scale. As long as you've got some sort
656.0s
of full, uh, quadratic connections, uh,
658.7s
involved. Um, no one has thus far really
661.2s
proven out, uh, fully linear time
663.6s
attention mechanisms at scale. Uh,
665.4s
everything that I'm going to talk about in the next couple slides is a hybrid. in the next couple slides is a hybrid. Okay. So, now, you know, we've done linear attention, but you might say, Linear attention's too naive and simple for me. Like, this is not what I want. I want some more like serious, complex,
666.8s
in the next couple slides is a hybrid.
670.3s
in the next couple slides is a hybrid. Okay.
671.6s
So, now, you know, we've done linear
673.4s
attention, but you might say, Linear
675.0s
attention's too naive and simple for me.
677.4s
Like, this is not what I want. I want
678.9s
some more like serious, complex,
682.0s
um, neural network stuff to be happening. Well, um, what you can now do is, uh, oh, sorry. Went a little too do is, uh, oh, sorry. Went a little too far. Um, look at the RNN form of linear attention. So, you know, I've put the linear attention equation right here just so you don't have to keep going
683.4s
happening. Well, um, what you can now
686.1s
do is, uh, oh, sorry. Went a little too
689.2s
do is, uh, oh, sorry. Went a little too far.
690.3s
Um, look at the RNN form of linear
693.5s
attention. So, you know, I've put the
695.2s
linear attention equation right here
696.7s
just so you don't have to keep going
698.0s
back and forth. Um, and we can add a a small elaboration to, uh, linear attention in order to get more expressive updates. Um, and this is how you end up getting Mamba 2. Um, some of you may have heard of Mamba, Mamba 2, Mamba 3 now, I think. Um, these are a family of state space models, um,
700.5s
small elaboration to, uh, linear
703.0s
attention in order to get more
704.6s
expressive updates. Um, and this is how
707.3s
you end up getting Mamba 2. Um,
709.7s
some of you may have heard of Mamba,
711.1s
Mamba 2, Mamba 3 now, I think. Um, these
714.0s
are a family of state space models, um,
716.7s
by, uh, Albert Gu and Tri Dao and and friends. Um, and, you know, they're derived originally from like this kind of state space theory and view, but actually if you look at the mechanics of, you know, what is being done, you can actually see it as a very simple elaboration of the linear attention
718.8s
friends. Um, and, you know, they're
721.1s
derived originally from like this kind
723.4s
of state space theory and view, but
725.7s
actually if you look at the mechanics
727.9s
of, you know, what is being done, you
729.7s
can actually see it as a very simple
731.2s
elaboration of the linear attention
733.5s
mechanism. And so, the idea is the following, right? So, we start with linear attention. And you might say, Really, the main problem with linear attention is that I'm always passing my state forward, right? We know from like the olden days of LSTMs that it's important to know when to pass information forward and
735.5s
following, right? So, we start with
736.8s
linear attention.
738.3s
And you might say, Really, the main
739.9s
problem with linear attention is that
741.9s
I'm always passing my state forward,
744.0s
right? We know from like the olden days
745.9s
of LSTMs that it's important to know
748.5s
when to pass information forward and
750.8s
when to just not pass information forward, right? To forget things and send them to zero. So, what I'm going to do is I'm going to be inspired by that, and I'm going to add a gate gamma, right? Gamma of T. Gamma of T is not stateful. Gamma of T only depends on my
752.4s
forward, right? To forget things and
753.6s
send them to zero. So, what I'm going to
755.7s
do is I'm going to be inspired by that,
757.8s
and I'm going to add a gate gamma,
759.2s
right? Gamma of T. Gamma of T is not
761.8s
stateful. Gamma of T only depends on my
764.1s
current inputs X of T. Um, and gamma of T, um, is going to modulate how much of my state I'm going to carry forward into the future. Now, if you think about this, you know, the state-dependent terms are all these S's. Um, this gamma is not state-dependent at all. So, this means
767.2s
T, um, is going to modulate how much of
770.3s
my state I'm going to carry forward into
772.0s
the future.
773.4s
Now, if you think about this, you know,
774.8s
the state-dependent terms are all these
776.8s
S's. Um, this gamma is not
779.0s
state-dependent at all. So, this means
781.3s
that, you know, this is very simple to compute, and there's still sort of this duality. You can, uh, compute this Mamba 2 term either as a big dense matrix multiply or at inference time use it in this sort of recurrent form, um, to, um, uh, get nice inference improvements, right? So, that's kind of the core sort
782.7s
compute, and there's still sort of this
784.6s
duality. You can, uh, compute this Mamba
787.5s
2 term either as a big dense matrix
789.8s
multiply or at inference time use it in
792.4s
this sort of recurrent form, um, to, um,
796.1s
uh, get nice inference improvements,
797.5s
right? So, that's kind of the core sort
799.1s
of conceit or view of this idea that, you know, you can either use them in parallel or sequence just like linear attention, um, and people have now used this at scale. So, NeMo Tron 3, I mentioned, uh, briefly last lecture, um, uses this idea of combining Mamba 2 as their sort of lightweight layer, and
801.5s
you know, you can either use them in
802.6s
parallel or sequence just like linear
804.3s
attention, um, and people have now used
807.5s
this at scale. So, NeMo Tron 3, I
809.7s
mentioned, uh, briefly last lecture,
812.9s
um, uses this idea of combining Mamba 2
815.8s
as their sort of lightweight layer, and
817.8s
of course, you've got your big softmax attention every now and then, um, and they sort of alternate this in various ways to to manage their inference cost expressiveness trade-off. expressiveness trade-off. Um, NeMo Tron, um, NeMo Tron 3 achieves, you know, pretty good performance compared to Qwen 3, uh, thinking as compared to
819.4s
attention every now and then, um, and
821.5s
they sort of alternate this in various
823.2s
ways to to manage their inference cost
825.5s
expressiveness trade-off.
827.5s
expressiveness trade-off. Um,
828.4s
NeMo Tron, um, NeMo Tron 3 achieves, you
831.5s
know, pretty good performance compared
833.7s
to Qwen 3, uh, thinking as compared to
836.7s
GPT-OSS, um, and because of all these Mamba 2 layers, you know, has pretty good throughput at fairly large context length. So, you know, you can see that it kind of works. Uh, at least these are small frontier models, not like really big frontier models, but, you know, small frontier models, uh, that are open
839.5s
Mamba 2 layers, you know, has pretty
841.6s
good throughput at fairly large context
843.8s
length. So, you know, you can see that
845.2s
it kind of works. Uh, at least these are
847.9s
small frontier models, not like really
849.7s
big frontier models, but, you know,
851.2s
small frontier models, uh, that are open
853.9s
small frontier models, uh, that are open source. source. Cool. So, now we've like built on this gating idea, which is cool. You might say, though, like, Can we just, you know, keep pushing this idea further and like taking this recurrence and making it sort of increasingly more and making it sort of increasingly more complex?
855.8s
source. Cool.
857.0s
So, now we've like built on this gating
859.5s
idea, which is cool.
861.4s
You might say, though, like, Can we
863.6s
just, you know, keep pushing this idea
865.6s
further and like taking this recurrence
868.0s
and making it sort of increasingly more
870.0s
and making it sort of increasingly more complex?
871.5s
Um, and kind of the rule of thumb in some sense for what you can do with your operations here is, as long as you're gating the various terms in your RNN, so to speak, with only input-dependent terms. So, no state dependence, then you'll still have this fairly nice duality between sort of parallel
874.0s
some sense for what you can do with your
875.8s
operations here is, as long as you're
878.5s
gating the various terms in your RNN, so
881.2s
to speak, with only input-dependent
884.0s
terms. So, no state dependence, then
886.2s
you'll still have this fairly nice
888.0s
duality between sort of parallel
890.1s
operations, which you can use for training, and serial operations, which, you know, you are going to use for for inference. So, Gated DeltaNet is probably, uh, among the most widely used, I would say, um, state space model now. I think it's been tested in quite a few papers, um, and there's, you know, a very nice sort of,
891.4s
training, and serial operations, which,
894.2s
you know, you are going to use for for
895.8s
inference. So, Gated DeltaNet is
898.2s
probably, uh, among the most
901.5s
widely used, I would say, um, state
904.0s
space model now. I think it's been
905.7s
tested in quite a few papers,
908.2s
um, and there's, you know, a very nice
910.1s
sort of,
911.1s
uh, large-scale scale-up, um, in Qwen 3, sorry, 3.5 for these models. So, let me talk you through what Gated DeltaNet is um, in comparison to Mamba 2. So, Mamba 2 is, you know, as I have at this top here, we take the linear sort of recurrence, this state update. And
914.4s
sorry, 3.5 for these models. So, let me
916.8s
talk you through what Gated DeltaNet is
920.4s
um, in comparison to Mamba 2.
923.0s
So, Mamba 2 is, you know, as I have at
925.5s
this top here, we take the linear sort
928.0s
of recurrence, this state update. And
930.4s
what I'm going to do is I'm going to gate it with my gamma, right? Oh, I forgot to explain this VT. Sorry, I I apologize for that. Um, I added this for completeness. Mamba 2 also has a modification where they add what is basically a residual layer to the YT update. It's not really core to sort of
931.6s
gate it with my gamma, right? Oh, I
933.2s
forgot to explain this VT. Sorry, I I
935.0s
apologize for that. Um, I added this for
937.4s
completeness. Mamba 2 also has a
939.0s
modification where they add what is
940.9s
basically a residual layer to the YT
942.8s
update. It's not really core to sort of
945.4s
the state update or their arguments for why this is a good idea, but I put it in for completeness just to make sure, um, I'm actually giving the true Mamba 2, uh, updates rather than sort of a more abridged version. Um, that only gets the idea score ideas across. Okay. Um, but you can ignore the VT for
947.0s
why this is a good idea, but I put it in
948.6s
for completeness just to make sure, um,
950.6s
I'm actually giving the true Mamba 2,
952.8s
uh, updates rather than sort of a more
954.2s
abridged version. Um, that only gets the
956.7s
idea score ideas across.
958.9s
Okay. Um, but you can ignore the VT for
960.9s
now. This is really just like a architectural improvement that's not core to the state updates. Um, Gated DeltaNet takes this idea in some sense, um, and adds a second gate, beta T. So, what does beta T do? Um, beta T I think of as a no input operation gate, right? If beta T is
962.3s
architectural improvement that's not
964.0s
core to the state updates.
965.9s
Um, Gated DeltaNet takes this idea in
968.3s
some sense, um, and adds a second gate,
971.0s
beta T. So, what does beta T do? Um,
974.2s
beta T I think of as a no input
978.0s
operation gate, right? If beta T is
980.0s
zero, that basically means don't take any of my current information, don't add it into my state, right? Um, and I think most or all of you are probably familiar with LSTMs, right? This is very reminiscent of the ideas in LSTM. You've got gates that control, you know, whether to forget and whether to put in
982.5s
any of my current information, don't add
984.5s
it into my state, right? Um, and I think
988.0s
most or all of you are probably familiar
989.7s
with LSTMs, right? This is very
991.5s
reminiscent of the ideas in LSTM. You've
993.8s
got gates that control, you know,
996.2s
whether to forget and whether to put in
998.2s
information from the current time step, right? So, they operate in a very similar set of principles, um, although, of course, derived, I think, originally in a very different way. Now, the other thing that's interesting about, uh, Gated DeltaNet, and this comes from the the DeltaNet rather than the Gated DeltaNet, is this, um, update
999.8s
right? So, they operate in a very
1001.0s
similar set of principles, um, although,
1003.8s
of course, derived, I think, originally
1005.8s
in a very different way.
1007.6s
Now, the other thing that's interesting
1009.3s
about, uh, Gated DeltaNet, and this
1011.0s
comes from the the DeltaNet rather than
1012.7s
the Gated DeltaNet, is this, um, update
1015.8s
direction. So, instead of just gating things, you could have made this 1 minus beta T, for example. Um, what they do is if you're going to do an update where you're putting in this beta, what you're going to do is you're going to try to project out in some sense my current key
1017.4s
things, you could have made this 1 minus
1019.3s
beta T, for example. Um, what they do is
1022.8s
if you're going to do an update where
1024.2s
you're putting in this beta, what you're
1025.8s
going to do is you're going to try to
1026.8s
project out in some sense my current key
1029.1s
direction. So, the intuition here is I'm going to be writing in information from my current key, K of T. And when doing that, not only do I want to put in new information, I also want to erase any previous keys that have gone into it, right? That's one way that you can think about updates. Like, not
1031.4s
going to be writing in information from
1033.3s
my current key, K of T.
1035.5s
And when doing that, not only do I want
1037.7s
to put in new information, I also want
1039.9s
to erase any previous keys that have
1042.0s
gone into it, right? That's one way that
1043.7s
you can think about updates. Like, not
1045.0s
just put stuff in, you know, I want to clear out the current key that has gone in. And that's why there's this blue term, this identity minus beta T KT KT transpose, which is essentially acting as a projector that is projecting out things in the KT dimension. That's not exactly right because you're not doing things like unit
1046.4s
clear out the current key that has gone
1048.4s
in. And that's why there's this blue
1050.1s
term, this identity minus beta T KT KT
1052.7s
transpose, which is essentially acting
1054.4s
as a projector that is projecting out
1057.0s
things in the KT dimension.
1059.0s
That's not exactly right because you're
1060.3s
not doing things like unit
1061.1s
normalization, but, you know, you kind of get roughly the the intuition here. of get roughly the the intuition here. Um, the other thing I'll mention before I kind of move on, oops, sorry, I have this blue box here. I'll I'll move on with Gated DeltaNet is, um, that this has been reinvented in some
1063.0s
of get roughly the the intuition here.
1065.7s
of get roughly the the intuition here. Um,
1066.5s
the other thing I'll mention before I
1067.8s
kind of move on, oops, sorry, I have
1069.0s
this blue box here. I'll I'll move on
1070.5s
with Gated DeltaNet is, um,
1073.4s
that this has been reinvented in some
1075.4s
ways in in various settings. Um, this update, this projector appears if you try to solve certain kinds of like meta-learning least squares problems. So, it's been reinvented, um, in other forms in, uh, fast weight programming or test time training, where essentially through very different design principles, um, researchers from those areas have basically ended up, um, with
1078.0s
update, this projector appears if you
1080.4s
try to solve certain kinds of like
1082.0s
meta-learning least squares problems.
1084.0s
So, it's been reinvented, um, in other
1086.9s
forms in, uh, fast weight programming or
1089.2s
test time training, where essentially
1091.2s
through very different design
1092.2s
principles, um, researchers from those
1094.5s
areas have basically ended up, um, with
1096.8s
the exact same kinds of solutions. the exact same kinds of solutions. Um, and the thing that I'll say here is, you know, the Qwen 3.5 and the Qwen Next models that were kind of predecessors to that, um, some of the best open source models available today, um, and they use exactly this this architecture. They use
1099.3s
the exact same kinds of solutions. Um,
1101.4s
and the thing that I'll say here is, you
1103.1s
know, the Qwen 3.5 and the Qwen Next
1105.6s
models that were kind of predecessors to
1107.1s
that, um, some of the best open source
1109.2s
models available today, um, and they use
1111.8s
exactly this this architecture. They use
1113.4s
a 31 Gated DeltaNet attention hybrid, um, very good, very reasonable performance, um, strong inference characteristics that you can see on the right there, you know, Qwen uh, Next has like much higher decoding throughput relative to Qwen 3 as the context length goes up and up and up. Um, the middle panel is a comparison of
1116.1s
um, very good, very reasonable
1118.0s
performance, um, strong inference
1119.9s
characteristics that you can see on the
1121.2s
right there, you know, Qwen
1123.3s
uh, Next has like much higher decoding
1125.5s
throughput relative to Qwen 3 as the
1127.2s
context length goes up and up and up.
1129.4s
Um, the middle panel is a comparison of
1131.9s
performance versus various like close source models and like previous generation models, this hybrid architecture doesn't seem to hurt performance very much at all. performance very much at all. Um, the last thing, um, that I will say is, you know, there hasn't been that many great sort of controlled studies of how hybrid architectures perform. Um, there's maybe
1133.9s
source models and like previous
1135.3s
generation models, this hybrid
1137.4s
architecture doesn't seem to hurt
1138.7s
performance very much at all.
1141.8s
performance very much at all. Um,
1143.0s
the last thing,
1144.5s
um, that I will say is, you know, there
1146.9s
hasn't been that many great sort of
1149.2s
controlled studies of how hybrid
1151.5s
architectures perform. Um, there's maybe
1153.8s
one by ByteDance seed and UC Santa Cruz folks where they compare things like Mamba 2 and Gated DeltaNet and other architectures as a function of how many sort of hybrid layers you have. Um, some of the results I would say are kind of messy, but, you know, the things that I
1156.5s
folks where they compare things like
1158.9s
Mamba 2 and Gated DeltaNet and other
1160.9s
architectures as a function of how many
1163.1s
sort of hybrid layers you have. Um, some
1166.3s
of the results I would say are kind of
1167.5s
messy, but, you know, the things that I
1169.0s
would point at is this dash line over here on this right panel. This guy is full attention performance, and as you go from left to right, this is increasing the number of sort of non-full attention layers. So, you know, the the RNN layers, let's call them. You know, are you increasing them as you go
1170.7s
here on this right panel. This guy is
1173.2s
full attention performance, and as you
1175.2s
go from left to right, this is
1177.1s
increasing the number of sort of
1179.5s
non-full attention layers. So, you know,
1181.5s
the the RNN layers, let's call them. You
1183.6s
know, are you increasing them as you go
1184.8s
to the right, and you see performance to the right, and you see performance degradation. But maybe the thing to to see here is that for some of the best architectures that you have, like the yellow, the orange, the blue, um these are like the gated delta nets and like various other
1186.2s
to the right, and you see performance degradation.
1187.7s
But maybe the thing to to see here is
1190.0s
that for some of the best architectures
1191.5s
that you have, like the yellow, the
1192.7s
orange, the blue, um these are like the
1194.5s
gated delta nets and like various other
1196.6s
uh variants, you know, you see that at the low ratios, there's basically no hit. And then uh once you pass a certain point, you start to see more significant degradation in performance uh in long context performance in this case. Um as you basically at the end as you go to full RNN, you know, you have very
1198.4s
the low ratios, there's basically no
1200.6s
hit. And then uh once you pass a certain
1202.5s
point, you start to see more significant
1204.4s
degradation in performance uh in long
1207.3s
context performance in this case. Um as
1209.5s
you basically at the end as you go to
1211.4s
full RNN, you know, you have very
1213.3s
noticeable performance degradation in all these architectures. Um the same is true at various different other kinds of evaluations. Um some of these I think have been optimized by um the architectures themselves, like single key retrieval is a task that I think all of these uh long context architectures explicitly optimize for, but if you look
1214.9s
all these architectures. Um the same is
1216.9s
true at various different other kinds of
1219.0s
evaluations. Um some of these I think
1221.2s
have been optimized by um the
1223.4s
architectures themselves, like single
1224.8s
key retrieval is a task that I think all
1226.6s
of these uh long context architectures
1228.7s
explicitly optimize for, but if you look
1230.8s
at QA performance, we sort of see the same story, that as we increase the amount of hybrid ratio, you know, we see decreases in performance fairly steadily and fairly clearly, ending up at this pure point. pure point. Okay. Um any questions for for the linear time attention stuff? I have one last efficient attention thing to to talk
1232.4s
same story, that as we increase the
1234.2s
amount of hybrid ratio, you know, we see
1236.2s
decreases in performance fairly steadily
1238.2s
and fairly clearly, ending up at this
1240.0s
pure point.
1241.6s
pure point. Okay.
1242.4s
Um any questions for for the linear time
1244.6s
attention stuff? I have one last
1246.2s
efficient attention thing to to talk
1247.6s
about, but yes. In the beginning, um you talked about like the parallel form of attention versus the linear form, like versus the linear form, like there. Not not not at all. Shouldn't they be like approximately equivalent to slightly degradation in performance?
1249.6s
In the beginning, um you talked about
1251.3s
like the parallel form of attention
1252.7s
versus the linear form, like
1255.9s
versus the linear form, like there.
1256.7s
Not not not at all.
1259.2s
Shouldn't they be like approximately
1261.5s
equivalent to slightly
1263.9s
degradation in performance?
1270.1s
you know, I I meant RNN in general, like any recurrent formulation. So, gated delta net would be a recurrence. Um that is not equivalent to this top one. This is full softmax attention with the row, right? The first step is where we drop the row and then we become linear, right? That's the very very first step
1271.8s
any recurrent formulation. So, gated
1273.9s
delta net would be a recurrence. Um that
1276.2s
is not equivalent to this top one. This
1277.9s
is full softmax attention with the row,
1280.2s
right? The first step is where we drop
1282.3s
the row and then we become linear,
1284.0s
right? That's the very very first step
1285.2s
of any of these. So, this part is going to be lossy and then after that, you know, this linear form to, you know, this recurrent form, this equivalence, that is exact. that is exact. Good. Good. Yes. What's the ET actual ET like the attention term below? Oh, oh, oh, yes, yes, yes, yes.
1287.0s
to be lossy and then after that, you
1288.7s
know, this linear form to, you know,
1291.7s
this recurrent form, this equivalence,
1293.7s
that is exact.
1296.0s
that is exact. Good.
1296.8s
Good. Yes.
1298.8s
What's the
1300.1s
ET actual ET like the
1302.6s
attention term below?
1305.2s
Oh, oh, oh, yes, yes, yes, yes.
1307.1s
Um yeah, this term, right? Yeah. Um so, in the linear um attention form, what we're really doing is, you know, we're taking these KVs and we're, you know, uh accumulating them with Qs, right? In the dual form, that's kind of how this dual form, that's kind of how this operates. operates. Um
1310.0s
in the linear um attention form, what
1312.7s
we're really doing is, you know, we're
1313.6s
taking these KVs and we're, you know, uh
1315.9s
accumulating them with Qs, right? In the
1317.5s
dual form, that's kind of how this
1318.8s
dual form, that's kind of how this operates.
1320.3s
operates. Um
1321.1s
but you might actually want in your outputs to also have some value from your current time step. You might think that oh, the value for my current time step for my token is actually useful for various things downstream. And so, this acts a little bit like a residual connection, where the value of my
1323.1s
outputs to also have some value from
1325.0s
your current time step. You might think
1326.4s
that oh, the value for my current time
1328.2s
step for my token is actually useful for
1330.2s
various things downstream. And so, this
1332.3s
acts a little bit like a residual
1333.5s
connection, where the value of my
1334.8s
current inputs directly pass through to the outputs of my attention. And D is, you know, another modulated gate, where you can control the amount of pass through that happens. So, that's linear time attention. There's a lot of really interesting stuff. Um it's kind of also interesting that a lot of the methods that have
1337.0s
the outputs of my attention. And D is,
1339.3s
you know, another modulated gate, where
1340.9s
you can control the amount of pass
1342.0s
through that happens.
1347.8s
So, that's linear time attention.
1349.2s
There's a lot of really interesting
1350.6s
stuff. Um it's kind of also interesting
1352.6s
that a lot of the methods that have
1354.8s
gotten, you know, tested out really have this nice simple linear attention style recurrent form. Um there's a few exceptions, but for the most part, a lot of this has converged to almost very LSTM-like objects in the end. LSTM-like objects in the end. Okay. Okay. Um What do you want to say? Uh right, okay.
1357.6s
this nice simple linear attention style
1360.2s
recurrent form. Um there's a few
1361.8s
exceptions, but for the most part, a lot
1364.0s
of this has converged to almost very
1366.4s
LSTM-like objects in the end.
1369.7s
LSTM-like objects in the end. Okay.
1370.7s
Okay. Um
1372.8s
What do you want to say? Uh right, okay.
1374.2s
So, there's another alternative. Um what I talked about was linear attention, which has this, you know, very nice like complexity theoretic, you know, uh flavor to it. Um I want to talk about another uh attention optimization, which I did not talk about, but falls into this family of like efficient attention things that are kind of important. Um
1376.2s
I talked about was linear attention,
1377.8s
which has this, you know, very nice like
1379.1s
complexity theoretic, you know, uh
1381.6s
flavor to it. Um I want to talk about
1383.9s
another uh attention optimization, which
1386.1s
I did not talk about, but falls into
1388.1s
this family of like efficient attention
1390.5s
things that are kind of important. Um
1393.0s
so, DeepSeek in DeepSeek 3.2, I think, um had uh DSA, um not do sparse attention, that's DeepSeek attention, um where what they do is instead of um computing attention Oh, that was not intentional pun. Um instead of computing attention over all the tokens, what you're going to do is you're going to first have a
1397.2s
um had uh DSA, um not do sparse
1401.1s
attention, that's DeepSeek attention,
1403.3s
um where what they do is instead of um
1406.0s
computing attention Oh,
1408.2s
that was not intentional pun.
1409.8s
Um instead of computing attention over
1412.0s
all the tokens, what you're going to do
1413.8s
is you're going to first have a
1415.0s
lightweight indexer that's going to subset a bunch of tokens, right? So, you're going to look at your very long context and the indexer is going to pick out, you know, some subset of these, um much smaller than the full sequence, and then you're going to do full attention on that smaller subset, right? Um
1416.8s
subset a bunch of tokens, right? So,
1418.2s
you're going to look at your very long
1419.6s
context and the indexer is going to pick
1421.8s
out, you know, some subset of these, um
1424.1s
much smaller than the full sequence, and
1426.2s
then you're going to do full attention
1427.5s
on that smaller subset, right? Um
1430.2s
So, the mechanics of this are pretty simple, at least in the forward pass. Um where what you do is you have, you know, your normal Qs and Ks, um but then you pass this through this indexer. And what the indexer does is it takes the QK inner products, it takes a ReLU, um and
1432.0s
simple, at least in the forward pass.
1435.2s
Um where what you do is you have, you
1437.6s
know, your normal Qs and Ks, um but then
1440.4s
you pass this through this indexer. And
1442.0s
what the indexer does is it takes the QK
1443.9s
inner products, it takes a ReLU, um and
1447.2s
then it's going to uh have these weights that are derived from the preceding tokens, and then this is going to give you um the act sort of the the activations for um each of the positions. And then you're going to pick this through a top K, where some number of these, you know, top positions get
1450.4s
that are derived from the preceding
1451.6s
tokens, and then this is going to give
1452.9s
you um the act sort of the the
1455.5s
activations for um each of the
1457.8s
positions. And then you're going to pick
1460.2s
this through a top K, where some number
1462.6s
of these, you know, top positions get
1465.0s
sort of included into my attention computation, and I'm going to do my usual attention computation that emits a key value. Um the nice thing about this is it's very very sparse, um and if this indexer is lightweight, now this bottom term, where you're doing full attention, can be done on a very small subset. Another
1466.6s
computation, and I'm going to do my
1468.2s
usual attention computation that emits a
1470.0s
key value.
1471.6s
Um the nice thing about this is it's
1474.1s
very very sparse, um and if this indexer
1477.5s
is lightweight, now this bottom term,
1479.6s
where you're doing full attention, can
1481.1s
be done on a very small subset. Another
1483.4s
thing that both uh DSA and I think um GLM uh show is that you don't actually have to train the model with this, cuz training with this like indexer thing is is maybe pretty annoying and very complex. What you do is you just train a normal transformer, and when you do your long context extension, you actually
1486.5s
GLM uh show is that you don't actually
1489.7s
have to train the model with this, cuz
1491.2s
training with this like indexer thing is
1493.2s
is maybe pretty annoying and very
1495.7s
complex. What you do is you just train a
1497.7s
normal transformer, and when you do your
1499.5s
long context extension, you actually
1501.7s
then drop in this lightning indexer, and then you train the model to sort of handle this indexer in a sort of extension stage that's separate from pre-training. So, actually the computational cost of doing this is itself pretty light. Um and I want to highlight, you know, in in DeepSeek's V3.2, where they they proposed DSA, you
1504.1s
then you train the model to sort of
1505.7s
handle this indexer in a sort of
1507.2s
extension stage that's separate from
1509.2s
pre-training. So, actually the
1510.6s
computational cost of doing this is
1512.7s
itself pretty light.
1514.8s
Um and I want to highlight, you know, in
1516.7s
in DeepSeek's
1518.2s
V3.2, where they they proposed DSA, you
1520.9s
know, this model is quite good. Like V3.2, you know, matches all of the other um frontier models at the time, like Claude 4.5 Sonnet or Gemini 3. Um and they also have this very nice plot where they're showing, well, you know, our our prefill, which is like the attending to the the input, um and also decoding,
1522.7s
V3.2, you know, matches all of the other
1525.5s
um frontier models at the time, like
1527.0s
Claude 4.5 Sonnet or Gemini 3. Um and
1530.2s
they also have this very nice plot where
1531.6s
they're showing, well, you know, our our
1533.7s
prefill, which is like the attending to
1535.9s
the the input, um and also decoding,
1538.5s
where you're actually generating, both of these are very sort of favorable scaling compared to the previous generation of DeepSeek models that did not use sparse attention. And that kind of makes sense, right? If you're only paying attention to a very small subset, you're going to dramatically be reducing your costs of doing attention.
1540.3s
of these are very sort of favorable
1542.5s
scaling compared to the previous
1544.1s
generation of DeepSeek models that did
1546.1s
not use sparse attention. And that kind
1547.8s
of makes sense, right? If you're only
1549.0s
paying attention to a very small subset,
1551.2s
you're going to dramatically be reducing
1553.0s
your costs of doing attention.
1556.2s
Um and I think another nice orthogonal sort of validation of this is GLM 5, um which um came out last year. Uh I think GLM 5 is one of the you know, the best open models that are available, like period. Um and they also adopted this DSA approach, and they have actually fairly nice ablations in their
1559.2s
sort of validation of this is GLM 5, um
1562.6s
which um came out last year. Uh
1565.8s
I think GLM 5 is one of the
1568.9s
you know, the best open models that are
1570.3s
available, like period. Um and they also
1572.9s
adopted this DSA approach, and they have
1575.4s
actually fairly nice ablations in their
1577.1s
paper, if you're if you're curious, where they end up sort of comparing with like oh, what happens if we do DSA with some warm-up or, you know, just do DSA or don't do DSA at all, you know, the main thing that they're showing is if they do full DSA training,
1578.9s
where they end up sort of comparing with
1580.7s
like oh, what happens if we do DSA with
1582.5s
some warm-up or, you know, just do DSA
1585.0s
or don't do DSA at all, you know, the
1587.0s
main thing that they're showing is if
1588.3s
they do full DSA training,
1590.7s
um you know, you don't lose very much in performance relative to full attention, even at long context retrieval tasks um that are fairly difficult to to do with RNN-style architectures. So, this is a totally alternative view of how to totally alternative view of how to reduce uh attention costs. Notice, I'll I'll go
1592.8s
performance relative to full
1594.7s
attention, even at long context
1596.7s
retrieval tasks um that are fairly
1598.8s
difficult to to do with
1601.1s
RNN-style architectures. So, this is a
1603.0s
totally alternative view of how to
1605.4s
totally alternative view of how to reduce
1606.8s
uh attention costs. Notice, I'll I'll go
1609.3s
back one slide, this is not linear time, right? The indexer still has to operate on full all-to-all attention, cuz you got these QK inner products that are not going to be linear in cost, but you can make this indexer much much cheaper, um for example, by making this lower dimensional or making this W very small,
1612.1s
right? The indexer still has to operate
1614.4s
on full all-to-all attention, cuz you
1616.2s
got these QK inner products that are not
1618.6s
going to be linear in cost, but you can
1620.8s
make this indexer much much cheaper, um
1624.0s
for example, by making this lower
1625.3s
dimensional or making this W very small,
1628.1s
um all of this can be made, you know, much cheaper in its cost, right? Okay. Um so, that's a totally alternative view, works very well. I won't talk much about this, but it turns out that this idea of top K selection is going to be core to the next part uh of
1629.8s
much cheaper in its cost, right?
1632.6s
Okay. Um so, that's a totally
1634.3s
alternative view, works very well. I
1635.8s
won't talk much about this, but it turns
1637.9s
out that this idea of top K selection is
1641.2s
going to be core to the next part uh of
1643.5s
this lecture. I'll stop here for for any efficient attention questions before I move on to uh MOEs. move on to uh MOEs. Yes. Yes. So, uh the the indexer, the time complexity for that, is it It's quadratic. Quadratic, yeah. Yeah. That's right. Cuz it has to I mean, in order to know what to select, it does
1645.1s
efficient attention questions before I
1646.9s
move on to uh MOEs.
1649.8s
move on to uh MOEs. Yes.
1651.7s
Yes. So,
1652.5s
uh the the indexer, the time complexity
1655.2s
for that, is it
1657.0s
It's quadratic. Quadratic, yeah. Yeah.
1659.1s
That's right. Cuz it has to I mean, in
1660.7s
order to know what to select, it does
1662.8s
have to look at everything, right? Okay. And there's no clever like state transition stuff that happens here. It is really brute force inner products. So, if you're going through all of it and then like afterwards, the step afterwards is also quadratic as well, how are we getting It's It's I mean, it's kind of like a
1664.4s
And there's no clever like state
1666.5s
transition stuff that happens here. It
1668.0s
is really brute force inner products.
1669.8s
So, if you're going through all of it
1672.0s
and then like afterwards, the step
1674.2s
afterwards is also quadratic as well,
1676.5s
how are we getting
1679.0s
It's It's I mean, it's kind of like a
1680.2s
systems-y trick almost. Like you're going to do the indexer much more lightweight, and there are many ways to do that. You can do that in much lower precision, which they talk about. You can do it in um ways that are lower dimensional by projecting the Qs and Ks further just for the indexing. So, there
1681.9s
going to do the indexer much more
1683.6s
lightweight, and there are many ways to
1685.2s
do that. You can do that in much lower
1686.4s
precision, which they talk about. You
1688.1s
can do it in um ways that are lower
1690.7s
dimensional by projecting the Qs and Ks
1692.6s
further just for the indexing. So, there
1694.7s
are lots of tricks that you can do here that you can imagine, but the indexer is going to be much more lightweight, so the constant factors are very good. And then the for the second one, even though it's quadratic, it's quadratic on a shorter context length, cuz it's top K,
1696.1s
that you can imagine, but the indexer is
1697.8s
going to be much more lightweight, so
1698.9s
the constant factors are very good. And
1700.9s
then the for the second one, even though
1702.3s
it's quadratic, it's quadratic on a
1703.5s
shorter context length, cuz it's top K,
1705.4s
we can control K. So, now it's this is expensive, but small, right? So, it's, you know, this is another one of these things of um sometimes don't get too stuck up on the quadratic quadratic versus not. Sometimes the constant factors are really really important. There's another question? Yeah. That's right. That's right. Or it's part
1707.1s
expensive, but small, right? So, it's,
1709.7s
you know, this is another one of these
1711.0s
things of um
1712.2s
sometimes don't get too stuck up on the
1713.5s
quadratic quadratic versus not.
1715.4s
Sometimes the constant factors are
1716.6s
really really important.
1718.5s
There's another question? Yeah.
1724.8s
That's right. That's right. Or it's part
1726.3s
of like a continued pre-training style thing. Because all the models, you know, you don't train long context from scratch for compute and various other reasons, you train a shortish context model first, and then you've got these long context extension stages. And so, the nice idea here is that, you know, we're going to do the second phase
1727.9s
thing. Because all the models, you know,
1729.8s
you don't train long context from
1731.2s
scratch for compute and various other
1733.4s
reasons, you train a shortish context
1735.4s
model first, and then you've got these
1737.0s
long context extension stages. And so,
1739.1s
the nice idea here is that, you know,
1740.9s
we're going to do the second phase
1742.0s
anyway, why don't we bolt the long context, you know, cost savings on at the same stage, right? And it's kind of surprising that it works, you know, honestly, that you can bolt on this uh frankly scary-looking top K non-differentiable object, but as you'll see in the next part, um maybe you shouldn't be so scared of top K
1744.3s
context, you know, cost savings on at
1746.1s
the same stage, right? And it's kind of
1748.0s
surprising that it works, you know,
1749.7s
honestly, that you can bolt on this uh
1752.2s
frankly scary-looking top K
1755.0s
non-differentiable object, but as you'll
1757.3s
see in the next part, um maybe you
1759.4s
shouldn't be so scared of top K
1760.8s
shouldn't be so scared of top K selection. the post another post-training step? Yeah, the usual recipe that I am familiar with with a lot of the open-source models is that usually goes like short context pre-training, long context extension, then post-training.
1766.3s
the post another post-training step?
1768.8s
Yeah, the usual recipe that I am
1770.5s
familiar with with a lot of the
1772.0s
open-source models is that usually goes
1773.6s
like short context pre-training, long
1775.4s
context extension, then post-training.
1783.2s
the short context length? Yeah, that's right. Like you would pick values that are like much closer to short context performance and you would bound it, right? So, regardless of the input size, they would be bound. Okay, actually many questions. I Okay, actually many questions. I appreciate Are there general tests for the removal of the soft max
1785.7s
Yeah, that's right. Like you would pick
1787.3s
values that are like much closer to
1788.7s
short context performance and you would
1790.1s
bound it, right? So, regardless of the
1791.4s
input size, they would be bound.
1794.9s
Okay, actually many questions. I
1796.1s
Okay, actually many questions. I appreciate
1797.0s
Are there general tests for the removal
1798.6s
of the soft max
1799.9s
for the context length? Training stability issues. That's a good question. Um, I don't think so. I'm not aware of Um, I don't think so. I'm not aware of like, you know, documented linear uh soft max causing lots of attentions. If anything, soft maxes are usually more of the problems for various reasons.
1805.3s
Training stability issues. That's a good question.
1806.3s
Um, I don't think so. I'm not aware of
1808.4s
Um, I don't think so. I'm not aware of like,
1809.4s
you know, documented linear uh soft max
1811.8s
causing lots of attentions. If anything,
1813.2s
soft maxes are usually more of the
1814.9s
problems for various reasons.
1817.0s
Um, so I think this would actually improve the stability properties of the improve the stability properties of the architecture. Um, can you briefly could you briefly describe the next level of deep learning the next level of deep learning attention? mechanism because because of the uh the uh the low community cost and the low latency.
1818.2s
improve the stability properties of the
1819.8s
improve the stability properties of the architecture.
1825.4s
Um, can you briefly could you briefly describe
1826.4s
the next level of deep learning
1827.7s
the next level of deep learning attention?
1832.0s
mechanism because because of the uh the
1835.3s
uh the
1836.0s
low community cost and the low latency.
1839.2s
Yeah. Can you elaborate what do you mean? Like what kind of Like the mean? Like what kind of Like the highlight in highlight of what highlight in the future is a better attention mechanism. The the yeah, the vast majority of the Yeah, I don't know if I have good predictions for what the future
1840.6s
mean? Like what kind of Like the
1843.0s
mean? Like what kind of Like the highlight
1844.0s
in highlight of what highlight in the
1845.5s
future is
1846.7s
a better attention mechanism.
1848.2s
The the yeah, the vast majority of the
1850.7s
Yeah, I don't know if I have good
1853.1s
predictions for what the future
1854.6s
attention architecture looks like. I mean, my somewhat cop-out answer to this is I think a lot of this will look like we throw all of these tricks in, right? In the same ways that architectures have gotten very complex by taking all of these successful recipes and throwing them in. I think also that there's a a
1856.9s
mean, my somewhat cop-out answer to this
1859.9s
is I think a lot of this will look like
1861.5s
we throw all of these tricks in, right?
1863.7s
In the same ways that architectures have
1865.1s
gotten very complex by taking all of
1867.3s
these successful recipes and throwing
1868.6s
them in. I think also that there's a a
1871.6s
them in. I think also that there's a a even uh broader layer or like a higher level layer of basically using post-training to to have the model manage its own context, right? Compaction, retrieval, all of these things are basically an added layer on top. Um, I think a lot of the the work that I imagine will be
1872.8s
uh broader layer or like a higher level
1874.6s
layer of basically using post-training
1877.0s
to to have the model manage its own
1878.7s
context, right? Compaction, retrieval,
1880.9s
all of these things are basically an
1882.3s
added layer on top. Um, I think a lot of
1885.0s
the the work that I imagine will be
1887.0s
integration of these different layers because we've seen that, you know, at least the linear time attention stuff has converged a lot into LSTM-like or linear attention-like architectures and I don't really see that changing too much in the in the near future. Cool. Oh, yes. Well, last question. You speak about like lower precision attention, but I think
1889.0s
because we've seen that, you know,
1891.0s
at least the linear time attention stuff
1892.8s
has converged a lot into LSTM-like or
1895.4s
linear attention-like architectures and
1897.0s
I don't really see that changing too
1898.7s
much in the in the near future.
1902.0s
Cool. Oh, yes. Well, last question. You
1903.9s
speak about like lower precision
1905.4s
attention, but
1907.2s
I think
1908.2s
is FP4 attention possible? Um, I think the indexer is partially motivated by things like this. Like FP4 motivated by things like this. Like FP4 attention definitely possible. Um, I think the drawback is that, you know, the the loss of precision for things like soft maxes where like small underflow, overflow, these kinds of
1910.4s
Um, I think the indexer is partially
1913.0s
motivated by things like this. Like FP4
1915.2s
motivated by things like this. Like FP4 attention
1916.4s
definitely possible.
1918.2s
Um, I think the drawback is that, you
1920.2s
know, the the loss of precision for
1921.7s
things like soft maxes where like small
1923.9s
underflow, overflow, these kinds of
1925.3s
things can really start to matter can sometimes be pretty problematic. Um, and so I think, you know, this motivation of saying like, okay, like let's do selection with low precision and then do a full precision computation afterwards cuz then the value vectors can be like more finely added together and weighted more finely added together and weighted together.
1927.0s
sometimes be pretty problematic. Um, and
1929.2s
so I think, you know, this motivation of
1931.1s
saying like, okay, like let's do
1932.2s
selection with low precision and then do
1934.4s
a full precision computation afterwards
1936.3s
cuz then the value vectors can be like
1938.3s
more finely added together and weighted
1939.9s
more finely added together and weighted together.
1941.5s
It's a uh approach that makes sense. I don't know if there's, you know, really good ways of um of doing full attention in in FP4 or even lower precision of some kind. Yeah. Um, okay, fine. One last question and then we'll move on. then we'll move on. So, compared to the transformers state space models, what's the downside
1942.2s
uh approach that makes sense. I don't
1943.6s
know if there's,
1945.1s
you know, really good ways of um
1948.0s
of doing full attention in in FP4 or
1950.3s
even lower precision of some kind. Yeah.
1955.6s
Um, okay, fine. One last question and
1956.8s
then we'll move on.
1959.1s
then we'll move on. So,
1959.9s
compared to the transformers
1962.5s
state space models,
1963.7s
what's the downside
1968.8s
state space Yeah, so that's a good question. There has to be some downside or otherwise everybody Well, look, like the downside I think is sort of expressive power, right? Um, you know, the all-to-all connection in in soft max attention is incredibly powerful. It's also very easy to train, right? Um, it used to be I
1970.2s
Yeah, so that's a good question. There
1971.6s
has to be some downside or otherwise
1973.1s
everybody Well, look, like the downside
1974.9s
I think is sort of expressive power,
1976.7s
right? Um, you know, the all-to-all
1978.6s
connection in in soft max attention is
1980.4s
incredibly powerful. It's also very easy
1982.1s
to train, right? Um, it used to be I
1984.3s
think, you know, compared to say an LSTM, attention had the very strong advantage of of like hardware efficiency, right? You could you could train in parallel. Um, I think the reason why these state space models have caught on, you know, despite their similarity to LSTMs and similar drawbacks in terms of representational
1985.4s
LSTM, attention had the very strong
1987.6s
advantage of of like hardware
1989.0s
efficiency, right? You could you could
1990.3s
train in parallel. Um, I think the
1992.5s
reason why these state space models have
1994.8s
caught on, you know, despite their
1996.6s
similarity to LSTMs and similar
1998.5s
drawbacks in terms of representational
2000.0s
power is that now, you know, there's this kind of well-understood duality because of the linear attention sort of thing that I talked about where you can go from, you know, this RNN form to the more dense matrix multiply form. And the matrix multiply form allows for computational efficiency. So, you know, that trade-off has now been sort of
2002.9s
this kind of well-understood duality
2004.9s
because of the linear attention sort of
2006.4s
thing that I talked about where you can
2008.0s
go from, you know, this RNN form to the
2010.6s
more dense matrix multiply form. And the
2012.8s
matrix multiply form allows for
2014.3s
computational efficiency. So, you know,
2016.0s
that trade-off has now been sort of
2017.8s
checked off, but you've still got the trade-off of if you have a finite state and you have to carry everything and you have to carry everything through, you know, you you're going to be losing some information relative to just carrying everything.
2019.8s
trade-off of if you have a finite state
2021.5s
and you have to carry everything
2022.4s
and you have to carry everything through,
2023.4s
you know, you you're going to be losing
2025.0s
some information relative to just
2026.3s
carrying everything.
2038.3s
basically kind of think about the context length relative to the size of your state, right? I mean, if you have a the state the size of your context, then you're good to go, but then you're paying these like very large costs, right? So, really the the point is the free lunch is if you want a really tiny
2039.4s
context length relative to the size of
2041.2s
your state, right? I mean, if you have a
2042.8s
the state the size of your context, then
2045.0s
you're good to go, but then you're
2046.2s
paying these like very large costs,
2047.7s
right? So, really the the point is the
2050.0s
free lunch is if you want a really tiny
2051.5s
state, it's really hard to compress all the information in a big context, right? It might be possible one day to to not have any trade-offs, but I think thus far that's where the trade-offs are showing up. Cool. Okay, so now I'm going to talk about mixture of experts. about mixture of experts. Um,
2053.5s
the information in a big context, right?
2055.7s
It might be possible one day to to not
2057.9s
have any trade-offs, but I think thus
2059.4s
far that's where the trade-offs are
2060.5s
showing up.
2062.1s
Cool. Okay, so now I'm going to talk
2063.7s
about mixture of experts.
2065.4s
about mixture of experts. Um,
2066.1s
and I think mixture of experts are important to talk about. Um, conceptually they don't really change the game. Like one way of thinking about mixture of experts is they are just a more efficient MLP, right? Like you take your MLP and somehow someone gives you just a more efficient MLP, that is kind
2067.6s
important to talk about. Um,
2069.3s
conceptually they don't really change
2071.0s
the game. Like one way of thinking about
2072.6s
mixture of experts is they are just a
2074.7s
more efficient MLP, right? Like you take
2076.6s
your MLP and somehow someone gives you
2078.5s
just a more efficient MLP, that is kind
2081.0s
of what mixture of experts is. Um, but I think you need to understand, you know, what mixture of experts is. I think for two reasons. One is kind of this slide of like everyone's doing MOEs these days. Everyone's shipping MOEs these days. Um, you should understand what it is if you're going to, you know, want to
2083.2s
think you need to understand, you know,
2085.0s
what mixture of experts is. I think for
2086.8s
two reasons. One is kind of this slide
2089.4s
of like everyone's doing MOEs these
2091.0s
days. Everyone's shipping MOEs these
2093.0s
days. Um, you should understand what it
2095.1s
is if you're going to, you know, want to
2096.8s
is if you're going to, you know, want to be be uh people who deeply understand language models and how they work, right? So, that's one reason. Um, the other reason is that MOEs actually have some really is that MOEs actually have some really interesting um sort of mechanical components to
2097.9s
be uh
2098.5s
people who deeply understand language
2100.0s
models and how they work, right? So,
2101.3s
that's one reason. Um, the other reason
2103.9s
is that MOEs actually have some really
2105.9s
is that MOEs actually have some really interesting
2107.6s
um sort of mechanical components to
2109.6s
them. Like if you look at how MOEs are built and what primitives they use, you will kind of quickly see MOE-like primitives in many places. Um, and you'll kind of realize, wow, I can do more interesting things with neural networks. So, there's a broader reason why I think MOEs are interesting and cool and you should learn about.
2111.2s
built and what primitives they use, you
2113.6s
will kind of quickly see MOE-like
2115.5s
primitives in many places. Um, and
2117.6s
you'll kind of realize, wow, I can do
2119.2s
more interesting things with neural
2120.5s
networks. So, there's a broader reason
2122.5s
why I think MOEs are interesting and
2124.0s
cool and you should learn about.
2126.2s
Okay. So, what is an MOE, right? MOEs are very simple things in some ways. They're just replacements of your MLP. So, on the left is your usual MLP for a transformer, right? You've got your attention, you've got some normalization, and then you've got a feed forward, right? Right after that.
2129.4s
are very simple things in some ways.
2131.0s
They're just replacements of your MLP.
2133.2s
So, on the left is your usual MLP for a
2136.3s
transformer, right? You've got your
2137.6s
attention, you've got some
2138.5s
normalization, and then you've got a
2140.3s
feed forward, right? Right after that.
2142.1s
This is kind of where a lot of the dense, big information processing happens. Now, um what I can do instead is let's say I take my big FFN and I'm going to cut them up into smaller FFNs, right? So, I have now have four feed forward networks. Um, and let's say
2143.6s
dense, big information processing
2145.6s
happens. Now, um what I can do instead
2149.3s
is let's say I take my big FFN and I'm
2151.7s
going to cut them up into smaller FFNs,
2154.2s
right? So, I have now have four feed
2156.1s
forward networks. Um, and let's say
2158.1s
actually let's say I don't cut them up into smaller pieces. Let's say they're the same size as my original FFN. And somehow, right? By some magic, I have a system that tells me which FFN to pick for every input, right? Now, what I've done is I have 4x the parameters or at least for
2159.5s
into smaller pieces. Let's say they're
2160.6s
the same size
2162.3s
as my original FFN. And somehow, right?
2164.8s
By some magic, I have a system that
2166.9s
tells me which FFN to pick for every
2169.4s
input, right? Now, what I've done is I
2172.2s
have 4x the parameters or at least for
2174.6s
my FFN, right? I have four FFNs worth of parameters, but on any forward or backward pass, right? I'm only going to pay one FFN worth of cost, right? That's great. So, in some sense, you know, the original motivations for MOEs, if you read the original papers, have this like very um parameter-centric view. You
2176.7s
parameters, but on any forward or
2179.0s
backward pass, right? I'm only going to
2181.4s
pay one FFN worth of cost, right? That's
2184.4s
great. So, in some sense, you know, the
2187.3s
original motivations for MOEs, if you
2189.2s
read the original papers, have this like
2190.8s
very um parameter-centric view. You
2193.5s
know, if you it start something like, Well, let's say you just wanted more parameters because you believe more parameters are good. It's like, okay, good. If we want more parameters and we don't want to pay the cost for more parameters, you need something like something on the right, right? So, this
2195.3s
Well, let's say you just wanted more
2197.0s
parameters because you believe more
2198.6s
parameters are good. It's like, okay,
2200.4s
good. If we want more parameters and we
2202.4s
don't want to pay the cost for more
2203.6s
parameters, you need something like
2205.4s
something on the right, right? So, this
2206.6s
is this is the mental model that you should have for MOEs, right? You want to increase parameters without affecting your flops. Now, why are they so popular, right? Every model that you get on hugging face or whatever these days are MOEs past a certain size, right? Past a certain size, everything that you get seems to
2208.4s
should have for MOEs, right? You want to
2210.4s
increase parameters without affecting
2212.3s
your flops.
2214.6s
Now, why are they so popular, right?
2217.2s
Every model that you get on hugging face
2219.4s
or whatever these days are MOEs past a
2221.9s
certain size, right? Past a certain
2223.2s
size, everything that you get seems to
2225.1s
be an MOE. Um, and the reason is because it seems to be the case that for whatever reason, if you keep the total compute the same, but you just increase the number of sparse parameters, right? Somehow the models are generally getting better, right? So, this is uh evidence maybe in favor of the view that
2228.5s
it seems to be the case that for
2230.4s
whatever reason, if you keep the total
2232.9s
compute the same, but you just increase
2235.9s
the number of sparse parameters, right?
2238.7s
Somehow the models are generally getting
2241.0s
better, right? So, this is uh
2242.6s
evidence maybe in favor of the view that
2244.7s
more parameters are in fact generally good even if only a subset of them are active at a time. Um, and I'm going to, you know, refer to some of the original Google papers extensively throughout this lecture. This one's from uh Fedus et al. in 2022 um where they were doing I think the switch transformer for this
2247.1s
good even if only a subset of them are
2249.5s
active at a time. Um, and I'm going to,
2252.3s
you know, refer to some of the original
2253.7s
Google papers extensively throughout
2255.4s
this lecture. This one's from uh Fedus
2257.4s
et al. in 2022 um where they were doing
2259.4s
I think the switch transformer for this
2261.1s
one. You know, and they're showing, oh, as you increase the number of experts, right? So, this is active parameter stays the same, number of experts increases, you know, your test loss for language modeling is just decreasing, decreasing, decreasing. That's lovely if all you care about is your forward pass all you care about is your forward pass flops.
2262.5s
as you increase the number of experts,
2264.3s
right? So, this is active parameter
2266.2s
stays the same, number of experts
2267.9s
increases, you know, your test loss for
2270.2s
language modeling is just decreasing,
2271.6s
decreasing, decreasing. That's lovely if
2273.6s
all you care about is your forward pass
2275.1s
all you care about is your forward pass flops.
2276.4s
Um, on the right you see sort of similar things um as you, you know, you train faster, right? As you increase uh your training compute on the x-axis, if you have more experts for the same amount of training compute, you just get better performance, right? So, for both training and inference, um MOEs just
2278.5s
things um as you, you know, you train
2281.0s
faster, right? As you increase uh your
2283.2s
training compute on the x-axis, if you
2285.2s
have more experts for the same amount of
2286.9s
training compute, you just get better
2289.1s
performance, right? So, for both
2290.3s
training and inference, um MOEs just
2292.7s
give you a free win, right? This is pretty pretty wonderful. It's pretty pretty pretty wonderful. It's pretty great. great. Um And this is really, you know, something that's been seen many times. This is the same plot from before on the left from the Fedus paper um last year, I think. Um,
2294.2s
pretty pretty wonderful. It's pretty
2295.8s
pretty pretty wonderful. It's pretty great.
2296.6s
great. Um
2298.2s
And this is really, you know, something
2299.7s
that's been
2301.0s
seen many times. This is the same plot
2302.8s
from before on the left from the Fedus
2304.1s
paper um last year, I think. Um,
2308.0s
AIQ folks put out this Olmo paper which was their like open source mixture of experts uh analysis and sort of training paper and they see the same thing as well. If you look at um training loss or validation loss or sort of uh downstream benchmark performance, you see that um in training time, right? Um, these
2310.8s
was their like open source mixture of
2312.8s
experts uh analysis and sort of training
2315.7s
paper and they see the same thing as
2317.4s
well. If you look at um training loss or
2319.7s
validation loss or sort of uh downstream
2322.0s
benchmark performance, you see that um
2324.6s
in training time, right? Um, these
2326.7s
models are doing much better as MOEs um compared to their dense counterparts. So, something like uh two times faster or something like that um training an MOE relative to a dense model. So, you can really see the benefits of sparsity can really see the benefits of sparsity here. here. Um
2329.2s
compared to their dense counterparts.
2330.8s
So, something like uh two times faster
2333.3s
or something like that um training an
2335.2s
MOE relative to a dense model. So, you
2338.0s
can really see the benefits of sparsity
2339.8s
can really see the benefits of sparsity here.
2340.8s
here. Um
2342.4s
And, you know, I think the proof is in the pudding, but if you look at all the MOEs that got released, you know, they are much stronger than dense models in terms of activated parameters, which is what matters if you care about inference flops, or really even, you know, just in general the cost to train and serve
2343.9s
the pudding, but if you look at all the
2345.1s
MOEs that got released, you know, they
2347.8s
are much stronger than dense models in
2350.5s
terms of activated parameters, which is
2352.6s
what matters if you care about inference
2354.2s
flops, or really even, you know, just in
2356.8s
general the cost to train and serve
2358.6s
these systems, right? Like the the all-in cost. So, when DeepSeek V2 and like the earlier DeepSeek MOEs came out, you know, it was really kind of a big shift from a lot of these dense models that everyone else was training. You kind of saw that, wow, we have much fewer active parameters, you know, just
2360.1s
all-in cost. So, when DeepSeek V2 and
2362.8s
like the earlier DeepSeek MOEs came out,
2365.1s
you know, it was really kind of a big
2367.0s
shift from a lot of these dense models
2368.8s
that everyone else was training. You
2370.2s
kind of saw that, wow, we have much
2371.9s
fewer active parameters, you know, just
2373.9s
as good if not better MMLU performance uh compared to everyone else, right? So, this is this is a really good stuff happening with MOEs. Okay. So, hopefully I've convinced you that MOEs are like an interesting thing that you should like spend your time thinking about and studying. Oh, right. And then one
2376.2s
uh compared to everyone else, right? So,
2377.7s
this is this is a really good stuff
2379.7s
happening with MOEs.
2382.7s
Okay. So, hopefully I've convinced you
2384.6s
that MOEs are like an interesting thing
2386.6s
that you should like spend your time
2388.0s
thinking about
2389.6s
and studying. Oh, right. And then one
2391.5s
last thing that I'll say for for why this is important is MOEs give you another axis of parallelization. We'll talk about parallelization in systems in three three more lectures, I think. But parallelization is actually one of the really big important things in LLMs. LLMs are very big. They don't fit in single devices either to train or
2393.8s
this is important is MOEs give you
2395.8s
another axis of parallelization.
2398.4s
We'll talk about parallelization in
2399.7s
systems in
2401.4s
three three more lectures, I think.
2404.1s
But parallelization is actually one of
2405.4s
the really big important things in LLMs.
2408.0s
LLMs are very big. They don't fit in
2409.8s
single devices either to train or
2411.7s
inference. So, you want many different ways of cutting up your model to efficiently serve or train them. Mixture of experts have this additional nice thing that you can do where your experts within each FFN naturally come in different, you know, chunks, right? Each expert is a natural chunk. You can put them on different devices, and then
2413.4s
ways of cutting up your model to
2414.8s
efficiently serve or train them.
2416.9s
Mixture of experts have this additional
2418.5s
nice thing that you can do where your
2420.5s
experts within each FFN naturally come
2423.6s
in different, you know, chunks, right?
2425.2s
Each expert is a natural chunk. You can
2427.1s
put them on different devices, and then
2428.9s
you can route the activations to different devices. We'll talk about expert parallel later, which is what this is, but this gives you an additional systems thing that you can optimize, which is very helpful if you want to sort of optimize your serving. want to sort of optimize your serving. Okay. Um in the west, you know, I think
2430.6s
different devices. We'll talk about
2432.1s
expert parallel later, which is what
2433.5s
this is, but this gives you an
2434.8s
additional systems thing that you can
2436.9s
optimize, which is very helpful if you
2438.9s
want to sort of optimize your serving.
2442.1s
want to sort of optimize your serving. Okay.
2443.0s
Um in the west, you know, I think
2445.2s
there's been several open-source releases that are big strong MOEs. On the big model side, there was Llama 4 On the big model side, there was Llama 4 and GPT-OSS from OpenAI. Unfortunately, I guess open-source model releases in the west have kind of stalled for the most part. So, there aren't quite as many
2447.1s
releases that are big strong MOEs.
2450.0s
On the big model side, there was Llama 4
2452.7s
On the big model side, there was Llama 4 and
2454.1s
GPT-OSS from OpenAI. Unfortunately, I
2456.6s
guess open-source model releases in the
2459.1s
west have kind of stalled for the most
2461.1s
part. So, there aren't quite as many
2463.3s
being released now, but you know, there's at least two good examples of there's at least two good examples of this. You know, both top-tier models in their own right. own right. Um But I think a lot of the action for both MOE research and training has happened in China. Qwen and DeepSeek
2464.5s
there's at least two good examples of
2466.2s
there's at least two good examples of this.
2467.4s
You know, both top-tier models in their
2469.6s
own right.
2470.9s
own right. Um
2472.2s
But I think a lot of the action for both
2474.5s
MOE research and training has happened
2477.2s
in China.
2479.6s
Qwen and DeepSeek
2482.0s
and some others like MiniCPM did some of the earliest work kind of training and popularizing MOEs. They continue to do really great work on MOEs. You know, if you look at some of the early Qwen MOE models, this was Qwen 1.5 early Qwen MOE models, this was Qwen 1.5 MOE. You know, they showed that they're like
2483.8s
the earliest work kind of training and
2486.0s
popularizing MOEs. They continue to do
2488.7s
really great work on MOEs.
2491.3s
You know, if you look at some of the
2492.4s
early Qwen MOE models, this was Qwen 1.5
2494.8s
early Qwen MOE models, this was Qwen 1.5 MOE.
2496.2s
You know, they showed that they're like
2497.9s
small 2.7 billion parameter active model you know, was doing better than many of the 7B models at the time. And I think some of this early proof of work from DeepSeek and Qwen really convinced everyone else in the open-source community that this was the right way to go. Um And speaking of DeepSeek, oh, was there
2501.1s
you know, was doing better than many of
2502.8s
the 7B models at the time. And I think
2505.0s
some of this early proof of work from
2506.6s
DeepSeek and Qwen really convinced
2509.5s
everyone else in the open-source
2510.8s
community that this was the right way to
2512.4s
go. Um
2514.4s
And speaking of DeepSeek, oh, was there
2516.1s
a question? Yes. I guess one one question I had about like the expert parallel. Like isn't there like a communication bottleneck that's built up because these experts are communicating with each other? Yes. So, in some sense, you pay the communication cost of like shipping an activation over. And you know, I will
2517.4s
I guess one one question I had about
2519.6s
like the expert parallel. Like isn't
2522.2s
there like a communication bottleneck
2523.9s
that's built up because these experts
2526.0s
are communicating with each other?
2527.9s
Yes. So, in some sense, you pay the
2530.3s
communication cost of like shipping an
2531.8s
activation over. And you know, I will
2535.0s
not explain, but I will show you one trick of like how you reduce that. But in general, you're like trading certain things off. You're getting more aggregate flops. You're reducing your memory use. In exchange, you're going to pay for coms, right? So, you know, it's highly dependent on your topology and all these other things whether this is
2536.3s
trick of like how you reduce that.
2539.1s
But in general, you're like trading
2540.5s
certain things off. You're getting more
2542.2s
aggregate flops. You're reducing your
2543.8s
memory use. In exchange, you're going to
2545.4s
pay for coms, right? So, you know, it's
2547.7s
highly dependent on your topology and
2550.0s
all these other things whether this is
2551.0s
going to be a net win. Cool. Oh, yeah. I was going to say during training though, though they During Okay, this is this is the important part. During training, they are also sparse, right? This is really the key thing that makes MOEs hard. I mean, I'm I'm, you know, spoilering sort
2553.2s
Cool. Oh, yeah. I was going to say
2554.6s
during training though,
2556.0s
though they
2557.6s
During Okay, this is this is the
2558.8s
important part. During training, they
2560.2s
are also sparse, right? This is really
2562.2s
the key thing that makes MOEs hard. I
2564.3s
mean, I'm I'm, you know, spoilering sort
2566.3s
of the key idea from a few slides down, but it would be very easy if we activated all the experts during training, right? Because then we get to see, oh, which expert is good for my input. So, you you get to learn how to route. The hard thing about MOEs is that during training, you only have, you
2570.1s
but it would be very easy if we
2571.7s
activated all the experts during
2573.2s
training, right? Because then we get to
2574.2s
see, oh, which expert is good for my
2575.6s
input. So, you you get to learn how to
2577.2s
route. The hard thing about MOEs is that
2579.8s
during training, you only have, you
2581.9s
know, one or K experts active. So, you don't know what happened with the rest of them. Right? You only know the ones that activated. Despite this, you must somehow learn to route. Right? So, it's kind of got this RL bandit flavor to the kind of got this RL bandit flavor to the problem.
2583.5s
don't know what happened with the rest
2584.8s
of them.
2585.6s
Right? You only know the ones that
2587.0s
activated. Despite this, you must
2588.8s
somehow learn to route. Right? So, it's
2590.5s
kind of got this RL bandit flavor to the
2592.8s
kind of got this RL bandit flavor to the problem.
2594.0s
But we're not going to solve it with either RL or bandit. So, we're going to solve it with the power of heuristics and deep learning magic, right? Okay, and deep learning magic, right? Okay, yes. What is the granularity of switching between experts? Uh so, the granularity as in like the scale at which we route
2595.2s
either RL or bandit. So, we're going to
2596.5s
solve it with the power of heuristics
2598.1s
and deep learning magic, right? Okay,
2599.5s
and deep learning magic, right? Okay, yes.
2600.5s
What is the granularity of switching
2602.6s
between experts? Uh so, the granularity
2605.6s
as in like the scale at which we route
2607.3s
experts is at token level. If I don't know if Okay, good. Good. That was the question. Yes, at the token level. So, every token gets an expert. And the routers, you know, just to get your your sort of like mental model correct, the routers are super naive, right? They're like a single matrix multiply between
2609.0s
know if Okay, good. Good. That was the
2610.1s
question. Yes, at the token level. So,
2611.4s
every token gets an expert. And the
2613.1s
routers, you know, just to get your your
2614.8s
sort of like mental model correct, the
2616.2s
routers are super naive, right? They're
2617.8s
like a single matrix multiply between
2620.0s
your input and your whatever. And so, you're not going to do anything complicated. You don't know it's like a medical question or not. You're like, oh, this is a token that looks like it's in, you know, I don't know, Japanese. Let's route it to expert seven, right? Yes. Uh yeah, so is there any upper limit to parallelizing
2622.2s
you're not going to do anything
2623.1s
complicated. You don't know it's like a
2624.5s
medical question or not. You're like,
2626.4s
oh, this is a token that looks like it's
2627.9s
in, you know, I don't know, Japanese.
2629.0s
Let's route it to expert seven, right?
2631.5s
Yes. Uh yeah, so is there any upper
2633.7s
limit to parallelizing
2634.4s
limit to parallelizing clears throat and snorts clears throat and snorts uh these uh MOE models in in training or inference? Yeah, in general, there's upper limit to parallelization because, you know, as you shard over more and more devices, the communication cost explodes, right? You will Well, if we if we get the assignment to design done correctly, you
2635.4s
clears throat and snorts uh
2636.1s
these uh MOE models
2639.0s
in in training or inference? Yeah, in
2641.6s
general, there's upper limit to
2642.7s
parallelization because, you know, as
2644.0s
you shard over more and more devices,
2646.2s
the communication cost explodes, right?
2648.7s
You will Well, if we if we get the
2650.7s
assignment to design done correctly, you
2653.0s
will have to deal with some of these in your assignment. We will we will ask you to figure out the networking topology and things like this that will make sharding make sense. Or it's the reverse. Given a networking topology, you'll have to figure out how to shard models and things like this. Yeah. Good. Okay. Um right.
2654.4s
your assignment. We will we will ask you
2655.8s
to figure out the networking topology
2657.4s
and things like this that will
2659.1s
make sharding make sense. Or it's the
2660.4s
reverse. Given a networking topology,
2662.3s
you'll have to figure out how to shard
2664.2s
models and things like this. Yeah.
2666.5s
Good. Okay.
2668.5s
Um right.
2670.5s
DeepSeek, you know, now known for like DeepSeek R1 and so on. But I think the thing I really like about DeepSeek, and originally this lecture was a, you know, let's just go through a DeepSeek paper lecture that very first iteration. Um is that they've done actually really good architecture and LLM science work
2672.6s
DeepSeek R1 and so on. But I think the
2674.8s
thing I really like about DeepSeek, and
2676.2s
originally this lecture was a, you know,
2678.0s
let's just go through a DeepSeek paper
2679.4s
lecture that very first iteration.
2681.6s
Um is that they've done actually really
2683.1s
good architecture and LLM science work
2686.3s
for a long time. And DeepSeek were one of the earlier ones, you know, going through and showing why MOEs are good. And they have a lot of ton of good ablation work on like, oh, what happens if we use a dense layer? What if we use hash routing? What if we use a switch
2688.0s
of the earlier ones, you know, going
2689.8s
through and showing why MOEs are good.
2691.9s
And they have a lot of ton of good
2693.4s
ablation work on like, oh, what happens
2695.5s
if we use a dense layer? What if we use
2697.1s
hash routing? What if we use a switch
2698.7s
router? You know, all these different kinds of design decisions you can see in some of their earlier DeepSeek MOE papers. So, if you're interested in a lot of these questions like architecture design and so on, I would suggest you read their earlier DeepSeek papers. There's a lot to learn from all of these, right?
2700.2s
kinds of design decisions you can see in
2702.2s
some of their earlier DeepSeek MOE
2703.7s
papers. So, if you're interested in a
2705.5s
lot of these questions like architecture
2707.0s
design and so on, I would suggest you
2708.8s
read their earlier DeepSeek papers.
2710.3s
There's a lot to learn from all of
2711.8s
these, right?
2713.2s
And of course, progress hasn't stopped. You know, DeepSeek V3 and V3.2 and like GLM and so on have built these like wonderful very strong MOE models. It's quite clear that this is kind of where the future of these big models are going to be for the foreseeable future. At least the next couple years, I
2715.0s
You know, DeepSeek V3 and V3.2 and like
2718.2s
GLM and so on have built these like
2720.3s
wonderful very strong MOE models.
2723.1s
It's quite clear that this is kind of
2724.6s
where the future of these big models are
2726.8s
going to be for the foreseeable future.
2729.4s
At least the next couple years, I
2730.5s
imagine will be MOEs. Okay. So, now the question is, why haven't they they be Why haven't they been more popular? In many different ways, I feel like MOEs caught on pretty slowly. You know, you look at some of these papers like in 2022, right? Google was studying and really trying to like push MOEs. Um
2733.2s
Okay. So, now the question is, why
2736.0s
haven't they they be
2737.9s
Why haven't they been more popular? In
2740.4s
many different ways, I feel like MOEs
2742.2s
caught on pretty slowly. You know, you
2744.0s
look at some of these papers like in
2745.4s
2022, right? Google was studying and
2748.1s
really trying to like push MOEs. Um
2751.3s
but, you know, it's only in like 2024 onwards that MOEs really started to catch on. Like why? Also, you know, if you're doing LLM research of various kinds, you're probably mostly working with dense models. Like why not MOE models if they're so nice, right? We can save on compute. Um there's a lot of complexities that
2753.8s
onwards that MOEs really started to
2755.4s
catch on. Like why? Also, you know, if
2757.6s
you're doing LLM research of various
2759.4s
kinds, you're probably mostly working
2761.2s
with dense models. Like why not MOE
2763.3s
models if they're so nice, right? We can
2764.9s
save on compute.
2766.3s
Um there's a lot of complexities that
2768.2s
come from MOEs. Like it's not easy to train a MOE or to use a MOE to or to do many things. Um you know, the infrastructure is very complex. I mean, there were a couple questions about parallelization, but this is like generally right. That it's hard to parallelize, you know, experts in ways that are like efficient in
2770.0s
train a MOE or to use a MOE to or to do
2772.4s
many things.
2774.0s
Um you know, the infrastructure is very
2776.3s
complex. I mean, there were a couple
2777.5s
questions about parallelization, but
2779.3s
this is like generally right. That it's
2780.8s
hard to parallelize, you know, experts
2783.3s
in ways that are like efficient in
2784.9s
utilization. MOEs have a lot of parameters. It's hard to fit them on a single device. And if you're training them, the MOEs can really blow up on you. And you'll see why in a few slides why MOEs are just really hard to deal with. But they're just not easy objects to work
2786.9s
parameters. It's hard to fit them on a
2788.4s
single device.
2790.2s
And if you're training them, the MOEs
2792.4s
can really blow up on you. And you'll
2794.0s
see why in a few slides why MOEs are
2796.5s
just really hard to deal with. But
2798.0s
they're just not easy objects to work
2800.3s
with, right? I mean, there's good rules of thumb now, but it's not easy. of thumb now, but it's not easy. Okay. Um I'll briefly mention with one slide that there are people that also, you know, kind of mixture of experts the attention block in transformers. There have been a couple papers,
2802.1s
of thumb now, but it's not easy.
2805.1s
of thumb now, but it's not easy. Okay.
2806.2s
Um I'll briefly mention with one slide
2808.4s
that there are people that also, you
2810.8s
know, kind of mixture of experts the
2812.3s
attention block in transformers. There
2815.1s
have been a couple papers,
2817.1s
but they are much less common. And I think the, you know, things that I have seen is that they haven't been quite as easy to tame or to to get to work compared to just replacing the the FFN or the MLP layer, right? So, I'm only going to talk about the left. This is
2819.8s
think the, you know, things that I have
2821.9s
seen is that they haven't been quite as
2823.6s
easy to tame or to to get to work
2826.5s
compared to just replacing the the FFN
2829.5s
or the MLP layer, right? So, I'm only
2831.5s
going to talk about the left. This is
2832.6s
basically uh what all the big models do. Um people are not, for the most part, sharding or sorry, splitting up attention heads into experts. Okay. So, that was kind of the broad industrial overview of like MOEs and like what the state of things are. Um but now let's like talk about how you
2835.2s
Um people are not, for the most part,
2837.5s
sharding or sorry, splitting up
2839.3s
attention heads into experts.
2843.4s
Okay. So, that was kind of the broad
2845.9s
industrial overview of like MOEs and
2848.2s
like what the state of things are. Um
2850.3s
but now let's like talk about how you
2851.6s
design an MOE, how you think about an MOE. So, um to do that, I'm going to talk about three different kind of axes of variation, right? Um so, what kinds of things can change for MOEs? Well, you can change the routing function, right? We know that MOEs have this core thing of like a router has to
2853.4s
MOE. So,
2855.5s
um to do that, I'm going to talk about
2857.1s
three different kind of
2859.0s
axes of variation, right? Um so, what
2861.4s
kinds of things can change for MOEs?
2863.9s
Well, you can change the routing
2865.0s
function, right? We know that MOEs have
2866.9s
this core thing of like a router has to
2869.2s
send tokens to experts, right? So, we can change that. We can change the sizes of the different experts, right? For a given budget, you might have more experts with fewer parameters or fewer experts with more parameters. Um and then finally, we're going to talk about training, right? Because training is going to be really difficult. Like once
2870.8s
can change that. We can change the sizes
2873.3s
of the different experts, right? For a
2874.6s
given budget, you might have more
2876.9s
experts with fewer parameters or fewer
2879.7s
experts with more parameters. Um and
2881.9s
then finally, we're going to talk about
2883.2s
training, right? Because training is
2884.5s
going to be really difficult. Like once
2885.9s
you start to realize how MOEs work, you realize, wow, it is not easy to train something that routes tokens to different places.
2888.8s
realize, wow, it is not easy to train
2891.5s
something that routes tokens to
2892.9s
different places.
2898.8s
activate all the experts, right? Like we want to only have some subset of experts active. Um and so, we are going to always be choosing some top K. There is different kinds of choices that you can different kinds of choices that you can do. You can have the token choose the expert. You can have the expert choose
2901.1s
want to only have some subset of experts
2903.3s
active. Um and so, we are going to
2906.2s
always be choosing some top K. There is
2909.2s
different kinds of choices that you can
2910.9s
different kinds of choices that you can do.
2912.2s
You can have the token choose the
2914.2s
expert. You can have the expert choose
2916.4s
the token. So, each expert picks, you know, their favorite tokens. Or you can have each token uh pick their favorite experts. Or you can somehow have this like very complicated router that like globally decides, oh yes, you token should go to this expert and this other token should go to that other experts to
2918.0s
know, their favorite tokens. Or you can
2919.3s
have each token uh pick their favorite
2921.4s
experts. Or you can somehow have this
2923.7s
like very complicated router that like
2925.8s
globally decides, oh yes, you token
2927.9s
should go to this expert and this other
2929.4s
token should go to that other experts to
2931.2s
try to like globally optimize try to like globally optimize assignments. assignments. Um Um So, basically almost all the MOEs do token choice top K. And so, the token is going to be the one to choose the experts. And there's going to be K different experts that get sort of selected to be in the
2932.8s
try to like globally optimize assignments.
2934.0s
assignments. Um
2935.4s
Um So,
2936.5s
basically almost all the MOEs do token
2939.6s
choice top K. And so, the token is going
2942.2s
to be the one to choose the experts. And
2944.3s
there's going to be K different experts
2945.8s
that get sort of selected to be in the
2948.0s
that get sort of selected to be in the pool. And you can sort of see, you know, this is from the OLMo paper, differences between token choice and expert choice. You see that token choice is going to get, you know, lower validation losses, higher downstream benchmark scores compared to expert choice. Um There have
2949.7s
And you can sort of see, you know, this
2951.5s
is from the OLMo paper,
2953.6s
differences between token choice and
2955.1s
expert choice.
2956.5s
You see that token choice is going to
2957.8s
get, you know, lower validation losses,
2959.8s
higher downstream benchmark scores
2962.0s
compared to expert choice. Um There have
2964.5s
been successful expert choice models. I mean, this also trains fine, you know. But token choice has generally been much easier to get working. It has been kind of the standard for all of the models that kind of we see today. I think one of the llama four models was maybe expert choice, but I don't think
2966.0s
mean, this also trains fine, you know.
2968.5s
But token choice has generally been much
2970.8s
easier to get working. It has been kind
2973.1s
of the standard for all of the models
2974.8s
that kind of we see today.
2976.7s
I think one of the llama four models was
2978.3s
maybe expert choice, but I don't think
2980.4s
that's necessarily a a strong vote of that's necessarily a a strong vote of confidence. confidence. Uh Uh What? I don't think I think that was one of the unreleased llama four models, you know, in my in my defense. Um okay, so I think there's many different ways that you can do routing.
2982.6s
that's necessarily a a strong vote of confidence.
2984.0s
confidence. Uh
2985.2s
Uh What?
2986.7s
I don't think I think that was one of
2987.7s
the unreleased llama four models, you
2989.2s
know, in my in my defense.
2990.9s
Um okay, so I think there's many
2992.5s
different ways that you can do routing.
2995.3s
So one of the ways that you can do routing is you can actually have a small small sort of ML like a linear projection that says, Okay, I'm going to get my input. My router is an inner inner product. Um each expert has sort of a vector direction and the sort of the closest in
2996.5s
routing is you can actually have a small
2999.3s
small sort of
3000.9s
ML like a linear projection that says,
3003.2s
Okay, I'm going to get my input. My
3004.9s
router is an inner inner product. Um
3007.4s
each expert has sort of a vector
3009.3s
direction and the sort of the closest in
3011.0s
inner product space will get selected. This is by far the most common router. It's used in the classic switch transformer, the GShard, those two are the early Google papers. It's used in Grok, Mistral, Plan, DBRX, Deep Seek, and others, right? Um the number of K that you choose varies depending on, you know, which paper you're looking at.
3013.4s
This is by far the most common router.
3016.2s
It's used in the classic switch
3017.6s
transformer, the GShard, those two are
3019.8s
the early Google papers. It's used in
3021.9s
Grok, Mistral, Plan, DBRX, Deep Seek,
3024.7s
and others, right? Um the number of K
3027.6s
that you choose varies depending on, you
3029.7s
know, which paper you're looking at.
3031.9s
Um but for the most part, you know, you're you're just taking inner you're you're just taking inner products. Um one thing that has always been a little mysterious to me um is that many many papers have shown that actually you don't need to do any sort of like learned routing. You can just take your
3033.4s
you're you're just taking inner
3034.4s
you're you're just taking inner products.
3035.8s
Um one thing that has always been a
3037.5s
little mysterious to me um is that many
3041.2s
many papers have shown that actually you
3043.3s
don't need to do any sort of like
3044.5s
learned routing. You can just take your
3046.0s
X's, you can hash them, and send them to different uh feedforward networks. And actually this is fine, too. This sometimes or this gives gains. Not as much as top K, but somehow hashing-based uh expert routing is also, you know, something that people do. Common baseline in papers. Um not used in in deployment for for
3048.0s
different uh feedforward networks. And
3050.3s
actually this is fine, too. This
3051.4s
sometimes or this gives gains. Not as
3053.5s
much as top K, but somehow hashing-based
3056.2s
uh expert routing is also, you know,
3057.9s
something that people do. Common
3059.3s
baseline in papers.
3060.9s
Um not used in in deployment for for
3062.8s
real reasons for many reasons. Um you can do RL to learn the routes. You can basically have a router that's a inner product router. Um and you can sort of use reinforcement learning to try to treat this router as a kind of policy and then learn it. Um You know, if you're a machine learning
3064.8s
Um you can do RL to learn the routes.
3067.4s
You can basically have a router that's a
3068.9s
inner product router.
3070.7s
Um and you can sort of use reinforcement
3073.1s
learning to try to treat this router as
3075.0s
a kind of policy and then learn it. Um
3078.4s
You know, if you're a machine learning
3079.8s
like classic, you know, learning theory type oriented person, I think this should be the natural way to think about the problem, right? You're like, Okay, what am I doing? It's kind of like a bandit problem. I'm selecting one out of K or K out of N. Um and I don't observe all of them. This
3082.2s
type oriented person, I think this
3084.0s
should be the natural way to think about
3085.3s
the problem, right? You're like, Okay,
3086.5s
what am I doing? It's kind of like a
3088.2s
bandit problem. I'm selecting one out of
3090.2s
K or K out of N.
3092.7s
Um and I don't observe all of them. This
3095.0s
is a bandit problem. I will use a bandit or an RL algorithm, right? Um used in some of the earliest work because this is kind of the natural way to think about the problem like Bengio in 2013. Um but it is not a common approach. Um And I think the reason why it's not
3097.2s
or an RL algorithm, right? Um used in
3099.7s
some of the earliest work because this
3101.1s
is kind of the natural way to think
3102.4s
about the problem like Bengio in 2013.
3105.0s
Um but it is not a common approach. Um
3108.6s
And I think the reason why it's not
3109.8s
common is because, you know, the approaches needed to you do RL introduce a lot of overhead in terms of both the RL algorithm and the stochasticity. RL algorithm and the stochasticity. Um and what people have figured out is there's a good set of heuristics like these recipes that you can apply to this
3112.0s
approaches needed to you do RL introduce
3114.3s
a lot of overhead in terms of both the
3116.4s
RL algorithm and the stochasticity.
3119.1s
RL algorithm and the stochasticity. Um
3120.4s
and what people have figured out is
3122.0s
there's a good set of heuristics like
3124.1s
these recipes that you can apply to this
3126.6s
very simple classic top K routing scheme um that basically just has it working, right? And so there's no reason to do something much more complicated. Um I will finally sort of mention, you know, an idea that I think is really cool and that I love as a person that likes things that make sense. Um you can solve
3129.6s
um that basically just has it working,
3132.4s
right? And so there's no reason to do
3134.2s
something much more complicated. Um I
3136.7s
will finally sort of mention, you know,
3139.0s
an idea that I think is really cool and
3141.0s
that I love as a person that likes
3142.6s
things that make sense. Um you can solve
3144.6s
a linear assignment problem where you say, Okay, I'm going to globally compute kind of the score that says like, oh, it would be um this good for this token to go to that expert. Right? So you compute this all pairwise scores of how good the assignments are. And you can use linear
3146.6s
say, Okay, I'm going to globally
3147.9s
compute kind of the score that says
3149.5s
like, oh, it would be um
3151.5s
this good for this token to go to that
3153.1s
expert. Right? So you compute this all
3154.5s
pairwise scores of how good the
3156.0s
assignments are. And you can use linear
3158.0s
assignment to exactly solve for the for the optimal assignment from experts to tokens. Um it's been used. It's been shown to to do good things in some shown to to do good things in some cases. Um hasn't been seen at scale at all, I would think because it's, you know, extremely expensive relative to the
3160.3s
the optimal assignment from experts to
3162.0s
tokens. Um it's been used. It's been
3164.2s
shown to to do good things in some
3166.3s
shown to to do good things in some cases.
3167.4s
Um hasn't been seen at scale at all, I
3170.5s
would think because it's, you know,
3171.7s
extremely expensive relative to the
3173.4s
extremely expensive relative to the others uh to get working. uh to get working. Okay. Um so top K routing, as I said, is basically the consensus routing mechanism. Um how does this work? It's very very simple. And if you, you know, were were paying attention in the DSA uh slide, this looks a lot like DSA, right?
3174.6s
uh to get working.
3176.8s
uh to get working. Okay.
3178.0s
Um so top K routing, as I said, is
3180.2s
basically the consensus routing
3181.6s
mechanism. Um how does this work? It's
3184.4s
very very simple. And if you, you know,
3186.3s
were were paying attention in the DSA uh
3189.4s
slide, this looks a lot like DSA, right?
3192.3s
Um you'll also notice this if you were reading some other papers like H-Nets or these other kinds of things. This will appear in many different places. So this is a good kind of pattern to be able to to sort of recognize. Um so what you're doing is you are going to doing is you are going to um
3194.5s
reading some other papers like H-Nets or
3196.5s
these other kinds of things. This will
3197.7s
appear in many different places. So this
3199.2s
is a good kind of pattern to be able to
3201.4s
to sort of recognize. Um so what you're
3203.7s
doing is you are going to
3206.6s
doing is you are going to um
3207.6s
take in your inputs. Um and you're going to take it through sort of a feedforward. And then I'm going to compute um some sort of scores. And the scores that I'm going to compute is going to be um the softmax between the experts and sort of my my inputs U. And then this S is
3209.7s
to take it through sort of a
3210.8s
feedforward. And then I'm going to
3212.4s
compute um some sort of scores. And the
3214.8s
scores that I'm going to compute is
3216.9s
going to be um
3218.6s
the softmax between the experts and sort
3221.2s
of my my inputs U. And then this S is
3224.1s
going to be the scores. And I'm going to select the top K of those guys, and that's going to be the gate that's going to control each of my experts, right? So so this is like input, this is the residual term, this is the FFN. I've got gates, and the gates are controlled top
3225.7s
select the top K of those guys, and
3227.8s
that's going to be the gate that's going
3229.4s
to control each of my experts, right? So
3231.6s
so this is like input, this is the
3233.4s
residual term, this is the FFN. I've got
3235.5s
gates, and the gates are controlled top
3237.2s
K, right? This is kind of the equation form of what you would imagine of a top K router. And the only, you know, detail that you're probably learning here is how do I learn my gates? My gates are just learned by taking inner product between a weight for each expert and the
3238.7s
form of what you would imagine of a top
3241.6s
K router. And the only, you know, detail
3243.4s
that you're probably learning here is
3245.4s
how do I learn my gates? My gates are
3247.9s
just learned by taking inner product
3249.7s
between a weight for each expert and the
3252.1s
inputs that I have, right? Very very lightweight, right? So there's nothing complicated about the ways that I'm going to select my gates. going to select my gates. Um even within this simple thing, there's been some innovation and variation from uh DeepSeek um and friends. Um DeepSeek MoE uh pioneered this idea, which is now very
3254.0s
lightweight, right? So there's nothing
3255.4s
complicated about the ways that I'm
3257.2s
going to select my gates.
3259.5s
going to select my gates. Um
3260.7s
even within this simple thing, there's
3262.3s
been some innovation and variation from
3264.8s
uh DeepSeek
3266.2s
um and friends. Um DeepSeek MoE uh
3269.7s
pioneered this idea, which is now very
3271.6s
widely used, called shared experts. widely used, called shared experts. Um this is a very intuitive idea um where, you know, in the classic top K router you would have, you know, K different Oh, sorry. In this case, N different experts. Um and these experts, you know, you would select amongst them. But sometimes you might say, Well,
3274.2s
widely used, called shared experts. Um
3274.8s
this is a very intuitive idea
3277.3s
um where, you know, in the classic top K
3279.4s
router you would have, you know, K
3281.2s
different Oh, sorry. In this case, N
3282.5s
different experts. Um and these experts,
3285.5s
you know, you would select amongst them.
3287.8s
But sometimes you might say, Well,
3289.6s
there's some processing some very common processing that I want to apply to all tokens, right? Like maybe you don't want experts to be different every time. You want some experts to always be applied. And then some other experts to be applied conditionally, right? So um what DeepSeek MoE did was they took sort
3291.6s
processing that I want to apply to all
3293.0s
tokens, right? Like maybe you don't
3295.2s
want experts to be different every time.
3297.8s
You want some experts to always be
3299.9s
applied. And then some other experts to
3302.5s
be applied conditionally, right? So um
3305.1s
what DeepSeek MoE did was they took sort
3307.9s
of the uh original expert design. They cut up the experts into smaller finer-grained chunks. And some subset of those fine-grained experts were were sort of set as shared experts that are always on, right? So these sort of bypass the router. They always process our inputs and they sort of go back on the outputs. Oh, sorry.
3310.7s
cut up the experts into smaller
3312.4s
finer-grained chunks. And some subset of
3314.8s
those fine-grained experts were were
3317.0s
sort of set as shared experts that are
3318.8s
always on, right? So these sort of
3320.4s
bypass the router. They always process
3322.7s
our inputs and they sort of go back on
3324.8s
the outputs. Oh, sorry.
3326.4s
the outputs. Oh, sorry. Um and this is a nice idea because it turned out that, you know, if you have this classic expert design, you were just kind of reusing a lot of these weights to do common modeling. And so you can offload this into the one shared expert, letting these guys specialize
3327.9s
and this is a nice idea because it
3330.0s
turned out that, you know, if you have
3331.7s
this classic expert design, you were
3333.5s
just kind of reusing a lot of these
3335.0s
weights to do common modeling. And so
3337.3s
you can offload this into the one shared
3338.8s
expert, letting these guys specialize
3341.1s
even more. Um and you know, as I was kind of saying, DeepSeek is really nice in that they do careful ablations in a lot of their papers. Um and so, you know, they have this very nice ablation where they look at, you know, zero shared experts, and then you have these big experts, 16 routed experts. And this
3344.0s
kind of saying, DeepSeek is really nice
3346.5s
in that they do careful ablations in a
3348.2s
lot of their papers. Um and so, you
3350.5s
know, they have this very nice ablation
3352.2s
where they look at, you know, zero
3353.7s
shared experts, and then you have these
3355.8s
big experts, 16 routed experts. And this
3357.9s
is really the old GShard Google design. Um compared to, you know, they have very fine-grained segmentation of experts and one shared expert amongst these. And you know, as you make the experts smaller and smaller um and you have one shared expert, you see significant gains from both of those interventions. And and for
3360.5s
Um compared to, you know, they have very
3363.1s
fine-grained segmentation of experts and
3365.4s
one shared expert amongst these. And you
3367.4s
know, as you make the experts smaller
3369.4s
and smaller um and you have one shared
3371.5s
expert, you see significant gains from
3374.0s
both of those interventions. And and for
3375.6s
some of these, you know, the shared expert has very significant improvements like TriviaQA, Natural Questions. This blue to yellow gap, right, is shared experts or not. So more experts, shared experts all generally seem to help. And if you want sort of more validation of uh these claims and ideas, uh you see
3376.6s
expert has very significant improvements
3378.7s
like TriviaQA, Natural Questions. This
3380.6s
blue to yellow gap, right, is shared
3382.6s
experts or not.
3384.3s
So more experts, shared experts all
3387.3s
generally seem to help. And if you want
3388.9s
sort of more validation of
3392.5s
uh these claims and ideas, uh you see
3394.4s
this from Elmo as well. Elmo was I think the nice uh western sort of carefully controlled uh MoE study. They do things like, Oh, do shared experts help? Um they conclude that shared experts don't help very much. Um and you know, number of experts they show, you know, doing fine-grained many experts is helpful. So
3396.3s
the nice uh western sort of carefully
3399.1s
controlled uh MoE study. They do things
3402.0s
like, Oh, do shared experts help? Um
3404.0s
they conclude that shared experts don't
3405.6s
help very much. Um and you know, number
3408.4s
of experts they show, you know, doing
3409.8s
fine-grained many experts is helpful. So
3412.2s
they disagree a little bit on the shared expert point. Um but, you know, they also have very nice carefully controlled ablations. Um many of the recent MoEs uh do follow essentially the DeepSeek design. Um you know, there were a couple of early MoEs uh up until this point. Um the the first three, you know, I I keep
3413.8s
expert point. Um
3415.6s
but, you know, they also have very nice
3417.8s
carefully controlled ablations.
3419.9s
Um many of the recent MoEs uh do follow
3423.1s
essentially the DeepSeek design. Um
3425.5s
you know, there were a couple of early
3427.6s
MoEs uh up until this point. Um
3430.6s
the the first three, you know, I I keep
3432.8s
giving Google praise for doing lots of interesting things. You know, the first three were from Google. Um Mistral, DBRX, and Grok were some of the early MoE attempts in the west. DeepSeek V1 comes up with sort of both the fine-grained and shared expert design. And then that sort of catches on and
3434.6s
interesting things. You know, the first
3435.8s
three were from Google. Um Mistral,
3438.1s
DBRX, and Grok were some of the early
3439.9s
MoE attempts in the west. DeepSeek V1
3442.8s
comes up with sort of both the
3443.9s
fine-grained and shared expert design.
3446.0s
And then that sort of catches on and
3447.3s
everyone else sort of, you know, follows from DeepSeek. I mean, it's a similar, I would say, phenomenon to to the llama design for dense transformers being, you know, essentially the standard. DeepSeek MoE and DeepSeek V3 are kind of the standard design that people have copied for the MoE side of things.
3448.9s
from DeepSeek. I mean, it's a similar, I
3450.7s
would say, phenomenon to to the llama
3452.9s
design for dense transformers being, you
3455.2s
know, essentially the standard. DeepSeek
3457.2s
MoE and DeepSeek V3 are kind of the
3459.2s
standard design that people have copied
3461.1s
for the MoE side of things.
3464.4s
And if you look at like Qwen 3.5 or GLM or any of these modern models, you continue to see uh both shared experts and fine-grained experts uh in widespread use amongst all these people. And so that's been a very battle-tested um design. um design. Okay. Um any questions? Oh, yes, good. Yeah, if there is a
3466.6s
or any of these modern models, you
3468.4s
continue to see uh both shared experts
3471.0s
and fine-grained experts uh in
3472.6s
widespread use amongst all these people.
3475.7s
And so that's been a very battle-tested
3478.0s
um design.
3479.6s
um design. Okay.
3480.5s
Um any questions? Oh, yes, good. Yeah,
3482.0s
if there is a
3483.2s
if there is a shared expert uh or there are like few shared experts, right? Uh how would parallelization work in that case? Um Yeah. So in the shared expert case, you know, you wouldn't get any parallelization savings. Like every every activation has to route through those. Like you can copy the shared expert um
3485.9s
are like few shared experts, right?
3488.5s
Uh how would
3489.5s
parallelization work in that case?
3492.4s
Um Yeah. So in the shared expert case,
3494.9s
you know, you wouldn't get any
3495.7s
parallelization savings. Like every
3497.2s
every activation has to route through
3498.5s
those. Like you can copy the shared
3499.6s
expert um
3500.9s
to basically eat memory savings to reduce comms cost. All right. So that's kind of the sort of design space of routers. Now I want to talk about training MoEs. And this is like, you know, it gets into really surprising and I would say very deep learning things about this. Um I think initially when I was learning about
3502.5s
reduce comms cost.
3507.0s
All right. So that's kind of the sort of
3509.3s
design space of routers. Now I want to
3511.4s
talk about training MoEs. And this is
3513.2s
like, you know, it gets into really
3516.0s
surprising and I would say very deep
3517.8s
learning things about this. Um I think
3520.3s
initially when I was learning about
3522.0s
MoEs, I thought there's no way that we can train these things well or like reasonably. But it turns out there are lots of, you know, tricks that when when combined together like just really work well um and robustly for for some well um and robustly for for some reason. reason. Um
3523.9s
can train these things well or like
3526.2s
reasonably. But it turns out there are
3528.7s
lots of, you know, tricks that when when
3531.2s
combined together like just really work
3532.8s
well um and robustly for for some
3534.7s
well um and robustly for for some reason.
3535.8s
reason. Um
3537.0s
so as was sort of raised in the previous question, right? We don't want to activate all the experts during training, right? Because if we do that, we're going to pay the full flops cost of all of our experts. And we have a lot of experts. So we need sparsity for efficiency of training time.
3538.4s
question, right? We don't want to
3539.9s
activate all the experts during
3541.0s
training, right? Because if we do that,
3542.4s
we're going to pay the full flops cost
3544.0s
of all of our experts. And we have a lot
3545.9s
of experts. So we need sparsity for
3548.2s
efficiency of training time.
3550.0s
But if we have sparsity, then gating decisions are no longer differentiable. And we also don't see all the different counterfactual experts we could have picked, right? So these are both big picked, right? So these are both big problems. Um so there are different solutions. We could use RL to optimize our gating policies. I already mentioned this, not
3552.6s
decisions are no longer differentiable.
3554.2s
And we also don't see all the different
3555.9s
counterfactual experts we could have
3557.8s
picked, right? So these are both big
3559.0s
picked, right? So these are both big problems.
3560.2s
Um so there are different solutions. We
3562.0s
could use RL to optimize our gating
3563.5s
policies. I already mentioned this, not
3565.5s
super popular. Um you could do things like stochastic perturbations, right? So if you're familiar with like, you know, explore-exploit bandit style stuff, you could do some of that. Now, the third one is you could kind of have a whole bunch of weird heuristics that, you know, deal with the fact that the gating decisions are sparse to
3566.9s
Um you could do things like stochastic
3568.3s
perturbations, right? So if you're
3570.0s
familiar with like, you know,
3571.5s
explore-exploit bandit style stuff, you
3573.4s
could do some of that.
3575.0s
Now, the third one is you could kind of
3576.9s
have a whole bunch of weird heuristics
3579.0s
that, you know, deal with the fact that
3581.2s
the gating decisions are sparse to
3582.6s
balance out which experts go where. Um you know, guess which one people use in practice. I've been kind of hinting at this. Um it's number three, right? It's a it's a collection of really interesting heuristics that end up interesting heuristics that end up working. Um I'll talk about uh the other two
3585.2s
Um you know, guess which one people use
3587.2s
in practice. I've been kind of hinting
3588.6s
at this. Um it's number three, right?
3590.8s
It's a it's a collection of really
3592.6s
interesting heuristics that end up
3594.1s
interesting heuristics that end up working.
3595.6s
Um I'll talk about uh the other two
3598.0s
first. Um they're not widely used. I would I would, you know, open with that. Um but I think it's useful to know. Um so, some people have done, you know, this is from Clark 2020, um research on things like can you use RL to learn routers? Um so, it does work. You can
3600.1s
would I would, you know, open with that.
3602.1s
Um but I think it's useful to know. Um
3604.5s
so, some people have done, you know,
3606.1s
this is from Clark 2020, um research on
3608.9s
things like can you use RL to learn
3610.5s
routers? Um so, it does work. You can
3612.9s
use RL. This is the the green line over here. Um that's one of the baseline approaches. It works fine, but because of the gradient variance and complexity, it's not, you know, actually the optimal thing to do. And even sort of some of the approaches that are proposed in this paper, you know, handily beat using
3614.6s
here. Um that's one of the baseline
3616.7s
approaches. It works fine, but because
3619.6s
of the gradient variance and complexity,
3621.6s
it's not, you know, actually the optimal
3623.7s
thing to do. And even sort of some of
3625.1s
the approaches that are proposed in this
3626.7s
paper, you know, handily beat using
3629.0s
reinforce gradients for routing reinforce gradients for routing decisions. Now, something that's very close to the idea of RL, um but sort of starts to move more towards this like heuristic world of like, you know, uh write down objectives and try things, um is this idea of a stochastic approximation. And it comes from um the
3631.2s
reinforce gradients for routing decisions.
3632.7s
Now, something that's very close to the
3635.6s
idea of RL,
3637.5s
um but sort of starts to move more
3640.0s
towards this like heuristic world of
3641.7s
like, you know, uh
3644.2s
write down objectives and try things, um
3646.6s
is this idea of a stochastic
3648.1s
approximation. And it comes from um the
3651.2s
approximation. And it comes from um the earliest uh sort of MOE paper from Noam Shazeer and others. Um where what you do is you don't really make a hard decision during training. What you do is you make stochastic decisions. So, what you do is, you know, you still have this sort
3652.6s
uh sort of MOE paper from Noam Shazeer
3654.4s
and others. Um where what you do is you
3656.8s
don't really make a hard decision during
3658.6s
training. What you do is you make
3660.0s
stochastic decisions. So, what you do
3662.6s
is, you know, you still have this sort
3664.2s
of routing uh setup. So, this H is sort of the um you have the inner product between your parameters and your inputs. This is your, you know, standard routing inner product. But then, instead of just doing this, I'm going to uh inject some noise that's kind of, you know, dependent on the scale of the input. And
3667.3s
of the um you have the inner product
3669.6s
between your parameters and your inputs.
3671.4s
This is your, you know, standard routing
3673.6s
inner product. But then, instead of just
3676.5s
doing this, I'm going to uh inject some
3678.6s
noise that's kind of, you know,
3680.5s
dependent on the scale of the input. And
3683.1s
then that's going to be my H. And now, I'm going to keep the top K and softmax just like I was explaining with top K. Um and then that's going to be, you know, the the output that goes in to the to my uh top K uh function. Now, if you look at this, you know, what
3685.2s
I'm going to keep the top K and softmax
3687.0s
just like I was explaining with top K.
3689.4s
Um and then that's going to be, you
3691.2s
know, the the output that goes in to the
3693.2s
to my uh top K uh function.
3697.0s
Now, if you look at this, you know, what
3698.4s
is this doing? Well, it just means that if you have two experts that are like closely tied or it's like very close to each other, then stochastically you'll pick one of them and not the other, right? And as you backprop, you're going to backprop in a way that like, you know, the the experts that are helpful
3700.5s
if you have two experts that are like
3702.1s
closely tied or it's like very close to
3704.0s
each other, then stochastically you'll
3705.8s
pick one of them and not the other,
3707.4s
right? And as you backprop, you're going
3709.4s
to backprop in a way that like, you
3711.0s
know, the the experts that are helpful
3712.6s
will get high weights. So, this, you know, perturbation allows you to like tiebreak and explore a little bit more, and hopefully the gradient structure is such that if an expert is helpful, um it's going to sort of emphasize that expert more and it will get higher and higher weights in general, right? So,
3714.5s
know, perturbation allows you to like
3715.9s
tiebreak and explore a little bit more,
3718.4s
and hopefully the gradient structure is
3720.1s
such that if an expert is helpful, um
3722.2s
it's going to sort of emphasize that
3723.7s
expert more and it will get higher and
3725.1s
higher weights in general, right? So,
3726.7s
this is one reasonable thing that you can do. It leads to experts that are a bit more robust, um and you the softmax, you know, allows you to to learn how to rank all these K experts, right? Rather than just doing a hard selection of top K, you're kind of waiting by these
3729.1s
can do. It leads to experts that are a
3730.6s
bit more robust, um and you the softmax,
3734.0s
you know, allows you to to learn how to
3735.8s
rank all these K experts, right? Rather
3737.8s
than just doing a hard selection of top
3739.6s
K, you're kind of waiting by these
3742.3s
K, you're kind of waiting by these elements. Um there's another sort of scheme that's very closely related that's done in Fit Fedus et al. in 2022, which does kind of a uniform multiplicative perturbation. Um they do this for the reason of like hardening the experts, like making them robust. Um
3746.4s
Um there's another sort of scheme that's
3749.3s
very closely related that's done in Fit
3751.4s
Fedus et al. in 2022, which does kind of
3754.0s
a uniform multiplicative perturbation.
3756.5s
Um they do this for the reason of like
3758.6s
hardening the experts, like making them
3760.2s
robust. Um
3762.0s
this was a useful thing to do, but this was later removed in in later Google papers that do sort of routing. And it's not really clear that this is, you know, necessary at all. Um you kind of see later ablation showing that, you know, not doing any any of these like stochastic sort of robustness tricks, um
3764.2s
was later removed in in later Google
3766.1s
papers that do sort of routing. And it's
3768.2s
not really clear that this is, you know,
3770.3s
necessary at all. Um you kind of see
3772.5s
later ablation showing that, you know,
3774.3s
not doing any any of these like
3776.0s
stochastic sort of robustness tricks, um
3778.7s
actually helps um with uh both stability and overall quality of the final trained MOE. Um So, this kind of suggests, you know, you don't need necessarily these stochastic exploration terms. exploration terms. So, um which is a series of sort of heuristics that, you know, in the end is going to give us a working uh MOE.
3782.3s
and overall quality of the final trained
3784.7s
MOE. Um
3786.5s
So, this kind of suggests, you know, you
3787.6s
don't need necessarily these stochastic
3790.7s
exploration terms.
3793.0s
exploration terms. So,
3800.1s
um which is a series of sort of
3800.8s
heuristics that, you know, in the end is
3801.5s
going to give us a working uh MOE.
3804.6s
And a lot of these have to do with essentially balancing. So, what's the problem if you just do normal gradient descent ignoring any of these like exploration exploitation concerns? Well, what's going to happen is you're going to first route your top K experts, and then sort of the strongest of those experts will get sort of more signal.
3806.8s
essentially balancing. So, what's the
3809.6s
problem if you just do normal gradient
3811.8s
descent ignoring any of these like
3813.3s
exploration exploitation concerns? Well,
3815.7s
what's going to happen is you're going
3816.6s
to first route your top K experts, and
3819.0s
then sort of the strongest of those
3820.4s
experts will get sort of more signal.
3821.9s
Your your backprop is going to say that expert was good, you know, increase the weight of that expert, right? Move that parameter closer to that. So, you get this rich gets richer effect where sort of experts that are chosen get very strong weights. Strong weights means that they're selected more often, and
3823.3s
expert was good, you know, increase the
3825.2s
weight of that expert, right? Move that
3826.8s
parameter closer to that. So, you get
3828.8s
this rich gets richer effect where sort
3831.0s
of experts that are chosen get very
3832.5s
strong weights. Strong weights means
3834.0s
that they're selected more often, and
3835.3s
you they kind of run away taking on all everything. And so, this expert collapse phenomenon, or like expert, you know, uh starvation phenomenon, is a very very real problem. Um and in some sense, this is the core issue that you have to solve with sort of heuristic training uh of MOEs. And
3837.9s
everything. And so,
3839.3s
this expert collapse phenomenon, or like
3841.6s
expert, you know, uh starvation
3843.3s
phenomenon, is a very very real problem.
3845.9s
Um and in some sense, this is the core
3847.5s
issue that you have to solve with sort
3849.9s
of heuristic training uh of MOEs. And
3853.0s
so, what do you do? Well, what you do is you add a sort of heuristic loss um to the total modeling loss. And this loss is supposed to balance out the tokens that go to different experts. Um and sort of one approach, which is done in the Switch Transformer over here, um
3855.2s
you add a sort of heuristic loss um to
3858.1s
the total modeling loss. And this loss
3860.1s
is supposed to balance out the tokens
3862.1s
that go to different experts. Um and
3865.0s
sort of one approach, which is done in
3866.6s
the Switch Transformer over here, um
3869.3s
you can kind of look at what this does. So, this is um you have different experts going from uh one to N, and you have a a batch of tokens. And what you want to do is you want to make sure that these two inner products are kind of small. Okay, so
3870.9s
So, this is um
3872.8s
you have different experts going from uh
3874.8s
one to N, and you have a a batch of
3877.0s
tokens. And what you want to do is you
3879.3s
want to make sure that these two inner
3881.1s
products are kind of small. Okay, so
3883.0s
what is happening here? F is the fraction of tokens that are dispatched to any single tok uh any single expert. And then P of I is kind of the softer version. This is not just the whether the token is dispatched to an expert, this is the total probability mass of the router that's allocated. Um and you
3884.8s
fraction of tokens that are dispatched
3886.6s
to any single tok uh any single expert.
3889.4s
And then P of I is kind of the softer
3891.4s
version. This is not just the whether
3893.6s
the token is dispatched to an expert,
3895.4s
this is the total probability mass of
3897.4s
the router that's allocated. Um and you
3899.4s
multiply these two two terms. Um and this is like at least to me, not a thing where you would like derive it from first principles. I think the easiest way to kind of see what this thing is doing is to take the derivative of this object. So, if you take the derivative of this
3902.6s
this is like
3904.2s
at least to me, not a thing where you
3906.2s
would like derive it from first
3907.6s
principles. I think the easiest way to
3909.7s
kind of see what this thing is doing is
3912.0s
to take the derivative of this object.
3913.9s
So, if you take the derivative of this
3915.5s
with respect to P of I, and P of I remember is the uh probability mass allocated to an expert, well, the derivative is going to be the fraction of uh tokens that are allocated. So, in some sense, you can think of this as a penalty in the gradient space, where the
3917.2s
remember is the uh probability mass
3919.8s
allocated to an expert, well, the
3922.2s
derivative is going to be the fraction
3924.5s
of uh tokens that are allocated. So, in
3926.6s
some sense, you can think of this as a
3927.8s
penalty in the gradient space, where the
3930.5s
more tokens you get, the sort of more negative gradient you get, right? So, so this is trying to push down the the probability mass on very popular experts proportional to their fraction. Um and that's kind of the action of this loss, right? That's the way to think about it. So, um you know, the the objective itself in
3932.4s
negative gradient you get, right? So, so
3934.4s
this is trying to push down the the
3936.6s
probability mass on very popular experts
3938.8s
proportional to their fraction. Um and
3941.8s
that's kind of the action of this loss,
3943.7s
right? That's the way to think about it.
3944.8s
So, um
3946.8s
you know, the the objective itself in
3948.2s
equation four might not be clear to start with, but once you sort of reason about the action of the uh the gradient, it becomes kind of clear what this is trying to do, which is to push down popular uh experts. So, you know, then we can actually take these ideas and essentially reconstruct
3949.8s
start with, but once you sort of reason
3951.5s
about the action of the uh the gradient,
3953.6s
it becomes kind of clear what this is
3955.0s
trying to do, which is to push down
3956.7s
popular uh experts.
3960.1s
So, you know, then we can actually take
3962.6s
these ideas and essentially reconstruct
3965.8s
um Deep Seek V1 and twos uh MOE system. So, what do you do to build Deep Seek MOE? Well, you're going to do your normal losses, and you're just going to backprop straight through the experts, ignoring all this non-differentiability and like exploration concerns. But you're going to add per expert balancing that's identical to the Switch
3969.5s
So, what do you do to build Deep Seek
3971.2s
MOE? Well, you're going to do your
3973.3s
normal losses, and you're just going to
3975.0s
backprop straight through the experts,
3976.6s
ignoring all this non-differentiability
3978.4s
and like exploration concerns. But
3980.5s
you're going to add per expert balancing
3982.6s
that's identical to the Switch
3983.8s
Transformer, what I just showed you, right? So, you're going to have this loss that multiplies F, which is the fraction, and P, which is the total mass. Um but, you know, Deep Seek folks, they're very savvy with their sort of systems design, so they don't just balance the experts, right? They also want to
3985.2s
right? So, you're going to have this
3986.4s
loss that multiplies F, which is the
3988.0s
fraction, and P, which is the total
3990.1s
mass. Um
3991.6s
but, you know, Deep Seek folks, they're
3993.4s
very savvy with their sort of systems
3995.0s
design, so they don't just balance the
3997.0s
experts, right? They also want to
3999.0s
balance by device. So, if your experts are allocated to some devices, so let's say, you know, my machine over here has four experts, this machine over here has another four experts, um you want those two machines to be balanced. So, both of them are kind of running at full utilization, right? So, you don't want
4001.5s
are allocated to some devices, so let's
4003.1s
say, you know, my machine over here has
4004.7s
four experts, this machine over here has
4006.8s
another four experts, um you want those
4009.2s
two machines to be balanced. So, both of
4010.8s
them are kind of running at full
4011.9s
utilization, right? So, you don't want
4013.2s
to just make the experts even, you also want to ensure that the devices are even, right? So, this is a secondary objective. It acts in exactly the same way. If you look at the function here, instead of experts, you're just operating on each device's fraction. You apply the secondary device balancing loss just for utilization reasons.
4015.4s
want to ensure that the devices are
4016.9s
even, right? So, this is a secondary
4018.5s
objective. It acts in exactly the same
4020.9s
way. If you look at the function here,
4022.0s
instead of experts, you're just
4023.6s
operating on each device's fraction. You
4026.0s
apply the secondary device balancing
4027.9s
loss just for utilization reasons.
4030.4s
loss just for utilization reasons. Um in Deep Seek V3, um they have sort of a per expert bias term that they use, um and they use kind of a online learning trick to uh sort of balance out experts, sort of getting rid of some of these uh auxiliary losses. Um
4031.9s
in Deep Seek V3, um they have sort of a
4034.7s
per expert bias term that they use, um
4038.2s
and they use kind of a online learning
4040.1s
trick to uh sort of balance out experts,
4043.0s
sort of getting rid of some of these uh
4045.3s
auxiliary losses. Um
4047.7s
but sort of, you know, they they do in in the end have to add some of these auxiliary losses in order to ensure like extreme imbalances don't happen. So, I think, you know, I I add this as kind of a side note that, you know, Deep Seek V3 and others have started to get rid of
4050.1s
in the end have to add some of these
4051.7s
auxiliary losses in order to ensure like
4054.0s
extreme imbalances don't happen. So, I
4055.8s
think, you know, I I add this as kind of
4057.8s
a side note that, you know, Deep Seek V3
4060.0s
and others have started to get rid of
4061.7s
some of these like, I would say, uglier auxiliary losses, but there's no kind of solution that fully gets rid of them uh so far. Um the last thing I'll talk about with load balancing losses is, you know, you might ask like, do I need to deal with this, right? It's all a bunch of
4063.3s
auxiliary losses, but there's no kind of
4065.5s
solution that fully gets rid of them
4067.8s
uh so far.
4072.6s
Um the last thing I'll talk about with
4074.0s
load balancing losses is, you know, you
4075.8s
might ask like, do I need to deal with
4077.7s
this, right? It's all a bunch of
4078.9s
complicated heuristics. Um I wish I could just get rid of the load balancing loss, um but if you remove the load balancing loss, you know, kind of the effects are pretty catastrophic. Um these are both from the Olmo paper where they do this nice ablation of removing the load balancing loss entirely. Um if
4081.1s
could just get rid of the load balancing
4082.6s
loss, um but if you remove the load
4084.7s
balancing loss, you know, kind of the
4086.3s
effects are pretty catastrophic. Um
4089.4s
these are both from the Olmo paper where
4091.1s
they do this nice ablation of removing
4093.3s
the load balancing loss entirely. Um if
4095.8s
you have the load balancing loss, you get this pink line at the top, very nice normal curves. Remove the load balancing loss, you know, your losses, you know, significantly increase, um and your training loss also is is not doing so well. Maybe more tellingly, if you look at, you know, which experts are doing what, um
4097.1s
get this pink line at the top, very nice
4098.8s
normal curves. Remove the load balancing
4101.0s
loss, you know, your losses, you know,
4102.6s
significantly increase, um and your
4104.6s
training loss also is is not doing so
4106.4s
well. Maybe more tellingly, if you look
4108.4s
at, you know, which experts are doing
4110.4s
what, um
4111.8s
with load balancing and without load load balancing, that's these bottom panels over here, you see a really really stark difference. Without load balancing, almost all the tokens are going to two experts, right? The yellow expert uh and the pinkish expert that you see here. Um whereas if you have the load balancing
4113.4s
load balancing, that's these bottom
4115.0s
panels over here,
4116.7s
you see a really really stark
4118.0s
difference. Without load balancing,
4120.2s
almost all the tokens are going to two
4122.4s
experts, right? The yellow expert uh and
4124.5s
the pinkish expert that you see here. Um
4126.9s
whereas if you have the load balancing
4128.4s
loss, even though we had this kind of heuristic thing going on, all of these experts are being utilized across the tokens. So, we have the nice even utilization, right? Um so, we do kind of see that the load balancing loss does what we'd expect. And without it, you know, we've thrown away a ton of parameters. Those experts
4130.0s
heuristic thing going on, all of these
4132.1s
experts are being utilized across the
4133.9s
tokens. So, we have the nice even
4135.3s
utilization, right?
4136.9s
Um so, we do kind of see that the load
4139.2s
balancing loss does what we'd expect.
4140.7s
And without it, you know, we've thrown
4142.2s
away a ton of parameters. Those experts
4144.0s
are doing absolutely nothing uh for most of training. Right? So, this is kind of how uh MOEs are trained. Um and I think, you know, just to to go back, this is pretty surprising to me that you've got this thing that's like kind of non-differentiable, you're doing this top K selection, and it's
4146.0s
of training.
4147.6s
Right? So, this is kind of how uh MOEs
4151.0s
are trained. Um and I think, you know,
4152.6s
just to to go back, this is pretty
4154.2s
surprising to me that you've got this
4156.0s
thing that's like kind of
4156.7s
non-differentiable, you're doing this
4158.0s
top K selection, and it's
4160.0s
honestly a pretty complicated object. Um and all you really need to do is to just add this balancing loss to even out the experts, and then the rest of it you treat it as if you can just pump gradients through the system, um and the model trains very nicely. Um
4162.4s
and all you really need to do is to just
4164.6s
add this balancing loss to even out the
4166.7s
experts, and then the rest of it you
4168.9s
treat it as if you can just pump
4170.4s
gradients through the system,
4172.6s
um and the model trains very nicely. Um
4174.5s
and I think a big part of this is, you know, this dynamic that if an expert is useful, you're going to reinforce it, and that's a sort of positive re- uh positive reinforcement cycle that's nicely sort of balanced out by evening things out, right? Those two dynamics cancel each other out. cancel each other out. Um
4176.1s
know, this dynamic that if an expert is
4177.8s
useful, you're going to reinforce it,
4179.4s
and that's a sort of positive re- uh
4181.1s
positive reinforcement cycle that's
4182.8s
nicely sort of balanced out by evening
4185.3s
things out, right? Those two dynamics
4186.8s
cancel each other out.
4188.9s
cancel each other out. Um
4190.8s
Cool. Oh, and the and the final thing I'll say is, you know, that same trick is used in uh DSA, right? We already saw the top K thing. It's used um in things like H-Net if you're familiar with that. That was a attempt to remove tokenizers. Um and so I think you'll you'll see this trick um
4192.1s
I'll say is, you know, that same trick
4194.2s
is used in uh DSA, right? We already saw
4197.2s
the top K thing. It's used um in things
4199.8s
like H-Net if you're familiar with that.
4201.2s
That was a attempt to remove tokenizers.
4203.8s
Um and so I think you'll you'll see this
4205.2s
trick um
4206.4s
more often in the future, where this idea of like top K selection and using um load balancing or other kinds of auxiliary losses to enable that non-differentiability will be sort of a ingredient of future architecture ingredient of future architecture design. design. Yes. Is there any way of relating mixture of experts to like systems optimization, but like
4208.4s
idea of like top K selection and using
4211.2s
um load balancing or other kinds of
4213.0s
auxiliary losses to enable that
4214.6s
non-differentiability will be sort of a
4216.8s
ingredient of future architecture
4218.4s
ingredient of future architecture design.
4219.7s
design. Yes.
4221.0s
Is there any way of relating mixture of
4222.3s
experts to like systems optimization,
4224.1s
but like
4226.2s
are they actually experts when you look at them? Are they actually doing things in different domains? Yeah, so some papers have um a visualization of which tokens are activated by which experts. And because the routers are so simple, it's not like, you know, experts are like smart experts. They're not like medical experts, legal experts, or
4227.7s
at them? Are they actually doing things
4229.2s
in different domains? Yeah, so some
4232.0s
papers have um a visualization of which
4234.7s
tokens are activated by which experts.
4237.1s
And because the routers are so simple,
4238.9s
it's not like, you know, experts are
4240.8s
like smart experts. They're not like
4242.4s
medical experts, legal experts, or
4243.8s
whatever. You do see that certain tokens are routed to different things. Like, you know, punctuations might be routed to one expert, uh or like other symbols might be routed to one expert. You know, non-English language character sets might be routed to another expert. Um but it's not really something where you
4246.2s
are routed to different things. Like,
4247.8s
you know, punctuations might be routed
4249.1s
to one expert, uh or like other symbols
4251.1s
might be routed to one expert. You know,
4252.8s
non-English language character sets
4254.4s
might be routed to another expert. Um
4257.0s
but it's not really something where you
4258.6s
can like look at it and be like, Oh, this is the Wall Street Journal expert. or something, right? Um nothing like or something, right? Um nothing like that. There's no semantics is maybe one way of putting it. Yeah. Yeah. Why is there a per expert loss versus a per device loss? That's what's interesting to this. Yeah,
4259.5s
this is the Wall Street Journal expert.
4261.1s
or something, right? Um nothing like
4262.5s
or something, right? Um nothing like that.
4263.4s
There's no semantics is maybe one way of
4264.7s
putting it. Yeah. Yeah. Why is there a
4266.9s
per expert loss
4267.9s
versus a per device loss?
4270.1s
That's what's interesting to this. Yeah,
4272.3s
the per expert loss, in principle, if it's perfectly enforced, will also enforce a per device balancing if your experts are evenly split, right? Um but I think at least my interpretation of that is, you know, you don't want to crank up the per expert loss so high that you get true full uniformity, cuz
4274.5s
it's perfectly enforced, will also
4276.0s
enforce a per device balancing if your
4277.6s
experts are evenly split, right? Um but
4279.6s
I think at least my interpretation of
4281.2s
that is, you know, you don't want to
4282.5s
crank up the per expert loss so high
4284.4s
that you get true full uniformity, cuz
4286.6s
that has deleterious effects on training dynamics. But per device is important enough that you're willing to add on a little bit of extra loss to encourage per device balancing over others. little bit about the system side of things um before we wrap up. things um before we wrap up. So um training MOEs has additional fun
4288.3s
dynamics. But per device is important
4290.3s
enough that you're willing to add on a
4291.5s
little bit of extra loss to encourage
4293.2s
per device balancing over others.
4299.9s
little bit about the system side of
4301.2s
things um before we wrap up.
4303.8s
things um before we wrap up. So
4304.8s
um training MOEs has additional fun
4307.5s
systems dynamics that are introduced. Um when I talk about parallelism, I'm mostly going to talk about, you know, the the standard, let's call it, forms of parallelism. You can you know, split up your data into many small chunks. Um or sorry, you can split up your data over many small chunks. You can take
4310.3s
when I talk about parallelism, I'm
4312.0s
mostly going to talk about, you know,
4313.8s
the the standard, let's call it, forms
4316.0s
of parallelism. You can you know, split
4317.4s
up your data into many small chunks. Um
4320.5s
or sorry, you can split up your data
4321.8s
over many small chunks. You can take
4323.0s
your one data set and split it up over many sort of machines. Um you can do model parallelism, where, you know, you can take your model, and maybe you can, you know, split up your model over uh different machines. Um and you can combine those two in various ways. You
4324.5s
many sort of machines. Um you can do
4326.8s
model parallelism, where, you know, you
4328.7s
can take your model, and maybe you can,
4330.4s
you know, split up your model over uh
4332.2s
different machines. Um and you can
4333.9s
combine those two in various ways. You
4335.6s
can of course do that. Um but I think the the thing that we realized is each of those parallelism techniques has some sort of limit, right? For data parallelism, you know, you're maxed out your batch size, right? Once you you hit your batch size, you can't parallelize by data anymore. For model size, there's
4338.4s
the the thing that we realized is each
4339.9s
of those parallelism techniques has some
4342.0s
sort of limit, right? For data
4343.4s
parallelism, you know, you're maxed out
4345.0s
your batch size, right? Once you you hit
4346.8s
your batch size, you can't parallelize
4348.4s
by data anymore. For model size, there's
4351.0s
natural sort of like cut points where you can cut your model. Once you exhaust those, you can't parallelize anymore. Um if you combine this with expert parallelism, you got another additional nice axis to parallelize on. Um and when you do this, um a lot of the the underlying implementations for this allow you to
4352.3s
you can cut your model. Once you exhaust
4354.0s
those, you can't parallelize anymore. Um
4356.2s
if you combine this with expert
4357.5s
parallelism, you got another additional
4359.4s
nice axis to parallelize on. Um and when
4362.3s
you do this,
4364.1s
um a lot of the the underlying
4366.1s
implementations for this allow you to
4368.8s
take advantage of uh sparse matrix multiply that has been built in to GPUs. Um so the another thing that I'll mention is like if you're if you have multiple experts, right? Let's say on one GPU, um one way of thinking about it is to say, I'm going to have multiple small matrix multiplies. But this is
4371.1s
multiply that has been built in to GPUs.
4374.4s
Um so the another thing that I'll
4375.7s
mention is like if you're if you have
4377.1s
multiple experts, right? Let's say on
4378.8s
one GPU, um one way of thinking about it
4381.6s
is to say, I'm going to have multiple
4383.2s
small matrix multiplies. But this is
4385.3s
not nice because you ideally want these like bigger matrix multiplies where you can reuse caches and do all sorts of things. So instead, what you can do is you can, you know, nicely leverage sparsity, and block diagonal, of course, is the most basic form, but you can use much more sort of complicated forms of
4387.9s
like bigger matrix multiplies where you
4389.7s
can reuse caches and do all sorts of
4391.4s
things. So instead, what you can do is
4394.5s
you can, you know, nicely leverage
4396.7s
sparsity, and block diagonal, of course,
4398.8s
is the most basic form, but you can use
4400.7s
much more sort of complicated forms of
4402.8s
structured sparsity that are sort of natively supported in hardware um to allow you to essentially multiply uh sort of experts and inputs in very, very clean ways uh very fast, right? So this is another reason why MOEs are nice. If you think about the the sort of computation patterns of MOEs, they almost correspond to these kinds of
4404.7s
natively supported in hardware um to
4407.4s
allow you to essentially multiply uh
4410.3s
sort of experts and inputs in very, very
4412.5s
clean ways uh very fast, right? So this
4415.2s
is another reason why MOEs are nice. If
4417.7s
you think about the the sort of
4419.1s
computation patterns of MOEs, they
4421.2s
almost correspond to these kinds of
4422.9s
structured matrix multiplications that are very easy and efficient to support in hardware, right? So there's this hardware architecture co-design that's kind of happening um with MOEs. Um and in terms of parallelism, um one thing I'll mention, this is a very recent development. This is from uh Neumitron 3. Um you know, as I I forgot
4425.1s
are very easy and efficient to support
4426.9s
in hardware, right? So there's this
4428.6s
hardware architecture co-design that's
4430.7s
kind of happening um with MOEs.
4434.4s
Um and in terms of parallelism, um one
4437.8s
thing I'll mention, this is a very
4439.2s
recent development. This is from uh
4440.8s
Neumitron 3. Um you know, as I I forgot
4443.9s
who asked this question, um who asked this question, um uh you know, when we do this expert parallelism, we have to be shipping activations from device to device in order to say like, Oh, you belong in that expert. I need to ship you over. That can result in significant communications overhead.
4445.9s
who asked this question, um uh
4446.6s
you know, when we do this expert
4448.0s
parallelism, we have to be shipping
4449.5s
activations from device to device in
4452.0s
order to say like, Oh, you belong in
4453.6s
that expert. I need to ship you over.
4455.9s
That can result in significant
4457.3s
communications overhead.
4459.2s
So some of the things that people have started to think about is to say, Well, my shared expert, which, you know, I'm not going to communicate, this can be, you know, in big dimensions, but my experts, I need to communicate those activations. I want those to be in smaller, lower dimensional vectors,
4460.8s
started to think about is to say, Well,
4463.0s
my shared expert, which, you know, I'm
4464.8s
not going to communicate, this can be,
4466.8s
you know, in big dimensions, but my
4469.3s
experts, I need to communicate those
4471.1s
activations. I want those to be in
4472.6s
smaller, lower dimensional vectors,
4474.8s
right? So you might actually take your your sort of residual stream and down project it first, and then decide to do sort of the collective communication call of sending this activation out. And that will significantly save on uh communication without sort of fully having the the drawbacks of having a smaller hidden dimension size, right? So
4476.8s
your sort of residual stream and down
4479.1s
project it first, and then decide to do
4482.0s
sort of the collective communication
4483.6s
call of sending this activation out. And
4486.2s
that will significantly save on uh
4488.7s
communication without sort of fully
4491.2s
having the the drawbacks of having a
4493.0s
smaller hidden dimension size, right? So
4495.2s
So you can actually do these like kind of projection tricks to to nicely control uh the trade-off of uh communication and uh having your MOEs communication and uh having your MOEs parallelized.
4496.4s
of projection tricks to to nicely
4498.4s
control uh the trade-off of uh
4500.6s
communication and uh having your MOEs
4503.8s
communication and uh having your MOEs parallelized.
4510.7s
call it, detail uh of MOEs that I think is fun, um but it has been solved in recent years, is stochasticity. If you do MOE sort of infrastructure naively, what you'll find is that unless your experts are sort of perfectly balanced on your input sort of distribution, you'll find that certain experts might
4513.2s
is fun, um but it has been solved in
4515.1s
recent years, is stochasticity. If you
4518.0s
do MOE sort of infrastructure naively,
4521.5s
what you'll find is that unless your
4523.0s
experts are sort of perfectly balanced
4525.3s
on your input sort of distribution,
4527.1s
you'll find that certain experts might
4528.3s
be much more popular than others. And if, you know, some expert, like let's say in this case expert zero, is just a really popular expert, you're going to actually start running into situations where this expert sort of queue of tokens just starts building up and building up and building up, and eventually you'll get to a point where
4530.4s
And if, you know, some expert, like
4531.9s
let's say in this case expert zero, is
4533.8s
just a really popular expert, you're
4535.8s
going to actually start running into
4537.2s
situations where this expert sort of
4538.8s
queue of tokens just starts building up
4540.7s
and building up and building up, and
4542.3s
eventually you'll get to a point where
4543.6s
you say, My queue is so long, I have to start dropping tokens. right? Um in the earlier generation of uh MOE inference sort of infrastructure and code, a lot of what happened was you would just actually silently drop this expert, and you just proceed on with your computation as if nothing has
4545.4s
start dropping tokens. right?
4547.8s
Um in the earlier generation of uh MOE
4550.4s
inference sort of infrastructure and
4552.6s
code, a lot of what happened was you
4554.4s
would just actually silently drop this
4556.1s
expert, and you just proceed on with
4557.9s
your computation as if nothing has
4559.2s
happened. You know, you just send a zero back and pretend that that was fine. Um and this means that, it, you know, you you can have like these weird stochasticities where if other users, you know, sending queries that hit the experts that you're using, then you could actually get a worse result
4560.8s
back and pretend that that was fine. Um
4563.6s
and this means that, it, you know, you
4566.0s
you can have like these weird
4567.2s
stochasticities where if other users,
4570.1s
you know, sending queries that hit the
4572.3s
experts that you're using, then you
4574.4s
could actually get a worse result
4576.1s
because they would kind of bump you out of the expert queue. You know, these kinds of very strange things can happen. Um but these days there's sort of dropless architectures that have gotten rid of a lot of these issues. So this is no longer really a a issue. So MegaBlocks and other sort of very common open-source MOE
4577.3s
of the expert queue. You know, these
4578.8s
kinds of very strange things can happen.
4580.9s
Um but these days there's sort of
4582.3s
dropless architectures that have gotten
4584.9s
rid of a lot of these issues. So this is
4586.6s
no longer really a a issue. So
4588.3s
MegaBlocks and other sort of very common
4590.4s
open-source MOE
4592.2s
uh frameworks don't have this issue uh frameworks don't have this issue anymore. Okay. So the last thing that I'll talk about in terms of the design of MOEs is this issue of stability. Um we've already talked about stability issues in the context of um architecture design last lecture. But, you know, if you remember what I
4593.5s
uh frameworks don't have this issue anymore.
4595.7s
Okay. So the last thing that I'll talk
4598.0s
about in terms of the design of MOEs is
4600.0s
this issue of stability. Um we've
4602.2s
already talked about stability issues in
4603.9s
the context of um architecture design
4606.6s
last lecture.
4608.1s
But, you know, if you remember what I
4609.9s
said then, I said, you know, exponentials are bad, divisions are bad, which means soft maxes are, you know, danger zone for stability issues. Now, what have we done with um mixture of experts? Well, we have introduced yet another soft max in the routing, right? Because we're going to do top K, we're
4611.7s
exponentials are bad, divisions are bad,
4614.0s
which means soft maxes are, you know,
4615.8s
danger zone for stability issues. Now,
4618.3s
what have we done with um mixture of
4620.4s
experts? Well, we have introduced yet
4622.3s
another soft max in the routing, right?
4624.0s
Because we're going to do top K, we're
4625.3s
going to do soft max, and that's going to be sort of how we route uh to to be sort of how we route uh to experts. Now, um you know, people have long noticed that the soft max operation in MOEs are potentially very, very dangerous. Um Barrett Zoph and others in
4627.0s
to be sort of how we route uh to
4628.6s
to be sort of how we route uh to experts.
4630.1s
Now, um you know, people have long
4633.4s
noticed that the soft max operation in
4635.6s
MOEs are potentially very, very
4637.8s
dangerous. Um Barrett Zoph and others in
4640.4s
early days of sort of Google MOE design sort of had an entire paper on MOE stability, and this was one of the things that they noticed. Um and you know, I'm going to go into some of the details. You know, you'll see the solution fairly often for some of the more sensitive and tricky parts. You
4642.6s
sort of had an entire paper on MOE
4645.3s
stability, and this was one of the
4646.4s
things that they noticed. Um and you
4648.6s
know, I'm going to go into some of the
4649.5s
details. You know, you'll see the
4651.4s
solution fairly often for some of the
4653.2s
more sensitive and tricky parts. You
4655.3s
know, you might end up using something like float32 for just the expert router. Um you might have a Z loss as well here. Um we've talked about Z loss before um in order to control the soft max stability issues that we saw. Um Z loss, I think, was was quite popular, actually,
4656.4s
like float32 for just the expert router.
4659.1s
Um you might have a Z loss as well here.
4661.4s
Um we've talked about Z loss before um
4663.7s
in order to control the soft max
4665.8s
stability issues that we saw.
4668.7s
Um Z loss, I think, was was quite
4671.2s
popular, actually,
4672.8s
uh to do MOE router work uh MOE router stability even in the early days. Um OLMo, which I've been talking about extensively, has done ablations on, you know, adding or removing Z loss. You know, it's quite clear from these very spiky training loss curves that Z loss on the router um can be quite
4675.3s
MOE router stability even in the early
4677.0s
days. Um OLMo, which I've been talking
4679.4s
about extensively, has done ablations
4681.5s
on, you know, adding or removing Z loss.
4683.8s
You know, it's quite clear from these
4684.8s
very spiky training loss curves that Z
4686.6s
loss on the router um can be quite
4688.9s
loss on the router um can be quite helpful. you can uh end up with. Um and I think, you know, you probably all have experienced this maybe recently if you work with language models, um which is that MOEs can be pretty annoying to fine-tune. They have so many parameters that if you're trying to fine-tune the
4693.1s
you can uh end up with. Um and I think,
4695.9s
you know, you probably all have
4697.0s
experienced this maybe recently if you
4698.9s
work with language models, um which is
4701.2s
that MOEs can be pretty annoying to
4703.6s
fine-tune. They have so many parameters
4706.0s
that if you're trying to fine-tune the
4707.9s
the um experts, you actually end up with very sort of serious overfitting issues. So you can kind of see um that if you have your dense models, train and val are, you know, fairly close. They don't overfit crazily. But if you look at the sparse, train and val gap is, you know, extremely large as you
4710.3s
very sort of serious overfitting issues.
4712.2s
So you can kind of see
4713.7s
um that if you have your dense models,
4715.6s
train and val are, you know, fairly
4717.8s
close. They don't overfit crazily. But
4719.8s
if you look at the sparse, train and val
4721.5s
gap is, you know, extremely large as you
4724.0s
fine-tune these models for in this case uh downstream task. Uh it was I think it was one of the GLUE benchmark tasks. was one of the GLUE benchmark tasks. Now, Now, um you know, you could have a case where, you know, you fine-tune um non-MOE feedforwards. You could also fine-tune just attention. These are pretty common
4726.6s
uh downstream task. Uh it was I think it
4728.8s
was one of the GLUE benchmark tasks.
4731.6s
was one of the GLUE benchmark tasks. Now,
4732.6s
Now, um
4733.2s
you know, you could have a case where,
4735.5s
you know, you fine-tune um non-MOE
4738.7s
feedforwards. You could also fine-tune
4740.6s
just attention. These are pretty common
4743.1s
interventions that people uh do. Especially the attention fine-tuning is something that I see quite often with uh MOE works recently. But if your model has non-MOE layers, then of course you can also just fine-tune those. And Barrett Zoph and others, you know, were were kind of arguing for those. Um and
4745.8s
Especially the attention fine-tuning is
4747.9s
something that I see quite often with uh
4750.8s
MOE works recently. But if your model
4753.3s
has non-MOE layers, then of course you
4755.4s
can also just fine-tune those. And
4756.9s
Barrett Zoph and others, you know, were
4758.8s
were kind of arguing for those. Um and
4761.1s
of course there's also the uh kind of bitter lesson version of this, which is um you know, maybe you should just use 1.4 million examples instead of, you know, whatever number you have. If you have a lot of data, you know, you can retrain basically the MOE entirely during your fine-tuning stage, and you
4762.4s
bitter lesson version of this, which is
4764.1s
um you know, maybe you should just use
4765.8s
1.4 million examples instead of, you
4768.0s
know, whatever number you have. If you
4769.3s
have a lot of data, you know, you can
4770.8s
retrain basically the MOE entirely
4772.8s
during your fine-tuning stage, and you
4774.5s
won't have quite as much of a generalization gap.
4776.2s
generalization gap.
4783.0s
that I'll mention is uh this idea of uh upcycling. This has actually become a lot less popular in the last year. I don't think I've seen a single upcycled model this year. Um, but it's a cool idea. And also, there were some very nice models earlier on using this trick.
4786.0s
upcycling. This has actually become a
4787.5s
lot less popular in the last year. I
4789.8s
don't think I've seen a single upcycled
4791.8s
model this year. Um, but it's a cool
4794.2s
idea. And also, there were some very
4796.0s
nice models earlier on using this trick.
4798.4s
So, I I kind of want to mention this. Um, so upcycling is this idea that, you know, if you want that MOE, maybe you can just take a dense model that you've trained, and then you can instantiate a MOE based on that dense model. And the idea is you just copy everything,
4800.2s
Um, so upcycling is this idea that, you
4803.1s
know, if you want that MOE, maybe you
4805.3s
can just take a dense model that you've
4806.8s
trained, and then you can instantiate a
4808.9s
MOE based on that dense model. And the
4811.4s
idea is you just copy everything,
4813.9s
including the MLPs. You just make a whole bunch of copies from of these MLPs. You have a router, which, you know, have a random initialization. Um, and then you just train this model. And if you do this, then, you know, because of the stochasticity of which inputs go to where in the MLP, um you're going to
4816.0s
whole bunch of copies from of these
4817.7s
MLPs. You have a router, which, you
4820.1s
know, have a random initialization. Um,
4822.4s
and then you just train this model. And
4823.8s
if you do this, then, you know, because
4825.7s
of the stochasticity of which inputs go
4827.5s
to where in the MLP, um you're going to
4829.6s
get these experts that start to specialize. Um, and in the end, you'll actually end up with a MOE. Um, and in sort of some of the earliest papers that proposed upcycling, you know, they showed that an upcycled method um could get, you know, much better sort of uh language modeling sort of val- in this case, accuracy uh
4831.1s
specialize. Um, and in the end, you'll
4833.2s
actually end up with a MOE.
4835.8s
Um, and in sort of some of the earliest
4837.6s
papers that proposed upcycling,
4840.2s
you know, they showed that an upcycled
4841.5s
method um could get, you know, much
4843.9s
better sort of uh language modeling sort
4846.2s
of val- in this case, accuracy uh
4848.5s
performance um sort of upcycling a dense model rather than continuing to train the same model. Um, which is kind of a cool uh thing that you can show. People have sort of validated this kind of uh at scale in many ways. So, MiniCPM is another sort of, you know, Chinese model that, you know, I quite like
4851.9s
model rather than continuing to train
4854.6s
the same model. Um, which is kind of a
4856.1s
cool uh thing that you can show.
4859.5s
People have sort of validated this kind
4861.2s
of uh at scale in many ways. So, MiniCPM
4864.4s
is another sort of, you know, Chinese
4866.8s
model that, you know, I quite like
4868.7s
because they do a lot of carefully controlled ablations. They had a MOE where they took their MiniCPM 2.4B model, they upcycled it to a 13.4 billion parameter model, and they got a bunch of, you know, almost free wins from doing that. from doing that. Um, Qwen also, you know, when they did their QwenMOE, initialized from their 1.8B
4869.8s
controlled ablations. They had a MOE
4872.5s
where they took their MiniCPM 2.4B
4874.9s
model, they upcycled it to a 13.4
4877.4s
billion parameter model, and they got a
4879.2s
bunch of, you know, almost free wins
4881.2s
from doing that.
4883.0s
from doing that. Um,
4884.9s
Qwen also, you know, when they did their
4886.9s
QwenMOE, initialized from their 1.8B
4889.9s
Qwen model, you know, they upcycled that to a uh Qwen uh 1.5A 2.7B model, um and they ended up with, you know, one of the first and, you know, one of the very high-performance uh models. It's one I think one of the first larger-scale upcycling successes. Of course, these days, I don't think anyone's really
4892.0s
to a uh Qwen uh 1.5A 2.7B model, um and
4897.6s
they ended up with, you know, one of the
4899.0s
first and, you know, one of the very
4900.9s
high-performance uh models. It's one I
4903.2s
think one of the first larger-scale
4905.2s
upcycling successes. Of course, these
4908.0s
days, I don't think anyone's really
4909.8s
upcycling anymore because you don't really train dense models and then convert them. You might as well just train your big, you know, hero run on a MOE to start with, right? So, I don't think we see this quite as much, but I wanted to mention this because I think it's an important thing to know in the
4911.7s
really train dense models and then
4914.2s
convert them. You might as well just
4915.7s
train your big, you know, hero run on a
4918.3s
MOE to start with, right? So, I don't
4920.1s
think we see this quite as much, but I
4921.8s
wanted to mention this because I think
4923.2s
it's an important thing to know in the
4925.6s
action space of mixture of experts action space of mixture of experts models. to end with um is I want to just walk us through the DeepSeek models V1, V2, and V3. I think there's kind of a lot to learn from the evolution of DeepSeek models. And also, if you're interested just in general
4928.9s
action space of mixture of experts models.
4933.5s
to end with um
4935.1s
is I want to just walk us through the
4936.9s
DeepSeek models V1, V2, and V3. I think
4939.7s
there's kind of a lot to learn from the
4941.0s
evolution of DeepSeek models. And also,
4942.8s
if you're interested just in general
4944.3s
about MOEs or architecture design, I would encourage that you read the DeepSeek papers. I think they're very well written. I think they have a lot of details that are are worth digging into. So, um the DeepSeek MOE architecture at V1 um is is, you know, already actually the prototype of a lot of different
4946.6s
would encourage that you read the
4947.8s
DeepSeek papers. I think they're very
4949.3s
well written. I think they have a lot of
4950.6s
details that are are worth digging into.
4953.0s
So, um the DeepSeek MOE architecture at
4955.3s
V1 um is is, you know, already actually
4959.0s
the prototype of a lot of different
4960.6s
modern MOEs. It's got the shared and fine-grained expert structure. It uses standard top-k routing with the auxiliary loss balancing. I mean, in some ways, I would say this is like the prototypical like Platonic ideal of uh of a MOE model. Um, DeepSeek MOE V2 kind of just like scales this guy up, right?
4962.4s
fine-grained expert structure. It uses
4964.4s
standard top-k routing with the
4966.4s
auxiliary loss balancing. I mean, in
4968.2s
some ways, I would say this is like the
4970.0s
prototypical like Platonic ideal of uh
4973.2s
of a MOE model. Um, DeepSeek MOE V2 kind
4976.7s
of just like scales this guy up, right?
4978.4s
So, you've got two shared experts, many more fine-grained experts, um and you've got this device routing um component, and you've got a communication balancing component. Both of these are essentially adding auxiliary losses to optimize systems. I think it's important to realize that, you know, successful language model training is not just about, you know, deep learning, it's
4980.4s
more fine-grained experts, um and you've
4982.4s
got this device routing um component,
4985.0s
and you've got a communication balancing
4986.7s
component. Both of these are essentially
4989.0s
adding auxiliary losses to optimize
4991.3s
systems. I think it's important to
4992.8s
realize that, you know, successful
4994.6s
language model training is not just
4996.1s
about, you know, deep learning, it's
4997.7s
also about really respecting your system. So, that's right. So, DeepSeek V2 is really a reflection, I think, of this philosophy. Um, and then um in DeepSeek MOE V3, you know, they've still got the the shared and fine-grained experts design, but they've got sort of different ways of doing balancing. They've got this
4999.5s
system. So, that's right. So, DeepSeek
5000.7s
V2 is really a reflection, I think, of
5002.8s
this philosophy.
5004.5s
Um, and then um in DeepSeek MOE V3, you
5007.9s
know, they've still got the the shared
5009.6s
and fine-grained experts design, but
5011.4s
they've got sort of different ways of
5013.2s
doing balancing. They've got this
5014.6s
auxiliary they've got an aux loss free ways of balancing the different um experts. Um, and they've sort of switched to using um sigmoid plus softmax, which is a different way of sort of weighting um their experts. But really mostly similar. Finally, since we're already here, we might as well just walk through the rest
5016.9s
ways of balancing the different um
5019.0s
experts. Um, and they've sort of
5021.8s
switched to using um sigmoid plus
5023.5s
softmax, which is a different way of
5024.9s
sort of weighting um their experts. But
5028.4s
really mostly similar.
5030.1s
Finally, since we're already here, we
5031.8s
might as well just walk through the rest
5033.4s
of DeepSeek MOE V3. They've got lots of other cool things. I'm a big fan of of how they think about systems. Um, they've got multi-head latent attention, where, you know, instead of doing QKV uh sort of directly, what you're going to do is you're going to represent them as this as this lower-dimensional latency.
5035.5s
other cool things. I'm a big fan of of
5038.1s
how they think about systems. Um,
5039.9s
they've got multi-head latent attention,
5041.6s
where, you know, instead of doing QKV uh
5044.2s
sort of directly, what you're going to
5046.0s
do is you're going to represent them as
5048.2s
this as this lower-dimensional latency.
5050.9s
So, you have your hidden input, instead of directly producing your Qs, Ks, and Vs, you're going to first produce the C, and then you're going to produce your QKs and Vs as a function of that. And that allows you to get sort of significantly uh improved savings. Instead of KV caching
5052.7s
of directly producing your Qs, Ks, and
5054.7s
Vs, you're going to first produce the C,
5057.3s
and then you're going to produce your
5058.3s
QKs and Vs as a function of that. And
5061.0s
that allows you to get sort of
5062.2s
significantly uh
5064.2s
improved savings. Instead of KV caching
5066.8s
all your Ks and Vs, I only need to store these Cs, right? And the Cs are hopefully lower dimensional, right? And that's why it's called uh MLA, like latent activation, because these Cs, which are latent, are the KV cache quantities that you need to save. Um, the only complexity that you might
5069.3s
these Cs, right? And the Cs are
5071.1s
hopefully lower dimensional, right? And
5072.8s
that's why it's called uh MLA, like
5074.5s
latent activation, because these Cs,
5076.9s
which are latent, are the KV cache
5079.0s
quantities that you need to save.
5081.7s
Um, the only complexity that you might
5084.0s
need to worry about is that this is going to conflict with uh rope when you you do sort of KV caching. Um, so you have to be a little bit careful about how, you know, you rotate different dimensions and so on. Um, and sort of the trick is you have non-latent dimensions that encode positions, um but
5085.8s
going to conflict with uh rope when you
5089.1s
you do sort of KV caching. Um, so you
5091.2s
have to be a little bit careful about
5092.4s
how, you know, you rotate different
5094.7s
dimensions and so on. Um, and sort of
5096.6s
the trick is you have non-latent
5098.1s
dimensions that encode positions, um but
5100.6s
I'm sort of not going to go into too much more details about that. Finally, something that I thought was a very cool idea for MOE, I said DeepSeek MOE V3, but hasn't caught on very much, is the idea of MTP or multi-token prediction. Instead of predicting one future token, you predict multiple tokens um all at
5102.2s
much more details about that. Finally,
5105.5s
something that I thought was a very cool
5107.0s
idea for MOE, I said DeepSeek MOE V3,
5109.8s
but hasn't caught on very much, is the
5112.2s
idea of MTP or multi-token prediction.
5114.8s
Instead of predicting one future token,
5116.8s
you predict multiple tokens um all at
5119.2s
once. That has sort of uh statistical arguments for why that's a good idea. Maybe it lets you predict the future a little bit better. But also, there's this nice systems argument that you kind of now have a a speculative decoder built in. Um, Percy will talk about that when he talks about inference later. Um,
5121.5s
arguments for why that's a good idea.
5122.9s
Maybe it lets you predict the future a
5124.3s
little bit better. But also, there's
5125.9s
this nice systems argument that you kind
5127.8s
of now have a a speculative decoder
5129.8s
built in. Um, Percy will talk about that
5132.5s
when he talks about inference later. Um,
5134.7s
but it's a sort of a trick that you can use to speed up your decoding from your use to speed up your decoding from your model. Okay, I think we're roughly at time. So, to put everything together, right? MOEs are this kind of very clever idea to take advantage of sparsity so that you
5136.8s
use to speed up your decoding from your
5139.1s
use to speed up your decoding from your model.
5141.0s
Okay, I think we're roughly at time. So,
5143.0s
to put everything together, right? MOEs
5145.3s
are this kind of very clever idea to
5147.7s
take advantage of sparsity so that you
5149.7s
can have more parameters than you're sort of paying for, right? So, you don't pay for the the compute cost of all of your parameters, but you are getting the parameter benefits. Um, you might initially think that this kind of routing problem is very hard, but it turns out that very simple things work
5151.6s
sort of paying for, right? So, you don't
5153.4s
pay for the the compute cost of all of
5155.5s
your parameters, but you are getting the
5157.4s
parameter benefits. Um,
5159.6s
you might initially think that this kind
5161.6s
of routing problem is very hard, but it
5163.6s
turns out that very simple things work
5165.7s
well even at scale. Um, and at this point, it's kind of clear that MOEs are here to stay. So, you should kind of understand how they work and what they are. Um, are. Um, thanks.
5168.1s
point, it's kind of clear that MOEs are
5169.7s
here to stay. So, you should kind of
5171.0s
understand how they work and what they
5172.5s
are. Um,
5174.0s
are. Um, thanks.