Source code for linuxnet.iptables.matches.tcpmatch

# Copyright (c) 2021, 2022, 2023, Panagiotis Tsirigotis

# This file is part of linuxnet-iptables.
#
# linuxnet-iptables is free software: you can redistribute it and/or
# modify it under the terms of version 3 of the GNU Affero General Public
# License as published by the Free Software Foundation.
#
# linuxnet-iptables is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY
# or FITNESS FOR A PARTICULAR PURPOSE. See the GNU Affero General Public
# License for more details.
#
# You should have received a copy of the GNU Affero General
# Public License along with linuxnet-iptables. If not, see
# <https://www.gnu.org/licenses/>.

"""
This module supports matching against TCP packets
"""

from enum import IntFlag
from typing import List, Optional, Set, Tuple

from ..exceptions import IptablesError, IptablesParsingError
from ..deps import get_logger

from .match import Match, Criterion, MatchParser

_logger = get_logger('linuxnet.iptables.matches.tcpmatch')


class TcpFlag(IntFlag):
    """Names and values for the TCP flags.
    """
    FIN = 0x1
    SYN = 0x2
    RST = 0x4
    PSH = 0x8
    ACK = 0x10
    URG = 0x20


[docs]class TcpFlagsCriterion(Criterion): """A criterion for comparing against packets with an arbitrary set of TCP flags set, or for comparing against SYN packets. This is an either-or use, determined at the time of object instantiation. The value is the tuple (flags-checked, flags-set); both flags-checked and flags-set are comma-separated lists of TCP flag names. """ def __init__(self, match: Match, syn_only=False): """ :param match: the :class:`Match` object that owns this object :param syn_only: optional boolean value indicating a check only against the **SYN** flag """ super().__init__(match) # If syn_only is True, then flags_checked/flags_set will be None self.__syn_only = syn_only self.__flags_checked = None self.__flags_set = None def __eq__(self, other): if not isinstance(other, TcpFlagsCriterion): return False if self.is_syn_only() ^ other.is_syn_only(): return False if not self._may_be_equal(other): return False return self.get_value() == other.get_value()
[docs] def get_value(self) -> Tuple[Set[TcpFlag], Set[TcpFlag]]: """Returns the value that the criterion is comparing against """ return (self.__flags_checked, self.__flags_set)
[docs] def is_syn_only(self): """Returns ``True`` if the criterion is only meant to check for the SYN flag (but note that it may not be set yet) """ return self.__syn_only
[docs] def bit_set(self) -> Match: """This method can be used if this criterion implements a SYN-only comparison to check if the packet flags include only the SYN bit. """ if not self.__syn_only: raise IptablesError('not a syn-only criterion') return self.equals()
[docs] def bit_not_set(self) -> Match: """This method can be used if this criterion implements a SYN-only comparison to check for the non-existence of the SYN bit """ return self.not_equals()
[docs] def equals(self, # pylint: disable=arguments-differ flags_checked: Optional[Set[TcpFlag]] =None, flags_set: Optional[List[TcpFlag]] =None) -> Match: """Perform flags comparison """ if self.__syn_only: if not (flags_checked is None and flags_set is None): raise IptablesError("cannot set flags in SYN criterion") return self._set_polarity(True) if flags_checked is None: raise IptablesError("need to specify flags to check") if flags_set is None: raise IptablesError("need to specify flags that are set") self.__flags_checked = frozenset(flags_checked) self.__flags_set = frozenset(flags_set) return self._set_polarity(True)
def _crit_iptables_args(self) -> List[str]: """Returns **iptables(8)** arguments for the specified TCP flags """ if self.__syn_only: return ['--syn'] return ['--tcp-flags', ','.join([f.name for f in self.__flags_checked]), ','.join([f.name for f in self.__flags_set])]
class _PortCriterion(Criterion): """Compare against a source/destination port or port-range """ def __init__(self, match: Match, iptables_option: str): super().__init__(match) self.__option = iptables_option self.__port = None self.__last_port = None def get_value(self) -> Tuple[int, Optional[int]]: """Returns the value that the criterion is comparing against """ return (self.__port, self.__last_port) def equals(self, # pylint: disable=arguments-differ port: int, last_port: Optional[int] =None) -> Match: """Compare with a port (or inclusion in port-range if ``last_post`` is present) """ self.__port = port self.__last_port = last_port return self._set_polarity(True) def _crit_iptables_args(self) -> List[str]: """Returns **iptables(8)** arguments for the specified port(s) """ port_spec = str(self.__port) if self.__last_port is not None: port_spec += f':{self.__last_port}' return [self.__option, port_spec]
[docs]class SourcePortCriterion(_PortCriterion): """Compare with a source port or check for inclusion in port-range The value is a the tuple (port, last_port) where last_port may be ``None`` """ def __init__(self, match: Match): super().__init__(match, '--sport')
[docs]class DestPortCriterion(_PortCriterion): """Compare against a destination port or check for inclusion in port-range The value is a the tuple (port, last_port) where last_port may be ``None`` """ def __init__(self, match: Match): super().__init__(match, '--dport')
class _PortParser: # pylint: disable=too-few-public-methods """Helper class used to parse TCP/UDP port criteria """ SOURCE_PORT_PREFIX = ('spt:', 'spts:') DEST_PORT_PREFIX = ('dpt:', 'dpts:') PORT_PREFIX = SOURCE_PORT_PREFIX + DEST_PORT_PREFIX @classmethod def parse(cls, port_match_str: str, match: Match): """Add the proper criterion to 'match' """ if port_match_str.startswith(cls.SOURCE_PORT_PREFIX): port_crit = match.source_port() else: port_crit = match.dest_port() port_spec = port_match_str.split(':', 1)[1] is_equal, port_spec = MatchParser.parse_value(port_spec) if ':' not in port_spec: port_crit.compare(is_equal, int(port_spec)) return ports = port_spec.split(':', 1) port_crit.compare(is_equal, int(ports[0]), int(ports[1]))
[docs]class TcpMatch(Match): """Match against the fields of the TCP header """ def __init__(self): self.__flags_crit = None self.__src_port_crit = None self.__dest_port_crit = None def __eq__(self, other): return ( isinstance(other, TcpMatch) and self.tcp_flags() == other.tcp_flags() and self.source_port() == other.source_port() and self.dest_port() == other.dest_port() )
[docs] def syn(self) -> TcpFlagsCriterion: """Criterion for matching against a SYN packet """ if self.__flags_crit is None: self.__flags_crit = TcpFlagsCriterion(self, syn_only=True) return self.__flags_crit
[docs] def tcp_flags(self) -> TcpFlagsCriterion: """Compare with TCP flags """ if self.__flags_crit is None: self.__flags_crit = TcpFlagsCriterion(self) return self.__flags_crit
[docs] def source_port(self) -> SourcePortCriterion: """Matching against the source port """ if self.__src_port_crit is None: self.__src_port_crit = SourcePortCriterion(self) return self.__src_port_crit
[docs] def dest_port(self) -> DestPortCriterion: """Match against the destination port """ if self.__dest_port_crit is None: self.__dest_port_crit = DestPortCriterion(self) return self.__dest_port_crit
[docs] def to_iptables_args(self) -> List[str]: """Returns **iptables(8)** arguments for this match """ criteria = (self.__flags_crit, self.__src_port_crit, self.__dest_port_crit) return self.build_iptables_args('tcp', criteria)
@classmethod def __parse_tcp_flags_num(cls, numstr: int) -> Set[TcpFlag]: """Parse a hex-value numstr (e.g. 0x11) into a set of TCP flags. """ try: flag_mask = int(numstr, 16) flags = {flag for flag in TcpFlag if flag_mask & flag} return flags except ValueError as valerr: raise IptablesParsingError( "Bad TCP flag mask: " + numstr) from valerr
[docs] @classmethod def parse(cls, parser: MatchParser) -> Match: """Parse the TCP criteria """ criteria_iter = parser.get_iter() match = TcpMatch() for val in criteria_iter: if val.startswith('flags:'): flag_spec = val.split(':', 1)[1] is_equal, flag_spec = parser.parse_value(flag_spec) if '/' not in flag_spec: raise IptablesParsingError( f"no '/' in TCP flags: {flag_spec}") mask, comp = flag_spec.split('/', 1) flags_checked = cls.__parse_tcp_flags_num(mask) flags_set = cls.__parse_tcp_flags_num(comp) if (flags_set == {TcpFlag.SYN} and flags_checked == {TcpFlag.FIN, TcpFlag.SYN, TcpFlag.RST, TcpFlag.ACK}): match.syn().compare(is_equal) else: match.tcp_flags().compare(is_equal, flags_checked, flags_set) elif val.startswith(_PortParser.PORT_PREFIX): _PortParser.parse(val, match) else: criteria_iter.put_back(val) break return match
MatchParser.register_match('tcp', TcpMatch)