DH.py 3.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101
  1. from Crypto.Util.number import long_to_bytes
  2. from Crypto.PublicKey.ECC import EccKey
  3. def _compute_ecdh(key_priv, key_pub):
  4. # See Section 5.7.1.2 in NIST SP 800-56Ar3
  5. pointP = key_pub.pointQ * key_priv.d
  6. if pointP.is_point_at_infinity():
  7. raise ValueError("Invalid ECDH point")
  8. z = long_to_bytes(pointP.x, pointP.size_in_bytes())
  9. return z
  10. def key_agreement(**kwargs):
  11. """Perform a Diffie-Hellman key agreement.
  12. Keywords:
  13. kdf (callable):
  14. A key derivation function that accepts ``bytes`` as input and returns
  15. ``bytes``.
  16. static_priv (EccKey):
  17. The local static private key. Optional.
  18. static_pub (EccKey):
  19. The static public key that belongs to the peer. Optional.
  20. eph_priv (EccKey):
  21. The local ephemeral private key, generated for this session. Optional.
  22. eph_pub (EccKey):
  23. The ephemeral public key, received from the peer for this session. Optional.
  24. At least two keys must be passed, of which one is a private key and one
  25. a public key.
  26. Returns (bytes):
  27. The derived secret key material.
  28. """
  29. static_priv = kwargs.get('static_priv', None)
  30. static_pub = kwargs.get('static_pub', None)
  31. eph_priv = kwargs.get('eph_priv', None)
  32. eph_pub = kwargs.get('eph_pub', None)
  33. kdf = kwargs.get('kdf', None)
  34. if kdf is None:
  35. raise ValueError("'kdf' is mandatory")
  36. count_priv = 0
  37. count_pub = 0
  38. curve = None
  39. def check_curve(curve, key, name, private):
  40. if not isinstance(key, EccKey):
  41. raise TypeError("'%s' must be an ECC key" % name)
  42. if private and not key.has_private():
  43. raise TypeError("'%s' must be a private ECC key" % name)
  44. if curve is None:
  45. curve = key.curve
  46. elif curve != key.curve:
  47. raise TypeError("'%s' is defined on an incompatible curve" % name)
  48. return curve
  49. if static_priv is not None:
  50. curve = check_curve(curve, static_priv, 'static_priv', True)
  51. count_priv += 1
  52. if static_pub is not None:
  53. curve = check_curve(curve, static_pub, 'static_pub', False)
  54. count_pub += 1
  55. if eph_priv is not None:
  56. curve = check_curve(curve, eph_priv, 'eph_priv', True)
  57. count_priv += 1
  58. if eph_pub is not None:
  59. curve = check_curve(curve, eph_pub, 'eph_pub', False)
  60. count_pub += 1
  61. if (count_priv + count_pub) < 2 or count_priv == 0 or count_pub == 0:
  62. raise ValueError("Too few keys for the ECDH key agreement")
  63. Zs = b''
  64. Ze = b''
  65. if static_priv and static_pub:
  66. # C(*, 2s)
  67. Zs = _compute_ecdh(static_priv, static_pub)
  68. if eph_priv and eph_pub:
  69. # C(2e, 0s) or C(2e, 2s)
  70. if bool(static_priv) != bool(static_pub):
  71. raise ValueError("DH mode C(2e, 1s) is not supported")
  72. Ze = _compute_ecdh(eph_priv, eph_pub)
  73. elif eph_priv and static_pub:
  74. # C(1e, 2s) or C(1e, 1s)
  75. Ze = _compute_ecdh(eph_priv, static_pub)
  76. elif eph_pub and static_priv:
  77. # C(1e, 2s) or C(1e, 1s)
  78. Ze = _compute_ecdh(static_priv, eph_pub)
  79. Z = Ze + Zs
  80. return kdf(Z)