# Copyright (c) 2021, 2022, 2023, Panagiotis Tsirigotis
# This file is part of linuxnet-iptables.
#
# linuxnet-iptables is free software: you can redistribute it and/or
# modify it under the terms of version 3 of the GNU Affero General Public
# License as published by the Free Software Foundation.
#
# linuxnet-iptables is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY
# or FITNESS FOR A PARTICULAR PURPOSE. See the GNU Affero General Public
# License for more details.
#
# You should have received a copy of the GNU Affero General
# Public License along with linuxnet-iptables. If not, see
# <https://www.gnu.org/licenses/>.
"""
This module supports matching against TCP packets
"""
from enum import IntFlag
from typing import List, Optional, Set, Tuple
from ..exceptions import IptablesError, IptablesParsingError
from ..deps import get_logger
from .match import Match, Criterion, MatchParser
_logger = get_logger('linuxnet.iptables.matches.tcpmatch')
class TcpFlag(IntFlag):
"""Names and values for the TCP flags.
"""
FIN = 0x1
SYN = 0x2
RST = 0x4
PSH = 0x8
ACK = 0x10
URG = 0x20
[docs]class TcpFlagsCriterion(Criterion):
"""A criterion for comparing against packets with an arbitrary set of
TCP flags set, or for comparing against SYN packets. This is
an either-or use, determined at the time of object instantiation.
The value is the tuple (flags-checked, flags-set); both flags-checked
and flags-set are comma-separated lists of TCP flag names.
"""
def __init__(self, match: Match, syn_only=False):
"""
:param match: the :class:`Match` object that owns this object
:param syn_only: optional boolean value indicating a check only
against the **SYN** flag
"""
super().__init__(match)
# If syn_only is True, then flags_checked/flags_set will be None
self.__syn_only = syn_only
self.__flags_checked = None
self.__flags_set = None
def __eq__(self, other):
if not isinstance(other, TcpFlagsCriterion):
return False
if self.is_syn_only() ^ other.is_syn_only():
return False
if not self._may_be_equal(other):
return False
return self.get_value() == other.get_value()
[docs] def get_value(self) -> Tuple[Set[TcpFlag], Set[TcpFlag]]:
"""Returns the value that the criterion is comparing against
"""
return (self.__flags_checked, self.__flags_set)
[docs] def is_syn_only(self):
"""Returns ``True`` if the criterion is only meant to check
for the SYN flag (but note that it may not be set yet)
"""
return self.__syn_only
[docs] def bit_set(self) -> Match:
"""This method can be used if this criterion implements a
SYN-only comparison to check if the packet flags include only
the SYN bit.
"""
if not self.__syn_only:
raise IptablesError('not a syn-only criterion')
return self.equals()
[docs] def bit_not_set(self) -> Match:
"""This method can be used if this criterion implements a
SYN-only comparison to check for the non-existence of the SYN bit
"""
return self.not_equals()
[docs] def equals(self, # pylint: disable=arguments-differ
flags_checked: Optional[Set[TcpFlag]] =None,
flags_set: Optional[List[TcpFlag]] =None) -> Match:
"""Perform flags comparison
"""
if self.__syn_only:
if not (flags_checked is None and flags_set is None):
raise IptablesError("cannot set flags in SYN criterion")
return self._set_polarity(True)
if flags_checked is None:
raise IptablesError("need to specify flags to check")
if flags_set is None:
raise IptablesError("need to specify flags that are set")
self.__flags_checked = frozenset(flags_checked)
self.__flags_set = frozenset(flags_set)
return self._set_polarity(True)
def _crit_iptables_args(self) -> List[str]:
"""Returns **iptables(8)** arguments for the specified TCP flags
"""
if self.__syn_only:
return ['--syn']
return ['--tcp-flags',
','.join([f.name for f in self.__flags_checked]),
','.join([f.name for f in self.__flags_set])]
class _PortCriterion(Criterion):
"""Compare against a source/destination port or port-range
"""
def __init__(self, match: Match, iptables_option: str):
super().__init__(match)
self.__option = iptables_option
self.__port = None
self.__last_port = None
def get_value(self) -> Tuple[int, Optional[int]]:
"""Returns the value that the criterion is comparing against
"""
return (self.__port, self.__last_port)
def equals(self, # pylint: disable=arguments-differ
port: int, last_port: Optional[int] =None) -> Match:
"""Compare with a port (or inclusion in port-range if ``last_post``
is present)
"""
self.__port = port
self.__last_port = last_port
return self._set_polarity(True)
def _crit_iptables_args(self) -> List[str]:
"""Returns **iptables(8)** arguments for the specified port(s)
"""
port_spec = str(self.__port)
if self.__last_port is not None:
port_spec += f':{self.__last_port}'
return [self.__option, port_spec]
[docs]class SourcePortCriterion(_PortCriterion):
"""Compare with a source port or check for inclusion in port-range
The value is a the tuple (port, last_port) where last_port may be ``None``
"""
def __init__(self, match: Match):
super().__init__(match, '--sport')
[docs]class DestPortCriterion(_PortCriterion):
"""Compare against a destination port or check for inclusion in port-range
The value is a the tuple (port, last_port) where last_port may be ``None``
"""
def __init__(self, match: Match):
super().__init__(match, '--dport')
class _PortParser: # pylint: disable=too-few-public-methods
"""Helper class used to parse TCP/UDP port criteria
"""
SOURCE_PORT_PREFIX = ('spt:', 'spts:')
DEST_PORT_PREFIX = ('dpt:', 'dpts:')
PORT_PREFIX = SOURCE_PORT_PREFIX + DEST_PORT_PREFIX
@classmethod
def parse(cls, port_match_str: str, match: Match):
"""Add the proper criterion to 'match'
"""
if port_match_str.startswith(cls.SOURCE_PORT_PREFIX):
port_crit = match.source_port()
else:
port_crit = match.dest_port()
port_spec = port_match_str.split(':', 1)[1]
is_equal, port_spec = MatchParser.parse_value(port_spec)
if ':' not in port_spec:
port_crit.compare(is_equal, int(port_spec))
return
ports = port_spec.split(':', 1)
port_crit.compare(is_equal, int(ports[0]), int(ports[1]))
[docs]class TcpMatch(Match):
"""Match against the fields of the TCP header
"""
def __init__(self):
self.__flags_crit = None
self.__src_port_crit = None
self.__dest_port_crit = None
def __eq__(self, other):
return (
isinstance(other, TcpMatch) and
self.tcp_flags() == other.tcp_flags() and
self.source_port() == other.source_port() and
self.dest_port() == other.dest_port()
)
[docs] def syn(self) -> TcpFlagsCriterion:
"""Criterion for matching against a SYN packet
"""
if self.__flags_crit is None:
self.__flags_crit = TcpFlagsCriterion(self, syn_only=True)
return self.__flags_crit
[docs] def tcp_flags(self) -> TcpFlagsCriterion:
"""Compare with TCP flags
"""
if self.__flags_crit is None:
self.__flags_crit = TcpFlagsCriterion(self)
return self.__flags_crit
[docs] def source_port(self) -> SourcePortCriterion:
"""Matching against the source port
"""
if self.__src_port_crit is None:
self.__src_port_crit = SourcePortCriterion(self)
return self.__src_port_crit
[docs] def dest_port(self) -> DestPortCriterion:
"""Match against the destination port
"""
if self.__dest_port_crit is None:
self.__dest_port_crit = DestPortCriterion(self)
return self.__dest_port_crit
[docs] def to_iptables_args(self) -> List[str]:
"""Returns **iptables(8)** arguments for this match
"""
criteria = (self.__flags_crit, self.__src_port_crit,
self.__dest_port_crit)
return self.build_iptables_args('tcp', criteria)
@classmethod
def __parse_tcp_flags_num(cls, numstr: int) -> Set[TcpFlag]:
"""Parse a hex-value numstr (e.g. 0x11) into a set of TCP flags.
"""
try:
flag_mask = int(numstr, 16)
flags = {flag for flag in TcpFlag if flag_mask & flag}
return flags
except ValueError as valerr:
raise IptablesParsingError(
"Bad TCP flag mask: " + numstr) from valerr
[docs] @classmethod
def parse(cls, parser: MatchParser) -> Match:
"""Parse the TCP criteria
"""
criteria_iter = parser.get_iter()
match = TcpMatch()
for val in criteria_iter:
if val.startswith('flags:'):
flag_spec = val.split(':', 1)[1]
is_equal, flag_spec = parser.parse_value(flag_spec)
if '/' not in flag_spec:
raise IptablesParsingError(
f"no '/' in TCP flags: {flag_spec}")
mask, comp = flag_spec.split('/', 1)
flags_checked = cls.__parse_tcp_flags_num(mask)
flags_set = cls.__parse_tcp_flags_num(comp)
if (flags_set == {TcpFlag.SYN} and
flags_checked == {TcpFlag.FIN, TcpFlag.SYN,
TcpFlag.RST, TcpFlag.ACK}):
match.syn().compare(is_equal)
else:
match.tcp_flags().compare(is_equal,
flags_checked, flags_set)
elif val.startswith(_PortParser.PORT_PREFIX):
_PortParser.parse(val, match)
else:
criteria_iter.put_back(val)
break
return match
MatchParser.register_match('tcp', TcpMatch)