# Copyright (c) 2021, 2022, 2023, Panagiotis Tsirigotis
# This file is part of linuxnet-iptables.
#
# linuxnet-iptables is free software: you can redistribute it and/or
# modify it under the terms of version 3 of the GNU Affero General Public
# License as published by the Free Software Foundation.
#
# linuxnet-iptables is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY
# or FITNESS FOR A PARTICULAR PURPOSE. See the GNU Affero General Public
# License for more details.
#
# You should have received a copy of the GNU Affero General
# Public License along with linuxnet-iptables. If not, see
# <https://www.gnu.org/licenses/>.
"""This module provides the Chain class
"""
import traceback
from typing import Callable, Iterator, List, Optional
from .exceptions import (IptablesError, IptablesParsingError)
from .rule import ChainRule
from .targets import Target, Targets, ChainTarget
from .deps import get_logger
_logger = get_logger('linuxnet.iptables.chain')
# pylint: disable=too-many-instance-attributes,too-many-public-methods
[docs]class Chain:
"""This class is used to represent an iptables chain.
A chain contains a list of rules which can be referenced by
number (rule numbers start with 1).
A :class:`Chain` instance is iterable, returning the chain's rules.
The :class:`Chain` class supports the standard :func:`len` function
returning the number of rules in the chain.
The :class:`Chain` class supports integer-based indexing
(slices are not supported). Positive integers are interpreteted
as rule numbers, i.e. indexing starts at ``1``.
Index ``0`` will raise an :exc:`IndexError`.
Negative index values are supported with ``-1`` identifying the last rule,
``-2`` identifying the penultimate rule, etc.
"""
def __init__(self, chain_name: str):
"""
:param chain_name: real chain name
"""
self.__real_chain_name = chain_name
self.__reference_count = 0
self.__packet_count = 0
self.__byte_count = 0
self.__rule_list = []
self.__pft = None
def __str__(self):
return f'Chain({self.__real_chain_name})'
def __len__(self):
return len(self.__rule_list)
def __getitem__(self, rulenum):
if not isinstance(rulenum, int):
raise TypeError("only integer-based indexing supported")
if rulenum > 0:
return self.__rule_list[rulenum-1]
if rulenum == 0:
raise IndexError(f"bad rule number: {rulenum}")
return self.__rule_list[rulenum]
def __iter__(self):
"""Iterator for the chain's rules
"""
return iter(self.__rule_list)
[docs] def is_builtin(self) -> bool: # pylint: disable=no-self-use
"""Returns ``True`` if this is a built-in chain (e.g. ``INPUT``)
"""
return False
[docs] def has_rules(self) -> bool:
"""Returns ``True`` if the chain contains any rules (note that
a :class:`Chain` instance can also be used directly in a boolean
content; if empty, it evaluates to ``False``).
"""
return bool(self.__rule_list)
[docs] def get_reference_count(self) -> int:
"""Returns the reference count of a (non-builtin) chain; returns 0
for builtin chains
"""
return self.__reference_count
[docs] def get_packet_count(self) -> int:
"""Returns the packet count of the chain
"""
return self.__packet_count
[docs] def get_byte_count(self) -> int:
"""Returns the byte count of the chain
"""
return self.__byte_count
[docs] def get_real_name(self) -> str:
"""Returns the real chain name
"""
return self.__real_chain_name
[docs] def get_logical_name(self) -> str:
"""Returns the logical chain name
"""
if self.__pft is None:
return self.__real_chain_name
return self.__pft.rcn2lcn(self.__real_chain_name)
[docs] def has_unparsed_rules(self) -> bool:
"""Returns ``True`` if the chain contains unparsed rules
"""
for rule in self.__rule_list:
if rule.parsing_failed():
return True
return False
[docs] def get_unparsed_rule_count(self) -> int:
"""Returns the number of unparsed rules
"""
count = 0
for rule in self.__rule_list:
if rule.parsing_failed():
count += 1
return count
[docs] def get_rule_count(self) -> int:
"""Returns the number of rules in the chain (note that the
standard :func:`len` function is also supported)
"""
return len(self.__rule_list)
[docs] def get_rules(self) -> List[ChainRule]:
"""Returns a list that contains the chain rules.
"""
return self.__rule_list[:]
[docs] def iter_rules(self, *, chain_target=False, uses_goto=False,
match_count: Optional[int] =None,
match: Optional['Match'] =None) -> Iterator[ChainRule]:
"""Returns an iterator for the chain rules. The rules returned
by the iterator depend on the arguments:
:param chain_target: if ``True``, return rules where the target
is a chain
:param uses_goto: if ``True``, return rules that use goto
:param match_count: if not ``None``, return rules that have that
number of matches
:param match: if not ``None``, return rules that have a matching
:class:`Match` in their match list; if the ``match`` has
no criteria set, it will match any :class:``Match``
instance of the **same** class
"""
if (chain_target or uses_goto or match_count is not None or
match is not None):
match_klass = type(None)
if match:
# Perform a match class comparison
match_klass = type(match)
for crit in match.get_criteria():
if crit is not None and crit.is_set():
# Perform a match value comparison
match_klass = type(None)
break
# Define the filter function
def rule_filter( # pylint: disable=too-many-return-statements
rule: ChainRule) -> bool:
"""Returns True/False based on whether the specified rule
satisfies the specified criteria.
"""
if chain_target or uses_goto:
target_chain = rule.get_target_chain()
if target_chain is None:
return False
if uses_goto and not rule.uses_goto():
return False
if (match_count is not None and
rule.get_match_count() != match_count):
return False
if match is None:
return True
for rule_match in rule:
if match_klass is type(rule_match):
return True
if rule_match == match:
return True
return False
return filter(rule_filter, self.__rule_list)
return iter(self.__rule_list)
def _set_rule_list(self, rule_list: List[ChainRule]) -> None:
"""Set the rule list.
This method is only used by the parsing code, so it does not
update any chain reference counts.
"""
for rulenum, rule in enumerate(rule_list, 1):
rule._set_chain(self, rulenum) # pylint: disable=protected-access
self.__rule_list = rule_list
def _propagate_rule_stats(self, log_stat_failures: bool) -> None:
"""Propagate the packet/byte counts of each rule to
the rule's target.
This method is only used by the parsing code.
"""
for rule in self.__rule_list:
target = rule.get_target()
if not isinstance(target, ChainTarget):
continue
target_chain = target.get_chain()
if target_chain is None:
if log_stat_failures:
_logger.warning("%s: unknown chain: %s",
self._propagate_rule_stats.__qualname__,
self.get_real_name())
_logger.warning("Call stack:\n%s",
''.join(traceback.extract_stack().format()[:-1]))
continue
target_chain._update_stats( # pylint: disable=protected-access
packet_count=rule.get_packet_count(),
byte_count=rule.get_byte_count())
def _update_stats(self, packet_count: int, byte_count: int) -> None:
"""Update the packet/byte counts of this chain
"""
self.__packet_count += packet_count
self.__byte_count += byte_count
[docs] def find_rule_by_target_lcn(self,
logical_chain_name: str) -> List[ChainRule]:
"""Return a list of rules with the specified chain as a target
:param logical_chain_name: identifies the chain targeted by the rule
"""
rule_list = []
for rule in self.__rule_list:
target_chain = rule.get_target_chain()
if target_chain is None:
continue
# A builtin chain may reference multiple peers, each with a
# different prefix. We need to ignore ones not handled by our pft.
if (self.is_builtin() and
not self.__pft.is_handler_of(target_chain.get_real_name())):
continue
if target_chain.get_logical_name() == logical_chain_name:
rule_list.append(rule)
return rule_list
[docs] def find_rule_by(self, *, match: Optional['Match'] =None,
is_only_match=True,
target: Optional[Target] =None) -> List[ChainRule]:
"""Return a list of :class:`ChainRule` objects where the rule
contains the specified ``match`` object or has the specified ``target``
(target comparison is by name).
If both ``match`` and ``target`` are specified, returned rules
must satisfy both criteria.
If no ``match`` or ``target`` is present, an empty list is returned.
:param match: :class:`Match` object to compare against;
if ``match`` is ``None``, do not perform any match comparisons;
if ``match`` is a :class:`MatchNone` object, this will
match a rule that has no matches
:param is_only_match: if ``True`` the specified ``match`` must
be the only match used in the rule
:param target: :class:`Target` object to compare against;
if ``target`` is ``None``, do not perform any target comparisons;
if ``target`` is a :class:`TargetNone` object, this will
match a rule that has no target
"""
if match is None and target is None:
return []
return [rule for rule in self.__rule_list
if (match is None or rule.has_match(match, is_only_match)) and
(target is None or rule.has_target(target))]
[docs] def get_pft(self) -> 'IptablesPacketFilterTable':
"""Returns the :class:`IptablesPacketFilterTable` where this
chain belongs
"""
return self.__pft
def _set_pft(self, pft) -> None:
"""Set the :class:`IptablesPacketFilterTable` where this
:class:`Chain` belongs.
:param pft: an :class:`IptablesPacketFilterTable` object
"""
self.__pft = pft
def _clear_pft(self) -> None:
"""Reset the :class:`IptablesPacketFilterTable` where this
:class:`Chain` belongs.
"""
self.__pft = None
[docs] def flush(self) -> None:
"""Delete all rules from this chain
"""
try:
args = ['-F', self.__real_chain_name]
_ = self.__pft.iptables_run(args, check=True)
except Exception:
_logger.exception("Failed to flush chain %s",
self.__real_chain_name)
raise
for rule in self.__rule_list:
try:
self.__dec_target_refcount(rule.get_target())
rule._deleted() # pylint: disable=protected-access
except Exception: # pylint: disable=broad-except
_logger.exception("%s: unexpected exception",
self.flush.__qualname__)
self.__rule_list.clear()
def _setref(self, count: int) -> None:
"""Set the chain reference count
"""
self.__reference_count = count
def _incref(self) -> None:
"""Increase the chain reference count
"""
self.__reference_count += 1
def _decref(self) -> None:
"""Decrease the chain reference count
"""
self.__reference_count -= 1
if self.__reference_count < 0:
# This shouldn't happen
_logger.warning("Negative refcount for chain %s",
self.__real_chain_name)
def __inc_target_refcount(self, target: Target) -> None:
"""If target is a :class:`ChainTarget`, increase the refcount of the
corresponding chain.
"""
if not isinstance(target, ChainTarget):
return
chain = target.resolve_chain(self.__pft)
if chain is not None:
chain._incref() # pylint: disable=protected-access
else:
_logger.warning("Missed refcount increase for chain %s",
target.get_target_name())
def __dec_target_refcount(self, target: Target) -> None:
"""If target is a :class:`ChainTarget`, decrease the refcount of the
corresponding chain.
:param target: a :class:`Target` object
"""
if not isinstance(target, ChainTarget):
return
chain = target.resolve_chain(self.__pft)
if chain is not None:
chain._decref() # pylint: disable=protected-access
else:
_logger.warning("Missed refcount decrease for chain %s",
target.get_target_name())
def __added_rule(self, rule: ChainRule, rulenum: int):
"""Added the specified rule.
"""
# pylint: disable=protected-access
rule._set_chain(self, rulenum)
self.__inc_target_refcount(rule.get_target())
# pylint: enable=protected-access
for i in range(rulenum, len(self.__rule_list)):
rule = self.__rule_list[i]
rule._inc_rulenum() # pylint: disable=protected-access
def __verify_no_owner(self, rule: ChainRule) -> None:
"""Raise an exception is this rule has an owner
"""
owner_chain = rule.get_chain()
if owner_chain is None:
return
if owner_chain is self:
raise IptablesError('rule already in this chain')
_logger.error("attempt to insert rule of chain '%s' to '%s'",
owner_chain, self)
raise IptablesError('rule belongs to another chain')
[docs] def append_rule(self, rule: ChainRule) -> None:
"""Append the new rule at the end of the chain
Raises an :exc:`IptablesError` if the rule is already part of a chain
"""
self.__verify_no_owner(rule)
rule_args = rule.to_iptables_args()
if not rule_args:
_logger.warning("%s: rule has no args: %s",
self.append_rule.__qualname__, rule)
return
args = ['-A', self.__real_chain_name] + rule_args
try:
_ = self.__pft.iptables_run(args, check=True)
except Exception:
_logger.exception("Failed to append rule to chain %s",
self.__real_chain_name)
raise
self.__rule_list.append(rule)
self.__added_rule(rule, rulenum=len(self.__rule_list))
[docs] def insert_rule(self, rule: ChainRule, rulenum=0) -> None:
"""Insert the new rule at the beginning of the chain (by
default) or as rule number ``rulenum``.
Raises an :exc:`IptablesError` if the rule is already part of a chain
:param rule: the :class:`ChainRule` to insert
:param rulenum: rule number (starting with 1) for the inserted
rule
"""
self.__verify_no_owner(rule)
rule_args = rule.to_iptables_args()
if not rule_args:
_logger.warning("%s: rule has no args: %s",
self.insert_rule.__qualname__, rule)
return
if rulenum < 0:
raise IptablesError(f'invalid rule number: {rulenum}')
rule_index = rulenum-1 if rulenum > 0 else 0
try:
self.__rule_list.insert(rule_index, rule)
except IndexError as idxerr:
raise IptablesError(
f'rule number out-of-range: {rulenum}') from idxerr
args = ['-I', self.__real_chain_name, str(rule_index+1)] + rule_args
try:
_ = self.__pft.iptables_run(args, check=True)
except Exception:
self.__rule_list.pop(rule_index)
_logger.exception("Failed to insert rule to chain %s",
self.__real_chain_name)
raise
self.__added_rule(rule, rulenum=rule_index+1)
[docs] def delete_rule(self, rule: ChainRule) -> None:
"""Delete the specified ``rule``.
Raises an :exc:`IptablesError` if the rule is not part of this chain.
"""
owner_chain = rule.get_chain()
if owner_chain is not self:
_logger.error(
"attempt to delete rule of chain '%s' from chain '%s'",
owner_chain, self)
raise IptablesError('attempt to delete rule from wrong chain')
rule_index = rule.get_rulenum() - 1
if self.__rule_list[rule_index] is not rule:
_logger.error("%s: wrong rule index '%d'; ChainRule: %s",
self.delete_rule.__qualname__, rule_index, rule)
raise IptablesError('internal rule list error')
self.__delete_rule_at(rule_index)
[docs] def delete_rulenum(self, rulenum: int) -> None:
"""Delete the rule with the specified rule number
Raises an :exc:`IptablesError` if the number is invalid
:param rulenum: rule number (numbering starts from 1)
"""
rule_index = rulenum - 1
if 0 <= rule_index < len(self.__rule_list):
self.__delete_rule_at(rule_index)
else:
raise IptablesError(f'bad rule number: {rulenum}')
def __delete_rule_at(self, rule_index: int) -> None:
"""Delete the rule at index ``rule_index`` in the rule_list.
This is the method that actually performs the deletion.
"""
rule = self.__rule_list.pop(rule_index)
# iptables enumerates rules starting from 1
rulenum = rule_index + 1
cmd = ['-D', self.__real_chain_name, f'{rulenum}']
try:
_ = self.__pft.iptables_run(cmd, check=True)
except Exception:
self.__rule_list.insert(rule_index, rule)
_logger.exception("unable to delete rule %d from chain %s",
rulenum, self.__real_chain_name)
raise
rule._deleted() # pylint: disable=protected-access
self.__dec_target_refcount(rule.get_target())
#
# Renumber rules after the deleted rule
#
for index in range(rule_index, len(self.__rule_list)):
rule = self.__rule_list[index]
rule._dec_rulenum() # pylint: disable=protected-access
def __delete_rules(self, rule_list: List[ChainRule]) -> int:
"""Delete a number of rules
"""
if not rule_list:
return 0
# We can delete rules in any order.
# The reason for sorting in reverse rule number order is that it
# makes debugging easier # as the rule numbers of the rules being
# deleted do not change.
rule_list.sort(key=lambda r: r.get_rulenum(), reverse=True)
for rule in rule_list:
self.__delete_rule_at(rule.get_rulenum()-1)
return len(rule_list)
[docs] def delete_rule_by_pred(self, pred: Callable[[ChainRule], bool]) -> int:
"""Delete all rules for which ``pred`` returns ``True``.
:param pred: a ``Callable`` object
:rtype: number of deleted rules
"""
deletion_list = [rule for rule in self.__rule_list if pred(rule)]
return self.__delete_rules(deletion_list)
[docs] def delete_rule_if(self, *, match: Optional['Match'] =None,
target: Optional[Target] =None) -> int:
"""Delete all rules with the specified ``match`` and/or ``target``.
If no ``match`` or ``target`` is present, this is a no-op.
:param match: :class:`Match` object to compare against;
the comparison will be successful if this is the **only** match
used in the rule;
if ``match`` is ``None``, do not perform any match comparisons;
if ``match`` is a :class:`MatchNone` object, this will
match a rule that has no matches
:param target: :class:`Target` object to compare against;
if ``target`` is ``None``, do not perform any target comparisons;
if ``target`` is a :class:`TargetNone` object, this will
match a rule that has no target
:rtype: number of deleted rules
"""
if match is None and target is None:
return 0
deletion_list = self.find_rule_by(match=match, target=target)
return self.__delete_rules(deletion_list)
[docs] def delete_rule_by_target_chain(self, chain: 'Chain') -> int:
"""Delete all rules that jump/goto the specified chain.
:param chain: a :class:`Chain` object
:rtype: number of deleted rules
"""
return self.delete_rule_by_pred(pred=lambda r: r.targets_chain(chain))
[docs] def zero_counters(self) -> None:
"""Zero the packet and byte counters of this chain in the kernel.
"""
if self.__pft is None:
raise IptablesError('chain not in kernel')
self.__pft.zero_counters(chain=self)
@classmethod
def __parse_chain_line(cls, line: str) -> 'Chain':
"""Parse a line which has one of the following 2 forms:
Chain INPUT (policy ACCEPT 9108340 packets, 10054611039 bytes)
Chain host_origin (1 references)
Returns a Chain object.
It raises an IptablesParsingError in case of a parsing error.
"""
fields = line.split(' ', 2)
n_fields = len(fields)
if n_fields != 3:
raise IptablesParsingError(
f'line has {n_fields} field(s) instead of 3', line=line)
if fields[0] != 'Chain':
raise IptablesParsingError("line does not start with 'Chain'",
line=line)
real_chain_name = fields[1]
packet_count = 0
byte_count = 0
policy = None
reference_count = 0
try:
param_fields = fields[2][1:-1].split()
if param_fields[0] == 'policy':
target_name = param_fields[1]
policy = Targets.get_special(target_name)
if policy is None:
raise IptablesParsingError(
f"unknown policy target for chain {real_chain_name}")
if param_fields[3].startswith('packets'):
packet_count = int(param_fields[2])
if param_fields[5].startswith('bytes'):
byte_count = int(param_fields[4])
elif param_fields[1] == 'references':
reference_count = int(param_fields[0])
else:
_logger.warning("unable to parse line: %s", line)
return None
if policy is not None:
return BuiltinChain(real_chain_name, policy,
packet_count, byte_count)
chain = Chain(real_chain_name)
chain._setref(reference_count)
return chain
except IndexError as idxerr:
raise IptablesParsingError(
'insufficient number of fields', line=line) from idxerr
except ValueError as valerr:
raise IptablesParsingError(
'bad field value', line=line) from valerr
[docs] @classmethod
def create_from_existing(cls, line_list: List[str],
pft: 'IptablesPacketFilterTable',
log_parsing_failures=True) -> 'Chain':
"""Parse a set of lines from the output of ``iptables -xnv``
into a :class:`Chain` object.
It raises an :exc:`IptablesParsingError` if there is a parsing error.
:param line_list: list of **iptables(8)** output lines
:param pft: an :class:`IptablesPacketFilterTable` object
:param log_parsing_failures: if ``True``, log any parsing failures
:rtype: a :class:`Chain` object
"""
line_iter = iter(line_list)
chain = cls.__parse_chain_line(next(line_iter).rstrip())
# The next line contains the headers - skip it
try:
_ = next(line_iter)
except StopIteration as stopit:
raise IptablesParsingError(
'chain output lines missing headers') from stopit
rule_list = []
for line in line_iter:
line = line.rstrip()
if not line:
continue
try:
rule = ChainRule.create_from_existing(line, pft)
except IptablesParsingError:
if log_parsing_failures:
_logger.exception(
"%s: chain=%s: error parsing rules; "
"will create unparsed rule",
cls.create_from_existing.__qualname__,
chain.get_real_name())
rule = ChainRule._create_unparsed_rule(line) # pylint: disable=protected-access
rule_list.append(rule)
chain._set_rule_list(rule_list) # pylint: disable=protected-access
return chain
[docs]class BuiltinChain(Chain):
"""This class is used to represent an iptables built-in chain.
Instances of this class are not intended to be created by the user;
they are created when processing the output of **iptables(8)**
"""
def __init__(self, chain_name, policy: Target,
packet_count: int, byte_count: int):
"""
:param chain_name: builtin chain name
:param policy: chain policy target
:param packet_count: number of packets that were processed according
to the policy
:param byte_count: number of bytes for packets that were processed
according to the policy
"""
super().__init__(chain_name)
self.__policy = policy
self.__policy_packet_count = packet_count
self.__policy_byte_count = byte_count
def __str__(self):
return f'BuiltinChain({self.get_real_name()})'
[docs] @staticmethod
def is_builtin() -> bool:
"""
:rtype: always returns ``True``
"""
return True
[docs] def get_policy(self) -> Target:
"""Returns the policy target of this builtin chain
"""
return self.__policy
[docs] def get_policy_packet_count(self) -> int:
"""Returns the number of packets that were handled as per
the chain policy
"""
return self.__policy_packet_count
[docs] def get_policy_byte_count(self) -> int:
"""Returns the number of bytes that were handled as per
the chain policy
"""
return self.__policy_byte_count
def _set_stats(self, packet_count: int, byte_count: int) -> None:
"""Set the packet/byte counts of this chain
"""
total_packet_count = packet_count + self.__policy_packet_count
total_byte_count = byte_count + self.__policy_byte_count
self._update_stats(total_packet_count, total_byte_count)
[docs] def set_policy(self, policy: Target) -> None:
"""Set the policy target of this builtin chain
"""
builtin = self.get_real_name()
policy_name = policy.get_target_name()
cmd = ['-P', builtin, policy_name]
try:
_ = self.get_pft().iptables_run(cmd, check=True)
except Exception:
_logger.exception("Failed to set policy of builtin chain %s",
builtin)
raise
self.__policy = policy