# -*- Mode: Python -*-

def egcd (a, b):
    if a == 0:
        return b, 0, 1
    else:
        q, r = divmod (b, a)
        g, y, x = egcd (r, a)
        return g, x - q * y, y

class NoInverse (Exception):
    pass

def modinv (a, p):
    if a < 0:
        return p - modinv (-a, p)
    else:
        g, x, y = egcd (a, p)
        if g != 1:
            raise NoInverse (a)
        else:
            return x % p

def mod (n, m):
    r = n % m
    if r < 0:
        return n - r
    else:
        return r

class Monty:
    def __init__ (self, N, R):
        self.N = N
        self.R = R
        self.R1 = modinv (R, N)
        self.N1 = (self.R1 * R) // N
        self.R2N = (R * R) % N
        self.R3N = (R * self.R2N) % N
        print ("N = %r R = %r" % (N, R))
        print ("N' = %r R' = %r" % (self.N1, self.R1))
        print ("RR' - NN' = %r" % (self.R * self.R1 - self.N * self.N1))
        print ("R^2N = %r" % (self.R2N,))
        print ("R^3N = %r" % (self.R3N,))

    def redc (self, T):
        m = mod (mod (T, self.R) * self.N1, self.R)
        t = (T + m * self.N) // self.R
        #assert mod (T + m * self.N, self.R) == 0
        if not t < self.N:
            return t - self.N
        else:
            return t

    def tm (self, a):
        return self.redc (mod (a, self.N) * self.R2N)

    # # equivalent to the above
    # def tm (self, a):
    #     return mod (a * self.R, self.N)

    def fm (self, a):
        return self.redc (a)

    # call *only* with int argument
    def __call__ (self, n):
        return MontyInt (self, self.tm (n))

class MontyInt:

    # do not call this unless you understand what you are doing.
    # create MontyInts by calling a Monty instance.
    def __init__ (self, M, v):
        self.M = M
        self.v = v

    def __repr__ (self):
        return f'<M({self.M.N},{self.M.R}) {self.v} = {self.M.fm(self.v)}>'

    def __check__ (self, other):
        if not (isinstance (other, MontyInt) and other.M is self.M):
            raise ValueError (other)

    def __add__ (self, other):
        self.__check__ (other)
        return MontyInt (self.M, self.v + other.v)

    def __sub__ (self, other):
        self.__check__ (other)
        r = self.v - other.v
        if r < 0:
            r = r + self.M.N
        return MontyInt (self.M, r)

    def __mul__ (self, other):
        self.__check__ (other)
        return MontyInt (self.M, self.M.redc (self.v * other.v))

    def __eq__ (self, other):
        self.__check__ (other)
        return self.v == other.v

    def redc (self):
        return MontyInt (self.M, self.M.redc (self.v))

    def inv (self):
        return MontyInt (self.M, modinv (self.redc().redc().v, self.M.N))

    def __truediv__ (self, other):
        self.__check__ (other)
        return self * other.inv()

    def pow (self, n):
        assert isinstance (n, int)
        z = self.M(1)
        x = self
        while n > 0:
            if n & 1:
                z = z * x
            x = x * x
            n >>= 1
        return z

    def __div__ (self, other):
        self.__check__ (other)
        self.M.redc (other.v)



# --------------------------------------------------------------------------------
# https://stackoverflow.com/questions/32871539/integer-factorization-in-python
from math import gcd

def factorization(n):

    factors = []

    def get_factor(n):
        x_fixed = 2
        cycle_size = 2
        x = 2
        factor = 1

        while factor == 1:
            for count in range(cycle_size):
                if factor > 1: break
                x = (x * x + 1) % n
                factor = gcd(x - x_fixed, n)

            cycle_size *= 2
            x_fixed = x

        return factor

    while n > 1:
        next = get_factor(n)
        factors.append(next)
        n //= next

    return factors
# --------------------------------------------------------------------------------


# find primes near 1048576 (1<<20)
def near_2_20():
    r = []
    n = (1<<20) - 1
    while len(r) < 10:
        factors = factorization (n)
        if len(factors) == 1:
            r.append (factors[0])
        n -= 2
    return r


M = Monty (59, 64)
a = M(3)
b = M(5)
c = M(12)


# P = Monty (1021, 1024)
# N = Monty (1009, 1021)

# P = Monty (1021, 1024)
# N = Monty (997, 1021)

P = Monty (646273, 1048576)
N = Monty (360551, 646273)

p0 = 1048576
p1 = 1048573
#p2 = 1048571
p2 = 1048559

P = Monty (p1, p0)
N = Monty (p2, p1)

# P = Monty (1890230747, 2147483648)
# N = Monty (1033497221, 1890230747)

# p0 = 147573952589676412928 # 1<<67
# p1 = 92651585035625003099
# p2 = 81324886740482816237
# P = Monty (p1, p0)
# N = Monty (p2, p1)


# secp256r1
# 1<<256
# p0 = 115792089237316195423570985008687907853269984665640564039457584007913129639936
# p1 = 0xffffffff00000001000000000000000000000000ffffffffffffffffffffffff
# p2 = 0xffffffff00000000ffffffffffffffffbce6faada7179e84f3b9cac2fc632551
# P = Monty (p1, p0)
# N = Monty (p2, p1)

import random
from pprint import pprint as pp

def find_off_by_ones():
    r = []
    for i in range (100000):
        x = random.randint (1, N.N)
        y = random.randint (1, N.N)
        xN = N(x)
        yN = N(y)
        xNP = P(xN.v)
        yNP = P(yN.v)
        zNP = xNP + yNP
        z0 = N.fm (P.fm (zNP.v))
        delta = ((x + y) % N.N) - z0
        if delta:
            r.append ((x,y,delta))
    return r

def draw_off_by_ones (p0,p1,r):
    from PIL import Image, ImageDraw

    def txfm (p):
        return p * 1024 // P.R

    P = Monty (p1, r)
    N = Monty (p0, p1)

    img = Image.new ('1', (1024,1024), 'white')
    drw = ImageDraw.Draw (img)
    for i in range (1_000_000):
        x = random.randint (1, N.N)
        y = random.randint (1, N.N)
        xN = N(x)
        yN = N(y)
        xNP = P(xN.v)
        yNP = P(yN.v)
        zNP = xNP + yNP
        z0 = N.fm (P.fm (zNP.v))
        delta = ((x + y) % N.N) - z0
        if delta:
            drw.point ((txfm(x), txfm(y)), 'black')
    img.save (f'/tmp/off_{p0}_{p1}.pbm')

from itertools import combinations

def several():
    r = 1<<20
    primes = [1048573, 1048571, 1048559, 1048549, 1048517, 1048507, 1048447, 1048433, 1048423, 1048391]
    for p1, p0 in combinations (primes, 2):
        draw_off_by_ones (p0, p1, r)
