"""Servlet factory template."""
import os
import sys
import threading
from keyword import iskeyword
from MiscUtils import AbstractError
from Servlet import Servlet
debug = False
class ServletFactory(object):
    """Servlet factory template.
    ServletFactory is an abstract class that defines the protocol for
    all servlet factories.
    Servlet factories are used by the Application to create servlets
    for transactions.
    A factory must inherit from this class and override uniqueness(),
    extensions() and either loadClass() or servletForTransaction().
    Do not invoke the base class methods as they all raise AbstractErrors.
    Each method is documented below.
    """
    
    def __init__(self, application):
        """Create servlet factory.
        Stores a reference to the application in self._app, because
        subclasses may or may not need to talk back to the application
        to do their work.
        """
        self._app = application
        self._imp = self._app._imp
        self._cacheClasses = self._app.setting("CacheServletClasses", True)
        self._cacheInstances = self._app.setting("CacheServletInstances", True)
        self._reloadClasses = self._app.setting("ReloadServletClasses", True)
        
        
        
        
        self._classCache = {}
        
        self._servletPool = {}
        
        
        self._threadsafeServletCache = {}
        self._importLock = threading.RLock()
    
    def name(self):
        """Return the name of the factory.
        This is a convenience for the class name.
        """
        return self.__class__.__name__
    def uniqueness(self):
        """Return uniqueness type.
        Returns a string to indicate the uniqueness of the ServletFactory's
        servlets. The Application needs to know if the servlets are unique
        per file, per extension or per application.
        Return values are 'file', 'extension' and 'application'.
        NOTE: Application only supports 'file' uniqueness at this point in time.
        """
        raise AbstractError(self.__class__)
    def extensions(self):
        """Return a list of extensions that match this handler.
        Extensions should include the dot. An empty string indicates a file
        with no extension and is a valid value. The extension '.*' is a special
        case that is looked for a URL's extension doesn't match anything.
        """
        raise AbstractError(self.__class__)
    
    def importAsPackage(self, transaction, serverSidePathToImport):
        """Import requested module.
        Imports the module at the given path in the proper package/subpackage
        for the current request. For example, if the transaction has the URL
        http://localhost/WebKit.cgi/MyContextDirectory/MySubdirectory/MyPage
        and path = 'some/random/path/MyModule.py' and the context is configured
        to have the name 'MyContext' then this function imports the module at
        that path as MyContext.MySubdirectory.MyModule . Note that the context
        name may differ from the name of the directory containing the context,
        even though they are usually the same by convention.
        Note that the module imported may have a different name from the
        servlet name specified in the URL. This is used in PSP.
        """
        
        request = transaction.request()
        path = request.serverSidePath()
        contextPath = request.serverSideContextPath()
        fullname = request.contextName()
        
        
        if not fullname or not path.startswith(contextPath):
            fullname = serverSidePathToImport
            if os.sep != '/':
                fullname = fullname.replace(os.sep, '_')
            fullname = fullname.replace('/', '_').replace('.', '_')
            name = os.path.splitext(os.path.basename(
                serverSidePathToImport))[0]
            moduleDir = os.path.dirname(serverSidePathToImport)
            module = self._importModuleFromDirectory(fullname, name,
                moduleDir, forceReload=self._reloadClasses)
            return module
        
        if os.sep != '/':
            fullname = fullname.replace(os.sep, '_')
        fullname = fullname.replace('/', '_')
        directory, contextDirName = os.path.split(contextPath)
        self._importModuleFromDirectory(fullname, contextDirName,
            directory, isPackageDir=True)
        directory = contextPath
        
        remainder = path[len(contextPath)+1:]
        if os.sep != '/':
            remainder = remainder.replace(os.sep, '/')
        remainder = remainder.split('/')
        
        for name in remainder[:-1]:
            fullname = '%s.%s' % (fullname, name)
            self._importModuleFromDirectory(fullname, name,
                directory, isPackageDir=True)
            directory = os.path.join(directory, name)
        
        
        moduleFileName = os.path.basename(serverSidePathToImport)
        moduleDir = os.path.dirname(serverSidePathToImport)
        name = os.path.splitext(moduleFileName)[0]
        fullname = '%s.%s' % (fullname, name)
        module = self._importModuleFromDirectory(fullname, name,
            moduleDir, forceReload=self._reloadClasses)
        return module
    def _importModuleFromDirectory(self, fullModuleName, moduleName,
            directory, isPackageDir=False, forceReload=False):
        """Imports the given module from the given directory.
        fullModuleName should be the full dotted name that will be given
        to the module within Python. moduleName should be the name of the
        module in the filesystem, which may be different from the name
        given in fullModuleName. Returns the module object. If forceReload is
        True then this reloads the module even if it has already been imported.
        If isPackageDir is True, then this function creates an empty
        __init__.py if that file doesn't already exist.
        """
        if debug:
            print __file__, fullModuleName, moduleName, directory
        module = sys.modules.get(fullModuleName)
        if module is not None and not forceReload:
            return module
        if isPackageDir:
            
            packageDir = os.path.join(directory, moduleName)
            initPy = os.path.join(packageDir, '__init__.py')
            for ext in ('', 'c', 'o'):
                
                if os.path.exists(initPy + ext):
                    break
            else: 
                file = open(initPy, 'w')
                file.write('#')
                file.close()
        fp, pathname, stuff = self._imp.find_module(moduleName, [directory])
        module = self._imp.load_module(fullModuleName, fp, pathname, stuff)
        module.__donotreload__ = self._reloadClasses
        return module
    def loadClass(self, transaction, path):
        """Load the appropriate class.
        Given a transaction and a path, load the class for creating these
        servlets. Caching, pooling, and threadsafeness are all handled by
        servletForTransaction. This method is not expected to be threadsafe.
        """
        raise AbstractError(self.__class__)
    
    def servletForTransaction(self, transaction):
        """Return a new servlet that will handle the transaction.
        This method handles caching, and will call loadClass(trans, filepath)
        if no cache is found. Caching is generally controlled by servlets
        with the canBeReused() and canBeThreaded() methods.
        """
        request = transaction.request()
        path = request.serverSidePath()
        
        
        mtime = os.path.getmtime(path)
        if (path not in self._classCache
                or mtime != self._classCache[path]['mtime']):
            
            
            self._importLock.acquire()
            try:
                if (path not in self._classCache
                        or mtime != self._classCache[path]['mtime']):
                    theClass = self.loadClass(transaction, path)
                    if self._cacheClasses:
                        self._classCache[path] = {
                            'mtime': mtime, 'class': theClass}
                else:
                    theClass = self._classCache[path]['class']
            finally:
                self._importLock.release()
        else:
            theClass = self._classCache[path]['class']
        
        
        
        if path in self._threadsafeServletCache:
            servlet = self._threadsafeServletCache[path]
            if servlet.__class__ is theClass:
                return servlet
        else:
            while 1:
                try:
                    servlet = self._servletPool[path].pop()
                except (KeyError, IndexError):
                    break
                else:
                    if servlet.__class__ is theClass:
                        servlet.open()
                        return servlet
        
        
        self._importLock.acquire()
        try:
            mtime = os.path.getmtime(path)
            if path not in self._classCache:
                self._classCache[path] = {
                    'mtime': mtime,
                    'class': self.loadClass(transaction, path)}
            elif mtime > self._classCache[path]['mtime']:
                self._classCache[path]['mtime'] = mtime
                self._classCache[path]['class'] = self.loadClass(
                    transaction, path)
            theClass = self._classCache[path]['class']
            if not self._cacheClasses:
                del self._classCache[path]
        finally:
            self._importLock.release()
        
        servlet = theClass()
        servlet.setFactory(self)
        if servlet.canBeReused():
            if servlet.canBeThreaded():
                self._threadsafeServletCache[path] = servlet
            else:
                self._servletPool[path] = []
                servlet.open()
        return servlet
    def returnServlet(self, servlet):
        """Return servlet to the pool.
        Called by Servlet.close(), which returns the servlet
        to the servlet pool if necessary.
        """
        if (servlet.canBeReused() and not servlet.canBeThreaded()
                and self._cacheInstances):
            path = servlet.serverSidePath()
            self._servletPool[path].append(servlet)
    def flushCache(self):
        """Flush the servlet cache and start fresh.
        Servlets that are currently in the wild may find their way back
        into the cache (this may be a problem).
        """
        self._importLock.acquire()
        self._classCache = {}
        
        
        for key in self._servletPool:
            self._servletPool[key] = []
        self._threadsafeServletCache = {}
        self._importLock.release()
class PythonServletFactory(ServletFactory):
    """The factory for Python servlets.
    This is the factory for ordinary Python servlets whose extensions
    are empty or .py. The servlets are unique per file since the file
    itself defines the servlet.
    """
    
    def uniqueness(self):
        return 'file'
    def extensions(self):
        
        
        
        
        
        
        
        return ['.py', '.pyc', '.pyo']
    
    def loadClass(self, transaction, path):
        
        module = self.importAsPackage(transaction, path)
        
        name = os.path.splitext(os.path.split(path)[1])[0]
        
        if not hasattr(module, name):
            
            
            
            name = name.replace('-', '_').replace(' ', '_')
            
            
            if iskeyword(name):
                name += '_'
            
            if not hasattr(module, name):
                raise ValueError('Cannot find expected servlet class %r in %r.'
                    % (name, path))
        
        theClass = getattr(module, name)
        assert isinstance(theClass, (object, type))
        assert issubclass(theClass, Servlet)
        return theClass