#!/usr/bin/python
"""
$Id: utils.py,v 1.41 2007/05/29 13:22:42 dom Exp $

Some utility classes to simplify other interfaces.
Imports xml.sax
        uripath http://dev.w3.org/cvsweb/2000/10/swap/uripath.html
        httplib2 http://bitworking.org/projects/httplib2/
License
-------
Copyright (c) 2006 World Wide Web Consortium, (Massachusetts
Institute of Technology, European Research Consortium for Informatics
and Mathematics, Keio University). All Rights Reserved. This work is
distributed under the W3C Software License [1] in the hope that it
will be useful, but WITHOUT ANY WARRANTY; without even the implied
warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.

[1] http://www.w3.org/Consortium/Legal/copyright-software

"""

from  httplib2 import HttpLib2Error
class HTTPRedirection(HttpLib2Error): pass

class MaxLinksNumber(Exception):pass

class HttpError(Exception):pass

# We use a class with a static method
# so that we can avoid doing an HTTP request more than once

from threading import Thread, Lock


class HTTPRequest:
    _cache = {}
    _urisLocks = {}
    _cacheLock = Lock()

    # @@@ needs to deal with
    # * non HTTP URIs
    # * HTTP Error codes
    # * connection problem (socket.error ?)
    def http_request(cls,uri,method="GET",req_head=None,credentials=None):
        # @@@ Get from DDC? as param?
        request_headers = {"user-agent":"W3C-mobileOK/DDC-1.0 (see http://www.w3.org/2006/07/mobileok-ddc)","accept-charset": "UTF-8","accept":"application/xhtml+xml,text/html;q=0.1,application/vnd.wap.xhtml+xml;q=0.1,text/css,image/jpeg,image/gif","cache-control":"max-age=0"}
        # if the parameter req_head is set, we update the existing headers
        # with that data
        if req_head:
            request_headers.update(req_head)
        import socket
        socket.setdefaulttimeout(20.0)
        method=method.upper()
        HTTPRequest._cacheLock.acquire()
        # if specific request parameters were given, we can't rely on
        # the in-memory caching
        if not cls._cache.has_key(uri) or req_head:
            cls._cache[uri] = {}
            HTTPRequest._urisLocks[uri] = Lock()
        HTTPRequest._cacheLock.release()
        if not cls._cache[uri].has_key(method):
            import httplib2
            # @@@ the path needs to come from a config option
            h = httplib2.Http("/tmp/httplib2-cache")

            # setting up authentication if needed
            if credentials:
                h.add_credentials(credentials["login"],credentials["password"])
            # when doing a HEAD, we don't care about getting 304
            # we really want the headers of the resource
            if method=="HEAD":
                request_headers['cache-control']= request_headers['cache-control'] + ',no-cache'
            HTTPRequest._urisLocks[uri].acquire()
            headers,content = h.request(uri,method,None,request_headers)
            HTTPRequest._urisLocks[uri].release()
            if headers:
                if headers.status not in [200,304]:
                    headers["Location"]=uri
            HTTPRequest._cacheLock.acquire()
            cls._cache[uri][method] = (headers,content)
            HTTPRequest._cacheLock.release()
            if method=="GET":
                HTTPRequest._cacheLock.acquire()
                cls._cache[uri]["HEAD"]= (headers,None)
                HTTPRequest._cacheLock.release()
        return  cls._cache[uri][method]

    # To make this a static method (with a cache shared across instances)
    http_request = classmethod(http_request)


# A class that implements threaded requests
# so that we don't have to wait for each link to be dereferenced before
# proceeding
class ThreadedRequest:
    _threads = {}
    _results = {}
    
    # internal class that inherits from threading
    class threadit(Thread):
        def __init__(self,url,method):
            Thread.__init__(self)
            self.url = url
            self.method = method
            self.headers = None
            self.content = None
            self.exceptions = []
            self._request = HTTPRequest()

        def run(self):
            if self.url.split(":")[0] in ["http","https"]:
                try:
                    self.headers,self.content = self._request.http_request(self.url,self.method)
                except Exception ,e:
                    self.exceptions.append(e)
            else:
                self.headers,self.content = None,None

    def addRequest(self,url,method):
        if not self._threads.has_key(url):
            self._threads[url]=self.threadit(url,method)
            self._threads[url].start()
    # returns a hash by url of headers,content pairs
    def getResults(self):
        for url,thread in self._threads.iteritems():
            thread.join()
            self._results[url]=thread.headers,thread.content
        return self._results

    # To make these static methods
    addRequest = classmethod(addRequest)
    getResults = classmethod(getResults)


        
from xml.sax.handler  import ContentHandler

class SaxParser(ContentHandler):
    """A derivative of xml.sax.saxutils.XMLFilterBase with the proper
    default features set"""
    
    def __init__(self):
        ContentHandler.__init__(self)
        self._encoding = ""
        self._parser = self._getSaxParser()
        self.locator = None
        self._content = None
        self._headers = None

    def _http_request(self,uri,method="GET"):
        #if uri.split(':')[0]=="file":
        #    f = open(uri[7:])
        #    return None,f.read()
        h = HTTPRequest()
        return h.http_request(uri,method)

    def parse(self,uri):
        # @@@ only accepts HTTP URIS; need to raise proper exception in other cases
        self.uri = uri
        self._headers,self._content = self._http_request(uri,"GET")
        from  cStringIO import StringIO
        fp = StringIO(self._content)
        self._parser.setContentHandler(self)
        self._parser.parse(fp)

    def getEncoding(self):
        return self._encoding

    def _getXMLDeclaredEncoding(self,content=None):
        # copied from http://aspn.activestate.com/ASPN/Cookbook/Python/Recipe/363841
        """ Attempts to detect the character encoding of the xml file
        given by a file object. fp must not be a codec wrapped file
        object!

        The return value can be:
        - if detection of the BOM succeeds, the codec name of the
        corresponding unicode charset is returned

        - if BOM detection fails, the xml declaration is searched for
        the encoding attribute and its value returned. the "<"
        character has to be the very first in the file then (it's xml
        standard after all).

        - if BOM and xml declaration fail, None is returned. According
        to xml 1.0 it should be utf_8 then, but it wasn't detected by
        the means offered here. at least one can be pretty sure that a
        character coding including most of ASCII is used :-/
        """
        if not content:
            content=self._content
        from cStringIO import StringIO
        fp = StringIO(content)

        ### detection using BOM
        ## the BOMs we know, by their pattern
        bomDict={ # bytepattern : name              
            (0x00, 0x00, 0xFE, 0xFF) : "utf_32_be",        
            (0xFF, 0xFE, 0x00, 0x00) : "utf_32_le",
            (0xFE, 0xFF, None, None) : "utf_16_be", 
            (0xFF, 0xFE, None, None) : "utf_16_le", 
            (0xEF, 0xBB, 0xBF, None) : "utf_8",
            }
        
        ## go to beginning of file and get the first 4 bytes
        oldFP = fp.tell()
        fp.seek(0)
        (byte1, byte2, byte3, byte4) = tuple(map(ord, fp.read(4)))
        
        ## try bom detection using 4 bytes, 3 bytes, or 2 bytes
        bomDetection = bomDict.get((byte1, byte2, byte3, byte4))
        if not bomDetection :
            bomDetection = bomDict.get((byte1, byte2, byte3, None))
            if not bomDetection :
                bomDetection = bomDict.get((byte1, byte2, None, None))
        ## if BOM detected, we're done :-)
        if bomDetection :
            fp.seek(oldFP)
            return bomDetection
        ## still here? BOM detection failed.
        ##  now that BOM detection has failed we assume one byte character
        ##  encoding behaving ASCII - of course one could think of nice
        ##  algorithms further investigating on that matter, but I won't for now.
        ### search xml declaration for encoding attribute
        import re
        ## assume xml declaration fits into the first 2 KB (*cough*)
        fp.seek(0)
        buf = fp.read(2048)

        ## set up regular expression
        xmlDeclPattern = r"""
        ^<\?xml             # w/o BOM, xmldecl starts with <?xml at the first byte
        .+?                 # some chars (version info), matched minimal
        encoding=           # encoding attribute begins
        ["']                # attribute start delimiter
        (?P<encstr>         # what is matched in the brackets will be named encstr
        [^"']+              # every character not delimiter (not overly exact!)
        )                   # closes the brackets pair for the named group
        ["']                # attribute end delimiter
        .*?                 # some chars optionally (standalone decl or whitespace)
        \?>                 # xmldecl end
        """
         
        xmlDeclRE = re.compile(xmlDeclPattern, re.VERBOSE)

        ## search and extract encoding string
        match = xmlDeclRE.search(buf)
        fp.seek(oldFP)
        if match :
                return match.group("encstr")
        else :
                return None        
        
    def _getSaxParser(self):
        import xml.sax
        parser = xml.sax.make_parser()
        parser.setFeature(xml.sax.handler.feature_validation,0)
        parser.setFeature(xml.sax.handler.feature_external_ges,0)
        parser.setFeature(xml.sax.handler.feature_external_pes,0)
        # @@@ parser.getProperty(xml.sax.handler.property_encoding) not supported!!!
        self._encoding = "utf-8"
        return parser
    
    def setDocumentLocator(self, locator):
        #If the parser supports location info it invokes this event
        #before any other methods in the DocumentHandler interface
        self.locator = locator
        return


LINK_TYPE_EMBED = 1
LINK_TYPE_EXT = 2
LINK_TYPE_INT = 3

MAX_LINKS_NUMBER = 300

class Conneg:
    def __init__(self,accept_headers):
        self._acceptHeaders = accept_headers

    def preferred(self,mime_list):
        if len(mime_list)==0:
            return None
        if len(mime_list)==1:
            return mime_list[0]
        from cgi import parse_header
        accepted_mimes = self._acceptHeaders.split(",")
        pref = [None]
        for m in accepted_mimes:
            mime_data = parse_header(m)
            if mime_data[0] in mime_list:
                q_level = 1
                if len(mime_data)>1 and mime_data[1].has_key("q") and float(mime_data[1]["q"]) >=0 and float(mime_data[1]["q"]) <= 1:
                    q_level = mime_data[1]["q"]

                # note that the strictly inferior matches the rule that the order 
                # in accept_header matches the preferred order
                if (len(pref)<2 or pref[1] < q_level) and q_level>0:
                    pref = [mime_data[0],q_level]
        return pref[0]

    def __repr__(self):
        return "Conneg(%s)" % (self._acceptHeaders)

        
class LinkTarget:
    """ A link target describes the content type, the text and nature (embedded or not) of a link"""
    def __init__(self,target,nature=None,mime=None,text=None,location=None):
        self.target = target
        self.mime=mime
        self.text=text
        self.nature=nature
        self.location=location
        self.encodingHint=None
        self.protocolEncoding = None

    def __eq__(self,link):
        return self.target==link.target and self.mime==link.mime and self.text==link.text and self.nature==link.nature and self.location==link.location

    def __ne__(self,link):
        return not self.__eq__(link)

    def __repr__(self):
        return "LinkTarget('%s',%d,'%s','%s',%s)" % (self.target,self.nature,self.mime,self.text,self.location)

import UserList
class LinksList(UserList.UserList):
    " A list of LinkTargets indexed by absolute hashless URIs"
    def __init__(self,links=None):
        UserList.UserList.__init__(self,links)
        self._hash = {}

    def append(self,link):
        self.data.append(link)
        uri = link.target.split("#")[0]
        if not self._hash.has_key(uri):
            self._hash[uri] = []
        self._hash[uri].append(self.data[-1])

    def getLinksByUri(self,uri):
        uri = uri.split("#")[0]
        if uri in self:
            return self._hash[uri]
        else:
            return []

    def __contains__(self,uri):
        uri = uri.split("#")[0]
        return self._hash.has_key(uri)


class LinksParser(SaxParser):
    def startDocument(self):
        self._links=LinksList()
        self._base=self.uri
        self._currentLink = None
        self._linksCounter = 0
        self._requester = ThreadedRequest()
        import re
        # from http://www.w3.org/TR/html401/types.html#type-media-descriptors
        self._media_parser = re.compile("^([-a-zA-Z0-9]*)[^-a-zA-Z0-9]?")


    def _absolutize(self,uri,base=None):
        from swap import uripath
        # making uri absolute
        if base==None:
            base=self._base
        return uripath.join(base,uri)
            
    def _getMimeType(self,headers,returnsEncoding=False):
        from cgi import parse_header
        encoding= None
        mime = None
        if headers and  headers.has_key("content-type"):
            # keeping only the MIME type itself, not the associated parameters
            ct = parse_header(headers["content-type"])
            mime = ct[0]
            if len(ct)>1 and ct[1].has_key("charset"):
                encoding = ct[1]["charset"]
        if returnsEncoding:
            return mime,encoding
        else:
            return mime


    def startElement(self,name,attrs):
        from testcase import LineColumnLocation
        if name=="base" and attrs.has_key('href'):
            self._base = attrs["href"]
        # looking for hyperlinks
        elif name in ["a","area"]  and attrs.has_key('href'):
            from swap import uripath
            target = self._absolutize(attrs['href'])
            # checking whether the link is internal to the page or external
            relativeLink = uripath.refTo(self.uri,target)
            location = LineColumnLocation(self.uri,"utf-8",self.locator.getLineNumber(),self.locator.getColumnNumber())
            mime = None
            if len(relativeLink)==0 or relativeLink[0]=='#':
                nature = LINK_TYPE_INT
            else:
                nature = LINK_TYPE_EXT
            if nature == LINK_TYPE_EXT:
                if self._linksCounter < MAX_LINKS_NUMBER:
                    self._requester.addRequest(target,"HEAD")
                    self._currentLink = LinkTarget(target,nature,mime,'',location)
                    if attrs.has_key("charset"):
                        self._currentLink.encodingHint = attrs["charset"]
        # looking for stylesheets
        elif name=="link" and attrs.has_key('rel') and "stylesheet" in attrs["rel"].split(" ") and attrs.has_key('href'):
            if attrs.has_key('media'):
                mediastring = attrs["media"]
                if mediastring.strip()=="":
                    mediastring="all"
                    # we divert here from http://www.w3.org/TR/html401/present/styles.html#adef-media which
                    # says to default to "screen"
                    # but that sounds illogical in XHTML Basic context
                # parsing per http://www.w3.org/TR/html401/types.html#type-media-descriptors
                medias = map(lambda x: self._media_parser.findall(x.strip())[0],mediastring.split(","))
                if not "all" in medias and not "handheld" in medias:
                    return
            target = self._absolutize(attrs['href'])
            # we assume the type is text/css
            # since that's the only style sheet format for XHTML Basic @@@
            location = LineColumnLocation(self.uri,"utf-8",self.locator.getLineNumber(),self.locator.getColumnNumber())
            mime = None
            if self._linksCounter < MAX_LINKS_NUMBER:
                # we'll need to download the content later on, so we may as
                # well GET rather than HEAD
                self._requester.addRequest(target,"GET")
                link = LinkTarget(target,LINK_TYPE_EMBED,mime,'',location)
                if attrs.has_key("charset"):
                    link.encodingHint = attrs["charset"]
                self._links.append(link)
            self._linksCounter = self._linksCounter + 1
        # we look for all the other possible embedding elements
        # even the ones that don't fit in XHTML Basic
        # if the module gets used in a broader context
        elif name in ["img","input","script","iframe","frame"] and attrs.has_key("src"):
            target = self._absolutize(attrs['src'])
            location = LineColumnLocation(self.uri,"utf-8",self.locator.getLineNumber(),self.locator.getColumnNumber())
            mime = None
            encoding = None
            # we'll need to download the content later on, so we may as
            # well GET rather than HEAD
            if self._linksCounter < MAX_LINKS_NUMBER:
                self._requester.addRequest(target,"GET")
                link = LinkTarget(target,LINK_TYPE_EMBED,mime,'',location)
                self._links.append(link)
            self._linksCounter = self._linksCounter + 1
        elif name == "object" and attrs.has_key("data"):
            mime = None
            if attrs.has_key("codebase"):
                base = attrs["codebase"]
            else:
                base = self._base
            target = self._absolutize(attrs['data'],base)
            location = LineColumnLocation(self.uri,"utf-8",self.locator.getLineNumber(),self.locator.getColumnNumber())
            if self._linksCounter < MAX_LINKS_NUMBER:
                # we'll need to download the content later on, so we may as
                # well GET rather than HEAD
                self._requester.addRequest(target,"GET")
                link = LinkTarget(target,LINK_TYPE_EMBED,mime,'',location)
                self._links.append(link)
            self._linksCounter = self._linksCounter + 1

    def characters(self,content):
        if isinstance(self._currentLink,LinkTarget):
            self._currentLink.text = self._currentLink.text + content


    def endElement(self,name):
        if name in ["a","area"] and self._currentLink:            
            if self._currentLink.nature == LINK_TYPE_EXT and self._linksCounter < MAX_LINKS_NUMBER:
                self._links.append(self._currentLink)
            self._linksCounter = self._linksCounter + 1
            self._currentLink = None


    def endDocument(self):
        # let's collect the results of the threaded HEAD requests
        results = self._requester.getResults()
        for url,result in results.iteritems():
            for l in self._links.getLinksByUri(url):
                mime,encoding = self._getMimeType(result[0],True)
                l.mime = mime
                if encoding:
                    l.protocolEncoding = encoding
            
        if self._linksCounter > MAX_LINKS_NUMBER:
            raise MaxLinksNumber("The document contains links to %d external resources, and this tool checks at most %d external resources." % (self._linksCounter,MAX_LINKS_NUMBER))

    def getLinks(self):
        if self._links:
            return self._links
        else:
            return []



import unittest

class Tests(unittest.TestCase):
    def testLinkParser(self):
        from testcase import LineColumnLocation
        testfile = "http://dev.w3.org/cvsweb/~checkout~/2006/mwbp-validator/tests/links-parser.html"
        l = LinksParser()
        l.parse(testfile)
        res = [
            LinkTarget('http://www.w3.org/2006/04/bp-tests.css',1,'text/css','',LineColumnLocation(testfile,'utf-8',7,0)),
            LinkTarget('http://www.w3.org/Mobile/handheld.css',1,'text/css','',LineColumnLocation(testfile,'utf-8',13,0)),
            LinkTarget('http://www.w3.org/Icons/w3c_home.gif',1,'image/gif','',LineColumnLocation(testfile,'utf-8',18,0)),
            LinkTarget('http://dev.w3.org/cvsweb/2006/mwbp-validator/tests/',2,'text/html','Tests\nlist',LineColumnLocation(testfile,'utf-8',20,18)),
            LinkTarget('http://www.w3.org/Talks/Tools/Slidy/w3c-logo-blue.svg',1,'image/svg+xml','',LineColumnLocation(testfile,'utf-8',23,25)),
            LinkTarget('http://www.w3.org/',2,'application/xhtml+xml','',LineColumnLocation(testfile,'utf-8',24,17)),
            LinkTarget('http://www.w3.org/People/Dom/',2,'text/html','Dominique\nHazael-Massieux',LineColumnLocation(testfile,'utf-8',26,22))]
        self.assertEqual(l.getLinks(),res)

    def testConneg(self):
        tests = [
            [Conneg("application/xhtml+xml;q=0.0, text/html;q=0.5"),"text/html"],
            [Conneg("application/xhtml+xml, text/html"),"application/xhtml+xml"],
            [Conneg("text/html;application/xhtml+xml;q=1.0"),"text/html"],
            [Conneg("application/xhtml+xml;q=0.5,text/html;q=0.4"),"application/xhtml+xml"]
            ]
        for t in tests:
            c = t[0]
            self.assertEqual(c.preferred(["text/html","application/xhtml+xml"]),t[1])

class Profile:
    def __init__(self,file):
        from gnosis.xml.objectify import XML_Objectify, EXPAT
        self._file = file
        self.data = XML_Objectify(file,EXPAT).make_instance()
        
    def getList(self,path,group):
        return map(lambda x:  x.PCDATA,getattr(getattr(self.data,path),group))


def _test():
    unittest.main()

if __name__ == '__main__':
    _test()


