@@ -172,7 +172,10 @@ def SpecificDecoder(field_number, is_repeated, is_packed, key, new_default,
172172clear_if_default = False ):
173173if is_packed :
174174local_DecodeVarint = _DecodeVarint
175- def DecodePackedField (buffer ,pos ,end ,message ,field_dict ):
175+ def DecodePackedField (
176+ buffer ,pos ,end ,message ,field_dict ,current_depth = 0
177+ ):
178+ del current_depth # unused
176179value = field_dict .get (key )
177180if value is None :
178181value = field_dict .setdefault (key ,new_default (message ))
@@ -191,7 +194,10 @@ def DecodePackedField(buffer, pos, end, message, field_dict):
191194elif is_repeated :
192195tag_bytes = encoder .TagBytes (field_number ,wire_type )
193196tag_len = len (tag_bytes )
194- def DecodeRepeatedField (buffer ,pos ,end ,message ,field_dict ):
197+ def DecodeRepeatedField (
198+ buffer ,pos ,end ,message ,field_dict ,current_depth = 0
199+ ):
200+ del current_depth # unused
195201value = field_dict .get (key )
196202if value is None :
197203value = field_dict .setdefault (key ,new_default (message ))
@@ -208,7 +214,8 @@ def DecodeRepeatedField(buffer, pos, end, message, field_dict):
208214return new_pos
209215return DecodeRepeatedField
210216else :
211- def DecodeField (buffer ,pos ,end ,message ,field_dict ):
217+ def DecodeField (buffer ,pos ,end ,message ,field_dict ,current_depth = 0 ):
218+ del current_depth # unused
212219 (new_value ,pos )= decode_value (buffer ,pos )
213220if pos > end :
214221raise _DecodeError ('Truncated message.' )
@@ -352,7 +359,9 @@ def EnumDecoder(field_number, is_repeated, is_packed, key, new_default,
352359enum_type = key .enum_type
353360if is_packed :
354361local_DecodeVarint = _DecodeVarint
355- def DecodePackedField (buffer ,pos ,end ,message ,field_dict ):
362+ def DecodePackedField (
363+ buffer ,pos ,end ,message ,field_dict ,current_depth = 0
364+ ):
356365"""Decode serialized packed enum to its value and a new position.
357366
358367 Args:
@@ -365,6 +374,7 @@ def DecodePackedField(buffer, pos, end, message, field_dict):
365374 Returns:
366375 int, new position in serialized data.
367376 """
377+ del current_depth # unused
368378value = field_dict .get (key )
369379if value is None :
370380value = field_dict .setdefault (key ,new_default (message ))
@@ -405,7 +415,9 @@ def DecodePackedField(buffer, pos, end, message, field_dict):
405415elif is_repeated :
406416tag_bytes = encoder .TagBytes (field_number ,wire_format .WIRETYPE_VARINT )
407417tag_len = len (tag_bytes )
408- def DecodeRepeatedField (buffer ,pos ,end ,message ,field_dict ):
418+ def DecodeRepeatedField (
419+ buffer ,pos ,end ,message ,field_dict ,current_depth = 0
420+ ):
409421"""Decode serialized repeated enum to its value and a new position.
410422
411423 Args:
@@ -418,6 +430,7 @@ def DecodeRepeatedField(buffer, pos, end, message, field_dict):
418430 Returns:
419431 int, new position in serialized data.
420432 """
433+ del current_depth # unused
421434value = field_dict .get (key )
422435if value is None :
423436value = field_dict .setdefault (key ,new_default (message ))
@@ -446,7 +459,7 @@ def DecodeRepeatedField(buffer, pos, end, message, field_dict):
446459return new_pos
447460return DecodeRepeatedField
448461else :
449- def DecodeField (buffer ,pos ,end ,message ,field_dict ):
462+ def DecodeField (buffer ,pos ,end ,message ,field_dict , current_depth = 0 ):
450463"""Decode serialized repeated enum to its value and a new position.
451464
452465 Args:
@@ -459,6 +472,7 @@ def DecodeField(buffer, pos, end, message, field_dict):
459472 Returns:
460473 int, new position in serialized data.
461474 """
475+ del current_depth # unused
462476value_start_pos = pos
463477 (enum_value ,pos )= _DecodeSignedVarint32 (buffer ,pos )
464478if pos > end :
@@ -540,7 +554,10 @@ def _ConvertToUnicode(memview):
540554tag_bytes = encoder .TagBytes (field_number ,
541555wire_format .WIRETYPE_LENGTH_DELIMITED )
542556tag_len = len (tag_bytes )
543- def DecodeRepeatedField (buffer ,pos ,end ,message ,field_dict ):
557+ def DecodeRepeatedField (
558+ buffer ,pos ,end ,message ,field_dict ,current_depth = 0
559+ ):
560+ del current_depth # unused
544561value = field_dict .get (key )
545562if value is None :
546563value = field_dict .setdefault (key ,new_default (message ))
@@ -557,7 +574,8 @@ def DecodeRepeatedField(buffer, pos, end, message, field_dict):
557574return new_pos
558575return DecodeRepeatedField
559576else :
560- def DecodeField (buffer ,pos ,end ,message ,field_dict ):
577+ def DecodeField (buffer ,pos ,end ,message ,field_dict ,current_depth = 0 ):
578+ del current_depth # unused
561579 (size ,pos )= local_DecodeVarint (buffer ,pos )
562580new_pos = pos + size
563581if new_pos > end :
@@ -581,7 +599,10 @@ def BytesDecoder(field_number, is_repeated, is_packed, key, new_default,
581599tag_bytes = encoder .TagBytes (field_number ,
582600wire_format .WIRETYPE_LENGTH_DELIMITED )
583601tag_len = len (tag_bytes )
584- def DecodeRepeatedField (buffer ,pos ,end ,message ,field_dict ):
602+ def DecodeRepeatedField (
603+ buffer ,pos ,end ,message ,field_dict ,current_depth = 0
604+ ):
605+ del current_depth # unused
585606value = field_dict .get (key )
586607if value is None :
587608value = field_dict .setdefault (key ,new_default (message ))
@@ -598,7 +619,8 @@ def DecodeRepeatedField(buffer, pos, end, message, field_dict):
598619return new_pos
599620return DecodeRepeatedField
600621else :
601- def DecodeField (buffer ,pos ,end ,message ,field_dict ):
622+ def DecodeField (buffer ,pos ,end ,message ,field_dict ,current_depth = 0 ):
623+ del current_depth # unused
602624 (size ,pos )= local_DecodeVarint (buffer ,pos )
603625new_pos = pos + size
604626if new_pos > end :
@@ -623,7 +645,9 @@ def GroupDecoder(field_number, is_repeated, is_packed, key, new_default):
623645tag_bytes = encoder .TagBytes (field_number ,
624646wire_format .WIRETYPE_START_GROUP )
625647tag_len = len (tag_bytes )
626- def DecodeRepeatedField (buffer ,pos ,end ,message ,field_dict ):
648+ def DecodeRepeatedField (
649+ buffer ,pos ,end ,message ,field_dict ,current_depth = 0
650+ ):
627651value = field_dict .get (key )
628652if value is None :
629653value = field_dict .setdefault (key ,new_default (message ))
@@ -632,7 +656,13 @@ def DecodeRepeatedField(buffer, pos, end, message, field_dict):
632656if value is None :
633657value = field_dict .setdefault (key ,new_default (message ))
634658# Read sub-message.
635- pos = value .add ()._InternalParse (buffer ,pos ,end )
659+ current_depth += 1
660+ if current_depth > _recursion_limit :
661+ raise _DecodeError (
662+ 'Error parsing message: too many levels of nesting.'
663+ )
664+ pos = value .add ()._InternalParse (buffer ,pos ,end ,current_depth )
665+ current_depth -= 1
636666# Read end tag.
637667new_pos = pos + end_tag_len
638668if buffer [pos :new_pos ]!= end_tag_bytes or new_pos > end :
@@ -644,12 +674,16 @@ def DecodeRepeatedField(buffer, pos, end, message, field_dict):
644674return new_pos
645675return DecodeRepeatedField
646676else :
647- def DecodeField (buffer ,pos ,end ,message ,field_dict ):
677+ def DecodeField (buffer ,pos ,end ,message ,field_dict , current_depth = 0 ):
648678value = field_dict .get (key )
649679if value is None :
650680value = field_dict .setdefault (key ,new_default (message ))
651681# Read sub-message.
652- pos = value ._InternalParse (buffer ,pos ,end )
682+ current_depth += 1
683+ if current_depth > _recursion_limit :
684+ raise _DecodeError ('Error parsing message: too many levels of nesting.' )
685+ pos = value ._InternalParse (buffer ,pos ,end ,current_depth )
686+ current_depth -= 1
653687# Read end tag.
654688new_pos = pos + end_tag_len
655689if buffer [pos :new_pos ]!= end_tag_bytes or new_pos > end :
@@ -668,7 +702,9 @@ def MessageDecoder(field_number, is_repeated, is_packed, key, new_default):
668702tag_bytes = encoder .TagBytes (field_number ,
669703wire_format .WIRETYPE_LENGTH_DELIMITED )
670704tag_len = len (tag_bytes )
671- def DecodeRepeatedField (buffer ,pos ,end ,message ,field_dict ):
705+ def DecodeRepeatedField (
706+ buffer ,pos ,end ,message ,field_dict ,current_depth = 0
707+ ):
672708value = field_dict .get (key )
673709if value is None :
674710value = field_dict .setdefault (key ,new_default (message ))
@@ -679,18 +715,27 @@ def DecodeRepeatedField(buffer, pos, end, message, field_dict):
679715if new_pos > end :
680716raise _DecodeError ('Truncated message.' )
681717# Read sub-message.
682- if value .add ()._InternalParse (buffer ,pos ,new_pos )!= new_pos :
718+ current_depth += 1
719+ if current_depth > _recursion_limit :
720+ raise _DecodeError (
721+ 'Error parsing message: too many levels of nesting.'
722+ )
723+ if (
724+ value .add ()._InternalParse (buffer ,pos ,new_pos ,current_depth )
725+ != new_pos
726+ ):
683727# The only reason _InternalParse would return early is if it
684728# encountered an end-group tag.
685729raise _DecodeError ('Unexpected end-group tag.' )
686730# Predict that the next tag is another copy of the same repeated field.
731+ current_depth -= 1
687732pos = new_pos + tag_len
688733if buffer [new_pos :pos ]!= tag_bytes or new_pos == end :
689734# Prediction failed. Return.
690735return new_pos
691736return DecodeRepeatedField
692737else :
693- def DecodeField (buffer ,pos ,end ,message ,field_dict ):
738+ def DecodeField (buffer ,pos ,end ,message ,field_dict , current_depth = 0 ):
694739value = field_dict .get (key )
695740if value is None :
696741value = field_dict .setdefault (key ,new_default (message ))
@@ -699,11 +744,14 @@ def DecodeField(buffer, pos, end, message, field_dict):
699744new_pos = pos + size
700745if new_pos > end :
701746raise _DecodeError ('Truncated message.' )
702- # Read sub-message.
703- if value ._InternalParse (buffer ,pos ,new_pos )!= new_pos :
747+ current_depth += 1
748+ if current_depth > _recursion_limit :
749+ raise _DecodeError ('Error parsing message: too many levels of nesting.' )
750+ if value ._InternalParse (buffer ,pos ,new_pos ,current_depth )!= new_pos :
704751# The only reason _InternalParse would return early is if it encountered
705752# an end-group tag.
706753raise _DecodeError ('Unexpected end-group tag.' )
754+ current_depth -= 1
707755return new_pos
708756return DecodeField
709757
@@ -859,7 +907,8 @@ def MapDecoder(field_descriptor, new_default, is_message_map):
859907# Can't read _concrete_class yet; might not be initialized.
860908message_type = field_descriptor .message_type
861909
862- def DecodeMap (buffer ,pos ,end ,message ,field_dict ):
910+ def DecodeMap (buffer ,pos ,end ,message ,field_dict ,current_depth = 0 ):
911+ del current_depth # unused
863912submsg = message_type ._concrete_class ()
864913value = field_dict .get (key )
865914if value is None :
@@ -941,8 +990,16 @@ def _SkipGroup(buffer, pos, end):
941990return pos
942991pos = new_pos
943992
993+ DEFAULT_RECURSION_LIMIT = 100
994+ _recursion_limit = DEFAULT_RECURSION_LIMIT
995+
996+
997+ def SetRecursionLimit (new_limit ):
998+ global _recursion_limit
999+ _recursion_limit = new_limit
1000+
9441001
945- def _DecodeUnknownFieldSet (buffer ,pos ,end_pos = None ):
1002+ def _DecodeUnknownFieldSet (buffer ,pos ,end_pos = None , current_depth = 0 ):
9461003"""Decode UnknownFieldSet. Returns the UnknownFieldSet and new position."""
9471004
9481005unknown_field_set = containers .UnknownFieldSet ()
@@ -952,14 +1009,14 @@ def _DecodeUnknownFieldSet(buffer, pos, end_pos=None):
9521009field_number ,wire_type = wire_format .UnpackTag (tag )
9531010if wire_type == wire_format .WIRETYPE_END_GROUP :
9541011break
955- (data ,pos )= _DecodeUnknownField (buffer ,pos ,wire_type )
1012+ (data ,pos )= _DecodeUnknownField (buffer ,pos ,wire_type , current_depth )
9561013# pylint: disable=protected-access
9571014unknown_field_set ._add (field_number ,wire_type ,data )
9581015
9591016return (unknown_field_set ,pos )
9601017
9611018
962- def _DecodeUnknownField (buffer ,pos ,wire_type ):
1019+ def _DecodeUnknownField (buffer ,pos ,wire_type , current_depth = 0 ):
9631020"""Decode a unknown field. Returns the UnknownField and new position."""
9641021
9651022if wire_type == wire_format .WIRETYPE_VARINT :
@@ -973,7 +1030,11 @@ def _DecodeUnknownField(buffer, pos, wire_type):
9731030data = buffer [pos :pos + size ].tobytes ()
9741031pos += size
9751032elif wire_type == wire_format .WIRETYPE_START_GROUP :
976- (data ,pos )= _DecodeUnknownFieldSet (buffer ,pos )
1033+ current_depth += 1
1034+ if current_depth >= _recursion_limit :
1035+ raise _DecodeError ('Error parsing message: too many levels of nesting.' )
1036+ (data ,pos )= _DecodeUnknownFieldSet (buffer ,pos ,None ,current_depth )
1037+ current_depth -= 1
9771038elif wire_type == wire_format .WIRETYPE_END_GROUP :
9781039return (0 ,- 1 )
9791040else :