🎁Amazon Prime 📖Kindle Unlimited 🎧Audible Plus 🎵Amazon Music Unlimited 🌿iHerb 💰Binance
Video
Transcript
uh hi um my
talk is going to be about characterizing
test time compute on graph structured
problems
um most of my scholars project has been
spent thinking about
uh this question of whether we can uh
create models
that continuously improve their outputs
the more compute that we give them at
test time
this is something that i’ll call the
test time compute dream
and i think there’s much anthromorphic
motivation here after all as humans when
we’re being evaluated our answers
tend to become better the longer we’re
given to think
machine learning models for the most
part don’t exhibit this ability which
seems a little weird so i tend to bucket
this test time compute stuff
into two general categories one is
generalization
improvement mechanisms which deal with
the question of how can we create models
that use test time compute to learn more
general algorithms instead of learning
simple statistical associations and data
ideally we’d like these models to use
the extra compute
to resolve ambiguity and to correct and
refine
their own answers the second side of
this coin
is efficiency stuff and this deals with
the question of how we can
decouple the amount of parameters that a
model has
from uh the amount of time that it takes
to run the model at inference with the
motivation
here being that if we can construct
models that are larger
but that don’t incur a larger
computational cost for those extra
parameters than we would okay so how did
we
tackle this question
um the overwhelming vast majority of
this project was actually spent
on something i’ll likely only spend a
single slide talking about
in the interest of time and that’s the
shortest path task
the shortest path is a sequence to
sequence modeling task in which i give
the model a pair of
tokens representing pairs of u.s cities
and i expected to output
a sequence of target tokens that
represent the
shortest path between the destinations
the stuff i’ll mostly be presenting on
only really took shape in the past three
or four weeks and it involved
investigating some of these test time
properties
on uh graph neural networks operating
over the game of sudoku
okay like i said most of my project was
spent on the shortest pathwork
in which we were trying to answer this
question if we control the
flop the total flop budget of um
our models is there ever a point where
the test time performance of models like
the one that you see on the left
which use this sort of top layer
occurrence
ever begins to approach or match the
performance of models that don’t have
this recurrence but maybe are larger
have been trained for longer the way we
did this was by keeping the training
complete
budget in terms of flops fixed for all
the models and then training these
recurrent models
with a fixed number of time steps during
training with loss evaluated at every
single time step
and then during test time evaluating
them with more steps of recurrence
see if it ever reaches a point where the
extra compute
allows them to in some sense catch up to
the larger models that were trained
without this
recurrence long story short it largely
doesn’t seem
to work we never really see this sort of
phase transition
recurrence alone doesn’t seem to be
enough
to be clear if you run a linear probe on
the embedding space for these models
they actually seem to learn something
like the locations or something at least
isometric to the locations of the cities
fairly quickly which indicates that the
problem isn’t actually learning where
the cities are
it seems to be that even with the extra
recurrence the extra compute
learning a general shortest path
algorithm
is difficult occurrence alone doesn’t
seem to be enough
we need additional structure on top of
that
which is where the graph neural network
stuff comes in
so graph neural networks or networks
that operate on graph structured data
there are a few main parts the first
part is this input representation phase
where you pass in
your graph structured data x here
represents the nodes
in your graph which contain the features
that you care about
these could be the locations of u.s
cities or the values of cells on a
sudoku board
a represents the adjacencies which
encode some concept
of the edges of the graph in other words
what relationships
nodes have with each other the gnn
processes this graph by iteratively
performing a learned message passing
operation between the nodes where it
attempts to refine
its internal representation of those
nodes at the end
of this refinement phase we can then run
classification tasks on either the
individual nodes
or if we aggregate the nodes we can run
classification on the entire graph
okay a key feature of these gnns is this
graph refinement equation which i’ll
come back to
at least twice in this presentation um
it looks wild in its general form but
all it really is is
just three parts um it says that the
hidden state
for a node i is updated by a function
that takes in
the node embedding for that node and all
pairs of that node’s neighbors
passed through some function and then
aggregated using your favorite
aggregation function
cool okay so how do we do this for
sudoku
well every cell on the sudoku board
corresponds to a node
on the graph this node on the graph
the nodes on this graph refine their
representations by passing messages to
themselves or
or their neighbors using that graph
refinement equation we just saw
and now what’s typically done is that
you run this graph refinement phase for
a fixed number of times let’s say 10
time steps
and then at the very end you run your
linear projection and you make a
prediction
what we do a little differently here is
that we make a prediction at every point
along the graph refinement phase and we
evaluate the loss at every single point
as well
this allows the model to be more robust
to being evaluated during the graph for
fun
to being evaluated with more graph
refinement iterations at test time than
it was trained on
um at training
okay so how does this actually look like
in practice here’s one solving sudoku
this is real data by the way what’s cool
about this is that it appears to
prioritize spending the extra compute
resources
on attending to and refining tokens that
have assigned a
low probability high uncertainty to in
the previous time steps
the red things become more green and the
green things stay green
okay this is cool because it’s a sign
that the test time computer dream is at
least in principle
possible if we look at this graph which
shows the gnn
operating over two data sets one is
normal and the other is hard
we see generalization in two different
senses
one as we increase the amount of
iterations or test time compute
we see that the accuracy of the network
improves in an almost
monotonically increasing way by the way
the accuracy here is measured on the
sequence level which means that
i only count it if it gets the entire
board correct
the other sense is that if we give the
network problems that are harder
than the ones it was trained on it still
performs well okay
so if the argument here is that more
test time compute more iterations is
good
what would happen if we could evaluate
this model at infinite depth
in other words could we do better in
order to answer this question we need to
steal the machinery of deep equilibrium
models
now i don’t have a whole lot of time to
go into the details
of deep equilibrium models but i suggest
that you check out
the paper by xiao zubai or the europe’s
workshop from this past year the gist is
that deep equilibrium models
are inspired by the observation that we
can often rewrite
a standard neural network as an implicit
function
that instead of specifying explicitly
how to compute
the layer’s output as a function of its
input we instead specify the conditions
in which we would like the layer’s
output to satisfy
after rewriting these layers as implicit
functions it turns out
that most of them converge to a fixed
point which allows us to
instead of keeping track of the
intermediaries that graph refinement
phase in our auto grad library we could
instead
use an arbitrary black box root finding
algorithm
and to evaluate this convergence point
this is equivalent to running an
infinite depth weight tied feed forward
network
but has the notable advantage that we
can analytically back
propagate through this equilibrium point
using something called the implicit
function theorem
cool um yeah how’s this relevant to gnns
well if you take a look at that graph
refinement equation from earlier
it looks exactly like a fixed point
equation which means that we can apply
the machinery of deep equilibrium nuts
here if you try this out it actually
works really well with a big caveat that
i’ll related to early stopping that i’ll
get to in the next slide
these early training curves are
preliminary but kind of dramatic
the deep equilibrium sorry the deep
equilibrium
gnn trains a lot faster than the
traditional gnn
further because we’re using the
machinery of d people agreements to save
us from having to keep
track of the intermediate steps of that
graph refinement phase
in our auto grad library the memory
usage of the dp equilibrium
is smaller than the regular one as well
okay so what’s the caveat
well uh as far as i can tell every
single time i’ve been
uh i’ve run this i’ve run into this
weird collapse
that happens where it starts training
and it’s doing really great
and then it dies and i haven’t quite
been able to figure out why this happens
i suspect it has to do with the growth
of the spectral norm of the operators
inside the gnn
as it’s being evaluated by the fixed
point iterator but it also
just could be a bug in my code um
stopping training early when this
degeneracy begins is proven to be fine
and i’m still investigating the problem
but i just wanted to point this out for
completeness
okay shifting gears a little bit can we
do better
still in another way gnns are fine as
we’ve seen they seem to do well
on these relational reasoning style
tasks but one potential oddity is that
we must be explicit about the network
the structure of the network of the
graph that is we must explicitly tell
the network which nodes are connected to
each other nodes
for sudoku for instance we must be
explicit about saying that things that
are in the same row
things that are in the same column
things that are in the same cell
are connected could we instead learn the
adjacencies from scratch from the raw
unstructured data here’s the idea
okay transformers seem to be pretty good
at learning how
how relevant pairs of tokens are to each
other
on the other hand gnns are good at
operating over
structured data what if we could use the
tension head from a standard transformer
to extract an adjacency matrix which we
then feed into
the gnn here’s how it works we first
feed
a small transformer our input with a
small modification that at the top
layer we use the probability scores to
categorically sample the top k
indices which are the most relevant to
that particular token
that extracts k neighborhoods for each
token which we can then feed into our
tnn
now sampling indices is a
non-differentiable operation
however we can compensate for this by
using the surrogate loss thing outlined
below
this is taken from a paper by john
shulman and
uh it just provides a general framework
for gradient estimation through
stochastic compute graphs
the formulism just gives us a way to
convert stochastic compute graphs into
deterministic compute graphs
and evaluate a surrogate loss using
standard back propagation that provides
an
unbiased estimator of the gradient
through the stochastic node
cool okay so if you try this out it
works kind of um the reality is that it
just
trains much slower than the standard gnn
and
you know vanilla policy gradients are
high variance they’re kind of messy
but and the performance actually is
worse than the standard gnn but it does
show that in principle we could train a
gnn
from scratch that learns the adjacencies
from scratch as well which is
kind of cool okay
conclusion yeah so uh test time compute
mechanisms i think are largely
underexplored but hold much promise they
have the potential for improved
generalization mechanisms
potential for improved sample efficiency
i think
recurrence plus message passing seems to
be a really interesting combination
and if the methods of this presentation
seem
uh contrived that’s because they are but
ultimately like i’m
while the specific methods are kind of
crude i’m bullish on the idea of test
time compute in general and i think that
the next few years we’ll see
critical breakthroughs that make use of
ideas that have test time compute at
their core
that’s it i’d like to thank my mentor
will gus and i’d also like to thank
uh the program organizers and my cohort
and uh all the people that um gave me
early feedback
on uh some of this stuff and thank you
and now i’ll take questions
uh let’s see let me stop sharing
okay so this first question here
is
how do i extend this gnn setting to
sequence modeling
like the language modeling loss in uh
gpt
yeah so um you could imagine that each
output token corresponds to either
uh yeah you could do this in a couple
ways like you could imagine like an auto
regressive type thing
where like you’re at each point
evaluating the state of the entire graph
and outputting an output sequence and
then feeding that output sequence into
uh the sort of beginning of the model
and then running this again
doing this sort of auto aggressively is
one way um
and then yeah but i’m sure there are
other ways that i’m just not
familiar with but yeah so what type of
problems you expect test time to
compute to really shine in yeah this is
a great question i think
sorry my dog gets really excited here uh
i think ultimately
uh test time compute will shine in
problems that really have sort of these
relational reasoning style tasks where
we need to
relate our previous outputs to
things that we’re currently processing
or problems where we need to condition
the amount of computes that we do on the
complexity of our inputs
okay uh does the stochastic compute
graph mean that gnns can be applied to
settings without inductive biases
that’s the ultimate hope i think this is
just very crude early work that shows
that you could
potentially also just learn the
adjacencies without
hand uh baking the inductive bias though
i mean i think part of the appeal of
graph neural networks is that like
they’re so easy to bake in inductive
biases that you just
feed in the graph as
it is and that is the inductive bias for
your data
so there’s definitely a trade-off here
and it’s not like super clear that doing
this is like
always the right thing to do okay last
question
what does it look like if you threshold
the learned adjacency
weights to produce a discrete graph
structure is
is this roughly right threshold the
learned adjacency weights
oh right so yeah that’s that’s a good
point
these are discrete that’s the whole
point of the sampling thing
is that the adjacencies are the indices
for each token which correspond to the
other tokens that are
are um they’re near it so this isn’t
like a tension where we’re doing like a
soft max over
over um over the other output tokens
like
we’re using the transformers probability
scores
and then we’re discretely sampling like
which we’re using the transformers
attention
uh scores asks our weights
for our discrete sampling if that makes
sense
um but yeah cool
i think i’m over time so um
yeah i’ll hand it back over to francis i
guess