@@ -182,8 +182,6 @@ struct SpatialTransformer {
182182
183183 std::vector<Transformer> transformers;
184184
185- struct ggml_tensor * attn_scale;
186-
187185// proj_out
188186struct ggml_tensor * proj_out_w;// [in_channels, in_channels, 1, 1]
189187struct ggml_tensor * proj_out_b;// [in_channels,]
@@ -202,7 +200,6 @@ struct SpatialTransformer {
202200 mem_size +=2 * in_channels *ggml_type_sizef (GGML_TYPE_F32);// norm_w/norm_b
203201 mem_size +=2 * in_channels * in_channels *1 *1 *ggml_type_sizef (GGML_TYPE_F16);// proj_in_w/proj_out_w
204202 mem_size +=2 * in_channels *ggml_type_sizef (GGML_TYPE_F32);// proj_in_b/proj_out_b
205- mem_size +=1 *ggml_type_sizef (GGML_TYPE_F32);// attn_scale
206203
207204// transformer
208205for (auto & transformer : transformers) {
@@ -226,11 +223,6 @@ struct SpatialTransformer {
226223 proj_out_w =ggml_new_tensor_4d (ctx, GGML_TYPE_F16,1 ,1 , in_channels, in_channels);
227224 proj_out_b =ggml_new_tensor_1d (ctx, GGML_TYPE_F32, in_channels);
228225
229- attn_scale =ggml_new_tensor_1d (ctx, GGML_TYPE_F32,1 );
230- ggml_allocr_alloc (alloc, attn_scale);
231- float scale =1 .0f /sqrt ((float )d_head);
232- ggml_backend_tensor_set (attn_scale, &scale,0 ,sizeof (scale));
233-
234226// transformer
235227for (auto & transformer : transformers) {
236228 transformer.norm1_w =ggml_new_tensor_1d (ctx, GGML_TYPE_F32, in_channels);
@@ -332,7 +324,7 @@ struct SpatialTransformer {
332324 x =ggml_reshape_2d (ctx, x, c, h * w * n);// [N * h * w, in_channels]
333325struct ggml_tensor * q =ggml_mul_mat (ctx, transformer.attn1_q_w , x);// [N * h * w, in_channels]
334326#if !defined(SD_USE_FLASH_ATTENTION) || defined(SD_USE_CUBLAS) || defined(SD_USE_METAL)
335- q =ggml_scale_inplace (ctx, q,attn_scale );
327+ q =ggml_scale_inplace (ctx, q,1 . 0f / sqrt (( float )d_head) );
336328#endif
337329 q =ggml_reshape_4d (ctx, q, d_head, n_head, h * w, n);// [N, h * w, n_head, d_head]
338330 q =ggml_cont (ctx,ggml_permute (ctx, q,0 ,2 ,1 ,3 ));// [N, n_head, h * w, d_head]
@@ -380,7 +372,7 @@ struct SpatialTransformer {
380372 context =ggml_reshape_2d (ctx, context, context->ne [0 ], context->ne [1 ] * context->ne [2 ]);// [N * max_position, hidden_size]
381373struct ggml_tensor * q =ggml_mul_mat (ctx, transformer.attn2_q_w , x);// [N * h * w, in_channels]
382374#if !defined(SD_USE_FLASH_ATTENTION) || defined(SD_USE_CUBLAS) || defined(SD_USE_METAL)
383- q =ggml_scale_inplace (ctx, q,attn_scale );
375+ q =ggml_scale_inplace (ctx, q,1 . 0f / sqrt (( float )d_head) );
384376#endif
385377 q =ggml_reshape_4d (ctx, q, d_head, n_head, h * w, n);// [N, h * w, n_head, d_head]
386378 q =ggml_cont (ctx,ggml_permute (ctx, q,0 ,2 ,1 ,3 ));// [N, n_head, h * w, d_head]