製作 while 迴圈

從 theano 0.9 開始,while 迴圈可以通過 theano.scan_module.scan_utils.until 完成。要使用,你應該在 scanfn 中返回 until 物件。

在下面的示例中,我們構建了一個函式,用於檢查複數是否在 Mandelbrot 集內。如果系列 z_{n+1} = z_{n}^2 + z_0 不收斂,則複數 z_0 在 mandelbrot 集內。

MAX_ITER = 256
BAILOUT = 2.
s_z0 = th.cscalar()
def iterate(s_i_, s_z_, s_z0_):
    return [s_z_*s_z_+s_z0_,s_i_+1], {}, until(T.abs_(s_z_)>BAILOUT)
(_1, s_niter), _2 = theano.scan(
    fn = iterate,
    outputs_info = [0, s_z0],
    non_sequences = [s_z0],
    n_steps = MAX_ITER
)
fn_mandelbrot_iters = theano.function([s_z0], s_niter)
def is_in_mandelbrot(z_):
    return fn_mandelbrot_iters(z_)>=MAX_ITER

is_in_mandelbrot(0.24+0.j) # True
is_in_mandelbrot(1.j) # True
is_in_mandelbrot(0.26+0.j) # False