1
0
mirror of https://gitlab.com/MoonTestUse1/AdministrationItDepartmens.git synced 2025-08-14 00:25:46 +02:00

Initial commit

This commit is contained in:
MoonTestUse1
2024-12-23 19:27:44 +06:00
commit e81df4c87e
4952 changed files with 1705479 additions and 0 deletions

View File

@@ -0,0 +1,10 @@
__version__ = "3.3.0"
__author__ = "Michael Davis"
__license__ = "MIT"
__copyright__ = "Copyright 2016 Michael Davis"
from .exceptions import ExpiredSignatureError # noqa: F401
from .exceptions import JOSEError # noqa: F401
from .exceptions import JWSError # noqa: F401
from .exceptions import JWTError # noqa: F401

View File

@@ -0,0 +1,32 @@
try:
from jose.backends.cryptography_backend import get_random_bytes # noqa: F401
except ImportError:
try:
from jose.backends.pycrypto_backend import get_random_bytes # noqa: F401
except ImportError:
from jose.backends.native import get_random_bytes # noqa: F401
try:
from jose.backends.cryptography_backend import CryptographyRSAKey as RSAKey # noqa: F401
except ImportError:
try:
from jose.backends.rsa_backend import RSAKey # noqa: F401
except ImportError:
RSAKey = None
try:
from jose.backends.cryptography_backend import CryptographyECKey as ECKey # noqa: F401
except ImportError:
from jose.backends.ecdsa_backend import ECDSAECKey as ECKey # noqa: F401
try:
from jose.backends.cryptography_backend import CryptographyAESKey as AESKey # noqa: F401
except ImportError:
AESKey = None
try:
from jose.backends.cryptography_backend import CryptographyHMACKey as HMACKey # noqa: F401
except ImportError:
from jose.backends.native import HMACKey # noqa: F401
from .base import DIRKey # noqa: F401

View File

@@ -0,0 +1,83 @@
"""ASN1 encoding helpers for converting between PKCS1 and PKCS8.
Required by rsa_backend but not cryptography_backend.
"""
from pyasn1.codec.der import decoder, encoder
from pyasn1.type import namedtype, univ
RSA_ENCRYPTION_ASN1_OID = "1.2.840.113549.1.1.1"
class RsaAlgorithmIdentifier(univ.Sequence):
"""ASN1 structure for recording RSA PrivateKeyAlgorithm identifiers."""
componentType = namedtype.NamedTypes(
namedtype.NamedType("rsaEncryption", univ.ObjectIdentifier()), namedtype.NamedType("parameters", univ.Null())
)
class PKCS8PrivateKey(univ.Sequence):
"""ASN1 structure for recording PKCS8 private keys."""
componentType = namedtype.NamedTypes(
namedtype.NamedType("version", univ.Integer()),
namedtype.NamedType("privateKeyAlgorithm", RsaAlgorithmIdentifier()),
namedtype.NamedType("privateKey", univ.OctetString()),
)
class PublicKeyInfo(univ.Sequence):
"""ASN1 structure for recording PKCS8 public keys."""
componentType = namedtype.NamedTypes(
namedtype.NamedType("algorithm", RsaAlgorithmIdentifier()), namedtype.NamedType("publicKey", univ.BitString())
)
def rsa_private_key_pkcs8_to_pkcs1(pkcs8_key):
"""Convert a PKCS8-encoded RSA private key to PKCS1."""
decoded_values = decoder.decode(pkcs8_key, asn1Spec=PKCS8PrivateKey())
try:
decoded_key = decoded_values[0]
except IndexError:
raise ValueError("Invalid private key encoding")
return decoded_key["privateKey"]
def rsa_private_key_pkcs1_to_pkcs8(pkcs1_key):
"""Convert a PKCS1-encoded RSA private key to PKCS8."""
algorithm = RsaAlgorithmIdentifier()
algorithm["rsaEncryption"] = RSA_ENCRYPTION_ASN1_OID
pkcs8_key = PKCS8PrivateKey()
pkcs8_key["version"] = 0
pkcs8_key["privateKeyAlgorithm"] = algorithm
pkcs8_key["privateKey"] = pkcs1_key
return encoder.encode(pkcs8_key)
def rsa_public_key_pkcs1_to_pkcs8(pkcs1_key):
"""Convert a PKCS1-encoded RSA private key to PKCS8."""
algorithm = RsaAlgorithmIdentifier()
algorithm["rsaEncryption"] = RSA_ENCRYPTION_ASN1_OID
pkcs8_key = PublicKeyInfo()
pkcs8_key["algorithm"] = algorithm
pkcs8_key["publicKey"] = univ.BitString.fromOctetString(pkcs1_key)
return encoder.encode(pkcs8_key)
def rsa_public_key_pkcs8_to_pkcs1(pkcs8_key):
"""Convert a PKCS8-encoded RSA private key to PKCS1."""
decoded_values = decoder.decode(pkcs8_key, asn1Spec=PublicKeyInfo())
try:
decoded_key = decoded_values[0]
except IndexError:
raise ValueError("Invalid public key encoding.")
return decoded_key["publicKey"].asOctets()

View File

@@ -0,0 +1,89 @@
from ..utils import base64url_encode, ensure_binary
class Key:
"""
A simple interface for implementing JWK keys.
"""
def __init__(self, key, algorithm):
pass
def sign(self, msg):
raise NotImplementedError()
def verify(self, msg, sig):
raise NotImplementedError()
def public_key(self):
raise NotImplementedError()
def to_pem(self):
raise NotImplementedError()
def to_dict(self):
raise NotImplementedError()
def encrypt(self, plain_text, aad=None):
"""
Encrypt the plain text and generate an auth tag if appropriate
Args:
plain_text (bytes): Data to encrypt
aad (bytes, optional): Authenticated Additional Data if key's algorithm supports auth mode
Returns:
(bytes, bytes, bytes): IV, cipher text, and auth tag
"""
raise NotImplementedError()
def decrypt(self, cipher_text, iv=None, aad=None, tag=None):
"""
Decrypt the cipher text and validate the auth tag if present
Args:
cipher_text (bytes): Cipher text to decrypt
iv (bytes): IV if block mode
aad (bytes): Additional Authenticated Data to verify if auth mode
tag (bytes): Authentication tag if auth mode
Returns:
bytes: Decrypted value
"""
raise NotImplementedError()
def wrap_key(self, key_data):
"""
Wrap the the plain text key data
Args:
key_data (bytes): Key data to wrap
Returns:
bytes: Wrapped key
"""
raise NotImplementedError()
def unwrap_key(self, wrapped_key):
"""
Unwrap the the wrapped key data
Args:
wrapped_key (bytes): Wrapped key data to unwrap
Returns:
bytes: Unwrapped key
"""
raise NotImplementedError()
class DIRKey(Key):
def __init__(self, key_data, algorithm):
self._key = ensure_binary(key_data)
self._alg = algorithm
def to_dict(self):
return {
"alg": self._alg,
"kty": "oct",
"k": base64url_encode(self._key),
}

View File

@@ -0,0 +1,605 @@
import math
import warnings
from cryptography.exceptions import InvalidSignature, InvalidTag
from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.bindings.openssl.binding import Binding
from cryptography.hazmat.primitives import hashes, hmac, serialization
from cryptography.hazmat.primitives.asymmetric import ec, padding, rsa
from cryptography.hazmat.primitives.asymmetric.utils import decode_dss_signature, encode_dss_signature
from cryptography.hazmat.primitives.ciphers import Cipher, aead, algorithms, modes
from cryptography.hazmat.primitives.keywrap import InvalidUnwrap, aes_key_unwrap, aes_key_wrap
from cryptography.hazmat.primitives.padding import PKCS7
from cryptography.hazmat.primitives.serialization import load_pem_private_key, load_pem_public_key
from cryptography.utils import int_to_bytes
from cryptography.x509 import load_pem_x509_certificate
from ..constants import ALGORITHMS
from ..exceptions import JWEError, JWKError
from ..utils import base64_to_long, base64url_decode, base64url_encode, ensure_binary, long_to_base64
from .base import Key
_binding = None
def get_random_bytes(num_bytes):
"""
Get random bytes
Currently, Cryptography returns OS random bytes. If you want OpenSSL
generated random bytes, you'll have to switch the RAND engine after
initializing the OpenSSL backend
Args:
num_bytes (int): Number of random bytes to generate and return
Returns:
bytes: Random bytes
"""
global _binding
if _binding is None:
_binding = Binding()
buf = _binding.ffi.new("char[]", num_bytes)
_binding.lib.RAND_bytes(buf, num_bytes)
rand_bytes = _binding.ffi.buffer(buf, num_bytes)[:]
return rand_bytes
class CryptographyECKey(Key):
SHA256 = hashes.SHA256
SHA384 = hashes.SHA384
SHA512 = hashes.SHA512
def __init__(self, key, algorithm, cryptography_backend=default_backend):
if algorithm not in ALGORITHMS.EC:
raise JWKError("hash_alg: %s is not a valid hash algorithm" % algorithm)
self.hash_alg = {
ALGORITHMS.ES256: self.SHA256,
ALGORITHMS.ES384: self.SHA384,
ALGORITHMS.ES512: self.SHA512,
}.get(algorithm)
self._algorithm = algorithm
self.cryptography_backend = cryptography_backend
if hasattr(key, "public_bytes") or hasattr(key, "private_bytes"):
self.prepared_key = key
return
if hasattr(key, "to_pem"):
# convert to PEM and let cryptography below load it as PEM
key = key.to_pem().decode("utf-8")
if isinstance(key, dict):
self.prepared_key = self._process_jwk(key)
return
if isinstance(key, str):
key = key.encode("utf-8")
if isinstance(key, bytes):
# Attempt to load key. We don't know if it's
# a Public Key or a Private Key, so we try
# the Public Key first.
try:
try:
key = load_pem_public_key(key, self.cryptography_backend())
except ValueError:
key = load_pem_private_key(key, password=None, backend=self.cryptography_backend())
except Exception as e:
raise JWKError(e)
self.prepared_key = key
return
raise JWKError("Unable to parse an ECKey from key: %s" % key)
def _process_jwk(self, jwk_dict):
if not jwk_dict.get("kty") == "EC":
raise JWKError("Incorrect key type. Expected: 'EC', Received: %s" % jwk_dict.get("kty"))
if not all(k in jwk_dict for k in ["x", "y", "crv"]):
raise JWKError("Mandatory parameters are missing")
x = base64_to_long(jwk_dict.get("x"))
y = base64_to_long(jwk_dict.get("y"))
curve = {
"P-256": ec.SECP256R1,
"P-384": ec.SECP384R1,
"P-521": ec.SECP521R1,
}[jwk_dict["crv"]]
public = ec.EllipticCurvePublicNumbers(x, y, curve())
if "d" in jwk_dict:
d = base64_to_long(jwk_dict.get("d"))
private = ec.EllipticCurvePrivateNumbers(d, public)
return private.private_key(self.cryptography_backend())
else:
return public.public_key(self.cryptography_backend())
def _sig_component_length(self):
"""Determine the correct serialization length for an encoded signature component.
This is the number of bytes required to encode the maximum key value.
"""
return int(math.ceil(self.prepared_key.key_size / 8.0))
def _der_to_raw(self, der_signature):
"""Convert signature from DER encoding to RAW encoding."""
r, s = decode_dss_signature(der_signature)
component_length = self._sig_component_length()
return int_to_bytes(r, component_length) + int_to_bytes(s, component_length)
def _raw_to_der(self, raw_signature):
"""Convert signature from RAW encoding to DER encoding."""
component_length = self._sig_component_length()
if len(raw_signature) != int(2 * component_length):
raise ValueError("Invalid signature")
r_bytes = raw_signature[:component_length]
s_bytes = raw_signature[component_length:]
r = int.from_bytes(r_bytes, "big")
s = int.from_bytes(s_bytes, "big")
return encode_dss_signature(r, s)
def sign(self, msg):
if self.hash_alg.digest_size * 8 > self.prepared_key.curve.key_size:
raise TypeError(
"this curve (%s) is too short "
"for your digest (%d)" % (self.prepared_key.curve.name, 8 * self.hash_alg.digest_size)
)
signature = self.prepared_key.sign(msg, ec.ECDSA(self.hash_alg()))
return self._der_to_raw(signature)
def verify(self, msg, sig):
try:
signature = self._raw_to_der(sig)
self.prepared_key.verify(signature, msg, ec.ECDSA(self.hash_alg()))
return True
except Exception:
return False
def is_public(self):
return hasattr(self.prepared_key, "public_bytes")
def public_key(self):
if self.is_public():
return self
return self.__class__(self.prepared_key.public_key(), self._algorithm)
def to_pem(self):
if self.is_public():
pem = self.prepared_key.public_bytes(
encoding=serialization.Encoding.PEM, format=serialization.PublicFormat.SubjectPublicKeyInfo
)
return pem
pem = self.prepared_key.private_bytes(
encoding=serialization.Encoding.PEM,
format=serialization.PrivateFormat.TraditionalOpenSSL,
encryption_algorithm=serialization.NoEncryption(),
)
return pem
def to_dict(self):
if not self.is_public():
public_key = self.prepared_key.public_key()
else:
public_key = self.prepared_key
crv = {
"secp256r1": "P-256",
"secp384r1": "P-384",
"secp521r1": "P-521",
}[self.prepared_key.curve.name]
# Calculate the key size in bytes. Section 6.2.1.2 and 6.2.1.3 of
# RFC7518 prescribes that the 'x', 'y' and 'd' parameters of the curve
# points must be encoded as octed-strings of this length.
key_size = (self.prepared_key.curve.key_size + 7) // 8
data = {
"alg": self._algorithm,
"kty": "EC",
"crv": crv,
"x": long_to_base64(public_key.public_numbers().x, size=key_size).decode("ASCII"),
"y": long_to_base64(public_key.public_numbers().y, size=key_size).decode("ASCII"),
}
if not self.is_public():
private_value = self.prepared_key.private_numbers().private_value
data["d"] = long_to_base64(private_value, size=key_size).decode("ASCII")
return data
class CryptographyRSAKey(Key):
SHA256 = hashes.SHA256
SHA384 = hashes.SHA384
SHA512 = hashes.SHA512
RSA1_5 = padding.PKCS1v15()
RSA_OAEP = padding.OAEP(padding.MGF1(hashes.SHA1()), hashes.SHA1(), None)
RSA_OAEP_256 = padding.OAEP(padding.MGF1(hashes.SHA256()), hashes.SHA256(), None)
def __init__(self, key, algorithm, cryptography_backend=default_backend):
if algorithm not in ALGORITHMS.RSA:
raise JWKError("hash_alg: %s is not a valid hash algorithm" % algorithm)
self.hash_alg = {
ALGORITHMS.RS256: self.SHA256,
ALGORITHMS.RS384: self.SHA384,
ALGORITHMS.RS512: self.SHA512,
}.get(algorithm)
self._algorithm = algorithm
self.padding = {
ALGORITHMS.RSA1_5: self.RSA1_5,
ALGORITHMS.RSA_OAEP: self.RSA_OAEP,
ALGORITHMS.RSA_OAEP_256: self.RSA_OAEP_256,
}.get(algorithm)
self.cryptography_backend = cryptography_backend
# if it conforms to RSAPublicKey interface
if hasattr(key, "public_bytes") and hasattr(key, "public_numbers"):
self.prepared_key = key
return
if isinstance(key, dict):
self.prepared_key = self._process_jwk(key)
return
if isinstance(key, str):
key = key.encode("utf-8")
if isinstance(key, bytes):
try:
if key.startswith(b"-----BEGIN CERTIFICATE-----"):
self._process_cert(key)
return
try:
self.prepared_key = load_pem_public_key(key, self.cryptography_backend())
except ValueError:
self.prepared_key = load_pem_private_key(key, password=None, backend=self.cryptography_backend())
except Exception as e:
raise JWKError(e)
return
raise JWKError("Unable to parse an RSA_JWK from key: %s" % key)
def _process_jwk(self, jwk_dict):
if not jwk_dict.get("kty") == "RSA":
raise JWKError("Incorrect key type. Expected: 'RSA', Received: %s" % jwk_dict.get("kty"))
e = base64_to_long(jwk_dict.get("e", 256))
n = base64_to_long(jwk_dict.get("n"))
public = rsa.RSAPublicNumbers(e, n)
if "d" not in jwk_dict:
return public.public_key(self.cryptography_backend())
else:
# This is a private key.
d = base64_to_long(jwk_dict.get("d"))
extra_params = ["p", "q", "dp", "dq", "qi"]
if any(k in jwk_dict for k in extra_params):
# Precomputed private key parameters are available.
if not all(k in jwk_dict for k in extra_params):
# These values must be present when 'p' is according to
# Section 6.3.2 of RFC7518, so if they are not we raise
# an error.
raise JWKError("Precomputed private key parameters are incomplete.")
p = base64_to_long(jwk_dict["p"])
q = base64_to_long(jwk_dict["q"])
dp = base64_to_long(jwk_dict["dp"])
dq = base64_to_long(jwk_dict["dq"])
qi = base64_to_long(jwk_dict["qi"])
else:
# The precomputed private key parameters are not available,
# so we use cryptography's API to fill them in.
p, q = rsa.rsa_recover_prime_factors(n, e, d)
dp = rsa.rsa_crt_dmp1(d, p)
dq = rsa.rsa_crt_dmq1(d, q)
qi = rsa.rsa_crt_iqmp(p, q)
private = rsa.RSAPrivateNumbers(p, q, d, dp, dq, qi, public)
return private.private_key(self.cryptography_backend())
def _process_cert(self, key):
key = load_pem_x509_certificate(key, self.cryptography_backend())
self.prepared_key = key.public_key()
def sign(self, msg):
try:
signature = self.prepared_key.sign(msg, padding.PKCS1v15(), self.hash_alg())
except Exception as e:
raise JWKError(e)
return signature
def verify(self, msg, sig):
if not self.is_public():
warnings.warn("Attempting to verify a message with a private key. " "This is not recommended.")
try:
self.public_key().prepared_key.verify(sig, msg, padding.PKCS1v15(), self.hash_alg())
return True
except InvalidSignature:
return False
def is_public(self):
return hasattr(self.prepared_key, "public_bytes")
def public_key(self):
if self.is_public():
return self
return self.__class__(self.prepared_key.public_key(), self._algorithm)
def to_pem(self, pem_format="PKCS8"):
if self.is_public():
if pem_format == "PKCS8":
fmt = serialization.PublicFormat.SubjectPublicKeyInfo
elif pem_format == "PKCS1":
fmt = serialization.PublicFormat.PKCS1
else:
raise ValueError("Invalid format specified: %r" % pem_format)
pem = self.prepared_key.public_bytes(encoding=serialization.Encoding.PEM, format=fmt)
return pem
if pem_format == "PKCS8":
fmt = serialization.PrivateFormat.PKCS8
elif pem_format == "PKCS1":
fmt = serialization.PrivateFormat.TraditionalOpenSSL
else:
raise ValueError("Invalid format specified: %r" % pem_format)
return self.prepared_key.private_bytes(
encoding=serialization.Encoding.PEM, format=fmt, encryption_algorithm=serialization.NoEncryption()
)
def to_dict(self):
if not self.is_public():
public_key = self.prepared_key.public_key()
else:
public_key = self.prepared_key
data = {
"alg": self._algorithm,
"kty": "RSA",
"n": long_to_base64(public_key.public_numbers().n).decode("ASCII"),
"e": long_to_base64(public_key.public_numbers().e).decode("ASCII"),
}
if not self.is_public():
data.update(
{
"d": long_to_base64(self.prepared_key.private_numbers().d).decode("ASCII"),
"p": long_to_base64(self.prepared_key.private_numbers().p).decode("ASCII"),
"q": long_to_base64(self.prepared_key.private_numbers().q).decode("ASCII"),
"dp": long_to_base64(self.prepared_key.private_numbers().dmp1).decode("ASCII"),
"dq": long_to_base64(self.prepared_key.private_numbers().dmq1).decode("ASCII"),
"qi": long_to_base64(self.prepared_key.private_numbers().iqmp).decode("ASCII"),
}
)
return data
def wrap_key(self, key_data):
try:
wrapped_key = self.prepared_key.encrypt(key_data, self.padding)
except Exception as e:
raise JWEError(e)
return wrapped_key
def unwrap_key(self, wrapped_key):
try:
unwrapped_key = self.prepared_key.decrypt(wrapped_key, self.padding)
return unwrapped_key
except Exception as e:
raise JWEError(e)
class CryptographyAESKey(Key):
KEY_128 = (ALGORITHMS.A128GCM, ALGORITHMS.A128GCMKW, ALGORITHMS.A128KW, ALGORITHMS.A128CBC)
KEY_192 = (ALGORITHMS.A192GCM, ALGORITHMS.A192GCMKW, ALGORITHMS.A192KW, ALGORITHMS.A192CBC)
KEY_256 = (
ALGORITHMS.A256GCM,
ALGORITHMS.A256GCMKW,
ALGORITHMS.A256KW,
ALGORITHMS.A128CBC_HS256,
ALGORITHMS.A256CBC,
)
KEY_384 = (ALGORITHMS.A192CBC_HS384,)
KEY_512 = (ALGORITHMS.A256CBC_HS512,)
AES_KW_ALGS = (ALGORITHMS.A128KW, ALGORITHMS.A192KW, ALGORITHMS.A256KW)
MODES = {
ALGORITHMS.A128GCM: modes.GCM,
ALGORITHMS.A192GCM: modes.GCM,
ALGORITHMS.A256GCM: modes.GCM,
ALGORITHMS.A128CBC_HS256: modes.CBC,
ALGORITHMS.A192CBC_HS384: modes.CBC,
ALGORITHMS.A256CBC_HS512: modes.CBC,
ALGORITHMS.A128CBC: modes.CBC,
ALGORITHMS.A192CBC: modes.CBC,
ALGORITHMS.A256CBC: modes.CBC,
ALGORITHMS.A128GCMKW: modes.GCM,
ALGORITHMS.A192GCMKW: modes.GCM,
ALGORITHMS.A256GCMKW: modes.GCM,
ALGORITHMS.A128KW: None,
ALGORITHMS.A192KW: None,
ALGORITHMS.A256KW: None,
}
def __init__(self, key, algorithm):
if algorithm not in ALGORITHMS.AES:
raise JWKError("%s is not a valid AES algorithm" % algorithm)
if algorithm not in ALGORITHMS.SUPPORTED.union(ALGORITHMS.AES_PSEUDO):
raise JWKError("%s is not a supported algorithm" % algorithm)
self._algorithm = algorithm
self._mode = self.MODES.get(self._algorithm)
if algorithm in self.KEY_128 and len(key) != 16:
raise JWKError(f"Key must be 128 bit for alg {algorithm}")
elif algorithm in self.KEY_192 and len(key) != 24:
raise JWKError(f"Key must be 192 bit for alg {algorithm}")
elif algorithm in self.KEY_256 and len(key) != 32:
raise JWKError(f"Key must be 256 bit for alg {algorithm}")
elif algorithm in self.KEY_384 and len(key) != 48:
raise JWKError(f"Key must be 384 bit for alg {algorithm}")
elif algorithm in self.KEY_512 and len(key) != 64:
raise JWKError(f"Key must be 512 bit for alg {algorithm}")
self._key = key
def to_dict(self):
data = {"alg": self._algorithm, "kty": "oct", "k": base64url_encode(self._key)}
return data
def encrypt(self, plain_text, aad=None):
plain_text = ensure_binary(plain_text)
try:
iv = get_random_bytes(algorithms.AES.block_size // 8)
mode = self._mode(iv)
if mode.name == "GCM":
cipher = aead.AESGCM(self._key)
cipher_text_and_tag = cipher.encrypt(iv, plain_text, aad)
cipher_text = cipher_text_and_tag[: len(cipher_text_and_tag) - 16]
auth_tag = cipher_text_and_tag[-16:]
else:
cipher = Cipher(algorithms.AES(self._key), mode, backend=default_backend())
encryptor = cipher.encryptor()
padder = PKCS7(algorithms.AES.block_size).padder()
padded_data = padder.update(plain_text)
padded_data += padder.finalize()
cipher_text = encryptor.update(padded_data) + encryptor.finalize()
auth_tag = None
return iv, cipher_text, auth_tag
except Exception as e:
raise JWEError(e)
def decrypt(self, cipher_text, iv=None, aad=None, tag=None):
cipher_text = ensure_binary(cipher_text)
try:
iv = ensure_binary(iv)
mode = self._mode(iv)
if mode.name == "GCM":
if tag is None:
raise ValueError("tag cannot be None")
cipher = aead.AESGCM(self._key)
cipher_text_and_tag = cipher_text + tag
try:
plain_text = cipher.decrypt(iv, cipher_text_and_tag, aad)
except InvalidTag:
raise JWEError("Invalid JWE Auth Tag")
else:
cipher = Cipher(algorithms.AES(self._key), mode, backend=default_backend())
decryptor = cipher.decryptor()
padded_plain_text = decryptor.update(cipher_text)
padded_plain_text += decryptor.finalize()
unpadder = PKCS7(algorithms.AES.block_size).unpadder()
plain_text = unpadder.update(padded_plain_text)
plain_text += unpadder.finalize()
return plain_text
except Exception as e:
raise JWEError(e)
def wrap_key(self, key_data):
key_data = ensure_binary(key_data)
cipher_text = aes_key_wrap(self._key, key_data, default_backend())
return cipher_text # IV, cipher text, auth tag
def unwrap_key(self, wrapped_key):
wrapped_key = ensure_binary(wrapped_key)
try:
plain_text = aes_key_unwrap(self._key, wrapped_key, default_backend())
except InvalidUnwrap as cause:
raise JWEError(cause)
return plain_text
class CryptographyHMACKey(Key):
"""
Performs signing and verification operations using HMAC
and the specified hash function.
"""
ALG_MAP = {ALGORITHMS.HS256: hashes.SHA256(), ALGORITHMS.HS384: hashes.SHA384(), ALGORITHMS.HS512: hashes.SHA512()}
def __init__(self, key, algorithm):
if algorithm not in ALGORITHMS.HMAC:
raise JWKError("hash_alg: %s is not a valid hash algorithm" % algorithm)
self._algorithm = algorithm
self._hash_alg = self.ALG_MAP.get(algorithm)
if isinstance(key, dict):
self.prepared_key = self._process_jwk(key)
return
if not isinstance(key, str) and not isinstance(key, bytes):
raise JWKError("Expecting a string- or bytes-formatted key.")
if isinstance(key, str):
key = key.encode("utf-8")
invalid_strings = [
b"-----BEGIN PUBLIC KEY-----",
b"-----BEGIN RSA PUBLIC KEY-----",
b"-----BEGIN CERTIFICATE-----",
b"ssh-rsa",
]
if any(string_value in key for string_value in invalid_strings):
raise JWKError(
"The specified key is an asymmetric key or x509 certificate and"
" should not be used as an HMAC secret."
)
self.prepared_key = key
def _process_jwk(self, jwk_dict):
if not jwk_dict.get("kty") == "oct":
raise JWKError("Incorrect key type. Expected: 'oct', Received: %s" % jwk_dict.get("kty"))
k = jwk_dict.get("k")
k = k.encode("utf-8")
k = bytes(k)
k = base64url_decode(k)
return k
def to_dict(self):
return {
"alg": self._algorithm,
"kty": "oct",
"k": base64url_encode(self.prepared_key).decode("ASCII"),
}
def sign(self, msg):
msg = ensure_binary(msg)
h = hmac.HMAC(self.prepared_key, self._hash_alg, backend=default_backend())
h.update(msg)
signature = h.finalize()
return signature
def verify(self, msg, sig):
msg = ensure_binary(msg)
sig = ensure_binary(sig)
h = hmac.HMAC(self.prepared_key, self._hash_alg, backend=default_backend())
h.update(msg)
try:
h.verify(sig)
verified = True
except InvalidSignature:
verified = False
return verified

View File

@@ -0,0 +1,150 @@
import hashlib
import ecdsa
from jose.backends.base import Key
from jose.constants import ALGORITHMS
from jose.exceptions import JWKError
from jose.utils import base64_to_long, long_to_base64
class ECDSAECKey(Key):
"""
Performs signing and verification operations using
ECDSA and the specified hash function
This class requires the ecdsa package to be installed.
This is based off of the implementation in PyJWT 0.3.2
"""
SHA256 = hashlib.sha256
SHA384 = hashlib.sha384
SHA512 = hashlib.sha512
CURVE_MAP = {
SHA256: ecdsa.curves.NIST256p,
SHA384: ecdsa.curves.NIST384p,
SHA512: ecdsa.curves.NIST521p,
}
CURVE_NAMES = (
(ecdsa.curves.NIST256p, "P-256"),
(ecdsa.curves.NIST384p, "P-384"),
(ecdsa.curves.NIST521p, "P-521"),
)
def __init__(self, key, algorithm):
if algorithm not in ALGORITHMS.EC:
raise JWKError("hash_alg: %s is not a valid hash algorithm" % algorithm)
self.hash_alg = {
ALGORITHMS.ES256: self.SHA256,
ALGORITHMS.ES384: self.SHA384,
ALGORITHMS.ES512: self.SHA512,
}.get(algorithm)
self._algorithm = algorithm
self.curve = self.CURVE_MAP.get(self.hash_alg)
if isinstance(key, (ecdsa.SigningKey, ecdsa.VerifyingKey)):
self.prepared_key = key
return
if isinstance(key, dict):
self.prepared_key = self._process_jwk(key)
return
if isinstance(key, str):
key = key.encode("utf-8")
if isinstance(key, bytes):
# Attempt to load key. We don't know if it's
# a Signing Key or a Verifying Key, so we try
# the Verifying Key first.
try:
key = ecdsa.VerifyingKey.from_pem(key)
except ecdsa.der.UnexpectedDER:
key = ecdsa.SigningKey.from_pem(key)
except Exception as e:
raise JWKError(e)
self.prepared_key = key
return
raise JWKError("Unable to parse an ECKey from key: %s" % key)
def _process_jwk(self, jwk_dict):
if not jwk_dict.get("kty") == "EC":
raise JWKError("Incorrect key type. Expected: 'EC', Received: %s" % jwk_dict.get("kty"))
if not all(k in jwk_dict for k in ["x", "y", "crv"]):
raise JWKError("Mandatory parameters are missing")
if "d" in jwk_dict:
# We are dealing with a private key; the secret exponent is enough
# to create an ecdsa key.
d = base64_to_long(jwk_dict.get("d"))
return ecdsa.keys.SigningKey.from_secret_exponent(d, self.curve)
else:
x = base64_to_long(jwk_dict.get("x"))
y = base64_to_long(jwk_dict.get("y"))
if not ecdsa.ecdsa.point_is_valid(self.curve.generator, x, y):
raise JWKError(f"Point: {x}, {y} is not a valid point")
point = ecdsa.ellipticcurve.Point(self.curve.curve, x, y, self.curve.order)
return ecdsa.keys.VerifyingKey.from_public_point(point, self.curve)
def sign(self, msg):
return self.prepared_key.sign(
msg, hashfunc=self.hash_alg, sigencode=ecdsa.util.sigencode_string, allow_truncate=False
)
def verify(self, msg, sig):
try:
return self.prepared_key.verify(
sig, msg, hashfunc=self.hash_alg, sigdecode=ecdsa.util.sigdecode_string, allow_truncate=False
)
except Exception:
return False
def is_public(self):
return isinstance(self.prepared_key, ecdsa.VerifyingKey)
def public_key(self):
if self.is_public():
return self
return self.__class__(self.prepared_key.get_verifying_key(), self._algorithm)
def to_pem(self):
return self.prepared_key.to_pem()
def to_dict(self):
if not self.is_public():
public_key = self.prepared_key.get_verifying_key()
else:
public_key = self.prepared_key
crv = None
for key, value in self.CURVE_NAMES:
if key == self.prepared_key.curve:
crv = value
if not crv:
raise KeyError(f"Can't match {self.prepared_key.curve}")
# Calculate the key size in bytes. Section 6.2.1.2 and 6.2.1.3 of
# RFC7518 prescribes that the 'x', 'y' and 'd' parameters of the curve
# points must be encoded as octed-strings of this length.
key_size = self.prepared_key.curve.baselen
data = {
"alg": self._algorithm,
"kty": "EC",
"crv": crv,
"x": long_to_base64(public_key.pubkey.point.x(), size=key_size).decode("ASCII"),
"y": long_to_base64(public_key.pubkey.point.y(), size=key_size).decode("ASCII"),
}
if not self.is_public():
data["d"] = long_to_base64(self.prepared_key.privkey.secret_multiplier, size=key_size).decode("ASCII")
return data

View File

@@ -0,0 +1,76 @@
import hashlib
import hmac
import os
from jose.backends.base import Key
from jose.constants import ALGORITHMS
from jose.exceptions import JWKError
from jose.utils import base64url_decode, base64url_encode
def get_random_bytes(num_bytes):
return bytes(os.urandom(num_bytes))
class HMACKey(Key):
"""
Performs signing and verification operations using HMAC
and the specified hash function.
"""
HASHES = {ALGORITHMS.HS256: hashlib.sha256, ALGORITHMS.HS384: hashlib.sha384, ALGORITHMS.HS512: hashlib.sha512}
def __init__(self, key, algorithm):
if algorithm not in ALGORITHMS.HMAC:
raise JWKError("hash_alg: %s is not a valid hash algorithm" % algorithm)
self._algorithm = algorithm
self._hash_alg = self.HASHES.get(algorithm)
if isinstance(key, dict):
self.prepared_key = self._process_jwk(key)
return
if not isinstance(key, str) and not isinstance(key, bytes):
raise JWKError("Expecting a string- or bytes-formatted key.")
if isinstance(key, str):
key = key.encode("utf-8")
invalid_strings = [
b"-----BEGIN PUBLIC KEY-----",
b"-----BEGIN RSA PUBLIC KEY-----",
b"-----BEGIN CERTIFICATE-----",
b"ssh-rsa",
]
if any(string_value in key for string_value in invalid_strings):
raise JWKError(
"The specified key is an asymmetric key or x509 certificate and"
" should not be used as an HMAC secret."
)
self.prepared_key = key
def _process_jwk(self, jwk_dict):
if not jwk_dict.get("kty") == "oct":
raise JWKError("Incorrect key type. Expected: 'oct', Received: %s" % jwk_dict.get("kty"))
k = jwk_dict.get("k")
k = k.encode("utf-8")
k = bytes(k)
k = base64url_decode(k)
return k
def sign(self, msg):
return hmac.new(self.prepared_key, msg, self._hash_alg).digest()
def verify(self, msg, sig):
return hmac.compare_digest(sig, self.sign(msg))
def to_dict(self):
return {
"alg": self._algorithm,
"kty": "oct",
"k": base64url_encode(self.prepared_key).decode("ASCII"),
}

View File

@@ -0,0 +1,284 @@
import binascii
import warnings
import rsa as pyrsa
import rsa.pem as pyrsa_pem
from pyasn1.error import PyAsn1Error
from rsa import DecryptionError
from jose.backends._asn1 import (
rsa_private_key_pkcs1_to_pkcs8,
rsa_private_key_pkcs8_to_pkcs1,
rsa_public_key_pkcs1_to_pkcs8,
)
from jose.backends.base import Key
from jose.constants import ALGORITHMS
from jose.exceptions import JWEError, JWKError
from jose.utils import base64_to_long, long_to_base64
ALGORITHMS.SUPPORTED.remove(ALGORITHMS.RSA_OAEP) # RSA OAEP not supported
LEGACY_INVALID_PKCS8_RSA_HEADER = binascii.unhexlify(
"30" # sequence
"8204BD" # DER-encoded sequence contents length of 1213 bytes -- INCORRECT STATIC LENGTH
"020100" # integer: 0 -- Version
"30" # sequence
"0D" # DER-encoded sequence contents length of 13 bytes -- PrivateKeyAlgorithmIdentifier
"06092A864886F70D010101" # OID -- rsaEncryption
"0500" # NULL -- parameters
)
ASN1_SEQUENCE_ID = binascii.unhexlify("30")
RSA_ENCRYPTION_ASN1_OID = "1.2.840.113549.1.1.1"
# Functions gcd and rsa_recover_prime_factors were copied from cryptography 1.9
# to enable pure python rsa module to be in compliance with section 6.3.1 of RFC7518
# which requires only private exponent (d) for private key.
def _gcd(a, b):
"""Calculate the Greatest Common Divisor of a and b.
Unless b==0, the result will have the same sign as b (so that when
b is divided by it, the result comes out positive).
"""
while b:
a, b = b, (a % b)
return a
# Controls the number of iterations rsa_recover_prime_factors will perform
# to obtain the prime factors. Each iteration increments by 2 so the actual
# maximum attempts is half this number.
_MAX_RECOVERY_ATTEMPTS = 1000
def _rsa_recover_prime_factors(n, e, d):
"""
Compute factors p and q from the private exponent d. We assume that n has
no more than two factors. This function is adapted from code in PyCrypto.
"""
# See 8.2.2(i) in Handbook of Applied Cryptography.
ktot = d * e - 1
# The quantity d*e-1 is a multiple of phi(n), even,
# and can be represented as t*2^s.
t = ktot
while t % 2 == 0:
t = t // 2
# Cycle through all multiplicative inverses in Zn.
# The algorithm is non-deterministic, but there is a 50% chance
# any candidate a leads to successful factoring.
# See "Digitalized Signatures and Public Key Functions as Intractable
# as Factorization", M. Rabin, 1979
spotted = False
a = 2
while not spotted and a < _MAX_RECOVERY_ATTEMPTS:
k = t
# Cycle through all values a^{t*2^i}=a^k
while k < ktot:
cand = pow(a, k, n)
# Check if a^k is a non-trivial root of unity (mod n)
if cand != 1 and cand != (n - 1) and pow(cand, 2, n) == 1:
# We have found a number such that (cand-1)(cand+1)=0 (mod n).
# Either of the terms divides n.
p = _gcd(cand + 1, n)
spotted = True
break
k *= 2
# This value was not any good... let's try another!
a += 2
if not spotted:
raise ValueError("Unable to compute factors p and q from exponent d.")
# Found !
q, r = divmod(n, p)
assert r == 0
p, q = sorted((p, q), reverse=True)
return (p, q)
def pem_to_spki(pem, fmt="PKCS8"):
key = RSAKey(pem, ALGORITHMS.RS256)
return key.to_pem(fmt)
def _legacy_private_key_pkcs8_to_pkcs1(pkcs8_key):
"""Legacy RSA private key PKCS8-to-PKCS1 conversion.
.. warning::
This is incorrect parsing and only works because the legacy PKCS1-to-PKCS8
encoding was also incorrect.
"""
# Only allow this processing if the prefix matches
# AND the following byte indicates an ASN1 sequence,
# as we would expect with the legacy encoding.
if not pkcs8_key.startswith(LEGACY_INVALID_PKCS8_RSA_HEADER + ASN1_SEQUENCE_ID):
raise ValueError("Invalid private key encoding")
return pkcs8_key[len(LEGACY_INVALID_PKCS8_RSA_HEADER) :]
class RSAKey(Key):
SHA256 = "SHA-256"
SHA384 = "SHA-384"
SHA512 = "SHA-512"
def __init__(self, key, algorithm):
if algorithm not in ALGORITHMS.RSA:
raise JWKError("hash_alg: %s is not a valid hash algorithm" % algorithm)
if algorithm in ALGORITHMS.RSA_KW and algorithm != ALGORITHMS.RSA1_5:
raise JWKError("alg: %s is not supported by the RSA backend" % algorithm)
self.hash_alg = {
ALGORITHMS.RS256: self.SHA256,
ALGORITHMS.RS384: self.SHA384,
ALGORITHMS.RS512: self.SHA512,
}.get(algorithm)
self._algorithm = algorithm
if isinstance(key, dict):
self._prepared_key = self._process_jwk(key)
return
if isinstance(key, (pyrsa.PublicKey, pyrsa.PrivateKey)):
self._prepared_key = key
return
if isinstance(key, str):
key = key.encode("utf-8")
if isinstance(key, bytes):
try:
self._prepared_key = pyrsa.PublicKey.load_pkcs1(key)
except ValueError:
try:
self._prepared_key = pyrsa.PublicKey.load_pkcs1_openssl_pem(key)
except ValueError:
try:
self._prepared_key = pyrsa.PrivateKey.load_pkcs1(key)
except ValueError:
try:
der = pyrsa_pem.load_pem(key, b"PRIVATE KEY")
try:
pkcs1_key = rsa_private_key_pkcs8_to_pkcs1(der)
except PyAsn1Error:
# If the key was encoded using the old, invalid,
# encoding then pyasn1 will throw an error attempting
# to parse the key.
pkcs1_key = _legacy_private_key_pkcs8_to_pkcs1(der)
self._prepared_key = pyrsa.PrivateKey.load_pkcs1(pkcs1_key, format="DER")
except ValueError as e:
raise JWKError(e)
return
raise JWKError("Unable to parse an RSA_JWK from key: %s" % key)
def _process_jwk(self, jwk_dict):
if not jwk_dict.get("kty") == "RSA":
raise JWKError("Incorrect key type. Expected: 'RSA', Received: %s" % jwk_dict.get("kty"))
e = base64_to_long(jwk_dict.get("e"))
n = base64_to_long(jwk_dict.get("n"))
if "d" not in jwk_dict:
return pyrsa.PublicKey(e=e, n=n)
else:
d = base64_to_long(jwk_dict.get("d"))
extra_params = ["p", "q", "dp", "dq", "qi"]
if any(k in jwk_dict for k in extra_params):
# Precomputed private key parameters are available.
if not all(k in jwk_dict for k in extra_params):
# These values must be present when 'p' is according to
# Section 6.3.2 of RFC7518, so if they are not we raise
# an error.
raise JWKError("Precomputed private key parameters are incomplete.")
p = base64_to_long(jwk_dict["p"])
q = base64_to_long(jwk_dict["q"])
return pyrsa.PrivateKey(e=e, n=n, d=d, p=p, q=q)
else:
p, q = _rsa_recover_prime_factors(n, e, d)
return pyrsa.PrivateKey(n=n, e=e, d=d, p=p, q=q)
def sign(self, msg):
return pyrsa.sign(msg, self._prepared_key, self.hash_alg)
def verify(self, msg, sig):
if not self.is_public():
warnings.warn("Attempting to verify a message with a private key. " "This is not recommended.")
try:
pyrsa.verify(msg, sig, self._prepared_key)
return True
except pyrsa.pkcs1.VerificationError:
return False
def is_public(self):
return isinstance(self._prepared_key, pyrsa.PublicKey)
def public_key(self):
if isinstance(self._prepared_key, pyrsa.PublicKey):
return self
return self.__class__(pyrsa.PublicKey(n=self._prepared_key.n, e=self._prepared_key.e), self._algorithm)
def to_pem(self, pem_format="PKCS8"):
if isinstance(self._prepared_key, pyrsa.PrivateKey):
der = self._prepared_key.save_pkcs1(format="DER")
if pem_format == "PKCS8":
pkcs8_der = rsa_private_key_pkcs1_to_pkcs8(der)
pem = pyrsa_pem.save_pem(pkcs8_der, pem_marker="PRIVATE KEY")
elif pem_format == "PKCS1":
pem = pyrsa_pem.save_pem(der, pem_marker="RSA PRIVATE KEY")
else:
raise ValueError(f"Invalid pem format specified: {pem_format!r}")
else:
if pem_format == "PKCS8":
pkcs1_der = self._prepared_key.save_pkcs1(format="DER")
pkcs8_der = rsa_public_key_pkcs1_to_pkcs8(pkcs1_der)
pem = pyrsa_pem.save_pem(pkcs8_der, pem_marker="PUBLIC KEY")
elif pem_format == "PKCS1":
der = self._prepared_key.save_pkcs1(format="DER")
pem = pyrsa_pem.save_pem(der, pem_marker="RSA PUBLIC KEY")
else:
raise ValueError(f"Invalid pem format specified: {pem_format!r}")
return pem
def to_dict(self):
if not self.is_public():
public_key = self.public_key()._prepared_key
else:
public_key = self._prepared_key
data = {
"alg": self._algorithm,
"kty": "RSA",
"n": long_to_base64(public_key.n).decode("ASCII"),
"e": long_to_base64(public_key.e).decode("ASCII"),
}
if not self.is_public():
data.update(
{
"d": long_to_base64(self._prepared_key.d).decode("ASCII"),
"p": long_to_base64(self._prepared_key.p).decode("ASCII"),
"q": long_to_base64(self._prepared_key.q).decode("ASCII"),
"dp": long_to_base64(self._prepared_key.exp1).decode("ASCII"),
"dq": long_to_base64(self._prepared_key.exp2).decode("ASCII"),
"qi": long_to_base64(self._prepared_key.coef).decode("ASCII"),
}
)
return data
def wrap_key(self, key_data):
if not self.is_public():
warnings.warn("Attempting to encrypt a message with a private key." " This is not recommended.")
wrapped_key = pyrsa.encrypt(key_data, self._prepared_key)
return wrapped_key
def unwrap_key(self, wrapped_key):
try:
unwrapped_key = pyrsa.decrypt(wrapped_key, self._prepared_key)
except DecryptionError as e:
raise JWEError(e)
return unwrapped_key

View File

@@ -0,0 +1,98 @@
import hashlib
class Algorithms:
# DS Algorithms
NONE = "none"
HS256 = "HS256"
HS384 = "HS384"
HS512 = "HS512"
RS256 = "RS256"
RS384 = "RS384"
RS512 = "RS512"
ES256 = "ES256"
ES384 = "ES384"
ES512 = "ES512"
# Content Encryption Algorithms
A128CBC_HS256 = "A128CBC-HS256"
A192CBC_HS384 = "A192CBC-HS384"
A256CBC_HS512 = "A256CBC-HS512"
A128GCM = "A128GCM"
A192GCM = "A192GCM"
A256GCM = "A256GCM"
# Pseudo algorithm for encryption
A128CBC = "A128CBC"
A192CBC = "A192CBC"
A256CBC = "A256CBC"
# CEK Encryption Algorithms
DIR = "dir"
RSA1_5 = "RSA1_5"
RSA_OAEP = "RSA-OAEP"
RSA_OAEP_256 = "RSA-OAEP-256"
A128KW = "A128KW"
A192KW = "A192KW"
A256KW = "A256KW"
ECDH_ES = "ECDH-ES"
ECDH_ES_A128KW = "ECDH-ES+A128KW"
ECDH_ES_A192KW = "ECDH-ES+A192KW"
ECDH_ES_A256KW = "ECDH-ES+A256KW"
A128GCMKW = "A128GCMKW"
A192GCMKW = "A192GCMKW"
A256GCMKW = "A256GCMKW"
PBES2_HS256_A128KW = "PBES2-HS256+A128KW"
PBES2_HS384_A192KW = "PBES2-HS384+A192KW"
PBES2_HS512_A256KW = "PBES2-HS512+A256KW"
# Compression Algorithms
DEF = "DEF"
HMAC = {HS256, HS384, HS512}
RSA_DS = {RS256, RS384, RS512}
RSA_KW = {RSA1_5, RSA_OAEP, RSA_OAEP_256}
RSA = RSA_DS.union(RSA_KW)
EC_DS = {ES256, ES384, ES512}
EC_KW = {ECDH_ES, ECDH_ES_A128KW, ECDH_ES_A192KW, ECDH_ES_A256KW}
EC = EC_DS.union(EC_KW)
AES_PSEUDO = {A128CBC, A192CBC, A256CBC, A128GCM, A192GCM, A256GCM}
AES_JWE_ENC = {A128CBC_HS256, A192CBC_HS384, A256CBC_HS512, A128GCM, A192GCM, A256GCM}
AES_ENC = AES_JWE_ENC.union(AES_PSEUDO)
AES_KW = {A128KW, A192KW, A256KW}
AEC_GCM_KW = {A128GCMKW, A192GCMKW, A256GCMKW}
AES = AES_ENC.union(AES_KW)
PBES2_KW = {PBES2_HS256_A128KW, PBES2_HS384_A192KW, PBES2_HS512_A256KW}
HMAC_AUTH_TAG = {A128CBC_HS256, A192CBC_HS384, A256CBC_HS512}
GCM = {A128GCM, A192GCM, A256GCM}
SUPPORTED = HMAC.union(RSA_DS).union(EC_DS).union([DIR]).union(AES_JWE_ENC).union(RSA_KW).union(AES_KW)
ALL = SUPPORTED.union([NONE]).union(AEC_GCM_KW).union(EC_KW).union(PBES2_KW)
HASHES = {
HS256: hashlib.sha256,
HS384: hashlib.sha384,
HS512: hashlib.sha512,
RS256: hashlib.sha256,
RS384: hashlib.sha384,
RS512: hashlib.sha512,
ES256: hashlib.sha256,
ES384: hashlib.sha384,
ES512: hashlib.sha512,
}
KEYS = {}
ALGORITHMS = Algorithms()
class Zips:
DEF = "DEF"
NONE = None
SUPPORTED = {DEF, NONE}
ZIPS = Zips()

View File

@@ -0,0 +1,59 @@
class JOSEError(Exception):
pass
class JWSError(JOSEError):
pass
class JWSSignatureError(JWSError):
pass
class JWSAlgorithmError(JWSError):
pass
class JWTError(JOSEError):
pass
class JWTClaimsError(JWTError):
pass
class ExpiredSignatureError(JWTError):
pass
class JWKError(JOSEError):
pass
class JWEError(JOSEError):
"""Base error for all JWE errors"""
pass
class JWEParseError(JWEError):
"""Could not parse the JWE string provided"""
pass
class JWEInvalidAuth(JWEError):
"""
The authentication tag did not match the protected sections of the
JWE string provided
"""
pass
class JWEAlgorithmUnsupportedError(JWEError):
"""
The JWE algorithm is not supported by the backend
"""
pass

View File

@@ -0,0 +1,607 @@
import binascii
import json
import zlib
from collections.abc import Mapping
from struct import pack
from . import jwk
from .backends import get_random_bytes
from .constants import ALGORITHMS, ZIPS
from .exceptions import JWEError, JWEParseError
from .utils import base64url_decode, base64url_encode, ensure_binary
def encrypt(plaintext, key, encryption=ALGORITHMS.A256GCM, algorithm=ALGORITHMS.DIR, zip=None, cty=None, kid=None):
"""Encrypts plaintext and returns a JWE cmpact serialization string.
Args:
plaintext (bytes): A bytes object to encrypt
key (str or dict): The key(s) to use for encrypting the content. Can be
individual JWK or JWK set.
encryption (str, optional): The content encryption algorithm used to
perform authenticated encryption on the plaintext to produce the
ciphertext and the Authentication Tag. Defaults to A256GCM.
algorithm (str, optional): The cryptographic algorithm used
to encrypt or determine the value of the CEK. Defaults to dir.
zip (str, optional): The compression algorithm) applied to the
plaintext before encryption. Defaults to None.
cty (str, optional): The media type for the secured content.
See http://www.iana.org/assignments/media-types/media-types.xhtml
kid (str, optional): Key ID for the provided key
Returns:
bytes: The string representation of the header, encrypted key,
initialization vector, ciphertext, and authentication tag.
Raises:
JWEError: If there is an error signing the token.
Examples:
>>> from jose import jwe
>>> jwe.encrypt('Hello, World!', 'asecret128bitkey', algorithm='dir', encryption='A128GCM')
'eyJhbGciOiJkaXIiLCJlbmMiOiJBMTI4R0NNIn0..McILMB3dYsNJSuhcDzQshA.OfX9H_mcUpHDeRM4IA.CcnTWqaqxNsjT4eCaUABSg'
"""
plaintext = ensure_binary(plaintext) # Make sure it's bytes
if algorithm not in ALGORITHMS.SUPPORTED:
raise JWEError("Algorithm %s not supported." % algorithm)
if encryption not in ALGORITHMS.SUPPORTED:
raise JWEError("Algorithm %s not supported." % encryption)
key = jwk.construct(key, algorithm)
encoded_header = _encoded_header(algorithm, encryption, zip, cty, kid)
plaintext = _compress(zip, plaintext)
enc_cek, iv, cipher_text, auth_tag = _encrypt_and_auth(key, algorithm, encryption, zip, plaintext, encoded_header)
jwe_string = _jwe_compact_serialize(encoded_header, enc_cek, iv, cipher_text, auth_tag)
return jwe_string
def decrypt(jwe_str, key):
"""Decrypts a JWE compact serialized string and returns the plaintext.
Args:
jwe_str (str): A JWE to be decrypt.
key (str or dict): A key to attempt to decrypt the payload with. Can be
individual JWK or JWK set.
Returns:
bytes: The plaintext bytes, assuming the authentication tag is valid.
Raises:
JWEError: If there is an exception verifying the token.
Examples:
>>> from jose import jwe
>>> jwe.decrypt(jwe_string, 'asecret128bitkey')
'Hello, World!'
"""
header, encoded_header, encrypted_key, iv, cipher_text, auth_tag = _jwe_compact_deserialize(jwe_str)
# Verify that the implementation understands and can process all
# fields that it is required to support, whether required by this
# specification, by the algorithms being used, or by the "crit"
# Header Parameter value, and that the values of those parameters
# are also understood and supported.
try:
# Determine the Key Management Mode employed by the algorithm
# specified by the "alg" (algorithm) Header Parameter.
alg = header["alg"]
enc = header["enc"]
if alg not in ALGORITHMS.SUPPORTED:
raise JWEError("Algorithm %s not supported." % alg)
if enc not in ALGORITHMS.SUPPORTED:
raise JWEError("Algorithm %s not supported." % enc)
except KeyError:
raise JWEParseError("alg and enc headers are required!")
# Verify that the JWE uses a key known to the recipient.
key = jwk.construct(key, alg)
# When Direct Key Agreement or Key Agreement with Key Wrapping are
# employed, use the key agreement algorithm to compute the value
# of the agreed upon key. When Direct Key Agreement is employed,
# let the CEK be the agreed upon key. When Key Agreement with Key
# Wrapping is employed, the agreed upon key will be used to
# decrypt the JWE Encrypted Key.
#
# When Key Wrapping, Key Encryption, or Key Agreement with Key
# Wrapping are employed, decrypt the JWE Encrypted Key to produce
# the CEK. The CEK MUST have a length equal to that required for
# the content encryption algorithm. Note that when there are
# multiple recipients, each recipient will only be able to decrypt
# JWE Encrypted Key values that were encrypted to a key in that
# recipient's possession. It is therefore normal to only be able
# to decrypt one of the per-recipient JWE Encrypted Key values to
# obtain the CEK value. Also, see Section 11.5 for security
# considerations on mitigating timing attacks.
if alg == ALGORITHMS.DIR:
# When Direct Key Agreement or Direct Encryption are employed,
# verify that the JWE Encrypted Key value is an empty octet
# sequence.
# Record whether the CEK could be successfully determined for this
# recipient or not.
cek_valid = encrypted_key == b""
# When Direct Encryption is employed, let the CEK be the shared
# symmetric key.
cek_bytes = _get_key_bytes_from_key(key)
else:
try:
cek_bytes = key.unwrap_key(encrypted_key)
# Record whether the CEK could be successfully determined for this
# recipient or not.
cek_valid = True
except NotImplementedError:
raise JWEError(f"alg {alg} is not implemented")
except Exception:
# Record whether the CEK could be successfully determined for this
# recipient or not.
cek_valid = False
# To mitigate the attacks described in RFC 3218 [RFC3218], the
# recipient MUST NOT distinguish between format, padding, and length
# errors of encrypted keys. It is strongly recommended, in the event
# of receiving an improperly formatted key, that the recipient
# substitute a randomly generated CEK and proceed to the next step, to
# mitigate timing attacks.
cek_bytes = _get_random_cek_bytes_for_enc(enc)
# Compute the Encoded Protected Header value BASE64URL(UTF8(JWE
# Protected Header)). If the JWE Protected Header is not present
# (which can only happen when using the JWE JSON Serialization and
# no "protected" member is present), let this value be the empty
# string.
protected_header = encoded_header
# Let the Additional Authenticated Data encryption parameter be
# ASCII(Encoded Protected Header). However, if a JWE AAD value is
# present (which can only be the case when using the JWE JSON
# Serialization), instead let the Additional Authenticated Data
# encryption parameter be ASCII(Encoded Protected Header || '.' ||
# BASE64URL(JWE AAD)).
aad = protected_header
# Decrypt the JWE Ciphertext using the CEK, the JWE Initialization
# Vector, the Additional Authenticated Data value, and the JWE
# Authentication Tag (which is the Authentication Tag input to the
# calculation) using the specified content encryption algorithm,
# returning the decrypted plaintext and validating the JWE
# Authentication Tag in the manner specified for the algorithm,
# rejecting the input without emitting any decrypted output if the
# JWE Authentication Tag is incorrect.
try:
plain_text = _decrypt_and_auth(cek_bytes, enc, cipher_text, iv, aad, auth_tag)
except NotImplementedError:
raise JWEError(f"enc {enc} is not implemented")
except Exception as e:
raise JWEError(e)
# If a "zip" parameter was included, uncompress the decrypted
# plaintext using the specified compression algorithm.
if plain_text is not None:
plain_text = _decompress(header.get("zip"), plain_text)
return plain_text if cek_valid else None
def get_unverified_header(jwe_str):
"""Returns the decoded headers without verification of any kind.
Args:
jwe_str (str): A compact serialized JWE to decode the headers from.
Returns:
dict: The dict representation of the JWE headers.
Raises:
JWEError: If there is an exception decoding the JWE.
"""
header = _jwe_compact_deserialize(jwe_str)[0]
return header
def _decrypt_and_auth(cek_bytes, enc, cipher_text, iv, aad, auth_tag):
"""
Decrypt and verify the data
Args:
cek_bytes (bytes): cek to derive encryption and possible auth key to
verify the auth tag
cipher_text (bytes): Encrypted data
iv (bytes): Initialization vector (iv) used to encrypt data
aad (bytes): Additional Authenticated Data used to verify the data
auth_tag (bytes): Authentication ntag to verify the data
Returns:
(bytes): Decrypted data
"""
# Decrypt the JWE Ciphertext using the CEK, the JWE Initialization
# Vector, the Additional Authenticated Data value, and the JWE
# Authentication Tag (which is the Authentication Tag input to the
# calculation) using the specified content encryption algorithm,
# returning the decrypted plaintext
# and validating the JWE
# Authentication Tag in the manner specified for the algorithm,
if enc in ALGORITHMS.HMAC_AUTH_TAG:
encryption_key, mac_key, key_len = _get_encryption_key_mac_key_and_key_length_from_cek(cek_bytes, enc)
auth_tag_check = _auth_tag(cipher_text, iv, aad, mac_key, key_len)
elif enc in ALGORITHMS.GCM:
encryption_key = jwk.construct(cek_bytes, enc)
auth_tag_check = auth_tag # GCM check auth on decrypt
else:
raise NotImplementedError(f"enc {enc} is not implemented!")
plaintext = encryption_key.decrypt(cipher_text, iv, aad, auth_tag)
if auth_tag != auth_tag_check:
raise JWEError("Invalid JWE Auth Tag")
return plaintext
def _get_encryption_key_mac_key_and_key_length_from_cek(cek_bytes, enc):
derived_key_len = len(cek_bytes) // 2
mac_key_bytes = cek_bytes[0:derived_key_len]
mac_key = _get_hmac_key(enc, mac_key_bytes)
encryption_key_bytes = cek_bytes[-derived_key_len:]
encryption_alg, _ = enc.split("-")
encryption_key = jwk.construct(encryption_key_bytes, encryption_alg)
return encryption_key, mac_key, derived_key_len
def _jwe_compact_deserialize(jwe_bytes):
"""
Deserialize and verify the header and segments are appropriate.
Args:
jwe_bytes (bytes): The compact serialized JWE
Returns:
(dict, bytes, bytes, bytes, bytes, bytes)
"""
# Base64url decode the encoded representations of the JWE
# Protected Header, the JWE Encrypted Key, the JWE Initialization
# Vector, the JWE Ciphertext, the JWE Authentication Tag, and the
# JWE AAD, following the restriction that no line breaks,
# whitespace, or other additional characters have been used.
jwe_bytes = ensure_binary(jwe_bytes)
try:
header_segment, encrypted_key_segment, iv_segment, cipher_text_segment, auth_tag_segment = jwe_bytes.split(
b".", 4
)
header_data = base64url_decode(header_segment)
except ValueError:
raise JWEParseError("Not enough segments")
except (TypeError, binascii.Error):
raise JWEParseError("Invalid header")
# Verify that the octet sequence resulting from decoding the
# encoded JWE Protected Header is a UTF-8-encoded representation
# of a completely valid JSON object conforming to RFC 7159
# [RFC7159]; let the JWE Protected Header be this JSON object.
#
# If using the JWE Compact Serialization, let the JOSE Header be
# the JWE Protected Header. Otherwise, when using the JWE JSON
# Serialization, let the JOSE Header be the union of the members
# of the JWE Protected Header, the JWE Shared Unprotected Header
# and the corresponding JWE Per-Recipient Unprotected Header, all
# of which must be completely valid JSON objects. During this
# step, verify that the resulting JOSE Header does not contain
# duplicate Header Parameter names. When using the JWE JSON
# Serialization, this restriction includes that the same Header
# Parameter name also MUST NOT occur in distinct JSON object
# values that together comprise the JOSE Header.
try:
header = json.loads(header_data)
except ValueError as e:
raise JWEParseError(f"Invalid header string: {e}")
if not isinstance(header, Mapping):
raise JWEParseError("Invalid header string: must be a json object")
try:
encrypted_key = base64url_decode(encrypted_key_segment)
except (TypeError, binascii.Error):
raise JWEParseError("Invalid encrypted key")
try:
iv = base64url_decode(iv_segment)
except (TypeError, binascii.Error):
raise JWEParseError("Invalid IV")
try:
ciphertext = base64url_decode(cipher_text_segment)
except (TypeError, binascii.Error):
raise JWEParseError("Invalid cyphertext")
try:
auth_tag = base64url_decode(auth_tag_segment)
except (TypeError, binascii.Error):
raise JWEParseError("Invalid auth tag")
return header, header_segment, encrypted_key, iv, ciphertext, auth_tag
def _encoded_header(alg, enc, zip, cty, kid):
"""
Generate an appropriate JOSE header based on the values provided
Args:
alg (str): Key wrap/negotiation algorithm
enc (str): Encryption algorithm
zip (str): Compression method
cty (str): Content type of the encrypted data
kid (str): ID for the key used for the operation
Returns:
bytes: JSON object of header based on input
"""
header = {"alg": alg, "enc": enc}
if zip:
header["zip"] = zip
if cty:
header["cty"] = cty
if kid:
header["kid"] = kid
json_header = json.dumps(
header,
separators=(",", ":"),
sort_keys=True,
).encode("utf-8")
return base64url_encode(json_header)
def _big_endian(int_val):
return pack("!Q", int_val)
def _encrypt_and_auth(key, alg, enc, zip, plaintext, aad):
"""
Generate a content encryption key (cek) and initialization
vector (iv) based on enc and alg, compress the plaintext based on zip,
encrypt the compressed plaintext using the cek and iv based on enc
Args:
key (Key): The key provided for encryption
alg (str): The algorithm use for key wrap/negotiation
enc (str): The encryption algorithm with which to encrypt the plaintext
zip (str): The compression algorithm with which to compress the plaintext
plaintext (bytes): The data to encrypt
aad (str): Additional authentication data utilized for generating an
auth tag
Returns:
(bytes, bytes, bytes, bytes): A tuple of the following data
(key wrapped cek, iv, cipher text, auth tag)
"""
try:
cek_bytes, kw_cek = _get_cek(enc, alg, key)
except NotImplementedError:
raise JWEError(f"alg {alg} is not implemented")
if enc in ALGORITHMS.HMAC_AUTH_TAG:
encryption_key, mac_key, key_len = _get_encryption_key_mac_key_and_key_length_from_cek(cek_bytes, enc)
iv, ciphertext, tag = encryption_key.encrypt(plaintext, aad)
auth_tag = _auth_tag(ciphertext, iv, aad, mac_key, key_len)
elif enc in ALGORITHMS.GCM:
encryption_key = jwk.construct(cek_bytes, enc)
iv, ciphertext, auth_tag = encryption_key.encrypt(plaintext, aad)
else:
raise NotImplementedError(f"enc {enc} is not implemented!")
return kw_cek, iv, ciphertext, auth_tag
def _get_hmac_key(enc, mac_key_bytes):
"""
Get an HMACKey for the provided encryption algorithm and key bytes
Args:
enc (str): Encryption algorithm
mac_key_bytes (bytes): vytes for the HMAC key
Returns:
(HMACKey): The key to perform HMAC actions
"""
_, hash_alg = enc.split("-")
mac_key = jwk.construct(mac_key_bytes, hash_alg)
return mac_key
def _compress(zip, plaintext):
"""
Compress the plaintext based on the algorithm supplied
Args:
zip (str): Compression Algorithm
plaintext (bytes): plaintext to compress
Returns:
(bytes): Compressed plaintext
"""
if zip not in ZIPS.SUPPORTED:
raise NotImplementedError("ZIP {} is not supported!")
if zip is None:
compressed = plaintext
elif zip == ZIPS.DEF:
compressed = zlib.compress(plaintext)
else:
raise NotImplementedError("ZIP {} is not implemented!")
return compressed
def _decompress(zip, compressed):
"""
Decompress the plaintext based on the algorithm supplied
Args:
zip (str): Compression Algorithm
plaintext (bytes): plaintext to decompress
Returns:
(bytes): Compressed plaintext
"""
if zip not in ZIPS.SUPPORTED:
raise NotImplementedError("ZIP {} is not supported!")
if zip is None:
decompressed = compressed
elif zip == ZIPS.DEF:
decompressed = zlib.decompress(compressed)
else:
raise NotImplementedError("ZIP {} is not implemented!")
return decompressed
def _get_cek(enc, alg, key):
"""
Get the content encryption key
Args:
enc (str): Encryption algorithm
alg (str): kwy wrap/negotiation algorithm
key (Key): Key provided to encryption method
Return:
(bytes, bytes): Tuple of (cek bytes and wrapped cek)
"""
if alg == ALGORITHMS.DIR:
cek, wrapped_cek = _get_direct_key_wrap_cek(key)
else:
cek, wrapped_cek = _get_key_wrap_cek(enc, key)
return cek, wrapped_cek
def _get_direct_key_wrap_cek(key):
"""
Get the cek and wrapped cek from the encryption key direct
Args:
key (Key): Key provided to encryption method
Return:
(Key, bytes): Tuple of (cek Key object and wrapped cek)
"""
# Get the JWK data to determine how to derive the cek
jwk_data = key.to_dict()
if jwk_data["kty"] == "oct":
# Get the last half of an octal key as the cek
cek_bytes = _get_key_bytes_from_key(key)
wrapped_cek = b""
else:
raise NotImplementedError("JWK type {} not supported!".format(jwk_data["kty"]))
return cek_bytes, wrapped_cek
def _get_key_bytes_from_key(key):
"""
Get the raw key bytes from a Key object
Args:
key (Key): Key from which to extract the raw key bytes
Returns:
(bytes) key data
"""
jwk_data = key.to_dict()
encoded_key = jwk_data["k"]
cek_bytes = base64url_decode(encoded_key)
return cek_bytes
def _get_key_wrap_cek(enc, key):
"""_get_rsa_key_wrap_cek
Get the content encryption key for RSA key wrap
Args:
enc (str): Encryption algorithm
key (Key): Key provided to encryption method
Returns:
(Key, bytes): Tuple of (cek Key object and wrapped cek)
"""
cek_bytes = _get_random_cek_bytes_for_enc(enc)
wrapped_cek = key.wrap_key(cek_bytes)
return cek_bytes, wrapped_cek
def _get_random_cek_bytes_for_enc(enc):
"""
Get the random cek bytes based on the encryptionn algorithm
Args:
enc (str): Encryption algorithm
Returns:
(bytes) random bytes for cek key
"""
if enc == ALGORITHMS.A128GCM:
num_bits = 128
elif enc == ALGORITHMS.A192GCM:
num_bits = 192
elif enc in (ALGORITHMS.A128CBC_HS256, ALGORITHMS.A256GCM):
num_bits = 256
elif enc == ALGORITHMS.A192CBC_HS384:
num_bits = 384
elif enc == ALGORITHMS.A256CBC_HS512:
num_bits = 512
else:
raise NotImplementedError(f"{enc} not supported")
cek_bytes = get_random_bytes(num_bits // 8)
return cek_bytes
def _auth_tag(ciphertext, iv, aad, mac_key, tag_length):
"""
Get ann auth tag from the provided data
Args:
ciphertext (bytes): Encrypted value
iv (bytes): Initialization vector
aad (bytes): Additional Authenticated Data
mac_key (bytes): Key to use in generating the MAC
tag_length (int): How log the tag should be
Returns:
(bytes) Auth tag
"""
al = _big_endian(len(aad) * 8)
auth_tag_input = aad + iv + ciphertext + al
signature = mac_key.sign(auth_tag_input)
auth_tag = signature[0:tag_length]
return auth_tag
def _jwe_compact_serialize(encoded_header, encrypted_cek, iv, cipher_text, auth_tag):
"""
Generate a compact serialized JWE
Args:
encoded_header (bytes): Base64 URL Encoded JWE header JSON
encrypted_cek (bytes): Encrypted content encryption key (cek)
iv (bytes): Initialization vector (IV)
cipher_text (bytes): Cipher text
auth_tag (bytes): JWE Auth Tag
Returns:
(str): JWE compact serialized string
"""
cipher_text = ensure_binary(cipher_text)
encoded_encrypted_cek = base64url_encode(encrypted_cek)
encoded_iv = base64url_encode(iv)
encoded_cipher_text = base64url_encode(cipher_text)
encoded_auth_tag = base64url_encode(auth_tag)
return (
encoded_header
+ b"."
+ encoded_encrypted_cek
+ b"."
+ encoded_iv
+ b"."
+ encoded_cipher_text
+ b"."
+ encoded_auth_tag
)

View File

@@ -0,0 +1,79 @@
from jose.backends.base import Key
from jose.constants import ALGORITHMS
from jose.exceptions import JWKError
try:
from jose.backends import RSAKey # noqa: F401
except ImportError:
pass
try:
from jose.backends import ECKey # noqa: F401
except ImportError:
pass
try:
from jose.backends import AESKey # noqa: F401
except ImportError:
pass
try:
from jose.backends import DIRKey # noqa: F401
except ImportError:
pass
try:
from jose.backends import HMACKey # noqa: F401
except ImportError:
pass
def get_key(algorithm):
if algorithm in ALGORITHMS.KEYS:
return ALGORITHMS.KEYS[algorithm]
elif algorithm in ALGORITHMS.HMAC: # noqa: F811
return HMACKey
elif algorithm in ALGORITHMS.RSA:
from jose.backends import RSAKey # noqa: F811
return RSAKey
elif algorithm in ALGORITHMS.EC:
from jose.backends import ECKey # noqa: F811
return ECKey
elif algorithm in ALGORITHMS.AES:
from jose.backends import AESKey # noqa: F811
return AESKey
elif algorithm == ALGORITHMS.DIR:
from jose.backends import DIRKey # noqa: F811
return DIRKey
return None
def register_key(algorithm, key_class):
if not issubclass(key_class, Key):
raise TypeError("Key class is not a subclass of jwk.Key")
ALGORITHMS.KEYS[algorithm] = key_class
ALGORITHMS.SUPPORTED.add(algorithm)
return True
def construct(key_data, algorithm=None):
"""
Construct a Key object for the given algorithm with the given
key_data.
"""
# Allow for pulling the algorithm off of the passed in jwk.
if not algorithm and isinstance(key_data, dict):
algorithm = key_data.get("alg", None)
if not algorithm:
raise JWKError("Unable to find an algorithm for key: %s" % key_data)
key_class = get_key(algorithm)
if not key_class:
raise JWKError("Unable to find an algorithm for key: %s" % key_data)
return key_class(key_data, algorithm)

View File

@@ -0,0 +1,266 @@
import binascii
import json
from collections.abc import Iterable, Mapping
from jose import jwk
from jose.backends.base import Key
from jose.constants import ALGORITHMS
from jose.exceptions import JWSError, JWSSignatureError
from jose.utils import base64url_decode, base64url_encode
def sign(payload, key, headers=None, algorithm=ALGORITHMS.HS256):
"""Signs a claims set and returns a JWS string.
Args:
payload (str or dict): A string to sign
key (str or dict): The key to use for signing the claim set. Can be
individual JWK or JWK set.
headers (dict, optional): A set of headers that will be added to
the default headers. Any headers that are added as additional
headers will override the default headers.
algorithm (str, optional): The algorithm to use for signing the
the claims. Defaults to HS256.
Returns:
str: The string representation of the header, claims, and signature.
Raises:
JWSError: If there is an error signing the token.
Examples:
>>> jws.sign({'a': 'b'}, 'secret', algorithm='HS256')
'eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJhIjoiYiJ9.jiMyrsmD8AoHWeQgmxZ5yq8z0lXS67_QGs52AzC8Ru8'
"""
if algorithm not in ALGORITHMS.SUPPORTED:
raise JWSError("Algorithm %s not supported." % algorithm)
encoded_header = _encode_header(algorithm, additional_headers=headers)
encoded_payload = _encode_payload(payload)
signed_output = _sign_header_and_claims(encoded_header, encoded_payload, algorithm, key)
return signed_output
def verify(token, key, algorithms, verify=True):
"""Verifies a JWS string's signature.
Args:
token (str): A signed JWS to be verified.
key (str or dict): A key to attempt to verify the payload with. Can be
individual JWK or JWK set.
algorithms (str or list): Valid algorithms that should be used to verify the JWS.
Returns:
str: The str representation of the payload, assuming the signature is valid.
Raises:
JWSError: If there is an exception verifying a token.
Examples:
>>> token = 'eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJhIjoiYiJ9.jiMyrsmD8AoHWeQgmxZ5yq8z0lXS67_QGs52AzC8Ru8'
>>> jws.verify(token, 'secret', algorithms='HS256')
"""
header, payload, signing_input, signature = _load(token)
if verify:
_verify_signature(signing_input, header, signature, key, algorithms)
return payload
def get_unverified_header(token):
"""Returns the decoded headers without verification of any kind.
Args:
token (str): A signed JWS to decode the headers from.
Returns:
dict: The dict representation of the token headers.
Raises:
JWSError: If there is an exception decoding the token.
"""
header, claims, signing_input, signature = _load(token)
return header
def get_unverified_headers(token):
"""Returns the decoded headers without verification of any kind.
This is simply a wrapper of get_unverified_header() for backwards
compatibility.
Args:
token (str): A signed JWS to decode the headers from.
Returns:
dict: The dict representation of the token headers.
Raises:
JWSError: If there is an exception decoding the token.
"""
return get_unverified_header(token)
def get_unverified_claims(token):
"""Returns the decoded claims without verification of any kind.
Args:
token (str): A signed JWS to decode the headers from.
Returns:
str: The str representation of the token claims.
Raises:
JWSError: If there is an exception decoding the token.
"""
header, claims, signing_input, signature = _load(token)
return claims
def _encode_header(algorithm, additional_headers=None):
header = {"typ": "JWT", "alg": algorithm}
if additional_headers:
header.update(additional_headers)
json_header = json.dumps(
header,
separators=(",", ":"),
sort_keys=True,
).encode("utf-8")
return base64url_encode(json_header)
def _encode_payload(payload):
if isinstance(payload, Mapping):
try:
payload = json.dumps(
payload,
separators=(",", ":"),
).encode("utf-8")
except ValueError:
pass
return base64url_encode(payload)
def _sign_header_and_claims(encoded_header, encoded_claims, algorithm, key):
signing_input = b".".join([encoded_header, encoded_claims])
try:
if not isinstance(key, Key):
key = jwk.construct(key, algorithm)
signature = key.sign(signing_input)
except Exception as e:
raise JWSError(e)
encoded_signature = base64url_encode(signature)
encoded_string = b".".join([encoded_header, encoded_claims, encoded_signature])
return encoded_string.decode("utf-8")
def _load(jwt):
if isinstance(jwt, str):
jwt = jwt.encode("utf-8")
try:
signing_input, crypto_segment = jwt.rsplit(b".", 1)
header_segment, claims_segment = signing_input.split(b".", 1)
header_data = base64url_decode(header_segment)
except ValueError:
raise JWSError("Not enough segments")
except (TypeError, binascii.Error):
raise JWSError("Invalid header padding")
try:
header = json.loads(header_data.decode("utf-8"))
except ValueError as e:
raise JWSError("Invalid header string: %s" % e)
if not isinstance(header, Mapping):
raise JWSError("Invalid header string: must be a json object")
try:
payload = base64url_decode(claims_segment)
except (TypeError, binascii.Error):
raise JWSError("Invalid payload padding")
try:
signature = base64url_decode(crypto_segment)
except (TypeError, binascii.Error):
raise JWSError("Invalid crypto padding")
return (header, payload, signing_input, signature)
def _sig_matches_keys(keys, signing_input, signature, alg):
for key in keys:
if not isinstance(key, Key):
key = jwk.construct(key, alg)
try:
if key.verify(signing_input, signature):
return True
except Exception:
pass
return False
def _get_keys(key):
if isinstance(key, Key):
return (key,)
try:
key = json.loads(key, parse_int=str, parse_float=str)
except Exception:
pass
if isinstance(key, Mapping):
if "keys" in key:
# JWK Set per RFC 7517
return key["keys"]
elif "kty" in key:
# Individual JWK per RFC 7517
return (key,)
else:
# Some other mapping. Firebase uses just dict of kid, cert pairs
values = key.values()
if values:
return values
return (key,)
# Iterable but not text or mapping => list- or tuple-like
elif isinstance(key, Iterable) and not (isinstance(key, str) or isinstance(key, bytes)):
return key
# Scalar value, wrap in tuple.
else:
return (key,)
def _verify_signature(signing_input, header, signature, key="", algorithms=None):
alg = header.get("alg")
if not alg:
raise JWSError("No algorithm was specified in the JWS header.")
if algorithms is not None and alg not in algorithms:
raise JWSError("The specified alg value is not allowed")
keys = _get_keys(key)
try:
if not _sig_matches_keys(keys, signing_input, signature, alg):
raise JWSSignatureError()
except JWSSignatureError:
raise JWSError("Signature verification failed.")
except JWSError:
raise JWSError("Invalid or unsupported algorithm: %s" % alg)

View File

@@ -0,0 +1,496 @@
import json
from calendar import timegm
from collections.abc import Mapping
from datetime import datetime, timedelta
from jose import jws
from .constants import ALGORITHMS
from .exceptions import ExpiredSignatureError, JWSError, JWTClaimsError, JWTError
from .utils import calculate_at_hash, timedelta_total_seconds
def encode(claims, key, algorithm=ALGORITHMS.HS256, headers=None, access_token=None):
"""Encodes a claims set and returns a JWT string.
JWTs are JWS signed objects with a few reserved claims.
Args:
claims (dict): A claims set to sign
key (str or dict): The key to use for signing the claim set. Can be
individual JWK or JWK set.
algorithm (str, optional): The algorithm to use for signing the
the claims. Defaults to HS256.
headers (dict, optional): A set of headers that will be added to
the default headers. Any headers that are added as additional
headers will override the default headers.
access_token (str, optional): If present, the 'at_hash' claim will
be calculated and added to the claims present in the 'claims'
parameter.
Returns:
str: The string representation of the header, claims, and signature.
Raises:
JWTError: If there is an error encoding the claims.
Examples:
>>> jwt.encode({'a': 'b'}, 'secret', algorithm='HS256')
'eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJhIjoiYiJ9.jiMyrsmD8AoHWeQgmxZ5yq8z0lXS67_QGs52AzC8Ru8'
"""
for time_claim in ["exp", "iat", "nbf"]:
# Convert datetime to a intDate value in known time-format claims
if isinstance(claims.get(time_claim), datetime):
claims[time_claim] = timegm(claims[time_claim].utctimetuple())
if access_token:
claims["at_hash"] = calculate_at_hash(access_token, ALGORITHMS.HASHES[algorithm])
return jws.sign(claims, key, headers=headers, algorithm=algorithm)
def decode(token, key, algorithms=None, options=None, audience=None, issuer=None, subject=None, access_token=None):
"""Verifies a JWT string's signature and validates reserved claims.
Args:
token (str): A signed JWS to be verified.
key (str or dict): A key to attempt to verify the payload with. Can be
individual JWK or JWK set.
algorithms (str or list): Valid algorithms that should be used to verify the JWS.
audience (str): The intended audience of the token. If the "aud" claim is
included in the claim set, then the audience must be included and must equal
the provided claim.
issuer (str or iterable): Acceptable value(s) for the issuer of the token.
If the "iss" claim is included in the claim set, then the issuer must be
given and the claim in the token must be among the acceptable values.
subject (str): The subject of the token. If the "sub" claim is
included in the claim set, then the subject must be included and must equal
the provided claim.
access_token (str): An access token string. If the "at_hash" claim is included in the
claim set, then the access_token must be included, and it must match
the "at_hash" claim.
options (dict): A dictionary of options for skipping validation steps.
defaults = {
'verify_signature': True,
'verify_aud': True,
'verify_iat': True,
'verify_exp': True,
'verify_nbf': True,
'verify_iss': True,
'verify_sub': True,
'verify_jti': True,
'verify_at_hash': True,
'require_aud': False,
'require_iat': False,
'require_exp': False,
'require_nbf': False,
'require_iss': False,
'require_sub': False,
'require_jti': False,
'require_at_hash': False,
'leeway': 0,
}
Returns:
dict: The dict representation of the claims set, assuming the signature is valid
and all requested data validation passes.
Raises:
JWTError: If the signature is invalid in any way.
ExpiredSignatureError: If the signature has expired.
JWTClaimsError: If any claim is invalid in any way.
Examples:
>>> payload = 'eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJhIjoiYiJ9.jiMyrsmD8AoHWeQgmxZ5yq8z0lXS67_QGs52AzC8Ru8'
>>> jwt.decode(payload, 'secret', algorithms='HS256')
"""
defaults = {
"verify_signature": True,
"verify_aud": True,
"verify_iat": True,
"verify_exp": True,
"verify_nbf": True,
"verify_iss": True,
"verify_sub": True,
"verify_jti": True,
"verify_at_hash": True,
"require_aud": False,
"require_iat": False,
"require_exp": False,
"require_nbf": False,
"require_iss": False,
"require_sub": False,
"require_jti": False,
"require_at_hash": False,
"leeway": 0,
}
if options:
defaults.update(options)
verify_signature = defaults.get("verify_signature", True)
try:
payload = jws.verify(token, key, algorithms, verify=verify_signature)
except JWSError as e:
raise JWTError(e)
# Needed for at_hash verification
algorithm = jws.get_unverified_header(token)["alg"]
try:
claims = json.loads(payload.decode("utf-8"))
except ValueError as e:
raise JWTError("Invalid payload string: %s" % e)
if not isinstance(claims, Mapping):
raise JWTError("Invalid payload string: must be a json object")
_validate_claims(
claims,
audience=audience,
issuer=issuer,
subject=subject,
algorithm=algorithm,
access_token=access_token,
options=defaults,
)
return claims
def get_unverified_header(token):
"""Returns the decoded headers without verification of any kind.
Args:
token (str): A signed JWT to decode the headers from.
Returns:
dict: The dict representation of the token headers.
Raises:
JWTError: If there is an exception decoding the token.
"""
try:
headers = jws.get_unverified_headers(token)
except Exception:
raise JWTError("Error decoding token headers.")
return headers
def get_unverified_headers(token):
"""Returns the decoded headers without verification of any kind.
This is simply a wrapper of get_unverified_header() for backwards
compatibility.
Args:
token (str): A signed JWT to decode the headers from.
Returns:
dict: The dict representation of the token headers.
Raises:
JWTError: If there is an exception decoding the token.
"""
return get_unverified_header(token)
def get_unverified_claims(token):
"""Returns the decoded claims without verification of any kind.
Args:
token (str): A signed JWT to decode the headers from.
Returns:
dict: The dict representation of the token claims.
Raises:
JWTError: If there is an exception decoding the token.
"""
try:
claims = jws.get_unverified_claims(token)
except Exception:
raise JWTError("Error decoding token claims.")
try:
claims = json.loads(claims.decode("utf-8"))
except ValueError as e:
raise JWTError("Invalid claims string: %s" % e)
if not isinstance(claims, Mapping):
raise JWTError("Invalid claims string: must be a json object")
return claims
def _validate_iat(claims):
"""Validates that the 'iat' claim is valid.
The "iat" (issued at) claim identifies the time at which the JWT was
issued. This claim can be used to determine the age of the JWT. Its
value MUST be a number containing a NumericDate value. Use of this
claim is OPTIONAL.
Args:
claims (dict): The claims dictionary to validate.
"""
if "iat" not in claims:
return
try:
int(claims["iat"])
except ValueError:
raise JWTClaimsError("Issued At claim (iat) must be an integer.")
def _validate_nbf(claims, leeway=0):
"""Validates that the 'nbf' claim is valid.
The "nbf" (not before) claim identifies the time before which the JWT
MUST NOT be accepted for processing. The processing of the "nbf"
claim requires that the current date/time MUST be after or equal to
the not-before date/time listed in the "nbf" claim. Implementers MAY
provide for some small leeway, usually no more than a few minutes, to
account for clock skew. Its value MUST be a number containing a
NumericDate value. Use of this claim is OPTIONAL.
Args:
claims (dict): The claims dictionary to validate.
leeway (int): The number of seconds of skew that is allowed.
"""
if "nbf" not in claims:
return
try:
nbf = int(claims["nbf"])
except ValueError:
raise JWTClaimsError("Not Before claim (nbf) must be an integer.")
now = timegm(datetime.utcnow().utctimetuple())
if nbf > (now + leeway):
raise JWTClaimsError("The token is not yet valid (nbf)")
def _validate_exp(claims, leeway=0):
"""Validates that the 'exp' claim is valid.
The "exp" (expiration time) claim identifies the expiration time on
or after which the JWT MUST NOT be accepted for processing. The
processing of the "exp" claim requires that the current date/time
MUST be before the expiration date/time listed in the "exp" claim.
Implementers MAY provide for some small leeway, usually no more than
a few minutes, to account for clock skew. Its value MUST be a number
containing a NumericDate value. Use of this claim is OPTIONAL.
Args:
claims (dict): The claims dictionary to validate.
leeway (int): The number of seconds of skew that is allowed.
"""
if "exp" not in claims:
return
try:
exp = int(claims["exp"])
except ValueError:
raise JWTClaimsError("Expiration Time claim (exp) must be an integer.")
now = timegm(datetime.utcnow().utctimetuple())
if exp < (now - leeway):
raise ExpiredSignatureError("Signature has expired.")
def _validate_aud(claims, audience=None):
"""Validates that the 'aud' claim is valid.
The "aud" (audience) claim identifies the recipients that the JWT is
intended for. Each principal intended to process the JWT MUST
identify itself with a value in the audience claim. If the principal
processing the claim does not identify itself with a value in the
"aud" claim when this claim is present, then the JWT MUST be
rejected. In the general case, the "aud" value is an array of case-
sensitive strings, each containing a StringOrURI value. In the
special case when the JWT has one audience, the "aud" value MAY be a
single case-sensitive string containing a StringOrURI value. The
interpretation of audience values is generally application specific.
Use of this claim is OPTIONAL.
Args:
claims (dict): The claims dictionary to validate.
audience (str): The audience that is verifying the token.
"""
if "aud" not in claims:
# if audience:
# raise JWTError('Audience claim expected, but not in claims')
return
audience_claims = claims["aud"]
if isinstance(audience_claims, str):
audience_claims = [audience_claims]
if not isinstance(audience_claims, list):
raise JWTClaimsError("Invalid claim format in token")
if any(not isinstance(c, str) for c in audience_claims):
raise JWTClaimsError("Invalid claim format in token")
if audience not in audience_claims:
raise JWTClaimsError("Invalid audience")
def _validate_iss(claims, issuer=None):
"""Validates that the 'iss' claim is valid.
The "iss" (issuer) claim identifies the principal that issued the
JWT. The processing of this claim is generally application specific.
The "iss" value is a case-sensitive string containing a StringOrURI
value. Use of this claim is OPTIONAL.
Args:
claims (dict): The claims dictionary to validate.
issuer (str or iterable): Acceptable value(s) for the issuer that
signed the token.
"""
if issuer is not None:
if isinstance(issuer, str):
issuer = (issuer,)
if claims.get("iss") not in issuer:
raise JWTClaimsError("Invalid issuer")
def _validate_sub(claims, subject=None):
"""Validates that the 'sub' claim is valid.
The "sub" (subject) claim identifies the principal that is the
subject of the JWT. The claims in a JWT are normally statements
about the subject. The subject value MUST either be scoped to be
locally unique in the context of the issuer or be globally unique.
The processing of this claim is generally application specific. The
"sub" value is a case-sensitive string containing a StringOrURI
value. Use of this claim is OPTIONAL.
Args:
claims (dict): The claims dictionary to validate.
subject (str): The subject of the token.
"""
if "sub" not in claims:
return
if not isinstance(claims["sub"], str):
raise JWTClaimsError("Subject must be a string.")
if subject is not None:
if claims.get("sub") != subject:
raise JWTClaimsError("Invalid subject")
def _validate_jti(claims):
"""Validates that the 'jti' claim is valid.
The "jti" (JWT ID) claim provides a unique identifier for the JWT.
The identifier value MUST be assigned in a manner that ensures that
there is a negligible probability that the same value will be
accidentally assigned to a different data object; if the application
uses multiple issuers, collisions MUST be prevented among values
produced by different issuers as well. The "jti" claim can be used
to prevent the JWT from being replayed. The "jti" value is a case-
sensitive string. Use of this claim is OPTIONAL.
Args:
claims (dict): The claims dictionary to validate.
"""
if "jti" not in claims:
return
if not isinstance(claims["jti"], str):
raise JWTClaimsError("JWT ID must be a string.")
def _validate_at_hash(claims, access_token, algorithm):
"""
Validates that the 'at_hash' is valid.
Its value is the base64url encoding of the left-most half of the hash
of the octets of the ASCII representation of the access_token value,
where the hash algorithm used is the hash algorithm used in the alg
Header Parameter of the ID Token's JOSE Header. For instance, if the
alg is RS256, hash the access_token value with SHA-256, then take the
left-most 128 bits and base64url encode them. The at_hash value is a
case sensitive string. Use of this claim is OPTIONAL.
Args:
claims (dict): The claims dictionary to validate.
access_token (str): The access token returned by the OpenID Provider.
algorithm (str): The algorithm used to sign the JWT, as specified by
the token headers.
"""
if "at_hash" not in claims:
return
if not access_token:
msg = "No access_token provided to compare against at_hash claim."
raise JWTClaimsError(msg)
try:
expected_hash = calculate_at_hash(access_token, ALGORITHMS.HASHES[algorithm])
except (TypeError, ValueError):
msg = "Unable to calculate at_hash to verify against token claims."
raise JWTClaimsError(msg)
if claims["at_hash"] != expected_hash:
raise JWTClaimsError("at_hash claim does not match access_token.")
def _validate_claims(claims, audience=None, issuer=None, subject=None, algorithm=None, access_token=None, options=None):
leeway = options.get("leeway", 0)
if isinstance(leeway, timedelta):
leeway = timedelta_total_seconds(leeway)
required_claims = [e[len("require_") :] for e in options.keys() if e.startswith("require_") and options[e]]
for require_claim in required_claims:
if require_claim not in claims:
raise JWTError('missing required key "%s" among claims' % require_claim)
else:
options["verify_" + require_claim] = True # override verify when required
if not isinstance(audience, ((str,), type(None))):
raise JWTError("audience must be a string or None")
if options.get("verify_iat"):
_validate_iat(claims)
if options.get("verify_nbf"):
_validate_nbf(claims, leeway=leeway)
if options.get("verify_exp"):
_validate_exp(claims, leeway=leeway)
if options.get("verify_aud"):
_validate_aud(claims, audience=audience)
if options.get("verify_iss"):
_validate_iss(claims, issuer=issuer)
if options.get("verify_sub"):
_validate_sub(claims, subject=subject)
if options.get("verify_jti"):
_validate_jti(claims)
if options.get("verify_at_hash"):
_validate_at_hash(claims, access_token, algorithm)

View File

@@ -0,0 +1,108 @@
import base64
import struct
# Piggyback of the backends implementation of the function that converts a long
# to a bytes stream. Some plumbing is necessary to have the signatures match.
try:
from cryptography.utils import int_to_bytes as _long_to_bytes
def long_to_bytes(n, blocksize=0):
return _long_to_bytes(n, blocksize or None)
except ImportError:
from ecdsa.ecdsa import int_to_string as _long_to_bytes
def long_to_bytes(n, blocksize=0):
ret = _long_to_bytes(n)
if blocksize == 0:
return ret
else:
assert len(ret) <= blocksize
padding = blocksize - len(ret)
return b"\x00" * padding + ret
def long_to_base64(data, size=0):
return base64.urlsafe_b64encode(long_to_bytes(data, size)).strip(b"=")
def int_arr_to_long(arr):
return int("".join(["%02x" % byte for byte in arr]), 16)
def base64_to_long(data):
if isinstance(data, str):
data = data.encode("ascii")
# urlsafe_b64decode will happily convert b64encoded data
_d = base64.urlsafe_b64decode(bytes(data) + b"==")
return int_arr_to_long(struct.unpack("%sB" % len(_d), _d))
def calculate_at_hash(access_token, hash_alg):
"""Helper method for calculating an access token
hash, as described in http://openid.net/specs/openid-connect-core-1_0.html#CodeIDToken
Its value is the base64url encoding of the left-most half of the hash of the octets
of the ASCII representation of the access_token value, where the hash algorithm
used is the hash algorithm used in the alg Header Parameter of the ID Token's JOSE
Header. For instance, if the alg is RS256, hash the access_token value with SHA-256,
then take the left-most 128 bits and base64url encode them. The at_hash value is a
case sensitive string.
Args:
access_token (str): An access token string.
hash_alg (callable): A callable returning a hash object, e.g. hashlib.sha256
"""
hash_digest = hash_alg(access_token.encode("utf-8")).digest()
cut_at = int(len(hash_digest) / 2)
truncated = hash_digest[:cut_at]
at_hash = base64url_encode(truncated)
return at_hash.decode("utf-8")
def base64url_decode(input):
"""Helper method to base64url_decode a string.
Args:
input (str): A base64url_encoded string to decode.
"""
rem = len(input) % 4
if rem > 0:
input += b"=" * (4 - rem)
return base64.urlsafe_b64decode(input)
def base64url_encode(input):
"""Helper method to base64url_encode a string.
Args:
input (str): A base64url_encoded string to encode.
"""
return base64.urlsafe_b64encode(input).replace(b"=", b"")
def timedelta_total_seconds(delta):
"""Helper method to determine the total number of seconds
from a timedelta.
Args:
delta (timedelta): A timedelta to convert to seconds.
"""
return delta.days * 24 * 60 * 60 + delta.seconds
def ensure_binary(s):
"""Coerce **s** to bytes."""
if isinstance(s, bytes):
return s
if isinstance(s, str):
return s.encode("utf-8", "strict")
raise TypeError(f"not expecting type '{type(s)}'")