require "openssl"
module Sandal
module Enc
module Alg
# Base class for RSA key encryption algorithm.
class RSA
# The JWA name of the algorithm.
attr_reader :name
# Initialises a new instance.
# @param name [String] The JWA name of the algorithm.
# @param rsa_key [OpenSSL::PKey::RSA or String] The RSA key to use for key encryption (public) or decryption
# (private). If the value is a String then it will be passed to the constructor of the RSA class. This must
# be at least 2048 bits to be compliant with the JWA specification.
def initialize(name, rsa_key, padding)
@name = name
@rsa_key = rsa_key.is_a?(String) ? : rsa_key
@padding = padding
# Encrypts the content key.
# @param key [String] The content key.
# @return [String] The encrypted content key.
def encrypt_key(key)
@rsa_key.public_encrypt(key, @padding)
# Decrypts the content key.
# @param encrypted_key [String] The encrypted content key.
# @return [String] The pre-shared content key.
# @raise [Sandal::TokenError] The content key can't be decrypted.
def decrypt_key(encrypted_key)
@rsa_key.private_decrypt(encrypted_key, @padding)
rescue => e
raise Sandal::InvalidTokenError, "Cannot decrypt content key: #{e.message}"
# The RSA1_5 key encryption algorithm.
class RSA1_5 < RSA
# The JWA name of the algorithm.
NAME = "RSA1_5"
# Initialises a new instance.
# @param rsa_key [OpenSSL::PKey::RSA or String] The RSA key to use for key encryption (public) or decryption
# (private). If the value is a String then it will be passed to the constructor of the RSA class. This must
# be at least 2048 bits to be compliant with the JWA specification.
def initialize(rsa_key)
super(NAME, rsa_key, OpenSSL::PKey::RSA::PKCS1_PADDING)
# The RSA-OAEP key encryption algorithm.
class RSA_OAEP < RSA
# The JWA name of the algorithm.
# Initialises a new instance.
# @param rsa_key [OpenSSL::PKey::RSA or String] The RSA key to use for key encryption (public) or decryption
# (private). If the value is a String then it will be passed to the constructor of the RSA class. This must
# be at least 2048 bits to be compliant with the JWA specification.
def initialize(rsa_key)
super(NAME, rsa_key, OpenSSL::PKey::RSA::PKCS1_OAEP_PADDING)