How to update multiple tensors using a single value with tf.scan30 Jun 2017 | tutorial
I assume that you have a set of Tensors that you want to update with a sequence iteratively. E.g. you have a neural network that you’d like to update with a point at time t in a sequence and values from the network at time t-1. If you want to see this in full-fledged use, look at my jupyter notebook where I recreate the Variational Recurrent Neural Network!
This is the definition of scan:
tf.scan( fn, elems, initializer=None, parallel_iterations=10, back_prop=True, swap_memory=False, infer_shape=True, name=None )
fn should follow the form
fn(parameter_that_changes,parameter_you_change_with). This means that you can assume that your input from elem will always go to
parameter_you_change_with, and that what you return should be
Writing it like a function looks something like the following
def fn(x, elem): return new_x
new_x will be
x the next time
fn is called. That took me some time to figure out.
import tensorflow as tf import numpy as np
# Super Simple scan example as per: https://stackoverflow.com/questions/43841782/scan-function-in-theano-and-tensorflow def f(x, ys): (y1, y2) = ys return x + y1 * y2 a = tf.constant([1, 2, 3, 4, 5]) b = tf.constant([2, 3, 2, 2, 1]) c = tf.scan(f, (a, b), initializer=0) with tf.Session() as sess: print(sess.run(c))
[ 2 8 14 22 27]
# updating 3 tensors with a single sequence a1 = tf.Variable([0,0]) a2 = tf.Variable([1,1]) a3 = tf.Variable([2,2]) sequence = tf.Variable([1,2,3]) # using tf.multiply istead of '*', e.g. tf.multiply(x,2) instead of 2*x was key to this compiling... def replace_one(old, x): a1, a2, a3 = old a1 = tf.add(a1,tf.multiply(x,1)) a2 = tf.add(a2,tf.multiply(x,2)) a3 = tf.add(a3,tf.multiply(x,3)) return [a1,a2,a3] # key things that worked: initializer needed to match output. # dumb mistake I can see tripping up many people update = tf.scan(replace_one, sequence, initializer=[a1, a2, a3]) a1 = a1.assign(a2) with tf.Session() as sess: sess.run(tf.global_variables_initializer()) print(sess.run(update))
[array([[1, 1], [3, 3], [6, 6]], dtype=int32), array([[ 3, 3], [ 7, 7], [13, 13]], dtype=int32), array([[ 5, 5], [11, 11], [20, 20]], dtype=int32)]
A few notes
So this ws more difficult to implement than I expected. I had to get all the ingredients perfectly right.
- While I can assign outside of scan, for some reason the tensors a1, a2, a3 couldn’t be assigned, i.e.
a1.assign(tf.add(a1,tf.multiply(x,1))), inside of scan
- You can have all your values inside a single tensor for the initializer and update them via indexing. This also doesn’t work. i.e. with
T=tf.concat([a1,a2,a3]), you can’t do
- I spent a long time trying to manually concatonate the values so that I could track them in the future only to learn that scan does this by default!! E.g., for a1, the corresponding output vector is
[a1+1, a1+1+2, a1+1+2+3]since the elements were