Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commit29445be

Browse files
authored
Merge pull request#21880 from shaod2/py-25
Manually backport Pure Python recursion limit enforcement to 25.x
2 parents88a3b90 +cc13b69 commit29445be

File tree

6 files changed

+176
-27
lines changed

6 files changed

+176
-27
lines changed

‎python/build_targets.bzl‎

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -335,6 +335,12 @@ def build_targets(name):
335335
data= ["//src/google/protobuf:testdata"],
336336
)
337337

338+
internal_py_test(
339+
name="decoder_test",
340+
srcs= ["google/protobuf/internal/decoder_test.py"],
341+
data= ["//src/google/protobuf:testdata"],
342+
)
343+
338344
internal_py_test(
339345
name="proto_builder_test",
340346
srcs= ["google/protobuf/internal/proto_builder_test.py"],

‎python/google/protobuf/internal/decoder.py‎

Lines changed: 85 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -172,7 +172,10 @@ def SpecificDecoder(field_number, is_repeated, is_packed, key, new_default,
172172
clear_if_default=False):
173173
ifis_packed:
174174
local_DecodeVarint=_DecodeVarint
175-
defDecodePackedField(buffer,pos,end,message,field_dict):
175+
defDecodePackedField(
176+
buffer,pos,end,message,field_dict,current_depth=0
177+
):
178+
delcurrent_depth# unused
176179
value=field_dict.get(key)
177180
ifvalueisNone:
178181
value=field_dict.setdefault(key,new_default(message))
@@ -191,7 +194,10 @@ def DecodePackedField(buffer, pos, end, message, field_dict):
191194
elifis_repeated:
192195
tag_bytes=encoder.TagBytes(field_number,wire_type)
193196
tag_len=len(tag_bytes)
194-
defDecodeRepeatedField(buffer,pos,end,message,field_dict):
197+
defDecodeRepeatedField(
198+
buffer,pos,end,message,field_dict,current_depth=0
199+
):
200+
delcurrent_depth# unused
195201
value=field_dict.get(key)
196202
ifvalueisNone:
197203
value=field_dict.setdefault(key,new_default(message))
@@ -208,7 +214,8 @@ def DecodeRepeatedField(buffer, pos, end, message, field_dict):
208214
returnnew_pos
209215
returnDecodeRepeatedField
210216
else:
211-
defDecodeField(buffer,pos,end,message,field_dict):
217+
defDecodeField(buffer,pos,end,message,field_dict,current_depth=0):
218+
delcurrent_depth# unused
212219
(new_value,pos)=decode_value(buffer,pos)
213220
ifpos>end:
214221
raise_DecodeError('Truncated message.')
@@ -352,7 +359,9 @@ def EnumDecoder(field_number, is_repeated, is_packed, key, new_default,
352359
enum_type=key.enum_type
353360
ifis_packed:
354361
local_DecodeVarint=_DecodeVarint
355-
defDecodePackedField(buffer,pos,end,message,field_dict):
362+
defDecodePackedField(
363+
buffer,pos,end,message,field_dict,current_depth=0
364+
):
356365
"""Decode serialized packed enum to its value and a new position.
357366
358367
Args:
@@ -365,6 +374,7 @@ def DecodePackedField(buffer, pos, end, message, field_dict):
365374
Returns:
366375
int, new position in serialized data.
367376
"""
377+
delcurrent_depth# unused
368378
value=field_dict.get(key)
369379
ifvalueisNone:
370380
value=field_dict.setdefault(key,new_default(message))
@@ -405,7 +415,9 @@ def DecodePackedField(buffer, pos, end, message, field_dict):
405415
elifis_repeated:
406416
tag_bytes=encoder.TagBytes(field_number,wire_format.WIRETYPE_VARINT)
407417
tag_len=len(tag_bytes)
408-
defDecodeRepeatedField(buffer,pos,end,message,field_dict):
418+
defDecodeRepeatedField(
419+
buffer,pos,end,message,field_dict,current_depth=0
420+
):
409421
"""Decode serialized repeated enum to its value and a new position.
410422
411423
Args:
@@ -418,6 +430,7 @@ def DecodeRepeatedField(buffer, pos, end, message, field_dict):
418430
Returns:
419431
int, new position in serialized data.
420432
"""
433+
delcurrent_depth# unused
421434
value=field_dict.get(key)
422435
ifvalueisNone:
423436
value=field_dict.setdefault(key,new_default(message))
@@ -446,7 +459,7 @@ def DecodeRepeatedField(buffer, pos, end, message, field_dict):
446459
returnnew_pos
447460
returnDecodeRepeatedField
448461
else:
449-
defDecodeField(buffer,pos,end,message,field_dict):
462+
defDecodeField(buffer,pos,end,message,field_dict,current_depth=0):
450463
"""Decode serialized repeated enum to its value and a new position.
451464
452465
Args:
@@ -459,6 +472,7 @@ def DecodeField(buffer, pos, end, message, field_dict):
459472
Returns:
460473
int, new position in serialized data.
461474
"""
475+
delcurrent_depth# unused
462476
value_start_pos=pos
463477
(enum_value,pos)=_DecodeSignedVarint32(buffer,pos)
464478
ifpos>end:
@@ -540,7 +554,10 @@ def _ConvertToUnicode(memview):
540554
tag_bytes=encoder.TagBytes(field_number,
541555
wire_format.WIRETYPE_LENGTH_DELIMITED)
542556
tag_len=len(tag_bytes)
543-
defDecodeRepeatedField(buffer,pos,end,message,field_dict):
557+
defDecodeRepeatedField(
558+
buffer,pos,end,message,field_dict,current_depth=0
559+
):
560+
delcurrent_depth# unused
544561
value=field_dict.get(key)
545562
ifvalueisNone:
546563
value=field_dict.setdefault(key,new_default(message))
@@ -557,7 +574,8 @@ def DecodeRepeatedField(buffer, pos, end, message, field_dict):
557574
returnnew_pos
558575
returnDecodeRepeatedField
559576
else:
560-
defDecodeField(buffer,pos,end,message,field_dict):
577+
defDecodeField(buffer,pos,end,message,field_dict,current_depth=0):
578+
delcurrent_depth# unused
561579
(size,pos)=local_DecodeVarint(buffer,pos)
562580
new_pos=pos+size
563581
ifnew_pos>end:
@@ -581,7 +599,10 @@ def BytesDecoder(field_number, is_repeated, is_packed, key, new_default,
581599
tag_bytes=encoder.TagBytes(field_number,
582600
wire_format.WIRETYPE_LENGTH_DELIMITED)
583601
tag_len=len(tag_bytes)
584-
defDecodeRepeatedField(buffer,pos,end,message,field_dict):
602+
defDecodeRepeatedField(
603+
buffer,pos,end,message,field_dict,current_depth=0
604+
):
605+
delcurrent_depth# unused
585606
value=field_dict.get(key)
586607
ifvalueisNone:
587608
value=field_dict.setdefault(key,new_default(message))
@@ -598,7 +619,8 @@ def DecodeRepeatedField(buffer, pos, end, message, field_dict):
598619
returnnew_pos
599620
returnDecodeRepeatedField
600621
else:
601-
defDecodeField(buffer,pos,end,message,field_dict):
622+
defDecodeField(buffer,pos,end,message,field_dict,current_depth=0):
623+
delcurrent_depth# unused
602624
(size,pos)=local_DecodeVarint(buffer,pos)
603625
new_pos=pos+size
604626
ifnew_pos>end:
@@ -623,7 +645,9 @@ def GroupDecoder(field_number, is_repeated, is_packed, key, new_default):
623645
tag_bytes=encoder.TagBytes(field_number,
624646
wire_format.WIRETYPE_START_GROUP)
625647
tag_len=len(tag_bytes)
626-
defDecodeRepeatedField(buffer,pos,end,message,field_dict):
648+
defDecodeRepeatedField(
649+
buffer,pos,end,message,field_dict,current_depth=0
650+
):
627651
value=field_dict.get(key)
628652
ifvalueisNone:
629653
value=field_dict.setdefault(key,new_default(message))
@@ -632,7 +656,13 @@ def DecodeRepeatedField(buffer, pos, end, message, field_dict):
632656
ifvalueisNone:
633657
value=field_dict.setdefault(key,new_default(message))
634658
# Read sub-message.
635-
pos=value.add()._InternalParse(buffer,pos,end)
659+
current_depth+=1
660+
ifcurrent_depth>_recursion_limit:
661+
raise_DecodeError(
662+
'Error parsing message: too many levels of nesting.'
663+
)
664+
pos=value.add()._InternalParse(buffer,pos,end,current_depth)
665+
current_depth-=1
636666
# Read end tag.
637667
new_pos=pos+end_tag_len
638668
ifbuffer[pos:new_pos]!=end_tag_bytesornew_pos>end:
@@ -644,12 +674,16 @@ def DecodeRepeatedField(buffer, pos, end, message, field_dict):
644674
returnnew_pos
645675
returnDecodeRepeatedField
646676
else:
647-
defDecodeField(buffer,pos,end,message,field_dict):
677+
defDecodeField(buffer,pos,end,message,field_dict,current_depth=0):
648678
value=field_dict.get(key)
649679
ifvalueisNone:
650680
value=field_dict.setdefault(key,new_default(message))
651681
# Read sub-message.
652-
pos=value._InternalParse(buffer,pos,end)
682+
current_depth+=1
683+
ifcurrent_depth>_recursion_limit:
684+
raise_DecodeError('Error parsing message: too many levels of nesting.')
685+
pos=value._InternalParse(buffer,pos,end,current_depth)
686+
current_depth-=1
653687
# Read end tag.
654688
new_pos=pos+end_tag_len
655689
ifbuffer[pos:new_pos]!=end_tag_bytesornew_pos>end:
@@ -668,7 +702,9 @@ def MessageDecoder(field_number, is_repeated, is_packed, key, new_default):
668702
tag_bytes=encoder.TagBytes(field_number,
669703
wire_format.WIRETYPE_LENGTH_DELIMITED)
670704
tag_len=len(tag_bytes)
671-
defDecodeRepeatedField(buffer,pos,end,message,field_dict):
705+
defDecodeRepeatedField(
706+
buffer,pos,end,message,field_dict,current_depth=0
707+
):
672708
value=field_dict.get(key)
673709
ifvalueisNone:
674710
value=field_dict.setdefault(key,new_default(message))
@@ -679,18 +715,27 @@ def DecodeRepeatedField(buffer, pos, end, message, field_dict):
679715
ifnew_pos>end:
680716
raise_DecodeError('Truncated message.')
681717
# Read sub-message.
682-
ifvalue.add()._InternalParse(buffer,pos,new_pos)!=new_pos:
718+
current_depth+=1
719+
ifcurrent_depth>_recursion_limit:
720+
raise_DecodeError(
721+
'Error parsing message: too many levels of nesting.'
722+
)
723+
if (
724+
value.add()._InternalParse(buffer,pos,new_pos,current_depth)
725+
!=new_pos
726+
):
683727
# The only reason _InternalParse would return early is if it
684728
# encountered an end-group tag.
685729
raise_DecodeError('Unexpected end-group tag.')
686730
# Predict that the next tag is another copy of the same repeated field.
731+
current_depth-=1
687732
pos=new_pos+tag_len
688733
ifbuffer[new_pos:pos]!=tag_bytesornew_pos==end:
689734
# Prediction failed. Return.
690735
returnnew_pos
691736
returnDecodeRepeatedField
692737
else:
693-
defDecodeField(buffer,pos,end,message,field_dict):
738+
defDecodeField(buffer,pos,end,message,field_dict,current_depth=0):
694739
value=field_dict.get(key)
695740
ifvalueisNone:
696741
value=field_dict.setdefault(key,new_default(message))
@@ -699,11 +744,14 @@ def DecodeField(buffer, pos, end, message, field_dict):
699744
new_pos=pos+size
700745
ifnew_pos>end:
701746
raise_DecodeError('Truncated message.')
702-
# Read sub-message.
703-
ifvalue._InternalParse(buffer,pos,new_pos)!=new_pos:
747+
current_depth+=1
748+
ifcurrent_depth>_recursion_limit:
749+
raise_DecodeError('Error parsing message: too many levels of nesting.')
750+
ifvalue._InternalParse(buffer,pos,new_pos,current_depth)!=new_pos:
704751
# The only reason _InternalParse would return early is if it encountered
705752
# an end-group tag.
706753
raise_DecodeError('Unexpected end-group tag.')
754+
current_depth-=1
707755
returnnew_pos
708756
returnDecodeField
709757

@@ -859,7 +907,8 @@ def MapDecoder(field_descriptor, new_default, is_message_map):
859907
# Can't read _concrete_class yet; might not be initialized.
860908
message_type=field_descriptor.message_type
861909

862-
defDecodeMap(buffer,pos,end,message,field_dict):
910+
defDecodeMap(buffer,pos,end,message,field_dict,current_depth=0):
911+
delcurrent_depth# unused
863912
submsg=message_type._concrete_class()
864913
value=field_dict.get(key)
865914
ifvalueisNone:
@@ -941,8 +990,16 @@ def _SkipGroup(buffer, pos, end):
941990
returnpos
942991
pos=new_pos
943992

993+
DEFAULT_RECURSION_LIMIT=100
994+
_recursion_limit=DEFAULT_RECURSION_LIMIT
995+
996+
997+
defSetRecursionLimit(new_limit):
998+
global_recursion_limit
999+
_recursion_limit=new_limit
1000+
9441001

945-
def_DecodeUnknownFieldSet(buffer,pos,end_pos=None):
1002+
def_DecodeUnknownFieldSet(buffer,pos,end_pos=None,current_depth=0):
9461003
"""Decode UnknownFieldSet. Returns the UnknownFieldSet and new position."""
9471004

9481005
unknown_field_set=containers.UnknownFieldSet()
@@ -952,14 +1009,14 @@ def _DecodeUnknownFieldSet(buffer, pos, end_pos=None):
9521009
field_number,wire_type=wire_format.UnpackTag(tag)
9531010
ifwire_type==wire_format.WIRETYPE_END_GROUP:
9541011
break
955-
(data,pos)=_DecodeUnknownField(buffer,pos,wire_type)
1012+
(data,pos)=_DecodeUnknownField(buffer,pos,wire_type,current_depth)
9561013
# pylint: disable=protected-access
9571014
unknown_field_set._add(field_number,wire_type,data)
9581015

9591016
return (unknown_field_set,pos)
9601017

9611018

962-
def_DecodeUnknownField(buffer,pos,wire_type):
1019+
def_DecodeUnknownField(buffer,pos,wire_type,current_depth=0):
9631020
"""Decode a unknown field. Returns the UnknownField and new position."""
9641021

9651022
ifwire_type==wire_format.WIRETYPE_VARINT:
@@ -973,7 +1030,11 @@ def _DecodeUnknownField(buffer, pos, wire_type):
9731030
data=buffer[pos:pos+size].tobytes()
9741031
pos+=size
9751032
elifwire_type==wire_format.WIRETYPE_START_GROUP:
976-
(data,pos)=_DecodeUnknownFieldSet(buffer,pos)
1033+
current_depth+=1
1034+
ifcurrent_depth>=_recursion_limit:
1035+
raise_DecodeError('Error parsing message: too many levels of nesting.')
1036+
(data,pos)=_DecodeUnknownFieldSet(buffer,pos,None,current_depth)
1037+
current_depth-=1
9771038
elifwire_type==wire_format.WIRETYPE_END_GROUP:
9781039
return (0,-1)
9791040
else:
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
# -*- coding: utf-8 -*-
2+
# Protocol Buffers - Google's data interchange format
3+
# Copyright 2008 Google Inc. All rights reserved.
4+
#
5+
# Use of this source code is governed by a BSD-style
6+
# license that can be found in the LICENSE file or at
7+
# https://developers.google.com/open-source/licenses/bsd
8+
9+
"""Test decoder."""
10+
11+
importunittest
12+
13+
fromgoogle.protobufimportmessage
14+
fromgoogle.protobuf.internalimportdecoder
15+
fromgoogle.protobuf.internalimporttesting_refleaks
16+
fromgoogle.protobuf.internalimportwire_format
17+
18+
@testing_refleaks.TestCase
19+
classDecoderTest(unittest.TestCase):
20+
deftest_decode_unknown_group_field_too_many_levels(self):
21+
data=memoryview(b'\023'*5_000_000)
22+
self.assertRaisesRegex(
23+
message.DecodeError,
24+
'Error parsing message',
25+
decoder._DecodeUnknownField,
26+
data,
27+
1,
28+
wire_format.WIRETYPE_START_GROUP,
29+
1
30+
)
31+
32+
if__name__=='__main__':
33+
unittest.main()
34+

‎python/google/protobuf/internal/message_test.py‎

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,9 +30,11 @@
3030

3131
fromgoogle.protobuf.internalimportapi_implementation# pylint: disable=g-import-not-at-top
3232
fromgoogle.protobuf.internalimportencoder
33+
fromgoogle.protobuf.internalimportdecoder
3334
fromgoogle.protobuf.internalimportmore_extensions_pb2
3435
fromgoogle.protobuf.internalimportmore_messages_pb2
3536
fromgoogle.protobuf.internalimportpacked_field_test_pb2
37+
fromgoogle.protobuf.internalimportself_recursive_pb2
3638
fromgoogle.protobuf.internalimporttest_proto3_optional_pb2
3739
fromgoogle.protobuf.internalimporttest_util
3840
fromgoogle.protobuf.internalimporttesting_refleaks
@@ -1261,6 +1263,35 @@ def __eq__(self, other):
12611263
self.assertNotEqual(ComparesWithFoo(),m)
12621264

12631265

1266+
@testing_refleaks.TestCase
1267+
classTestRecursiveGroup(unittest.TestCase):
1268+
1269+
def_MakeRecursiveGroupMessage(self,n):
1270+
msg=self_recursive_pb2.SelfRecursive.RecursiveGroup()
1271+
sub=msg
1272+
for_inrange(n):
1273+
sub=sub.sub_group
1274+
sub.i=1
1275+
returnmsg.SerializeToString()
1276+
1277+
deftestRecursiveGroups(self):
1278+
recurse_msg=self_recursive_pb2.SelfRecursive.RecursiveGroup()
1279+
data=self._MakeRecursiveGroupMessage(100)
1280+
recurse_msg.ParseFromString(data)
1281+
self.assertTrue(recurse_msg.HasField('sub_group'))
1282+
1283+
deftestRecursiveGroupsException(self):
1284+
ifapi_implementation.Type()!='python':
1285+
api_implementation._c_module.SetAllowOversizeProtos(False)
1286+
recurse_msg=self_recursive_pb2.SelfRecursive.RecursiveGroup()
1287+
data=self._MakeRecursiveGroupMessage(300)
1288+
withself.assertRaises(message.DecodeError)ascontext:
1289+
recurse_msg.ParseFromString(data)
1290+
self.assertIn('Error parsing message',str(context.exception))
1291+
ifapi_implementation.Type()=='python':
1292+
self.assertIn('too many levels of nesting',str(context.exception))
1293+
1294+
12641295
# Class to test proto2-only features (required, extensions, etc.)
12651296
@testing_refleaks.TestCase
12661297
classProto2Test(unittest.TestCase):

0 commit comments

Comments
 (0)

[8]ページ先頭

©2009-2025 Movatter.jp