1
0
Fork 0
bthlabs-jsonrpc/packages/bthlabs-jsonrpc-core/bthlabs_jsonrpc_core/ext/jwt.py

153 lines
4.0 KiB
Python

# -*- coding: utf-8 -*-
# bthlabs-jsonrpc-core | (c) 2022-present Tomek Wójcik | MIT License
from __future__ import annotations
import datetime
import dataclasses
import typing
from jose import jwt
from jose.constants import ALGORITHMS
from jose.jwk import ( # noqa: F401
ECKey,
HMACKey,
RSAKey,
)
import pytz
from bthlabs_jsonrpc_core.codecs import Codec
@dataclasses.dataclass
class KeyPair:
"""
Key pair used to verify and sign JWTs.
For HMAC, both *decode_key* and *encode_key* should be `HMACKey` instances,
wrapping the respective secrets.
For RSA and ECDSA, *decode_key* must be a public key for signature
verification. *encode_key* must be a private key for signing.
"""
decode_key: ECKey | HMACKey | RSAKey
encode_key: ECKey | HMACKey | RSAKey
@dataclasses.dataclass
class TimeClaims:
"""Time claims container."""
iat: datetime.datetime
nbf: datetime.datetime | None
exp: datetime.datetime | None
def as_claims(self) -> dict:
"""
Return dict representation of the claims suitable for including in a
JWT.
"""
result = {
'iat': self.iat,
}
if self.nbf is not None:
result['nbf'] = self.nbf
if self.exp is not None:
result['exp'] = self.exp
return result
class JWTCodec(Codec):
"""
JWT codec. Uses keys specified in *key_pair* when decoding and encoding
tokens.
*algorithm* specifies the signature algorithm to use. Defaults to
:py:attr:`ALGORITHMS.HS256`.
*issuer* specifies the ``iss`` claim. Defaults to ``None`` for no issuer.
*ttl* specifies the token's TTL. It'll be used to generate the ``exp``
claim. Defaults to ``None`` for non-expiring token.
*include_nbf* specifies if the ``nbf`` claim should be added to the token.
"""
def __init__(self,
key_pair: KeyPair,
*,
algorithm: str = ALGORITHMS.HS256,
issuer: str | None = None,
ttl: datetime.timedelta | None = None,
include_nbf: bool = True):
super().__init__()
self.key_pair = key_pair
self.algorithm = algorithm
self.issuer = issuer
self.ttl = ttl
self.include_nbf = include_nbf
# pragma mark - Private interface
def get_time_claims(self) -> TimeClaims:
"""
Get time claims.
:meta: private
"""
now = datetime.datetime.now().astimezone(pytz.utc)
exp: datetime.datetime | None = None
if self.ttl is not None:
exp = now + self.ttl
return TimeClaims(
iat=now,
nbf=now if self.include_nbf is True else None,
exp=exp,
)
# pragma mark - Public interface
def decode(self, payload: str | bytes, **decoder_kwargs) -> typing.Any:
"""
Decode payload using :py:func:`jose.jwt.decode`. *decoder_kwargs* will
be passed verbatim to the decode function.
Consult *python-jose* documentation for more information.
"""
decoded_payload = jwt.decode(
payload,
self.key_pair.decode_key,
algorithms=[self.algorithm],
**decoder_kwargs,
)
return decoded_payload['jsonrpc']
def encode(self, payload: typing.Any, **encoder_kwargs) -> str:
"""
Encode payload using :py:func:`jose.jwt.encode`. *encoder_kwargs* will
be passed verbatim to the encode function.
Consult *python-jose* documentation for more information.
"""
claims: dict = {
**self.get_time_claims().as_claims(),
'jsonrpc': payload,
}
if self.issuer is not None:
claims['iss'] = self.issuer
return jwt.encode(
claims,
self.key_pair.encode_key,
algorithm=self.algorithm,
**encoder_kwargs,
)
def get_content_type(self) -> str:
"""Returns ``application/jwt``."""
return 'application/jwt'