]> xmof Git - DeDRM.git/commitdiff
Create ion.py
authortomthumb1997 <37314994+tomthumb1997@users.noreply.github.com>
Tue, 13 Mar 2018 00:35:28 +0000 (20:35 -0400)
committerGitHub <noreply@github.com>
Tue, 13 Mar 2018 00:35:28 +0000 (20:35 -0400)
DeDRM_calibre_plugin/DeDRM_plugin/ion.py [new file with mode: 0644]

diff --git a/DeDRM_calibre_plugin/DeDRM_plugin/ion.py b/DeDRM_calibre_plugin/DeDRM_plugin/ion.py
new file mode 100644 (file)
index 0000000..c100191
--- /dev/null
@@ -0,0 +1,981 @@
+#!/usr/bin/env python
+# -*- coding: utf-8 -*-
+# Pascal implementation by lulzkabulz. Python translation by apprenticenaomi. DeDRM integration by anon.
+# BinaryIon.pas + DrmIon.pas + IonSymbols.pas
+
+from __future__ import with_statement
+
+import collections
+import hashlib
+import hmac
+import os
+import os.path
+import struct
+
+try:
+    from cStringIO import StringIO
+except ImportError:
+    from StringIO import StringIO
+
+from Crypto.Cipher import AES
+from Crypto.Util.py3compat import bchr, bord
+
+try:
+    # lzma library from calibre 2.35.0 or later
+    import lzma.lzma1 as calibre_lzma
+except:
+    calibre_lzma = None
+    try:
+        import lzma
+    except:
+        # Need pip backports.lzma on Python <3.3
+        from backports import lzma
+
+
+TID_NULL = 0
+TID_BOOLEAN = 1
+TID_POSINT = 2
+TID_NEGINT = 3
+TID_FLOAT = 4
+TID_DECIMAL = 5
+TID_TIMESTAMP = 6
+TID_SYMBOL = 7
+TID_STRING = 8
+TID_CLOB = 9
+TID_BLOB = 0xA
+TID_LIST = 0xB
+TID_SEXP = 0xC
+TID_STRUCT = 0xD
+TID_TYPEDECL = 0xE
+TID_UNUSED = 0xF
+
+
+SID_UNKNOWN = -1
+SID_ION = 1
+SID_ION_1_0 = 2
+SID_ION_SYMBOL_TABLE = 3
+SID_NAME = 4
+SID_VERSION = 5
+SID_IMPORTS = 6
+SID_SYMBOLS = 7
+SID_MAX_ID = 8
+SID_ION_SHARED_SYMBOL_TABLE = 9
+SID_ION_1_0_MAX = 10
+
+
+LEN_IS_VAR_LEN = 0xE
+LEN_IS_NULL = 0xF
+
+
+VERSION_MARKER = b"\x01\x00\xEA"
+
+
+# asserts must always raise exceptions for proper functioning
+def _assert(test, msg="Exception"):
+    if not test:
+        raise Exception(msg)
+
+
+class SystemSymbols(object):
+    ION = '$ion'
+    ION_1_0 = '$ion_1_0'
+    ION_SYMBOL_TABLE = '$ion_symbol_table'
+    NAME = 'name'
+    VERSION = 'version'
+    IMPORTS = 'imports'
+    SYMBOLS = 'symbols'
+    MAX_ID = 'max_id'
+    ION_SHARED_SYMBOL_TABLE = '$ion_shared_symbol_table'
+
+
+class IonCatalogItem(object):
+    name = ""
+    version = 0
+    symnames = []
+
+    def __init__(self, name, version, symnames):
+        self.name = name
+        self.version = version
+        self.symnames = symnames
+
+
+class SymbolToken(object):
+    text = ""
+    sid = 0
+
+    def __init__(self, text, sid):
+        if text == "" and sid == 0:
+            raise ValueError("Symbol token must have Text or SID")
+
+        self.text = text
+        self.sid = sid
+
+
+class SymbolTable(object):
+    table = None
+
+    def __init__(self):
+        self.table = [None] * SID_ION_1_0_MAX
+        self.table[SID_ION] = SystemSymbols.ION
+        self.table[SID_ION_1_0] = SystemSymbols.ION_1_0
+        self.table[SID_ION_SYMBOL_TABLE] = SystemSymbols.ION_SYMBOL_TABLE
+        self.table[SID_NAME] = SystemSymbols.NAME
+        self.table[SID_VERSION] = SystemSymbols.VERSION
+        self.table[SID_IMPORTS] = SystemSymbols.IMPORTS
+        self.table[SID_SYMBOLS] = SystemSymbols.SYMBOLS
+        self.table[SID_MAX_ID] = SystemSymbols.MAX_ID
+        self.table[SID_ION_SHARED_SYMBOL_TABLE] = SystemSymbols.ION_SHARED_SYMBOL_TABLE
+
+    def findbyid(self, sid):
+        if sid < 1:
+            raise ValueError("Invalid symbol id")
+
+        if sid < len(self.table):
+            return self.table[sid]
+        else:
+            return ""
+
+    def import_(self, table, maxid):
+        for i in range(maxid):
+            self.table.append(table.symnames[i])
+
+    def importunknown(self, name, maxid):
+        for i in range(maxid):
+            self.table.append("%s#%d" % (name, i + 1))
+
+
+class ParserState:
+    Invalid,BeforeField,BeforeTID,BeforeValue,AfterValue,EOF = 1,2,3,4,5,6
+
+ContainerRec = collections.namedtuple("ContainerRec", "nextpos, tid, remaining")
+
+
+class BinaryIonParser(object):
+    eof = False
+    state = None
+    localremaining = 0
+    needhasnext = False
+    isinstruct = False
+    valuetid = 0
+    valuefieldid = 0
+    parenttid = 0
+    valuelen = 0
+    valueisnull = False
+    valueistrue = False
+    value = None
+    didimports = False
+
+    def __init__(self, stream):
+        self.annotations = []
+        self.catalog = []
+
+        self.stream = stream
+        self.initpos = stream.tell()
+        self.reset()
+        self.symbols = SymbolTable()
+
+    def reset(self):
+        self.state = ParserState.BeforeTID
+        self.needhasnext = True
+        self.localremaining = -1
+        self.eof = False
+        self.isinstruct = False
+        self.containerstack = []
+        self.stream.seek(self.initpos)
+
+    def addtocatalog(self, name, version, symbols):
+        self.catalog.append(IonCatalogItem(name, version, symbols))
+
+    def hasnext(self):
+        while self.needhasnext and not self.eof:
+            self.hasnextraw()
+            if len(self.containerstack) == 0 and not self.valueisnull:
+                if self.valuetid == TID_SYMBOL:
+                    if self.value == SID_ION_1_0:
+                        self.needhasnext = True
+                elif self.valuetid == TID_STRUCT:
+                    for a in self.annotations:
+                        if a == SID_ION_SYMBOL_TABLE:
+                            self.parsesymboltable()
+                            self.needhasnext = True
+                            break
+        return not self.eof
+
+    def hasnextraw(self):
+        self.clearvalue()
+        while self.valuetid == -1 and not self.eof:
+            self.needhasnext = False
+            if self.state == ParserState.BeforeField:
+                _assert(self.valuefieldid == SID_UNKNOWN)
+
+                self.valuefieldid = self.readfieldid()
+                if self.valuefieldid != SID_UNKNOWN:
+                    self.state = ParserState.BeforeTID
+                else:
+                    self.eof = True
+
+            elif self.state == ParserState.BeforeTID:
+                self.state = ParserState.BeforeValue
+                self.valuetid = self.readtypeid()
+                if self.valuetid == -1:
+                    self.state = ParserState.EOF
+                    self.eof = True
+                    break
+
+                if self.valuetid == TID_TYPEDECL:
+                    if self.valuelen == 0:
+                        self.checkversionmarker()
+                    else:
+                        self.loadannotations()
+
+            elif self.state == ParserState.BeforeValue:
+                self.skip(self.valuelen)
+                self.state = ParserState.AfterValue
+
+            elif self.state == ParserState.AfterValue:
+                if self.isinstruct:
+                    self.state = ParserState.BeforeField
+                else:
+                    self.state = ParserState.BeforeTID
+
+            else:
+                _assert(self.state == ParserState.EOF)
+
+    def next(self):
+        if self.hasnext():
+            self.needhasnext = True
+            return self.valuetid
+        else:
+            return -1
+
+    def push(self, typeid, nextposition, nextremaining):
+        self.containerstack.append(ContainerRec(nextpos=nextposition, tid=typeid, remaining=nextremaining))
+
+    def stepin(self):
+        _assert(self.valuetid in [TID_STRUCT, TID_LIST, TID_SEXP] and not self.eof,
+                "valuetid=%s eof=%s" % (self.valuetid, self.eof))
+        _assert((not self.valueisnull or self.state == ParserState.AfterValue) and
+               (self.valueisnull or self.state == ParserState.BeforeValue))
+
+        nextrem = self.localremaining
+        if nextrem != -1:
+            nextrem -= self.valuelen
+            if nextrem < 0:
+                nextrem = 0
+        self.push(self.parenttid, self.stream.tell() + self.valuelen, nextrem)
+
+        self.isinstruct = (self.valuetid == TID_STRUCT)
+        if self.isinstruct:
+            self.state = ParserState.BeforeField
+        else:
+            self.state = ParserState.BeforeTID
+
+        self.localremaining = self.valuelen
+        self.parenttid = self.valuetid
+        self.clearvalue()
+        self.needhasnext = True
+
+    def stepout(self):
+        rec = self.containerstack.pop()
+
+        self.eof = False
+        self.parenttid = rec.tid
+        if self.parenttid == TID_STRUCT:
+            self.isinstruct = True
+            self.state = ParserState.BeforeField
+        else:
+            self.isinstruct = False
+            self.state = ParserState.BeforeTID
+        self.needhasnext = True
+
+        self.clearvalue()
+        curpos = self.stream.tell()
+        if rec.nextpos > curpos:
+            self.skip(rec.nextpos - curpos)
+        else:
+            _assert(rec.nextpos == curpos)
+
+        self.localremaining = rec.remaining
+
+    def read(self, count=1):
+        if self.localremaining != -1:
+            self.localremaining -= count
+            _assert(self.localremaining >= 0)
+
+        result = self.stream.read(count)
+        if len(result) == 0:
+            raise EOFError()
+        return result
+
+    def readfieldid(self):
+        if self.localremaining != -1 and self.localremaining < 1:
+            return -1
+
+        try:
+            return self.readvaruint()
+        except EOFError:
+            return -1
+
+    def readtypeid(self):
+        if self.localremaining != -1:
+            if self.localremaining < 1:
+                return -1
+            self.localremaining -= 1
+
+        b = self.stream.read(1)
+        if len(b) < 1:
+            return -1
+        b = bord(b)
+        result = b >> 4
+        ln = b & 0xF
+
+        if ln == LEN_IS_VAR_LEN:
+            ln = self.readvaruint()
+        elif ln == LEN_IS_NULL:
+            ln = 0
+            self.state = ParserState.AfterValue
+        elif result == TID_NULL:
+            # Must have LEN_IS_NULL
+            _assert(False)
+        elif result == TID_BOOLEAN:
+            _assert(ln <= 1)
+            self.valueistrue = (ln == 1)
+            ln = 0
+            self.state = ParserState.AfterValue
+        elif result == TID_STRUCT:
+            if ln == 1:
+                ln = self.readvaruint()
+
+        self.valuelen = ln
+        return result
+
+    def readvarint(self):
+        b = bord(self.read())
+        negative = ((b & 0x40) != 0)
+        result = (b & 0x3F)
+
+        i = 0
+        while (b & 0x80) == 0 and i < 4:
+            b = bord(self.read())
+            result = (result << 7) | (b & 0x7F)
+            i += 1
+
+        _assert(i < 4 or (b & 0x80) != 0, "int overflow")
+
+        if negative:
+            return -result
+        return result
+
+    def readvaruint(self):
+        b = bord(self.read())
+        result = (b & 0x7F)
+
+        i = 0
+        while (b & 0x80) == 0 and i < 4:
+            b = bord(self.read())
+            result = (result << 7) | (b & 0x7F)
+            i += 1
+
+        _assert(i < 4 or (b & 0x80) != 0, "int overflow")
+
+        return result
+
+    def readdecimal(self):
+        if self.valuelen == 0:
+            return 0.
+
+        rem = self.localremaining - self.valuelen
+        self.localremaining = self.valuelen
+        exponent = self.readvarint()
+
+        _assert(self.localremaining > 0, "Only exponent in ReadDecimal")
+        _assert(self.localremaining <= 8, "Decimal overflow")
+
+        signed = False
+        b = [bord(x) for x in self.read(self.localremaining)]
+        if (b[0] & 0x80) != 0:
+            b[0] = b[0] & 0x7F
+            signed = True
+
+        # Convert variably sized network order integer into 64-bit little endian
+        j = 0
+        vb = [0] * 8
+        for i in range(len(b), -1, -1):
+            vb[i] = b[j]
+            j += 1
+
+        v = struct.unpack("<Q", b"".join(bchr(x) for x in vb))[0]
+
+        result = v * (10 ** exponent)
+        if signed:
+            result = -result
+
+        self.localremaining = rem
+        return result
+
+    def skip(self, count):
+        if self.localremaining != -1:
+            self.localremaining -= count
+            if self.localremaining < 0:
+                raise EOFError()
+
+        self.stream.seek(count, os.SEEK_CUR)
+
+    def parsesymboltable(self):
+        self.next() # shouldn't do anything?
+
+        _assert(self.valuetid == TID_STRUCT)
+
+        if self.didimports:
+            return
+
+        self.stepin()
+
+        fieldtype = self.next()
+        while fieldtype != -1:
+            if not self.valueisnull:
+                _assert(self.valuefieldid == SID_IMPORTS, "Unsupported symbol table field id")
+
+                if fieldtype == TID_LIST:
+                    self.gatherimports()
+
+            fieldtype = self.next()
+
+        self.stepout()
+        self.didimports = True
+
+    def gatherimports(self):
+        self.stepin()
+
+        t = self.next()
+        while t != -1:
+            if not self.valueisnull and t == TID_STRUCT:
+                self.readimport()
+
+            t = self.next()
+
+        self.stepout()
+
+    def readimport(self):
+        version = -1
+        maxid = -1
+        name = ""
+
+        self.stepin()
+
+        t = self.next()
+        while t != -1:
+            if not self.valueisnull and self.valuefieldid != SID_UNKNOWN:
+                if self.valuefieldid == SID_NAME:
+                    name = self.stringvalue()
+                elif self.valuefieldid == SID_VERSION:
+                    version = self.intvalue()
+                elif self.valuefieldid == SID_MAX_ID:
+                    maxid = self.intvalue()
+
+            t = self.next()
+
+        self.stepout()
+
+        if name == "" or name == SystemSymbols.ION:
+            return
+
+        if version < 1:
+            version = 1
+
+        table = self.findcatalogitem(name)
+        if maxid < 0:
+            _assert(table is not None and version == table.version, "Import %s lacks maxid" % name)
+            maxid = len(table.symnames)
+
+        if table is not None:
+            self.symbols.import_(table, min(maxid, len(table.symnames)))
+        else:
+            self.symbols.importunknown(name, maxid)
+
+    def intvalue(self):
+        _assert(self.valuetid in [TID_POSINT, TID_NEGINT], "Not an int")
+
+        self.preparevalue()
+        return self.value
+
+    def stringvalue(self):
+        _assert(self.valuetid == TID_STRING, "Not a string")
+
+        if self.valueisnull:
+            return ""
+
+        self.preparevalue()
+        return self.value
+
+    def symbolvalue(self):
+        _assert(self.valuetid == TID_SYMBOL, "Not a symbol")
+
+        self.preparevalue()
+        result = self.symbols.findbyid(self.value)
+        if result == "":
+            result = "SYMBOL#%d" % self.value
+        return result
+
+    def lobvalue(self):
+        _assert(self.valuetid in [TID_CLOB, TID_BLOB], "Not a LOB type: %s" % self.getfieldname())
+
+        if self.valueisnull:
+            return None
+
+        result = self.read(self.valuelen)
+        self.state = ParserState.AfterValue
+        return result
+
+    def decimalvalue(self):
+        _assert(self.valuetid == TID_DECIMAL, "Not a decimal")
+
+        self.preparevalue()
+        return self.value
+
+    def preparevalue(self):
+        if self.value is None:
+            self.loadscalarvalue()
+
+    def loadscalarvalue(self):
+        if self.valuetid not in [TID_NULL, TID_BOOLEAN, TID_POSINT, TID_NEGINT,
+                                 TID_FLOAT, TID_DECIMAL, TID_TIMESTAMP,
+                                 TID_SYMBOL, TID_STRING]:
+            return
+
+        if self.valueisnull:
+            self.value = None
+            return
+
+        if self.valuetid == TID_STRING:
+            self.value = self.read(self.valuelen).decode("UTF-8")
+
+        elif self.valuetid in (TID_POSINT, TID_NEGINT, TID_SYMBOL):
+            if self.valuelen == 0:
+                self.value = 0
+            else:
+                _assert(self.valuelen <= 4, "int too long: %d" % self.valuelen)
+                v = 0
+                for i in range(self.valuelen - 1, -1, -1):
+                    v = (v | (bord(self.read()) << (i * 8)))
+
+                if self.valuetid == TID_NEGINT:
+                    self.value = -v
+                else:
+                    self.value = v
+
+        elif self.valuetid == TID_DECIMAL:
+            self.value = self.readdecimal()
+
+        #else:
+        #    _assert(False, "Unhandled scalar type %d" % self.valuetid)
+
+        self.state = ParserState.AfterValue
+
+    def clearvalue(self):
+        self.valuetid = -1
+        self.value = None
+        self.valueisnull = False
+        self.valuefieldid = SID_UNKNOWN
+        self.annotations = []
+
+    def loadannotations(self):
+        ln = self.readvaruint()
+        maxpos = self.stream.tell() + ln
+        while self.stream.tell() < maxpos:
+            self.annotations.append(self.readvaruint())
+        self.valuetid = self.readtypeid()
+
+    def checkversionmarker(self):
+        for i in VERSION_MARKER:
+            _assert(self.read() == i, "Unknown version marker")
+
+        self.valuelen = 0
+        self.valuetid = TID_SYMBOL
+        self.value = SID_ION_1_0
+        self.valueisnull = False
+        self.valuefieldid = SID_UNKNOWN
+        self.state = ParserState.AfterValue
+
+    def findcatalogitem(self, name):
+        for result in self.catalog:
+            if result.name == name:
+                return result
+
+    def forceimport(self, symbols):
+        item = IonCatalogItem("Forced", 1, symbols)
+        self.symbols.import_(item, len(symbols))
+
+    def getfieldname(self):
+        if self.valuefieldid == SID_UNKNOWN:
+            return ""
+        return self.symbols.findbyid(self.valuefieldid)
+
+    def getfieldnamesymbol(self):
+        return SymbolToken(self.getfieldname(), self.valuefieldid)
+
+    def gettypename(self):
+        if len(self.annotations) == 0:
+            return ""
+
+        return self.symbols.findbyid(self.annotations[0])
+
+    @staticmethod
+    def printlob(b):
+        if b is None:
+            return "null"
+
+        result = ""
+        for i in b:
+            result += ("%02x " % bord(i))
+
+        if len(result) > 0:
+            result = result[:-1]
+        return result
+
+    def ionwalk(self, supert, indent, lst):
+        while self.hasnext():
+            if supert == TID_STRUCT:
+                L = self.getfieldname() + ":"
+            else:
+                L = ""
+
+            t = self.next()
+            if t in [TID_STRUCT, TID_LIST]:
+                if L != "":
+                    lst.append(indent + L)
+                L = self.gettypename()
+                if L != "":
+                    lst.append(indent + L + "::")
+                if t == TID_STRUCT:
+                    lst.append(indent + "{")
+                else:
+                    lst.append(indent + "[")
+
+                self.stepin()
+                self.ionwalk(t, indent + "  ", lst)
+                self.stepout()
+
+                if t == TID_STRUCT:
+                    lst.append(indent + "}")
+                else:
+                    lst.append(indent + "]")
+
+            else:
+                if t == TID_STRING:
+                    L += ('"%s"' % self.stringvalue())
+                elif t in [TID_CLOB, TID_BLOB]:
+                    L += ("{%s}" % self.printlob(self.lobvalue()))
+                elif t == TID_POSINT:
+                    L += str(self.intvalue())
+                elif t == TID_SYMBOL:
+                    tn = self.gettypename()
+                    if tn != "":
+                        tn += "::"
+                    L += tn + self.symbolvalue()
+                elif t == TID_DECIMAL:
+                    L += str(self.decimalvalue())
+                else:
+                    L += ("TID %d" % t)
+                lst.append(indent + L)
+
+    def print_(self, lst):
+        self.reset()
+        self.ionwalk(-1, "", lst)
+
+
+SYM_NAMES = [ 'com.amazon.drm.Envelope@1.0',
+              'com.amazon.drm.EnvelopeMetadata@1.0', 'size', 'page_size',
+              'encryption_key', 'encryption_transformation',
+              'encryption_voucher', 'signing_key', 'signing_algorithm',
+              'signing_voucher', 'com.amazon.drm.EncryptedPage@1.0',
+              'cipher_text', 'cipher_iv', 'com.amazon.drm.Signature@1.0',
+              'data', 'com.amazon.drm.EnvelopeIndexTable@1.0', 'length',
+              'offset', 'algorithm', 'encoded', 'encryption_algorithm',
+              'hashing_algorithm', 'expires', 'format', 'id',
+              'lock_parameters', 'strategy', 'com.amazon.drm.Key@1.0',
+              'com.amazon.drm.KeySet@1.0', 'com.amazon.drm.PIDv3@1.0',
+              'com.amazon.drm.PlainTextPage@1.0',
+              'com.amazon.drm.PlainText@1.0', 'com.amazon.drm.PrivateKey@1.0',
+              'com.amazon.drm.PublicKey@1.0', 'com.amazon.drm.SecretKey@1.0',
+              'com.amazon.drm.Voucher@1.0', 'public_key', 'private_key',
+              'com.amazon.drm.KeyPair@1.0', 'com.amazon.drm.ProtectedData@1.0',
+              'doctype', 'com.amazon.drm.EnvelopeIndexTableOffset@1.0',
+              'enddoc', 'license_type', 'license', 'watermark', 'key', 'value',
+              'com.amazon.drm.License@1.0', 'category', 'metadata',
+              'categorized_metadata', 'com.amazon.drm.CategorizedMetadata@1.0',
+              'com.amazon.drm.VoucherEnvelope@1.0', 'mac', 'voucher',
+              'com.amazon.drm.ProtectedData@2.0',
+              'com.amazon.drm.Envelope@2.0',
+              'com.amazon.drm.EnvelopeMetadata@2.0',
+              'com.amazon.drm.EncryptedPage@2.0',
+              'com.amazon.drm.PlainText@2.0', 'compression_algorithm',
+              'com.amazon.drm.Compressed@1.0', 'priority', 'refines']
+
+def addprottable(ion):
+    ion.addtocatalog("ProtectedData", 1, SYM_NAMES)
+
+
+def pkcs7pad(msg, blocklen):
+    paddinglen = blocklen - len(msg) % blocklen
+    padding = bchr(paddinglen) * paddinglen
+    return msg + padding
+
+
+def pkcs7unpad(msg, blocklen):
+    _assert(len(msg) % blocklen == 0)
+
+    paddinglen = bord(msg[-1])
+    _assert(paddinglen > 0 and paddinglen <= blocklen, "Incorrect padding - Wrong key")
+    _assert(msg[-paddinglen:] == bchr(paddinglen) * paddinglen, "Incorrect padding - Wrong key")
+
+    return msg[:-paddinglen]
+
+
+class DrmIonVoucher(object):
+    envelope = None
+    voucher = None
+    drmkey = None
+    license_type = "Unknown"
+
+    encalgorithm = ""
+    enctransformation = ""
+    hashalgorithm = ""
+
+    lockparams = None
+
+    ciphertext = b""
+    cipheriv = b""
+    secretkey = b""
+
+    def __init__(self, voucherenv, dsn, secret):
+        self.dsn,self.secret = dsn,secret
+
+        self.lockparams = []
+
+        self.envelope = BinaryIonParser(voucherenv)
+        addprottable(self.envelope)
+
+    def decryptvoucher(self):
+        shared = "PIDv3" + self.encalgorithm + self.enctransformation + self.hashalgorithm
+
+        self.lockparams.sort()
+        for param in self.lockparams:
+            if param == "ACCOUNT_SECRET":
+                shared += param + self.secret
+            elif param == "CLIENT_ID":
+                shared += param + self.dsn
+            else:
+                _assert(False, "Unknown lock parameter: %s" % param)
+
+        sharedsecret = shared.encode("UTF-8")
+
+        key = hmac.new(sharedsecret, sharedsecret[:5], digestmod=hashlib.sha256).digest()
+        aes = AES.new(key[:32], AES.MODE_CBC, self.cipheriv[:16])
+        b = aes.decrypt(self.ciphertext)
+        b = pkcs7unpad(b, 16)
+
+        self.drmkey = BinaryIonParser(StringIO(b))
+        addprottable(self.drmkey)
+
+        _assert(self.drmkey.hasnext() and self.drmkey.next() == TID_LIST and self.drmkey.gettypename() == "com.amazon.drm.KeySet@1.0",
+                "Expected KeySet, got %s" % self.drmkey.gettypename())
+
+        self.drmkey.stepin()
+        while self.drmkey.hasnext():
+            self.drmkey.next()
+            if self.drmkey.gettypename() != "com.amazon.drm.SecretKey@1.0":
+                continue
+
+            self.drmkey.stepin()
+            while self.drmkey.hasnext():
+                self.drmkey.next()
+                if self.drmkey.getfieldname() == "algorithm":
+                    _assert(self.drmkey.stringvalue() == "AES", "Unknown cipher algorithm: %s" % self.drmkey.stringvalue())
+                elif self.drmkey.getfieldname() == "format":
+                    _assert(self.drmkey.stringvalue() == "RAW", "Unknown key format: %s" % self.drmkey.stringvalue())
+                elif self.drmkey.getfieldname() == "encoded":
+                    self.secretkey = self.drmkey.lobvalue()
+
+            self.drmkey.stepout()
+            break
+
+        self.drmkey.stepout()
+
+    def parse(self):
+        self.envelope.reset()
+        _assert(self.envelope.hasnext(), "Envelope is empty")
+        _assert(self.envelope.next() == TID_STRUCT and self.envelope.gettypename() == "com.amazon.drm.VoucherEnvelope@1.0",
+                "Unknown type encountered in envelope, expected VoucherEnvelope")
+
+        self.envelope.stepin()
+        while self.envelope.hasnext():
+            self.envelope.next()
+            field = self.envelope.getfieldname()
+            if field == "voucher":
+                self.voucher = BinaryIonParser(StringIO(self.envelope.lobvalue()))
+                addprottable(self.voucher)
+                continue
+            elif field != "strategy":
+                continue
+
+            _assert(self.envelope.gettypename() == "com.amazon.drm.PIDv3@1.0", "Unknown strategy: %s" % self.envelope.gettypename())
+
+            self.envelope.stepin()
+            while self.envelope.hasnext():
+                self.envelope.next()
+                field = self.envelope.getfieldname()
+                if field == "encryption_algorithm":
+                    self.encalgorithm = self.envelope.stringvalue()
+                elif field == "encryption_transformation":
+                    self.enctransformation = self.envelope.stringvalue()
+                elif field == "hashing_algorithm":
+                    self.hashalgorithm = self.envelope.stringvalue()
+                elif field == "lock_parameters":
+                    self.envelope.stepin()
+                    while self.envelope.hasnext():
+                        _assert(self.envelope.next() == TID_STRING, "Expected string list for lock_parameters")
+                        self.lockparams.append(self.envelope.stringvalue())
+                    self.envelope.stepout()
+
+            self.envelope.stepout()
+
+        self.parsevoucher()
+
+    def parsevoucher(self):
+        _assert(self.voucher.hasnext(), "Voucher is empty")
+        _assert(self.voucher.next() == TID_STRUCT and self.voucher.gettypename() == "com.amazon.drm.Voucher@1.0",
+                "Unknown type, expected Voucher")
+
+        self.voucher.stepin()
+        while self.voucher.hasnext():
+            self.voucher.next()
+
+            if self.voucher.getfieldname() == "cipher_iv":
+                self.cipheriv = self.voucher.lobvalue()
+            elif self.voucher.getfieldname() == "cipher_text":
+                self.ciphertext = self.voucher.lobvalue()
+            elif self.voucher.getfieldname() == "license":
+                _assert(self.voucher.gettypename() == "com.amazon.drm.License@1.0",
+                        "Unknown license: %s" % self.voucher.gettypename())
+                self.voucher.stepin()
+                while self.voucher.hasnext():
+                    self.voucher.next()
+                    if self.voucher.getfieldname() == "license_type":
+                        self.license_type = self.voucher.stringvalue()
+                self.voucher.stepout()
+
+    def printenvelope(self, lst):
+        self.envelope.print_(lst)
+
+    def printkey(self, lst):
+        if self.voucher is None:
+            self.parse()
+        if self.drmkey is None:
+            self.decryptvoucher()
+
+        self.drmkey.print_(lst)
+
+    def printvoucher(self, lst):
+        if self.voucher is None:
+            self.parse()
+
+        self.voucher.print_(lst)
+
+    def getlicensetype(self):
+        return self.license_type
+
+
+class DrmIon(object):
+    ion = None
+    voucher = None
+    vouchername = ""
+    key = b""
+    onvoucherrequired = None
+
+    def __init__(self, ionstream, onvoucherrequired):
+        self.ion = BinaryIonParser(ionstream)
+        addprottable(self.ion)
+        self.onvoucherrequired = onvoucherrequired
+
+    def parse(self, outpages):
+        self.ion.reset()
+
+        _assert(self.ion.hasnext(), "DRMION envelope is empty")
+        _assert(self.ion.next() == TID_SYMBOL and self.ion.gettypename() == "doctype", "Expected doctype symbol")
+        _assert(self.ion.next() == TID_LIST and self.ion.gettypename() in ["com.amazon.drm.Envelope@1.0", "com.amazon.drm.Envelope@2.0"],
+                "Unknown type encountered in DRMION envelope, expected Envelope, got %s" % self.ion.gettypename())
+
+        while True:
+            if self.ion.gettypename() == "enddoc":
+                break
+
+            self.ion.stepin()
+            while self.ion.hasnext():
+                self.ion.next()
+
+                if self.ion.gettypename() in ["com.amazon.drm.EnvelopeMetadata@1.0", "com.amazon.drm.EnvelopeMetadata@2.0"]:
+                    self.ion.stepin()
+                    while self.ion.hasnext():
+                        self.ion.next()
+                        if self.ion.getfieldname() != "encryption_voucher":
+                            continue
+
+                        if self.vouchername == "":
+                            self.vouchername = self.ion.stringvalue()
+                            self.voucher = self.onvoucherrequired(self.vouchername)
+                            self.key = self.voucher.secretkey
+                            _assert(self.key is not None, "Unable to obtain secret key from voucher")
+                        else:
+                            _assert(self.vouchername == self.ion.stringvalue(),
+                                    "Unexpected: Different vouchers required for same file?")
+
+                    self.ion.stepout()
+
+                elif self.ion.gettypename() in ["com.amazon.drm.EncryptedPage@1.0", "com.amazon.drm.EncryptedPage@2.0"]:
+                    decompress = False
+                    ct = None
+                    civ = None
+                    self.ion.stepin()
+                    while self.ion.hasnext():
+                        self.ion.next()
+                        if self.ion.gettypename() == "com.amazon.drm.Compressed@1.0":
+                            decompress = True
+                        if self.ion.getfieldname() == "cipher_text":
+                            ct = self.ion.lobvalue()
+                        elif self.ion.getfieldname() == "cipher_iv":
+                            civ = self.ion.lobvalue()
+
+                    if ct is not None and civ is not None:
+                        self.processpage(ct, civ, outpages, decompress)
+                    self.ion.stepout()
+
+            self.ion.stepout()
+            if not self.ion.hasnext():
+                break
+            self.ion.next()
+
+    def print_(self, lst):
+        self.ion.print_(lst)
+
+    def processpage(self, ct, civ, outpages, decompress):
+        aes = AES.new(self.key[:16], AES.MODE_CBC, civ[:16])
+        msg = pkcs7unpad(aes.decrypt(ct), 16)
+
+        if not decompress:
+            outpages.write(msg)
+            return
+
+        _assert(msg[0] == b"\x00", "LZMA UseFilter not supported")
+
+        if calibre_lzma is not None:
+            with calibre_lzma.decompress(msg[1:], bufsize=0x1000000) as f:
+                f.seek(0)
+                outpages.write(f.read())
+            return
+
+        decomp = lzma.LZMADecompressor(format=lzma.FORMAT_ALONE)
+        while not decomp.eof:
+            segment = decomp.decompress(msg[1:])
+            msg = b"" # Contents were internally buffered after the first call
+            outpages.write(segment)