I want to understand you

but I feel frustrated
because I don't know how to

but it's okay
because since you're a brain
I just need to feed you your data
and you'll understand yourself

and since you're within me
if I listen closely, I will as well

## How to update multiple tensors using a single value with tf.scan

Corresponding Jupyter Notebook

I assume that you have a set of Tensors that you want to update with a sequence iteratively. E.g. you have a neural network that you’d like to update with a point at time t in a sequence and values from the network at time t-1. If you want to see this in full-fledged use, look at my jupyter notebook where I recreate the Variational Recurrent Neural Network!

This is the definition of scan:

tf.scan(
fn,
elems,
initializer=None,
parallel_iterations=10,
back_prop=True,
swap_memory=False,
infer_shape=True,
name=None
)


fn should follow the form fn(parameter_that_changes,parameter_you_change_with). This means that you can assume that your input from elem will always go to parameter_you_change_with, and that what you return should be parameter_that_changes. Writing it like a function looks something like the following

def fn(x, elem):
return new_x


where new_x will be x the next time fn is called. That took me some time to figure out.

import tensorflow as tf
import numpy as np

# Super Simple scan example as per: https://stackoverflow.com/questions/43841782/scan-function-in-theano-and-tensorflow
def f(x, ys):
(y1, y2) = ys
return x + y1 * y2

a = tf.constant([1, 2, 3, 4, 5])
b = tf.constant([2, 3, 2, 2, 1])
c = tf.scan(f, (a, b), initializer=0)
with tf.Session() as sess:
print(sess.run(c))

[ 2  8 14 22 27]

# updating 3 tensors with a single sequence
a1 = tf.Variable([0,0])
a2 = tf.Variable([1,1])
a3 = tf.Variable([2,2])

sequence = tf.Variable([1,2,3])

# using tf.multiply istead of '*', e.g. tf.multiply(x,2) instead of 2*x was key to this compiling...
def replace_one(old, x):
a1, a2, a3 = old

return [a1,a2,a3]

# key things that worked: initializer needed to match output.
# dumb mistake I can see tripping up many people
update = tf.scan(replace_one, sequence, initializer=[a1, a2, a3])

a1 = a1.assign(a2)
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
print(sess.run(update))


[array([[1, 1],
[3, 3],
[6, 6]], dtype=int32), array([[ 3,  3],
[ 7,  7],
[13, 13]], dtype=int32), array([[ 5,  5],
[11, 11],
[20, 20]], dtype=int32)]


# A few notes

So this ws more difficult to implement than I expected. I had to get all the ingredients perfectly right.

• While I can assign outside of scan, for some reason the tensors a1, a2, a3 couldn’t be assigned, i.e. a1.assign(tf.add(a1,tf.multiply(x,1))), inside of scan
• You can have all your values inside a single tensor for the initializer and update them via indexing. This also doesn’t work. i.e. with T=tf.concat([a1,a2,a3]), you can’t do T[0]=x
• I spent a long time trying to manually concatonate the values so that I could track them in the future only to learn that scan does this by default!! E.g., for a1, the corresponding output vector is [a1+1, a1+1+2, a1+1+2+3] since the elements were [1,2,3].