256 lines
6.6 KiB
Python
256 lines
6.6 KiB
Python
# -*- coding: utf-8 -*-
|
|
# type: ignore
|
|
from __future__ import annotations
|
|
|
|
import datetime
|
|
import typing
|
|
from unittest import mock
|
|
|
|
import freezegun
|
|
import jose
|
|
import pytest
|
|
|
|
from bthlabs_jsonrpc_core.ext import jwt
|
|
|
|
|
|
@pytest.fixture
|
|
def key_pair() -> jwt.KeyPair:
|
|
return jwt.KeyPair(
|
|
encode_key=jose.jwk.HMACKey(
|
|
'thisisntsecure', algorithm=jose.constants.ALGORITHMS.HS256,
|
|
),
|
|
decode_key=jose.jwk.HMACKey(
|
|
'thisisntsecure', algorithm=jose.constants.ALGORITHMS.HS256,
|
|
),
|
|
)
|
|
|
|
|
|
@pytest.fixture
|
|
def response() -> dict:
|
|
return {
|
|
'jsonrpc': '2.0',
|
|
'id': 'test',
|
|
'result': ['system.list_methods'],
|
|
}
|
|
|
|
|
|
@pytest.fixture
|
|
def encoded_payload(single_call: dict,
|
|
iat: datetime.datetime,
|
|
key_pair: jwt.KeyPair) -> str:
|
|
claims = {
|
|
'iss': 'bthlabs_jsonrpc_core_tests',
|
|
'jsonrpc': single_call,
|
|
'iat': iat,
|
|
}
|
|
|
|
return jose.jwt.encode(claims, key_pair.encode_key)
|
|
|
|
|
|
@pytest.fixture
|
|
def time_claims(iat: datetime.datetime) -> jwt.TimeClaims:
|
|
return jwt.TimeClaims(iat=iat, nbf=None, exp=None)
|
|
|
|
|
|
def test_init(key_pair: jwt.KeyPair):
|
|
# When
|
|
result = jwt.JWTCodec(key_pair)
|
|
|
|
# Then
|
|
assert result.key_pair == key_pair
|
|
assert result.algorithm == jose.constants.ALGORITHMS.HS256
|
|
assert result.issuer is None
|
|
assert result.ttl is None
|
|
assert result.include_nbf is True
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
'kwarg,value',
|
|
[
|
|
('algorithm', jose.constants.ALGORITHMS.RS256),
|
|
('issuer', 'bthlabs_jsonrpc_core_tests'),
|
|
('ttl', datetime.timedelta(seconds=10)),
|
|
('include_nbf', False),
|
|
],
|
|
)
|
|
def test_init_with_kwargs(kwarg: str,
|
|
value: typing.Any,
|
|
key_pair: jwt.KeyPair):
|
|
# Given
|
|
init_kwargs = {kwarg: value}
|
|
|
|
# When
|
|
result = jwt.JWTCodec(key_pair, **init_kwargs)
|
|
|
|
# Then
|
|
assert getattr(result, kwarg) == value
|
|
|
|
|
|
def test_decode(key_pair: jwt.KeyPair,
|
|
encoded_payload: str,
|
|
single_call: dict):
|
|
# Given
|
|
codec = jwt.JWTCodec(key_pair)
|
|
|
|
# When
|
|
result = codec.decode(encoded_payload)
|
|
|
|
# Then
|
|
assert result == single_call
|
|
|
|
|
|
def test_decode_jwt_decode_single_call(key_pair: jwt.KeyPair,
|
|
encoded_payload: str):
|
|
# Given
|
|
codec = jwt.JWTCodec(key_pair)
|
|
|
|
with mock.patch.object(jwt.jwt, 'decode') as mock_jwt_decode:
|
|
# When
|
|
_ = codec.decode(encoded_payload)
|
|
|
|
# Then
|
|
mock_jwt_decode.assert_called_once_with(
|
|
encoded_payload,
|
|
key_pair.decode_key,
|
|
algorithms=[codec.algorithm],
|
|
)
|
|
|
|
|
|
def test_decode_with_decoder_kwargs(key_pair: jwt.KeyPair,
|
|
encoded_payload: str):
|
|
# Given
|
|
codec = jwt.JWTCodec(key_pair)
|
|
|
|
with mock.patch.object(jwt.jwt, 'decode') as mock_jwt_decode:
|
|
# When
|
|
_ = codec.decode(encoded_payload, issuer='bthlabs_jsonrpc_core_tests')
|
|
|
|
# Then
|
|
mock_jwt_decode.assert_called_once_with(
|
|
encoded_payload,
|
|
key_pair.decode_key,
|
|
algorithms=[codec.algorithm],
|
|
issuer='bthlabs_jsonrpc_core_tests',
|
|
)
|
|
|
|
|
|
@freezegun.freeze_time('2024-01-11 07:09:43')
|
|
def test_encode(key_pair: jwt.KeyPair, response: dict):
|
|
# Given
|
|
codec = jwt.JWTCodec(key_pair)
|
|
|
|
# When
|
|
result = codec.encode(response)
|
|
|
|
# Then
|
|
assert isinstance(result, str) is True
|
|
|
|
decoded_result = jose.jwt.decode(result, key_pair.decode_key)
|
|
expected_decoded_result = {
|
|
'iat': 1704953383,
|
|
'nbf': 1704953383,
|
|
'jsonrpc': response,
|
|
}
|
|
assert decoded_result == expected_decoded_result
|
|
|
|
|
|
def test_encode_with_issuer(key_pair: jwt.KeyPair, response: dict):
|
|
# Given
|
|
codec = jwt.JWTCodec(key_pair, issuer='bthlabs_jsonrpc_core_tests')
|
|
|
|
# When
|
|
result = codec.encode(response)
|
|
|
|
# Then
|
|
assert isinstance(result, str) is True
|
|
|
|
decoded_result = jose.jwt.decode(result, key_pair.decode_key)
|
|
assert decoded_result['iss'] == 'bthlabs_jsonrpc_core_tests'
|
|
|
|
|
|
def test_encode_with_ttl(key_pair: jwt.KeyPair, response: dict):
|
|
# Given
|
|
codec = jwt.JWTCodec(key_pair, ttl=datetime.timedelta(seconds=60))
|
|
|
|
# When
|
|
result = codec.encode(response)
|
|
|
|
# Then
|
|
assert isinstance(result, str) is True
|
|
|
|
decoded_result = jose.jwt.decode(result, key_pair.decode_key)
|
|
assert 'exp' in decoded_result
|
|
assert decoded_result['exp'] - decoded_result['iat'] == 60
|
|
|
|
|
|
def test_encode_without_nbf(key_pair: jwt.KeyPair, response: dict):
|
|
# Given
|
|
codec = jwt.JWTCodec(key_pair, include_nbf=False)
|
|
|
|
# When
|
|
result = codec.encode(response)
|
|
|
|
# Then
|
|
assert isinstance(result, str) is True
|
|
|
|
decoded_result = jose.jwt.decode(result, key_pair.decode_key)
|
|
assert 'nbf' not in decoded_result
|
|
|
|
|
|
def test_encode_jwt_encode_single_call(key_pair: jwt.KeyPair,
|
|
time_claims: jwt.TimeClaims,
|
|
response: dict):
|
|
# Given
|
|
codec = jwt.JWTCodec(key_pair)
|
|
|
|
with mock.patch.object(codec, 'get_time_claims') as mock_get_time_claims:
|
|
with mock.patch.object(jwt.jwt, 'encode') as mock_jwt_encode:
|
|
mock_get_time_claims.return_value = time_claims
|
|
|
|
# When
|
|
_ = codec.encode(response)
|
|
|
|
mock_jwt_encode.assert_called_once_with(
|
|
{
|
|
'iat': time_claims.iat,
|
|
'jsonrpc': response,
|
|
},
|
|
key_pair.encode_key,
|
|
algorithm=codec.algorithm,
|
|
)
|
|
|
|
|
|
def test_encode_with_encoder_kwargs(key_pair: jwt.KeyPair,
|
|
time_claims: jwt.TimeClaims,
|
|
response: dict):
|
|
# Given
|
|
codec = jwt.JWTCodec(key_pair)
|
|
|
|
with mock.patch.object(codec, 'get_time_claims') as mock_get_time_claims:
|
|
with mock.patch.object(jwt.jwt, 'encode') as mock_jwt_encode:
|
|
mock_get_time_claims.return_value = time_claims
|
|
|
|
# When
|
|
_ = codec.encode(response, headers={'cty': 'JWT'})
|
|
|
|
mock_jwt_encode.assert_called_once_with(
|
|
{
|
|
'iat': time_claims.iat,
|
|
'jsonrpc': response,
|
|
},
|
|
key_pair.encode_key,
|
|
algorithm=codec.algorithm,
|
|
headers={'cty': 'JWT'},
|
|
)
|
|
|
|
|
|
def test_get_content_type():
|
|
# Given
|
|
codec = jwt.JWTCodec(key_pair)
|
|
|
|
# When
|
|
result = codec.get_content_type()
|
|
|
|
# Then
|
|
assert result == 'application/jwt'
|