_IntegerNative.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382
  1. # ===================================================================
  2. #
  3. # Copyright (c) 2014, Legrandin <helderijs@gmail.com>
  4. # All rights reserved.
  5. #
  6. # Redistribution and use in source and binary forms, with or without
  7. # modification, are permitted provided that the following conditions
  8. # are met:
  9. #
  10. # 1. Redistributions of source code must retain the above copyright
  11. # notice, this list of conditions and the following disclaimer.
  12. # 2. Redistributions in binary form must reproduce the above copyright
  13. # notice, this list of conditions and the following disclaimer in
  14. # the documentation and/or other materials provided with the
  15. # distribution.
  16. #
  17. # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
  18. # "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
  19. # LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS
  20. # FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE
  21. # COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT,
  22. # INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
  23. # BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
  24. # LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
  25. # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
  26. # LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN
  27. # ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
  28. # POSSIBILITY OF SUCH DAMAGE.
  29. # ===================================================================
  30. from ._IntegerBase import IntegerBase
  31. from Crypto.Util.number import long_to_bytes, bytes_to_long, inverse, GCD
  32. class IntegerNative(IntegerBase):
  33. """A class to model a natural integer (including zero)"""
  34. def __init__(self, value):
  35. if isinstance(value, float):
  36. raise ValueError("A floating point type is not a natural number")
  37. try:
  38. self._value = value._value
  39. except AttributeError:
  40. self._value = value
  41. # Conversions
  42. def __int__(self):
  43. return self._value
  44. def __str__(self):
  45. return str(int(self))
  46. def __repr__(self):
  47. return "Integer(%s)" % str(self)
  48. # Only Python 2.x
  49. def __hex__(self):
  50. return hex(self._value)
  51. # Only Python 3.x
  52. def __index__(self):
  53. return int(self._value)
  54. def to_bytes(self, block_size=0, byteorder='big'):
  55. if self._value < 0:
  56. raise ValueError("Conversion only valid for non-negative numbers")
  57. result = long_to_bytes(self._value, block_size)
  58. if len(result) > block_size > 0:
  59. raise ValueError("Value too large to encode")
  60. if byteorder == 'big':
  61. pass
  62. elif byteorder == 'little':
  63. result = bytearray(result)
  64. result.reverse()
  65. result = bytes(result)
  66. else:
  67. raise ValueError("Incorrect byteorder")
  68. return result
  69. @classmethod
  70. def from_bytes(cls, byte_string, byteorder='big'):
  71. if byteorder == 'big':
  72. pass
  73. elif byteorder == 'little':
  74. byte_string = bytearray(byte_string)
  75. byte_string.reverse()
  76. else:
  77. raise ValueError("Incorrect byteorder")
  78. return cls(bytes_to_long(byte_string))
  79. # Relations
  80. def __eq__(self, term):
  81. if term is None:
  82. return False
  83. return self._value == int(term)
  84. def __ne__(self, term):
  85. return not self.__eq__(term)
  86. def __lt__(self, term):
  87. return self._value < int(term)
  88. def __le__(self, term):
  89. return self.__lt__(term) or self.__eq__(term)
  90. def __gt__(self, term):
  91. return not self.__le__(term)
  92. def __ge__(self, term):
  93. return not self.__lt__(term)
  94. def __nonzero__(self):
  95. return self._value != 0
  96. __bool__ = __nonzero__
  97. def is_negative(self):
  98. return self._value < 0
  99. # Arithmetic operations
  100. def __add__(self, term):
  101. try:
  102. return self.__class__(self._value + int(term))
  103. except (ValueError, AttributeError, TypeError):
  104. return NotImplemented
  105. def __sub__(self, term):
  106. try:
  107. return self.__class__(self._value - int(term))
  108. except (ValueError, AttributeError, TypeError):
  109. return NotImplemented
  110. def __mul__(self, factor):
  111. try:
  112. return self.__class__(self._value * int(factor))
  113. except (ValueError, AttributeError, TypeError):
  114. return NotImplemented
  115. def __floordiv__(self, divisor):
  116. return self.__class__(self._value // int(divisor))
  117. def __mod__(self, divisor):
  118. divisor_value = int(divisor)
  119. if divisor_value < 0:
  120. raise ValueError("Modulus must be positive")
  121. return self.__class__(self._value % divisor_value)
  122. def inplace_pow(self, exponent, modulus=None):
  123. exp_value = int(exponent)
  124. if exp_value < 0:
  125. raise ValueError("Exponent must not be negative")
  126. if modulus is not None:
  127. mod_value = int(modulus)
  128. if mod_value < 0:
  129. raise ValueError("Modulus must be positive")
  130. if mod_value == 0:
  131. raise ZeroDivisionError("Modulus cannot be zero")
  132. else:
  133. mod_value = None
  134. self._value = pow(self._value, exp_value, mod_value)
  135. return self
  136. def __pow__(self, exponent, modulus=None):
  137. result = self.__class__(self)
  138. return result.inplace_pow(exponent, modulus)
  139. def __abs__(self):
  140. return abs(self._value)
  141. def sqrt(self, modulus=None):
  142. value = self._value
  143. if modulus is None:
  144. if value < 0:
  145. raise ValueError("Square root of negative value")
  146. # http://stackoverflow.com/questions/15390807/integer-square-root-in-python
  147. x = value
  148. y = (x + 1) // 2
  149. while y < x:
  150. x = y
  151. y = (x + value // x) // 2
  152. result = x
  153. else:
  154. if modulus <= 0:
  155. raise ValueError("Modulus must be positive")
  156. result = self._tonelli_shanks(self % modulus, modulus)
  157. return self.__class__(result)
  158. def __iadd__(self, term):
  159. self._value += int(term)
  160. return self
  161. def __isub__(self, term):
  162. self._value -= int(term)
  163. return self
  164. def __imul__(self, term):
  165. self._value *= int(term)
  166. return self
  167. def __imod__(self, term):
  168. modulus = int(term)
  169. if modulus == 0:
  170. raise ZeroDivisionError("Division by zero")
  171. if modulus < 0:
  172. raise ValueError("Modulus must be positive")
  173. self._value %= modulus
  174. return self
  175. # Boolean/bit operations
  176. def __and__(self, term):
  177. return self.__class__(self._value & int(term))
  178. def __or__(self, term):
  179. return self.__class__(self._value | int(term))
  180. def __rshift__(self, pos):
  181. try:
  182. return self.__class__(self._value >> int(pos))
  183. except OverflowError:
  184. if self._value >= 0:
  185. return 0
  186. else:
  187. return -1
  188. def __irshift__(self, pos):
  189. try:
  190. self._value >>= int(pos)
  191. except OverflowError:
  192. if self._value >= 0:
  193. return 0
  194. else:
  195. return -1
  196. return self
  197. def __lshift__(self, pos):
  198. try:
  199. return self.__class__(self._value << int(pos))
  200. except OverflowError:
  201. raise ValueError("Incorrect shift count")
  202. def __ilshift__(self, pos):
  203. try:
  204. self._value <<= int(pos)
  205. except OverflowError:
  206. raise ValueError("Incorrect shift count")
  207. return self
  208. def get_bit(self, n):
  209. if self._value < 0:
  210. raise ValueError("no bit representation for negative values")
  211. try:
  212. try:
  213. result = (self._value >> n._value) & 1
  214. if n._value < 0:
  215. raise ValueError("negative bit count")
  216. except AttributeError:
  217. result = (self._value >> n) & 1
  218. if n < 0:
  219. raise ValueError("negative bit count")
  220. except OverflowError:
  221. result = 0
  222. return result
  223. # Extra
  224. def is_odd(self):
  225. return (self._value & 1) == 1
  226. def is_even(self):
  227. return (self._value & 1) == 0
  228. def size_in_bits(self):
  229. if self._value < 0:
  230. raise ValueError("Conversion only valid for non-negative numbers")
  231. if self._value == 0:
  232. return 1
  233. return self._value.bit_length()
  234. def size_in_bytes(self):
  235. return (self.size_in_bits() - 1) // 8 + 1
  236. def is_perfect_square(self):
  237. if self._value < 0:
  238. return False
  239. if self._value in (0, 1):
  240. return True
  241. x = self._value // 2
  242. square_x = x ** 2
  243. while square_x > self._value:
  244. x = (square_x + self._value) // (2 * x)
  245. square_x = x ** 2
  246. return self._value == x ** 2
  247. def fail_if_divisible_by(self, small_prime):
  248. if (self._value % int(small_prime)) == 0:
  249. raise ValueError("Value is composite")
  250. def multiply_accumulate(self, a, b):
  251. self._value += int(a) * int(b)
  252. return self
  253. def set(self, source):
  254. self._value = int(source)
  255. def inplace_inverse(self, modulus):
  256. self._value = inverse(self._value, int(modulus))
  257. return self
  258. def inverse(self, modulus):
  259. result = self.__class__(self)
  260. result.inplace_inverse(modulus)
  261. return result
  262. def gcd(self, term):
  263. return self.__class__(GCD(abs(self._value), abs(int(term))))
  264. def lcm(self, term):
  265. term = int(term)
  266. if self._value == 0 or term == 0:
  267. return self.__class__(0)
  268. return self.__class__(abs((self._value * term) // self.gcd(term)._value))
  269. @staticmethod
  270. def jacobi_symbol(a, n):
  271. a = int(a)
  272. n = int(n)
  273. if n <= 0:
  274. raise ValueError("n must be a positive integer")
  275. if (n & 1) == 0:
  276. raise ValueError("n must be odd for the Jacobi symbol")
  277. # Step 1
  278. a = a % n
  279. # Step 2
  280. if a == 1 or n == 1:
  281. return 1
  282. # Step 3
  283. if a == 0:
  284. return 0
  285. # Step 4
  286. e = 0
  287. a1 = a
  288. while (a1 & 1) == 0:
  289. a1 >>= 1
  290. e += 1
  291. # Step 5
  292. if (e & 1) == 0:
  293. s = 1
  294. elif n % 8 in (1, 7):
  295. s = 1
  296. else:
  297. s = -1
  298. # Step 6
  299. if n % 4 == 3 and a1 % 4 == 3:
  300. s = -s
  301. # Step 7
  302. n1 = n % a1
  303. # Step 8
  304. return s * IntegerNative.jacobi_symbol(n1, a1)
  305. @staticmethod
  306. def _mult_modulo_bytes(term1, term2, modulus):
  307. if modulus < 0:
  308. raise ValueError("Modulus must be positive")
  309. if modulus == 0:
  310. raise ZeroDivisionError("Modulus cannot be zero")
  311. if (modulus & 1) == 0:
  312. raise ValueError("Odd modulus is required")
  313. number_len = len(long_to_bytes(modulus))
  314. return long_to_bytes((term1 * term2) % modulus, number_len)