From 45c043301ef2802e167c514785c6b72f76643db2 Mon Sep 17 00:00:00 2001 From: Caio Dottori Date: Thu, 7 Oct 2021 17:03:55 -0300 Subject: [PATCH] Refactor DER and binary handling structures for better readability and easier translations --- ellipticcurve/__init__.py | 7 +- ellipticcurve/curve.py | 20 +- ellipticcurve/ecdsa.py | 24 +- ellipticcurve/math.py | 68 ++--- ellipticcurve/privateKey.py | 97 +++---- ellipticcurve/publicKey.py | 114 ++++---- ellipticcurve/signature.py | 43 +-- ellipticcurve/utils/base.py | 12 - ellipticcurve/utils/binary.py | 87 +++--- ellipticcurve/utils/compatibility.py | 25 +- ellipticcurve/utils/der.py | 388 +++++++++++---------------- ellipticcurve/utils/oid.py | 35 +++ ellipticcurve/utils/pem.py | 14 + tests/testEcdsa.py | 3 +- tests/testOpenSSL.py | 20 +- tests/testPrivateKey.py | 7 +- tests/testPublicKey.py | 4 +- tests/testRandom.py | 23 ++ tests/testSignature.py | 9 +- tests/testSignatureWithRecoveryId.py | 9 +- 20 files changed, 456 insertions(+), 553 deletions(-) delete mode 100644 ellipticcurve/utils/base.py create mode 100644 ellipticcurve/utils/oid.py create mode 100644 ellipticcurve/utils/pem.py create mode 100644 tests/testRandom.py diff --git a/ellipticcurve/__init__.py b/ellipticcurve/__init__.py index 13ecac1..9757fdf 100644 --- a/ellipticcurve/__init__.py +++ b/ellipticcurve/__init__.py @@ -1 +1,6 @@ -from ellipticcurve.utils.compatibility import * \ No newline at end of file +from ellipticcurve.utils.compatibility import * +from ellipticcurve.privateKey import PrivateKey +from ellipticcurve.publicKey import PublicKey +from ellipticcurve.signature import Signature +from ellipticcurve.utils.file import File +from ellipticcurve.ecdsa import Ecdsa diff --git a/ellipticcurve/curve.py b/ellipticcurve/curve.py index b3d2667..6f9efa7 100644 --- a/ellipticcurve/curve.py +++ b/ellipticcurve/curve.py @@ -17,7 +17,7 @@ def __init__(self, A, B, P, N, Gx, Gy, name, oid, nistName=None): self.G = Point(Gx, Gy) self.name = name self.nistName = nistName - self.oid = oid + self.oid = oid # ASN.1 Object Identifier def contains(self, p): """ @@ -40,7 +40,7 @@ def length(self): N=0xfffffffffffffffffffffffffffffffebaaedce6af48a03bbfd25e8cd0364141, Gx=0x79be667ef9dcbbac55a06295ce870b07029bfcdb2dce28d959f2815b16f81798, Gy=0x483ada7726a3c4655da4fbfc0e1108a8fd17b448a68554199c47d08ffb10d4b8, - oid=(1, 3, 132, 0, 10) + oid=[1, 3, 132, 0, 10] ) prime256v1 = CurveFp( @@ -52,8 +52,9 @@ def length(self): N=0xffffffff00000000ffffffffffffffffbce6faada7179e84f3b9cac2fc632551, Gx=0x6b17d1f2e12c4247f8bce6e563a440f277037d812deb33a0f4a13945d898c296, Gy=0x4fe342e2fe1a7f9b8ee7eb4a7c0f9e162bce33576b315ececbb6406837bf51f5, - oid=(1, 2, 840, 10045, 3, 1, 7), + oid=[1, 2, 840, 10045, 3, 1, 7], ) + p256 = prime256v1 supportedCurves = [ @@ -61,4 +62,15 @@ def length(self): prime256v1, ] -curvesByOid = {curve.oid: curve for curve in supportedCurves} +_curvesByOid = {tuple(curve.oid): curve for curve in supportedCurves} + + +def getCurveByOid(oid): + if oid not in _curvesByOid: + raise Exception( + "Unknown curve with oid %s; The following are registered: %s" % ( + ".".join(oid), + ", ".join([curve.name for curve in supportedCurves]) + ) + ) + return _curvesByOid[oid] diff --git a/ellipticcurve/ecdsa.py b/ellipticcurve/ecdsa.py index a7876f8..24e44f7 100644 --- a/ellipticcurve/ecdsa.py +++ b/ellipticcurve/ecdsa.py @@ -1,8 +1,8 @@ from hashlib import sha256 from .signature import Signature from .math import Math -from .utils.binary import BinaryAscii from .utils.integer import RandomInteger +from .utils.binary import numberFromByteString from .utils.compatibility import * @@ -10,8 +10,8 @@ class Ecdsa: @classmethod def sign(cls, message, privateKey, hashfunc=sha256): - hashMessage = hashfunc(toBytes(message)).digest() - numberMessage = BinaryAscii.numberFromString(hashMessage) + byteMessage = hashfunc(toBytes(message)).digest() + numberMessage = numberFromByteString(byteMessage) curve = privateKey.curve r, s, randSignPoint = 0, 0, None @@ -28,14 +28,14 @@ def sign(cls, message, privateKey, hashfunc=sha256): @classmethod def verify(cls, message, signature, publicKey, hashfunc=sha256): - hashMessage = hashfunc(toBytes(message)).digest() - numberMessage = BinaryAscii.numberFromString(hashMessage) + byteMessage = hashfunc(toBytes(message)).digest() + numberMessage = numberFromByteString(byteMessage) curve = publicKey.curve - sigR = signature.r - sigS = signature.s - inv = Math.inv(sigS, curve.N) - u1 = Math.multiply(curve.G, n=(numberMessage * inv) % curve.N, A=curve.A, P=curve.P, N=curve.N) - u2 = Math.multiply(publicKey.point, n=(sigR * inv) % curve.N, A=curve.A, P=curve.P, N=curve.N) - add = Math.add(u1, u2, P=curve.P, A=curve.A) + r = signature.r + s = signature.s + inv = Math.inv(s, curve.N) + u1 = Math.multiply(curve.G, n=(numberMessage * inv) % curve.N, N=curve.N, A=curve.A, P=curve.P) + u2 = Math.multiply(publicKey.point, n=(r * inv) % curve.N, N=curve.N, A=curve.A, P=curve.P) + add = Math.add(u1, u2, A=curve.A, P=curve.P) modX = add.x % curve.N - return sigR == modX + return r == modX diff --git a/ellipticcurve/math.py b/ellipticcurve/math.py index 146c4d7..c779227 100644 --- a/ellipticcurve/math.py +++ b/ellipticcurve/math.py @@ -16,14 +16,7 @@ def multiply(cls, p, n, N, A, P): :return: Point that represents the sum of First and Second Point """ return cls._fromJacobian( - cls._jacobianMultiply( - cls._toJacobian(p), - n, - N, - A, - P, - ), - P, + cls._jacobianMultiply(cls._toJacobian(p), n, N, A, P), P ) @classmethod @@ -38,13 +31,7 @@ def add(cls, p, q, A, P): :return: Point that represents the sum of First and Second Point """ return cls._fromJacobian( - cls._jacobianAdd( - cls._toJacobian(p), - cls._toJacobian(q), - A, - P, - ), - P, + cls._jacobianAdd(cls._toJacobian(p), cls._toJacobian(q), A, P), P, ) @classmethod @@ -59,12 +46,19 @@ def inv(cls, x, n): if x == 0: return 0 - lm, hm = 1, 0 - low, high = x % n, n + lm = 1 + hm = 0 + low = x % n + high = n + while low > 1: r = high // low - nm, new = hm - lm * r, high - low * r - lm, low, hm, high = nm, new, lm, low + nm = hm - lm * r + nw = high - low * r + high = low + hm = lm + low = nw + lm = nm return lm % n @@ -88,11 +82,10 @@ def _fromJacobian(cls, p, P): :return: Point in default coordinates """ z = cls.inv(p.z, P) + x = (p.x * z ** 2) % P + y = (p.y * z ** 3) % P - return Point( - (p.x * z ** 2) % P, - (p.y * z ** 3) % P, - ) + return Point(x, y, 0) @classmethod def _jacobianDouble(cls, p, A, P): @@ -113,6 +106,7 @@ def _jacobianDouble(cls, p, A, P): nx = (M**2 - 2 * S) % P ny = (M * (S - nx) - 8 * ysq ** 2) % P nz = (2 * p.y * p.z) % P + return Point(nx, ny, nz) @classmethod @@ -126,9 +120,9 @@ def _jacobianAdd(cls, p, q, A, P): :param A: Coefficient of the first-order term of the equation Y^2 = X^3 + A*X + B (mod p) :return: Point that represents the sum of First and Second Point """ - if not p.y: return q + if not q.y: return p @@ -176,31 +170,9 @@ def _jacobianMultiply(cls, p, n, N, A, P): if (n % 2) == 0: return cls._jacobianDouble( - cls._jacobianMultiply( - p, - n // 2, - N, - A, - P - ), - A, - P, + cls._jacobianMultiply(p, n // 2, N, A, P), A, P ) - # (n % 2) == 1: return cls._jacobianAdd( - cls._jacobianDouble( - cls._jacobianMultiply( - p, - n // 2, - N, - A, - P, - ), - A, - P, - ), - p, - A, - P, + cls._jacobianDouble(cls._jacobianMultiply(p, n // 2, N, A, P), A, P), p, A, P ) diff --git a/ellipticcurve/privateKey.py b/ellipticcurve/privateKey.py index 3ada20e..2820357 100644 --- a/ellipticcurve/privateKey.py +++ b/ellipticcurve/privateKey.py @@ -1,12 +1,10 @@ +from .math import Math from .utils.integer import RandomInteger -from .utils.compatibility import * -from .utils.binary import BinaryAscii -from .utils.der import fromPem, removeSequence, removeInteger, removeObject, removeOctetString, removeConstructed, toPem, encodeSequence, encodeInteger, encodeBitString, encodeOid, encodeOctetString, encodeConstructed +from .utils.pem import getPemContent, createPem +from .utils.binary import hexFromByteString, byteStringFromHex, intFromHex, base64FromByteString, byteStringFromBase64 +from .utils.der import hexFromInt, parse, encodeConstructed, DerFieldType, encodePrimitive +from .curve import secp256k1, getCurveByOid from .publicKey import PublicKey -from .curve import secp256k1, curvesByOid, supportedCurves -from .math import Math - -hexAt = "\x00" class PrivateKey: @@ -27,69 +25,48 @@ def publicKey(self): return PublicKey(point=publicPoint, curve=curve) def toString(self): - return BinaryAscii.stringFromNumber(number=self.secret, length=self.curve.length()) + return hexFromInt(self.secret) def toDer(self): - encodedPublicKey = self.publicKey().toString(encoded=True) - - return encodeSequence( - encodeInteger(1), - encodeOctetString(self.toString()), - encodeConstructed(0, encodeOid(*self.curve.oid)), - encodeConstructed(1, encodeBitString(encodedPublicKey)), + publicKeyString = self.publicKey().toString(encoded=True) + hexadecimal = encodeConstructed( + encodePrimitive(DerFieldType.integer, 1), + encodePrimitive(DerFieldType.octetString, hexFromInt(self.secret)), + encodePrimitive(DerFieldType.oidContainer, encodePrimitive(DerFieldType.object, self.curve.oid)), + encodePrimitive(DerFieldType.publicKeyPointContainer, encodePrimitive(DerFieldType.bitString, publicKeyString)) ) + return byteStringFromHex(hexadecimal) def toPem(self): - return toPem(der=toBytes(self.toDer()), name="EC PRIVATE KEY") + der = self.toDer() + return createPem(content=base64FromByteString(der), template=_pemTemplate) @classmethod def fromPem(cls, string): - privateKeyPem = string[string.index("-----BEGIN EC PRIVATE KEY-----"):] - return cls.fromDer(fromPem(privateKeyPem)) + privateKeyPem = getPemContent(pem=string, template=_pemTemplate) + return cls.fromDer(byteStringFromBase64(privateKeyPem)) @classmethod - def fromDer(cls, string): - t, empty = removeSequence(string) - if len(empty) != 0: - raise Exception( - "trailing junk after DER private key: " + - BinaryAscii.hexFromBinary(empty) - ) - - one, t = removeInteger(t) - if one != 1: - raise Exception( - "expected '1' at start of DER private key, got %d" % one - ) - - privateKeyStr, t = removeOctetString(t) - tag, curveOidStr, t = removeConstructed(t) - if tag != 0: - raise Exception("expected tag 0 in DER private key, got %d" % tag) - - oidCurve, empty = removeObject(curveOidStr) - - if len(empty) != 0: - raise Exception( - "trailing junk after DER private key curve_oid: %s" % - BinaryAscii.hexFromBinary(empty) - ) - - if oidCurve not in curvesByOid: - raise Exception( - "unknown curve with oid %s; The following are registered: %s" % ( - oidCurve, - ", ".join([curve.name for curve in supportedCurves]) - ) - ) - - curve = curvesByOid[oidCurve] - - if len(privateKeyStr) < curve.length(): - privateKeyStr = hexAt * (curve.lenght() - len(privateKeyStr)) + privateKeyStr - - return cls.fromString(privateKeyStr, curve) + def fromDer(cls, binary): + hexadecimal = hexFromByteString(binary) + privateKeyFlag, secretHex, curveData, publicKeyString = parse(hexadecimal)[0] + if privateKeyFlag != 1: + raise Exception("Private keys should start with a '1' flag, but a '{flag}' was found instead".format( + flag=privateKeyFlag + )) + curve = getCurveByOid(curveData[0]) + privateKey = cls.fromString(string=secretHex, curve=curve) + if privateKey.publicKey().toString(encoded=True) != publicKeyString[0]: + raise Exception("The public key described inside the private key file doesn't match the actual public key of the pair") + return privateKey @classmethod def fromString(cls, string, curve=secp256k1): - return PrivateKey(secret=BinaryAscii.numberFromString(string), curve=curve) + return PrivateKey(secret=intFromHex(string), curve=curve) + + +_pemTemplate = """ +-----BEGIN EC PRIVATE KEY----- +{content} +-----END EC PRIVATE KEY----- +""" diff --git a/ellipticcurve/publicKey.py b/ellipticcurve/publicKey.py index 84814f2..a289510 100644 --- a/ellipticcurve/publicKey.py +++ b/ellipticcurve/publicKey.py @@ -1,8 +1,8 @@ -from .utils.compatibility import * -from .utils.der import fromPem, removeSequence, removeObject, removeBitString, toPem, encodeSequence, encodeOid, encodeBitString -from .utils.binary import BinaryAscii from .point import Point -from .curve import curvesByOid, supportedCurves, secp256k1 +from .utils.pem import getPemContent, createPem +from .utils.binary import hexFromByteString, byteStringFromHex, intFromHex, base64FromByteString, byteStringFromBase64 +from .utils.der import hexFromInt, parse, DerFieldType, encodeConstructed, encodePrimitive +from .curve import secp256k1, getCurveByOid class PublicKey: @@ -12,86 +12,72 @@ def __init__(self, point, curve): self.curve = curve def toString(self, encoded=False): - xString = BinaryAscii.stringFromNumber( - number=self.point.x, - length=self.curve.length(), - ) - yString = BinaryAscii.stringFromNumber( - number=self.point.y, - length=self.curve.length(), - ) - return "\x00\x04" + xString + yString if encoded else xString + yString + baseLength = 2 * self.curve.length() + xHex = hexFromInt(self.point.x).zfill(baseLength) + yHex = hexFromInt(self.point.y).zfill(baseLength) + string = xHex + yHex + if encoded: + return "0004" + string + return string def toDer(self): - oidEcPublicKey = (1, 2, 840, 10045, 2, 1) - encodeEcAndOid = encodeSequence( - encodeOid(*oidEcPublicKey), - encodeOid(*self.curve.oid), + hexadecimal = encodeConstructed( + encodeConstructed( + encodePrimitive(DerFieldType.object, _ecdsaPublicKeyOid), + encodePrimitive(DerFieldType.object, self.curve.oid), + ), + encodePrimitive(DerFieldType.bitString, self.toString(encoded=True)), ) - - return encodeSequence(encodeEcAndOid, encodeBitString(self.toString(encoded=True))) + return byteStringFromHex(hexadecimal) def toPem(self): - return toPem(der=toBytes(self.toDer()), name="PUBLIC KEY") + der = self.toDer() + return createPem(content=base64FromByteString(der), template=_pemTemplate) @classmethod def fromPem(cls, string): - return cls.fromDer(fromPem(string)) + publicKeyPem = getPemContent(pem=string, template=_pemTemplate) + return cls.fromDer(byteStringFromBase64(publicKeyPem)) @classmethod - def fromDer(cls, string): - s1, empty = removeSequence(string) - if len(empty) != 0: - raise Exception("trailing junk after DER public key: {}".format( - BinaryAscii.hexFromBinary(empty) - )) - - s2, pointBitString = removeSequence(s1) - - oidPk, rest = removeObject(s2) - - oidCurve, empty = removeObject(rest) - if len(empty) != 0: - raise Exception("trailing junk after DER public key objects: {}".format( - BinaryAscii.hexFromBinary(empty) + def fromDer(cls, binary): + hexadecimal = hexFromByteString(binary) + curveData, pointString = parse(hexadecimal)[0] + publicKeyOid, curveOid = curveData + if publicKeyOid != _ecdsaPublicKeyOid: + raise Exception("The Public Key Object Identifier (OID) should be {ecdsaPublicKeyOid}, but {actualOid} was found instead".format( + ecdsaPublicKeyOid=_ecdsaPublicKeyOid, + actualOid=publicKeyOid, )) - - if oidCurve not in curvesByOid: - raise Exception( - "Unknown curve with oid %s. Only the following are available: %s" % ( - oidCurve, - ", ".join([curve.name for curve in supportedCurves]) - ) - ) - - curve = curvesByOid[oidCurve] - - pointStr, empty = removeBitString(pointBitString) - if len(empty) != 0: - raise Exception( - "trailing junk after public key point-string: " + - BinaryAscii.hexFromBinary(empty) - ) - - return cls.fromString(pointStr[2:], curve) + curve = getCurveByOid(curveOid) + return cls.fromString(string=pointString, curve=curve) @classmethod def fromString(cls, string, curve=secp256k1, validatePoint=True): - baseLen = curve.length() + baseLength = 2 * curve.length() + if len(string) > 2 * baseLength and string[:4] == "0004": + string = string[4:] - xs = string[:baseLen] - ys = string[baseLen:] + xs = string[:baseLength] + ys = string[baseLength:] p = Point( - x=BinaryAscii.numberFromString(xs), - y=BinaryAscii.numberFromString(ys), + x=intFromHex(xs), + y=intFromHex(ys), ) - if validatePoint and not curve.contains(p): raise Exception( - "point ({x},{y}) is not valid for curve {name}".format( - x=p.x, y=p.y, name=curve.name - ) + "Point ({x},{y}) is not valid for curve {name}".format(x=p.x, y=p.y, name=curve.name) ) return PublicKey(point=p, curve=curve) + + +_ecdsaPublicKeyOid = (1, 2, 840, 10045, 2, 1) + + +_pemTemplate = """ +-----BEGIN PUBLIC KEY----- +{content} +-----END PUBLIC KEY----- +""" diff --git a/ellipticcurve/signature.py b/ellipticcurve/signature.py index 7e6c1fa..fa189ed 100644 --- a/ellipticcurve/signature.py +++ b/ellipticcurve/signature.py @@ -1,7 +1,6 @@ from .utils.compatibility import * -from .utils.base import Base64 -from .utils.binary import BinaryAscii -from .utils.der import encodeSequence, encodeInteger, removeSequence, removeInteger +from .utils.binary import hexFromByteString, byteStringFromHex, base64FromByteString, byteStringFromBase64 +from .utils.der import parse, encodeConstructed, encodePrimitive, DerFieldType class Signature: @@ -11,35 +10,39 @@ def __init__(self, r, s, recoveryId=None): self.s = s self.recoveryId = recoveryId + def toString(self): + return encodeConstructed( + encodePrimitive(DerFieldType.integer, self.r), + encodePrimitive(DerFieldType.integer, self.s), + ) + def toDer(self, withRecoveryId=False): - encodedSequence = encodeSequence(encodeInteger(self.r), encodeInteger(self.s)) + hexadecimal = self.toString() + encodedSequence = byteStringFromHex(hexadecimal) if not withRecoveryId: return encodedSequence - return chr(27 + self.recoveryId) + encodedSequence + return toBytes(chr(27 + self.recoveryId)) + encodedSequence def toBase64(self, withRecoveryId=False): - return toString(Base64.encode(toBytes(self.toDer(withRecoveryId=withRecoveryId)))) + return base64FromByteString(self.toDer(withRecoveryId)) + + @classmethod + def fromString(cls, string, recoveryId=None): + r, s = parse(string)[0] + return Signature(r=r, s=s, recoveryId=recoveryId) @classmethod - def fromDer(cls, string, recoveryByte=False): + def fromDer(cls, binary, recoveryByte=False): recoveryId = None if recoveryByte: - recoveryId = string[0] if isinstance(string[0], intTypes) else ord(string[0]) + recoveryId = binary[0] if isinstance(binary[0], intTypes) else ord(binary[0]) recoveryId -= 27 - string = string[1:] + binary = binary[1:] - rs, empty = removeSequence(string) - if len(empty) != 0: - raise Exception("trailing junk after DER signature: %s" % BinaryAscii.hexFromBinary(empty)) - - r, rest = removeInteger(rs) - s, empty = removeInteger(rest) - if len(empty) != 0: - raise Exception("trailing junk after DER numbers: %s" % BinaryAscii.hexFromBinary(empty)) - - return Signature(r=r, s=s, recoveryId=recoveryId) + hexadecimal = hexFromByteString(binary) + return cls.fromString(string=hexadecimal, recoveryId=recoveryId) @classmethod def fromBase64(cls, string, recoveryByte=False): - der = Base64.decode(string) + der = byteStringFromBase64(string) return cls.fromDer(der, recoveryByte) diff --git a/ellipticcurve/utils/base.py b/ellipticcurve/utils/base.py deleted file mode 100644 index ee6adcd..0000000 --- a/ellipticcurve/utils/base.py +++ /dev/null @@ -1,12 +0,0 @@ -from base64 import b64encode, b64decode - - -class Base64: - - @classmethod - def decode(cls, string): - return b64decode(string) - - @classmethod - def encode(cls, string): - return b64encode(string) diff --git a/ellipticcurve/utils/binary.py b/ellipticcurve/utils/binary.py index b6653a8..8231cad 100644 --- a/ellipticcurve/utils/binary.py +++ b/ellipticcurve/utils/binary.py @@ -1,49 +1,38 @@ -from .compatibility import * - - -class BinaryAscii: - - @classmethod - def hexFromBinary(cls, data): - """ - Return the hexadecimal representation of the binary data. Every byte of data is converted into the - corresponding 2-digit hex representation. The resulting string is therefore twice as long as the length of data. - - :param data: binary - :return: hexadecimal string - """ - return safeHexFromBinary(data) - - @classmethod - def binaryFromHex(cls, data): - """ - Return the binary data represented by the hexadecimal string hexstr. This function is the inverse of b2a_hex(). - hexstr must contain an even number of hexadecimal digits (which can be upper or lower case), otherwise a TypeError is raised. - - :param data: hexadecimal string - :return: binary - """ - return safeBinaryFromHex(data) - - @classmethod - def numberFromString(cls, string): - """ - Get a number representation of a string - - :param String to be converted in a number - :return: Number in hex from string - """ - return int(cls.hexFromBinary(string), 16) - - @classmethod - def stringFromNumber(cls, number, length): - """ - Get a string representation of a number - - :param number to be converted in a string - :param length max number of character for the string - :return: hexadecimal string - """ - - fmtStr = "%0" + str(2 * length) + "x" - return toString(cls.binaryFromHex((fmtStr % number).encode())) +from base64 import b64encode, b64decode +from ellipticcurve import toString +from ellipticcurve.utils.compatibility import safeHexFromBinary, safeBinaryFromHex + + +def hexFromInt(number): + hexadecimal = "{0:x}".format(number) + if len(hexadecimal) % 2 == 1: + hexadecimal = "0" + hexadecimal + return hexadecimal + + +def intFromHex(hexadecimal): + return int(hexadecimal, 16) + + +def hexFromByteString(byteString): + return safeHexFromBinary(byteString) + + +def byteStringFromHex(hexadecimal): + return safeBinaryFromHex(hexadecimal) + + +def numberFromByteString(byteString): + return intFromHex(hexFromByteString(byteString)) + + +def base64FromByteString(byteString): + return toString(b64encode(byteString)) + + +def byteStringFromBase64(base64String): + return b64decode(base64String) + + +def bitsFromHex(hexadecimal): + return format(intFromHex(hexadecimal), 'b').zfill(4 * len(hexadecimal)) diff --git a/ellipticcurve/utils/compatibility.py b/ellipticcurve/utils/compatibility.py index 48d4f5d..3b22dd3 100644 --- a/ellipticcurve/utils/compatibility.py +++ b/ellipticcurve/utils/compatibility.py @@ -5,35 +5,36 @@ if pyVersion.major == 3: # py3 constants and conversion functions - xrange = range stringTypes = (str,) intTypes = (int, float) - def toString(string): - return string.decode("latin-1") + def toString(string, encoding="utf-8"): + return string.decode(encoding) - def toBytes(string): - return string.encode("latin-1") + def toBytes(string, encoding="utf-8"): + return string.encode(encoding) - def safeBinaryFromHex(hexString): - return unhexlify(hexString) + def safeBinaryFromHex(hexadecimal): + if len(hexadecimal) % 2 == 1: + hexadecimal = "0" + hexadecimal + return unhexlify(hexadecimal) def safeHexFromBinary(byteString): - return hexlify(byteString) + return toString(hexlify(byteString)) else: # py2 constants and conversion functions stringTypes = (str, unicode) intTypes = (int, float, long) - def toString(string): + def toString(string, encoding="utf-8"): return string - def toBytes(string): + def toBytes(string, encoding="utf-8"): return string - def safeBinaryFromHex(hexString): - return unhexlify(hexString) + def safeBinaryFromHex(hexadecimal): + return unhexlify(hexadecimal) def safeHexFromBinary(byteString): return hexlify(byteString) diff --git a/ellipticcurve/utils/der.py b/ellipticcurve/utils/der.py index 1b50558..dd62dda 100644 --- a/ellipticcurve/utils/der.py +++ b/ellipticcurve/utils/der.py @@ -1,239 +1,159 @@ -from .base import Base64 -from .binary import BinaryAscii -from .compatibility import * +from datetime import datetime +from ellipticcurve.utils.oid import oidToHex, oidFromHex +from ellipticcurve.utils.binary import hexFromInt, intFromHex, byteStringFromHex, bitsFromHex -hexAt = "\x00" -hexB = "\x02" -hexC = "\x03" -hexD = "\x04" -hexF = "\x06" -hex0 = "\x30" +class DerFieldType: -hex31 = 0x1f -hex127 = 0x7f -hex129 = 0xa0 -hex160 = 0x80 -hex224 = 0xe0 + integer = "integer" + bitString = "bitString" + octetString = "octetString" + null = "null" + object = "object" + printableString = "printableString" + utcTime = "utcTime" + sequence = "sequence" + set = "set" + oidContainer = "oidContainer" + publicKeyPointContainer = "publicKeyPointContainer" -bytesHex0 = toBytes(hex0) -bytesHexB = toBytes(hexB) -bytesHexC = toBytes(hexC) -bytesHexD = toBytes(hexD) -bytesHexF = toBytes(hexF) - -def encodeSequence(*encodedPieces): - totalLengthLen = sum([len(p) for p in encodedPieces]) - return hex0 + _encodeLength(totalLengthLen) + "".join(encodedPieces) - - -def encodeInteger(x): - assert x >= 0 - t = ("%x" % x).encode() - - if len(t) % 2: - t = toBytes("0") + t - - x = BinaryAscii.binaryFromHex(t) - num = x[0] if isinstance(x[0], intTypes) else ord(x[0]) - - if num <= hex127: - return hexB + chr(len(x)) + toString(x) - return hexB + chr(len(x) + 1) + hexAt + toString(x) - - -def encodeOid(first, second, *pieces): - assert first <= 2 - assert second <= 39 - - encodedPieces = [chr(40 * first + second)] + [_encodeNumber(p) for p in pieces] - body = "".join(encodedPieces) - - return hexF + _encodeLength(len(body)) + body - - -def encodeBitString(t): - return hexC + _encodeLength(len(t)) + t - - -def encodeOctetString(t): - return hexD + _encodeLength(len(t)) + t - - -def encodeConstructed(tag, value): - return chr(hex129 + tag) + _encodeLength(len(value)) + value - - -def removeSequence(string): - _checkSequenceError(string=string, start=bytesHex0, expected="30") - - length, lengthLen = _readLength(string[1:]) - endSeq = 1 + lengthLen + length - - return string[1 + lengthLen: endSeq], string[endSeq:] - - -def removeInteger(string): - _checkSequenceError(string=string, start=bytesHexB, expected="02") - - length, lengthLen = _readLength(string[1:]) - numberBytes = string[1 + lengthLen:1 + lengthLen + length] - rest = string[1 + lengthLen + length:] - nBytes = numberBytes[0] if isinstance( - numberBytes[0], intTypes - ) else ord(numberBytes[0]) - - assert nBytes < hex160 - - return int(BinaryAscii.hexFromBinary(numberBytes), 16), rest - - -def removeObject(string): - _checkSequenceError(string=string, start=bytesHexF, expected="06") - - length, lengthLen = _readLength(string[1:]) - body = string[1 + lengthLen:1 + lengthLen + length] - rest = string[1 + lengthLen + length:] - numbers = [] - - while body: - n, lengthLength = _readNumber(body) - numbers.append(n) - body = body[lengthLength:] - - n0 = numbers.pop(0) - first = n0 // 40 - second = n0 - (40 * first) - numbers.insert(0, first) - numbers.insert(1, second) - - return tuple(numbers), rest - - -def removeBitString(string): - _checkSequenceError(string=string, start=bytesHexC, expected="03") - - length, lengthLen = _readLength(string[1:]) - body = string[1 + lengthLen:1 + lengthLen + length] - rest = string[1 + lengthLen + length:] - - return body, rest - - -def removeOctetString(string): - _checkSequenceError(string=string, start=bytesHexD, expected="04") - - length, lengthLen = _readLength(string[1:]) - body = string[1 + lengthLen:1 + lengthLen + length] - rest = string[1 + lengthLen + length:] - - return body, rest - - -def removeConstructed(string): - s0 = _extractFirstInt(string) - if (s0 & hex224) != hex129: - raise Exception("wanted constructed tag (0xa0-0xbf), got 0x%02x" % s0) - - tag = s0 & hex31 - length, lengthLen = _readLength(string[1:]) - body = string[1 + lengthLen:1 + lengthLen + length] - rest = string[1 + lengthLen + length:] - - return tag, body, rest - - -def fromPem(pem): - t = "".join([ - l.strip() for l in pem.splitlines() - if l and not l.startswith("-----") - ]) - return Base64.decode(t) - - -def toPem(der, name): - b64 = toString(Base64.encode(der)) - lines = ["-----BEGIN " + name + "-----\n"] - lines.extend([ - b64[start:start + 64] + '\n' - for start in xrange(0, len(b64), 64) - ]) - lines.append("-----END " + name + "-----\n") - - return "".join(lines) - - -def _encodeLength(length): - assert length >= 0 - - if length < hex160: - return chr(length) - - s = ("%x" % length).encode() - if len(s) % 2: - s = "0" + s - - s = BinaryAscii.binaryFromHex(s) - lengthLen = len(s) - - return chr(hex160 | lengthLen) + str(s) - - -def _encodeNumber(n): - b128Digits = [] - while n: - b128Digits.insert(0, (n & hex127) | hex160) - n >>= 7 - - if not b128Digits: - b128Digits.append(0) - - b128Digits[-1] &= hex127 - - return "".join([chr(d) for d in b128Digits]) - - -def _readLength(string): - num = _extractFirstInt(string) - if not (num & hex160): - return (num & hex127), 1 - - lengthLen = num & hex127 - - if lengthLen > len(string) - 1: - raise Exception("ran out of length bytes") - - return int(BinaryAscii.hexFromBinary(string[1:1 + lengthLen]), 16), 1 + lengthLen - - -def _readNumber(string): - number = 0 - lengthLen = 0 - while True: - if lengthLen > len(string): - raise Exception("ran out of length bytes") - - number <<= 7 - d = string[lengthLen] - if not isinstance(d, intTypes): - d = ord(d) - - number += (d & hex127) - lengthLen += 1 - if not d & hex160: - break - - return number, lengthLen - - -def _checkSequenceError(string, start, expected): - if not string.startswith(start): - raise Exception( - "wanted sequence (0x%s), got 0x%02x" % - (expected, _extractFirstInt(string)) - ) - - -def _extractFirstInt(string): - return string[0] if isinstance(string[0], intTypes) else ord(string[0]) +_hexTagToType = { + "02": DerFieldType.integer, + "03": DerFieldType.bitString, + "04": DerFieldType.octetString, + "05": DerFieldType.null, + "06": DerFieldType.object, + "13": DerFieldType.printableString, + "17": DerFieldType.utcTime, + "30": DerFieldType.sequence, + "31": DerFieldType.set, + "a0": DerFieldType.oidContainer, + "a1": DerFieldType.publicKeyPointContainer, +} +_typeToHexTag = {v: k for k, v in _hexTagToType.items()} + + +def encodeConstructed(*encodedValues): + return encodePrimitive(DerFieldType.sequence, "".join(encodedValues)) + + +def encodePrimitive(tagType, value): + if tagType == DerFieldType.integer: + value = _encodeInteger(value) + if tagType == DerFieldType.object: + value = oidToHex(value) + return "{tag}{size}{value}".format(tag=_typeToHexTag[tagType], size=_generateLengthBytes(value), value=value) + + +def parse(hexadecimal): + if not hexadecimal: + return [] + typeByte, hexadecimal = hexadecimal[:2], hexadecimal[2:] + length, lengthBytes = _readLengthBytes(hexadecimal) + content, hexadecimal = hexadecimal[lengthBytes: lengthBytes + length], hexadecimal[lengthBytes + length:] + if len(content) < length: + raise Exception("missing bytes in DER parse") + + tagData = _getTagData(typeByte) + if tagData["isConstructed"]: + content = parse(content) + + valueParser = { + DerFieldType.null: _parseNull, + DerFieldType.object: _parseOid, + DerFieldType.utcTime: _parseTime, + DerFieldType.integer: _parseInteger, + DerFieldType.printableString: _parseString, + }.get(tagData["type"], _parseAny) + return [valueParser(content)] + parse(hexadecimal) + + +def _parseAny(hexadecimal): + return hexadecimal + + +def _parseOid(hexadecimal): + return tuple(oidFromHex(hexadecimal)) + + +def _parseTime(hexadecimal): + string = _parseString(hexadecimal) + return datetime.strptime(string, "%y%m%d%H%M%SZ") + + +def _parseString(hexadecimal): + return byteStringFromHex(hexadecimal).decode() + + +def _parseNull(_content): + return None + + +def _parseInteger(hexadecimal): + integer = intFromHex(hexadecimal) + bits = bitsFromHex(hexadecimal[0]) + if bits[0] == "0": # negative numbers are encoded using two's complement + return integer + bitCount = 4 * len(hexadecimal) + return integer - (2 ** bitCount) + + +def _encodeInteger(number): + hexadecimal = hexFromInt(abs(number)) + if number < 0: + bitCount = 4 * len(hexadecimal) + twosComplement = (2 ** bitCount) + number + return hexFromInt(twosComplement) + bits = bitsFromHex(hexadecimal[0]) + if bits[0] == "1": # if first bit was left as 1, number would be parsed as a negative integer with two's complement + hexadecimal = "00" + hexadecimal + return hexadecimal + + +def _readLengthBytes(hexadecimal): + lengthBytes = 2 + lengthIndicator = intFromHex(hexadecimal[0:lengthBytes]) + isShortForm = lengthIndicator < 128 # checks if first bit of byte is 1 (a.k.a. short-form) + if isShortForm: + length = lengthIndicator * 2 + return length, lengthBytes + + lengthLength = lengthIndicator - 128 # nullifies first bit of byte (only used as long-form flag) + if lengthLength == 0: + raise Exception("indefinite length encoding located in DER") + lengthBytes += 2 * lengthLength + length = intFromHex(hexadecimal[2:lengthBytes]) * 2 + return length, lengthBytes + + +def _generateLengthBytes(hexadecimal): + size = len(hexadecimal) // 2 + length = hexFromInt(size) + if size < 128: # checks if first bit of byte should be 0 (a.k.a. short-form flag) + return length.zfill(2) + lengthLength = 128 + len(length) // 2 # +128 sets the first bit of the byte as 1 (a.k.a. long-form flag) + return hexFromInt(lengthLength) + length + + +def _getTagData(tag): + bits = bitsFromHex(tag) + bit8, bit7, bit6 = bits[:3] + + tagClass = { + "0": { + "0": "universal", + "1": "application", + }, + "1": { + "0": "context-specific", + "1": "private", + }, + }[bit8][bit7] + isConstructed = bit6 == "1" + + return { + "class": tagClass, + "isConstructed": isConstructed, + "type": _hexTagToType.get(tag), + } diff --git a/ellipticcurve/utils/oid.py b/ellipticcurve/utils/oid.py new file mode 100644 index 0000000..de1e01e --- /dev/null +++ b/ellipticcurve/utils/oid.py @@ -0,0 +1,35 @@ +from ellipticcurve.utils.binary import intFromHex, hexFromInt + + +def oidFromHex(hexadecimal): + firstByte, remainingBytes = hexadecimal[:2], hexadecimal[2:] + firstByteInt = intFromHex(firstByte) + oid = [firstByteInt // 40, firstByteInt % 40] + oidInt = 0 + while len(remainingBytes) > 0: + byte, remainingBytes = remainingBytes[0:2], remainingBytes[2:] + byteInt = intFromHex(byte) + if byteInt >= 128: + oidInt = byteInt - 128 + continue + oidInt = oidInt * 128 + byteInt + oid.append(oidInt) + oidInt = 0 + return oid + + +def oidToHex(oid): + hexadecimal = hexFromInt(40 * oid[0] + oid[1]) + byteArray = [] + for oidInt in oid[2:]: + endDelta = 0 + while True: + byteInt = oidInt % 128 + endDelta + oidInt = oidInt // 128 + endDelta = 128 + byteArray.append(byteInt) + if oidInt == 0: + break + hexadecimal += "".join(hexFromInt(byteInt).zfill(2) for byteInt in reversed(byteArray)) + byteArray = [] + return hexadecimal diff --git a/ellipticcurve/utils/pem.py b/ellipticcurve/utils/pem.py new file mode 100644 index 0000000..1e58b40 --- /dev/null +++ b/ellipticcurve/utils/pem.py @@ -0,0 +1,14 @@ +from re import search + + +def getPemContent(pem, template): + pattern = template.format(content="(.*)") + return search("".join(pattern.splitlines()), "".join(pem.splitlines())).group(1) + + +def createPem(content, template): + lines = [ + content[start:start + 64] + for start in range(0, len(content), 64) + ] + return template.format(content="\n".join(lines)) diff --git a/tests/testEcdsa.py b/tests/testEcdsa.py index 3999682..94bb914 100644 --- a/tests/testEcdsa.py +++ b/tests/testEcdsa.py @@ -1,6 +1,5 @@ from unittest.case import TestCase -from ellipticcurve.ecdsa import Ecdsa -from ellipticcurve.privateKey import PrivateKey +from ellipticcurve import Ecdsa, PrivateKey class EcdsaTest(TestCase): diff --git a/tests/testOpenSSL.py b/tests/testOpenSSL.py index 09f31e2..1c5f9ba 100644 --- a/tests/testOpenSSL.py +++ b/tests/testOpenSSL.py @@ -1,22 +1,16 @@ -# coding=utf-8 - from unittest.case import TestCase -from ellipticcurve.ecdsa import Ecdsa -from ellipticcurve.privateKey import PrivateKey -from ellipticcurve.publicKey import PublicKey -from ellipticcurve.signature import Signature -from ellipticcurve.utils.file import File +from ellipticcurve import Ecdsa, PrivateKey, PublicKey, Signature, File class OpensslTest(TestCase): def testAssign(self): # Generated by: openssl ecparam -name secp256k1 -genkey -out privateKey.pem - privateKeyPem = File.read("tests/privateKey.pem") + privateKeyPem = File.read("privateKey.pem") privateKey = PrivateKey.fromPem(privateKeyPem) - message = File.read("tests/message.txt") + message = File.read("message.txt") signature = Ecdsa.sign(message=message, privateKey=privateKey) @@ -27,15 +21,15 @@ def testAssign(self): def testVerifySignature(self): # openssl ec -in privateKey.pem -pubout -out publicKey.pem - publicKeyPem = File.read("tests/publicKey.pem") + publicKeyPem = File.read("publicKey.pem") # openssl dgst -sha256 -sign privateKey.pem -out signature.binary message.txt - signatureDer = File.read("tests/signatureDer.txt", "rb") + signatureDer = File.read("signatureDer.txt", "rb") - message = File.read("tests/message.txt") + message = File.read("message.txt") publicKey = PublicKey.fromPem(publicKeyPem) - signature = Signature.fromDer(string=signatureDer) + signature = Signature.fromDer(binary=signatureDer) self.assertTrue(Ecdsa.verify(message=message, signature=signature, publicKey=publicKey)) diff --git a/tests/testPrivateKey.py b/tests/testPrivateKey.py index 4157eda..bdd98cd 100644 --- a/tests/testPrivateKey.py +++ b/tests/testPrivateKey.py @@ -1,8 +1,5 @@ -# coding=utf-8 - from unittest.case import TestCase from ellipticcurve.privateKey import PrivateKey -from ellipticcurve.utils.compatibility import * class PrivateKeyTest(TestCase): @@ -17,13 +14,13 @@ def testPemConversion(self): def testDerConversion(self): privateKey1 = PrivateKey() der = privateKey1.toDer() - privateKey2 = PrivateKey.fromDer(toBytes(der)) + privateKey2 = PrivateKey.fromDer(der) self.assertEqual(privateKey1.secret, privateKey2.secret) self.assertEqual(privateKey1.curve, privateKey2.curve) def testStringConversion(self): privateKey1 = PrivateKey() string = privateKey1.toString() - privateKey2 = PrivateKey.fromString(toBytes(string)) + privateKey2 = PrivateKey.fromString(string) self.assertEqual(privateKey1.secret, privateKey2.secret) self.assertEqual(privateKey1.curve, privateKey2.curve) diff --git a/tests/testPublicKey.py b/tests/testPublicKey.py index d95ebbe..404941e 100644 --- a/tests/testPublicKey.py +++ b/tests/testPublicKey.py @@ -1,5 +1,3 @@ -# coding=utf-8 - from unittest.case import TestCase from ellipticcurve.privateKey import PrivateKey from ellipticcurve.publicKey import PublicKey @@ -21,7 +19,7 @@ def testDerConversion(self): privateKey = PrivateKey() publicKey1 = privateKey.publicKey() der = publicKey1.toDer() - publicKey2 = PublicKey.fromDer(toBytes(der)) + publicKey2 = PublicKey.fromDer(der) self.assertEqual(publicKey1.point.x, publicKey2.point.x) self.assertEqual(publicKey1.point.y, publicKey2.point.y) self.assertEqual(publicKey1.curve, publicKey2.curve) diff --git a/tests/testRandom.py b/tests/testRandom.py new file mode 100644 index 0000000..6422109 --- /dev/null +++ b/tests/testRandom.py @@ -0,0 +1,23 @@ +from unittest.case import TestCase +from ellipticcurve import Ecdsa, Signature, PublicKey, PrivateKey + + +class RandomTest(TestCase): + + def testMany(self): + for _ in range(1000): + privateKey1 = PrivateKey() + publicKey1 = privateKey1.publicKey() + + privateKeyPem = privateKey1.toPem() + publicKeyPem = publicKey1.toPem() + + privateKey2 = PrivateKey.fromPem(privateKeyPem) + publicKey2 = PublicKey.fromPem(publicKeyPem) + + message = "test" + + signatureBase64 = Ecdsa.sign(message=message, privateKey=privateKey2).toBase64() + signature = Signature.fromBase64(signatureBase64) + + self.assertTrue(Ecdsa.verify(message=message, signature=signature, publicKey=publicKey2)) diff --git a/tests/testSignature.py b/tests/testSignature.py index 3870354..0b9638e 100644 --- a/tests/testSignature.py +++ b/tests/testSignature.py @@ -1,10 +1,5 @@ -# coding=utf-8 - from unittest.case import TestCase -from ellipticcurve.ecdsa import Ecdsa -from ellipticcurve.privateKey import PrivateKey -from ellipticcurve.signature import Signature -from ellipticcurve.utils.compatibility import * +from ellipticcurve import Ecdsa, PrivateKey, Signature class SignatureTest(TestCase): @@ -16,7 +11,7 @@ def testDerConversion(self): signature1 = Ecdsa.sign(message, privateKey) der = signature1.toDer() - signature2 = Signature.fromDer(toBytes(der)) + signature2 = Signature.fromDer(der) self.assertEqual(signature1.r, signature2.r) self.assertEqual(signature1.s, signature2.s) diff --git a/tests/testSignatureWithRecoveryId.py b/tests/testSignatureWithRecoveryId.py index 5369cf0..4911d9a 100644 --- a/tests/testSignatureWithRecoveryId.py +++ b/tests/testSignatureWithRecoveryId.py @@ -1,10 +1,5 @@ -# coding=utf-8 - from unittest.case import TestCase -from ellipticcurve.ecdsa import Ecdsa -from ellipticcurve.privateKey import PrivateKey -from ellipticcurve.signature import Signature -from ellipticcurve.utils.compatibility import * +from ellipticcurve import Ecdsa, PrivateKey, Signature class SignatureTest(TestCase): @@ -16,7 +11,7 @@ def testDerConversion(self): signature1 = Ecdsa.sign(message, privateKey) der = signature1.toDer(withRecoveryId=True) - signature2 = Signature.fromDer(toBytes(der), recoveryByte=True) + signature2 = Signature.fromDer(der, recoveryByte=True) self.assertEqual(signature1.r, signature2.r) self.assertEqual(signature1.s, signature2.s)