@@ -172,6 +172,49 @@ get_zstd_state(PyObject *module)
172
172
return (_zstd_state * )state ;
173
173
}
174
174
175
+ static Py_ssize_t
176
+ calculate_samples_stats (PyBytesObject * samples_bytes ,PyObject * samples_sizes ,
177
+ size_t * * chunk_sizes )
178
+ {
179
+ Py_ssize_t chunks_number ;
180
+ Py_ssize_t sizes_sum ;
181
+ Py_ssize_t i ;
182
+
183
+ chunks_number = Py_SIZE (samples_sizes );
184
+ if ((size_t )chunks_number > UINT32_MAX ) {
185
+ PyErr_Format (PyExc_ValueError ,
186
+ "The number of samples should be <= %u." ,UINT32_MAX );
187
+ return -1 ;
188
+ }
189
+
190
+ /* Prepare chunk_sizes */
191
+ * chunk_sizes = PyMem_New (size_t ,chunks_number );
192
+ if (* chunk_sizes == NULL ) {
193
+ PyErr_NoMemory ();
194
+ return -1 ;
195
+ }
196
+
197
+ sizes_sum = 0 ;
198
+ for (i = 0 ;i < chunks_number ;i ++ ) {
199
+ PyObject * size = PyTuple_GetItem (samples_sizes ,i );
200
+ (* chunk_sizes )[i ]= PyLong_AsSize_t (size );
201
+ if ((* chunk_sizes )[i ]== (size_t )-1 && PyErr_Occurred ()) {
202
+ PyErr_Format (PyExc_ValueError ,
203
+ "Items in samples_sizes should be an int "
204
+ "object, with a value between 0 and %u." ,SIZE_MAX );
205
+ return -1 ;
206
+ }
207
+ sizes_sum += (* chunk_sizes )[i ];
208
+ }
209
+
210
+ if (sizes_sum != Py_SIZE (samples_bytes )) {
211
+ PyErr_SetString (PyExc_ValueError ,
212
+ "The samples size tuple doesn't match the concatenation's size." );
213
+ return -1 ;
214
+ }
215
+ return chunks_number ;
216
+ }
217
+
175
218
176
219
/*[clinic input]
177
220
_zstd.train_dict
@@ -192,54 +235,25 @@ _zstd_train_dict_impl(PyObject *module, PyBytesObject *samples_bytes,
192
235
PyObject * samples_sizes ,Py_ssize_t dict_size )
193
236
/*[clinic end generated code: output=8e87fe43935e8f77 input=d20dedb21c72cb62]*/
194
237
{
195
- // TODO(emmatyping): The preamble and suffix to this function and _finalize_dict
196
- // are pretty similar. We should see if we can refactor them to share that code.
197
- Py_ssize_t chunks_number ;
198
- size_t * chunk_sizes = NULL ;
199
238
PyObject * dst_dict_bytes = NULL ;
239
+ size_t * chunk_sizes = NULL ;
240
+ Py_ssize_t chunks_number ;
200
241
size_t zstd_ret ;
201
- Py_ssize_t sizes_sum ;
202
- Py_ssize_t i ;
203
242
204
243
/* Check arguments */
205
244
if (dict_size <=0 ) {
206
245
PyErr_SetString (PyExc_ValueError ,"dict_size argument should be positive number." );
207
246
return NULL ;
208
247
}
209
248
210
- chunks_number = Py_SIZE (samples_sizes );
211
- if ((size_t )chunks_number > UINT32_MAX ) {
212
- PyErr_Format (PyExc_ValueError ,
213
- "The number of samples should be <= %u." ,UINT32_MAX );
249
+ /* Check that the samples are valid and get their sizes */
250
+ chunks_number = calculate_samples_stats (samples_bytes ,samples_sizes ,
251
+ & chunk_sizes );
252
+ if (chunks_number < 0 )
253
+ {
214
254
return NULL ;
215
255
}
216
256
217
- /* Prepare chunk_sizes */
218
- chunk_sizes = PyMem_New (size_t ,chunks_number );
219
- if (chunk_sizes == NULL ) {
220
- PyErr_NoMemory ();
221
- gotoerror ;
222
- }
223
-
224
- sizes_sum = 0 ;
225
- for (i = 0 ;i < chunks_number ;i ++ ) {
226
- PyObject * size = PyTuple_GetItem (samples_sizes ,i );
227
- chunk_sizes [i ]= PyLong_AsSize_t (size );
228
- if (chunk_sizes [i ]== (size_t )-1 && PyErr_Occurred ()) {
229
- PyErr_Format (PyExc_ValueError ,
230
- "Items in samples_sizes should be an int "
231
- "object, with a value between 0 and %u." ,SIZE_MAX );
232
- gotoerror ;
233
- }
234
- sizes_sum += chunk_sizes [i ];
235
- }
236
-
237
- if (sizes_sum != Py_SIZE (samples_bytes )) {
238
- PyErr_SetString (PyExc_ValueError ,
239
- "The samples size tuple doesn't match the concatenation's size." );
240
- gotoerror ;
241
- }
242
-
243
257
/* Allocate dict buffer */
244
258
dst_dict_bytes = PyBytes_FromStringAndSize (NULL ,dict_size );
245
259
if (dst_dict_bytes == NULL ) {
@@ -307,48 +321,21 @@ _zstd_finalize_dict_impl(PyObject *module, PyBytesObject *custom_dict_bytes,
307
321
PyObject * dst_dict_bytes = NULL ;
308
322
size_t zstd_ret ;
309
323
ZDICT_params_t params ;
310
- Py_ssize_t sizes_sum ;
311
- Py_ssize_t i ;
312
324
313
325
/* Check arguments */
314
326
if (dict_size <=0 ) {
315
327
PyErr_SetString (PyExc_ValueError ,"dict_size argument should be positive number." );
316
328
return NULL ;
317
329
}
318
330
319
- chunks_number = Py_SIZE (samples_sizes );
320
- if ((size_t )chunks_number > UINT32_MAX ) {
321
- PyErr_Format (PyExc_ValueError ,
322
- "The number of samples should be <= %u." ,UINT32_MAX );
331
+ /* Check that the samples are valid and get their sizes */
332
+ chunks_number = calculate_samples_stats (samples_bytes ,samples_sizes ,
333
+ & chunk_sizes );
334
+ if (chunks_number < 0 )
335
+ {
323
336
return NULL ;
324
337
}
325
338
326
- /* Prepare chunk_sizes */
327
- chunk_sizes = PyMem_New (size_t ,chunks_number );
328
- if (chunk_sizes == NULL ) {
329
- PyErr_NoMemory ();
330
- gotoerror ;
331
- }
332
-
333
- sizes_sum = 0 ;
334
- for (i = 0 ;i < chunks_number ;i ++ ) {
335
- PyObject * size = PyTuple_GetItem (samples_sizes ,i );
336
- chunk_sizes [i ]= PyLong_AsSize_t (size );
337
- if (chunk_sizes [i ]== (size_t )-1 && PyErr_Occurred ()) {
338
- PyErr_Format (PyExc_ValueError ,
339
- "Items in samples_sizes should be an int "
340
- "object, with a value between 0 and %u." ,SIZE_MAX );
341
- gotoerror ;
342
- }
343
- sizes_sum += chunk_sizes [i ];
344
- }
345
-
346
- if (sizes_sum != Py_SIZE (samples_bytes )) {
347
- PyErr_SetString (PyExc_ValueError ,
348
- "The samples size tuple doesn't match the concatenation's size." );
349
- gotoerror ;
350
- }
351
-
352
339
/* Allocate dict buffer */
353
340
dst_dict_bytes = PyBytes_FromStringAndSize (NULL ,dict_size );
354
341
if (dst_dict_bytes == NULL ) {