# Twisted, the Framework of Your Internet # Copyright (C) 2001 Matthew W. Lefkowitz # # This library is free software; you can redistribute it and/or # modify it under the terms of version 2.1 of the GNU Lesser General Public # License as published by the Free Software Foundation. # # This library is distributed in the hope that it will be useful, # but WITHOUT ANY WARRANTY; without even the implied warranty of # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU # Lesser General Public License for more details. # # You should have received a copy of the GNU Lesser General Public # License along with this library; if not, write to the Free Software # Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA """ Test cases for twisted.smtp module. """ from twisted.trial import unittest import twisted.internet.protocol, twisted.protocols.smtp from twisted import protocols from twisted import internet from twisted.protocols import loopback, smtp from twisted.internet import defer, protocol from twisted.test.test_protocols import StringIOWithoutClosing import string, re from cStringIO import StringIO class DummyMessage: def __init__(self, domain, user): self.domain = domain self.user = user self.buffer = [] def lineReceived(self, line): # Throw away the generated Received: header if not re.match('Received: From yyy.com \(\[.*\]\) by localhost;', line): self.buffer.append(line) def eomReceived(self): message = string.join(self.buffer, '\n')+'\n' self.domain.messages[self.user.dest.local].append(message) deferred = defer.Deferred() deferred.callback("saved") return deferred class DummyDomain: def __init__(self, names): self.messages = {} for name in names: self.messages[name] = [] def exists(self, user, success, failure): if self.messages.has_key(user.dest.local): success(user) else: failure(user) def startMessage(self, user): return DummyMessage(self, user) class SMTPTestCase(unittest.TestCase): messages = [('foo@bar.com', ['foo@baz.com', 'qux@baz.com'], '''\ Subject: urgent\015 \015 Someone set up us the bomb!\015 ''')] mbox = {'foo': ['Subject: urgent\n\nSomeone set up us the bomb!\n']} def setUp(self): self.factory = smtp.SMTPFactory() self.factory.domains = {} self.factory.domains['baz.com'] = DummyDomain(['foo']) self.output = StringIOWithoutClosing() self.transport = internet.protocol.FileWrapper(self.output) def testMessages(self): from twisted.mail import protocols protocol = protocols.DomainSMTP() protocol.service = self.factory protocol.factory = self.factory protocol.makeConnection(self.transport) protocol.lineReceived('HELO yyy.com') for message in self.messages: protocol.lineReceived('MAIL FROM:<%s>' % message[0]) for target in message[1]: protocol.lineReceived('RCPT TO:<%s>' % target) protocol.lineReceived('DATA') protocol.dataReceived(message[2]) protocol.lineReceived('.') protocol.lineReceived('QUIT') if self.mbox != self.factory.domains['baz.com'].messages: raise AssertionError(self.factory.domains['baz.com'].messages) protocol.timeoutID.cancel() mail = '''\ Subject: hello Goodbye ''' class MySMTPClient(protocols.smtp.SMTPClient): def __init__(self): protocols.smtp.SMTPClient.__init__(self, 'foo.baz') self.mail = 'moshez@foo.bar', ['moshez@foo.bar'], mail def lineReceived(self, line): protocols.smtp.SMTPClient.lineReceived(self, line) def getMailFrom(self): return self.mail[0] def getMailTo(self): return self.mail[1] def getMailData(self): return StringIO(self.mail[2]) def sentMail(self, addresses): self.mail = None, None, None class LoopbackSMTPTestCase(unittest.TestCase): def loopback(self, server, client): loopback.loopbackTCP(server, client) def testMessages(self): factory = smtp.SMTPFactory() factory.domains = {} factory.domains['foo.bar'] = DummyDomain(['moshez']) from twisted.mail.protocols import DomainSMTP protocol = DomainSMTP() protocol.service = factory protocol.factory = factory clientProtocol = MySMTPClient() self.loopback(protocol, clientProtocol) protocol.timeoutID.cancel() class FakeSMTPServer(protocols.basic.LineReceiver): clientData = '''\ 220 hello 250 nice to meet you 250 great 250 great 354 go on, lad ''' def connectionMade(self): self.buffer = '' for line in string.split(self.clientData, '\n'): self.transport.write(line + '\r\n') def lineReceived(self, line): self.buffer = self.buffer + line + '\r\n' if line == "QUIT": self.transport.write("221 see ya around\r\n") self.transport.loseConnection() if line == ".": self.transport.write("250 gotcha\r\n") class SMTPClientTestCase(unittest.TestCase): expected_output='''\ HELO foo.baz\r MAIL FROM:\r RCPT TO:\r DATA\r Subject: hello\r \r Goodbye\r .\r QUIT\r ''' def xxxtestMessages(self): # this test is disabled temporarily client = MySMTPClient() server = FakeSMTPServer() loopback.loopbackTCP(server, client) self.assertEquals(server.buffer, self.expected_output) class DummySMTPMessage: def __init__(self, protocol, users): self.protocol = protocol self.users = users self.buffer = [] def lineReceived(self, line): # Throw away the generated Received: header if not re.match('Received: From foo.com \(\[.*\]\) by foo.com;', line): self.buffer.append(line) def eomReceived(self): message = string.join(self.buffer, '\n')+'\n' helo, origin = self.users[0].helo[0], str(self.users[0].orig) recipients = [] for user in self.users: recipients.append(str(user)) self.protocol.message = (helo, origin, recipients, message) deferred = defer.Deferred() deferred.callback("saved") return deferred class DummySMTP(smtp.SMTP): def connectionMade(self): smtp.SMTP.connectionMade(self) self.message = None def startMessage(self, users): return [DummySMTPMessage(self, users)] class AnotherSMTPTestCase(unittest.TestCase): messages = [ ('foo.com', 'moshez@foo.com', ['moshez@bar.com'], 'moshez@foo.com', ['moshez@bar.com'], '''\ From: Moshe To: Moshe Hi, how are you? '''), ('foo.com', 'tttt@rrr.com', ['uuu@ooo', 'yyy@eee'], 'tttt@rrr.com', ['uuu@ooo', 'yyy@eee'], '''\ Subject: pass ..rrrr.. '''), ('foo.com', '@this,@is,@ignored:foo@bar.com', ['@ignore,@this,@too:bar@foo.com'], 'foo@bar.com', ['bar@foo.com'], '''\ Subject: apa To: foo 123 . 456 '''), ] data = [ ('', '220.*\r\n$', None, None), ('HELO foo.com\r\n', '250.*\r\n$', None, None), ('RSET\r\n', '250.*\r\n$', None, None), ] for helo_, from_, to_, realfrom, realto, msg in messages: data.append(('MAIL FROM:<%s>\r\n' % from_, '250.*\r\n', None, None)) for rcpt in to_: data.append(('RCPT TO:<%s>\r\n' % rcpt, '250.*\r\n', None, None)) data.append(('DATA\r\n','354.*\r\n', msg, ('250.*\r\n', (helo_, realfrom, realto, msg)))) def testBuffer(self): output = StringIOWithoutClosing() a = DummySMTP() class fooFactory: domain = 'foo.com' a.factory = fooFactory() a.makeConnection(protocol.FileWrapper(output)) for (send, expect, msg, msgexpect) in self.data: if send: a.dataReceived(send) data = output.getvalue() output.truncate(0) if not re.match(expect, data): raise AssertionError, (send, expect, data) if data[:3] == '354': for line in msg.splitlines(): if line and line[0] == '.': line = '.' + line a.dataReceived(line + '\r\n') a.dataReceived('.\r\n') # Special case for DATA. Now we want a 250, and then # we compare the messages data = output.getvalue() output.truncate() resp, msgdata = msgexpect if not re.match(resp, data): raise AssertionError, (resp, data) self.assertEquals(a.message, msgdata) a.timeoutID.cancel()