@@ -533,7 +533,9 @@ int wc_KyberKey_EncapsulateWithRandom(KyberKey* key, unsigned char* ct,
533533 byte msg [2 * KYBER_SYM_SZ ];
534534 byte kr [2 * KYBER_SYM_SZ + 1 ];
535535 int ret = 0 ;
536+ #ifndef WOLFSSL_ML_KEM
536537 unsigned int ctSz = 0 ;
538+ #endif
537539
538540 /* Validate parameters. */
539541 if ((key == NULL ) || (ct == NULL ) || (ss == NULL ) || (rand == NULL )) {
@@ -543,6 +545,7 @@ int wc_KyberKey_EncapsulateWithRandom(KyberKey* key, unsigned char* ct,
543545 ret = BUFFER_E ;
544546 }
545547
548+ #ifndef WOLFSSL_ML_KEM
546549 if (ret == 0 ) {
547550 /* Establish parameters based on key type. */
548551 switch (key -> type ) {
@@ -567,6 +570,7 @@ int wc_KyberKey_EncapsulateWithRandom(KyberKey* key, unsigned char* ct,
567570 break ;
568571 }
569572 }
573+ #endif
570574
571575 /* If public hash (h) is not stored against key, calculate it. */
572576 if ((ret == 0 ) && ((key -> flags & KYBER_FLAG_H_SET ) == 0 )) {
@@ -596,8 +600,12 @@ int wc_KyberKey_EncapsulateWithRandom(KyberKey* key, unsigned char* ct,
596600 }
597601
598602 if (ret == 0 ) {
603+ #ifndef WOLFSSL_ML_KEM
599604 /* Hash random to anonymize as seed data. */
600605 ret = KYBER_HASH_H (rand , KYBER_SYM_SZ , msg );
606+ #else
607+ XMEMCPY (msg , rand , KYBER_SYM_SZ );
608+ #endif
601609 }
602610 if (ret == 0 ) {
603611 /* Copy the hash of the public key into msg. */
@@ -612,6 +620,7 @@ int wc_KyberKey_EncapsulateWithRandom(KyberKey* key, unsigned char* ct,
612620 ret = kyberkey_encapsulate (key , msg , kr + KYBER_SYM_SZ , ct );
613621 }
614622
623+ #ifndef WOLFSSL_ML_KEM
615624 if (ret == 0 ) {
616625 /* Hash the cipher text after the seed. */
617626 ret = KYBER_HASH_H (ct , ctSz , kr + KYBER_SYM_SZ );
@@ -620,6 +629,11 @@ int wc_KyberKey_EncapsulateWithRandom(KyberKey* key, unsigned char* ct,
620629 /* Derive the secret from the seed and hash of cipher text. */
621630 ret = KYBER_KDF (kr , 2 * KYBER_SYM_SZ , ss , KYBER_SS_SZ );
622631 }
632+ #else
633+ if (ret == 0 ) {
634+ XMEMCPY (ss , kr , KYBER_SS_SZ );
635+ }
636+ #endif
623637
624638 return ret ;
625639}
@@ -725,6 +739,39 @@ static KYBER_NOINLINE int kyberkey_decapsulate(KyberKey* key,
725739 return ret ;
726740}
727741
742+ #ifdef WOLFSSL_ML_KEM
743+ /* Derive the secret from z and cipher text.
744+ *
745+ * @param [in] z Implicit rejection value.
746+ * @param [in] ct Cipher text.
747+ * @param [in] ctSz Length of cipher text in bytes.
748+ * @param [out] ss Shared secret.
749+ * @return 0 on success.
750+ * @return MEMORY_E when dynamic memory allocation failed.
751+ * @return Other negative when a hash error occurred.
752+ */
753+ static int kyber_derive_secret (const byte * z , const byte * ct , word32 ctSz ,
754+ byte * ss )
755+ {
756+ int ret ;
757+ wc_Shake shake ;
758+
759+ ret = wc_InitShake256 (& shake , NULL , INVALID_DEVID );
760+ if (ret == 0 ) {
761+ ret = wc_Shake256_Update (& shake , z , KYBER_SYM_SZ );
762+ if (ret == 0 ) {
763+ ret = wc_Shake256_Update (& shake , ct , ctSz );
764+ }
765+ if (ret == 0 ) {
766+ ret = wc_Shake256_Final (& shake , ss , KYBER_SS_SZ );
767+ }
768+ wc_Shake256_Free (& shake );
769+ }
770+
771+ return ret ;
772+ }
773+ #endif
774+
728775/**
729776 * Decapsulate the cipher text to calculate the shared secret.
730777 *
@@ -818,6 +865,7 @@ int wc_KyberKey_Decapsulate(KyberKey* key, unsigned char* ss,
818865 /* Compare generated cipher text with that passed in. */
819866 fail = kyber_cmp (ct , cmp , ctSz );
820867
868+ #ifndef WOLFSSL_ML_KEM
821869 /* Hash the cipher text after the seed. */
822870 ret = KYBER_HASH_H (ct , ctSz , kr + KYBER_SYM_SZ );
823871 }
@@ -829,6 +877,15 @@ int wc_KyberKey_Decapsulate(KyberKey* key, unsigned char* ss,
829877
830878 /* Derive the secret from the seed and hash of cipher text. */
831879 ret = KYBER_KDF (kr , 2 * KYBER_SYM_SZ , ss , KYBER_SS_SZ );
880+ #else
881+ ret = kyber_derive_secret (key -> z , ct , ctSz , msg );
882+ }
883+ if (ret == 0 ) {
884+ /* Change seed to z on comparison failure. */
885+ for (i = 0 ; i < KYBER_SYM_SZ ; i ++ ) {
886+ ss [i ] = kr [i ] ^ ((kr [i ] ^ msg [i ]) & fail );
887+ }
888+ #endif
832889 }
833890
834891#ifndef USE_INTEL_SPEEDUP
@@ -854,13 +911,14 @@ int wc_KyberKey_Decapsulate(KyberKey* key, unsigned char* ss,
854911 * @return NOT_COMPILED_IN when key type is not supported.
855912 * @return BUFFER_E when len is not the correct size.
856913 */
857- int wc_KyberKey_DecodePrivateKey (KyberKey * key , unsigned char * in , word32 len )
914+ int wc_KyberKey_DecodePrivateKey (KyberKey * key , const unsigned char * in ,
915+ word32 len )
858916{
859917 int ret = 0 ;
860918 word32 privLen = 0 ;
861919 word32 pubLen = 0 ;
862920 unsigned int k = 0 ;
863- unsigned char * p = in ;
921+ const unsigned char * p = in ;
864922
865923 /* Validate parameters. */
866924 if ((key == NULL ) || (in == NULL )) {
@@ -938,12 +996,13 @@ int wc_KyberKey_DecodePrivateKey(KyberKey* key, unsigned char* in, word32 len)
938996 * @return NOT_COMPILED_IN when key type is not supported.
939997 * @return BUFFER_E when len is not the correct size.
940998 */
941- int wc_KyberKey_DecodePublicKey (KyberKey * key , unsigned char * in , word32 len )
999+ int wc_KyberKey_DecodePublicKey (KyberKey * key , const unsigned char * in ,
1000+ word32 len )
9421001{
9431002 int ret = 0 ;
9441003 word32 pubLen = 0 ;
9451004 unsigned int k = 0 ;
946- unsigned char * p = in ;
1005+ const unsigned char * p = in ;
9471006
9481007 if ((key == NULL ) || (in == NULL )) {
9491008 ret = BAD_FUNC_ARG ;
0 commit comments