| 
 # Copyright (C) 2010  Internet Systems Consortium.  
#  
# Permission to use, copy, modify, and distribute this software for any  
# purpose with or without fee is hereby granted, provided that the above  
# copyright notice and this permission notice appear in all copies.  
#  
# THE SOFTWARE IS PROVIDED "AS IS" AND INTERNET SYSTEMS CONSORTIUM  
# DISCLAIMS ALL WARRANTIES WITH REGARD TO THIS SOFTWARE INCLUDING ALL  
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL  
# INTERNET SYSTEMS CONSORTIUM BE LIABLE FOR ANY SPECIAL, DIRECT,  
# INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES WHATSOEVER RESULTING  
# FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN ACTION OF CONTRACT,  
# NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF OR IN CONNECTION  
# WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.  
  
'''Tests for the XfroutSession and UnixSockServer classes '''  
  
  
import unittest  
import os  
from isc.testutils.tsigctx_mock import MockTSIGContext  
from isc.testutils.ccsession_mock import MockModuleCCSession  
from isc.cc.session import *  
import isc.config  
from isc.dns import *  
from isc.testutils.rrset_utils import *  
from xfrout import *  
import xfrout  
import isc.log  
import isc.acl.dns  
import isc.server_common.tsig_keyring  
  
TESTDATA_SRCDIR = os.getenv("TESTDATASRCDIR")  
TSIG_KEY = TSIGKey("example.com:SFuWd/q99SzF8Yzd1QbB9g==")  
  
#  
# Commonly used (mostly constant) test parameters  
#  
TEST_ZONE_NAME_STR = "example.com."  
TEST_ZONE_NAME = Name(TEST_ZONE_NAME_STR)  
TEST_RRCLASS = RRClass.IN()  
IXFR_OK_VERSION = 2011111802  
IXFR_NG_VERSION = 2011111803  
SOA_CURRENT_VERSION = 2011112001  
  
# our fake socket, where we can read and insert messages  
class MySocket():  
    def __init__(self, family, type):  
        self.family = family  
        self.type = type  
        self.sendqueue = bytearray()  
  
    def connect(self, to):  
        pass  
  
    def close(self):  
        pass  
  
    def send(self, data):  
        self.sendqueue.extend(data);  
        return len(data)  
  
    def readsent(self):  
        if len(self.sendqueue) >= 2:  
            size = 2 + struct.unpack("!H", self.sendqueue[:2])[0]  
        else:  
            size = 0  
        result = self.sendqueue[:size]  
        self.sendqueue = self.sendqueue[size:]  
        return result  
  
    def read_msg(self, parse_options=Message.PARSE_DEFAULT, need_len=False):  
        sent_data = self.readsent()  
        get_msg = Message(Message.PARSE)  
        get_msg.from_wire(bytes(sent_data[2:]), parse_options)  
        if need_len:  
            return (get_msg, len(sent_data) - 2)  
        return get_msg  
  
    def clear_send(self):  
        del self.sendqueue[:]  
  
class MockDataSrcClient:  
    def __init__(self, type, config):  
        pass  
  
    def find_zone(self, zone_name):  
        '''Mock version of find_zone().  
  
        It returns itself (subsequently acting as a mock ZoneFinder) for  
        some test zone names.  For a special name it returns NOTFOUND to  
        emulate the condition where the specified zone doen't exist.  
  
        '''  
        self._zone_name = zone_name  
        if zone_name == Name('notauth.example.com'):  
            return (isc.datasrc.DataSourceClient.NOTFOUND, None)  
        return (isc.datasrc.DataSourceClient.SUCCESS, self)  
  
    def find(self, name, rrtype, options=ZoneFinder.FIND_DEFAULT):  
        '''Mock ZoneFinder.find().  
  
        (At the moment) this method only handles query for type SOA.  
        By default it returns a normal SOA RR(set) whose owner name is  
        the query name  It also emulates some unusual cases for special  
        zone names.  
  
        '''  
        if name == Name('nosoa.example.com') and rrtype == RRType.SOA():  
            return (ZoneFinder.NXDOMAIN, None, 0)  
        elif name == Name('multisoa.example.com') and rrtype == RRType.SOA():  
            soa_rrset = create_soa(SOA_CURRENT_VERSION)  
            soa_rrset.add_rdata(soa_rrset.get_rdata()[0])  
            return (ZoneFinder.SUCCESS, soa_rrset, 0)  
        elif name == Name('maxserial.example.com'):  
            soa_rrset = create_soa(0xffffffff)  
            return (ZoneFinder.SUCCESS, soa_rrset, 0)  
120        elif rrtype == RRType.SOA():  
            return (ZoneFinder.SUCCESS, create_soa(SOA_CURRENT_VERSION), 0)  
        raise ValueError('Unexpected input to mock finder: bug in test case?')  
  
    def get_iterator(self, zone_name, adjust_ttl=False):  
        if zone_name == Name('notauth.example.com'):  
            raise isc.datasrc.Error('no such zone')  
        self._zone_name = zone_name  
        return self  
  
    def get_soa(self):  # emulate ZoneIterator.get_soa()  
        if self._zone_name == Name('nosoa.example.com'):  
            return None  
        soa_rrset = create_soa(SOA_CURRENT_VERSION)  
        if self._zone_name == Name('multisoa.example.com'):  
            soa_rrset.add_rdata(soa_rrset.get_rdata()[0])  
        return soa_rrset  
  
    def get_journal_reader(self, zone_name, begin_serial, end_serial):  
        if zone_name == Name('notauth2.example.com'):  
            return isc.datasrc.ZoneJournalReader.NO_SUCH_ZONE, None  
        if zone_name == Name('nojournal.example.com'):  
            raise isc.datasrc.NotImplemented('journaling not supported')  
        if begin_serial == IXFR_NG_VERSION:  
            return isc.datasrc.ZoneJournalReader.NO_SUCH_VERSION, None  
        return isc.datasrc.ZoneJournalReader.SUCCESS, self  
  
class MyCCSession(isc.config.ConfigData):  
    def __init__(self):  
        module_spec = isc.config.module_spec_from_file(  
            xfrout.SPECFILE_LOCATION)  
        ConfigData.__init__(self, module_spec)  
  
    def get_remote_config_value(self, module_name, identifier):  
155        if module_name == "Auth" and identifier == "database_file":  
            return "initdb.file", False  
        else:  
            return "unknown", False  
  
# This constant dictionary stores all default configuration parameters  
# defined in the xfrout spec file.  
DEFAULT_CONFIG = MyCCSession().get_full_config()  
  
# We subclass the Session class we're testing here, only overriding a few  
# methods  
class MyXfroutSession(XfroutSession):  
    def _handle(self):  
        pass  
  
    def _close_socket(self):  
        pass  
  
    def _send_data(self, sock, data):  
        size = len(data)  
        total_count = 0  
        while total_count < size:  
            count = sock.send(data[total_count:])  
            total_count += count  
  
class Dbserver:  
    def __init__(self):  
        self._shutdown_event = threading.Event()  
        self.transfer_counter = 0  
        self._max_transfers_out = DEFAULT_CONFIG['transfers_out']  
    def get_db_file(self):  
        return 'test.sqlite3'  
    def increase_transfers_counter(self):  
        self.transfer_counter += 1  
        return True  
    def decrease_transfers_counter(self):  
        self.transfer_counter -= 1  
  
class TestXfroutSessionBase(unittest.TestCase):  
    '''Base classs for tests related to xfrout sessions  
  
    This class defines common setup/teadown and utility methods.  Actual  
    tests are delegated to subclasses.  
  
    '''  
    def getmsg(self):  
        msg = Message(Message.PARSE)  
        msg.from_wire(self.mdata)  
        return msg  
  
    def create_mock_tsig_ctx(self, error):  
        # This helper function creates a MockTSIGContext for a given key  
        # and TSIG error to be used as a result of verify (normally faked  
        # one)  
        mock_ctx = MockTSIGContext(TSIG_KEY)  
        mock_ctx.error = error  
        return mock_ctx  
  
    def message_has_tsig(self, msg):  
        return msg.get_tsig_record() is not None  
  
    def create_request_data(self, with_question=True, with_tsig=False,  
                            ixfr=None, qtype=None, zone_name=TEST_ZONE_NAME,  
                            soa_class=TEST_RRCLASS, num_soa=1):  
        '''Create a commonly used XFR request data.  
  
        By default the request type is AXFR; if 'ixfr' is an integer,  
        the request type will be IXFR and an SOA with the serial being  
        the value of the parameter will be included in the authority  
        section.  
  
        This method has various minor parameters only for creating bad  
        format requests for testing purposes:  
        qtype: the RR type of the question section.  By default automatically  
               determined by the value of ixfr, but could be an invalid type  
               for testing.  
        zone_name: the query (zone) name.  for IXFR, it's also used as  
                   the owner name of the SOA in the authority section.  
        soa_class: IXFR only.  The RR class of the SOA RR in the authority  
                   section.  
        num_soa: IXFR only.  The number of SOA RDATAs in the authority  
                 section.  
        '''  
        msg = Message(Message.RENDER)  
        query_id = 0x1035  
        msg.set_qid(query_id)  
        msg.set_opcode(Opcode.QUERY())  
        msg.set_rcode(Rcode.NOERROR())  
        req_type = RRType.AXFR() if ixfr is None else RRType.IXFR()  
        if with_question:  
            msg.add_question(Question(zone_name, RRClass.IN(),  
                                      req_type if qtype is None else qtype))  
        if req_type == RRType.IXFR():  
            soa = RRset(zone_name, soa_class, RRType.SOA(), RRTTL(0))  
            # In the RDATA only the serial matters.  
            for i in range(0, num_soa):  
                soa.add_rdata(Rdata(RRType.SOA(), soa_class,  
                                    'm r ' + str(ixfr) + ' 1 1 1 1'))  
            msg.add_rrset(Message.SECTION_AUTHORITY, soa)  
  
        renderer = MessageRenderer()  
        if with_tsig:  
            tsig_ctx = MockTSIGContext(TSIG_KEY)  
            msg.to_wire(renderer, tsig_ctx)  
        else:  
            msg.to_wire(renderer)  
        request_data = renderer.get_data()  
        return request_data  
  
    def set_request_type(self, type):  
        self.xfrsess._request_type = type  
        if type == RRType.AXFR():  
            self.xfrsess._request_typestr = 'AXFR'  
        else:  
            self.xfrsess._request_typestr = 'IXFR'  
  
    def setUp(self):  
        self.sock = MySocket(socket.AF_INET,socket.SOCK_STREAM)  
        self.xfrsess = MyXfroutSession(self.sock, None, Dbserver(),  
                                       TSIGKeyRing(),  
                                       (socket.AF_INET, socket.SOCK_STREAM,  
                                        ('127.0.0.1', 12345)),  
                                       # When not testing ACLs, simply accept  
                                       isc.acl.dns.REQUEST_LOADER.load(  
                                           [{"action": "ACCEPT"}]),  
                                       {})  
        self.set_request_type(RRType.AXFR()) # test AXFR by default  
        self.mdata = self.create_request_data()  
        self.soa_rrset = create_soa(SOA_CURRENT_VERSION)  
        # some test replaces a module-wide function.  We should ensure the  
        # original is used elsewhere.  
        self.orig_get_rrset_len = xfrout.get_rrset_len  
  
    def tearDown(self):  
        xfrout.get_rrset_len = self.orig_get_rrset_len  
        # transfer_counter must be always be reset no matter happens within  
        # the XfroutSession object.  We check the condition here.  
        self.assertEqual(0, self.xfrsess._server.transfer_counter)  
  
class TestXfroutSession(TestXfroutSessionBase):  
    def test_quota_error(self):  
        '''Emulating the server being too busy.  
  
        '''  
        self.xfrsess._request_data = self.mdata  
        self.xfrsess._server.increase_transfers_counter = lambda : False  
        XfroutSession._handle(self.xfrsess)  
        self.assertEqual(self.sock.read_msg().get_rcode(), Rcode.REFUSED())  
  
    def test_quota_ok(self):  
        '''The default case in terms of the xfrout quota.  
  
        '''  
        # set up a bogus request, which should result in FORMERR. (it only  
        # has to be something that is different from the previous case)  
        self.xfrsess._request_data = \  
            self.create_request_data(ixfr=IXFR_OK_VERSION, num_soa=2)  
        # Replace the data source client to avoid datasrc related exceptions  
        self.xfrsess.ClientClass = MockDataSrcClient  
        XfroutSession._handle(self.xfrsess)  
        self.assertEqual(self.sock.read_msg().get_rcode(), Rcode.FORMERR())  
  
    def test_exception_from_session(self):  
        '''Test the case where the main processing raises an exception.  
  
        We just check it doesn't any unexpected disruption and (in tearDown)  
        transfer_counter is correctly reset to 0.  
  
        '''  
        def dns_xfrout_start(fd, msg, quota):  
            raise ValueError('fake exception')  
        self.xfrsess.dns_xfrout_start = dns_xfrout_start  
        XfroutSession._handle(self.xfrsess)  
  
    def test_parse_query_message(self):  
        # Valid AXFR  
        [get_rcode, get_msg] = self.xfrsess._parse_query_message(self.mdata)  
        self.assertEqual(RRType.AXFR(), self.xfrsess._request_type)  
        self.assertEqual(get_rcode.to_text(), "NOERROR")  
  
        # Valid IXFR  
        request_data = self.create_request_data(ixfr=2011111801)  
        rcode, msg = self.xfrsess._parse_query_message(request_data)  
        self.assertEqual(RRType.IXFR(), self.xfrsess._request_type)  
        self.assertEqual(Rcode.NOERROR(), rcode)  
  
        # Broken request: no question  
        self.assertRaises(RuntimeError, self.xfrsess._parse_query_message,  
                          self.create_request_data(with_question=False))  
  
        # Broken request: invalid RR type (neither AXFR nor IXFR)  
        self.assertRaises(RuntimeError, self.xfrsess._parse_query_message,  
                          self.create_request_data(qtype=RRType.A()))  
  
        # NOERROR  
        request_data = self.create_request_data(ixfr=IXFR_OK_VERSION)  
        rcode, msg = self.xfrsess._parse_query_message(request_data)  
        self.assertEqual(rcode.to_text(), "NOERROR")  
  
        # tsig signed query message  
        request_data = self.create_request_data(with_tsig=True)  
        # BADKEY  
        [rcode, msg] = self.xfrsess._parse_query_message(request_data)  
        self.assertEqual(rcode.to_text(), "NOTAUTH")  
        self.assertTrue(self.xfrsess._tsig_ctx is not None)  
        # NOERROR  
        self.assertEqual(TSIGKeyRing.SUCCESS,  
                         self.xfrsess._tsig_key_ring.add(TSIG_KEY))  
        [rcode, msg] = self.xfrsess._parse_query_message(request_data)  
        self.assertEqual(rcode.to_text(), "NOERROR")  
        self.assertTrue(self.xfrsess._tsig_ctx is not None)  
  
    def check_transfer_acl(self, acl_setter):  
        # ACL checks, put some ACL inside  
        acl_setter(isc.acl.dns.REQUEST_LOADER.load([  
            {  
                "from": "127.0.0.1",  
                "action": "ACCEPT"  
            },  
            {  
                "from": "192.0.2.1",  
                "action": "DROP"  
            }  
        ]))  
        # Localhost (the default in this test) is accepted  
        rcode, msg = self.xfrsess._parse_query_message(self.mdata)  
        self.assertEqual(rcode.to_text(), "NOERROR")  
        # This should be dropped completely, therefore returning None  
        self.xfrsess._remote = (socket.AF_INET, socket.SOCK_STREAM,  
                                ('192.0.2.1', 12345))  
        rcode, msg = self.xfrsess._parse_query_message(self.mdata)  
        self.assertEqual(None, rcode)  
        # This should be refused, therefore REFUSED  
        self.xfrsess._remote = (socket.AF_INET, socket.SOCK_STREAM,  
                                ('192.0.2.2', 12345))  
        rcode, msg = self.xfrsess._parse_query_message(self.mdata)  
        self.assertEqual(rcode.to_text(), "REFUSED")  
  
        # TSIG signed request  
        request_data = self.create_request_data(with_tsig=True)  
  
        # If the TSIG check fails, it should not check ACL  
        # (If it checked ACL as well, it would just drop the request)  
        self.xfrsess._remote = (socket.AF_INET, socket.SOCK_STREAM,  
                                ('192.0.2.1', 12345))  
        self.xfrsess._tsig_key_ring = TSIGKeyRing()  
        rcode, msg = self.xfrsess._parse_query_message(request_data)  
        self.assertEqual(rcode.to_text(), "NOTAUTH")  
        self.assertTrue(self.xfrsess._tsig_ctx is not None)  
  
        # ACL using TSIG: successful case  
        acl_setter(isc.acl.dns.REQUEST_LOADER.load([  
            {"key": "example.com", "action": "ACCEPT"}, {"action": "REJECT"}  
        ]))  
        self.assertEqual(TSIGKeyRing.SUCCESS,  
                         self.xfrsess._tsig_key_ring.add(TSIG_KEY))  
        [rcode, msg] = self.xfrsess._parse_query_message(request_data)  
        self.assertEqual(rcode.to_text(), "NOERROR")  
  
        # ACL using TSIG: key name doesn't match; should be rejected  
        acl_setter(isc.acl.dns.REQUEST_LOADER.load([  
            {"key": "example.org", "action": "ACCEPT"}, {"action": "REJECT"}  
        ]))  
        [rcode, msg] = self.xfrsess._parse_query_message(request_data)  
        self.assertEqual(rcode.to_text(), "REFUSED")  
  
        # ACL using TSIG: no TSIG; should be rejected  
        acl_setter(isc.acl.dns.REQUEST_LOADER.load([  
            {"key": "example.org", "action": "ACCEPT"}, {"action": "REJECT"}  
        ]))  
        [rcode, msg] = self.xfrsess._parse_query_message(self.mdata)  
        self.assertEqual(rcode.to_text(), "REFUSED")  
  
        #  
        # ACL using IP + TSIG: both should match  
        #  
        acl_setter(isc.acl.dns.REQUEST_LOADER.load([  
                {"ALL": [{"key": "example.com"}, {"from": "192.0.2.1"}],  
                 "action": "ACCEPT"},  
                {"action": "REJECT"}  
        ]))  
        # both matches  
        self.xfrsess._remote = (socket.AF_INET, socket.SOCK_STREAM,  
                                ('192.0.2.1', 12345))  
        [rcode, msg] = self.xfrsess._parse_query_message(request_data)  
        self.assertEqual(rcode.to_text(), "NOERROR")  
        # TSIG matches, but address doesn't  
        self.xfrsess._remote = (socket.AF_INET, socket.SOCK_STREAM,  
                                ('192.0.2.2', 12345))  
        [rcode, msg] = self.xfrsess._parse_query_message(request_data)  
        self.assertEqual(rcode.to_text(), "REFUSED")  
        # Address matches, but TSIG doesn't (not included)  
        self.xfrsess._remote = (socket.AF_INET, socket.SOCK_STREAM,  
                                ('192.0.2.1', 12345))  
        [rcode, msg] = self.xfrsess._parse_query_message(self.mdata)  
        self.assertEqual(rcode.to_text(), "REFUSED")  
        # Neither address nor TSIG matches  
        self.xfrsess._remote = (socket.AF_INET, socket.SOCK_STREAM,  
                                ('192.0.2.2', 12345))  
        [rcode, msg] = self.xfrsess._parse_query_message(self.mdata)  
        self.assertEqual(rcode.to_text(), "REFUSED")  
  
    def test_transfer_acl(self):  
        # ACL checks only with the default ACL  
        def acl_setter(acl):  
            self.xfrsess._acl = acl  
        self.check_transfer_acl(acl_setter)  
  
    def test_transfer_zoneacl(self):  
        # ACL check with a per zone ACL + default ACL.  The per zone ACL  
        # should match the queryied zone, so it should be used.  
        def acl_setter(acl):  
            zone_key = ('IN', 'example.com.')  
            self.xfrsess._zone_config[zone_key] = {}  
            self.xfrsess._zone_config[zone_key]['transfer_acl'] = acl  
            self.xfrsess._acl = isc.acl.dns.REQUEST_LOADER.load([  
                    {"from": "127.0.0.1", "action": "DROP"}])  
        self.check_transfer_acl(acl_setter)  
  
    def test_transfer_zoneacl_nomatch(self):  
        # similar to the previous one, but the per zone doesn't match the  
        # query.  The default should be used.  
        def acl_setter(acl):  
            zone_key = ('IN', 'example.org.')  
            self.xfrsess._zone_config[zone_key] = {}  
            self.xfrsess._zone_config[zone_key]['transfer_acl'] = \  
                isc.acl.dns.REQUEST_LOADER.load([  
                    {"from": "127.0.0.1", "action": "DROP"}])  
            self.xfrsess._acl = acl  
        self.check_transfer_acl(acl_setter)  
  
    def test_get_transfer_acl(self):  
        # set the default ACL.  If there's no specific zone ACL, this one  
        # should be used.  
        self.xfrsess._acl = isc.acl.dns.REQUEST_LOADER.load([  
                {"from": "127.0.0.1", "action": "ACCEPT"}])  
        acl = self.xfrsess._get_transfer_acl(Name('example.com'), RRClass.IN())  
        self.assertEqual(acl, self.xfrsess._acl)  
  
        # install a per zone config with transfer ACL for example.com.  Then  
        # that ACL will be used for example.com; for others the default ACL  
        # will still be used.  
        com_acl = isc.acl.dns.REQUEST_LOADER.load([  
                {"from": "127.0.0.1", "action": "REJECT"}])  
        self.xfrsess._zone_config[('IN', 'example.com.')] = {}  
        self.xfrsess._zone_config[('IN', 'example.com.')]['transfer_acl'] = \  
            com_acl  
        self.assertEqual(com_acl,  
                         self.xfrsess._get_transfer_acl(Name('example.com'),  
                                                        RRClass.IN()))  
        self.assertEqual(self.xfrsess._acl,  
                         self.xfrsess._get_transfer_acl(Name('example.org'),  
                                                        RRClass.IN()))  
  
        # Name matching should be case insensitive.  
        self.assertEqual(com_acl,  
                         self.xfrsess._get_transfer_acl(Name('EXAMPLE.COM'),  
                                                        RRClass.IN()))  
  
    def test_send_data(self):  
        self.xfrsess._send_data(self.sock, self.mdata)  
        senddata = self.sock.readsent()  
        self.assertEqual(senddata, self.mdata)  
  
    def test_reply_xfrout_query_with_error_rcode(self):  
        msg = self.getmsg()  
        self.xfrsess._reply_query_with_error_rcode(msg, self.sock, Rcode(3))  
        get_msg = self.sock.read_msg()  
        self.assertEqual(get_msg.get_rcode().to_text(), "NXDOMAIN")  
  
        # tsig signed message  
        msg = self.getmsg()  
        self.xfrsess._tsig_ctx = self.create_mock_tsig_ctx(TSIGError.NOERROR)  
        self.xfrsess._reply_query_with_error_rcode(msg, self.sock, Rcode(3))  
        get_msg = self.sock.read_msg()  
        self.assertEqual(get_msg.get_rcode().to_text(), "NXDOMAIN")  
        self.assertTrue(self.message_has_tsig(get_msg))  
  
    def test_send_message(self):  
        msg = self.getmsg()  
        msg.make_response()  
        # SOA record data with different cases  
        soa_rrset = RRset(Name('Example.com.'), RRClass.IN(), RRType.SOA(),  
                               RRTTL(3600))  
        soa_rrset.add_rdata(Rdata(RRType.SOA(), RRClass.IN(),  
                                  'master.Example.com. admin.exAmple.com. ' +  
                                  '2011112001 3600 1800 2419200 7200'))  
        msg.add_rrset(Message.SECTION_ANSWER, soa_rrset)  
        self.xfrsess._send_message(self.sock, msg)  
        send_out_data = self.sock.readsent()[2:]  
  
        # CASE_INSENSITIVE compression mode  
        render = MessageRenderer();  
        render.set_length_limit(XFROUT_MAX_MESSAGE_SIZE)  
        msg.to_wire(render)  
        self.assertNotEqual(render.get_data(), send_out_data)  
  
        # CASE_SENSITIVE compression mode  
        render.clear()  
        render.set_compress_mode(MessageRenderer.CASE_SENSITIVE)  
        render.set_length_limit(XFROUT_MAX_MESSAGE_SIZE)  
        msg.to_wire(render)  
        self.assertEqual(render.get_data(), send_out_data)  
  
    def test_clear_message(self):  
        msg = self.getmsg()  
        qid = msg.get_qid()  
        opcode = msg.get_opcode()  
        rcode = msg.get_rcode()  
  
        self.xfrsess._clear_message(msg)  
        self.assertEqual(msg.get_qid(), qid)  
        self.assertEqual(msg.get_opcode(), opcode)  
        self.assertEqual(msg.get_rcode(), rcode)  
        self.assertTrue(msg.get_header_flag(Message.HEADERFLAG_AA))  
  
    def test_send_message_with_last_soa(self):  
        msg = self.getmsg()  
        msg.make_response()  
  
        self.xfrsess._send_message_with_last_soa(msg, self.sock,  
                                                 self.soa_rrset, 0)  
        get_msg = self.sock.read_msg()  
        # tsig context does not exist  
        self.assertFalse(self.message_has_tsig(get_msg))  
  
        self.assertEqual(get_msg.get_rr_count(Message.SECTION_QUESTION), 1)  
        self.assertEqual(get_msg.get_rr_count(Message.SECTION_ANSWER), 1)  
        self.assertEqual(get_msg.get_rr_count(Message.SECTION_AUTHORITY), 0)  
  
        answer = get_msg.get_section(Message.SECTION_ANSWER)[0]  
        self.assertEqual(answer.get_name().to_text(), "example.com.")  
        self.assertEqual(answer.get_class(), RRClass("IN"))  
        self.assertEqual(answer.get_type().to_text(), "SOA")  
        rdata = answer.get_rdata()  
        self.assertEqual(rdata[0], self.soa_rrset.get_rdata()[0])  
  
        # Sending the message with last soa together  
        self.xfrsess._send_message_with_last_soa(msg, self.sock,  
                                                 self.soa_rrset, 0)  
        get_msg = self.sock.read_msg()  
        # tsig context does not exist  
        self.assertFalse(self.message_has_tsig(get_msg))  
  
    def test_send_message_with_last_soa_with_tsig(self):  
        # create tsig context  
        self.xfrsess._tsig_ctx = self.create_mock_tsig_ctx(TSIGError.NOERROR)  
  
        msg = self.getmsg()  
        msg.make_response()  
  
        # Sending the message with last soa together  
        self.xfrsess._send_message_with_last_soa(msg, self.sock,  
                                                 self.soa_rrset, 0)  
        get_msg = self.sock.read_msg()  
        self.assertTrue(self.message_has_tsig(get_msg))  
  
        self.assertEqual(get_msg.get_rr_count(Message.SECTION_QUESTION), 1)  
        self.assertEqual(get_msg.get_rr_count(Message.SECTION_ANSWER), 1)  
        self.assertEqual(get_msg.get_rr_count(Message.SECTION_AUTHORITY), 0)  
  
    def test_trigger_send_message_with_last_soa(self):  
        rrset_a = RRset(Name("example.com"), RRClass.IN(), RRType.A(), RRTTL(3600))  
        rrset_a.add_rdata(Rdata(RRType.A(), RRClass.IN(), "192.0.2.1"))  
  
        msg = self.getmsg()  
        msg.make_response()  
        msg.add_rrset(Message.SECTION_ANSWER, rrset_a)  
  
        # length larger than MAX-len(rrset)  
        length_need_split = xfrout.XFROUT_MAX_MESSAGE_SIZE - \  
            get_rrset_len(self.soa_rrset) + 1  
  
        # give the function a value that is larger than MAX-len(rrset)  
        # this should have triggered the sending of two messages  
        # (1 with the rrset we added manually, and 1 that triggered  
        # the sending in _with_last_soa)  
        self.xfrsess._send_message_with_last_soa(msg, self.sock,  
                                                 self.soa_rrset,  
                                                 length_need_split)  
        get_msg = self.sock.read_msg()  
        self.assertFalse(self.message_has_tsig(get_msg))  
        self.assertEqual(get_msg.get_rr_count(Message.SECTION_QUESTION), 1)  
        self.assertEqual(get_msg.get_rr_count(Message.SECTION_ANSWER), 1)  
        self.assertEqual(get_msg.get_rr_count(Message.SECTION_AUTHORITY), 0)  
  
        answer = get_msg.get_section(Message.SECTION_ANSWER)[0]  
        self.assertEqual(answer.get_name().to_text(), "example.com.")  
        self.assertEqual(answer.get_class(), RRClass("IN"))  
        self.assertEqual(answer.get_type().to_text(), "A")  
        rdata = answer.get_rdata()  
        self.assertEqual(rdata[0].to_text(), "192.0.2.1")  
  
        get_msg = self.sock.read_msg()  
        self.assertFalse(self.message_has_tsig(get_msg))  
        self.assertEqual(get_msg.get_rr_count(Message.SECTION_QUESTION), 0)  
        self.assertEqual(get_msg.get_rr_count(Message.SECTION_ANSWER), 1)  
        self.assertEqual(get_msg.get_rr_count(Message.SECTION_AUTHORITY), 0)  
  
        answer = get_msg.get_section(Message.SECTION_ANSWER)[0]  
        self.assertEqual(answer.get_name().to_text(), "example.com.")  
        self.assertEqual(answer.get_class(), RRClass("IN"))  
        self.assertEqual(answer.get_type().to_text(), "SOA")  
        rdata = answer.get_rdata()  
        self.assertEqual(rdata[0], self.soa_rrset.get_rdata()[0])  
  
        # and it should not have sent anything else  
        self.assertEqual(0, len(self.sock.sendqueue))  
  
    def test_trigger_send_message_with_last_soa_with_tsig(self):  
        self.xfrsess._tsig_ctx = self.create_mock_tsig_ctx(TSIGError.NOERROR)  
        msg = self.getmsg()  
        msg.make_response()  
        msg.add_rrset(Message.SECTION_ANSWER, self.soa_rrset)  
  
        # length larger than MAX-len(rrset)  
        length_need_split = xfrout.XFROUT_MAX_MESSAGE_SIZE - \  
            get_rrset_len(self.soa_rrset) + 1  
  
        # give the function a value that is larger than MAX-len(rrset)  
        # this should have triggered the sending of two messages  
        # (1 with the rrset we added manually, and 1 that triggered  
        # the sending in _with_last_soa)  
        self.xfrsess._send_message_with_last_soa(msg, self.sock,  
                                                 self.soa_rrset,  
                                                 length_need_split)  
        # Both messages should have TSIG RRs  
        get_msg = self.sock.read_msg()  
        self.assertTrue(self.message_has_tsig(get_msg))  
        get_msg = self.sock.read_msg()  
        self.assertTrue(self.message_has_tsig(get_msg))  
        # and it should not have sent anything else  
        self.assertEqual(0, len(self.sock.sendqueue))  
  
    def test_get_rrset_len(self):  
        self.assertEqual(82, get_rrset_len(self.soa_rrset))  
  
    def test_xfrout_axfr_setup(self):  
        self.xfrsess.ClientClass = MockDataSrcClient  
        # Successful case.  A zone iterator should be set up.  
        self.assertEqual(self.xfrsess._xfrout_setup(  
                self.getmsg(), TEST_ZONE_NAME, TEST_RRCLASS), Rcode.NOERROR())  
        self.assertNotEqual(None, self.xfrsess._iterator)  
  
        # Failure cases  
        self.assertEqual(self.xfrsess._xfrout_setup(  
                self.getmsg(), Name('notauth.example.com'), TEST_RRCLASS),  
                         Rcode.NOTAUTH())  
        self.assertEqual(self.xfrsess._xfrout_setup(  
                self.getmsg(), Name('nosoa.example.com'), TEST_RRCLASS),  
                         Rcode.SERVFAIL())  
        self.assertEqual(self.xfrsess._xfrout_setup(  
                self.getmsg(), Name('multisoa.example.com'), TEST_RRCLASS),  
                         Rcode.SERVFAIL())  
  
    def test_xfrout_ixfr_setup(self):  
        self.xfrsess.ClientClass = MockDataSrcClient  
        self.set_request_type(RRType.IXFR())  
  
        # Successful case of pure IXFR.  A zone journal reader should be set  
        # up.  
        self.mdata = self.create_request_data(ixfr=IXFR_OK_VERSION)  
        self.assertEqual(self.xfrsess._xfrout_setup(  
                self.getmsg(), TEST_ZONE_NAME, TEST_RRCLASS), Rcode.NOERROR())  
        self.assertNotEqual(None, self.xfrsess._jnl_reader)  
  
        # Successful case, but as a result of falling back to AXFR-style  
        # IXFR.  A zone iterator should be set up instead of a journal reader.  
        self.mdata = self.create_request_data(ixfr=IXFR_NG_VERSION)  
        self.assertEqual(self.xfrsess._xfrout_setup(  
                self.getmsg(), TEST_ZONE_NAME, TEST_RRCLASS), Rcode.NOERROR())  
        self.assertNotEqual(None, self.xfrsess._iterator)  
        self.assertEqual(None, self.xfrsess._jnl_reader)  
  
        # Successful case, but the requested SOA serial is greater than that of  
        # the local SOA.  Both iterator and jnl_reader should be None,  
        # indicating that the response will contain just one SOA.  
        self.mdata = self.create_request_data(ixfr=SOA_CURRENT_VERSION+1)  
        self.assertEqual(self.xfrsess._xfrout_setup(  
                self.getmsg(), TEST_ZONE_NAME, TEST_RRCLASS), Rcode.NOERROR())  
        self.assertEqual(None, self.xfrsess._iterator)  
        self.assertEqual(None, self.xfrsess._jnl_reader)  
  
        # Similar to the previous case, but the requested serial is equal to  
        # the local SOA.  
        self.mdata = self.create_request_data(ixfr=SOA_CURRENT_VERSION)  
        self.assertEqual(self.xfrsess._xfrout_setup(  
                self.getmsg(), TEST_ZONE_NAME, TEST_RRCLASS), Rcode.NOERROR())  
        self.assertEqual(None, self.xfrsess._iterator)  
        self.assertEqual(None, self.xfrsess._jnl_reader)  
  
        # Similar to the previous case, but the comparison should be done  
        # based on serial number arithmetic, not as integers.  
        zone_name = Name('maxserial.example.com') # whose SOA is 0xffffffff  
        self.mdata = self.create_request_data(ixfr=1, zone_name=zone_name)  
        self.assertEqual(self.xfrsess._xfrout_setup(  
                 self.getmsg(), zone_name, TEST_RRCLASS), Rcode.NOERROR())  
        self.assertEqual(None, self.xfrsess._iterator)  
        self.assertEqual(None, self.xfrsess._jnl_reader)  
  
        # The data source doesn't support journaling.  Should fallback to AXFR.  
        zone_name = Name('nojournal.example.com')  
        self.mdata = self.create_request_data(ixfr=IXFR_OK_VERSION,  
                                              zone_name=zone_name)  
        self.assertEqual(self.xfrsess._xfrout_setup(  
                self.getmsg(), zone_name, TEST_RRCLASS), Rcode.NOERROR())  
        self.assertNotEqual(None, self.xfrsess._iterator)  
  
        # Failure cases  
        zone_name = Name('notauth.example.com')  
        self.mdata = self.create_request_data(ixfr=IXFR_OK_VERSION,  
                                              zone_name=zone_name)  
        self.assertEqual(self.xfrsess._xfrout_setup(  
                self.getmsg(), zone_name, TEST_RRCLASS), Rcode.NOTAUTH())  
        # this is a strange case: zone's SOA will be found but the journal  
        # reader won't be created due to 'no such zone'.  
        zone_name = Name('notauth2.example.com')  
        self.mdata = self.create_request_data(ixfr=IXFR_OK_VERSION,  
                                              zone_name=zone_name)  
        self.assertEqual(self.xfrsess._xfrout_setup(  
                self.getmsg(), zone_name, TEST_RRCLASS), Rcode.NOTAUTH())  
        zone_name = Name('nosoa.example.com')  
        self.mdata = self.create_request_data(ixfr=IXFR_OK_VERSION,  
                                              zone_name=zone_name)  
        self.assertEqual(self.xfrsess._xfrout_setup(  
                self.getmsg(), zone_name, TEST_RRCLASS), Rcode.SERVFAIL())  
        zone_name = Name('multisoa.example.com')  
        self.mdata = self.create_request_data(ixfr=IXFR_OK_VERSION,  
                                              zone_name=zone_name)  
        self.assertEqual(self.xfrsess._xfrout_setup(  
                self.getmsg(), zone_name, TEST_RRCLASS), Rcode.SERVFAIL())  
  
        # query name doesn't match the SOA's owner  
        self.mdata = self.create_request_data(ixfr=IXFR_OK_VERSION)  
        self.assertEqual(self.xfrsess._xfrout_setup(  
                self.getmsg(), zone_name, TEST_RRCLASS), Rcode.FORMERR())  
  
        # query's RR class doesn't match the SOA's class  
        zone_name = TEST_ZONE_NAME # make sure the name matches this time  
        self.mdata = self.create_request_data(ixfr=IXFR_OK_VERSION,  
                                              soa_class=RRClass.CH())  
        self.assertEqual(self.xfrsess._xfrout_setup(  
                self.getmsg(), zone_name, TEST_RRCLASS), Rcode.FORMERR())  
  
        # multiple SOA RRs  
        self.mdata = self.create_request_data(ixfr=IXFR_OK_VERSION,  
                                              num_soa=2)  
        self.assertEqual(self.xfrsess._xfrout_setup(  
                self.getmsg(), zone_name, TEST_RRCLASS), Rcode.FORMERR())  
  
    def test_dns_xfrout_start_formerror(self):  
        # formerror  
        self.xfrsess.dns_xfrout_start(self.sock, b"\xd6=\x00\x00\x00\x01\x00")  
        sent_data = self.sock.readsent()  
        self.assertEqual(len(sent_data), 0)  
  
    def test_dns_xfrout_start_notauth(self):  
        def notauth(msg, name, rrclass):  
            return Rcode.NOTAUTH()  
        self.xfrsess._xfrout_setup = notauth  
        self.xfrsess.dns_xfrout_start(self.sock, self.mdata)  
        get_msg = self.sock.read_msg()  
        self.assertEqual(get_msg.get_rcode().to_text(), "NOTAUTH")  
  
    def test_dns_xfrout_start_datasrc_servfail(self):  
        def internal_raise(x, y):  
            raise isc.datasrc.Error('exception for the sake of test')  
        self.xfrsess.ClientClass = internal_raise  
        self.xfrsess.dns_xfrout_start(self.sock, self.mdata)  
        self.assertEqual(self.sock.read_msg().get_rcode(), Rcode.SERVFAIL())  
  
    def test_dns_xfrout_start_noerror(self):  
        def noerror(msg, name, rrclass):  
            return Rcode.NOERROR()  
        self.xfrsess._xfrout_setup = noerror  
  
        def myreply(msg, sock):  
            self.sock.send(b"success")  
  
        self.xfrsess._reply_xfrout_query = myreply  
        self.xfrsess.dns_xfrout_start(self.sock, self.mdata)  
        self.assertEqual(self.sock.readsent(), b"success")  
  
    def test_reply_xfrout_query_axfr(self):  
        self.xfrsess._soa = self.soa_rrset  
        self.xfrsess._iterator = [self.soa_rrset]  
        self.xfrsess._reply_xfrout_query(self.getmsg(), self.sock)  
        reply_msg = self.sock.read_msg()  
        self.assertEqual(reply_msg.get_rr_count(Message.SECTION_ANSWER), 2)  
  
    def test_reply_xfrout_query_axfr_with_tsig(self):  
        rrset = RRset(Name('a.example.com'), RRClass.IN(), RRType.A(),  
                      RRTTL(3600))  
        rrset.add_rdata(Rdata(RRType.A(), RRClass.IN(), '192.0.2.1'))  
        global xfrout  
  
        def get_rrset_len(rrset):  
            return 65520  
  
        self.xfrsess._soa = self.soa_rrset  
        self.xfrsess._iterator = [rrset for i in range(0, 100)]  
        xfrout.get_rrset_len = get_rrset_len  
  
        self.xfrsess._tsig_ctx = self.create_mock_tsig_ctx(TSIGError.NOERROR)  
        self.xfrsess._reply_xfrout_query(self.getmsg(), self.sock)  
  
        # All messages must have TSIG as we don't support the feature of  
        # skipping intermediate TSIG records (with bulk signing).  
        for i in range(0, 102): # 102 = all 100 RRs from iterator and 2 SOAs  
            reply_msg = self.sock.read_msg()  
            # With the hack of get_rrset_len() above, every message must have  
            # exactly one RR in the answer section.  
            self.assertEqual(reply_msg.get_rr_count(Message.SECTION_ANSWER), 1)  
            self.assertTrue(self.message_has_tsig(reply_msg))  
  
        # and it should not have sent anything else  
        self.assertEqual(0, len(self.sock.sendqueue))  
  
    def test_reply_xfrout_query_ixfr(self):  
        # Creating a pure (incremental) IXFR response.  Intermediate SOA  
        # RRs won't be skipped.  
        self.xfrsess._soa = create_soa(SOA_CURRENT_VERSION)  
        self.xfrsess._iterator = [create_soa(IXFR_OK_VERSION),  
                                  create_a(Name('a.example.com'), '192.0.2.2'),  
                                  create_soa(SOA_CURRENT_VERSION),  
                                  create_aaaa(Name('a.example.com'),  
                                              '2001:db8::1')]  
        self.xfrsess._jnl_reader = self.xfrsess._iterator  
        self.xfrsess._reply_xfrout_query(self.getmsg(), self.sock)  
        reply_msg = self.sock.read_msg(Message.PRESERVE_ORDER)  
        actual_records = reply_msg.get_section(Message.SECTION_ANSWER)  
  
        expected_records = self.xfrsess._iterator[:]  
        expected_records.insert(0, create_soa(SOA_CURRENT_VERSION))  
        expected_records.append(create_soa(SOA_CURRENT_VERSION))  
  
        self.assertEqual(len(expected_records), len(actual_records))  
        for (expected_rr, actual_rr) in zip(expected_records, actual_records):  
            self.assertTrue(rrsets_equal(expected_rr, actual_rr))  
  
    def test_reply_xfrout_query_axfr_maxlen(self):  
        # The test RR(set) has the length of 65535 - 12 (size of hdr) bytes:  
        # owner name = 1 (root), fixed fields (type,class,TTL,RDLEN) = 10  
        # RDATA = 65512 (= 65535 - 12 - 1 - 10)  
        self.xfrsess._soa = self.soa_rrset  
        test_rr = create_generic(Name('.'), 65512)  
        self.xfrsess._iterator = [self.soa_rrset, test_rr]  
        self.xfrsess._reply_xfrout_query(self.getmsg(), self.sock)  
        # The first message should contain the beginning SOA, and only that RR  
        r = self.sock.read_msg()  
        self.assertEqual(1, r.get_rr_count(Message.SECTION_ANSWER))  
        self.assertTrue(rrsets_equal(self.soa_rrset,  
                                     r.get_section(Message.SECTION_ANSWER)[0]))  
        # The second message should contain the beginning SOA, and only that RR  
        # The wire format data should have the possible maximum size.  
        r, rlen = self.sock.read_msg(need_len=True)  
        self.assertEqual(65535, rlen)  
        self.assertEqual(1, r.get_rr_count(Message.SECTION_ANSWER))  
        self.assertTrue(rrsets_equal(test_rr,  
                                     r.get_section(Message.SECTION_ANSWER)[0]))  
        # The third message should contain the ending SOA, and only that RR  
        r = self.sock.read_msg()  
        self.assertEqual(1, r.get_rr_count(Message.SECTION_ANSWER))  
        self.assertTrue(rrsets_equal(self.soa_rrset,  
                                     r.get_section(Message.SECTION_ANSWER)[0]))  
  
        # there should be no more message  
        self.assertEqual(0, len(self.sock.sendqueue))  
  
    def maxlen_test_common_setup(self, tsig=False):  
        '''Common initialization for some of the tests below  
  
        For those tests we use '.' for all owner names and names in RDATA  
        to avoid having unexpected results due to compression.  It returns  
        the created SOA for convenience.  
  
        If tsig is True, also setup TSIG (mock) context.  In our test cases  
        the size of the TSIG RR is 81 bytes (key name = example.com,  
        algorithm = hmac-md5)  
  
        '''  
        soa = RRset(Name('.'), RRClass.IN(), RRType.SOA(), RRTTL(3600))  
        soa.add_rdata(Rdata(RRType.SOA(), RRClass.IN(), '. . 0 0 0 0 0'))  
        self.mdata = self.create_request_data(zone_name=Name('.'))  
        self.xfrsess._soa = soa  
        if tsig:  
            self.xfrsess._tsig_ctx = \  
                self.create_mock_tsig_ctx(TSIGError.NOERROR)  
            self.xfrsess._tsig_len = 81  
        return soa  
  
    def maxlen_test_common_checks(self, soa_rr, test_rr, expected_n_rr):  
        '''A set of common assertion checks for some tests below.  
  
        In all cases two AXFR response messages should have been created.  
        expected_n_rr is a list of two elements, each specifies the expected  
        number of answer RRs for each message: expected_n_rr[0] is the expected  
        number of the first answer RRs; expected_n_rr[1] is the expected number  
        of the second answer RRs.  The message that contains two RRs should  
        have the maximum possible wire length (65535 bytes).  And, in all  
        cases, the resulting RRs should be in the order of SOA, another RR,  
        SOA.  
  
        '''  
        # Check the first message  
        r, rlen = self.sock.read_msg(need_len=True)  
        if expected_n_rr[0] == 2:  
            self.assertEqual(65535, rlen)  
        self.assertEqual(expected_n_rr[0],  
                         r.get_rr_count(Message.SECTION_ANSWER))  
        actual_rrs = r.get_section(Message.SECTION_ANSWER)[:]  
  
        # Check the second message  
        r, rlen = self.sock.read_msg(need_len=True)  
        if expected_n_rr[1] == 2:  
            self.assertEqual(65535, rlen)  
        self.assertEqual(expected_n_rr[1],  
                         r.get_rr_count(Message.SECTION_ANSWER))  
        actual_rrs.extend(r.get_section(Message.SECTION_ANSWER))  
        for (expected_rr, actual_rr) in zip([soa_rr, test_rr, soa_rr],  
                                            actual_rrs):  
            self.assertTrue(rrsets_equal(expected_rr, actual_rr))  
  
        # there should be no more message  
        self.assertEqual(0, len(self.sock.sendqueue))  
  
    def test_reply_xfrout_query_axfr_maxlen_with_soa(self):  
        # Similar to the 'maxlen' test, but the first message should be  
        # able to contain both SOA and the large RR.  
        soa = self.maxlen_test_common_setup()  
  
        # The first message will contain the question (5 bytes), so the  
        # test RDATA should allow a room for that.  
        test_rr = create_generic(Name('.'), 65512 - 5 - get_rrset_len(soa))  
        self.xfrsess._iterator = [soa, test_rr]  
        self.xfrsess._reply_xfrout_query(self.getmsg(), self.sock)  
        self.maxlen_test_common_checks(soa, test_rr, [2, 1])  
  
    def test_reply_xfrout_query_axfr_maxlen_with_soa_with_tsig(self):  
        # Similar to the previous case, but with TSIG (whose size is 81 bytes).  
        soa = self.maxlen_test_common_setup(True)  
        test_rr = create_generic(Name('.'), 65512 - 5 - 81 -  
                                 get_rrset_len(soa))  
        self.xfrsess._iterator = [soa, test_rr]  
        self.xfrsess._reply_xfrout_query(self.getmsg(), self.sock)  
        self.maxlen_test_common_checks(soa, test_rr, [2, 1])  
  
    def test_reply_xfrout_query_axfr_maxlen_with_endsoa(self):  
        # Similar to the max w/ soa test, but the first message cannot contain  
        # both SOA and the long RR due to the question section.  The second  
        # message should be able to contain both.  
        soa = self.maxlen_test_common_setup()  
        test_rr = create_generic(Name('.'), 65512 - get_rrset_len(soa))  
        self.xfrsess._iterator = [soa, test_rr]  
        self.xfrsess._reply_xfrout_query(self.getmsg(), self.sock)  
        self.maxlen_test_common_checks(soa, test_rr, [1, 2])  
  
    def test_reply_xfrout_query_axfr_maxlen_with_endsoa_with_tsig(self):  
        # Similar to the previous case, but with TSIG.  
        soa = self.maxlen_test_common_setup(True)  
        test_rr = create_generic(Name('.'), 65512 - 81 - get_rrset_len(soa))  
        self.xfrsess._iterator = [soa, test_rr]  
        self.xfrsess._reply_xfrout_query(self.getmsg(), self.sock)  
        self.maxlen_test_common_checks(soa, test_rr, [1, 2])  
  
    def test_reply_xfrout_query_axfr_toobigdata(self):  
        # Similar to the 'maxlen' test, but the RR doesn't even fit in a  
        # single message.  
        self.xfrsess._soa = self.soa_rrset  
        test_rr = create_generic(Name('.'), 65513) # 1 byte larger than 'max'  
        self.xfrsess._iterator = [self.soa_rrset, test_rr]  
        # the reply method should fail with exception  
        self.assertRaises(XfroutSessionError, self.xfrsess._reply_xfrout_query,  
                          self.getmsg(), self.sock)  
        # The first message should still have been sent and contain the  
        # beginning SOA, and only that RR  
        r = self.sock.read_msg()  
        self.assertEqual(1, r.get_rr_count(Message.SECTION_ANSWER))  
        self.assertTrue(rrsets_equal(self.soa_rrset,  
                                     r.get_section(Message.SECTION_ANSWER)[0]))  
        # And there should have been no other messages sent  
        self.assertEqual(0, len(self.sock.sendqueue))  
  
    def test_reply_xfrout_query_ixfr_soa_only(self):  
        # Creating an IXFR response that contains only one RR, which is the  
        # SOA of the current version.  
        self.xfrsess._soa = create_soa(SOA_CURRENT_VERSION)  
        self.xfrsess._iterator = None  
        self.xfrsess._jnl_reader = None  
        self.xfrsess._reply_xfrout_query(self.getmsg(), self.sock)  
        reply_msg = self.sock.read_msg(Message.PRESERVE_ORDER)  
        answer = reply_msg.get_section(Message.SECTION_ANSWER)  
        self.assertEqual(1, len(answer))  
        self.assertTrue(rrsets_equal(create_soa(SOA_CURRENT_VERSION),  
                                     answer[0]))  
  
class TestXfroutSessionWithSQLite3(TestXfroutSessionBase):  
    '''Tests for XFR-out sessions using an SQLite3 DB.  
  
    These are provided mainly to confirm the implementation actually works  
    in an environment closer to actual operational environments.  So we  
    only check a few common cases; other details are tested using mock  
    data sources.  
  
    '''  
    def setUp(self):  
        super().setUp()  
        self.xfrsess._request_data = self.mdata  
        self.xfrsess._server.get_db_file = lambda : TESTDATA_SRCDIR + \  
            'test.sqlite3'  
        self.ns_name = 'a.dns.example.com'  
  
    def check_axfr_stream(self, response):  
        '''Common checks for AXFR(-style) response for the test zone.  
        '''  
        # This zone contains two A RRs for the same name with different TTLs.  
        # These TTLs should be preseved in the AXFR stream.  
        actual_records = response.get_section(Message.SECTION_ANSWER)  
        self.assertEqual(5, len(actual_records))  
        # The first and last RR should be the expected SOA  
        expected_soa = create_soa(2011112001)  
        self.assertTrue(rrsets_equal(expected_soa, actual_records[0]))  
        self.assertTrue(rrsets_equal(expected_soa, actual_records[-1]))  
  
        # The ordering of the intermediate RRs can differ depending on the  
        # internal details of the SQLite3 library, so we sort them by a simple  
        # rule sufficient for the purpose here, and then compare them.  
        expected_others = [create_ns(self.ns_name),  
                           create_a(Name(self.ns_name), '192.0.2.1', 3600),  
                           create_a(Name(self.ns_name), '192.0.2.2', 7200)]  
        keyfn = lambda x: (x.get_type(), x.get_ttl())  
        for (expected_rr, actual_rr) in zip(sorted(expected_others, key=keyfn),  
                                            sorted(actual_records[1:4],  
                                                   key=keyfn)):  
            self.assertTrue(rrsets_equal(expected_rr, actual_rr))  
  
    def test_axfr_normal_session(self):  
        XfroutSession._handle(self.xfrsess)  
        response = self.sock.read_msg(Message.PRESERVE_ORDER);  
        self.assertEqual(Rcode.NOERROR(), response.get_rcode())  
        self.check_axfr_stream(response)  
  
    def test_ixfr_to_axfr(self):  
        self.xfrsess._request_data = \  
            self.create_request_data(ixfr=IXFR_NG_VERSION)  
        XfroutSession._handle(self.xfrsess)  
        response = self.sock.read_msg(Message.PRESERVE_ORDER);  
        self.assertEqual(Rcode.NOERROR(), response.get_rcode())  
        # This is an AXFR-style IXFR.  So the question section should indicate  
        # that it's an IXFR resposne.  
        self.assertEqual(RRType.IXFR(), response.get_question()[0].get_type())  
        self.check_axfr_stream(response)  
  
    def test_ixfr_normal_session(self):  
        # See testdata/creatediff.py.  There are 8 changes between two  
        # versions.  So the answer section should contain all of these and  
        # two beginning and trailing SOAs.  
        self.xfrsess._request_data = \  
            self.create_request_data(ixfr=IXFR_OK_VERSION)  
        XfroutSession._handle(self.xfrsess)  
        response = self.sock.read_msg(Message.PRESERVE_ORDER)  
        actual_records = response.get_section(Message.SECTION_ANSWER)  
        expected_records = [create_soa(2011112001), create_soa(2011111802),  
                            create_soa(2011111900),  
                            create_a(Name(self.ns_name), '192.0.2.2', 7200),  
                            create_soa(2011111900),  
                            create_a(Name(self.ns_name), '192.0.2.53'),  
                            create_aaaa(Name(self.ns_name), '2001:db8::1'),  
                            create_soa(2011112001),  
                            create_a(Name(self.ns_name), '192.0.2.1'),  
                            create_soa(2011112001)]  
        self.assertEqual(len(expected_records), len(actual_records))  
        for (expected_rr, actual_rr) in zip(expected_records, actual_records):  
            self.assertTrue(rrsets_equal(expected_rr, actual_rr))  
  
    def ixfr_soa_only_common_checks(self, request_serial):  
        self.xfrsess._request_data = \  
            self.create_request_data(ixfr=request_serial)  
        XfroutSession._handle(self.xfrsess)  
        response = self.sock.read_msg(Message.PRESERVE_ORDER);  
        answers = response.get_section(Message.SECTION_ANSWER)  
        self.assertEqual(1, len(answers))  
        self.assertTrue(rrsets_equal(create_soa(SOA_CURRENT_VERSION),  
                                     answers[0]))  
  
    def test_ixfr_soa_only(self):  
        # The requested SOA serial is the latest one.  The response should  
        # contain exactly one SOA of that serial.  
        self.ixfr_soa_only_common_checks(SOA_CURRENT_VERSION)  
  
    def test_ixfr_soa_only2(self):  
        # Similar to the previous test, but the requested SOA is larger than  
        # the current.  
        self.ixfr_soa_only_common_checks(SOA_CURRENT_VERSION + 1)  
  
class MyUnixSockServer(UnixSockServer):  
    def __init__(self):  
        self._shutdown_event = threading.Event()  
        self._common_init()  
        self._cc = MyCCSession()  
        self.update_config_data(self._cc.get_full_config())  
  
class TestUnixSockServer(unittest.TestCase):  
    def setUp(self):  
        self.write_sock, self.read_sock = socket.socketpair()  
        self.unix = MyUnixSockServer()  
  
    def test_tsig_keyring(self):  
        """  
        Check we use the global keyring when starting a request.  
        """  
        try:  
            # These are just so the keyring can be started  
            self.unix._cc.add_remote_config_by_name = \  
                lambda name, callback: None  
            self.unix._cc.get_remote_config_value = \  
                lambda module, name: ([], True)  
            self.unix._cc.remove_remote_config = lambda name: None  
            isc.server_common.tsig_keyring.init_keyring(self.unix._cc)  
            # These are not really interesting for the test. These are just  
            # handled over, so strings are OK.  
            self.unix._guess_remote = lambda sock: "Address"  
            self.unix._zone_config = "Zone config"  
            self.unix._acl = "acl"  
            # This would be the handler class, but we just check it is passed  
            # the right parametes, so function is enough for that.  
            keys = isc.server_common.tsig_keyring.get_keyring()  
            def handler(sock, data, server, keyring, address, acl, config):  
                self.assertEqual("sock", sock)  
                self.assertEqual("data", data)  
                self.assertEqual(self.unix, server)  
                self.assertEqual(keys, keyring)  
                self.assertEqual("Address", address)  
                self.assertEqual("acl", acl)  
                self.assertEqual("Zone config", config)  
            self.unix.RequestHandlerClass = handler  
            self.unix.finish_request("sock", "data")  
        finally:  
            isc.server_common.tsig_keyring.deinit_keyring()  
  
    def test_guess_remote(self):  
        """Test we can guess the remote endpoint when we have only the  
           file descriptor. This is needed, because we get only that one  
           from auth."""  
        # We test with UDP, as it can be "connected" without other  
        # endpoint.  Note that in the current implementation _guess_remote()  
        # unconditionally returns SOCK_STREAM.  
        sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)  
        sock.connect(('127.0.0.1', 12345))  
        self.assertEqual((socket.AF_INET, socket.SOCK_STREAM,  
                          ('127.0.0.1', 12345)),  
                         self.unix._guess_remote(sock.fileno()))  
exit        if socket.has_ipv6:  
            # Don't check IPv6 address on hosts not supporting them  
            sock = socket.socket(socket.AF_INET6, socket.SOCK_DGRAM)  
            sock.connect(('::1', 12345))  
            self.assertEqual((socket.AF_INET6, socket.SOCK_STREAM,  
                              ('::1', 12345, 0, 0)),  
                             self.unix._guess_remote(sock.fileno()))  
            # Try when pretending there's no IPv6 support  
            # (No need to pretend when there's really no IPv6)  
            xfrout.socket.has_ipv6 = False  
            sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)  
            sock.connect(('127.0.0.1', 12345))  
            self.assertEqual((socket.AF_INET, socket.SOCK_STREAM,  
                              ('127.0.0.1', 12345)),  
                             self.unix._guess_remote(sock.fileno()))  
            # Return it back  
            xfrout.socket.has_ipv6 = True  
  
    def test_receive_query_message(self):  
        send_msg = b"\xd6=\x00\x00\x00\x01\x00"  
        msg_len = struct.pack('H', socket.htons(len(send_msg)))  
        self.write_sock.send(msg_len)  
        self.write_sock.send(send_msg)  
        recv_msg = self.unix._receive_query_message(self.read_sock)  
        self.assertEqual(recv_msg, send_msg)  
  
    def check_default_ACL(self):  
        context = isc.acl.dns.RequestContext(socket.getaddrinfo("127.0.0.1",  
                                             1234, 0, socket.SOCK_DGRAM,  
                                             socket.IPPROTO_UDP,  
                                             socket.AI_NUMERICHOST)[0][4])  
        self.assertEqual(isc.acl.acl.ACCEPT, self.unix._acl.execute(context))  
  
    def check_loaded_ACL(self, acl):  
        context = isc.acl.dns.RequestContext(socket.getaddrinfo("127.0.0.1",  
                                             1234, 0, socket.SOCK_DGRAM,  
                                             socket.IPPROTO_UDP,  
                                             socket.AI_NUMERICHOST)[0][4])  
        self.assertEqual(isc.acl.acl.ACCEPT, acl.execute(context))  
        context = isc.acl.dns.RequestContext(socket.getaddrinfo("192.0.2.1",  
                                             1234, 0, socket.SOCK_DGRAM,  
                                             socket.IPPROTO_UDP,  
                                             socket.AI_NUMERICHOST)[0][4])  
        self.assertEqual(isc.acl.acl.REJECT, acl.execute(context))  
  
    def test_update_config_data(self):  
        self.check_default_ACL()  
        self.unix.update_config_data({'transfers_out':10 })  
        self.assertEqual(self.unix._max_transfers_out, 10)  
        self.check_default_ACL()  
  
        self.unix.update_config_data({'transfers_out':9})  
        self.assertEqual(self.unix._max_transfers_out, 9)  
  
        # Load the ACL  
        self.unix.update_config_data({'transfer_acl': [{'from': '127.0.0.1',  
                                               'action': 'ACCEPT'}]})  
        self.check_loaded_ACL(self.unix._acl)  
        # Pass a wrong data there and check it does not replace the old one  
        self.assertRaises(XfroutConfigError,  
                          self.unix.update_config_data,  
                          {'transfer_acl': ['Something bad']})  
        self.check_loaded_ACL(self.unix._acl)  
  
    def test_zone_config_data(self):  
        # By default, there's no specific zone config  
        self.assertEqual({}, self.unix._zone_config)  
  
        # Adding config for a specific zone.  The config is empty unless  
        # explicitly specified.  
        self.unix.update_config_data({'zone_config':  
                                          [{'origin': 'example.com',  
                                            'class': 'IN'}]})  
        self.assertEqual({}, self.unix._zone_config[('IN', 'example.com.')])  
  
        # zone class can be omitted  
        self.unix.update_config_data({'zone_config':  
                                          [{'origin': 'example.com'}]})  
        self.assertEqual({}, self.unix._zone_config[('IN', 'example.com.')])  
  
        # zone class, name are stored in the "normalized" form.  class  
        # strings are upper cased, names are down cased.  
        self.unix.update_config_data({'zone_config':  
                                          [{'origin': 'EXAMPLE.com'}]})  
        self.assertEqual({}, self.unix._zone_config[('IN', 'example.com.')])  
  
        # invalid zone class, name will result in exceptions  
        self.assertRaises(EmptyLabel,  
                          self.unix.update_config_data,  
                          {'zone_config': [{'origin': 'bad..example'}]})  
        self.assertRaises(InvalidRRClass,  
                          self.unix.update_config_data,  
                          {'zone_config': [{'origin': 'example.com',  
                                            'class': 'badclass'}]})  
  
        # Configuring a couple of more zones  
        self.unix.update_config_data({'zone_config':  
                                          [{'origin': 'example.com'},  
                                           {'origin': 'example.com',  
                                            'class': 'CH'},  
                                           {'origin': 'example.org'}]})  
        self.assertEqual({}, self.unix._zone_config[('IN', 'example.com.')])  
        self.assertEqual({}, self.unix._zone_config[('CH', 'example.com.')])  
        self.assertEqual({}, self.unix._zone_config[('IN', 'example.org.')])  
  
        # Duplicate data: should be rejected with an exception  
        self.assertRaises(XfroutConfigError,  
                          self.unix.update_config_data,  
                          {'zone_config': [{'origin': 'example.com'},  
                                           {'origin': 'example.org'},  
                                           {'origin': 'example.com'}]})  
  
    def test_zone_config_data_with_acl(self):  
        # Similar to the previous test, but with transfer_acl config  
        self.unix.update_config_data({'zone_config':  
                                          [{'origin': 'example.com',  
                                            'transfer_acl':  
                                                [{'from': '127.0.0.1',  
                                                  'action': 'ACCEPT'}]}]})  
        acl = self.unix._zone_config[('IN', 'example.com.')]['transfer_acl']  
        self.check_loaded_ACL(acl)  
  
        # invalid ACL syntax will be rejected with exception  
        self.assertRaises(XfroutConfigError,  
                          self.unix.update_config_data,  
                          {'zone_config': [{'origin': 'example.com',  
                                            'transfer_acl':  
                                                [{'action': 'BADACTION'}]}]})  
  
    def test_get_db_file(self):  
        self.assertEqual(self.unix.get_db_file(), "initdb.file")  
  
    def test_increase_transfers_counter(self):  
        self.unix._max_transfers_out = 10  
        count = self.unix._transfers_counter  
        self.assertEqual(self.unix.increase_transfers_counter(), True)  
        self.assertEqual(count + 1, self.unix._transfers_counter)  
  
        self.unix._max_transfers_out = 0  
        count = self.unix._transfers_counter  
        self.assertEqual(self.unix.increase_transfers_counter(), False)  
        self.assertEqual(count, self.unix._transfers_counter)  
  
    def test_decrease_transfers_counter(self):  
        count = self.unix._transfers_counter  
        self.unix.decrease_transfers_counter()  
        self.assertEqual(count - 1, self.unix._transfers_counter)  
  
    def _remove_file(self, sock_file):  
        try:  
            os.remove(sock_file)  
        except OSError:  
            pass  
  
    def test_sock_file_in_use_file_exist(self):  
        sock_file = 'temp.sock.file'  
        self._remove_file(sock_file)  
        self.assertFalse(self.unix._sock_file_in_use(sock_file))  
        self.assertFalse(os.path.exists(sock_file))  
  
    def test_sock_file_in_use_file_not_exist(self):  
        self.assertFalse(self.unix._sock_file_in_use('temp.sock.file'))  
  
    def _start_unix_sock_server(self, sock_file):  
        serv = ThreadingUnixStreamServer(sock_file, BaseRequestHandler)  
        serv_thread = threading.Thread(target=serv.serve_forever)  
        serv_thread.setDaemon(True)  
        serv_thread.start()  
  
    def test_sock_file_in_use(self):  
        sock_file = 'temp.sock.file'  
        self._remove_file(sock_file)  
        self.assertFalse(self.unix._sock_file_in_use(sock_file))  
        self._start_unix_sock_server(sock_file)  
  
        old_stdout = sys.stdout  
        sys.stdout = open(os.devnull, 'w')  
        self.assertTrue(self.unix._sock_file_in_use(sock_file))  
        sys.stdout = old_stdout  
  
    def test_remove_unused_sock_file_in_use(self):  
        sock_file = 'temp.sock.file'  
        self._remove_file(sock_file)  
        self.assertFalse(self.unix._sock_file_in_use(sock_file))  
        self._start_unix_sock_server(sock_file)  
        old_stdout = sys.stdout  
        sys.stdout = open(os.devnull, 'w')  
        try:  
            self.unix._remove_unused_sock_file(sock_file)  
        except SystemExit:  
            pass  
        else:  
            # This should never happen  
            self.assertTrue(False)  
  
        sys.stdout = old_stdout  
  
    def test_remove_unused_sock_file_dir(self):  
        import tempfile  
        dir_name = tempfile.mkdtemp()  
        old_stdout = sys.stdout  
        sys.stdout = open(os.devnull, 'w')  
        try:  
            self.unix._remove_unused_sock_file(dir_name)  
        except SystemExit:  
            pass  
        else:  
            # This should never happen  
            self.assertTrue(False)  
  
        sys.stdout = old_stdout  
        os.rmdir(dir_name)  
  
class TestInitialization(unittest.TestCase):  
    def setEnv(self, name, value):  
        if value is None:  
            if name in os.environ:  
                del os.environ[name]  
        else:  
            os.environ[name] = value  
  
    def setUp(self):  
        self._oldSocket = os.getenv("BIND10_XFROUT_SOCKET_FILE")  
        self._oldFromBuild = os.getenv("B10_FROM_BUILD")  
  
    def tearDown(self):  
        self.setEnv("B10_FROM_BUILD", self._oldFromBuild)  
        self.setEnv("BIND10_XFROUT_SOCKET_FILE", self._oldSocket)  
        # Make sure even the computed values are back  
        xfrout.init_paths()  
  
    def testNoEnv(self):  
        self.setEnv("B10_FROM_BUILD", None)  
        self.setEnv("BIND10_XFROUT_SOCKET_FILE", None)  
        xfrout.init_paths()  
        self.assertEqual(xfrout.UNIX_SOCKET_FILE,  
                         "@@LOCALSTATEDIR@@/bind10-devel/auth_xfrout_conn")  
  
    def testProvidedSocket(self):  
        self.setEnv("B10_FROM_BUILD", None)  
        self.setEnv("BIND10_XFROUT_SOCKET_FILE", "The/Socket/File")  
        xfrout.init_paths()  
        self.assertEqual(xfrout.UNIX_SOCKET_FILE, "The/Socket/File")  
  
class MyNotifier():  
    def __init__(self):  
        self.shutdown_called = False  
  
    def shutdown(self):  
        self.shutdown_called = True  
  
class MyXfroutServer(XfroutServer):  
    def __init__(self):  
        self._cc = MockModuleCCSession()  
        self._shutdown_event = threading.Event()  
        self._notifier = MyNotifier()  
        self._unix_socket_server = None  
        # Disable the wait for threads  
        self._wait_for_threads = lambda : None  
  
class TestXfroutServer(unittest.TestCase):  
    def setUp(self):  
        self.xfrout_server = MyXfroutServer()  
  
    def test_shutdown(self):  
        self.xfrout_server.shutdown()  
        self.assertTrue(self.xfrout_server._notifier.shutdown_called)  
        self.assertTrue(self.xfrout_server._cc.stopped)  
  
exitif __name__== "__main__":  
    isc.log.resetUnitTestRootLogger()  
    unittest.main()  
                
             |