safegcd_implementation.md (34380B)
1 # The safegcd implementation in libsecp256k1 explained 2 3 This document explains the modular inverse and Jacobi symbol implementations in the `src/modinv*.h` files. 4 It is based on the paper 5 ["Fast constant-time gcd computation and modular inversion"](https://gcd.cr.yp.to/papers.html#safegcd) 6 by Daniel J. Bernstein and Bo-Yin Yang. The references below are for the Date: 2019.04.13 version. 7 8 The actual implementation is in C of course, but for demonstration purposes Python3 is used here. 9 Most implementation aspects and optimizations are explained, except those that depend on the specific 10 number representation used in the C code. 11 12 ## 1. Computing the Greatest Common Divisor (GCD) using divsteps 13 14 The algorithm from the paper (section 11), at a very high level, is this: 15 16 ```python 17 def gcd(f, g): 18 """Compute the GCD of an odd integer f and another integer g.""" 19 assert f & 1 # require f to be odd 20 delta = 1 # additional state variable 21 while g != 0: 22 assert f & 1 # f will be odd in every iteration 23 if delta > 0 and g & 1: 24 delta, f, g = 1 - delta, g, (g - f) // 2 25 elif g & 1: 26 delta, f, g = 1 + delta, f, (g + f) // 2 27 else: 28 delta, f, g = 1 + delta, f, (g ) // 2 29 return abs(f) 30 ``` 31 32 It computes the greatest common divisor of an odd integer *f* and any integer *g*. Its inner loop 33 keeps rewriting the variables *f* and *g* alongside a state variable *δ* that starts at *1*, until 34 *g=0* is reached. At that point, *|f|* gives the GCD. Each of the transitions in the loop is called a 35 "division step" (referred to as divstep in what follows). 36 37 For example, *gcd(21, 14)* would be computed as: 38 - Start with *δ=1 f=21 g=14* 39 - Take the third branch: *δ=2 f=21 g=7* 40 - Take the first branch: *δ=-1 f=7 g=-7* 41 - Take the second branch: *δ=0 f=7 g=0* 42 - The answer *|f| = 7*. 43 44 Why it works: 45 - Divsteps can be decomposed into two steps (see paragraph 8.2 in the paper): 46 - (a) If *g* is odd, replace *(f,g)* with *(g,g-f)* or (f,g+f), resulting in an even *g*. 47 - (b) Replace *(f,g)* with *(f,g/2)* (where *g* is guaranteed to be even). 48 - Neither of those two operations change the GCD: 49 - For (a), assume *gcd(f,g)=c*, then it must be the case that *f=a c* and *g=b c* for some integers *a* 50 and *b*. As *(g,g-f)=(b c,(b-a)c)* and *(f,f+g)=(a c,(a+b)c)*, the result clearly still has 51 common factor *c*. Reasoning in the other direction shows that no common factor can be added by 52 doing so either. 53 - For (b), we know that *f* is odd, so *gcd(f,g)* clearly has no factor *2*, and we can remove 54 it from *g*. 55 - The algorithm will eventually converge to *g=0*. This is proven in the paper (see theorem G.3). 56 - It follows that eventually we find a final value *f'* for which *gcd(f,g) = gcd(f',0)*. As the 57 gcd of *f'* and *0* is *|f'|* by definition, that is our answer. 58 59 Compared to more [traditional GCD algorithms](https://en.wikipedia.org/wiki/Euclidean_algorithm), this one has the property of only ever looking at 60 the low-order bits of the variables to decide the next steps, and being easy to make 61 constant-time (in more low-level languages than Python). The *δ* parameter is necessary to 62 guide the algorithm towards shrinking the numbers' magnitudes without explicitly needing to look 63 at high order bits. 64 65 Properties that will become important later: 66 - Performing more divsteps than needed is not a problem, as *f* does not change anymore after *g=0*. 67 - Only even numbers are divided by *2*. This means that when reasoning about it algebraically we 68 do not need to worry about rounding. 69 - At every point during the algorithm's execution the next *N* steps only depend on the bottom *N* 70 bits of *f* and *g*, and on *δ*. 71 72 73 ## 2. From GCDs to modular inverses 74 75 We want an algorithm to compute the inverse *a* of *x* modulo *M*, i.e. the number a such that *a x=1 76 mod M*. This inverse only exists if the GCD of *x* and *M* is *1*, but that is always the case if *M* is 77 prime and *0 < x < M*. In what follows, assume that the modular inverse exists. 78 It turns out this inverse can be computed as a side effect of computing the GCD by keeping track 79 of how the internal variables can be written as linear combinations of the inputs at every step 80 (see the [extended Euclidean algorithm](https://en.wikipedia.org/wiki/Extended_Euclidean_algorithm)). 81 Since the GCD is *1*, such an algorithm will compute numbers *a* and *b* such that a x + b M = 1*. 82 Taking that expression *mod M* gives *a x mod M = 1*, and we see that *a* is the modular inverse of *x 83 mod M*. 84 85 A similar approach can be used to calculate modular inverses using the divsteps-based GCD 86 algorithm shown above, if the modulus *M* is odd. To do so, compute *gcd(f=M,g=x)*, while keeping 87 track of extra variables *d* and *e*, for which at every step *d = f/x (mod M)* and *e = g/x (mod M)*. 88 *f/x* here means the number which multiplied with *x* gives *f mod M*. As *f* and *g* are initialized to *M* 89 and *x* respectively, *d* and *e* just start off being *0* (*M/x mod M = 0/x mod M = 0*) and *1* (*x/x mod M 90 = 1*). 91 92 ```python 93 def div2(M, x): 94 """Helper routine to compute x/2 mod M (where M is odd).""" 95 assert M & 1 96 if x & 1: # If x is odd, make it even by adding M. 97 x += M 98 # x must be even now, so a clean division by 2 is possible. 99 return x // 2 100 101 def modinv(M, x): 102 """Compute the inverse of x mod M (given that it exists, and M is odd).""" 103 assert M & 1 104 delta, f, g, d, e = 1, M, x, 0, 1 105 while g != 0: 106 # Note that while division by two for f and g is only ever done on even inputs, this is 107 # not true for d and e, so we need the div2 helper function. 108 if delta > 0 and g & 1: 109 delta, f, g, d, e = 1 - delta, g, (g - f) // 2, e, div2(M, e - d) 110 elif g & 1: 111 delta, f, g, d, e = 1 + delta, f, (g + f) // 2, d, div2(M, e + d) 112 else: 113 delta, f, g, d, e = 1 + delta, f, (g ) // 2, d, div2(M, e ) 114 # Verify that the invariants d=f/x mod M, e=g/x mod M are maintained. 115 assert f % M == (d * x) % M 116 assert g % M == (e * x) % M 117 assert f == 1 or f == -1 # |f| is the GCD, it must be 1 118 # Because of invariant d = f/x (mod M), 1/x = d/f (mod M). As |f|=1, d/f = d*f. 119 return (d * f) % M 120 ``` 121 122 Also note that this approach to track *d* and *e* throughout the computation to determine the inverse 123 is different from the paper. There (see paragraph 12.1 in the paper) a transition matrix for the 124 entire computation is determined (see section 3 below) and the inverse is computed from that. 125 The approach here avoids the need for 2x2 matrix multiplications of various sizes, and appears to 126 be faster at the level of optimization we're able to do in C. 127 128 129 ## 3. Batching multiple divsteps 130 131 Every divstep can be expressed as a matrix multiplication, applying a transition matrix *(1/2 t)* 132 to both vectors *[f, g]* and *[d, e]* (see paragraph 8.1 in the paper): 133 134 ``` 135 t = [ u, v ] 136 [ q, r ] 137 138 [ out_f ] = (1/2 * t) * [ in_f ] 139 [ out_g ] = [ in_g ] 140 141 [ out_d ] = (1/2 * t) * [ in_d ] (mod M) 142 [ out_e ] [ in_e ] 143 ``` 144 145 where *(u, v, q, r)* is *(0, 2, -1, 1)*, *(2, 0, 1, 1)*, or *(2, 0, 0, 1)*, depending on which branch is 146 taken. As above, the resulting *f* and *g* are always integers. 147 148 Performing multiple divsteps corresponds to a multiplication with the product of all the 149 individual divsteps' transition matrices. As each transition matrix consists of integers 150 divided by *2*, the product of these matrices will consist of integers divided by *2<sup>N</sup>* (see also 151 theorem 9.2 in the paper). These divisions are expensive when updating *d* and *e*, so we delay 152 them: we compute the integer coefficients of the combined transition matrix scaled by *2<sup>N</sup>*, and 153 do one division by *2<sup>N</sup>* as a final step: 154 155 ```python 156 def divsteps_n_matrix(delta, f, g): 157 """Compute delta and transition matrix t after N divsteps (multiplied by 2^N).""" 158 u, v, q, r = 1, 0, 0, 1 # start with identity matrix 159 for _ in range(N): 160 if delta > 0 and g & 1: 161 delta, f, g, u, v, q, r = 1 - delta, g, (g - f) // 2, 2*q, 2*r, q-u, r-v 162 elif g & 1: 163 delta, f, g, u, v, q, r = 1 + delta, f, (g + f) // 2, 2*u, 2*v, q+u, r+v 164 else: 165 delta, f, g, u, v, q, r = 1 + delta, f, (g ) // 2, 2*u, 2*v, q , r 166 return delta, (u, v, q, r) 167 ``` 168 169 As the branches in the divsteps are completely determined by the bottom *N* bits of *f* and *g*, this 170 function to compute the transition matrix only needs to see those bottom bits. Furthermore all 171 intermediate results and outputs fit in *(N+1)*-bit numbers (unsigned for *f* and *g*; signed for *u*, *v*, 172 *q*, and *r*) (see also paragraph 8.3 in the paper). This means that an implementation using 64-bit 173 integers could set *N=62* and compute the full transition matrix for 62 steps at once without any 174 big integer arithmetic at all. This is the reason why this algorithm is efficient: it only needs 175 to update the full-size *f*, *g*, *d*, and *e* numbers once every *N* steps. 176 177 We still need functions to compute: 178 179 ``` 180 [ out_f ] = (1/2^N * [ u, v ]) * [ in_f ] 181 [ out_g ] ( [ q, r ]) [ in_g ] 182 183 [ out_d ] = (1/2^N * [ u, v ]) * [ in_d ] (mod M) 184 [ out_e ] ( [ q, r ]) [ in_e ] 185 ``` 186 187 Because the divsteps transformation only ever divides even numbers by two, the result of *t [f,g]* is always even. When *t* is a composition of *N* divsteps, it follows that the resulting *f* 188 and *g* will be multiple of *2<sup>N</sup>*, and division by *2<sup>N</sup>* is simply shifting them down: 189 190 ```python 191 def update_fg(f, g, t): 192 """Multiply matrix t/2^N with [f, g].""" 193 u, v, q, r = t 194 cf, cg = u*f + v*g, q*f + r*g 195 # (t / 2^N) should cleanly apply to [f,g] so the result of t*[f,g] should have N zero 196 # bottom bits. 197 assert cf % 2**N == 0 198 assert cg % 2**N == 0 199 return cf >> N, cg >> N 200 ``` 201 202 The same is not true for *d* and *e*, and we need an equivalent of the `div2` function for division by *2<sup>N</sup> mod M*. 203 This is easy if we have precomputed *1/M mod 2<sup>N</sup>* (which always exists for odd *M*): 204 205 ```python 206 def div2n(M, Mi, x): 207 """Compute x/2^N mod M, given Mi = 1/M mod 2^N.""" 208 assert (M * Mi) % 2**N == 1 209 # Find a factor m such that m*M has the same bottom N bits as x. We want: 210 # (m * M) mod 2^N = x mod 2^N 211 # <=> m mod 2^N = (x / M) mod 2^N 212 # <=> m mod 2^N = (x * Mi) mod 2^N 213 m = (Mi * x) % 2**N 214 # Subtract that multiple from x, cancelling its bottom N bits. 215 x -= m * M 216 # Now a clean division by 2^N is possible. 217 assert x % 2**N == 0 218 return (x >> N) % M 219 220 def update_de(d, e, t, M, Mi): 221 """Multiply matrix t/2^N with [d, e], modulo M.""" 222 u, v, q, r = t 223 cd, ce = u*d + v*e, q*d + r*e 224 return div2n(M, Mi, cd), div2n(M, Mi, ce) 225 ``` 226 227 With all of those, we can write a version of `modinv` that performs *N* divsteps at once: 228 229 ```python3 230 def modinv(M, Mi, x): 231 """Compute the modular inverse of x mod M, given Mi=1/M mod 2^N.""" 232 assert M & 1 233 delta, f, g, d, e = 1, M, x, 0, 1 234 while g != 0: 235 # Compute the delta and transition matrix t for the next N divsteps (this only needs 236 # (N+1)-bit signed integer arithmetic). 237 delta, t = divsteps_n_matrix(delta, f % 2**N, g % 2**N) 238 # Apply the transition matrix t to [f, g]: 239 f, g = update_fg(f, g, t) 240 # Apply the transition matrix t to [d, e]: 241 d, e = update_de(d, e, t, M, Mi) 242 return (d * f) % M 243 ``` 244 245 This means that in practice we'll always perform a multiple of *N* divsteps. This is not a problem 246 because once *g=0*, further divsteps do not affect *f*, *g*, *d*, or *e* anymore (only *δ* keeps 247 increasing). For variable time code such excess iterations will be mostly optimized away in later 248 sections. 249 250 251 ## 4. Avoiding modulus operations 252 253 So far, there are two places where we compute a remainder of big numbers modulo *M*: at the end of 254 `div2n` in every `update_de`, and at the very end of `modinv` after potentially negating *d* due to the 255 sign of *f*. These are relatively expensive operations when done generically. 256 257 To deal with the modulus operation in `div2n`, we simply stop requiring *d* and *e* to be in range 258 *[0,M)* all the time. Let's start by inlining `div2n` into `update_de`, and dropping the modulus 259 operation at the end: 260 261 ```python 262 def update_de(d, e, t, M, Mi): 263 """Multiply matrix t/2^N with [d, e] mod M, given Mi=1/M mod 2^N.""" 264 u, v, q, r = t 265 cd, ce = u*d + v*e, q*d + r*e 266 # Cancel out bottom N bits of cd and ce. 267 md = -((Mi * cd) % 2**N) 268 me = -((Mi * ce) % 2**N) 269 cd += md * M 270 ce += me * M 271 # And cleanly divide by 2**N. 272 return cd >> N, ce >> N 273 ``` 274 275 Let's look at bounds on the ranges of these numbers. It can be shown that *|u|+|v|* and *|q|+|r|* 276 never exceed *2<sup>N</sup>* (see paragraph 8.3 in the paper), and thus a multiplication with *t* will have 277 outputs whose absolute values are at most *2<sup>N</sup>* times the maximum absolute input value. In case the 278 inputs *d* and *e* are in *(-M,M)*, which is certainly true for the initial values *d=0* and *e=1* assuming 279 *M > 1*, the multiplication results in numbers in range *(-2<sup>N</sup>M,2<sup>N</sup>M)*. Subtracting less than *2<sup>N</sup>* 280 times *M* to cancel out *N* bits brings that up to *(-2<sup>N+1</sup>M,2<sup>N</sup>M)*, and 281 dividing by *2<sup>N</sup>* at the end takes it to *(-2M,M)*. Another application of `update_de` would take that 282 to *(-3M,2M)*, and so forth. This progressive expansion of the variables' ranges can be 283 counteracted by incrementing *d* and *e* by *M* whenever they're negative: 284 285 ```python 286 ... 287 if d < 0: 288 d += M 289 if e < 0: 290 e += M 291 cd, ce = u*d + v*e, q*d + r*e 292 # Cancel out bottom N bits of cd and ce. 293 ... 294 ``` 295 296 With inputs in *(-2M,M)*, they will first be shifted into range *(-M,M)*, which means that the 297 output will again be in *(-2M,M)*, and this remains the case regardless of how many `update_de` 298 invocations there are. In what follows, we will try to make this more efficient. 299 300 Note that increasing *d* by *M* is equal to incrementing *cd* by *u M* and *ce* by *q M*. Similarly, 301 increasing *e* by *M* is equal to incrementing *cd* by *v M* and *ce* by *r M*. So we could instead write: 302 303 ```python 304 ... 305 cd, ce = u*d + v*e, q*d + r*e 306 # Perform the equivalent of incrementing d, e by M when they're negative. 307 if d < 0: 308 cd += u*M 309 ce += q*M 310 if e < 0: 311 cd += v*M 312 ce += r*M 313 # Cancel out bottom N bits of cd and ce. 314 md = -((Mi * cd) % 2**N) 315 me = -((Mi * ce) % 2**N) 316 cd += md * M 317 ce += me * M 318 ... 319 ``` 320 321 Now note that we have two steps of corrections to *cd* and *ce* that add multiples of *M*: this 322 increment, and the decrement that cancels out bottom bits. The second one depends on the first 323 one, but they can still be efficiently combined by only computing the bottom bits of *cd* and *ce* 324 at first, and using that to compute the final *md*, *me* values: 325 326 ```python 327 def update_de(d, e, t, M, Mi): 328 """Multiply matrix t/2^N with [d, e], modulo M.""" 329 u, v, q, r = t 330 md, me = 0, 0 331 # Compute what multiples of M to add to cd and ce. 332 if d < 0: 333 md += u 334 me += q 335 if e < 0: 336 md += v 337 me += r 338 # Compute bottom N bits of t*[d,e] + M*[md,me]. 339 cd, ce = (u*d + v*e + md*M) % 2**N, (q*d + r*e + me*M) % 2**N 340 # Correct md and me such that the bottom N bits of t*[d,e] + M*[md,me] are zero. 341 md -= (Mi * cd) % 2**N 342 me -= (Mi * ce) % 2**N 343 # Do the full computation. 344 cd, ce = u*d + v*e + md*M, q*d + r*e + me*M 345 # And cleanly divide by 2**N. 346 return cd >> N, ce >> N 347 ``` 348 349 One last optimization: we can avoid the *md M* and *me M* multiplications in the bottom bits of *cd* 350 and *ce* by moving them to the *md* and *me* correction: 351 352 ```python 353 ... 354 # Compute bottom N bits of t*[d,e]. 355 cd, ce = (u*d + v*e) % 2**N, (q*d + r*e) % 2**N 356 # Correct md and me such that the bottom N bits of t*[d,e]+M*[md,me] are zero. 357 # Note that this is not the same as {md = (-Mi * cd) % 2**N} etc. That would also result in N 358 # zero bottom bits, but isn't guaranteed to be a reduction of [0,2^N) compared to the 359 # previous md and me values, and thus would violate our bounds analysis. 360 md -= (Mi*cd + md) % 2**N 361 me -= (Mi*ce + me) % 2**N 362 ... 363 ``` 364 365 The resulting function takes *d* and *e* in range *(-2M,M)* as inputs, and outputs values in the same 366 range. That also means that the *d* value at the end of `modinv` will be in that range, while we want 367 a result in *[0,M)*. To do that, we need a normalization function. It's easy to integrate the 368 conditional negation of *d* (based on the sign of *f*) into it as well: 369 370 ```python 371 def normalize(sign, v, M): 372 """Compute sign*v mod M, where v is in range (-2*M,M); output in [0,M).""" 373 assert sign == 1 or sign == -1 374 # v in (-2*M,M) 375 if v < 0: 376 v += M 377 # v in (-M,M). Now multiply v with sign (which can only be 1 or -1). 378 if sign == -1: 379 v = -v 380 # v in (-M,M) 381 if v < 0: 382 v += M 383 # v in [0,M) 384 return v 385 ``` 386 387 And calling it in `modinv` is simply: 388 389 ```python 390 ... 391 return normalize(f, d, M) 392 ``` 393 394 395 ## 5. Constant-time operation 396 397 The primary selling point of the algorithm is fast constant-time operation. What code flow still 398 depends on the input data so far? 399 400 - the number of iterations of the while *g ≠ 0* loop in `modinv` 401 - the branches inside `divsteps_n_matrix` 402 - the sign checks in `update_de` 403 - the sign checks in `normalize` 404 405 To make the while loop in `modinv` constant time it can be replaced with a constant number of 406 iterations. The paper proves (Theorem 11.2) that *741* divsteps are sufficient for any *256*-bit 407 inputs, and [safegcd-bounds](https://github.com/sipa/safegcd-bounds) shows that the slightly better bound *724* is 408 sufficient even. Given that every loop iteration performs *N* divsteps, it will run a total of 409 *⌈724/N⌉* times. 410 411 To deal with the branches in `divsteps_n_matrix` we will replace them with constant-time bitwise 412 operations (and hope the C compiler isn't smart enough to turn them back into branches; see 413 `ctime_tests.c` for automated tests that this isn't the case). To do so, observe that a 414 divstep can be written instead as (compare to the inner loop of `gcd` in section 1). 415 416 ```python 417 x = -f if delta > 0 else f # set x equal to (input) -f or f 418 if g & 1: 419 g += x # set g to (input) g-f or g+f 420 if delta > 0: 421 delta = -delta 422 f += g # set f to (input) g (note that g was set to g-f before) 423 delta += 1 424 g >>= 1 425 ``` 426 427 To convert the above to bitwise operations, we rely on a trick to negate conditionally: per the 428 definition of negative numbers in two's complement, (*-v == ~v + 1*) holds for every number *v*. As 429 *-1* in two's complement is all *1* bits, bitflipping can be expressed as xor with *-1*. It follows 430 that *-v == (v ^ -1) - (-1)*. Thus, if we have a variable *c* that takes on values *0* or *-1*, then 431 *(v ^ c) - c* is *v* if *c=0* and *-v* if *c=-1*. 432 433 Using this we can write: 434 435 ```python 436 x = -f if delta > 0 else f 437 ``` 438 439 in constant-time form as: 440 441 ```python 442 c1 = (-delta) >> 63 443 # Conditionally negate f based on c1: 444 x = (f ^ c1) - c1 445 ``` 446 447 To use that trick, we need a helper mask variable *c1* that resolves the condition *δ>0* to *-1* 448 (if true) or *0* (if false). We compute *c1* using right shifting, which is equivalent to dividing by 449 the specified power of *2* and rounding down (in Python, and also in C under the assumption of a typical two's complement system; see 450 `assumptions.h` for tests that this is the case). Right shifting by *63* thus maps all 451 numbers in range *[-2<sup>63</sup>,0)* to *-1*, and numbers in range *[0,2<sup>63</sup>)* to *0*. 452 453 Using the facts that *x&0=0* and *x&(-1)=x* (on two's complement systems again), we can write: 454 455 ```python 456 if g & 1: 457 g += x 458 ``` 459 460 as: 461 462 ```python 463 # Compute c2=0 if g is even and c2=-1 if g is odd. 464 c2 = -(g & 1) 465 # This masks out x if g is even, and leaves x be if g is odd. 466 g += x & c2 467 ``` 468 469 Using the conditional negation trick again we can write: 470 471 ```python 472 if g & 1: 473 if delta > 0: 474 delta = -delta 475 ``` 476 477 as: 478 479 ```python 480 # Compute c3=-1 if g is odd and delta>0, and 0 otherwise. 481 c3 = c1 & c2 482 # Conditionally negate delta based on c3: 483 delta = (delta ^ c3) - c3 484 ``` 485 486 Finally: 487 488 ```python 489 if g & 1: 490 if delta > 0: 491 f += g 492 ``` 493 494 becomes: 495 496 ```python 497 f += g & c3 498 ``` 499 500 It turns out that this can be implemented more efficiently by applying the substitution 501 *η=-δ*. In this representation, negating *δ* corresponds to negating *η*, and incrementing 502 *δ* corresponds to decrementing *η*. This allows us to remove the negation in the *c1* 503 computation: 504 505 ```python 506 # Compute a mask c1 for eta < 0, and compute the conditional negation x of f: 507 c1 = eta >> 63 508 x = (f ^ c1) - c1 509 # Compute a mask c2 for odd g, and conditionally add x to g: 510 c2 = -(g & 1) 511 g += x & c2 512 # Compute a mask c for (eta < 0) and odd (input) g, and use it to conditionally negate eta, 513 # and add g to f: 514 c3 = c1 & c2 515 eta = (eta ^ c3) - c3 516 f += g & c3 517 # Incrementing delta corresponds to decrementing eta. 518 eta -= 1 519 g >>= 1 520 ``` 521 522 A variant of divsteps with better worst-case performance can be used instead: starting *δ* at 523 *1/2* instead of *1*. This reduces the worst case number of iterations to *590* for *256*-bit inputs 524 (which can be shown using convex hull analysis). In this case, the substitution *ζ=-(δ+1/2)* 525 is used instead to keep the variable integral. Incrementing *δ* by *1* still translates to 526 decrementing *ζ* by *1*, but negating *δ* now corresponds to going from *ζ* to *-(ζ+1)*, or 527 *~ζ*. Doing that conditionally based on *c3* is simply: 528 529 ```python 530 ... 531 c3 = c1 & c2 532 zeta ^= c3 533 ... 534 ``` 535 536 By replacing the loop in `divsteps_n_matrix` with a variant of the divstep code above (extended to 537 also apply all *f* operations to *u*, *v* and all *g* operations to *q*, *r*), a constant-time version of 538 `divsteps_n_matrix` is obtained. The full code will be in section 7. 539 540 These bit fiddling tricks can also be used to make the conditional negations and additions in 541 `update_de` and `normalize` constant-time. 542 543 544 ## 6. Variable-time optimizations 545 546 In section 5, we modified the `divsteps_n_matrix` function (and a few others) to be constant time. 547 Constant time operations are only necessary when computing modular inverses of secret data. In 548 other cases, it slows down calculations unnecessarily. In this section, we will construct a 549 faster non-constant time `divsteps_n_matrix` function. 550 551 To do so, first consider yet another way of writing the inner loop of divstep operations in 552 `gcd` from section 1. This decomposition is also explained in the paper in section 8.2. We use 553 the original version with initial *δ=1* and *η=-δ* here. 554 555 ```python 556 for _ in range(N): 557 if g & 1 and eta < 0: 558 eta, f, g = -eta, g, -f 559 if g & 1: 560 g += f 561 eta -= 1 562 g >>= 1 563 ``` 564 565 Whenever *g* is even, the loop only shifts *g* down and decreases *η*. When *g* ends in multiple zero 566 bits, these iterations can be consolidated into one step. This requires counting the bottom zero 567 bits efficiently, which is possible on most platforms; it is abstracted here as the function 568 `count_trailing_zeros`. 569 570 ```python 571 def count_trailing_zeros(v): 572 """ 573 When v is zero, consider all N zero bits as "trailing". 574 For a non-zero value v, find z such that v=(d<<z) for some odd d. 575 """ 576 if v == 0: 577 return N 578 else: 579 return (v & -v).bit_length() - 1 580 581 i = N # divsteps left to do 582 while True: 583 # Get rid of all bottom zeros at once. In the first iteration, g may be odd and the following 584 # lines have no effect (until "if eta < 0"). 585 zeros = min(i, count_trailing_zeros(g)) 586 eta -= zeros 587 g >>= zeros 588 i -= zeros 589 if i == 0: 590 break 591 # We know g is odd now 592 if eta < 0: 593 eta, f, g = -eta, g, -f 594 g += f 595 # g is even now, and the eta decrement and g shift will happen in the next loop. 596 ``` 597 598 We can now remove multiple bottom *0* bits from *g* at once, but still need a full iteration whenever 599 there is a bottom *1* bit. In what follows, we will get rid of multiple *1* bits simultaneously as 600 well. 601 602 Observe that as long as *η ≥ 0*, the loop does not modify *f*. Instead, it cancels out bottom 603 bits of *g* and shifts them out, and decreases *η* and *i* accordingly - interrupting only when *η* 604 becomes negative, or when *i* reaches *0*. Combined, this is equivalent to adding a multiple of *f* to 605 *g* to cancel out multiple bottom bits, and then shifting them out. 606 607 It is easy to find what that multiple is: we want a number *w* such that *g+w f* has a few bottom 608 zero bits. If that number of bits is *L*, we want *g+w f mod 2<sup>L</sup> = 0*, or *w = -g/f mod 2<sup>L</sup>*. Since *f* 609 is odd, such a *w* exists for any *L*. *L* cannot be more than *i* steps (as we'd finish the loop before 610 doing more) or more than *η+1* steps (as we'd run `eta, f, g = -eta, g, -f` at that point), but 611 apart from that, we're only limited by the complexity of computing *w*. 612 613 This code demonstrates how to cancel up to 4 bits per step: 614 615 ```python 616 NEGINV16 = [15, 5, 3, 9, 7, 13, 11, 1] # NEGINV16[n//2] = (-n)^-1 mod 16, for odd n 617 i = N 618 while True: 619 zeros = min(i, count_trailing_zeros(g)) 620 eta -= zeros 621 g >>= zeros 622 i -= zeros 623 if i == 0: 624 break 625 # We know g is odd now 626 if eta < 0: 627 eta, f, g = -eta, g, -f 628 # Compute limit on number of bits to cancel 629 limit = min(min(eta + 1, i), 4) 630 # Compute w = -g/f mod 2**limit, using the table value for -1/f mod 2**4. Note that f is 631 # always odd, so its inverse modulo a power of two always exists. 632 w = (g * NEGINV16[(f & 15) // 2]) % (2**limit) 633 # As w = -g/f mod (2**limit), g+w*f mod 2**limit = 0 mod 2**limit. 634 g += w * f 635 assert g % (2**limit) == 0 636 # The next iteration will now shift out at least limit bottom zero bits from g. 637 ``` 638 639 By using a bigger table more bits can be cancelled at once. The table can also be implemented 640 as a formula. Several formulas are known for computing modular inverses modulo powers of two; 641 some can be found in Hacker's Delight second edition by Henry S. Warren, Jr. pages 245-247. 642 Here we need the negated modular inverse, which is a simple transformation of those: 643 644 - Instead of a 3-bit table: 645 - *-f* or *f ^ 6* 646 - Instead of a 4-bit table: 647 - *1 - f(f + 1)* 648 - *-(f + (((f + 1) & 4) << 1))* 649 - For larger tables the following technique can be used: if *w=-1/f mod 2<sup>L</sup>*, then *w(w f+2)* is 650 *-1/f mod 2<sup>2L</sup>*. This allows extending the previous formulas (or tables). In particular we 651 have this 6-bit function (based on the 3-bit function above): 652 - *f(f<sup>2</sup> - 2)* 653 654 This loop, again extended to also handle *u*, *v*, *q*, and *r* alongside *f* and *g*, placed in 655 `divsteps_n_matrix`, gives a significantly faster, but non-constant time version. 656 657 658 ## 7. Final Python version 659 660 All together we need the following functions: 661 662 - A way to compute the transition matrix in constant time, using the `divsteps_n_matrix` function 663 from section 2, but with its loop replaced by a variant of the constant-time divstep from 664 section 5, extended to handle *u*, *v*, *q*, *r*: 665 666 ```python 667 def divsteps_n_matrix(zeta, f, g): 668 """Compute zeta and transition matrix t after N divsteps (multiplied by 2^N).""" 669 u, v, q, r = 1, 0, 0, 1 # start with identity matrix 670 for _ in range(N): 671 c1 = zeta >> 63 672 # Compute x, y, z as conditionally-negated versions of f, u, v. 673 x, y, z = (f ^ c1) - c1, (u ^ c1) - c1, (v ^ c1) - c1 674 c2 = -(g & 1) 675 # Conditionally add x, y, z to g, q, r. 676 g, q, r = g + (x & c2), q + (y & c2), r + (z & c2) 677 c1 &= c2 # reusing c1 here for the earlier c3 variable 678 zeta = (zeta ^ c1) - 1 # inlining the unconditional zeta decrement here 679 # Conditionally add g, q, r to f, u, v. 680 f, u, v = f + (g & c1), u + (q & c1), v + (r & c1) 681 # When shifting g down, don't shift q, r, as we construct a transition matrix multiplied 682 # by 2^N. Instead, shift f's coefficients u and v up. 683 g, u, v = g >> 1, u << 1, v << 1 684 return zeta, (u, v, q, r) 685 ``` 686 687 - The functions to update *f* and *g*, and *d* and *e*, from section 2 and section 4, with the constant-time 688 changes to `update_de` from section 5: 689 690 ```python 691 def update_fg(f, g, t): 692 """Multiply matrix t/2^N with [f, g].""" 693 u, v, q, r = t 694 cf, cg = u*f + v*g, q*f + r*g 695 return cf >> N, cg >> N 696 697 def update_de(d, e, t, M, Mi): 698 """Multiply matrix t/2^N with [d, e], modulo M.""" 699 u, v, q, r = t 700 d_sign, e_sign = d >> 257, e >> 257 701 md, me = (u & d_sign) + (v & e_sign), (q & d_sign) + (r & e_sign) 702 cd, ce = (u*d + v*e) % 2**N, (q*d + r*e) % 2**N 703 md -= (Mi*cd + md) % 2**N 704 me -= (Mi*ce + me) % 2**N 705 cd, ce = u*d + v*e + M*md, q*d + r*e + M*me 706 return cd >> N, ce >> N 707 ``` 708 709 - The `normalize` function from section 4, made constant time as well: 710 711 ```python 712 def normalize(sign, v, M): 713 """Compute sign*v mod M, where v in (-2*M,M); output in [0,M).""" 714 v_sign = v >> 257 715 # Conditionally add M to v. 716 v += M & v_sign 717 c = (sign - 1) >> 1 718 # Conditionally negate v. 719 v = (v ^ c) - c 720 v_sign = v >> 257 721 # Conditionally add M to v again. 722 v += M & v_sign 723 return v 724 ``` 725 726 - And finally the `modinv` function too, adapted to use *ζ* instead of *δ*, and using the fixed 727 iteration count from section 5: 728 729 ```python 730 def modinv(M, Mi, x): 731 """Compute the modular inverse of x mod M, given Mi=1/M mod 2^N.""" 732 zeta, f, g, d, e = -1, M, x, 0, 1 733 for _ in range((590 + N - 1) // N): 734 zeta, t = divsteps_n_matrix(zeta, f % 2**N, g % 2**N) 735 f, g = update_fg(f, g, t) 736 d, e = update_de(d, e, t, M, Mi) 737 return normalize(f, d, M) 738 ``` 739 740 - To get a variable time version, replace the `divsteps_n_matrix` function with one that uses the 741 divsteps loop from section 5, and a `modinv` version that calls it without the fixed iteration 742 count: 743 744 ```python 745 NEGINV16 = [15, 5, 3, 9, 7, 13, 11, 1] # NEGINV16[n//2] = (-n)^-1 mod 16, for odd n 746 def divsteps_n_matrix_var(eta, f, g): 747 """Compute eta and transition matrix t after N divsteps (multiplied by 2^N).""" 748 u, v, q, r = 1, 0, 0, 1 749 i = N 750 while True: 751 zeros = min(i, count_trailing_zeros(g)) 752 eta, i = eta - zeros, i - zeros 753 g, u, v = g >> zeros, u << zeros, v << zeros 754 if i == 0: 755 break 756 if eta < 0: 757 eta, f, u, v, g, q, r = -eta, g, q, r, -f, -u, -v 758 limit = min(min(eta + 1, i), 4) 759 w = (g * NEGINV16[(f & 15) // 2]) % (2**limit) 760 g, q, r = g + w*f, q + w*u, r + w*v 761 return eta, (u, v, q, r) 762 763 def modinv_var(M, Mi, x): 764 """Compute the modular inverse of x mod M, given Mi = 1/M mod 2^N.""" 765 eta, f, g, d, e = -1, M, x, 0, 1 766 while g != 0: 767 eta, t = divsteps_n_matrix_var(eta, f % 2**N, g % 2**N) 768 f, g = update_fg(f, g, t) 769 d, e = update_de(d, e, t, M, Mi) 770 return normalize(f, d, Mi) 771 ``` 772 773 ## 8. From GCDs to Jacobi symbol 774 775 We can also use a similar approach to calculate Jacobi symbol *(x | M)* by keeping track of an 776 extra variable *j*, for which at every step *(x | M) = j (g | f)*. As we update *f* and *g*, we 777 make corresponding updates to *j* using 778 [properties of the Jacobi symbol](https://en.wikipedia.org/wiki/Jacobi_symbol#Properties): 779 * *((g/2) | f)* is either *(g | f)* or *-(g | f)*, depending on the value of *f mod 8* (negating if it's *3* or *5*). 780 * *(f | g)* is either *(g | f)* or *-(g | f)*, depending on *f mod 4* and *g mod 4* (negating if both are *3*). 781 782 These updates depend only on the values of *f* and *g* modulo *4* or *8*, and can thus be applied 783 very quickly, as long as we keep track of a few additional bits of *f* and *g*. Overall, this 784 calculation is slightly simpler than the one for the modular inverse because we no longer need to 785 keep track of *d* and *e*. 786 787 However, one difficulty of this approach is that the Jacobi symbol *(a | n)* is only defined for 788 positive odd integers *n*, whereas in the original safegcd algorithm, *f, g* can take negative 789 values. We resolve this by using the following modified steps: 790 791 ```python 792 # Before 793 if delta > 0 and g & 1: 794 delta, f, g = 1 - delta, g, (g - f) // 2 795 796 # After 797 if delta > 0 and g & 1: 798 delta, f, g = 1 - delta, g, (g + f) // 2 799 ``` 800 801 The algorithm is still correct, since the changed divstep, called a "posdivstep" (see section 8.4 802 and E.5 in the paper) preserves *gcd(f, g)*. However, there's no proof that the modified algorithm 803 will converge. The justification for posdivsteps is completely empirical: in practice, it appears 804 that the vast majority of nonzero inputs converge to *f=g=gcd(f<sub>0</sub>, g<sub>0</sub>)* in a 805 number of steps proportional to their logarithm. 806 807 Note that: 808 - We require inputs to satisfy *gcd(x, M) = 1*, as otherwise *f=1* is not reached. 809 - We require inputs *x &neq; 0*, because applying posdivstep with *g=0* has no effect. 810 - We need to update the termination condition from *g=0* to *f=1*. 811 812 We account for the possibility of nonconvergence by only performing a bounded number of 813 posdivsteps, and then falling back to square-root based Jacobi calculation if a solution has not 814 yet been found. 815 816 The optimizations in sections 3-7 above are described in the context of the original divsteps, but 817 in the C implementation we also adapt most of them (not including "avoiding modulus operations", 818 since it's not necessary to track *d, e*, and "constant-time operation", since we never calculate 819 Jacobi symbols for secret data) to the posdivsteps version.