from kyber_py.kyber import Kyber512,Kyber768,Kyber1024
from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
from cryptography.hazmat.primitives import padding
import binascii
import os
import sys
def go_encrypt(msg,method,mode):
cipher = Cipher(method, mode)
encryptor = cipher.encryptor()
ct = encryptor.update(msg) + encryptor.finalize()
return (ct)
def go_decrypt(ct,method,mode):
cipher = Cipher(method, mode)
decryptor = cipher.decryptor()
return (decryptor.update(ct) + decryptor.finalize())
def pad(data,size=128):
padder = padding.PKCS7(size).padder()
padded_data = padder.update(data)
padded_data += padder.finalize()
return(padded_data)
def unpad(data,size=128):
padder = padding.PKCS7(size).unpadder()
unpadded_data = padder.update(data)
unpadded_data += padder.finalize()
return(unpadded_data)
msg="Hello"
kemtype=0
if (len(sys.argv)>1):
msg=str(sys.argv[1])
if (len(sys.argv)>2):
kemtype=int(sys.argv[2])
iv = os.urandom(16)
pk, sk = Kyber512.keygen()
key, c = Kyber512.encaps(pk)
if (kemtype==1):
pk, sk = Kyber768.keygen()
key, c = Kyber768.encaps(pk)
if (kemtype==2):
pk, sk = Kyber1024.keygen()
key, c = Kyber1024.encaps(pk)
padded_data=pad(msg.encode())
cipher=go_encrypt(padded_data,algorithms.AES(key), modes.CBC(iv))
print ("Message: ",msg)
print(f"Kyber key created: {iv.hex()}")
print(f"IV: {key.hex()}")
print ("Cipher (encrypted with Kyber key): ",binascii.b2a_hex(cipher))
if (kemtype==0):
key = Kyber512.decaps(sk, c)
print("\nKyber512")
elif (kemtype==1):
key = Kyber768.decaps(sk, c)
print("\nKyber768")
elif (kemtype==2):
key = Kyber1024.decaps(sk, c)
print("\nKyber1024")
print(f"Kyber key recovered: {key.hex()}")
plain=go_decrypt(cipher,algorithms.AES(key), modes.CBC(iv))
data=unpad(plain)
print (f"Decrypted: {data.decode()}")