Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commit0069857

Browse files
committed
remove einops exts for better pytorch 2.0 compile compatibility
1 parent580274b commit0069857

File tree

4 files changed

+39
-23
lines changed

4 files changed

+39
-23
lines changed

‎dalle2_pytorch/dalle2_pytorch.py‎

Lines changed: 35 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,8 @@
1212
fromtorchimportnn,einsum
1313
importtorchvision.transformsasT
1414

15-
fromeinopsimportrearrange,repeat,reduce
15+
fromeinopsimportrearrange,repeat,reduce,pack,unpack
1616
fromeinops.layers.torchimportRearrange
17-
fromeinops_extsimportrearrange_many,repeat_many,check_shape
18-
fromeinops_exts.torchimportEinopsToAndFrom
1917

2018
fromkornia.filtersimportgaussian_blur2d
2119
importkornia.augmentationasK
@@ -669,6 +667,23 @@ def p2_reweigh_loss(self, loss, times):
669667
returnloss
670668
returnloss*extract(self.p2_loss_weight,times,loss.shape)
671669

670+
# rearrange image to sequence
671+
672+
classRearrangeToSequence(nn.Module):
673+
def__init__(self,fn):
674+
super().__init__()
675+
self.fn=fn
676+
677+
defforward(self,x):
678+
x=rearrange(x,'b c ... -> b ... c')
679+
x,ps=pack([x],'b * c')
680+
681+
x=self.fn(x)
682+
683+
x,=unpack(x,ps,'b * c')
684+
x=rearrange(x,'b ... c -> b c ...')
685+
returnx
686+
672687
# diffusion prior
673688

674689
classLayerNorm(nn.Module):
@@ -867,7 +882,7 @@ def forward(self, x, mask = None, attn_bias = None):
867882

868883
# add null key / value for classifier free guidance in prior net
869884

870-
nk,nv=repeat_many(self.null_kv.unbind(dim=-2),'d -> b 1 d',b=b)
885+
nk,nv=map(lambdat:repeat(t,'d -> b 1 d',b=b),self.null_kv.unbind(dim=-2))
871886
k=torch.cat((nk,k),dim=-2)
872887
v=torch.cat((nv,v),dim=-2)
873888

@@ -1629,14 +1644,10 @@ def __init__(
16291644
self.cross_attn=None
16301645

16311646
ifexists(cond_dim):
1632-
self.cross_attn=EinopsToAndFrom(
1633-
'b c h w',
1634-
'b (h w) c',
1635-
CrossAttention(
1636-
dim=dim_out,
1637-
context_dim=cond_dim,
1638-
cosine_sim=cosine_sim_cross_attn
1639-
)
1647+
self.cross_attn=CrossAttention(
1648+
dim=dim_out,
1649+
context_dim=cond_dim,
1650+
cosine_sim=cosine_sim_cross_attn
16401651
)
16411652

16421653
self.block1=Block(dim,dim_out,groups=groups,weight_standardization=weight_standardization)
@@ -1655,8 +1666,15 @@ def forward(self, x, time_emb = None, cond = None):
16551666

16561667
ifexists(self.cross_attn):
16571668
assertexists(cond)
1669+
1670+
h=rearrange(h,'b c ... -> b ... c')
1671+
h,ps=pack([h],'b * c')
1672+
16581673
h=self.cross_attn(h,context=cond)+h
16591674

1675+
h,=unpack(h,ps,'b * c')
1676+
h=rearrange(h,'b ... c -> b c ...')
1677+
16601678
h=self.block2(h)
16611679
returnh+self.res_conv(x)
16621680

@@ -1702,11 +1720,11 @@ def forward(self, x, context, mask = None):
17021720

17031721
q,k,v= (self.to_q(x),*self.to_kv(context).chunk(2,dim=-1))
17041722

1705-
q,k,v=rearrange_many((q,k,v),'b n (h d) -> b h n d',h=self.heads)
1723+
q,k,v=map(lambdat:rearrange(t,'b n (h d) -> b h n d',h=self.heads), (q,k,v))
17061724

17071725
# add null key / value for classifier free guidance in prior net
17081726

1709-
nk,nv=repeat_many(self.null_kv.unbind(dim=-2),'d -> b h 1 d',h=self.heads,b=b)
1727+
nk,nv=map(lambdat:repeat(t,'d -> b h 1 d',h=self.heads,b=b),self.null_kv.unbind(dim=-2))
17101728

17111729
k=torch.cat((nk,k),dim=-2)
17121730
v=torch.cat((nv,v),dim=-2)
@@ -1759,7 +1777,7 @@ def forward(self, fmap):
17591777

17601778
fmap=self.norm(fmap)
17611779
q,k,v=self.to_qkv(fmap).chunk(3,dim=1)
1762-
q,k,v=rearrange_many((q,k,v),'b (h c) x y -> (b h) (x y) c',h=h)
1780+
q,k,v=map(lambdat:rearrange(t,'b (h c) x y -> (b h) (x y) c',h=h), (q,k,v))
17631781

17641782
q=q.softmax(dim=-1)
17651783
k=k.softmax(dim=-2)
@@ -1993,7 +2011,7 @@ def __init__(
19932011

19942012
self_attn=cast_tuple(self_attn,num_stages)
19952013

1996-
create_self_attn=lambdadim:EinopsToAndFrom('b c h w','b (h w) c',Residual(Attention(dim,**attn_kwargs)))
2014+
create_self_attn=lambdadim:RearrangeToSequence(Residual(Attention(dim,**attn_kwargs)))
19972015

19982016
# resnet block klass
19992017

@@ -3230,7 +3248,7 @@ def forward(
32303248
learned_variance=self.learned_variance[unet_index]
32313249
b,c,h,w,device,=*image.shape,image.device
32323250

3233-
check_shape(image,'b c h w',c=self.channels)
3251+
assertimage.shape[1]==self.channels
32343252
asserth>=target_image_sizeandw>=target_image_size
32353253

32363254
times=torch.randint(0,noise_scheduler.num_timesteps, (b,),device=device,dtype=torch.long)

‎dalle2_pytorch/version.py‎

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
__version__='1.12.4'
1+
__version__='1.14.0'

‎dalle2_pytorch/vqgan_vae.py‎

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,7 @@
1111
fromtorch.autogradimportgradastorch_grad
1212
importtorchvision
1313

14-
fromeinopsimportrearrange,reduce,repeat
15-
fromeinops_extsimportrearrange_many
14+
fromeinopsimportrearrange,reduce,repeat,pack,unpack
1615
fromeinops.layers.torchimportRearrange
1716

1817
# constants
@@ -408,7 +407,7 @@ def forward(self, x):
408407
x=self.norm(x)
409408

410409
q,k,v=self.to_qkv(x).chunk(3,dim=-1)
411-
q,k,v=rearrange_many((q,k,v),'b n (h d) -> b h n d',h=h)
410+
q,k,v=map(lambdat:rearrange(t,'b n (h d) -> b h n d',h=h), (q,k,v))
412411

413412
q=q*self.scale
414413
sim=einsum('b h i d, b h j d -> b h i j',q,k)

‎setup.py‎

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,8 +30,7 @@
3030
'clip-anytorch>=2.5.2',
3131
'coca-pytorch>=0.0.5',
3232
'ema-pytorch>=0.0.7',
33-
'einops>=0.4',
34-
'einops-exts>=0.0.3',
33+
'einops>=0.6',
3534
'embedding-reader',
3635
'kornia>=0.5.4',
3736
'numpy',

0 commit comments

Comments
 (0)

[8]ページ先頭

©2009-2025 Movatter.jp