# Twisted, the Framework of Your Internet # Copyright (C) 2001 Matthew W. Lefkowitz # # This library is free software; you can redistribute it and/or # modify it under the terms of version 2.1 of the GNU Lesser General Public # License as published by the Free Software Foundation. # # This library is distributed in the hope that it will be useful, # but WITHOUT ANY WARRANTY; without even the implied warranty of # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU # Lesser General Public License for more details. # # You should have received a copy of the GNU Lesser General Public # License along with this library; if not, write to the Free Software # Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA """Test methods in twisted.internet.threads and reactor thread APIs.""" from twisted.trial import unittest from twisted.internet import threads, reactor from twisted.python import threadable, failure import time # make sure thread pool is shutdown import atexit #atexit.register(reactor.suggestThreadPoolSize, 0) class Counter: index = 0 problem = 0 def sync_add(self): """A thread-safe method.""" self.add() def add(self): """A none thread-safe method.""" next = self.index + 1 # another thread could jump in here and increment self.index on us if next != self.index + 1: self.problem = 1 raise ValueError # or here, same issue but we wouldn't catch it. We'd overwrite their # results, and the index will have lost a count. If several threads # get in here, we will actually make the count go backwards when we # overwrite it. self.index = next synchronized = ["sync_add"] threadable.synchronize(Counter) class ThreadsTestCase(unittest.TestCase): """Test twisted.internet.threads.""" def testCallInThread(self): c = Counter() for i in range(1000): reactor.callInThread(c.sync_add) # those thousand calls should all be running "at the same time" (at # least up to the size of the thread pool, anyway). While they are # fighting over the right to add, watch carefully for any sign of # overlapping threads, which might be detected by the thread itself # (c.problem), or might appear as an index that doesn't go all the # way up to 1000 (a few missing counts and it could stop at 999). If # threadpools are really broken, it might never increment at all. # in practice, if the threads are running unsynchronized (say, using # c.add instead of c.sync_add), it takes about 10 repetitions of # this test case to expose a problem when = time.time() oldIndex = 0 while c.index < 1000: # watch for the count to go backwards assert oldIndex <= c.index # spend some extra time per loop making sure we eventually get # out of it self.failIf(c.problem, "threads reported overlap") if c.index > oldIndex: when = time.time() # reset each count else: if time.time() > when + 5: if c.index > 0: self.fail("threads lost a count") else: self.fail("threads never started") oldIndex = c.index # This check will never fail, because a missing count wouldn't make # it out of the 'while c.index < 1000' loop. But it makes it clear # what our expectation are. self.assertEquals(c.index, 1000, "threads lost a count") def testCallMultiple(self): c = Counter() # we call the non-thread safe method, but because they should # all be called in same thread, this is ok. commands = [(c.add, (), {})] * 1000 threads.callMultipleInThread(commands) when = time.time() oldIndex = 0 while c.index < 1000: assert oldIndex <= c.index self.failIf(c.problem, "threads reported overlap") if c.index > oldIndex: when = time.time() # reset each count else: if time.time() > when + 5: if c.index > 0: self.fail("threads lost a count") else: self.fail("threads never started") oldIndex = c.index self.assertEquals(c.index, 1000) def testSuggestThreadPoolSize(self): reactor.suggestThreadPoolSize(34) reactor.suggestThreadPoolSize(4) gotResult = 0 def testDeferredResult(self): d = threads.deferToThread(lambda x, y=5: x + y, 3, y=4) d.addCallback(self._resultCallback) while not self.gotResult: reactor.iterate() self.gotResult = 0 def _resultCallback(self, result): self.assertEquals(result, 7) self.gotResult = 1 def testDeferredFailure(self): def raiseError(): raise TypeError d = threads.deferToThread(raiseError) d.addErrback(self._resultErrback) while not self.gotResult: reactor.iterate() self.gotResult = 0 def _resultErrback(self, error): self.assert_( isinstance(error, failure.Failure) ) self.assertEquals(error.type, TypeError) self.gotResult = 1