#! /usr/bin/env python
# -*- mode:python -*-
# $Id: dfd_keeper.py 13107 2010-08-20 15:15:27Z vax $

"""
Dynamic Firewall Daemon (dfd)
python/pf implementation (a/k/a the bridge keeper)

The dynamic firewall daemon sets up and maintains your packet filter.

Note that using this means you don't have to remember what order the
pf rules must be in; it takes care of that for you.

For more information:
<URL:http://www.subspacefield.org/security/dfd_keeper/>

Copyright 2005-2009 travis+c-dfd@subspacefield.org

This work is licensed under the Creative Commons
Attribution-Noncommercial-No Derivative Works 3.0 United States
License:

http://creativecommons.org/licenses/by-nc-nd/3.0/us/
"""

_revision="$Revision: 13107 $"

import time # So we can expire rules.
import os # So we can invoke commands.
import sets # Another listy kind of thing.
import pprint
import string
import subprocess
import syslog
from UserDict import DictMixin
from exceptions import KeyError

## These are classes related to rule creation.

# Implemented as a pair of lists.
# Derive from a mixin that adds full dict methods based on __xxxitem__
class named_list(DictMixin):
    """This class is a dictionary with a twist; as you create keys,
    it remembers their order, and so when you later wish to iterate
    over the dictionary, it returns values from the first key, followed
    by values from the second key, and so on.  It is implemented
    with two lists in this case.
    """
    __slots__ = ('keys', 'values') # optional performance tweak
    def __init__(self, list_of_tuples):
        self.keys = [ l[0] for l in list_of_tuples ]
        self.values = [ l[1] for l in list_of_tuples ]
    # This allows us to treat it as a sequence.
    def __iter__(self):
        """This returns an iter object over the values, not keys."""
        return self.values.__iter__()
    def __len___(self): return len(self.values)
    # TODO: Should raise IndexError for invalid numeric indices
    # so that for loops work properly.
    def __getitem__(self, key):
        try:
            return self.values[self.keys.index(key)]
        except ValueError:
            raise KeyError
    def __setitem__(self, key, value):
        try:
            self.values[self.keys.index(key)] = value
        except ValueError:
            self.keys.append(key)
            self.values.append(value)
        return value
    def __delitem__(self, key):
        i = self.keys.index(key)
        del self.keys[i]
        del self.values[i]
    # These are optional but provide higher performance.
    def keys(self):
        return self.keys
    def contains(self, key):
        return self.keys.contains(key)
    def iteritems(self):
        return zip(self.keys, self.values)

class apply_mixin:
    """
    This is a mix-in for any container classes that ultimately will contain
    rule objects.  It does a depth-first search of the hierarchy, dynamically
    computing the ancestors for each node.  Since parent information is not
    stored in the rules or containers, this means they can be easily adopted.
    This implements the visitor pattern.
    """
    def apply(self, fun, *ancestors):
        """Apply some function to every object in this collection."""
        return [x.apply(fun, self, *ancestors) for x in self]

class render_mixin:
    """
    This is a superclass for rule containers that may need to be rendered
    as a list of rules, with all variables expanded.
    """
    # TODO: Figure out a cleaner way of doing this.
    # TODO: Test it better.
    def valid(self):
        """
        Rule superclasses default to being valid.
        This determines if they appear in render output.
        """
        return True
    def expired(self):
        return False
    def render(self, namespace=None):
        """Concatenate the rendering of all elements of this sequence."""
        # First eliminate any expired children
        for x in self:
            if x.expired():
                self.remove(x)
        if namespace == None: namespace = self.namespace
        return self.flatten([x.render() for x in self if x.valid()])

    def flatten(self, l):
        """
        Flatten one level deep.  A list of lists becomes a list.
        Since this is done at the level of every rule container, it
        eventually returns a list, without any nested lists.
        """
        if len(l) < 1: return l
        else: return reduce(lambda x, y: x + y, l)

class rule_container(apply_mixin, render_mixin):
    def __init__(self, namespace):
        self.namespace = namespace

# Needs to have only certain sections
# They have a predefined order when rendering for use by pfctl
class ruleset(named_list, rule_container):
    """This represents your entire ruleset."""

    sections = ("macros", "tables", "options", "scrub",
                "queueing", "translation", "filter")

    def __init__(self, namespace):
        # Initialize the base class with our sections
        # as keys and the rule_holders as values..
        rule_container.__init__(self, namespace)
        named_list.__init__(self,
                            [(s, rule_holder(namespace))
                             for s in self.sections])
    def __str__(self):
        pp = pprint.PrettyPrinter(indent=4)
        return pp.pformat("".join([str(x) for x in self]))

    def __repr__(self):
        return "ruleset(" + ", ".join([repr(x) for x in self]) + ")"

class rule_holder(list, rule_container):
    """This is a container for rule objects."""
    def __init__(self, namespace):
        rule_container.__init__(self, namespace)

    def make_rule(self, *args, **kw):
        """Create a rule in this container."""
        r = rule(self.namespace, *args, **kw)
        self.append(r)
        return r

    # More convenient alias
    m = make_rule

    def __str__(self):
        return "\n".join([str(x) for x in self])

    def __repr__(self):
        return "rule_holder(" + ", ".join([repr(x) for x in self]) + ")"

class rule:
    """This is a particular pf rule."""

    def __init__(self, namespace, rule, **keywords):
        """Pass in an empty string if you want a no-op."""
        # Strip any leading or trailing whitespace.
        rule = rule.strip()
        # Just in case we need the text later, store it.
        self.rule = rule
        self.namespace = namespace
        self.tag = None
        # If the rule is a null rule, mark it inactive.
        if rule == "" and not keywords.has_key("active"):
            self.active = False
        else:
            self.active = True
        if keywords.has_key("lifespan"):
            if keywords.has_key("expires"):
                raise ParameterError, "Cannot specify lifespan and expires"
            else:
                keywords["expires"] = time.time() + keywords["lifespan"]
                del keywords["lifespan"]
        for kw in keywords.keys(): setattr(self, kw, keywords[kw])

    def expired(self):
        """Detect if this rule has expired or not."""
        try:
            expires = self.expires
        except AttributeError:
            return False
        else:
            return expires <= time.time()

    def timeout(self):
        """Return number of seconds until timeout.  Used by runtime."""
        try:
            expires = self.expires
        except AttributeError:
            return None
        else:
            seconds_left = expires - time.time()
            if seconds_left <= 0:
                return 0
            else:
                return seconds_left

    def apply(self, fun, *ancestors):
        """
        Call fun on this rule and its lineage (ancestors).
        This is the execution point for performing operations on rules
        and is typically called from the apply class's depth-first
        search.
        """
        return fun(self, *ancestors)

    def valid(self):
        """Returns True if this is a valid rule at this moment."""
        return self.active and not self.expired()

    def __repr__(self):
        if self.active == False: tmp = ", inactive"
        else: tmp = ""
        return "rule(" + str(self) + tmp + ")"

    def __str__(self):
        """Show the text form of the rule."""
        return self.rule

    def render(self):
        """Render this rule, resolving variable references."""
        if not self.valid():
            return []
        else:
            try:
                # Terminate all rules with newlines when rendering.
                return [template(self.rule).substitute(self.namespace) + "\n"]
            except EmptyValue:
                return []

class template(str):
    """Inspired by string.Template."""

    def __init__(self, data):
        self.template = data

    def substitute(self, ns):
        result = self.template
        # If there's any expansion to be done,
        if '$' in result:
            # Get the list of variables we could expand into order
            # sorted by length.  Use Schwartzian Transform.
            # http://www.python.org/doc/faq/programming/#i-want-to-do-a-complicated-sort-can-you-do-a-schwartzian-transform-in-python
            l = ns.keys()
            l.sort(key=lambda s: -len(s))
            # Now allow ten levels of nested expansions
            for i in range(10):
                # For every possible variable name,
                for k in l:
                    # If it appears in the results,
                    if '$' + k in result:
                        # Get the value it expands to,
                        v = ns[k]
                        # Call render on that, or use string value if no render method
                        # This allows nesting rules inside other rules.
                        x = getattr(v, "render", v.__str__)
                        # And replace the occurences of the variable with the results
                        result = result.replace('$' + k, x())
        # Throw an exception if we still have variables
        # TODO: Do not throw exception if the variable was in a comment.
        if '$' in result:
            pp = pprint.PrettyPrinter(indent=4)
            raise KeyError(pp.pformat(ns) + "\n" + result)
        return result

class rule_queue(rule_holder):
    """This is a fixed-length FIFO of rules."""

    def __init__(self, namespace, length, **keywords):
        self.length = length
        rule_holder.__init__(self, namespace, **keywords)
    def make_rule(self, *args, **kw):
        num_to_delete = len(self) + 1 - self.length
        if num_to_delete > 0:
            del self[:num_to_delete]
        return rule_holder.make_rule(self, *args, **kw)
    m = make_rule

# TODO: Add a LRU rule queue class that detects duplicates.
# Adding a rule again makes it "used", going to the head of the queue.

## These are classes related to logging, making them part of the
## runtime system.

# Logs to syslog
class syslogger:
    def __init__(self, ident, facility=syslog.LOG_DAEMON):
        self.ident = ident
        syslog.openlog(ident, syslog.LOG_PID, facility)
    def log(self, message, priority=syslog.LOG_INFO):
        syslog.syslog(priority, message)

# Logs to stdout
class outlogger:
    def __init__(self, ident):
        self.ident = ident
    def log(self, message, priority=None):
        print self.ident + " " + message

# Singleton class for logger
class logger:
    def __init__(self):
        pass
    def setlogger(self, syslog, ident='dfd'):
        """Set what kind of logging we want"""
        if syslog:
            logger.my_log = syslogger(ident)
        else:
            logger.my_log = outlogger(ident)
    def log(self, message, priority=syslog.LOG_INFO):
        logger.my_log.log(message, priority)

## These are exceptions, used mainly by the runtime system.
class ParameterError(StandardError):
    """This is raised when you specified an invalid parameter to a method."""
    pass

class FlushError(StandardError):
    """This is raised when pf gives an error flushing state"""
    pass

class CommandNotFoundError(StandardError):
    """This is raised when a command is not found."""
    pass

## This is the main class of the runtime system.

class pf:
    """This represents the entire state of the bridge keeper."""

    # You can invoke dfd in these modes.
    modes = ("daemon", "normal", "syntax", "test")

    # TODO: Consider using new interpolated strings to specify pf rules.

    def __init__(self, ruleset, namespace=dict(), mode="normal", syslog=True):
        # These are all the rules
        self.ruleset = ruleset
        # This is where pf variables get their contents.
        self.namespace = namespace
        # Error-check and set mode.
        if mode not in pf.modes:
            raise ParameterError, "Invalid mode to pf constructor."
        self.mode = mode
        # Set our history up.
        self.last_sync = None
        self.last_flush = None
        l = logger()
        l.setlogger(syslog)

    def __str__(self):
        return str(self.ruleset)
    
    def __repr__(self):
        tmp = ("namespace", "mode", "last_sync", "last_flush")
        tmp = map(lambda v: v + "=" + repr(getattr(self,v)), tmp)
        tmp = ", ".join(tmp) + ", "
        return "keeper(" + tmp + ", ".join(repr(self.ruleset)) + ")"

    def run_simple(self, *args):
        """Run a simple pfctl command and return the exit code."""

        if self.mode == "test": return
        if self.mode == "syntax":
            cmd = ("pfctl", "-n") + args
        else:
            cmd = ("pfctl",) + args
        l = logger()
        l.log("Running " + " ".join(cmd), syslog.LOG_DEBUG)
        # TODO: Trap standard output and error; state flushing is chatty.
        return os.spawnvp(os.P_WAIT, cmd[0], cmd)

    # TODO: Allow for spaces in command line arguments.
    # TODO: Return the standard output and error of the command
    #       so that people can debug their rules.
    def run_stdin(self, user, input, *args):
        """
        Run a pfctl command and provide it data to standard input.
        Returns None if successful, status in wait(3) format otherwise.
        """

        # If we're in testing mode, emulate everything
        if self.mode == "test":
            return reactor.callLater(0, self.send_results_to_user, user,
                                     "", "", "")
        cmd = [ "pfctl" ]
        # If we're in syntax-checking mode, use pfctl -n
        if self.mode == "syntax":
            cmd.append("-n")
        cmd.extend(args)
        # If input is a list, stringify it.
        if type(input) == type([]):
            input = "\n".join(input) + "\n"
        inst = pfctl_process(input, lambda *a, **kw: \
                             self.send_results_to_user(user, *a, **kw))
        l = logger()
        l.log("Running " + " ".join(cmd), syslog.LOG_DEBUG)
        reactor.spawnProcess(inst, cmd[0], cmd, os.environ)

    def send_results_to_user(self, user, output, stdout, stderr):
        """Print the results of the sync command."""
        if user: return user.command_done(output)
        elif len(output) > 0:
            l = logger()
            l.log("No user: " + output, syslog.LOG_ERR)

    def sync(self, user, force=False):
        """Synchronizes the firewall with the internal representation."""

        new_rules = self.ruleset.render()
        # If the rules have not changed, and we are not being forced,
        # do nothing.
        if new_rules == self.last_sync and not force: return
        # The rules have changed.  Reload them.
        self.run_stdin(user, new_rules, "-f", "/dev/stdin")
        # TODO: Figure out how to detect failed sync.
        self.last_sync = new_rules
        raise asynchronous_command()

    def force_sync(self):
        """This forces an initial sync before the reactor starts"""
        l = logger()
        l.log("Doing initial sync")
        if self.mode == "test": return
        args = ("-f", "/dev/stdin")
        if self.mode == "syntax":
            cmd = ("pfctl", "-n") + args
        else:
            cmd = ("pfctl",) + args
        rules = self.ruleset.render()
        s = subprocess.Popen(cmd, stdin=subprocess.PIPE,
                             stdout=subprocess.PIPE, stderr=subprocess.PIPE)
        (stdoutdata, stderrdata) = s.communicate(string.join(rules, "\n"))
        if stderrdata != "":
            l.log("Errors during initial sync:")
            errs = string.split(stderrdata, "\n")
            for err in errs:
                l.log(err)
        else:
            self.last_sync = rules

    def flush(self, src_host=None, dst_host=None):
        """This flushes the nat and filter state tables."""

        # NOTE: This may be passed magic variables or strings,
        # depending on if it is an automatic flush due to list
        # addition/deletion or due to a actual flush command.
        # Running str() an extra time doesn't hurt.

        l = logger()

        if src_host == None:
            # Flush all the state.
            new_rules = self.ruleset.render()
            # If the rules have not changed, do nothing.
            if new_rules == self.last_flush: return
            # The rules have changed.  Reload them.
            self.last_flush = new_rules
            l.log("Flushing all state", syslog.LOG_INFO)
            return self.run_simple("-F", "state")
        else:
            # Convert the magic thingies into strings.
            src_host = str(src_host)
            # Only flush state rules for a particular host.
            if dst_host == None:
                l.log("Flushing state for " + src_host, syslog.LOG_INFO)
                return self.run_simple("-k", src_host)
            else:
                dst_host = str(dst_host)
                l.log("Flushing state from " + src_host \
                       + " to " + dst_host, syslog.LOG_INFO)
                return self.run_simple("-k", src_host, "-k", dst_host)

    def flush_state(self, *args, **kws):
        """Flush the state table.  This is done automatically."""
        rv = self.flush(*args, **kws)
        if rv:
            raise FlushError("pfctl rejected the state flush with " + \
                             "exit code %s", str(rv))

## These are python objects representing pf data structures.

class EmptyValue(StandardError):
    """This is raised to indicate that a rule should not be rendered."""
    pass

class macro(str):
    """NOTE: This class is immutable."""
    def render(self):
        if len(self) == 0: raise EmptyValue
        return str.__str__(self)

# NOTE: This must come before the list class defined below.
class table(list):

    # Class-wide variable for assigning unique names to tables.
    numbered_tables=0

    def __init__(self, name=None, *a, **kw):
        if name != None:
            self.name = name
        else:
            # Pick a unique name.
            self.name = "table" + "%03d" % table.numbered_tables
            table.numbered_tables += 1
        return list.__init__(self, *a, **kw)

    def __str__(self):
        return "<" + self.name + ">"

class setlist(list):
    """
    This is my own kluge to get a combination of list-like
    semantics (indexability) with set-like semantics
    (uniqueness).
    """
    def add(self, item):
        if item not in self: self.append(item)

class pflist(setlist):
    def add(self, item):
        # TODO: Check for duplicates.
        self.append(item)
    def render(self):
        if len(self) == 0: raise EmptyValue
        # NOTE: Does not allow nested lists.
        return '{ ' + " ".join(map(lambda x: macro(x).render(), self)) + ' }'

class anchor:
    # TODO: Figure out how to do this typing automagically.
    def __init__(self, type="", *a, **kw):
        self.name = name
        if type not in ("rdr", "nat", "binat", ""):
            raise ParameterError, "unknown type %s" % type
        self.type = type
    def render(self):
        if type == "":
            return "anchor %s" % self.name
        else:
            return "%s-anchor %s" % (self.type, self.name)

class magic(setlist):
    """
    This is a mutable type that returns a macro or a list expansion,
    depending on how many elements it contains.
    """
    def __init__(self, a=[]):
        list.__init__(self)
        if type(a) == type (""):
	    if a != "":
                self.add(a)
        else:
            for e in a: self.add(e)
    def render(self):
        if len(self) == 1:
            return macro(self[0]).render()
        else:
            return pflist(self).render()

class magic_dict(DictMixin):
    def __init__(self, d):
        self.dict = d
    def __getitem__(self, key):
        return self.dict[key]
    def __setitem__(self, key, value):
        self.dict[key] = magic(value)
    def __delitem__(self, key):
        del self.dict[key]
    def keys(self):
        return self.dict.keys()

## Here begins the Twisted runtime components

import exceptions

import textwrap

import twisted.internet.protocol

import twisted.protocols.basic

from twisted.internet import reactor

class tcp_factory(twisted.internet.protocol.Factory):
    """
    This encapsulates all the persistent state a protocol needs.

    Put more specifically, this factory is set up and waits.  When a
    network connection comes in, Twisted's event loop calls
    buildProtocol.

    buildProtocol constructs a protocol object and a commander object
    and then they are linked to each other.  The protocol object will
    send lines of text to the commander object which will execute
    commands on behalf of the user.
    
    This is tailored for tcp command takers and allows you to easily
    specify a commander and protocol to instantiate.  Optionally
    you may specify a banner, responses for valid and invalid commands,
    and prompt.  If you wish to disable any of these just make them an
    empty string.  You may also configure the name of the quit command
    (or None if you don't wish to have one).
    """

    # TODO: Implement a command to shut the whole system down.
    def __init__(self, commander_factory, protocol,
                 banner = "Your wish is my command.",
                 ok_response = "It is done.",
                 bad_response = "Command not found.",
                 invalid_response = "Invalid argument(s) to command.",
                 prompt = "dfd_keeper>",
                 quit_commands = ("quit", "exit"),
                 stop_commands = ("stop",)):
        self.commander_factory = commander_factory
        self.protocol = protocol
        self.banner = _add_nl(banner)
        self.ok_response = _add_nl(ok_response)
        self.invalid_response = _add_nl(invalid_response)
        self.bad_response = _add_nl(bad_response)
        self.quit_commands = quit_commands
        self.stop_commands = stop_commands
        self.prompt = prompt
        l = logger()
        l.log("Starting")

    def buildProtocol(self, addr):
        """
        We have received a connection from a user; create the objects
        they need in order to do stuff.
        """
        l = logger()
        l.log("Connection from " + str(addr), syslog.LOG_INFO)
        # TODO: Implement address-based access control here?
        p = twisted.internet.protocol.Factory.buildProtocol(self, addr)
        c = self.commander_factory()
        p.commander = c
        p.peer = addr
        c.user = p
        return p

class pfctl_process(twisted.internet.protocol.ProcessProtocol):
    """
    Class for Twisted to interact with pfctl process.
    """

    def __init__(self, input, callback):
        self.input = input
        self.callback = callback
        self.output = self.stdout = self.stderr = ""
    def connectionMade(self):
        l = logger()
        l.log(self.input, syslog.LOG_DEBUG)
        self.transport.write(self.input)
        self.transport.closeStdin()
    def outReceived(self, data):
        l = logger()
        l.log(data, syslog.LOG_DEBUG)
        self.received(data)
    def errReceived(self, data):
        l = logger()
        l.log(data, syslog.LOG_WARNING)
        self.received(data)
    def received(self, data):
        self.output += data
        self.stderr += data
    def outConnectionLost(self):
        self.callback(self.output, self.stdout, self.stderr)

class line_receiver(twisted.protocols.basic.LineOnlyReceiver):
    """
    This class builds on the LineOnlyReceiver because it doesn't need
    the CR at EOL.
    """

    delimiter = "\n" # See?  No \r.

    def __init__(self):
        pass

    def lineReceived(self, line):
        if len(line) > 0:
            # Strip the CR from CRLF EOL convention, if it's there.
            if line[-1] == "\r":
                line = line[:-1]
        l = logger()
        l.log("Command by " + str(self.peer) + ": " + line, syslog.LOG_INFO)
        return self.line_received(line)

    def line_received(self, line):
        """Override this to take lines of input from the socket."""
        raise NotImplementedError

    def send_line(self, line):
        """Be conservative in what you send; always use CRLF."""
        return self.sendLine(line + "\r")

    def write_data(self, data):
        """Convert LF to CRLF, if it needs it."""
        data.replace("\n", "\r\n")
        data.replace("\r\r\n", "\r\n")
        return self.transport.write(data)

def _add_nl(s):
    """This helper function adds a newline iff a string needs one."""
    if s == None or len(s) == 0: return ""
    if s != "" and s[-1] != "\n":
        return s + "\n"
    else:
        return s

class stateful_queue(list):
    """
    This is essentially a queue of items that will be passed as
    arguments to the dispatch function.  Once you instantiate it, you
    may append or extend it with items that will be passed to the
    dispatch function
    """

    def __init__(self, dispatch, *a, **kw):
        self.ready = True
        # Call this for each command.
        self.dispatch = dispatch

    def append(self, *a, **kw):
        list.append(self, *a, **kw)
        if self.ready: self.process()

    def extend(self, *a, **kw):
        list.extend(self, *a, **kw)
        if self.ready: self.process()

    # It can be paused and resumed.
    def pause(self):
        self.ready = False

    def resume(self):
        self.ready = True
        self.process()

    def process(self):
        while len(self) != 0:
            command = self.pop(0)
            self.dispatch(command)

class asynchronous_command:
    """
    This is returned by a command to indicate that it will send back output
    asynchronously, so go ahead and return to the client and await
    notification.
    """
    def __init__(self, output=""):
        self.output = output
    def __str__(self):
        return self.output

class tcp_command_taker(line_receiver):
    """
    This class listens to a TCP port, queues incoming commands on self.queue,
    and then invokes commands in FIFO order via process_command.
    """
    def __init__(self, *a, **kw):
        self.queue = stateful_queue(self.process_command)
        line_receiver.__init__(self)

    def _n2n(self, o):
        """Maps None objects to empty (null) string."""
        if str(o) == None: return ""
        else: return str(o)

    def connectionMade(self):
        """We have received a TCP connection; print a banner and prompt."""
        self.write_data(self.factory.banner + self.factory.prompt)

    def line_received(self, line):
        """
        We have received a line of input from a client.
        Split it into commands via the semicolon.
        """
        self.queue.extend(line.split(';'))

    def process_command(self, command):
        """Split the line into words, stripping whitespace at either end."""
        vector = command.split()
        response = ""
        if len(vector) > 0:
            # Non-empty command.
            if command in self.factory.quit_commands:
                # Remove the circular reference
                del self.commander
                return self.transport.loseConnection()
            # TODO: Do we want a confirmation prompt?
            if command in self.factory.stop_commands:
                # TODO: Make this show up in help with this doc string:
                """Perform an elegant shutdown of DFD."""
                self.write_data("Shutting down cleanly...\n")
                return reactor.stop()
            try:
                response = self.commander._dispatch(*vector)
            # Specified a command which did not exist
            except CommandNotFoundError:
                response = self.factory.bad_response
            # One of the parameters was wrong
            except ParameterError, e:
                if self._n2n(e) == "": response = self.factory.invalid_response
                else: response = self._n2n(e)
            # They specified the wrong number of arguments.
            except exceptions.TypeError, e:
                if self._n2n(e) == "": response = self.factory.invalid_response
                else: response = self._n2n(e)
            # We are expecting more results via command_done.
            # Print the output so far and pause.
            # Print a prompt for the user to know that they have
            # the ability to execute commands.
            except asynchronous_command, e:
                self.write_data(self._n2n(e) + self.factory.ok_response)
                self.queue.pause()
                return
            else:
                response = self._n2n(response) + self.factory.ok_response
        response = _add_nl(response)
        response += self.factory.prompt
        self.write_data(response)
        return

    def command_done(self, delayed_output=""):
        """
        This indicates that there is no more data from the previous command,
        so it is okay to write a prompt.
        """
        self.write_data(self._n2n(delayed_output) + self.factory.prompt)
        self.queue.resume()

class chronos:
    """
    This class handles the time-related events for the rules in ruleset.

    You pass it a ruleset instance in its constructor and it schedules
    callbacks at the appropriate times to expire those rules.
    """
    def __init__(self, rs):
        self.rs = rs
        self.sync()

    def sync(self):
        """Schedule callbacks to delete expired rules."""
        def schedule_deletion(r, parent, *ancestors):
            secs = r.timeout()
            if secs != None:
                reactor.callLater(secs, parent.remove, r)
        self.rs.apply(schedule_deletion)

## The following classes are not used directly (except by the test code).
## Rather, they are for use by clients of this class.
## They are useful for implementing dynamic firewall changes.

# TODO: Compare this to the python and Perl ipv4 classes.
# TODO: Compare this to twisted.internet.interfaces.IAddress
class ipv4:
    """
    This represents IPv4-related objects
    """

    class address(list):
        """This represents an IPv4 address."""
    
        def __init__(self, dotted_quad):
            """Constructor which accepts a dotted quad."""
 
            # NB: storing as a string value to simplify printing
            octets = dotted_quad.split(".", 4)
            if len(octets) != 4:
                raise ParameterError, "More or less than four octets"
            for octet in octets:
                # NB: must convert to integer to test valid range
                try:
                    octet = int(octet)
                except:
                    raise ParameterError, "Octet non-integral"
                if (octet < 0 or octet > 255):
                    raise ParameterError, "Octet out of range"
                self.append(int(octet))

        def __repr__(self):
            return "ipv4.address(" + str(self) + ")"

        def __str__(self):
            return ".".join([str(octet) for octet in self])

    class cidr_address_range:
        """This represents an IPv4 address range."""

        def __init__(self, address, mask):
            """This is the most flexible method of specifying an IPv4 range."""
            self.address = ipv4.address(address)
            self.mask = ipv4.address(mask)

        def __repr__(self):
            return "ipv4.cidr_address_range(" + str(self.address) \
                   + "," + str(self.mask) + ")"

        def __str__(self):
            return str(self.address) + " mask " + str(self.mask)

        def contains(self, ip):
            """Return True iff this range contains the given address."""
            for i in (0, 1, 2, 3):
                if ((ip[i] & self.mask[i])
                    != (self.mask[i] & self.address[i])):
                    return False
            return True

class SyncError(StandardError):
    pass

class sync_proxy:
    """
    This is a class that performs syncing on another class.

    If you specify sync_every, that is how many method calls it will
    take to sync.  By leaving it equal to one you sync on every command.
    By setting it to zero, you never sync.  By setting it to 5 you sync
    every five proxied method calls.  You can sync manually at any time
    by calling sync(); this will reset the counter.
    """

    def __init__(self, delegate, sync_method, sync_every=1):
        # NOTE: Cannot say e.g. self._sync = sync_method since it
        # would trigger setattr.
        self._sync = sync_method
        self.sync_every = sync_every
        # This variable is how many proxied method calls before
        # the next call to self.sync().  Set it to 0 to disable.
        self.sync_in = sync_every
        self.delegate = delegate
    def delegate_call(self, method, *args, **keywords):
        """Perform a call on the delegate and then maybe sync."""
        return_value = method(*args, **keywords)
        # Sync the changes in the delegate if appropriate.
        if self.sync_every > 0:
            if self.sync_in == 1:
                try:
                    self.sync()
                # Sync errors are non-fatal, but we do want to pass
                # them on to the client.
                except SyncError, e:
                    return_value += str(e)
                except asynchronous_command:
                    raise asynchronous_command(return_value)
            else:
                self.sync_in -= 1
        return return_value
    def sync(self):
        self._sync()
        self.sync_in = self.sync_every
    def __getattr__(self, attribute_name):
        attrib = getattr(self.delegate, attribute_name)
        # NOTE: If it was a command, which first gets a method,
        # create a lambda to invoke it and maybe sync.
        if callable(attrib):
            return lambda *a, **kw: \
                   sync_proxy.delegate_call(self, attrib, *a, **kw)
        else:
            return attrib
    # NOTE: This is here so setting user on this object goes to the delegate.
    def __setattr__(self, attribute_name, attribute_value):
        if not self.__dict__.has_key(attribute_name) and \
               self.__dict__.has_key("delegate"):
            return setattr(self.delegate, attribute_name, attribute_value)
        else:
            self.__dict__[attribute_name] = attribute_value

# TODO: Create a helper function which dumps metadata about state file.
# TODO: Create a helper function which dumps data about version information.
# TODO: Implement "save state" command.
# (of script and classes).
class helper_functions:
    """
    This class is here for the convience of the client scripts.
    """

    # This is called with user input to dispatch a function.

    def _dispatch(self, *args, **kws):
        """Dispatch to a function based on user input."""

        cmd_prefix = getattr(self, "prefix", "")
        # NOTE: This is not the name, that's the second argument.
        try:
            cmd_method = getattr(self, cmd_prefix + args[0])
        except AttributeError, e:
            raise CommandNotFoundError
        # Actually call the method they indicated.
        try:
            response = apply(cmd_method, args[1:])
        # NOTE:  Allow asynchronous_command exceptions to propogate up.
        except ParameterError, e:
            # They supplied a bad argument to the function.
            if str(e) == "": raise ParameterError
            else: return _add_nl(str(e))
        return _add_nl(response)

    # Which methods can be used as end-user commands.
    end_user = ("version", "help", "show", "number", "variables", "sync", "flush")

    def version(self):
        return "dfd_keeper: %s\n" % _revision

    # The help command extracts help information from doc strings of functions
    # that are defined by client scripts.  See keeper_example for how to use
    # this.

    # TODO: Support the user specifying multiple commands.
    # TODO: Add help for protocol-level commands (stop, quit).
    def help(self, command=None):
        """Show help to the user.  A command may be provided as an argument."""

        if command != None:
            return self.help_single(self.prefix + command)
        else:
            # Show help for all commands.
            help_string = ""
            # TODO: Isn't there a more elegant way to enumerate a dict?
            for a in self.__class__.__dict__.keys():
                if callable(getattr(self, a)) and a.startswith(self.prefix):
                    help_string += self.help_single(a)
            return help_string

    # TODO: Don't indent body or put the name of the command unless
    # the user has not specified a command.
    # Simpler re-writes are quite welcome.
    def help_single(self, command):
        """Show help on a particular command to the user."""
        # Check to see if the command is the name of an instance method.
        a = getattr(self, command, None)
        if a == None or not callable(a):
            raise ParameterError, "Invalid command."
        doc_string = getattr(a, "__doc__", "No help for this command.")
        # XXX: Is this necessary?  Running help in test code fails otherwise.
        if doc_string == None:
            doc_string = "No help for this command."

        # This cleans up doc strings.
        def make_usage(doc_string):
            # Strip leading/trailing newlines.
            stripped = doc_string.strip("\n")
            # Remove trailing tabs and spaces.
            stripped = stripped.rstrip()
            # Expand tabs.
            expanded = doc_string.expandtabs()
            # Dedent.
            dedented = textwrap.dedent(expanded)
            # Dedent seems to insert leading/trailing newlines.
            dedented = dedented.strip("\n")
            # Split into paragraphs.
            paragraphs = dedented.split("\n\n")
            # TODO: Make use of textwrap.fill() to do filling.
            # Insert leading tabs.
            tabbed = ["\t" + x.replace("\n", "\n\t") for x in paragraphs ]
            # Join paragraphs back into a string.
            return "\n\n".join(tabbed)

        # Clean up the doc strings.
        doc_string = make_usage(doc_string)
        # Delete the self.prefix from the front of the method name.
        command = command[len(self.prefix):]
        # Prepend the name of the command to the help string.
        return "%s:\n%s\n" % (command, doc_string)

    def toggle_command(self, flushing, tag):
        """This method toggles a group of rules tagged with a certain tag."""
        for s in self.pf.ruleset:
            for r in s:
                if r.tag == tag:
                    r.active = not r.active
        if flushing: self.pf.flush()

    def toggle_factory(self, command, tag=None, help="", flushing=True):
        """This factory method creates commands that toggle some rules."""
        # By default, the tag is the same as the command name.
        if tag == None: tag = command
        # Make a reasonable default description (but allow override with None).
        if help == "": help = "Toggles %s rules on and off." % tag
        # Create the function that toggles the rules.
        l = lambda s: s.toggle_command(flushing, tag)
        l.__doc__ = help
        # Note that this creates a user-defined function in the class.
        # This is not an instance method!  That is generated automagically.
        setattr(self.__class__, self.prefix + command, l)

    def switch_command(self, flushing, tag, status="on"):
        """This command switches on/off the active flag of matching rules."""
        try:
            # NOTE: This cannot be a tuple since they don't have index.
            status = ["off", "on"].index(status)
        except ValueError:
            raise ParameterError, "Argument must be on (default) or off."
        # Convert 0/1 to True/False.
        status = (False, True)[status]
        for s in self.pf.ruleset:
            for r in s:
                if r.tag == tag:
                    r.active = status
        if flushing: self.pf.flush()

    def switch_factory(self, command, tag=None, help="", flushing=True):
        """This factory method creates commands that switch rules on or off."""
        if tag == None: tag = command
        if help == "":
            help = "Switches %s rules on or off (default = on)." % tag
        l = lambda s, status="on": s.switch_command(flushing, tag, status)
        l.__doc__ = help
        # Create a user-defined function in the class object.
        setattr(self.__class__, self.prefix + command, l)

    def list_command(self, variable, operation="list", host=None, flush=None):
        """
        This manages pf list variables to change firewall rules.
        It can either add or delete a host to/from such a list.
        The flushing policy can be set to add or delete or left unset.
        """
        # TODO: Sanity check host as being a valid domain name or IP.
        if operation not in ("add", "del", "list"):
            raise ParameterError, "Argument must be add or del (or absent)."
        if operation == "add":
            if host in variable:
                raise ParameterError, "They are already in the list."
            # Otherwise, put them in the list.
            variable.add(host)
            if flush == "add": self.pf.flush(host)
        else:
            if operation == "del":
              try:
                  variable.remove(host)
              except ValueError:
                  raise ParameterError, "They are not in the list."
              if flush == "del": self.pf.flush(host)
            else:
              pp = pprint.PrettyPrinter(indent=4)
              return pp.pformat(variable)

    def list_factory(self, command, variable, flushing=None, help=None):
        """This factory method creates commands that manipulate lists."""
        if help == None:
            help = "Command adds or deletes a host from a list."""
        l = lambda s, op, host: s.list_command(variable, op, host, flushing)
        l.__doc__ = help
        setattr(self.__class__, self.prefix + command, l)

    # TODO: Make something like list_factory but for single entities.
    # This is so that I can negate them (negated lists are verboten).

    # TODO: Allow user to specify a tag or maybe pf section.
    def show(self):
        """This command shows the active rules to the client."""
        return "\n".join(self.pf.ruleset.render()) + "\n"

    # TODO: Consider merging with show command.
    # (Do we need command-line switches?)
    def number(self):
        """This command enumerates the lines of the pf input for debugging."""
        rules = "\n".join(self.pf.ruleset.render())
        rules = rules.split("\n")
        for n in range(len(rules)):
            rules[n] = "%s %s" % (str(n+1), rules[n])
        rules = "\n".join(rules) + "\n"
        return rules

    def variables(self):
        """This command shows the current state variable namespace."""
        # TODO: Make this even prettier, since it's hard to read.
        pp = pprint.PrettyPrinter(indent=4)
        return pp.pformat(self.pf.namespace)

    def sync(self):
        """Synchronize the rules with pf.  This is done automatically."""
        return self.pf.sync(self.user, force=True)

    def _sync(self):
        """Used internally by sync_proxy."""
        return self.pf.sync(self.user)

    def flush(self, *args, **kws):
        """Flush the state table.  This is done automatically."""
        return self.pf.flush_state(*args, **kws)

    # TODO: Add helper commands to query named sections of pf rules.

import atexit, signal, os, sys


def daemonize(close=False):
   """
   Detach a process from the controlling terminal and run it in the
   background as a daemon.
   """

   # Default maximum for the number of available file descriptors.
   MAXFD = 1024

   # The standard I/O file descriptors are redirected to /dev/null by default.
   if (hasattr(os, "devnull")):
       REDIRECT_TO = os.devnull
   else:
       REDIRECT_TO = "/dev/null"

   try:
      # Fork a child process so the parent can exit.  This returns control to
      # the command-line or shell.  It also guarantees that the child will not
      # be a process group leader, since the child receives a new process ID
      # and inherits the parent's process group ID.  This step is required
      # to insure that the next call to os.setsid is successful.
      pid = os.fork()
   except OSError, e:
      raise Exception, "%s [%d]" % (e.strerror, e.errno)

   if (pid == 0):	# The first child.
      # To become the session leader of this new session and the process group
      # leader of the new process group, we call os.setsid().  The process is
      # also guaranteed not to have a controlling terminal.
      os.setsid()

      # Is ignoring SIGHUP necessary?
      #
      # It's often suggested that the SIGHUP signal should be ignored before
      # the second fork to avoid premature termination of the process.  The
      # reason is that when the first child terminates, all processes, e.g.
      # the second child, in the orphaned group will be sent a SIGHUP.
      #
      # "However, as part of the session management system, there are exactly
      # two cases where SIGHUP is sent on the death of a process:
      #
      #   1) When the process that dies is the session leader of a session that
      #      is attached to a terminal device, SIGHUP is sent to all processes
      #      in the foreground process group of that terminal device.
      #   2) When the death of a process causes a process group to become
      #      orphaned, and one or more processes in the orphaned group are
      #      stopped, then SIGHUP and SIGCONT are sent to all members of the
      #      orphaned group." [2]
      #
      # The first case can be ignored since the child is guaranteed not to have
      # a controlling terminal.  The second case isn't so easy to dismiss.
      # The process group is orphaned when the first child terminates and
      # POSIX.1 requires that every STOPPED process in an orphaned process
      # group be sent a SIGHUP signal followed by a SIGCONT signal.  Since the
      # second child is not STOPPED though, we can safely forego ignoring the
      # SIGHUP signal.  In any case, there are no ill-effects if it is ignored.
      #
      # import signal           # Set handlers for asynchronous events.
      # signal.signal(signal.SIGHUP, signal.SIG_IGN)

      try:
         # Fork a second child and exit immediately to prevent zombies.  This
         # causes the second child process to be orphaned, making the init
         # process responsible for its cleanup.  And, since the first child is
         # a session leader without a controlling terminal, it's possible for
         # it to acquire one by opening a terminal in the future (System V-
         # based systems).  This second fork guarantees that the child is no
         # longer a session leader, preventing the daemon from ever acquiring
         # a controlling terminal.
         pid = os.fork()	# Fork a second child.
      except OSError, e:
         raise Exception, "%s [%d]" % (e.strerror, e.errno)

      if (pid == 0):	# The second child.
         # Since the current working directory may be a mounted filesystem, we
         # avoid the issue of not being able to unmount the filesystem at
         # shutdown time by changing it to the root directory.
         os.chdir("/")
         # We probably don't want the file mode creation mask inherited from
         # the parent, so we give the child complete control over permissions.
         os.umask(0)
      else:
         # exit() or _exit()?  See below.
         os._exit(0)	# Exit parent (the first child) of the second child.
   else:
      # exit() or _exit()?
      # _exit is like exit(), but it doesn't call any functions registered
      # with atexit (and on_exit) or any registered signal handlers.  It also
      # closes any open file descriptors.  Using exit() may cause all stdio
      # streams to be flushed twice and any temporary files may be unexpectedly
      # removed.  It's therefore recommended that child branches of a fork()
      # and the parent branch(es) of a daemon use _exit().
      os._exit(0)	# Exit parent of the first child.
   # If close is true,
   # Close all open file descriptors.  This prevents the child from keeping
   # open any file descriptors inherited from the parent.  There is a variety
   # of methods to accomplish this task.  Three are listed below.
   #
   # Try the system configuration variable, SC_OPEN_MAX, to obtain the maximum
   # number of open file descriptors to close.  If it doesn't exists, use
   # the default value (configurable).
   #
   # try:
   #    maxfd = os.sysconf("SC_OPEN_MAX")
   # except (AttributeError, ValueError):
   #    maxfd = MAXFD
   #
   # OR
   #
   # if (os.sysconf_names.has_key("SC_OPEN_MAX")):
   #    maxfd = os.sysconf("SC_OPEN_MAX")
   # else:
   #    maxfd = MAXFD
   #
   # OR
   #
   # Use the getrlimit method to retrieve the maximum file descriptor number
   # that can be opened by this process.  If there is not limit on the
   # resource, use the default value.
   #
   if close:
       import resource		# Resource usage information.
       maxfd = resource.getrlimit(resource.RLIMIT_NOFILE)[1]
       if (maxfd == resource.RLIM_INFINITY):
           maxfd = MAXFD
  
       # Iterate through and close all file descriptors.
       for fd in range(0, maxfd):
           try:
               os.close(fd)
           except OSError:	# ERROR, fd wasn't open to begin with (ignored)
               pass

   # Redirect the standard I/O file descriptors to the specified file.  Since
   # the daemon has no controlling terminal, most daemons redirect stdin,
   # stdout, and stderr to /dev/null.  This is done to prevent side-effects
   # from reads and writes to the standard I/O file descriptors.

   # This call to open is guaranteed to return the lowest file descriptor,
   # which will be 0 (stdin), since it was closed above.
   os.open(REDIRECT_TO, os.O_RDWR)	# standard input (0)

   # Duplicate standard input to standard output and standard error.
   os.dup2(0, 1)			# standard output (1)
   os.dup2(0, 2)			# standard error (2)

   return(0)


import shelve

def sigint_handler(s, *a, **kw):
    """Save persistent data and exit."""
    l = logger()
    l.log("Interrupted", syslog.LOG_NOTICE)
    # This is the preferred way to exit.
    reactor.stop()

def save_persistent_data(s):
    l = logger()
    l.log("Saving persistent data", syslog.LOG_NOTICE)
    s.close()
    syslog.closelog()

class FileNotFoundError(StandardError):
    pass

def set_persist_hooks(fn=None):
    if fn == None:
        s = dict()
    else:
        try:
            s = shelve.open(fn, writeback=True)
            # TODO: Why doesn't this work?
        except exceptions.AttributeError:
            # It didn't like the writeback parameter.
            pass
        # TODO: Enumerate possible exceptions and only handle those.
        # Tricky because shelve tries to open various database extensions,
        # and any given one may not exist; all also have their own exceptions
        # and they are not mapped into any shelve-specific error.
        # bsddb._db.DBNoSuchFileError
        except:
            raise FileNotFoundError
        else:
            # Save state during normal exit.
            atexit.register(lambda: save_persistent_data(s))
            # Save state during a SIGINT.
            signal.signal(signal.SIGINT,
                          lambda *a, **kw: sigint_handler(s, *a, **kw))
    return magic_dict(s)

# All this is run only if invoked as a script directly.  It is all test code.
# TODO: Move all this to a test script, since it slows down use.
if __name__ == '__main__':

    # Parse command-line options.
    import getopt, sys

    usage = "Usage: %s [--help] [--print] [--loop] [--test] [--no] [--syntax]"\
            % sys.argv[0]
    try:
        opts, args = getopt.getopt(sys.argv[1:], "hlnpst",
                                   [ "help", "loop", "no", "print",
                                     "syntax", "test" ])
    except getopt.GetoptError:
        raise SystemExit, usage
    run = False
    test_print = False
    test = False
    loop = False
    mode = "normal"
    for o, a in opts:
        if o in ("-l", "--loop"):
            loop = True
        if o in ("-t", "--test"):
            test = True
        if o in ("-n", "--no"):
            mode = "test"
        if o in ("-s", "--syntax"):
            mode = "syntax"
        if o in ("-p", "--print"):
            test_print = True
        if o in ("-h", "--help"):
            print usage
            raise SystemExit

    if test:
        
        import unittest

        class test_named_list(unittest.TestCase):
            """Test the enumerated list class"""
            def setUp(self):
                self.el = named_list((("index1", "a"), ("index2", "b")))
            def test_whatever(self):
                self.assertEqual(self.el["index1"], "a")
                self.el["index1"] = "aa"
                self.assertEqual(self.el["index1"], "aa")
                self.assertEqual(self.el["index2"], "b")
                v = ("aa", "b")
                i = 0
                for e in self.el:
                    self.assertEqual(e, v[i])
                    i=i+1

        # Test ipv4.address class.
        class test_ipv4_address(unittest.TestCase):
            def setUp(self):
                self.addy = ipv4.address("127.0.0.1")
            def test_constructor(self):
                for address in ("127.0.0", "127.0.0.0.1", "256.0.0.0"):
                    try:
                        ipv4.address(address)
                    except ParameterError:
                        pass
                    else:
                        raise AssertionError, \
                              "ipv4.address allowing invalid IP addys"
            def test_str(self):
                # Test __str__ functionality.
                assert str(self.addy) == "127.0.0.1"
            def test_eq(self):
                # Test equivalency.
                assert self.addy == ipv4.address("127.0.0.1")
            def test_ned(self):
                # Test inequivalency.
                assert self.addy != ipv4.address("127.0.0.2")

        # Test ipv4.cidr_address_range class.
        class test_ipv4_address_range(unittest.TestCase):
            def setUp(self):
                self.address_range = ipv4.cidr_address_range("127.0.0.1",
                                                             "255.0.0.0")
            def test_contains(self):
                # Test the contains functionality.
                is_in = ipv4.address("127.127.0.10")
                is_out = ipv4.address("196.168.1.1")
                assert self.address_range.contains(is_in)
                assert not self.address_range.contains(is_out)

        # Test that sync_proxy works as expected.
        class test_proxy:
            # NOTE: Class level variable to avoid accesses of instance.
            counter = 0
            def proxied_method(self):
                return True
            def increase_counter(self):
                test_proxy.counter += 1
        class test_sync_proxy(unittest.TestCase):
            def setUp(self):
                self.delegate = test_proxy()
                self.proxy1 = sync_proxy(self.delegate,
                                         self.delegate.increase_counter)
                self.proxy2 = sync_proxy(self.delegate,
                                         self.delegate.increase_counter, 2)
            def test_proxy1(self):
                proxy1 = self.proxy1
                delegate = self.delegate
                # Manually make a delegated call.
                assert proxy1.delegate_call(delegate.proxied_method) == True
                # Check that it was automatically synced.
                assert delegate.counter == 1
                # Check the automated proxy1ing of method calls.
                assert proxy1.proxied_method() == True
                # Assert that sync was called again.
                assert delegate.counter == 2
            def test_proxy2(self):
                proxy2 = self.proxy2
                delegate = self.delegate
                # Test the delayed syncing mechanisms.
                # This call won't sync. (odd)
                assert proxy2.proxied_method() == True
                assert delegate.counter == 2
                # This call will sync. (even)
                assert proxy2.proxied_method() == True
                assert delegate.counter == 3
                # This call will not sync. (odd)
                assert proxy2.proxied_method()
                assert delegate.counter == 3
                # Perform a manual sync which resets sync_in.
                proxy2.sync()
                assert delegate.counter == 4
                # This should not sync. (odd)
                assert proxy2.proxied_method()
                assert delegate.counter == 4

        # Test the pf-variables-as-python-variables conversions.
        class test_pfvars(unittest.TestCase):
            def setUp(self):
                self.macro = macro("foo")
                self.list = pflist((1,2,3))
            def test_macro(self):
                assert self.macro.render() == 'foo'
            def test_list(self):
                assert self.list.render() == '{ 1 2 3 }'
            def test_table(self):
                table1 = table()
                assert str(table1) == "<table000>"
                table2 = table()
                assert str(table2) == "<table001>"

        class test_magic(unittest.TestCase):
            def setUp(self):
                self.macro = magic("foo")
                self.list = magic(["foo", "bar"])
            def test_magic_macro(self):
                assert self.macro.render() == "foo"
            def test_magic_list(self):
                assert self.list.render() == "{ foo bar }"

        class test_magic_dict(unittest.TestCase):
            def setUp(self):
                self.d = dict()
                self.md = magic_dict(self.d)
                self.md["a"] = "a"
                self.md["bc"] = [ "b", "c" ]
            def test_magic_dict_macro(self):
                r = self.md["a"]
                assert r.render() == "a"
            def test_magic_dict_list(self):
                r = self.md["bc"]
                assert r.render() == "{ b c }"

        # TODO: add tests of invalid rule containers.

        thing = unittest.TestSuite()
        thing.addTest(unittest.makeSuite(test_named_list))
        thing.addTest(unittest.makeSuite(test_ipv4_address))
        thing.addTest(unittest.makeSuite(test_ipv4_address_range))
        thing.addTest(unittest.makeSuite(test_sync_proxy))
        thing.addTest(unittest.makeSuite(test_pfvars))
        thing.addTest(unittest.makeSuite(test_magic))
        thing.addTest(unittest.makeSuite(test_magic_dict))
        unittest.TextTestRunner().run(thing)
            
    # Set up some pf rules (and test them in the process).
    if test or test_print:

        # Test the template class and its substitution method.
        d = { "a" : "knights", "cd" : "nee" }
        s = template("We are the $a who say $cd.").substitute(d)
        assert s == "We are the knights who say nee."

        # Test the pf class.
        rs = ruleset(d)
        # Retrieve the filter section.
        filt = rs["filter"]
        assert(isinstance(filt, rule_holder))
        # Create a freestanding rule.
        rule_text = "pass in quick on lo0 all"
        # Test the rule class (make a freestanding rule).
        r = rule(d, rule_text)
        # Test str(rule) functionality.
        assert str(r) == rule_text
        # Test repr(rule) functionality.
        assert repr(r) == "rule(" + str(r) + ")"
        # Test that render works properly.
        assert r.render() == [ rule_text + "\n" ]
        # Test the filter.
        filt.append(r)
        # Test the make_rule functionality.
        marker = filt.make_rule("", tag="tag")
        marker.active = False # Don't execute this rule.
        assert marker != None
        # Test rendering again, this time for no-ops.
        assert marker.render() == []
        # Test inequality of different rule renderings.
        assert r.render() != marker.render()
        # Test inequality of str representations of different rules.
        assert str(r) != str(marker)
        filt.make_rule("block drop all")
        filt.make_rule("block return on xl1 all", tag="return")
        # Test rules that expire
        curtime = time.time()
        # Test timeout routine.
        to1 = filt.make_rule("pass quick on xl2 all", expires=curtime)
        assert 1 > to1.timeout() >= 0 # slop in fp arithmetic
        to2 = filt.make_rule("pass in quick on enc0 all", lifespan=10)
        assert 11 > to2.timeout() >= 9 # slop in fp arithmetic
        # Test some functionality of chains.
        # TODO: change to 5 when expiration -> deletion
        assert len(filt) == 6
        # Test out apply
        number_of_rules = len(filt)
        def fun(*stuff):
            global number_of_rules
            number_of_rules -= 1
        rs.apply(fun)
        assert number_of_rules == 0
        # Test removal of rule.
        number_of_rules = len(filt)
        filt.remove(marker)
        assert len(filt) == (number_of_rules - 1)
        # Create another rule, then delete it during apply sequence.
        filt.make_rule("block in proto tcp from any to any port = www",
                       tag="www")
        number_of_rules = len(filt)
        def delete_www(r, parent, *ancestors):
            if r.tag == "www":
                parent.remove(r)
        rs.apply(delete_www)
        assert len(filt) == (number_of_rules - 1)
        # Test rule series.
        number_of_rules = len(filt)

        # Test rule_queue.
        rq = rule_queue(d, 2)
        assert len(rq) == 0
        rq.make_rule("pass in quick proto tcp from trusted1.host.com port ssh")
        assert len(rq) == 1
        r1 = rq.make_rule("pass in quick from trusted2.host.com")
        assert len(rq) == 2
        rq.make_rule("pass in quick proto tcp from trusted3.host.com port ssh")
        assert len(rq) == 2
        assert rq[0] == r1


        if test_print:
            # Print a rule.
            print rule
            # Print a rule_series.
            print rs
            print repr(rs)
            # Render it as pf rules
            print "As commands:"
            for c in rs.render():
                print c
            print rs

    if loop:
        """Set up some pf rules to test the run-time.

        This is a good example of how to use the dfd_keeper module to
        configure firewall rules for use as one-shot configuration.
        Note how using python variables eliminates some of the
        repetitive elements of specifying rules.
        """
        # Create tables on which the test code operates.
        p = pf(mode=mode)
        # Retrieve the filter table.
        tbl_filter = p["filter"]
        # Retrieve the INPUT chain.
        tbl_filter.make_rule("pass in quick on lo0 all")
        # Make some input rules.
        tbl_filter.make_rule("pass in proto tcp from any to any port = 22")
        tbl_filter.make_rule("pass in proto udp from any to any port = 53",
                             tag="dns")
        # Make some output rules.
        tbl_filter.make_rule("pass out quick on lo0 all")
        # Create some rules that expire, for testing expiration code.
        import time
        curtime = time.time()
        tbl_filter.make_rule("pass proto udp from any to any port = 53",
                             expires=curtime)
        tbl_filter.make_rule("pass proto tcp from any to any port = 80",
                             lifespan=10) # seconds
        # Test rule clusters.
        rs = rule_series()
        tbl_filter.append(rs)
        rs.make_rule("block in quick from www.yahoo.com to any")
        rs.make_rule("block in quick from www.apache.org to any")

        # This shows how to schedule callbacks to expire rules.
        cron = chronos(p)

        # This code illustrates how you might use this module in a program.
        # It listens to port 8007 and accepts CRLF-terminated commands.
        # To automatically send a command, try something like this:
        # $ printf "dns\r\nshow\r\nquit\r\n" | nc localhost 8007
        # Easy, huh?  Alternately, connect manually with telnet.

        class command_dispatch(helper_functions):
            """
            This class defines a command suite for pf.
            
            When initializing it, you pass it a configured set of rules.
            By defining commands of the form cmd_whatever, you make that
            command available to anyone to use via tcp_command_taker.
            """

            def __init__(self, packet_filter):
                self.pf = packet_filter
                self.prefix = "cmd_"
                # Pull in stuff from helper_functions.
                for c in (helper_functions.end_user):
                    setattr(self.__class__, self.prefix + c,
                            getattr(helper_functions, c))

            def cmd_dns(self):
                """Toggle the active flag of any rules tagged as dns."""
                def implement_dns(r, *ancestors):
                    if r.tag == "dns":
                        r.active = not r.active
                self.pf.apply(implement_dns)

            def cmd_render(self):
                """Print the rules on the terminal running the server."""
                for c in self.pf.ruleset.render():
                    print c

            def cmd_show(self):
                """
                Show the rules to the client.

                Note that to print to the client, just return a string!
                """
                response = ""
                for c in self.pf.ruleset.render():
                    response += str(c) + "\n"
                return response

            def cmd_block(self, ip):
                """Block an IP."""
                rs.make_rule("block quick from " + ip + " to any")

            def cmd_sync(self):
                return pf.sync_rules(self.user, force=True)
            
        # Note that this is where we set the listening port.
        reactor.listenTCP(8007,
                          tcp_factory(lambda x=p: command_dispatch(x),
                                      tcp_command_taker))

        # Check how many rules are in filter.
        filtlen = len(tbl_filter)

        # At least one filter rule should be purged in the first second.
        def have_rules_been_purged():
            """Check that a rule has been purged by timeouts."""
            assert filtlen > len(tbl_filter)

        reactor.callLater(1, have_rules_been_purged)

        reactor.run()

