# Twisted, the Framework of Your Internet # Copyright (C) 2001-2002 Matthew W. Lefkowitz # # This library is free software; you can redistribute it and/or # modify it under the terms of version 2.1 of the GNU Lesser General Public # License as published by the Free Software Foundation. # # This library is distributed in the hope that it will be useful, # but WITHOUT ANY WARRANTY; without even the implied warranty of # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU # Lesser General Public License for more details. # # You should have received a copy of the GNU Lesser General Public # License along with this library; if not, write to the Free Software # Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA # """Utility methods.""" from twisted.internet import protocol, reactor, defer import cStringIO def _callProtocolWithDeferred(protocol, executable, args, env, path, reactor): d = defer.Deferred() p = protocol(d) reactor.spawnProcess(p, executable, (executable,)+tuple(args), env, path) return d class _BackRelay(protocol.ProcessProtocol): def __init__(self, deferred): self.deferred = deferred self.s = cStringIO.StringIO() def errReceived(self, text): self.deferred.errback(failure.Failure(IOError("got stderr"))) self.deferred = None self.transport.loseConnection() def outReceived(self, text): self.s.write(text) def processEnded(self, reason): if self.deferred is not None: self.deferred.callback(self.s.getvalue()) def getProcessOutput(executable, args=(), env={}, path='.', reactor=reactor): """Spawn a process and return its output as a deferred returning a string. @param executable: The file name to run and get the output of - the full path should be used. @param args: the command line arguments to pass to the process; a sequence of strings. The first string should be the executable's name. @param env: the environment variables to pass to the processs; a dictionary of strings. @param path: the path to run the subprocess in - defaults to the current directory. @param reactor: the reactor to use - defaults to the default reactor """ return _callProtocolWithDeferred(_BackRelay, executable, args, env, path, reactor) class _ValueGetter(protocol.ProcessProtocol): def __init__(self, deferred): self.deferred = deferred def processEnded(self, reason): self.deferred.callback(reason.value.exitCode) def getProcessValue(executable, args=(), env={}, path='.', reactor=reactor): """Spawn a process and return its exit code as a Deferred.""" return _callProtocolWithDeferred(_ValueGetter, executable, args, env, path, reactor) import random from twisted.names import client from twisted.internet import error, interfaces class _SRVConnector_ClientFactoryWrapper: def __init__(self, connector, wrappedFactory): self.__connector = connector self.__wrappedFactory = wrappedFactory def startedConnecting(self, connector): self.__wrappedFactory.startedConnecting(self.__connector) def clientConnectionFailed(self, connector, reason): self.__connector.connectionFailed(reason) def clientConnectionLost(self, connector, reason): self.__connector.connectionLost(reason) def __getattr__(self, key): return getattr(self.__wrappedFactory, key) class SRVConnector: """A connector that looks up DNS SRV records. See RFC2782.""" __implements__ = interfaces.IConnector stopAfterDNS=0 def __init__(self, reactor, service, domain, factory, protocol='tcp', connectFuncName='connectTCP', connectFuncArgs=(), connectFuncKwArgs={}, ): self.reactor = reactor self.service = service self.domain = domain self.factory = factory self.protocol = protocol self.connectFuncName = connectFuncName self.connectFuncArgs = connectFuncArgs self.connectFuncKwArgs = connectFuncKwArgs self.connector = None self.servers = None self.orderedServers = None # list of servers already used in this round def connect(self): """Start connection to remote server.""" self.factory.doStart() self.factory.startedConnecting(self) if not self.servers: if self.domain is None: self.connectionFailed(error.DNSLookupError("Domain is not defined.")) return d = client.theResolver.lookupService('_%s._%s.%s' % (self.service, self.protocol, self.domain)) d.addCallback(self._cbGotServers) d.addCallback(lambda x, self=self: self._reallyConnect()) d.addErrback(self.connectionFailed) elif self.connector is None: self._reallyConnect() else: self.connector.connect() def _cbGotServers(self, (answers, auth, add)): if len(answers)==1 and answers[0].payload.target=='.': # decidedly not available raise error.DNSLookupError("Service %s not available for domain %s." % (repr(self.service), repr(self.domain))) self.servers = [] self.orderedServers = [] for a in answers: self.orderedServers.append((a.payload.priority, a.payload.weight, str(a.payload.target), a.payload.port)) def _serverCmp(self, a, b): if a[0]!=b[0]: return cmp(a[0], b[0]) else: return cmp(a[1], b[1]) def pickServer(self): assert self.servers is not None assert self.orderedServers is not None if not self.servers and not self.orderedServers: # no SRV record, fall back.. return self.domain, self.service if not self.servers and self.orderedServers: # start new round self.servers = self.orderedServers self.orderedServers = [] assert self.servers self.servers.sort(self._serverCmp) minPriority=self.servers[0][0] weightIndex = zip(xrange(len(self.servers)), [x[1] for x in self.servers if x[0]==minPriority]) weightSum = reduce(lambda x, y: (None, x[1]+y[1]), weightIndex, (None, 0))[1] rand = random.randint(0, weightSum) for index, weight in weightIndex: weightSum -= weight if weightSum <= 0: chosen = self.servers[index] del self.servers[index] self.orderedServers.append(chosen) p, w, host, port = chosen return host, port raise RuntimeError, 'Impossible %s pickServer result.' % self.__class__.__name__ def _reallyConnect(self): if self.stopAfterDNS: self.stopAfterDNS=0 return self.host, self.port = self.pickServer() assert self.host is not None, 'Must have a host to connect to.' assert self.port is not None, 'Must have a port to connect to.' connectFunc = getattr(self.reactor, self.connectFuncName) self.connector=connectFunc( self.host, self.port, _SRVConnector_ClientFactoryWrapper(self, self.factory), *self.connectFuncArgs, **self.connectFuncKwArgs) def stopConnecting(self): """Stop attempting to connect.""" if self.connector: self.connector.stopConnecting() else: self.stopAfterDNS=1 def disconnect(self): """Disconnect whatever our are state is.""" if self.connector is not None: self.connector.disconnect() else: self.stopConnecting() def getDestination(self): assert self.connector return self.connector.getDestination() def connectionFailed(self, reason): self.factory.clientConnectionFailed(self, reason) self.factory.doStop() def connectionLost(self, reason): self.factory.clientConnectionLost(self, reason) self.factory.doStop()