1212from torch import nn ,einsum
1313import torchvision .transforms as T
1414
15- from einops import rearrange ,repeat ,reduce
15+ from einops import rearrange ,repeat ,reduce , pack , unpack
1616from einops .layers .torch import Rearrange
17- from einops_exts import rearrange_many ,repeat_many ,check_shape
18- from einops_exts .torch import EinopsToAndFrom
1917
2018from kornia .filters import gaussian_blur2d
2119import kornia .augmentation as K
@@ -669,6 +667,23 @@ def p2_reweigh_loss(self, loss, times):
669667return loss
670668return loss * extract (self .p2_loss_weight ,times ,loss .shape )
671669
670+ # rearrange image to sequence
671+
672+ class RearrangeToSequence (nn .Module ):
673+ def __init__ (self ,fn ):
674+ super ().__init__ ()
675+ self .fn = fn
676+
677+ def forward (self ,x ):
678+ x = rearrange (x ,'b c ... -> b ... c' )
679+ x ,ps = pack ([x ],'b * c' )
680+
681+ x = self .fn (x )
682+
683+ x ,= unpack (x ,ps ,'b * c' )
684+ x = rearrange (x ,'b ... c -> b c ...' )
685+ return x
686+
672687# diffusion prior
673688
674689class LayerNorm (nn .Module ):
@@ -867,7 +882,7 @@ def forward(self, x, mask = None, attn_bias = None):
867882
868883# add null key / value for classifier free guidance in prior net
869884
870- nk ,nv = repeat_many ( self . null_kv . unbind ( dim = - 2 ) ,'d -> b 1 d' ,b = b )
885+ nk ,nv = map ( lambda t : repeat ( t ,'d -> b 1 d' ,b = b ), self . null_kv . unbind ( dim = - 2 ) )
871886k = torch .cat ((nk ,k ),dim = - 2 )
872887v = torch .cat ((nv ,v ),dim = - 2 )
873888
@@ -1629,14 +1644,10 @@ def __init__(
16291644self .cross_attn = None
16301645
16311646if exists (cond_dim ):
1632- self .cross_attn = EinopsToAndFrom (
1633- 'b c h w' ,
1634- 'b (h w) c' ,
1635- CrossAttention (
1636- dim = dim_out ,
1637- context_dim = cond_dim ,
1638- cosine_sim = cosine_sim_cross_attn
1639- )
1647+ self .cross_attn = CrossAttention (
1648+ dim = dim_out ,
1649+ context_dim = cond_dim ,
1650+ cosine_sim = cosine_sim_cross_attn
16401651 )
16411652
16421653self .block1 = Block (dim ,dim_out ,groups = groups ,weight_standardization = weight_standardization )
@@ -1655,8 +1666,15 @@ def forward(self, x, time_emb = None, cond = None):
16551666
16561667if exists (self .cross_attn ):
16571668assert exists (cond )
1669+
1670+ h = rearrange (h ,'b c ... -> b ... c' )
1671+ h ,ps = pack ([h ],'b * c' )
1672+
16581673h = self .cross_attn (h ,context = cond )+ h
16591674
1675+ h ,= unpack (h ,ps ,'b * c' )
1676+ h = rearrange (h ,'b ... c -> b c ...' )
1677+
16601678h = self .block2 (h )
16611679return h + self .res_conv (x )
16621680
@@ -1702,11 +1720,11 @@ def forward(self, x, context, mask = None):
17021720
17031721q ,k ,v = (self .to_q (x ),* self .to_kv (context ).chunk (2 ,dim = - 1 ))
17041722
1705- q ,k ,v = rearrange_many (( q , k , v ) ,'b n (h d) -> b h n d' ,h = self .heads )
1723+ q ,k ,v = map ( lambda t : rearrange ( t ,'b n (h d) -> b h n d' ,h = self .heads ), ( q , k , v ) )
17061724
17071725# add null key / value for classifier free guidance in prior net
17081726
1709- nk ,nv = repeat_many ( self . null_kv . unbind ( dim = - 2 ) ,'d -> b h 1 d' ,h = self .heads ,b = b )
1727+ nk ,nv = map ( lambda t : repeat ( t ,'d -> b h 1 d' ,h = self .heads ,b = b ), self . null_kv . unbind ( dim = - 2 ) )
17101728
17111729k = torch .cat ((nk ,k ),dim = - 2 )
17121730v = torch .cat ((nv ,v ),dim = - 2 )
@@ -1759,7 +1777,7 @@ def forward(self, fmap):
17591777
17601778fmap = self .norm (fmap )
17611779q ,k ,v = self .to_qkv (fmap ).chunk (3 ,dim = 1 )
1762- q ,k ,v = rearrange_many (( q , k , v ) ,'b (h c) x y -> (b h) (x y) c' ,h = h )
1780+ q ,k ,v = map ( lambda t : rearrange ( t ,'b (h c) x y -> (b h) (x y) c' ,h = h ), ( q , k , v ) )
17631781
17641782q = q .softmax (dim = - 1 )
17651783k = k .softmax (dim = - 2 )
@@ -1993,7 +2011,7 @@ def __init__(
19932011
19942012self_attn = cast_tuple (self_attn ,num_stages )
19952013
1996- create_self_attn = lambda dim :EinopsToAndFrom ( 'b c h w' , 'b (h w) c' , Residual (Attention (dim ,** attn_kwargs )))
2014+ create_self_attn = lambda dim :RearrangeToSequence ( Residual (Attention (dim ,** attn_kwargs )))
19972015
19982016# resnet block klass
19992017
@@ -3230,7 +3248,7 @@ def forward(
32303248learned_variance = self .learned_variance [unet_index ]
32313249b ,c ,h ,w ,device ,= * image .shape ,image .device
32323250
3233- check_shape ( image , 'b c h w' , c = self .channels )
3251+ assert image . shape [ 1 ] == self .channels
32343252assert h >= target_image_size and w >= target_image_size
32353253
32363254times = torch .randint (0 ,noise_scheduler .num_timesteps , (b ,),device = device ,dtype = torch .long )