All Videos
Stanford CS25: Transformers United V6 I From Representation Learning to World Modeling

Stanford CS25: Transformers United V6 I From Representation Learning to World Modeling

Read the full transcript of "Stanford CS25: Transformers United V6 I From Representation Learning to World Modeling" by Stanford Online. Practice English lis...

Channel: Stanford Online Duration: 71 min Sentences: 207
For more information about Stanford’s graduate programs, visit: https://online.stanford.edu/graduate-education April 9, 2026 This seminar covers: • How world models are increasingly moving away from reconstruction and toward prediction in latent space • Two recent JEPA-based approaches that illustrate this shift from complementary angles Follow along with the seminar schedule. Visit: https://web.stanford.edu/class/cs25/ Guest Speakers: Hazel Nam & Lucas Maes (Brown University) Instructors: • Steven Feng, Stanford Computer Science PhD student and NSERC PGS-D scholar • Karan P. Singh, Electrical Engineering PhD student and NSF Graduate Research Fellow in the Stanford Translational AI Lab • Michael C. Frank, Benjamin Scott Crocker Professor of Human Biology Director, Symbolic Systems Program • Christopher Manning, Thomas M. Siebel Professor in Machine Learning, Professor of Linguistics and of Computer Science, Co-Founder and Senior Fellow of the Stanford Institute for Human-Centered Artificial Intelligence (HAI)
Watch original video on YouTube →
Start Learning with Interactive Transcript

Full Transcript

6.8s welcome to the second lecture of CS25 this quarter. And so, today we're um very lucky to have um two speakers here with us. Um, so Hazel um or Hi Jang Nam um is um here in person today. Um she's a master student at Brown University working on representation learning, causality, and self-supervised um
9.6s this quarter. And so, today we're um
12.6s very lucky to have um two speakers here
15.4s with us. Um, so Hazel um or Hi Jang Nam
19.4s um is um here in person today. Um she's
21.9s a master student at Brown University
23.7s working on representation learning,
25.2s causality, and self-supervised um
27.8s learning. And we also have Lucas Amaze um over Zoom um who will be speaking afterwards um and he's a PhD student at MA and the University of Montreal working on Japa and um planning um so I'm sure they're going to give us um very insightful talk today. And so without further ado, I'll hand it off to
30.6s um over Zoom um who will be speaking
33.1s afterwards um and he's a PhD student at
35.8s MA and the University of Montreal
37.5s working on Japa and um planning um so
40.1s I'm sure they're going to give us um
41.9s very insightful talk today. And so
44.0s without further ado, I'll hand it off to
45.9s um Hazel. Yeah. Uh thank you for introducing um Hazel, first year master student in Brown University working with Professor Rendel Bolis Rio. Um our lab is working on um jetpa self-s supervised learning and some theory as well. Um today uh I'm a bit nervous. This is like first time giving a talk in English but at the same
48.2s Yeah. Uh thank you for introducing um
50.3s Hazel, first year master student in
52.1s Brown University working with Professor
53.8s Rendel Bolis Rio. Um our lab is working
56.9s on um jetpa self-s supervised learning
60.1s and some theory as well. Um today uh I'm
64.6s a bit nervous. This is like first time
66.1s giving a talk in English but at the same
68.2s time I'm really really excited to talk about jetpa worm model and my recent work causal world model. I'm pretty sure Lucas is also very excited to share his very recent recent work model. Um okay so the first part of this talk will be about the concepts of JEPA and war model and after a brief introduction of
69.8s about jetpa worm model and my recent
71.6s work causal world model. I'm pretty sure
73.4s Lucas is also very excited to share his
75.2s very recent recent work model. Um
79.9s okay so the first part of this talk will
83.4s be about the concepts of JEPA and war
85.6s model and after a brief introduction of
88.6s those two concepts uh we will talk about the causal JEPA paper first um which gives a question about how to make a model understand the object interaction and then I will hand over to Lucas and he will talk about luir model and lore model is about um the end end to end jet training without collapse
91.0s the causal JEPA paper first um which
93.7s gives a question about how to make a
95.8s model understand the object interaction
98.6s and then I will hand over to Lucas and
100.8s he will talk about luir model and lore
103.6s model is about um the end end to end jet
106.6s training without collapse
112.0s guys may heard of war model and some of you guys may be really be familiar with um what is the war model to talk about the war model we should have for this first so this is auto reggressive model right you get the previous state and you predict the next step however this is sometimes not
113.9s you guys may be really be familiar with
116.4s um what is the war model to talk about
118.4s the war model
120.4s we should have for this first so this is
122.9s auto reggressive model right you get the
125.0s previous state and you predict the next
127.4s step however this is sometimes not
129.8s enough to describe the world why because our world have uncertainty at the time that you're observing Someone might do something at that time, right? Someone can throw something. There's like inherent uncertainty to handle this. You need an action. And this now is a war model. So the war model basically is a function that gets the previous state
132.8s our world have uncertainty at the time
135.4s that you're observing Someone might do
138.2s something at that time, right? Someone
140.0s can throw something. There's like
141.6s inherent uncertainty to handle this. You
145.0s need an action. And this now is a war
148.1s model. So the war model basically is a
151.4s function that gets the previous state
153.5s and the action to predict the next state. And in this sense, I perceive the world model terminology as a simulator. So as an AR researcher, what would you do to make this world model better? In my opinion, I think three um points of my opinion, I think three um points of designing.
155.9s state. And in this sense, I perceive the
159.1s world model terminology as a simulator.
162.1s So as an AR researcher, what would you
165.1s do to make this world model better? In
167.8s my opinion, I think three um points of
170.8s my opinion, I think three um points of designing.
172.4s The first one is having a good state representation. This is about so you probably are not going to give the raw pixel images or like a raw pixel values to the predictor. You somehow have to have a good representation that reflects the world faithfully. And the second component will be good good transition model. For example, um
174.2s representation. This is about so you
177.6s probably are not going to give the raw
179.4s pixel images or like a raw pixel values
181.9s to the predictor. You somehow have to
185.0s have a good representation that reflects
187.4s the world faithfully.
190.0s And the second component will be good
191.6s good transition model. For example, um
195.8s there should be a grounding rule like underlying the environment. For example, there's a gravity that exists we we have to understand these rules to predict the world and you know to the simulate the world. Uh and the last component is the good dynamics model. Now you're going to do some action and
198.1s underlying the environment. For example,
200.3s there's a gravity that exists
204.0s we we have to understand these rules to
205.8s predict the world and you know to the
208.4s simulate the world. Uh and the last
210.6s component is the good dynamics model.
212.6s Now you're going to do some action and
215.6s your model should react to your action in the appropriate way. um keep these three things in mind because I'll revisit this end of this uh the causal jet part um to discuss like how I address these three questions. Um today we will not cover generated world models. There's like so many industries now like are doing the world
218.0s in the appropriate way. um keep these
220.4s three things in mind because I'll
221.8s revisit this end of this uh the causal
223.8s jet part
226.4s um to discuss like how I address these
228.4s three questions.
230.6s Um today we will not cover generated
232.8s world models. There's like so many
235.1s industries now like are doing the world
238.0s modeling stuff. For example, Gaia is a autonomous driving model that gets the action in the previous state to um predict the next driving um scenario and the genie they want to scaling up the war model. So they learned the something called latent action which is a proxy of the action because usually um online
240.5s autonomous driving model that gets the
242.6s action in the previous state to um
245.4s predict the next driving um scenario and
248.3s the genie they want to scaling up the
250.7s war model. So they learned the something
252.5s called latent action which is a proxy of
254.9s the action because usually um online
257.4s data doesn't have the per frame action annotation and Sora is like a technically video generation model but their fidelity of their generated scene is very good. So they got a lot of you know notice and marble from professor Fif Lee in Stanford um they are doing the 3D um interactive environment
259.5s annotation and Sora is like a
262.4s technically video generation model but
264.2s their fidelity of their generated scene
266.3s is very good. So they got a lot of you
268.5s know notice and marble from professor
271.0s Fif Lee in Stanford um they are doing
273.2s the 3D um interactive environment
277.1s so that um we can interact or explore but today we will talk about joint embedding predictive architecture which is somewhat different from the generative word model. So on the left side the thing you see is generative our model because as you can see um the x is the current state and let's say the y is
282.1s but today we will talk about joint
283.8s embedding predictive architecture which
285.8s is somewhat different from the
287.1s generative word model. So on the left
290.6s side the thing you see is generative our
292.9s model because as you can see um the x is
297.0s the current state and let's say the y is
299.4s the thing that we have to predict which is like a future step. You put the encoded current step to the predictor and you directly compare to y which means that your model should give the the same the pixel space as the your target is. However, in the right side that is a joint predictive architecture.
301.2s is like a future step.
304.2s You put the encoded current step to the
307.2s predictor and you directly compare to y
309.8s which means that your model should give
312.3s the the same the pixel space as the your
317.0s target is. However, in the right side
319.8s that is a joint predictive architecture.
322.2s You see like both X and Y are going to the encoder. So now we are comparing um our prediction and the true target in the latent space. So why does it matter? Um we don't have decoder in the jet power architecture but it is not merely about having no decoder. Um it is about uncertainty in the world.
324.6s the encoder. So now we are comparing um
328.6s our prediction and the true target in
331.0s the latent space. So why does it matter?
335.5s Um we don't have decoder in the jet
338.6s power architecture but it is not merely
341.5s about having no decoder. Um it is about
345.5s uncertainty in the world.
347.7s uncertainty in the world. So when human is thinking about like what is going to happen next are you predicting something in the like a very pixel level you know there's like some kind of uncertainty that you cannot predict and there's obviously like a your interest in like your area of interest. So, DEA tries to
349.5s when human is thinking about like what
351.5s is going to happen next are you
353.0s predicting something in the like a very
355.7s pixel level you know
358.5s there's like some kind of uncertainty
360.0s that you cannot predict and there's
362.6s obviously like a your interest in like
365.4s your area of interest. So, DEA tries to
370.4s um deal with only having predictive information in your latent space so that your prediction is getting more um meaningful and human-like way and also the framework can be interpreted in a different way. First of all, the generative war modeling is about like a likelihood. They learn the normalized likelihood over the future frames. But
373.1s information in your latent space so that
375.4s your prediction is getting more um
377.6s meaningful and human-like way and also
381.4s the framework can be interpreted in a
383.8s different way. First of all, the
385.0s generative war modeling is about like a
386.9s likelihood. They learn the normalized
388.7s likelihood over the future frames. But
390.7s Japa can be understood as a energy based Japa can be understood as a energy based model. So in the very original paper that Janukun suggested um he framed the the jetpa as an energy based model and energy based model learns the energy um score function that gives the high value if two value X and Y are
392.9s Japa can be understood as a energy based model.
394.5s So in the very original paper that
396.2s Janukun suggested um he framed the the
399.8s jetpa as an energy based model
403.0s and energy based model learns the energy
406.0s um score function that gives the high
408.5s value if two value X and Y are
411.3s incompatible and they give low score if those are compatible and compatibility in the war model makes um the point of the Y is a plausible future of X. So the energy based model uh can collapse unfortunately and there's two way to prevent this collapse. The first way is contrastive learning and the second way is regularizationbased
413.9s those are compatible and compatibility
416.1s in the war model makes um the point of
419.8s the Y is a plausible future of X.
425.0s So the energy based model uh can
427.9s collapse unfortunately and there's two
429.9s way to prevent this collapse. The first
432.2s way is contrastive learning and the
434.2s second way is regularizationbased
436.4s method. Um I'm not going to talk about this too deeply but the regularization based method are making very rich and well defined energy landscape. So that is what JPA is doing.
438.9s this too deeply but the regularization
441.4s based method are making very rich and
444.7s well defined energy landscape. So that
447.8s is what JPA is doing.
455.3s video they took on uh input as a the conse uh the consecutive frames with this pio temporal masking and they put this encoder and they get the representation and they tried to predict what is happening in the unseen area. So the target has like every information and here there's like um some uh regular regularization tool to
459.8s conse uh the consecutive frames with
462.0s this pio temporal masking and they put
464.4s this encoder and they get the
465.8s representation and they tried to predict
468.4s what is happening in the unseen area. So
471.3s the target has like every information
473.8s and here there's like um some
477.0s uh regular regularization tool to
479.4s prevent collapse. For example, they use EMA encoder. EMA here means um exponential moving average. This prevents like a really trivial collapses and also you you do stop gradient for the target encoder and you put the mask to give a model more challenging task. Um, VJA 2 is basically the same architecture as VJA 1, but they they
481.0s EMA encoder. EMA here means um
483.4s exponential moving average. This
485.1s prevents like a really trivial collapses
487.2s and also you you do stop gradient for
489.4s the target encoder and you put the mask
492.6s to give a model more challenging task.
497.0s Um, VJA 2 is basically the same
499.4s architecture as VJA 1, but they they
501.4s wanted to scale up and they also did some interesting post training and one of them is action condition control. So now this became more like war model because now it's getting as you can see in the predictor it gets the robot action and poses. action and poses. So this is uh one of the pools training in
504.2s some interesting post training and one
507.7s of them is action condition control. So
511.3s now this became more like war model
513.5s because now it's getting as you can see
515.0s in the predictor it gets the robot
516.6s action and poses.
519.2s action and poses. So
520.9s this is uh one of the pools training in
523.1s the visa patu and this is actually called dino model and have a very same architecture as the action condition post p training that you just saw. So what the dynino model is doing is they use frozen the frozen dino v2 encoder. What they claim is oh like do we actually have to train the
526.6s and this is actually called dino model
529.4s and have a very same architecture as the
533.0s action condition post p training that
535.0s you just saw. So what the dynino model
538.6s is doing is they use frozen the frozen
542.2s dino v2 encoder. What they claim is oh
544.8s like do we actually have to train the
546.7s jetpa encoder um to get the meaningful um abstraction for planning? They said no like a pre-trained dynino encoder can do that um role as well. So they generate past representation and with the auxilary variable for example action and propriceptive signals they predict the future state representation and they compare and this predictor is just like a simple
551.6s um abstraction for planning? They said
554.4s no like a pre-trained dynino encoder can
557.0s do that um role as well.
560.0s So they generate past representation and
563.3s with the auxilary variable for example
565.0s action and propriceptive signals they
567.3s predict the future state representation
569.8s and they compare
573.0s and this predictor is just like a simple
574.9s causal transformer um which predicts the future auto reggressively but here I want you to think about is this the past representation is what really human do like are we patchifying the image to you know predict the next step uh probably not so today we're going to talk about causal JEPA learning war models through
578.0s future auto reggressively
580.4s but here I want you to think about
584.1s is this the past representation
587.0s is what really human do like are we
590.1s patchifying the image to you know
592.4s predict the next step uh probably not
597.0s so today we're going to talk about
598.6s causal JEPA learning war models through
601.2s object ccentric latent intervention uh before starting this talk I would like to thank you for my collaborators Quantin Lucas and Yanukun and Rendle Quantin Lucas and Yanukun and Rendle Bistrio understanding the object interaction and object dynamics. Um these are three data set that I picked um to show today. The first one is push t um pretty famous
605.4s uh before starting this talk I would
607.5s like to thank you for my collaborators
609.8s Quantin Lucas and Yanukun and Rendle
612.6s Quantin Lucas and Yanukun and Rendle Bistrio
618.1s understanding the object interaction and
620.5s object dynamics. Um these are three data
623.8s set that I picked um to show today. The
628.0s first one is push t um pretty famous
630.7s example and the control experiments. The goal is you do action on the blue ball and you want to move the gray T block to perfectly overlap on the green tea. Um and the second um data set is clever. This is actually the video pair with question answering um the data set. The
633.5s goal is you do action on the blue ball
637.1s and you want to move the gray T block to
640.3s perfectly overlap on the green tea.
643.6s Um and the second um data set is clever.
646.6s This is actually the video pair with
648.6s question answering um the data set. The
652.0s question is like it can be u predictory for example what would happen in the next frame and it can be counterfactual for example oh what would happen if the um the blue collar cylinder doesn't exist and it can be explanatory so it's like basically vqa um benchmark and the third one is sphere as you can see u
655.0s for example what would happen in the
656.6s next frame and it can be counterfactual
658.9s for example oh what would happen if the
662.3s um the blue collar cylinder doesn't
664.2s exist and it can be explanatory so it's
667.5s like basically vqa um benchmark and the
670.7s third one is sphere as you can see u
673.4s many objects are interacting each other with the grounding rule for example there's a gravity um the mass matters here and etc.
675.2s with the grounding rule for example
676.8s there's a gravity
678.9s um the mass matters here and etc.
690.2s what the current models are doing is like this they patchify the image and they try to predict what will happen in each patches. But what you are you going to do to understand this mechanism is you want to understand the things like this. You have each object and you want to know like how one object um
693.2s like this they patchify the image and
696.1s they try to predict what will happen in
698.5s each patches.
700.6s But what you are you going to do to
703.5s understand this mechanism is
706.3s you want to understand the things like
708.2s this. You have each object and you want
711.2s to know like how one object um
714.8s influences each other.
722.6s if we just change the representation in the objectentric representation then it's like more clues to humanlike thinking right but then we have to learn the objectentric representation
724.5s the objectentric representation
727.0s then it's like more clues to humanlike
729.4s thinking right but then we have to learn
731.9s the objectentric representation
739.3s about the about the objectry learning but let me go through the uh briefly if you're doing computer vision you might have heard about the object object-centric learning there's a very foundational way called slot attention um from ro locatello um this is basically encoder decoder framework and beyond be like between this encoder and
741.8s but let me go through the uh briefly if
744.4s you're doing computer vision you might
745.8s have heard about the object
747.0s object-centric learning there's a very
749.6s foundational way called slot attention
752.4s um from ro locatello um this is
755.3s basically encoder decoder framework and
760.1s beyond be like between this encoder and
762.4s decoder you might have the feature space right and you bring like a buckets like a basket to put each features and like this basket is for each object so there's a mechanism called slot attention that puts um each feature to the each slot. So there's like a binding problem of feature to slot and they decode with
764.2s right and you bring like a buckets like
767.0s a basket to put each features and like
771.2s this basket is for each object so
774.2s there's a mechanism called slot
775.5s attention that puts
777.8s um each feature to the each slot. So
780.7s there's like a binding problem of
782.1s feature to slot and they decode with
784.6s this basket so that the model can have um well aligned information with the um well aligned information with the slots. brought this because uh this is united transformer class and slot attention is about self attention. Uh so what I want to show is how to bind the features to the slot is
788.3s um well aligned information with the
790.8s um well aligned information with the slots.
796.1s brought this because uh this is united
798.6s transformer class and slot attention is
801.0s about self attention. Uh so what I want
803.3s to show is
805.3s how to bind the features to the slot is
808.8s you do self attention where key is the input features. For example in the dino v2 there's like a patch embeddings right you allocate each patch embedding to slots you're basically assigning them and the value is also um the input features and you update the slot by using GRU. So iter you iteratively
811.1s input features. For example in the dino
814.2s v2 there's like a patch embeddings right
817.1s you allocate each patch embedding to
820.2s slots you're basically assigning them
823.0s and the value is also um the input
825.0s features and you update the slot by
828.9s using GRU. So iter you iteratively
831.3s updates the slots and now the slot will have the object aligned representation. So this is how usually the models are getting um objectentric representation. This is very very um uh nive approaches. There's so many advanced way for example this is for image and there's like a video slot attention models as well. So yeah, you
833.8s have the object aligned representation.
836.0s So this is how usually the models are
838.3s getting um objectentric representation.
841.0s This is very very um
843.9s uh nive approaches. There's so many
846.6s advanced way for example this is for
848.6s image and there's like a video slot
850.1s attention models as well. So yeah, you
852.8s can give it a try. can give it a try. And why do model behave if it perfectly understands the dynamics? Let's say the model learned the monkey eating banana. If model is like truly understanding this eating mechanism,
855.5s can give it a try. And
857.4s why do model behave if it perfectly
860.6s understands the dynamics? Let's say the
864.1s model learned the monkey eating banana.
868.5s If model is like truly understanding
870.4s this eating mechanism,
877.8s happening to banana when we cover the invisible clothes on banana. So even though you're not seeing the banana, if you see like a monkey like moving their mouth constantly, you can imagine oh like a banana might get shorter, you know, and vice versa. If you make the monkey invisible and like a banana is dis disappearing,
880.7s invisible clothes on banana.
884.2s So even though you're not seeing the
885.7s banana, if you see like a monkey like
888.7s moving their mouth constantly, you can
890.8s imagine oh like a banana might get
892.9s shorter, you know, and vice versa. If
896.3s you make the monkey invisible
898.7s and like a banana is dis disappearing,
901.4s you can infer that like a monkey is eating something, right? This is like very core motivation and very core explanation of causal japa. This is really relevant um of what I'm doing right now. And if you think of the previous And if you think of the previous example, it doesn't have to be the predictor
904.4s eating something, right? This is like
906.4s very core motivation and very core
909.0s explanation of causal japa. This is
911.0s really relevant um of what I'm doing
913.1s right now.
915.3s And if you think of the previous
916.7s And if you think of the previous example,
918.2s it doesn't have to be the predictor
919.8s doesn't have to be the causal transformer. Um as far as we are seeing the only history we're okay to use multiple time step history. So uh we're using a birectional transformer here. I would like to mention that. And let's see um how um the model actually implements this the monkey and banana mechanism in the transformer.
921.0s transformer. Um as far as we are seeing
924.3s the only history we're okay to use
927.0s multiple time step history. So uh we're
929.5s using a birectional transformer here. I
931.6s would like to mention that.
934.6s And let's see um how um the model
939.6s actually implements this the monkey and
941.8s banana mechanism in the transformer.
945.1s Let's say we have a nice like nicely aligned representation for each object. And we encode these object states to the representation. For example, let's say we our loop back history is four frames. So you see four frames and next frame um prediction will happen. But because this is not a causal auto reggressive transformer, this is a
948.1s aligned representation for each object.
952.1s And we encode these object states to the
954.6s representation. For example, let's say
956.6s we our loop back history is four frames.
959.7s So you see four frames and next
962.8s frame um prediction will happen.
966.0s But because this is not a causal auto
968.6s reggressive transformer, this is a
970.1s birectional VIT style transformer. We need the placeholder for this future tokens as well. So we use this as a mask tokens as well. So we use this as a mask token. And now so our goal is predict this mask token um correctly and for example the row each row means the evolution of each object.
973.4s need the placeholder for this future
975.3s tokens as well. So we use this as a mask
977.6s tokens as well. So we use this as a mask token.
980.1s And now
982.4s so our goal is predict this mask token
987.1s um correctly
990.1s and for example the row each row means
992.6s the evolution of each object.
996.2s Now, as I told you before, I want to mask something here. Um, blue dots are observable slots and the yellow one is mask slots. And it um let's imagine that we have this mask slot here. What should model do? What is the easiest way for model to have a reasonably low loss and
997.8s mask something here. Um, blue dots are
1000.5s observable slots and the yellow one is
1003.4s mask slots. And it um let's imagine that
1007.8s we have this mask slot here. What should
1010.9s model do? What is the easiest way for
1013.7s model to have a reasonably low loss and
1018.7s just predict the max token? Maybe just average the previous and the next token just like doing interpolation. But that is what we want. What we want is like a learn the object interaction. So like the previous example of monkey, I just mask everything. This is a bit aggressive way of masking. But now the
1021.5s average the previous and the next token
1023.6s just like doing interpolation. But that
1026.3s is what we want. What we want is like a
1028.6s learn the object interaction.
1031.5s So like the previous example of monkey,
1034.6s I just mask everything. This is a bit
1036.6s aggressive way of masking. But now the
1039.9s model doesn't have shortcut. It needs to infer the other slots to correctly infer the current state or the mask state.
1042.1s infer the other slots to correctly infer
1045.3s the current state or the mask state.
1055.0s objects at the same time? This can be happen because in slot attention you have to fix the maximum number of slots. The number of slot doesn't vary during training. So you kind of give a plenty amount of slot to give a model freedom. Okay, what information should um define as a slot?
1057.7s This can be happen because in slot
1060.2s attention you have to fix the maximum
1062.4s number of slots. The number of slot
1064.4s doesn't vary during training. So you
1067.4s kind of give a plenty amount of slot to
1071.0s give a model freedom. Okay, what
1072.8s information should um define as a slot?
1075.8s And sometimes in the scene, for example, we define eight slots, but we only have three objects in the scene. Then just masking only one object, only one slot might not be enough. And then now we want to put this into transformer and we have to flatten. So we have to do positional encoding before
1078.0s we define eight slots, but we only have
1080.0s three objects in the scene. Then just
1082.9s masking only one object, only one slot
1085.8s might not be enough.
1088.6s And then now we want to put this into
1090.6s transformer and we have to flatten. So
1093.2s we have to do positional encoding before
1095.4s we have to do positional encoding before then. Okay, there's no problem by temporal positional encoding. But when we are trying to do the positional encoding in the slot axis, there's a problem because what objectentric models are doing is they do not define the order of the objects but rather than they um the objectentric
1097.6s Okay, there's no problem by temporal
1099.6s positional encoding.
1102.1s But when we are trying to do the
1104.0s positional encoding in the slot axis,
1106.3s there's a problem because what
1107.9s objectentric models are doing is they do
1111.1s not define the order of the objects but
1114.2s rather than they um the objectentric
1117.6s models are um permutationally equivalent models are um permutationally equivalent equivalent
1120.7s models are um permutationally equivalent equivalent
1127.0s um depends with respect to the object orders. Um so this is basically not least but set of objects and then like if you're masking multiple objects the model might not know okay what to predict because these slots are not having object identity. So the video slot attentions can um keep the temporal consistency
1129.0s Um so this is basically not least but
1132.1s set of objects
1135.2s and then like if you're masking multiple
1137.2s objects the model might not know okay
1139.3s what to predict because these slots are
1142.0s not having object identity.
1145.6s So the video slot attentions can um keep
1148.6s the temporal consistency
1151.0s inside the video itself. For example, video A kind of have the consistency with the objects object order, but we cannot guarantee that video A and video B are having the same order of objects. So now what we do is we do not mask the very first time step in this case um time step t minus 3 and
1153.0s video A kind of have the consistency
1155.1s with the objects object order, but we
1158.2s cannot guarantee that video A and video
1160.1s B are having the same order of objects.
1164.1s So now what we do is
1167.0s we do not mask the very first time step
1170.4s in this case um time step t minus 3 and
1173.1s we use this information as a slot we use this information as a slot identity and we are what we are doing is we define each mask token um with the identity token plus learnable mask token with the positional learnable mask token with the positional encoding you have an initial condition of each
1175.9s we use this information as a slot identity
1177.9s and we are what we are doing is we
1179.6s define each mask token
1182.2s um with the identity token plus
1184.7s learnable mask token with the positional
1187.4s learnable mask token with the positional encoding
1192.3s you have an initial condition of each
1194.1s object. So it makes more sense when you um predict the max tokens. about is action conditioning. The war model should condition on the action properly. And what the dynino war model is doing is they concatenate the action embedding behind the the patch embedding. For example, Dino V2 small
1197.4s um predict the max tokens.
1203.9s about is action conditioning. The war
1205.8s model should condition on the action
1208.3s properly. And what the dynino war model
1211.2s is doing is they concatenate the action
1214.4s embedding behind the the patch
1217.7s embedding. For example, Dino V2 small
1220.2s has um 384 dimension of the features and let's say we have 10 dimensional action embedding and what dynam do is they duplicate it to like the number of the patches and they just like put the action embedding after the up after the patch representation. So now it's like 894 per patch because we added this action
1225.1s let's say we have 10 dimensional action
1228.5s embedding and what dynam do is they
1231.3s duplicate it to
1234.4s like the number of the patches and they
1236.8s just like put the action embedding after
1239.4s the up after the patch representation.
1241.8s So now it's like 894
1244.6s per patch because we added this action
1247.4s embedding after the patch representation. But this is not really optimal way I think because what we want to learn is something like this. Why don't we consider action as another node of the graphs? um the causal jeopard does not recover any true causal graph but it's motivation is grounded in the causal graph. So for
1249.4s representation. But this is not
1253.3s really optimal way I think because what
1255.4s we want to learn is something like this.
1259.0s Why don't we consider action
1261.7s as another node of the graphs? um the
1265.0s causal jeopard does not recover any true
1267.9s causal graph but it's motivation is
1270.5s grounded in the causal graph. So for
1273.0s example because we defined each object representation those are kind of playing a role of the nodes and we also try to use this action as one of the node. So in this current as one of the node. So in this current situation um the action is added like something like this not as a part of the feature
1276.2s representation those are kind of playing
1278.9s a role of the nodes and we also try to
1282.3s use this action
1284.6s as one of the node. So in this current
1287.8s as one of the node. So in this current situation
1289.4s um the action is added like something
1291.9s like this not as a part of the feature
1296.2s uh representation.
1302.1s this is the architecture of the causal Japa. So to sum up, you have the history frames and you put this into the objectentric encoder to get the object objectentric encoder to get the object representation and then you select some amount of objects to mask and you mask it and you put this into predictor
1303.6s Japa. So to sum up, you have the history
1306.2s frames and you put this into the
1308.2s objectentric encoder to get the object
1310.4s objectentric encoder to get the object representation
1312.3s and then you select some amount of
1314.0s objects to mask and you mask it and you
1317.4s put this into predictor
1321.0s um which is birectional transformer with the action and then you're predicting uh every mask tokens
1323.0s the action and then you're predicting uh
1326.5s every mask tokens
1336.4s did three experiments based on our goal. Our goal is understanding the object dynamics and first thing is reasoning on the counterfactual questions and second thing is planning and control and third is physical impossibility. So in the clever um we did an experiment with the other existing models as well but what I want to highlight is the
1339.4s Our goal is understanding the object
1341.3s dynamics and first thing is reasoning on
1344.2s the counterfactual questions and second
1346.9s thing is planning and control and third
1348.7s is physical impossibility.
1352.3s So in the clever um we did an experiment
1355.4s with the other existing models as well
1358.0s but what I want to highlight is the
1360.7s model without masking. What I want to say here is the performance is not because we use the objectentric representation but the core is masking like in the the banana example. So if you see the CJ JPA uh result um not only average accuracy is better you can see there's a clearly um clear gain
1363.1s say here is the performance is not
1365.8s because we use the objectentric
1367.8s representation but the core is masking
1370.6s like in the the banana example.
1374.7s So if you see the CJ JPA uh result um
1378.1s not only average accuracy is better you
1380.3s can see there's a clearly um clear gain
1383.5s in the counterfactual question. This counterfactual question is very well resonated with our original motivation because the counterfacture is asking something like oh what if um what if this object doesn't exist what if this object exists like something like this so you have to understand like how the object interact each other
1385.6s counterfactual question is very well
1387.4s resonated with our original motivation
1390.4s because the counterfacture is asking
1392.1s something like oh what if
1394.7s um what if this object doesn't exist
1397.5s what if this object exists like
1400.0s something like this so you have to
1401.6s understand like how the object interact
1403.8s each other
1412.2s the agent tried to control the object to reach the goal state. And here uh I want to uh emphasize the efficiency. If you can see the dynino world model baseline, you have um 196 patches. If you imagine um you are having 22 to4 by 22 to4 images and each patch should have a 384
1414.9s reach the goal state. And here uh I want
1418.6s to uh emphasize the efficiency. If you
1421.8s can see the dynino world model baseline,
1424.1s you have um 196 patches. If you imagine
1428.5s um you are having 22 to4 by 22 to4
1430.4s images and each patch should have a 384
1434.2s dimension of the features and if you're using objectcentric representation the number of tokens are significantly less than that and because now we have a clear semantic from each token the feature doesn't have to be super large. If you think of there's nothing much to include in the object. For example, texture, color, shape,
1436.8s using objectcentric representation the
1439.6s number of tokens are significantly less
1443.5s than that and because now we have a
1447.4s clear semantic from each token the
1450.3s feature doesn't have to be super large.
1453.8s If you think of there's nothing much to
1455.6s include in the object. For example,
1458.5s texture, color, shape,
1462.1s rotational state, and location. That's not that much. We we don't need like um really huge representation space. Um so after we only putting uh object ccentric representation instead of the patch representation in the dynino world model its performance actually drops a model its performance actually drops a lot and this can be true because dino world
1465.1s not that much. We we don't need like um
1466.9s really huge representation space.
1470.5s Um so after we only putting uh object
1475.2s ccentric representation instead of the
1478.7s patch representation in the dynino world
1480.5s model its performance actually drops a
1483.6s model its performance actually drops a lot
1485.8s and this can be true because dino world
1488.4s model used the causer transformer and we use by birectional transformer right and the object representation if you think of doesn't necessarily have to encode the velocity or acceleration kind of stuff because you cannot define those properties by only looking at the single static image. So this um drops the performance a lot and
1491.0s use by birectional transformer right and
1493.1s the object representation if you think
1495.0s of doesn't necessarily have to encode
1497.5s the velocity or acceleration kind of
1500.6s stuff because you cannot define those
1502.7s properties by only looking at the single
1505.0s static image. So
1508.8s this um drops the performance a lot and
1513.9s after we change the action conditioning method treating them as a separated node and we change the transformer to the birectional transformer the performance gain is significant. It gains this 15 gain is significant. It gains this 15 of um absolute percentages and after masking compared to the UCI objectentric dynino world model you
1516.1s method treating them as a separated node
1518.7s and we change the transformer to the
1521.5s birectional transformer the performance
1524.1s gain is significant. It gains this 15
1527.6s gain is significant. It gains this 15 of
1529.5s um absolute percentages and
1533.4s after masking compared to the UCI
1536.4s objectentric dynino world model you
1538.6s gains 28 which is pretty large and compared to the OC JPA the only difference is object masking and this object masking actually helps the model to understand this dynamics
1541.4s compared to the OC JPA the only
1543.4s difference is object masking and this
1546.4s object masking actually helps the model
1548.7s to understand this dynamics
1556.6s conditioning As I told you before, um the latent concatenation in the red line denotes the action conditioning method based on the dynam. But after we treating this action as a separated node, it clearly um shows the separated node, it clearly um shows the margin. fear. Fear is um among three data sets
1559.2s the latent concatenation in the red line
1562.1s denotes the action conditioning method
1564.3s based on the dynam.
1566.5s But after we treating this action as a
1569.3s separated node, it clearly um shows the
1572.8s separated node, it clearly um shows the margin.
1579.4s fear. Fear is um among three data sets
1583.3s we have it's like a most complicated dynamics and there's a lot of u formulation configuration and you need to learn precisely what is happening in the scenario and us when I compare the us japa and cj jpa us japa often um generates some physically implausible scene you see like the bar is floating
1585.0s dynamics and there's a lot of u
1588.2s formulation configuration
1590.6s and you need to learn precisely what is
1593.3s happening in the scenario and us when I
1597.0s compare the us japa and cj jpa us japa
1599.8s often um generates some physically
1601.9s implausible scene you see like the bar
1604.8s is floating
1606.8s below the fixed bar which doesn't make below the fixed bar which doesn't make sense And this can be done by just learning correlation like when two bars are closed they just stay there but that's not true like that's not how physics works. So by the training method of object masking, you keep asking the
1609.6s below the fixed bar which doesn't make sense
1610.6s And this can be done by just learning
1613.3s correlation like when two bars are
1615.5s closed they just stay there but that's
1618.6s not true like that's not how physics
1620.9s works. So by the training method of
1623.9s object masking, you keep asking the
1626.3s question to the model, what would happen if this doesn't exist? Um what should you what should you consider to predict the mass token? It can um learn the true the mass token? It can um learn the true dynamics and um this is the attention um probing for the this previous example. You can see the
1629.9s if this doesn't exist? Um what should
1632.6s you what should you consider to predict
1634.9s the mass token? It can um learn the true
1637.7s the mass token? It can um learn the true dynamics
1640.5s and um this is the attention um probing
1645.2s for the
1647.1s this previous example. You can see the
1650.0s failure is actually coming from attending to the wrong irrelevant object. So for example, OC Jeppa um relies on the cup which contains the blue ball and the C ja um condition on the right bar to predict its feature state. And here the causal JEPA the the terminology causal can
1651.9s attending to the wrong irrelevant
1654.0s object. So for example, OC Jeppa um
1657.7s relies on the cup which contains the
1660.4s blue ball
1663.4s and the C ja um condition on the right
1668.3s bar to predict its feature state.
1672.7s And here the causal JEPA the the
1675.3s terminology causal can
1678.3s um it can it can stands for many things but here we use the causal as a temporally directed predictive dependencies which means that because we're predicting the future from the history the the edge is directed and to predict the mask token you need to attend to the relevant object. So we called it temporally directed predictive
1681.4s but here we use the causal as a
1683.4s temporally directed predictive
1685.0s dependencies which means that because
1686.7s we're predicting the future from the
1688.5s history the the edge is directed
1694.7s and to predict the mask token you need
1697.4s to attend to the relevant object. So we
1699.7s called it temporally directed predictive
1701.4s dependencies. This is not the very conventional and traditional way of defining causal but um recent modern causal machine learning um use this kind of um definition as well. And here to go back to the role of the object masking um we would say this predictor finds influence neighborhood. influence neighborhood we just define as its name
1703.5s conventional and traditional way of
1705.1s defining causal but um recent modern
1708.3s causal machine learning um use this kind
1711.7s of um definition
1713.8s as well.
1716.6s And here to go back to the role of the
1718.3s object masking
1721.4s um we would say this predictor finds
1724.3s influence neighborhood. influence
1726.1s neighborhood we just define as its name
1728.2s but it's just a predictive predictively sufficient set. So it's like a minimal set that it needs to predict the mask token correctly and this can be true um more formally this can be true with four assumptions. The first thing is we do not assume there's an instantaneous relationship and the second assumption is that every
1730.6s sufficient set. So it's like a minimal
1732.5s set that it needs to predict the mask
1735.7s token correctly
1738.6s and this can be true um more formally
1741.5s this can be true with four assumptions.
1743.6s The first thing is we do not assume
1745.4s there's an instantaneous relationship
1747.9s and the second assumption is that every
1751.7s training instances should share the same mechanism. For example, the gravity should not change um along videos. It's there's like a governing rule applied um in the whole data set. And the third assumption is object line representation. This is the most tricky um in the practical sense because this assume that our objectentric representation is constant throughout
1754.1s mechanism. For example, the gravity
1756.9s should not change um along videos. It's
1761.8s there's like a governing rule
1765.2s applied um in the whole data set. And
1767.9s the third assumption is object line
1770.2s representation. This is the most tricky
1772.6s um in the practical sense because this
1775.4s assume that our objectentric
1777.0s representation is constant throughout
1780.1s the video. It should um not getting swapped. The object should not be split into the different slots and our presentation should u reflect the scenario faithfully. And the fourth is history sufficiency. This is be uh because we use finite history. For example, we see four previous frames and predict the next frame. And this four
1783.7s swapped. The object should not be split
1785.7s into the different slots and our
1788.5s presentation should u reflect the
1790.6s scenario faithfully. And the fourth is
1792.9s history sufficiency. This is be uh
1795.6s because we use finite history. For
1797.5s example, we see four previous frames and
1799.8s predict the next frame. And this four
1802.6s history frames should be enough to predict the future. And to make this causal u machine learning to practical uh we do not assume first order markup process that I told before the objectentric representation usually um does not follow the first order markup process and we allow confounder. Um the confounder makes things really tricky
1804.8s predict the future.
1808.4s And to make this causal u machine
1811.4s learning to practical
1813.6s uh we do not assume first order markup
1815.9s process that I told before the
1817.2s objectentric representation usually um
1820.1s does not follow the first order markup
1822.0s process and we allow confounder. Um the
1826.4s confounder makes things really tricky
1828.7s because we cannot recover the true um causal graph usually because of the confounder but in the objectentric representation in the real world um this is kind of inevitable.
1831.4s causal graph usually because of the
1833.0s confounder but in the objectentric
1836.2s representation in the real world um this
1839.0s is kind of inevitable.
1847.0s answer this questions really quickly. What happens if objectentric representation is not fatful? a little bit is fine because you know like still masking object slots is still some kind of inductive bias even though our masking is not like perfect. So some a little minor a little minor um minor uh
1848.9s What happens if objectentric
1850.4s representation is not fatful?
1853.0s a little bit is fine because you know
1855.0s like still masking object slots is still
1858.4s some kind of inductive bias even though
1861.0s our masking is not like perfect. So some
1865.4s a little minor
1868.4s a little minor um
1870.5s minor uh
1873.0s like a wrong model can be okay but if it's really bad it doesn't work. causal graph? No. In our method, um there's confounder. So we cannot recover the true causal graph. And also sometime it is really hard to define what is the true causal graph in many scenarios. How to select the number of objects to
1876.0s it's really bad it doesn't work.
1881.8s causal graph? No. In our method, um
1885.4s there's confounder. So we cannot recover
1887.5s the true causal graph. And also sometime
1889.3s it is really hard to define what is the
1891.0s true causal graph in many scenarios.
1894.3s How to select the number of objects to
1896.2s mask. Uh this is a good question. Uh because we have a fixed number of the slots and the number of actual objects are varying depend on which frame are we looking at. The ideal amount is just only one foreground um object without the background slots. But we sometimes cannot control this really well. So I'll
1898.8s because we have a fixed number of the
1900.3s slots and the number of actual objects
1902.4s are varying depend on which frame are we
1905.4s looking at. The ideal amount is just
1909.8s only one foreground um object without
1913.4s the background slots. But we sometimes
1917.5s cannot control this really well. So I'll
1920.2s cannot control this really well. So I'll say um decide the number of mask based on the data statistics. Just guess based on the data statistics. Just guess your perfect of um the ma perfect number of the mask and you should sweep a bit to find the perfect number and the limitation the largest limitation is coming from the
1922.0s um decide the number of mask
1925.4s based on the data statistics. Just guess
1928.3s based on the data statistics. Just guess your
1930.4s perfect of um the ma perfect number of
1932.9s the mask and you should sweep a bit to
1935.7s find the perfect number
1938.3s and the limitation the largest
1939.9s limitation is coming from the
1941.1s objectentric encoder. The objectentric representation does not work really well on the occlusion situation and you know in the middle of the video some objects can appear and disappear but this kind of slot attention cannot handle this scenario really well. So this is some pain point of this model and finally we got back to this um three
1943.4s representation does not work really well
1946.4s on the occlusion situation and you know
1950.1s in the middle of the video some objects
1952.3s can appear and disappear but
1955.4s this kind of slot attention cannot
1956.9s handle this scenario really well. So
1959.2s this is some pain point of this model
1962.9s and finally we got back to this um three
1966.2s components of world model again. Um in the beginning of the talk I said like there's a three component that uh we should consider to make a good world model and for the good state representation we use this object representation we use this object representation and for a good transition model we did
1968.5s the beginning of the talk I said like
1970.8s there's a three component that uh we
1973.0s should consider to make a good world
1974.4s model and for the good state
1976.9s representation we use this object
1979.0s representation we use this object representation
1981.0s and for a good transition model we did
1983.9s object masking um to let model learn the predictive sufficiency and about the dynamics model um we we kind of tweaked the method to condition the action So we treat this action variables as a separated nodes.
1987.2s predictive sufficiency
1990.6s and about the dynamics model um we we
1995.0s kind of tweaked the method to condition
1997.0s the action So we treat this action
1999.2s variables as a separated nodes.
2007.5s part and let me quickly uh put to Lucas.
2009.9s uh put to Lucas.
2019.5s a third year PhD student at Miller advised by Damian Shir with research at Samsung but I work closely with Rando um with prof and the work I'm going to talk about today is a work I done uh in collaboration with conte at NYU with Yan Lun Damian Yan Lun and um Randal
2023.0s advised by Damian Shir with research at
2026.0s Samsung but I work closely with Rando
2030.0s um with prof
2032.5s and the work I'm going to talk about
2034.9s today is a work I done uh in
2037.8s collaboration with conte
2040.5s at NYU with Yan Lun Damian Yan Lun and
2045.8s um Randal
2047.9s so Um today what I'm going to talk about briefly is how to make this um this whole world model stuff and Japa stuff pretty simple to train and um I'm not going to tell you again what uh what is a war model. I think as did it pretty well. So I'm going to um
2050.2s briefly is how to make this um this
2053.4s whole world model stuff and Japa stuff
2055.6s pretty simple to train
2058.0s and um I'm not going to tell you again
2061.4s what uh what is a war model. I think as
2064.4s did it pretty well. So I'm going to um
2067.4s go directly to what we did. So just before that I would like to talk about a big problem. So asil say um JA aim to learn uh representation in abstract space and so it's directly in opposition to generative modeling where in generative modeling what you try to do is you try to model your input space
2070.8s before that I would like to talk about a
2072.6s big problem. So asil say um JA aim to
2077.3s learn uh representation in abstract
2079.6s space and so it's directly in opposition
2082.9s to generative modeling where in
2084.9s generative modeling what you try to do
2086.9s is you try to model your input space
2089.1s basically you try to learn a representation of your input space and try to do stuff in your input base input space sorry and so ja say that this is for most of your task not desirable um for instance If you do self-driving car application, you most likely don't care to model uh the movement of the leaf of
2090.6s representation of your input space and
2092.3s try to do stuff in your input base input
2094.7s space sorry and so ja say that this is
2098.4s for most of your task not desirable um
2101.8s for instance If you do self-driving car
2104.2s application, you most likely don't care
2107.3s to model uh the movement of the leaf of
2109.9s of the tree, okay, in the road, you don't care to model that. If you do generative modeling, uh by definition, you will have to model that because you want to model all the detail of your input. So your loss is going to give signal for that. So what JA propose is
2112.4s don't care to model that. If you do
2114.1s generative modeling, uh by definition,
2116.7s you will have to model that because you
2118.5s want to model all the detail of your
2120.1s input. So your loss is going to give
2122.2s signal for that. So what JA propose is
2124.9s to um first encode all your inputs into an abstract space typically with an encoder and neural net and then try to model the dynamic the dynamics sorry of your space in that uh latent space. Okay. Um so it's pretty nice when you say like that but if you look on the
2128.1s an abstract space typically with an
2129.9s encoder and neural net and then try to
2132.6s model the dynamic the dynamics sorry of
2135.2s your space in that uh latent space.
2138.2s Okay. Um so it's pretty nice when you
2142.2s say like that but if you look on the
2144.1s right uh the image I put if you just do that you suffer from what people call collapse. And so um what is collapse? the failure mode where you can see that um if I do if I do nothing I don't put constraint on the distribution of my embeddings your model on the right can
2147.1s that you suffer from what people call
2149.3s collapse. And so um what is collapse?
2152.2s the failure mode where you can see that
2154.9s um if I do if I do nothing I don't put
2157.9s constraint on the distribution of my
2160.3s embeddings your model on the right can
2163.2s simply learn to uh disregard the input and just produce a constant vector and so the you can minimize the prediction loss um in the latent space just by saying okay I will going to encode everything as a constant vector like zero and then it's trivially easy to predict what is going to be the next uh
2166.2s and just produce a constant vector and
2168.9s so the you can minimize the prediction
2170.8s loss um in the latent space just by
2174.6s saying okay I will going to encode
2176.3s everything as a constant vector like
2178.4s zero and then it's trivially easy to
2180.6s predict what is going to be the next uh
2183.1s state is going to just be zero again and so the whole research on Japan and like a big part of the research on a big part of the research on collapse. So asel say before I should add coal ja in that previous jar recipe as well. Um you have vija that try to do uh that
2185.6s so the whole research on Japan and like
2188.7s a big part of the research on
2191.6s a big part of the research on collapse.
2193.8s So asel say before I should add coal ja
2198.1s in that previous jar recipe as well. Um
2200.8s you have vija that try to do uh that
2203.7s that try to avoid that collapse with exponential moving average. um you have dynino world model that use a pre-train encoder basically to avoid the collapse because if you uh use a pre-train encoder and you froze it you can produce uh non-trivial embedding and you can learn the dynamics directly in the embedding so it's essentially supervised
2205.5s exponential moving average. um you have
2208.1s dynino world model that use a pre-train
2210.5s encoder basically to avoid the collapse
2212.3s because if you uh use a pre-train
2214.7s encoder and you froze it you can produce
2217.7s uh non-trivial embedding and you can
2220.2s learn the dynamics directly in the
2221.6s embedding so it's essentially supervised
2223.8s learning in our model and then you have KDM that try everything end to end so they train the encoder the predictor and everything end to end their action condition model as well But they use um something called vrag and uh what is vrag? It's just a an anti-olapse regularization term that try to make the
2226.5s KDM that try everything end to end so
2230.2s they train the encoder the predictor and
2232.7s everything end to end their action
2235.1s condition model as well But they use um
2238.2s something called vrag and uh what is
2240.9s vrag? It's just a an anti-olapse
2243.8s regularization term that try to make the
2246.5s coariance matrix um of your uh features of your latent features identity. So basically you want uh each dimension to be uh decorated from the other of your feature space except for like the current metric the the diagonal line is one typically so like you have a positive correlation. The problem with
2250.8s of your latent features identity.
2254.8s So basically you want uh each dimension
2258.2s to be uh decorated from the other of
2261.7s your feature space except for like the
2264.6s current metric the the diagonal line is
2267.6s one typically so like you have a
2270.2s positive correlation. The problem with
2272.9s PLLDM is as you like as you can see on the slide you need to apply so one term for the variance what one term for the coariance minimization uh you need to do that specially and temporarily so you have four term they add an additional inverse dynamics loss to try to gain more information about
2275.9s the slide you need to apply so one term
2278.4s for the variance what one term for the
2280.6s coariance minimization
2282.8s uh you need to do that specially and
2284.5s temporarily so you have four term they
2286.7s add an additional inverse dynamics loss
2289.1s to try to gain more information about
2291.5s the action whatever it's five plus you have the the prediction in the future the L2 minimization you see on the right so your predictor try to predict the future in the latent space and you compare that with what your encoder say about the future should be. Okay. And so it give you a six term. So it's pretty
2295.4s have the the prediction in the future
2298.1s the L2 minimization you see on the right
2300.1s so your predictor try to predict the
2302.1s future in the latent space and you
2304.6s compare that with what your encoder say
2306.5s about the future should be. Okay. And so
2308.7s it give you a six term. So it's pretty
2310.6s difficult to tune that because you have six hyperparameters to tune and that's why we um propose lower model. So what is lower model? It's it's just a simple jer that doesn't use any tricks. So there is no exposing average, no masking, no stop gradient, no pre-trained encoder and also no unstable loss. Why? Because we have a single
2312.3s six hyperparameters to tune and that's
2315.1s why we um propose lower model. So what
2318.7s is lower model? It's it's just a simple
2321.7s jer that doesn't use any tricks. So
2323.9s there is no exposing average, no
2325.9s masking, no stop gradient, no
2328.1s pre-trained encoder and also no unstable
2330.5s loss. Why? Because we have a single
2332.4s loss. Why? Because we have a single hyperparameter uh and we make the model such that it's only 15 million parameter. So it mean you can train that on a single GPU. It's fully end to end from ro pixel. So there is no auxil auxiliary information such as prop reception information or whatever. And we observe that it's it's
2334.4s uh and we make the model such that it's
2336.6s only 15 million parameter. So it mean
2339.1s you can train that on a single GPU. It's
2341.7s fully end to end from ro pixel. So there
2343.8s is no auxil auxiliary information such
2346.6s as prop reception information or
2348.6s whatever. And we observe that it's it's
2351.4s 50 times faster than dynal model for doing planning. So basically what is lower model? It's basically JAR in its pure essence. So what you do is you take an observation OT an observation OT one. So the next observation you process I'm sorry you process both of them through a shared encoder and you you get a representation
2354.6s doing planning.
2357.1s So basically what is lower model? It's
2359.8s basically JAR in its pure essence. So
2362.7s what you do is you take an observation
2364.6s OT an observation OT one. So the next
2368.0s observation you process I'm sorry you
2371.0s process both of them through a shared
2373.2s encoder and you you get a representation
2376.6s ZT and ZT one. What you're going to do the task is you will use ZT and the action at uh step t and try to learn a predictor to model the dynamics. So you're going to try to predict the future. So you have an estimation of the future called a t plus one and you
2379.9s the task is you will use ZT and the
2382.3s action at uh step t and try to learn a
2386.4s predictor to model the dynamics. So
2387.9s you're going to try to predict the
2389.1s future. So you have an estimation of the
2391.3s future called a t plus one and you
2393.4s compare that to what your encoder predict for the next state. You use mean square error uh and you use what we call cig to avoid collapse. What is nice is that basically what I told you about before with JPA is literally what you see and the code is exactly what I just described. So if you
2395.4s predict for the next state. You use mean
2398.0s square error uh and you use what we call
2400.6s cig to avoid collapse.
2403.4s What is nice is that basically what I
2406.1s told you about before with JPA is
2407.9s literally what you see and the code is
2409.6s exactly what I just described. So if you
2411.6s look at the pseudo code on the right, it's actually not that much to the code. It's literally the true code. You encode, you predict, you compute your prediction, your future prediction error, and you use Creg to avoid collapse. Okay. Um so you can see that at the bottom line at the return I have a single
2413.1s it's actually not that much to the code.
2415.1s It's literally the true code. You
2417.2s encode, you predict, you compute your
2419.3s prediction, your future prediction
2421.0s error, and you use Creg to avoid
2423.6s collapse. Okay.
2426.3s Um so you can see that at the bottom
2429.7s line at the return I have a single
2431.6s hyperparameter lambda. So this is the only stuff you need to tune and because you have a single hyper parameter you can do a bisection. So it's login to find the optimal lambda. So it's pretty nice because you you can really easily um tune the model and so what is cig now? So this is the
2433.7s only stuff you need to tune and because
2435.4s you have a single hyper parameter you
2437.2s can do a bisection. So it's login to
2439.1s find the optimal lambda. So it's pretty
2440.6s nice because you you can really easily
2443.6s um tune the model
2446.4s and so what is cig now? So this is the
2448.9s stuff that avoid your representation collapse and make sure that your ZT are uh like informative about what is inside your your observation. So basically you can see ZT as a learn abstraction of the world or a learn state. you have an observation of the of the world which is the image OT and you try to learn you
2450.6s collapse and make sure that your ZT are
2453.5s uh like informative about what is inside
2457.1s your your observation. So basically you
2460.2s can see ZT as a learn abstraction of the
2463.5s world or a learn state. you have an
2465.8s observation of the of the world which is
2468.0s the image OT and you try to learn you
2471.0s try to estimate the state the breakthrough state okay and so how do you prevent collapse basically we use a simple um regularization called creg for sketch isotropic goian regularizer that uh randal and yan llun introduced um last November I think if I'm correct the idea is very simple the maths are a bit
2472.5s breakthrough state okay and so how do
2475.8s you prevent collapse basically we use a
2477.7s simple um regularization called creg for
2481.0s sketch isotropic goian regularizer that
2484.2s uh randal and yan llun introduced um
2487.8s last November I think if I'm correct the
2491.0s idea is very simple the maths are a bit
2494.2s tricky but I invite you to read the paper. It's a very nice paper. Uh but the idea is very simple. You take the distribution of your embedding. Okay. So you you take your batch, you look at how are your embedding ZT distributed. Okay. And you will try to optimize that distribution to be um isotropic gion. So
2495.9s paper. It's a very nice paper. Uh but
2498.6s the idea is very simple. You take the
2500.4s distribution of your embedding. Okay. So
2502.7s you you take your batch, you look at how
2505.0s are your embedding ZT distributed. Okay.
2509.0s And you will try to optimize that
2511.5s distribution to be um isotropic gion. So
2515.6s how do you do that? You could use generative modeling to do that uh with the K version and everything such as in um variational autoenccoder. We don't want that here. We don't want to use generative models. So what you do is um you're going to try to use a statistical test. Okay. So in statistic you have
2517.2s generative modeling to do that uh with
2519.5s the K version and everything such as in
2521.9s um variational autoenccoder. We don't
2524.6s want that here. We don't want to use
2526.2s generative models. So what you do is um
2530.2s you're going to try to use a statistical
2532.6s test. Okay. So in statistic you have
2534.6s test that tells you given an empirical distribution how close it how close is it to um goian distribution. Okay. The problem is that uh our embedding space is very high dimensional and so there is no um test that like basically the static stat statistical test they suffer from the curves of dimensionality so
2536.8s distribution how close it how close is
2539.9s it to um goian distribution. Okay. The
2544.0s problem is that uh our embedding space
2546.4s is very high dimensional and so there is
2548.8s no um test that like basically the
2551.8s static stat statistical test they suffer
2554.0s from the curves of dimensionality so
2556.2s it's pretty difficult to directly use a test to optimize in high dimensional uh test to optimize in high dimensional uh distribution and so what randal had or idea is to and so what randal had or idea is to basically um project you sample a lot of random direction in your latent space and
2558.1s test to optimize in high dimensional uh
2561.0s test to optimize in high dimensional uh distribution
2562.5s and so what randal had or idea is to
2567.1s and so what randal had or idea is to basically
2568.6s um project you sample a lot of random
2572.7s direction in your latent space and
2574.3s you're going to project all your embedding into uh one dimension. So as you can see for instance if you look at the red arrow um on this um on this image and you look on the right you can see that when you once you project your embedding into that direction you get a
2575.7s embedding into uh one dimension. So as
2579.0s you can see for instance if you look at
2580.3s the red arrow um on this um on this
2585.0s image and you look on the right you can
2587.3s see that when you once you project your
2589.2s embedding into that direction you get a
2593.0s uni model like a sorry a univariate empirical distribution and you're now you can optimize because it's just a 1D distribution you can optimize to make that um goion and if you do that in like in a lot of direction there a theorem called world theorem that say that if you optimize the marginals to be gian
2595.7s empirical distribution and you're now
2598.1s you can optimize because it's just a 1D
2600.2s distribution you can optimize to make
2602.2s that um goion and if you do that in like
2606.7s in a lot of direction there a theorem
2609.2s called world theorem that say that if
2611.3s you optimize the marginals to be gian
2615.2s then the joint is going to be gian. So basically if you do that for a lot of random direction you can prove that your latent embedding distribution is going to be gian as well. So it it makes us uh sure that that we will have informative embedding. So that's pretty nice.
2617.1s basically if you do that for a lot of
2618.6s random direction you can prove that your
2621.5s latent embedding distribution is going
2623.2s to be gian as well. So it it makes us uh
2627.8s sure that that we will have informative
2630.8s embedding. So that's pretty nice.
2633.5s Okay. So now how can you let's imagine you train your war model blah blah blah blah blah how can you make sure that you learn a good war model so there is two way uh I I mean I I know two way if you know more than that's nice we can talk
2636.0s you train your war model blah blah blah
2637.8s blah blah how can you make sure that you
2640.0s learn a good war model so there is two
2641.7s way uh I I mean I I know two way if you
2644.2s know more than that's nice we can talk
2646.1s about that but the first way is uh online control so can you use your learn world model that is action condition and can you in an online way optimize this action to perform control so this is the first thing and the second thing we are going to talk about after is try to uh
2648.8s online control so can you use your learn
2651.2s world model that is action condition and
2653.7s can you in an online way optimize this
2656.3s action to perform control so this is the
2658.9s first thing and the second thing we are
2661.3s going to talk about after is try to uh
2664.2s assess if your world model understand intuitive physics. But let's focus on control for now. So how we perform control is basically we're going to use your world model. Okay, the learn world model. You're going to sample random model. You're going to sample random trajectories of action and you're going to optimize
2666.5s intuitive physics. But let's focus on
2668.8s control for now. So how we perform
2671.1s control is basically we're going to use
2673.9s your world model. Okay, the learn world
2676.3s model. You're going to sample random
2678.7s model. You're going to sample random trajectories
2680.6s of action and you're going to optimize
2683.1s by rolling out uh the future. you you will optimize that um that sequence of action to match uh a target goal. So for instance, let's say you have a goal frame OG. Okay, you start from the current frame 01. You both encode them through your encoder. You you got the initial state latent state Z1. You have
2686.7s will optimize that um that sequence of
2690.4s action to match uh a target goal. So for
2693.4s instance, let's say you have a goal
2695.1s frame OG. Okay, you start from the
2697.0s current frame 01. You both encode them
2699.7s through your encoder. You you got the
2701.8s initial state latent state Z1. You have
2705.3s your target state which is ZG. Okay. And then you sample an initial uh sequence of action and you use that the first one with your first state inside the predictor. It gives you Z2. Then you use the second one etc blah blah blah blah blah. And after each state each H each
2709.4s then you sample an initial uh sequence
2711.8s of action and you use that the first one
2714.8s with your first state inside the
2716.6s predictor. It gives you Z2. Then you use
2718.7s the second one etc blah blah blah blah
2720.9s blah. And after each state each H each
2723.9s step sorry you compare how far are you in your laten space to the representation of the goal. Okay. It can be MSE for instance. You use the MSE loss to estimate how far you are from your goal. And because your predictor is differentiable, you can um for instance back propagate until the action try to
2727.3s in your laten space to the
2729.1s representation of the goal. Okay. It can
2731.4s be MSE for instance. You use the MSE
2734.9s loss to estimate how far you are from
2736.6s your goal. And because your predictor is
2739.0s differentiable, you can um for instance
2741.7s back propagate until the action try to
2745.5s sequence of action to minimize the distance with your goal. Okay. So I can show you some results. Um you can see that uh we consider four tasks. So the first one on the left is called two room. It's a simple 2D navigation task where you need to move from one side of
2747.2s distance with your goal. Okay. So I can
2751.6s show you some results. Um you can see
2754.8s that uh we consider four tasks. So the
2758.1s first one on the left is called two
2760.0s room. It's a simple 2D navigation task
2762.2s where you need to move from one side of
2764.4s a room into the other by passing through a door. Okay. Then you have a reacher. It come from the mind control suit where you need to um you can see on the right you have a 2D uh like sorry a two joint arm and you need to match a target location from a given goal. Okay. Then
2766.6s a door. Okay. Then you have a reacher.
2769.2s It come from the mind control suit where
2771.2s you need to um you can see on the right
2774.2s you have a 2D uh like sorry a two joint
2778.7s arm and you need to match a target
2781.4s location from a given goal. Okay. Then
2784.9s you have push t. So push tel already present but the the goal of push key is to um push basically the t to the green area. That's as simple as that and you can only push you cannot pull an ogbench cube. It's a 3D um environment where the the objective of that environment is to
2788.7s present but the the goal of push key is
2791.4s to um push basically the t to the green
2794.6s area. That's as simple as that and you
2796.4s can only push you cannot pull an ogbench
2799.4s cube. It's a 3D um environment where the
2802.2s the objective of that environment is to
2804.7s make the cube your cube it's a 3D manipulation robotic arm where you need to um make the cube matching a target location as well. And so I think the most interesting result is the first one is push ty because you can see that din model okay has been trained with prop
2807.8s manipulation robotic arm where you need
2809.7s to um make the cube matching a target
2812.2s location as well. And so I think the
2816.2s most interesting result is the first one
2818.5s is push ty because you can see that din
2822.7s model okay has been trained with prop
2825.0s reception and you you can see that lower model without prop reception beat dino model with prop reception with half the the parameter and but what is most interesting is if you remove the prop perception of dino model you see that the performance drop okay and low lower model because it doesn't have proper ception it it beats
2827.3s model without prop reception beat dino
2830.2s model with prop reception with half the
2832.1s the parameter
2834.4s and but what is most interesting is if
2836.2s you remove the prop perception of dino
2838.1s model you see that the performance drop
2840.6s okay and low lower model because it
2842.8s doesn't have proper ception it it beats
2844.9s like by a lot model with a prop reception and PLDM which is the only other one that doesn't use prop perception and that is fully trained end to end also have very bad performance or push it so it was it was very nice to see this result uh so for richer we beat PLM and dyn
2847.0s reception and PLDM which is the only
2850.0s other one that doesn't use prop
2851.5s perception and that is fully trained end
2853.6s to end also have very bad performance or
2855.8s push it so it was it was very nice to
2858.2s see this result uh so for richer we beat
2861.1s PLM and dyn
2868.0s plm but we don't beat dynoval model and uh the explanation The most likely explanation for that is that we only train on trajectories coming from the data set. The dyn has been pre-trained because it's a pre-trained encoder. It has been pre-trained on 124 million natural images and consequently it has a better understanding about object and 3D
2872.0s uh the explanation The most likely
2873.8s explanation for that is that we only
2875.9s train on trajectories coming from the
2878.4s data set. The dyn has been pre-trained
2882.5s because it's a pre-trained encoder. It
2884.2s has been pre-trained on 124 million
2886.7s natural images and consequently it has a
2890.6s better understanding about object and 3D
2893.7s um than than we have because it has been trained on a lot lot lot more data. Uh and also very funnily you can see that for two room even if this if it's the simplest um task almost all the the baseline destroy the task but for us we don't and this is actually a limitation
2896.3s trained on a lot lot lot more data.
2899.4s Uh and also very funnily you can see
2901.9s that for two room even if this if it's
2904.0s the simplest um task almost all the the
2908.0s baseline destroy the task but for us we
2910.9s don't and this is actually a limitation
2913.0s of cig that that exists for now is that um if you have intrinsic dimensionality so what I mean by that is the true dimensionality that you need to solve the problem for instance for two room it's just two because you just need to know the x and y of the agent as soon as
2915.7s um if you have intrinsic dimensionality
2918.6s so what I mean by that is the true
2921.4s dimensionality that you need to solve
2923.0s the problem for instance for two room
2924.7s it's just two because you just need to
2926.6s know the x and y of the agent as soon as
2928.7s you know that and I give you a target location, you can solve the environment. If that is very much smaller than the embedding you use, there is no way that you you can produce goian embeddings. At least your your um your cig when you optimize is going to need to create fake
2930.2s location, you can solve the environment.
2932.8s If that is very much smaller than the
2935.3s embedding you use, there is no way that
2937.5s you you can produce goian embeddings. At
2940.0s least your your um your cig when you
2943.0s optimize is going to need to create fake
2946.2s fake stuff to be able to be goian fake like to use some information to make your latent space go. So it it doesn't help you basically. Um there is research going on to try to fix that and actually we know that if you carefully tune the hyperparameter you can somewhat overcome
2948.7s like to use some information to make
2950.9s your latent space go. So it it doesn't
2954.1s help you basically. Um there is research
2956.7s going on to try to fix that and actually
2959.6s we know that if you carefully tune the
2961.3s hyperparameter you can somewhat overcome
2964.0s this issue but for the sake of fairness we didn't uh hyper tune all the hyperparameters and we keep the same hyper parameter across all the hyper parameter across all the environment but the most interesting part why why it's very nice is when you look at the planning time so for dynov model and we
2967.1s we didn't uh hyper tune all the
2969.4s hyperparameters and we keep the same
2971.0s hyper parameter across all the
2972.6s hyper parameter across all the environment
2974.7s but the most interesting part why why
2976.9s it's very nice is when you look at the
2978.4s planning time so for dynov model and we
2981.4s heavily optimize dynov model to be to try to be as fast as possible posible the faster the fastest we we could go for doing the full planning okay uh is um 47 second and the reason for that is because you need to your predictor takes all the patches and you need to predict
2984.2s try to be as fast as possible posible
2985.9s the faster the fastest we we could go
2988.6s for doing the full planning okay uh is
2992.9s um 47 second and the reason for that is
2996.4s because you need to your predictor takes
2999.3s all the patches and you need to predict
3001.8s all the patches so it's quite uh slow because of the quadratic cost of the attention to to predict the future right um as for us with lower model we we have a single embedding for representing the the the laten state because we just use the CLS token of the encoder and it
3003.8s because of the quadratic cost of the
3005.5s attention to to predict the future right
3008.9s um as for us with lower model we we have
3011.7s a single embedding for representing the
3013.6s the the laten state because we just use
3016.1s the CLS token of the encoder and it
3018.9s allows us this some other tricks. Uh we didn't optimize a lot for that. We could we could go a lot less than than this but we we can go to a full planning time under the sum which is very nice like it's it's like almost 50 times faster. Another thing interesting is that if you
3021.8s didn't optimize a lot for that. We could
3023.9s we could go a lot less than than this
3026.4s but we we can go to a full planning time
3028.7s under the sum which is very nice like
3030.9s it's it's like almost 50 times faster.
3035.0s Another thing interesting is that if you
3036.8s fix the flop okay so if you fix the flop that you use for planning uh and we fix the flop at those of lower model. So if you reduce the planning time by tweaking the hyperparameter for the planning until a dino model plan under a second you can see that the success rate drop
3040.2s that you use for planning uh and we fix
3042.9s the flop at those of lower model. So if
3045.7s you reduce the planning time by tweaking
3048.1s the hyperparameter for the planning
3050.0s until a dino model plan under a second
3053.6s you can see that the success rate drop
3056.0s by a lot uh for push tm for ogbench cube which is expected but it show that um at at similar budget we outperform uh even on on ogbench cube dynino model. So that's for the control. For intuitive physics understanding, we did something pretty similar uh pretty easy. Sorry. The first thing is uh we probe the
3060.0s which is expected but it show that um at
3062.9s at similar budget we outperform uh even
3067.0s on on ogbench cube dynino model.
3072.3s So that's for the control. For intuitive
3074.2s physics understanding, we did something
3075.9s pretty similar uh pretty easy. Sorry.
3078.9s The first thing is uh we probe the
3081.0s latent space. So we just took the encoder from a lower model and we froze it and we train a linear or MLP probe nonlinear probe to try to predict the coefficients uh of the simulation for that state. Okay. So uh I will not go that much into the detail. You can have
3082.7s encoder from a lower model and we froze
3085.4s it and we train a linear or MLP probe
3089.0s nonlinear probe to try to predict the
3091.8s coefficients uh of the simulation for
3095.0s that state. Okay. So uh I will not go
3098.3s that much into the detail. You can have
3099.8s a look at the paper for that. But what is pretty interesting is that for instance for OG bench cube you can see that the linear probe with lower model we almost all the time have a lower mean square error which means that uh like if you compare with the MLP pro on
3101.7s is pretty interesting is that for
3103.4s instance for OG bench cube you can see
3105.5s that the linear probe with lower model
3108.2s we almost all the time have a lower mean
3110.6s square error which means that uh
3115.1s like if you compare with the MLP pro on
3117.7s the right uh it's almost dal model all the time that have the the lowest min error and so what it it suggest is that um our latent space is less is less entangle than the one of dino model. So it's more easy to recover the coefficient directly from the latent space rather than dino model and it does
3120.0s the time that have
3122.2s the the lowest min error and so what it
3125.4s it suggest is that um our latent space
3129.0s is less is less entangle than the one of
3131.8s dino model. So it's more easy to recover
3135.0s the coefficient directly from the latent
3137.2s space rather than dino model and it does
3140.5s make sense I think because as we use cg it push like the latent space in such a way where your each dimension are somewhat meaningless. So it's not fully disentangle but it's somewhat more than d model which is pretty nice. Another thing that you could do which is pretty cool is um to try to see what
3143.0s it push like the latent space in such a
3145.4s way where your each dimension are
3147.6s somewhat meaningless. So it's not fully
3149.4s disentangle but it's somewhat more than
3151.6s d model which is pretty nice.
3155.3s Another thing that you could do which is
3157.0s pretty cool is um to try to see what
3159.9s happen when you violate the war model. So if if if a sudden change in the dynamics happen does your world model uh predict that this is a violation. And so what we did for that is as you can see on the left sorry the cube color and cube teleportation are inverted but if
3162.1s So if if if a sudden change in the
3164.9s dynamics happen does your world model uh
3167.8s predict that this is a violation. And so
3170.7s what we did for that is as you can see
3172.9s on the left sorry the cube color and
3174.9s cube teleportation are inverted but if
3177.4s you look on the left you can see that we have a normal trajectory where your robotic arm pick up the cube and move it to the right for instance. Okay nothing happened and then we consider two perturbation for instance we change the color of the cube. So this is the one on
3179.2s have a normal trajectory where your
3182.2s robotic arm pick up the cube and move it
3184.9s to the right for instance. Okay nothing
3187.9s happened and then we consider two
3190.2s perturbation for instance we change the
3192.2s color of the cube. So this is the one on
3194.2s the right not in the middle suddenly at at a given frame. Okay. And another transformation we did is randomly like suddenly this is the one in the middle the cube teleport. Okay. And so if you look on the right um you can see that the x-axis is the are the time steps. Okay. And the y- axis is
3196.9s at a given frame. Okay. And another
3199.9s transformation we did is randomly like
3202.2s suddenly this is the one in the middle
3204.5s the cube teleport.
3207.4s Okay. And so if you look on the right um
3210.7s you can see that the x-axis is the are
3213.3s the time steps. Okay. And the y- axis is
3216.1s the prediction error. So this is the difference between what uh your predictor okay predict uh is going to happen and what the actual embedding of the next state look like. Okay. And so what we can see is that um if you don't have any perturbation, so this is the green the gray line. It's fine. If you
3218.1s difference between what uh your
3220.0s predictor okay predict uh is going to
3223.0s happen and what the actual embedding of
3225.6s the next state look like. Okay. And so
3228.6s what we can see is that um if you don't
3230.6s have any perturbation, so this is the
3232.3s green the gray line. It's fine. If you
3234.6s change the cube color, it's uh a bit higher, but it's very negligible. So it means that basically your war model don't care much about the color of the cube, which is pretty cool because you don't need that for the dynamics, right? You don't care about the color of the cube. But if suddenly the cube teleport
3237.2s higher, but it's very negligible. So it
3239.4s means that basically your war model
3241.1s don't care much about the color of the
3242.7s cube, which is pretty cool because you
3244.6s don't need that for the dynamics, right?
3246.8s You don't care about the color of the
3248.1s cube. But if suddenly the cube teleport
3251.4s then the the the prediction error shoot a lot meaning that your war model didn't predict that. Um some people say to me often that yeah but it just out of distribution. Um and I would say it's true but I think it's not very meaningful to say that because as human like when you violate your
3253.8s prediction error shoot a lot meaning
3255.4s that your war model didn't predict that.
3257.7s Um some people say to me often that yeah
3261.0s but it just out of distribution. Um and
3264.4s I would say it's true but I think it's
3266.2s not very meaningful to say that because
3268.1s as human like when you violate your
3270.2s model it's also very out of distribution. That's why for instance when someone do to you a magic trick where they make a coin disappear in front of you, you're suddenly very surprised and like it frustrates you a bit, you know. It's the same here. It's because the the prediction error is very
3271.7s distribution. That's why for instance
3273.4s when someone do to you a magic trick
3275.9s where they make a coin disappear in
3277.7s front of you, you're suddenly very
3279.1s surprised and like it frustrates you a
3281.1s bit, you know. It's the same here. It's
3282.9s because the the prediction error is very
3284.5s because the the prediction error is very high. So we did two other small experiments as well. We the first one is we um try to embed the location of the agent and the T and we try to make the agent location and the T location move. So you can see that uh the middle plot shows you what
3286.6s So we did two other small experiments as
3289.5s well. We the first one is we um try to
3292.6s embed the location of the agent and the
3295.5s T and we try to make the agent location
3298.6s and the T location move. So you can see
3301.2s that uh the middle plot shows you what
3304.8s happened like the different location and we try to project uh with a Disney plot the the embedding space to try to see if we can recover this exact relative distance between all the different location of the original space and what you can see is that up to permutation of the axis and rotation reflection um this
3307.2s we try to project uh with a Disney plot
3310.4s the the embedding space to try to see if
3314.2s we can recover this exact relative
3317.0s distance between all the different
3318.5s location of the original space and what
3320.9s you can see is that up to permutation of
3323.3s the axis and rotation reflection um this
3326.8s is exactly what happened. So we recover uh basically the the relative distance in the original space which is pretty in the original space which is pretty nice. The last thing we did is we throw the uh war model and we train a decoder to try to interpret what is happening when uh
3329.0s uh basically the the relative distance
3331.3s in the original space which is pretty
3333.0s in the original space which is pretty nice.
3334.8s The last thing we did is we throw the uh
3337.5s war model and we train a decoder to try
3340.6s to interpret what is happening when uh
3342.9s you make future prediction. And so for that you can see that um we give a context to our model. So this is the first three frame of the top row you see uh or the bottom the second row. Okay. So we give that to to our model. The first row is really what happened in the
3344.8s that you can see that um we give a
3348.6s context to our model. So this is the
3350.3s first three frame of the top row you see
3352.9s uh or the bottom the second row. Okay.
3355.1s So we give that to to our model. The
3357.4s first row is really what happened in the
3359.4s future on the right where where you have open loop prediction. So this is actually what happened in the trajectory and then on the second row you have what your war model imagine when you give the same sequence of action that has been take for the for the original sequence and you can see that the wall model uh
3361.4s open loop prediction. So this is
3363.0s actually what happened in the trajectory
3365.2s and then on the second row you have what
3367.4s your war model imagine when you give the
3369.8s same sequence of action that has been
3371.9s take for the for the original sequence
3373.8s and you can see that the wall model uh
3376.3s predict somewhat uh the reality when you give the the sequence of action. But what is very interesting is that if you look at the second uh roll out for the with the cube. Okay. So you can see that we we predict very well what is going to happen with the cube. But if you are
3379.8s give the the sequence of action. But
3382.5s what is very interesting is that if you
3384.1s look at the second uh roll out for the
3387.8s with the cube. Okay. So you can see that
3390.6s we we predict very well what is going to
3393.7s happen with the cube. But if you are
3395.4s very careful you can see that at frame 15 and 20 the the angle of the gripper is opposite. And so basically you can see that the world model didn't learn uh the rotation of the gripper which which was pretty interesting because it still was able to solve uh somewhat the was able to solve uh somewhat the environment.
3397.5s 15 and 20 the the angle of the gripper
3402.0s is opposite. And so basically you can
3404.0s see that the world model didn't learn uh
3406.7s the rotation of the gripper which which
3408.7s was pretty interesting because it still
3410.7s was able to solve uh somewhat the
3413.0s was able to solve uh somewhat the environment.
3415.6s Uh so there is many limitation for our model or as I like to call them research opportunities. So for now you are doomed to shortterm uh planning horizon. So if you can unlock that that would be very nice. Another problem is that um you reason at a single temporal level. So we need hierarchies basically you need for
3417.4s model or as I like to call them research
3419.4s opportunities. So for now you are doomed
3422.6s to shortterm uh planning horizon. So if
3425.6s you can unlock that that would be very
3427.6s nice. Another problem is that um you
3430.6s reason at a single temporal level. So we
3433.7s need hierarchies basically you need for
3435.7s instance you when you think about oh I need to go to the airport you think at a different hierarchy right the airport is your goal you think okay I need to go to my car then I need to go to airport and blah blah blah and then only at the end it translate into muscle movement right
3437.4s need to go to the airport you think at a
3440.5s different hierarchy right the airport is
3442.4s your goal you think okay I need to go to
3445.0s my car then I need to go to airport and
3447.1s blah blah blah and then only at the end
3450.0s it translate into muscle movement right
3452.5s so you don't all the time think in term of muscle movement so we need that as well to be able to predict further in the future a very important stuff as well is to um move from this toy environment like I I I fairly agree with with this kind of criticism that um it's
3454.5s of muscle movement so we need that as
3456.4s well to be able to predict further in
3458.9s the future a very important stuff as
3461.4s well is to um move from this toy
3464.4s environment like I I I fairly agree with
3467.0s with this kind of criticism that um it's
3469.6s very toyish experiments and so can we move to real world robotics or very stoastic and partially observable environment like Minecraft that would be very nice and also a big problem I think is how do you specify your goal so for now you can see as I explained before you need to provide a visual goal but
3472.0s move to real world robotics or very
3474.3s stoastic and partially observable
3476.3s environment like Minecraft that would be
3478.2s very nice and also a big problem I think
3480.9s is how do you specify your goal so for
3483.2s now you can see as I explained before
3485.0s you need to provide a visual goal but
3487.2s you don't all the time have access to that and it doesn't tell you anything about how you should solve the task. ask. So for instance um if you have a plane that need to land okay you don't want to just show a picture of the plane landed to do the planning you want to
3488.7s that and it doesn't tell you anything
3490.5s about how you should solve the task.
3491.8s ask. So for instance um if you have a
3494.6s plane that need to land okay you don't
3497.1s want to just show a picture of the plane
3499.0s landed to do the planning you want to
3501.0s specify how uh smooth should be the landing and everything. So this is we don't know how to do that um for now with this kind of approach. I would like also to take just two minute to do a bit of advertisement for for something we have been pushing with a lot of Randal
3503.9s landing and everything. So this is we
3505.9s don't know how to do that um for now
3508.2s with this kind of approach. I would like
3510.8s also to take just two minute to do a bit
3512.8s of advertisement for for something we
3515.4s have been pushing with a lot of Randal
3517.7s students and a lot of people for the past few months which is called stable war model and if you're interested about world model research and and this kind of stuff you should definitely have a look for that. So it's a a GitHub library fully open source that allows you to train very easily model. So you
3519.5s past few months which is called stable
3522.1s war model and if you're interested about
3524.3s world model research and and this kind
3527.0s of stuff you should definitely have a
3528.3s look for that. So it's a a GitHub
3531.0s library fully open source that allows
3533.0s you to train very easily model. So you
3535.3s have all the baseline I discuss about and Asel discuss about um that are implemented there and heavily tested. You have all solver to do planning. You have many environments. We recently added um the mind control and and Minecraft and we are in discussion to add support for real robot data also very soon. And everything is very
3537.0s and Asel discuss about um that are
3539.9s implemented there and heavily tested.
3541.7s You have all solver to do planning. You
3544.1s have many environments. We recently
3546.1s added um the mind control and and
3548.6s Minecraft and we are in discussion to
3550.6s add support for real robot data also
3553.0s very soon. And everything is very
3555.7s heavily test and there is documentation. Um so yeah feel free to give a try and give feedback or contribute to library that that would be very cool. Thank you. Okay, great. Thank you so much. Uh let's give a round of applause to our speakers give a round of applause to our speakers today.
3558.2s Um so yeah feel free to give a try and
3560.4s give feedback or contribute to library
3562.3s that that would be very cool.
3564.8s Thank you.
3567.4s Okay, great. Thank you so much. Uh let's
3569.8s give a round of applause to our speakers
3572.2s give a round of applause to our speakers today.
3579.7s questions. We have some online on the slidle um but I'll also be um looking for in-person questions in case anybody here um wants to ask anything as well. Um, so we'll kind of balance between both and um, I'll let you guys figure out like who should answer each question or uh, if you both have insights for
3581.4s slidle um but I'll also be um looking
3584.8s for in-person questions in case anybody
3586.7s here um wants to ask anything as well.
3589.0s Um, so we'll kind of balance between
3590.7s both and um, I'll let you guys figure
3593.8s out like who should answer each question
3595.8s or uh, if you both have insights for
3598.6s anything that that's also great. Um, yeah. Does anybody here have any yeah. Does anybody here have any questions background? I'm just curious how does role models will transfer to physical AI space like how can we see them moving forward in robotic space? I would I would give it to Lucas probably. Lucas, can you hear us?
3602.1s yeah. Does anybody here have any
3603.8s yeah. Does anybody here have any questions
3605.7s background? I'm just curious how does
3608.2s role models will transfer to physical AI
3610.6s space like how can we see them moving
3612.6s forward in robotic space?
3615.8s I would I would give it to Lucas
3617.5s probably. Lucas, can you hear us?
3620.1s Yes. Yes. Can you repeat the question, Yes. Yes. Can you repeat the question, please? Oh. Um, so how can we expect the world models uh to be um in the physical AI space because we already have a diffusion based um foundation models like GR or maybe uh pi and also size zero like using old models. Can we
3621.7s Yes. Yes. Can you repeat the question, please?
3622.5s Oh. Um, so how can we expect the world
3625.1s models uh to be um in the physical AI
3628.3s space because we already have a
3629.4s diffusion based um foundation models
3631.6s like GR or maybe uh pi and also size
3634.6s zero like using old models. Can we
3637.0s expect some more applications in the robotic space? Okay. Okay. Okay. So first of all uh I will be very skeptical that the current VA model have a good understanding of the world like there is no reason for that okay they are not trained to predict the consequence of their action okay and I would be very skeptical to
3639.1s robotic space?
3641.6s Okay. Okay. Okay. So first of all uh I
3644.7s will be very skeptical that the current
3647.0s VA model have a good understanding of
3649.0s the world like there is no reason for
3650.9s that okay they are not trained to
3653.5s predict the consequence of their action
3655.8s okay and I would be very skeptical to
3658.2s see um to see something that doesn't know what is the outcome of his action be very reliable to do physical AI so as human the why you're very good at what you do is because you can predict what is the consequence of your action in the real world and that's that's what war
3661.4s know what is the outcome of his action
3663.4s be very reliable to do physical AI so as
3667.0s human the why you're very good at what
3669.7s you do is because you can predict what
3671.7s is the consequence of your action in the
3674.6s real world and that's that's what war
3676.6s model try to do the LA they don't do that so if if you want to have physical AI basically you need war model you cannot you cannot bypass that I think um as Yan say often like VA and everything they they can be very helpful for simple stuff where you don't need to predict
3678.9s that so if if you want to have physical
3681.4s AI basically you need war model you
3683.2s cannot you cannot bypass that I think um
3686.5s as Yan say often like VA and everything
3689.8s they they can be very helpful for simple
3691.8s stuff where you don't need to predict
3693.6s what is going to be the the outcome of your action in the real world for instance like if you want to make um robots that I don't know uh uh do some dance or whatever. They only need to know their their internal dynamics. Okay? They don't need to know how to interact with the real world. To
3695.4s your action in the real world for
3697.4s instance like if you want to make um
3700.1s robots that I don't know uh uh do some
3702.7s dance or whatever. They only need to
3704.4s know their their internal dynamics.
3706.2s Okay? They don't need to know how to
3707.8s interact with the real world. To
3709.4s interact with the real world, you need to predict the outcome of your action. So, you need war models. I hope that answer your question. If not, happy to answer your question. If not, happy to clarify. Sure. Sure. Thank you. I would like to answer like uh at a one more sentence. Um I think Professor Sher
3710.6s to predict the outcome of your action.
3712.2s So, you need war models. I hope that
3713.8s answer your question. If not, happy to
3715.9s answer your question. If not, happy to clarify.
3716.7s Sure. Sure. Thank you.
3718.5s I would like to answer like uh at a one
3720.6s more sentence. Um I think Professor Sher
3723.1s Young in NYU, she has an opinion that like um our model can be a evaluator for policies. So I think this can be also your Instagram.
3725.5s like um our model can be a evaluator for
3729.0s policies. So I think this can be also
3731.0s your Instagram.
3738.3s Um someone's asking masking is shown to improve predictions but is it really necessary to learn the world model? Does it mean that prediction loss alone is it mean that prediction loss alone is insufficient? Uh the prediction loss. Okay, let me assume that you're talking about you're still talking about the objectentric representation. Um if you only do the
3741.3s improve predictions but is it really
3743.3s necessary to learn the world model? Does
3745.4s it mean that prediction loss alone is
3747.4s it mean that prediction loss alone is insufficient?
3749.3s Uh the prediction loss. Okay, let me
3752.6s assume that you're talking about you're
3754.2s still talking about the objectentric
3755.8s representation. Um if you only do the
3758.5s prediction loss with the objectcentric representation the shortcut or the the what model will learn it's it can be the self dynamics we cannot guarantee that model learns interaction based uh dynamics so I won't say that is insufficient but with the masking I think you can reinforce the model to learn the object
3760.2s representation the shortcut or the the
3763.1s what model will learn it's it can be the
3765.5s self dynamics we cannot guarantee that
3768.0s model learns interaction based uh
3770.0s dynamics so
3773.1s I won't say that is insufficient
3776.7s but with the masking I think you can
3779.6s reinforce the model to learn the object
3781.8s reinforce the model to learn the object dynamics.
3788.5s someone's asking, Do you have any insights on how CJ Jeepa learns to plan ahead beyond just the next frame?
3790.1s insights on how CJ Jeepa learns to plan
3792.9s ahead beyond just the next frame?
3803.7s So for the uh is it all the question? Yeah. Okay. Um, we strictly follow the evaluation method I mean the planning of the dino model. So with the predicted um feature frame we do it all the regressively to reach the long horizon planning and we use the um the same the parameters for planning as well.
3804.2s Okay. Um, we strictly follow the
3806.6s evaluation method I mean the planning of
3808.7s the dino model. So with the predicted um
3811.8s feature frame we do it all the
3813.7s regressively to reach the long horizon
3816.2s planning and we use the um the same the
3820.4s parameters for planning as well.
3832.0s because we have quite a lot. Um let me because we have quite a lot. Um let me see. native agent architecture look like compared to current LLM based agents? Lucas, do you want to answer? Of course. Can you just repeat the question? It's pretty difficult to um someone asked, What would a Japanese native agent architecture look like
3834.3s because we have quite a lot. Um let me see.
3839.8s native agent architecture look like
3841.5s compared to current LLM based agents?
3845.7s Lucas, do you want to answer?
3849.1s Of course. Can you just repeat the
3850.6s question? It's pretty difficult to um
3853.3s someone asked, What would a Japanese
3854.8s native agent architecture look like
3857.0s compared to current LLM based agents?
3866.6s is that you learn action condition model right so LLM agent you don't you don't learn action condition model as far as I know I'm not expert in LLM as well but um you learn to use tooling and to make tool calls this is not what you do um when you do uh war models what you do is
3868.8s right so LLM agent you don't you don't
3871.0s learn action condition model as far as I
3873.7s know I'm not expert in LLM as well but
3876.7s um you learn to use tooling and to make
3879.0s tool calls this is not what you do um
3882.0s when you do uh war models what you do is
3884.5s you learn okay I have an action I am a current state what is going to happen in the future if I take that action okay this is not how you train LLM so this is one step that differs and the second thing is that by construction okay you can just learn to model the the future
3886.6s current state what is going to happen in
3888.0s the future if I take that action okay
3890.1s this is not how you train LLM so this is
3892.2s one step that differs and the second
3894.4s thing is that by construction okay you
3897.1s can just learn to model the the future
3899.8s with LL with world models and then use planning by using old concept of control theory such as model predictive control or whatever to directly convert that into an agent you don't need post training or whatever you can do that zero shot so that's pretty neat um for LLM I'm not sure so I can I cannot reply
3902.5s planning by using old concept of control
3905.1s theory such as model predictive control
3907.1s or whatever to directly convert that
3909.2s into an agent you don't need post
3910.7s training or whatever you can do that
3912.2s zero shot so that's pretty neat um for
3915.0s LLM I'm not sure so I can I cannot reply
3917.9s for that I think that makes sense yeah Um,
3919.7s I think that makes sense yeah Um,
3928.8s let me see. Yeah. Do you think Japa models are less prone to hallucination than transformer-based models since they hopefully learn a more accurate representation of the world? Not sure if you guys might have any Not sure if you guys might have any insights. Uh, I can answer first and Lucas, you
3931.0s models are less prone to hallucination
3933.2s than transformer-based models since they
3935.8s hopefully learn a more accurate
3937.4s representation of the world?
3940.5s Not sure if you guys might have any
3941.9s Not sure if you guys might have any insights.
3944.5s Uh, I can answer first and Lucas, you
3947.5s can just add your opinion um on it. I'm not sure what kind of hallucination um the person asked but in the predictive sense because JEPA model is as I mentioned in the earlier part of the talk the JEPA is energy based model. So rather than just predicting the full pixel space what it evaluate is oh is
3951.5s I'm not sure what kind of hallucination
3954.2s um the person asked but in the
3957.8s predictive sense because JEPA model is
3963.6s as I mentioned in the earlier part of
3965.7s the talk the JEPA is energy based model.
3968.4s So rather than just predicting the full
3970.7s pixel space what it evaluate is oh is
3974.0s our predicted um future make sense like is it possible situation or like is it is it impossible so I think in that sense in the sense of the representation the representation will be more um suitable for war model I think. Yeah. I I think the original question was to compare Japa and transformer based
3976.0s um future make sense like is it possible
3979.0s situation or like is it is it impossible
3981.7s so I think in that sense in the sense of
3984.7s the representation the representation
3987.2s will be more
3989.3s um suitable for war model I think. Yeah.
3998.3s I I think the original question was to
4000.5s compare Japa and transformer based
4002.5s model. So I would like just to emphasize that JPA is not an architecture a new architecture like for instance for lower model orang and predictor or two transformer. So Jai is is really a framework to try to learn world model. It's not a new architecture. Okay. And uh second I think you can have
4004.6s that JPA is not an architecture a new
4007.3s architecture like for instance for lower
4009.6s model orang and predictor or two
4012.2s transformer. So Jai is is really a
4014.8s framework to try to learn world model.
4017.5s It's not a new architecture. Okay. And
4020.2s uh second I think you can have
4022.0s definitely hallucination if you don't learn a good model. Okay. But this is not at all the same hallucination as you have with LLM because LLM they have multiple sources from hallucination. One one of them is that they are not grounded in the real world. Okay. With war model you can easily fix that. Uh additionally also
4023.8s learn a good model. Okay. But this is
4026.5s not at all the same hallucination
4028.9s as you have with LLM
4032.2s because LLM they have multiple sources
4034.3s from hallucination. One one of them is
4036.6s that they are not grounded in the real
4038.2s world. Okay. With war model you can
4040.3s easily fix that. Uh additionally also
4043.7s you can say that with war model you can have another source of alin hallucination as well is that um when you optimize the action sequence that you want to do uh for doing planning you could have um action that are not meaningful for the real world. For instance let's say that your action is
4045.4s have another source of alin
4047.2s hallucination as well is that um when
4049.4s you optimize the action sequence that
4051.4s you want to do uh for doing planning you
4054.2s could have um action that are not
4056.5s meaningful for the real world. For
4058.6s instance let's say that your action is
4060.6s minus is between minus one and one but your optimizer say okay you should use minus2 as a as an action. So you can have this kind of hallucination otherwise it's just about the quality of your learn model and this is just a matter of data and capacity. Hopefully that makes sense.
4063.2s your optimizer say okay you should use
4065.6s minus2 as a as an action. So you can
4068.2s have this kind of hallucination
4070.2s otherwise it's just about the quality of
4072.0s your learn model and this is just a
4073.9s matter of data and capacity. Hopefully
4077.0s that makes sense.
4078.5s Yeah. No great I think that makes a lot of sense. Thanks. Um we have someone asking are world models better than the diffusion based models for robot diffusion based models for robot control. previous one. the world model can still use the diffusion model. you could use diffusion model to train JPA. Actually, some people try to do
4080.2s of sense. Thanks. Um we have someone
4083.4s asking are world models better than the
4085.3s diffusion based models for robot
4087.3s diffusion based models for robot control.
4094.3s previous one. the world model can still
4097.0s use the diffusion model.
4103.3s you could use diffusion model to train
4104.9s JPA. Actually, some people try to do
4106.6s that. Um, diffusion model are there have some advantage and some limitation as well. Um, so yeah. Um, any in-person questions? All right, we got one at the back. kind of a longer horizon question, but you mentioned that you might consider planning through this whole family of planning through this whole family of different algorithms developed for control like
4108.7s diffusion model are there have some
4110.6s advantage and some limitation as well.
4112.9s Um, so yeah.
4115.4s Um, any in-person questions?
4119.4s All right, we got one at the back.
4122.1s kind of a longer horizon question, but
4123.8s you mentioned that you might consider
4126.2s planning through this whole family of
4128.1s planning through this whole family of different
4129.6s algorithms developed for control like
4131.6s model predict control etc. And you mentioned that it might be possible to use these things as kind of a grounding signal to train policies. Do you think that it's kind of a combining those approaches if it maybe that's what you already meant by training policies though of turning up a lot of compute on
4133.2s mentioned that it might be possible to
4135.1s use these things as kind of a grounding
4136.6s signal to train policies. Do you think
4138.4s that it's kind of a combining those
4140.2s approaches if it maybe that's what you
4141.4s already meant by training policies
4142.9s though of turning up a lot of compute on
4144.8s something like an autoative controller for a relatively long horizon and then distilling those back down into policies might be an efficient way at getting some longer horizon behavior for one of those problems andor research directions you mentioned
4146.1s for a relatively long horizon and then
4147.7s distilling those back down into policies
4149.3s might be an efficient way at getting
4150.7s some longer horizon behavior for one of
4152.6s those problems andor research directions
4154.4s you mentioned
4162.2s I can go okay yeah okay yeah okay um yeah you can definitely train policy on that because as soon as you have our model basically you can do either two thing you can use that directly and optimize uh your action sequence to have zero shot policy or you can do same as dreamer for
4164.0s okay yeah okay yeah okay um yeah you can
4166.6s definitely train policy on that because
4169.0s as soon as you have our model basically
4170.6s you can do either two thing you can use
4172.5s that directly and optimize uh your
4175.1s action sequence to have zero shot policy
4178.3s or you can do same as dreamer for
4180.3s instance and use your pre-train world model and train reinforcement learning policy on that and periodically go in your environment and collect better data fine tune your model improve your policy and blah blah blah you can you can do that as well so um in the future I think you you will have something similar to
4181.9s model and train reinforcement learning
4184.4s policy on that and periodically go in
4187.0s your environment and collect better data
4189.9s fine tune your model improve your policy
4192.6s and blah blah blah you can you can do
4193.8s that as well so um in the future I think
4196.5s you you will have something similar to
4198.6s system one, system two same as human does where you will have um a condense like basically action reaction policy where you will learn to directly predict what is the action I should take to to have that goal given this state. Okay. And then for difficult task that that you are not very confident about
4200.9s does where you will have um a condense
4203.9s like basically action reaction policy
4206.3s where you will learn to directly predict
4208.9s what is the action I should take to to
4210.7s have that goal given this state. Okay.
4213.3s And then for difficult task that that
4215.2s you are not very confident about
4217.5s planning and model predictive control to have more careful um plan and situation. For instance, uh when you want to learn to drive a car, maybe at the the first 20 hours, you will use model creative control to make sure you don't kill someone. And then after 20 hours, you think you're an expert. So um you can
4219.5s have more careful um plan and situation.
4222.5s For instance, uh when you want to learn
4224.3s to drive a car, maybe at the the first
4226.7s 20 hours, you will use model creative
4228.6s control to make sure you don't kill
4230.0s someone. And then after 20 hours, you
4232.2s think you're an expert. So um you can
4235.2s basically distill uh your prediction of model predictive control and um planning to a direct policy that output directly the action. Right. Great. Unfortunately, that's all the time we have. Um, Hazel will be staying after in case anybody has any questions or wants to talk to her. Um, so thanks again to our speakers for the
4237.9s model predictive control and um planning
4242.3s to a direct policy that output directly
4245.1s the action.
4245.9s Right. Great. Unfortunately, that's all
4247.3s the time we have. Um, Hazel will be
4248.9s staying after in case anybody has any
4251.5s questions or wants to talk to her. Um,
4253.6s so thanks again to our speakers for the
4255.9s very insightful talk. Thank you.