Source code for flextls.protocol

"""
The SSL/TLS Protocol
"""

from flextls.exception import NotEnoughData


[docs]class Protocol(object): """ Base Class to decode protocols. """ payload_list = None def __init__(self, connection=None): self.fields = [] self._connection = connection self.payload = None self.payload_identifier_field = None self.payload_length_field = None # DTLS self.payload_fragment_length_field = None self.payload_fragment_offset_field = None def __add__(self, payload): self.set_payload(payload) return self def __getattr__(self, name): return self.get_field_value(name) def __setattr__(self, name, value): if name == "fields": object.__setattr__(self, name, value) return for field in self.fields: if field.name == name: field.value = value #print("set") return object.__setattr__(self, name, value) @classmethod def add_payload_type(cls, pattern, payload_class): if cls.payload_list is None: cls.payload_list = {} cls.payload_list[pattern] = payload_class def assemble(self): data = b"" payload = b"" if isinstance(self.payload, Protocol): payload = self.payload.encode() if self.payload_identifier_field is not None: for pay_pattern, pay_class in self.payload_list.items(): if isinstance(payload, pay_class): self.set_field_value( self.payload_identifier_field, pay_pattern ) break elif self.payload is not None: payload = self.payload if self.payload_length_field is not None and payload is not None: self.set_field_value( self.payload_length_field, len(payload) ) for field in self.fields: # print(self) # print(field) # print(field.name) data = data + field.assemble() data = data + payload return data @classmethod def decode(cls, data, connection=None, payload_auto_decode=True): obj = cls( connection=connection ) data = obj.dissect( data, payload_auto_decode=payload_auto_decode ) return (obj, data) def decode_payload(self, data=None, payload_auto_decode=True): if data is None: data = self.payload if data is None: return False if self.payload_fragment_length_field is not None and self.payload_fragment_offset_field is not None: fragment_length = self.get_field_value(self.payload_fragment_length_field) fragment_offset = self.get_field_value(self.payload_fragment_offset_field) payload_length = self.get_field_value(self.payload_length_field) if fragment_offset != 0 or fragment_length != payload_length: self.payload = data return data[:0] # print(self.payload_identifier_field) # print(self.payload_length_field) if self.payload_identifier_field is not None: if self.payload_length_field is None: payload_data = data data = data[:0] else: payload_length = self.get_field_value(self.payload_length_field) if len(data) < payload_length: raise NotEnoughData( "Not enough data to decode payload" ) payload_data = data[:payload_length] data = data[payload_length:] payload_class = None if self.payload_list is not None: payload_class = self.payload_list.get( self.get_field_value(self.payload_identifier_field), None ) if payload_class is None or payload_auto_decode is False: self.payload = payload_data else: (obj, payload_data) = payload_class.decode( payload_data, connection=self._connection, payload_auto_decode=payload_auto_decode ) self.payload = obj return data @classmethod def decode_raw_payload(cls, payload_type, payload_data, payload_auto_decode=False, connection=None): payload_cls = cls.payload_list.get(payload_type) if payload_cls is None: # ToDo: raise Exception return payload_cls.decode( payload_data, connection=connection, payload_auto_decode=payload_auto_decode ) def dissect(self, data, connection=None, payload_auto_decode=True): if connection is not None: self._connection = connection # print(">>>") # print(self) # print(data) for field in self.fields: data = field.dissect(data) data = self.decode_payload( data, payload_auto_decode=payload_auto_decode ) return data def encode(self): return self.assemble() def get_field(self, name): for field in self.fields: if field.name == name: return field raise AttributeError(name) def get_field_value(self, name): for field in self.fields: if field.name == name: return field.value def get_payload_pattern(self, payload_cls): for pay_pattern, pay_class in self.payload_list.items(): if issubclass(payload_cls, pay_class): return pay_pattern if isinstance(payload_cls, pay_class): return pay_pattern # ToDo: Change exception type? raise Exception("Payload pattern not found") def is_fragment(self): if self.payload_fragment_length_field is None and self.payload_fragment_offset_field is None: return None fragment_length = self.get_field_value(self.payload_fragment_length_field) fragment_offset = self.get_field_value(self.payload_fragment_offset_field) payload_length = self.get_field_value(self.payload_length_field) if fragment_offset != 0 or fragment_length != payload_length: return True return False def set_field_value(self, name, value): for field in self.fields: if field.name == name: field.value = value def set_field_values(self, values): for n, v in values.items(): self.set_field_value(n, v) def set_payload(self, payload): if self.payload_identifier_field is not None: for pay_pattern, pay_class in self.payload_list.items(): if isinstance(payload, pay_class): self.set_field_value( self.payload_identifier_field, pay_pattern ) break self.payload = payload