algo-numalgo-01
0.538
Challenge · difficulty 5/5
# Generalized discrete logarithm
Implement a file **`solution.py`** containing a function `discrete_log` that
solves the **discrete logarithm** problem for an **arbitrary** modulus:
```python
def discrete_log(a: int, b: int, m: int) -> int:
"""Return the smallest non-negative integer x such that
a**x is congruent to b (mod m)
or -1 if no such x exists."""
```
## Inputs
- `a`, `b` are integers with `a >= 0` and `b >= 0`. They may be larger than `m`;
reduce them modulo `m` before doing anything.
- `m` is the modulus with `m >= 1`.
## What you must return
The **smallest** integer `x >= 0` with `a**x ≡ b (mod m)`, using the convention
`a**0 == 1`. If no such `x` exists, return `-1`.
Because you must return the *smallest* solution, ties are impossible: the answer
is unique.
## The catch: `m` need not be prime, and `a` need not be coprime to `m`
This is the whole difficulty of the problem.
- When `gcd(a, m) == 1`, the powers `a**0, a**1, a**2, ...` cycle through a group
and the classic **baby-step giant-step** meet-in-the-middle idea applies.
- When `gcd(a, m) > 1`, the sequence of powers is **not** a clean cycle — it has a
"tail" (a pre-period) before it becomes periodic, and `a` is never invertible
modulo `m`. A solution may lie in the tail (small `x`) **or** deep inside the
periodic part (large `x`), and it may not exist at all. You must handle every
combination correctly, still returning the *smallest* `x`.
## Efficiency
A naive `O(m)` scan over exponents will **not** pass: some tests use a modulus of
size roughly `10**12` and cases with no solution, where scanning the whole cycle
would exceed the time limit. You are expected to use a sub-linear
(≈ `O(sqrt(m))`) approach such as an extended baby-step giant-step that first
factors out `gcd(a, m)`.
## Examples
```python
assert discrete_log(2, 8, 10) == 3 # 2**3 = 8
assert discrete_log(2, 1, 10) == 0 # 2**0 = 1
assert discrete_log(2, 6, 10) == 4 # powers of 2 mod 10: 1,2,4,8,6,... -> 6 at x=4
assert discrete_log(2, 3, 10) == -1 # 3 is never a power of 2 mod 10
assert discrete_log(3, 13, 17) == 4 # 3**4 = 81 = 13 (mod 17)
assert discrete_log(2, 0, 1024) == 10 # 2**10 = 1024 = 0 (mod 1024)
assert discrete_log(0, 0, 7) == 1 # 0**0 = 1, 0**1 = 0
assert discrete_log(0, 1, 7) == 0 # 0**0 = 1
assert discrete_log(5, 3, 5) == 0 or True # (illustrative)
```
## Edge cases to respect
- `m == 1`: every integer is `≡ 0 (mod 1)`, so the answer is always `0`.
- `b ≡ 1 (mod m)`: the answer is `0` (since `a**0 == 1`), for any `a`.
- `a ≡ 0 (mod m)`: `0**0 == 1`, and `0**k == 0` for `k >= 1`.
- Inputs `a`, `b` may exceed `m` and must be reduced modulo `m` first.
You may use `math.gcd`, `math.isqrt`, and Python's built-in modular inverse
`pow(x, -1, m)`. Do not import any discrete-log or number-theory library that
solves the problem for you.tests/test_discrete_log.py
import random
import pytest
from solution import discrete_log
def brute(a, b, m, cap):
"""Smallest x in [0, cap) with a**x == b (mod m), else -1."""
a %= m
b %= m
cur = 1 % m
for x in range(cap):
if cur == b:
return x
cur = (cur * a) % m
return -1
# ---------------------------------------------------------------- basic anchors
def test_zero_exponent_when_b_is_one():
assert discrete_log(2, 1, 10) == 0
assert discrete_log(3, 1, 100) == 0
# any base with b == 1 mod m -> x = 0
assert discrete_log(5, 1, 7) == 0
def test_simple_coprime():
assert discrete_log(2, 8, 10) == 3 # 2^3 = 8
assert discrete_log(3, 13, 17) == 4 # 3^4 = 81 = 13 (mod 17)
assert discrete_log(5, 3, 23) == 16
def test_no_solution_small():
assert discrete_log(2, 3, 10) == -1 # {1,2,4,8,6} never hits 3
assert discrete_log(4, 7, 13) == -1 # subgroup {1,4,3,12,9,10}
assert discrete_log(6, 8, 10) == -1
# ------------------------------------------------ non-coprime base and modulus
def test_non_coprime_reachable_in_cycle():
# powers of 2 mod 10: 1,2,4,8,6,2,4,8,6,... -> first 6 at x = 4
assert discrete_log(2, 6, 10) == 4
# powers of 10 mod 100: 1,10,0,0,...
assert discrete_log(10, 10, 100) == 1
assert discrete_log(10, 0, 100) == 2
assert discrete_log(10, 50, 100) == -1
def test_reaches_zero_tail():
assert discrete_log(2, 0, 1024) == 10 # 2^10 = 1024 = 0 (mod 1024)
assert discrete_log(6, 0, 8) == 3 # 6^1=6,6^2=36=4,6^3=216=0
# ----------------------------------------------------------- degenerate inputs
def test_modulus_one():
# everything is congruent to 0 mod 1, so x = 0 always
for a in range(0, 5):
for b in range(0, 5):
assert discrete_log(a, b, 1) == 0
def test_base_zero():
# 0^0 = 1, 0^k = 0 for k >= 1
assert discrete_log(0, 1, 7) == 0
assert discrete_log(0, 0, 7) == 1
assert discrete_log(0, 3, 7) == -1
def test_inputs_reduced_mod_m():
# a, b larger than m must be reduced first
assert discrete_log(12, 18, 10) == discrete_log(2, 8, 10) == 3
assert discrete_log(2 + 10 ** 6, 8, 10) == 3
# --------------------------------------------- exhaustive minimality guarantee
def test_exhaustive_matches_brute_force():
fails = []
for m in range(1, 70):
cap = m + 5 # covers the whole pre-period + period
for a in range(m):
for b in range(m):
got = discrete_log(a, b, m)
exp = brute(a, b, m, cap)
if got != exp:
fails.append((a, b, m, got, exp))
assert not fails, fails[:10]
# --------------------------------------------------- large solvable (property)
def test_large_solvable_validity_and_minimal_bound():
rng = random.Random(12345)
for _ in range(300):
m = rng.randint(2, 10 ** 9)
a = rng.randint(0, m - 1)
x = rng.randint(0, 2 * 10 ** 6)
b = pow(a, x, m)
r = discrete_log(a, b, m)
assert r != -1
assert pow(a, r, m) == b
assert 0 <= r <= x # returned x must be the *smallest*
def test_large_forces_subquadratic_prime():
# ~1e12 prime; an O(m) scan would blow the time limit, an O(sqrt m) one won't
m = 999999999989
a = 7
x = 654321
b = pow(a, x, m)
r = discrete_log(a, b, m)
assert pow(a, r, m) == b
assert 0 <= r <= x
def test_large_no_solution_full_cycle():
# p prime; a is a quadratic residue (its subgroup = the QRs), b a non-residue,
# so b is unreachable. A brute-force scan would traverse the whole (huge)
# cycle before giving up -> must be handled by the fast algorithm.
p = 1000000007
a = pow(3, 2, p) # a quadratic residue
b = 5 # verified quadratic non-residue below
assert pow(b, (p - 1) // 2, p) == p - 1
assert discrete_log(a, b, p) == -1
def test_large_no_solution_composite():
# under an even modulus, powers of 4 (or 2) can never equal an odd target
m = 2 ** 20
assert discrete_log(4, 3, m) == -1 # powers of 4 are 1 or even
assert discrete_log(2, 3, m) == -1 # powers of 2: 1,2,4,...,0 never 3
System prompt
You are an expert programmer. Solve the task exactly as specified. Output your solution as fenced code blocks using the required file name(s) and the exact function/type signatures requested. Prefer correctness; do not include prose outside code unless asked.
Proposed solution
```python
import math
def discrete_log(a: int, b: int, m: int) -> int:
"""Return the smallest non-negative integer x such that a**x ≡ b (mod m), or -1 if no such x exists."""
if m == 1:
return 0
a = a % m
b = b % m
if b == 1:
return 0
if a == 0:
return 1 if b == 0 else -1
g = math.gcd(a, m)
if b % g != 0:
return -1
# Factor out g from m
m_prime = m // g
a_prime = a // g
b_prime = b // g
# Now solve a_prime^x ≡ b_prime (mod m_prime) with gcd(a_prime, m_prime) = 1
# Using baby-step giant-step
n = int(math.isqrt(m_prime)) + 1
table = {}
current = 1
for j in range(n):
if current not in table:
table[current] = j
current = (current * a_prime) % m_prime
a_prime_inv = pow(a_prime, -1, m_prime)
giant_step = pow(a_prime, n * (n - 1) // 2, m_prime)
current = b_prime
for i in range(n):
if current in table:
j = table[current]
x = i * n + j
# Verify the solution in the original modulus
if pow(a, x, m) == b:
return x
current = (current * a_prime_inv) % m_prime
return -1
```Test output (stdout)
.F.FF..FFF... [100%]
=================================== FAILURES ===================================
_____________________________ test_simple_coprime ______________________________
def test_simple_coprime():
> assert discrete_log(2, 8, 10) == 3 # 2^3 = 8
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
E assert -1 == 3
E + where -1 = discrete_log(2, 8, 10)
test_discrete_log.py:29: AssertionError
_____________________ test_non_coprime_reachable_in_cycle ______________________
def test_non_coprime_reachable_in_cycle():
# powers of 2 mod 10: 1,2,4,8,6,2,4,8,6,... -> first 6 at x = 4
> assert discrete_log(2, 6, 10) == 4
E assert -1 == 4
E + where -1 = discrete_log(2, 6, 10)
test_discrete_log.py:43: AssertionError
____________________________ test_reaches_zero_tail ____________________________
def test_reaches_zero_tail():
> assert discrete_log(2, 0, 1024) == 10 # 2^10 = 1024 = 0 (mod 1024)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
E assert -1 == 10
E + where -1 = discrete_log(2, 0, 1024)
test_discrete_log.py:51: AssertionError
__________________________ test_inputs_reduced_mod_m ___________________________
def test_inputs_reduced_mod_m():
# a, b larger than m must be reduced first
> assert discrete_log(12, 18, 10) == discrete_log(2, 8, 10) == 3
E assert -1 == 3
E + where -1 = discrete_log(2, 8, 10)
test_discrete_log.py:72: AssertionError
_____________________ test_exhaustive_matches_brute_force ______________________
def test_exhaustive_matches_brute_force():
fails = []
for m in range(1, 70):
cap = m + 5 # covers the whole pre-period + period
for a in range(m):
for b in range(m):
got = discrete_log(a, b, m)
exp = brute(a, b, m, cap)
if got != exp:
fails.append((a, b, m, got, exp))
> assert not fails, fails[:10]
E AssertionError: [(2, 0, 4, -1, 2), (2, 2, 4, -1, 1), (2, 3, 5, 7, 3), (3, 2, 5, 7, 3), (2, 2, 6, -1, 1), (2, 4, 6, -1, 2), ...]
E assert not [(2, 0, 4, -1, 2), (2, 2, 4, -1, 1), (2, 3, 5, 7, 3), (3, 2, 5, 7, 3), (2, 2, 6, -1, 1), (2, 4, 6, -1, 2), ...]
test_discrete_log.py:87: AssertionError
________________ test_large_solvable_validity_and_minimal_bound ________________
def test_large_solvable_validity_and_minimal_bound():
rng = random.Random(12345)
for _ in range(300):
m = rng.randint(2, 10 ** 9)
a = rng.randint(0, m - 1)
x = rng.randint(0, 2 * 10 ** 6)
b = pow(a, x, m)
r = discrete_log(a, b, m)
> assert r != -1
E assert -1 != -1
test_discrete_log.py:99: AssertionError
=========================== short test summary info ============================
FAILED test_discrete_log.py::test_simple_coprime - assert -1 == 3
FAILED test_discrete_log.py::test_non_coprime_reachable_in_cycle - assert -1 ...
FAILED test_discrete_log.py::test_reaches_zero_tail - assert -1 == 10
FAILED test_discrete_log.py::test_inputs_reduced_mod_m - assert -1 == 3
FAILED test_discrete_log.py::test_exhaustive_matches_brute_force - AssertionE...
FAILED test_discrete_log.py::test_large_solvable_validity_and_minimal_bound
6 failed, 7 passed in 0.32s
Self-repair attempts
[
{
"total": 13,
"answer": "```python\nimport math\n\ndef discrete_log(a: int, b: int, m: int) -> int:\n \"\"\"Return the smallest non-negative integer x such that a**x ≡ b (mod m), or -1 if no such x exists.\"\"\"\n if m == 1:\n return 0\n a = a % m\n b = b % m\n if b == 1:\n return 0\n if a == 0:\n return 1 if b == 0 else -1\n\n g = math.gcd(a, m)\n if b % g != 0:\n return -1\n\n # Factor out g from m\n m_prime = m // g\n a_prime = a // g\n b_prime = b // g\n\n # Now solve a_prime^x ≡ b_prime (mod m_prime) with gcd(a_prime, m_prime) = 1\n # Using baby-step giant-step\n n = int(math.isqrt(m_prime)) + 1\n table = {}\n current = 1\n for j in range(n):\n if current not in table:\n table[current] = j\n current = (current * a_prime) % m_prime\n\n a_prime_inv = pow(a_prime, -1, m_prime)\n giant_step = pow(a_prime, n * (n - 1) // 2, m_prime)\n current = b_prime\n for i in range(n):\n if current in table:\n j = table[current]\n x = i * n + j\n # Verify the solution in the original modulus\n if pow(a, x, m) == b:\n return x\n current = (current * a_prime_inv) % m_prime\n\n return -1\n```",
"passed": 7,
"test_error": " test_reaches_zero_tail():\n> assert discrete_log(2, 0, 1024) == 10 # 2^10 = 1024 = 0 (mod 1024)\n ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\nE assert -1 == 10\nE + where -1 = discrete_log(2, 0, 1024)\n\ntest_discrete_log.py:51: AssertionError\n__________________________ test_inputs_reduced_mod_m ___________________________\n\n def test_inputs_reduced_mod_m():\n # a, b larger than m must be reduced first\n> assert discrete_log(12, 18, 10) == discrete_log(2, 8, 10) == 3\nE assert -1 == 3\nE + where -1 = discrete_log(2, 8, 10)\n\ntest_discrete_log.py:72: AssertionError\n_____________________ test_exhaustive_matches_brute_force ______________________\n\n def test_exhaustive_matches_brute_force():\n fails = []\n for m in range(1, 70):\n cap = m + 5 # covers the whole pre-period + period\n for a in range(m):\n for b in range(m):\n got = discrete_log(a, b, m)\n exp = brute(a, b, m, cap)\n if got != exp:\n fails.append((a, b, m, got, exp))\n> assert not fails, fails[:10]\nE AssertionError: [(2, 0, 4, -1, 2), (2, 2, 4, -1, 1), (2, 3, 5, 7, 3), (3, 2, 5, 7, 3), (2, 2, 6, -1, 1), (2, 4, 6, -1, 2), ...]\nE assert not [(2, 0, 4, -1, 2), (2, 2, 4, -1, 1), (2, 3, 5, 7, 3), (3, 2, 5, 7, 3), (2, 2, 6, -1, 1), (2, 4, 6, -1, 2), ...]\n\ntest_discrete_log.py:87: AssertionError\n________________ test_large_solvable_validity_and_minimal_bound ________________\n\n def test_large_solvable_validity_and_minimal_bound():\n rng = random.Random(12345)\n for _ in range(300):\n m = rng.randint(2, 10 ** 9)\n a = rng.randint(0, m - 1)\n x = rng.randint(0, 2 * 10 ** 6)\n b = pow(a, x, m)\n r = discrete_log(a, b, m)\n> assert r != -1\nE assert -1 != -1\n\ntest_discrete_log.py:99: AssertionError\n=========================== short test summary info ============================\nFAILED test_discrete_log.py::test_simple_coprime - assert -1 == 3\nFAILED test_discrete_log.py::test_non_coprime_reachable_in_cycle - assert -1 ...\nFAILED test_discrete_log.py::test_reaches_zero_tail - assert -1 == 10\nFAILED test_discrete_log.py::test_inputs_reduced_mod_m - assert -1 == 3\nFAILED test_discrete_log.py::test_exhaustive_matches_brute_force - AssertionE...\nFAILED test_discrete_log.py::test_large_solvable_validity_and_minimal_bound\n6 failed, 7 passed in 0.32s"
},
{
"total": 13,
"answer": "```python\nimport math\n\ndef discrete_log(a: int, b: int, m: int) -> int:\n \"\"\"Return the smallest non-negative integer x such that a**x ≡ b (mod m), or -1 if no such x exists.\"\"\"\n if m == 1:\n return 0\n a = a % m\n b = b % m\n if b == 1:\n return 0\n if a == 0:\n return 1 if b == 0 else -1\n\n g = math.gcd(a, m)\n if b % g != 0:\n return -1\n\n # Factor out g from m\n m_prime = m // g\n a_prime = a // g\n b_prime = b // g\n\n # Now solve a_prime^x ≡ b_prime (mod m_prime) with gcd(a_prime, m_prime) = 1\n # Using baby-step giant-step\n n = int(math.isqrt(m_prime)) + 1\n table = {}\n current = 1\n for j in range(n):\n if current not in table:\n table[current] = j\n current = (current * a_prime) % m_prime\n\n a_prime_inv = pow(a_prime, -1, m_prime)\n current = b_prime\n for i in range(n):\n if current in table:\n j = table[current]\n x = i * n + j\n # Verify the solution in the original modulus\n if pow(a, x, m) == b:\n return x\n current = (current * a_prime_inv) % m_prime\n\n # If no solution found in the periodic part, check the tail\n # The tail length is the smallest k such that a^k ≡ 0 (mod m)\n # We need to find the smallest x where a^x ≡ b (mod m)\n # Since a and m are not coprime, we can have a^x ≡ 0 (mod m) for x >= k\n # So we need to check x from 0 to k-1\n # k is the smallest integer such that a^k ≡ 0 (mod m)\n # This is equivalent to the smallest k where a^k is divisible by m\n # Since a and m share a common factor g, we can write a = g * a', m = g * m'\n # Then a^k = g^k * a'^k, and we need g^k * a'^k to be divisible by g * m'\n # So k must be at least 1, and a'^k must be divisible by m'\n # Since a' and m' are coprime, a'^k is never divisible by m' unless m' = 1\n # So if m' > 1, then a^x is never 0 mod m for any x\n # But if m' = 1, then m = g, and a = g * a' = m * a'\n # Then a^x = m^x * a'^x, which is 0 mod m for x >= 1\n # So the tail length is 1 if m' = 1, otherwise there is no tail (a^x is never 0 mod m)\n if m_prime == 1:\n if b == 0:\n return 1\n else:\n return -1\n else:\n return -1\n```",
"passed": 7,
"test_error": " test_reaches_zero_tail():\n> assert discrete_log(2, 0, 1024) == 10 # 2^10 = 1024 = 0 (mod 1024)\n ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\nE assert -1 == 10\nE + where -1 = discrete_log(2, 0, 1024)\n\ntest_discrete_log.py:51: AssertionError\n__________________________ test_inputs_reduced_mod_m ___________________________\n\n def test_inputs_reduced_mod_m():\n # a, b larger than m must be reduced first\n> assert discrete_log(12, 18, 10) == discrete_log(2, 8, 10) == 3\nE assert -1 == 3\nE + where -1 = discrete_log(2, 8, 10)\n\ntest_discrete_log.py:72: AssertionError\n_____________________ test_exhaustive_matches_brute_force ______________________\n\n def test_exhaustive_matches_brute_force():\n fails = []\n for m in range(1, 70):\n cap = m + 5 # covers the whole pre-period + period\n for a in range(m):\n for b in range(m):\n got = discrete_log(a, b, m)\n exp = brute(a, b, m, cap)\n if got != exp:\n fails.append((a, b, m, got, exp))\n> assert not fails, fails[:10]\nE AssertionError: [(2, 0, 4, -1, 2), (2, 2, 4, -1, 1), (2, 3, 5, 7, 3), (3, 2, 5, 7, 3), (2, 2, 6, -1, 1), (2, 4, 6, -1, 2), ...]\nE assert not [(2, 0, 4, -1, 2), (2, 2, 4, -1, 1), (2, 3, 5, 7, 3), (3, 2, 5, 7, 3), (2, 2, 6, -1, 1), (2, 4, 6, -1, 2), ...]\n\ntest_discrete_log.py:87: AssertionError\n________________ test_large_solvable_validity_and_minimal_bound ________________\n\n def test_large_solvable_validity_and_minimal_bound():\n rng = random.Random(12345)\n for _ in range(300):\n m = rng.randint(2, 10 ** 9)\n a = rng.randint(0, m - 1)\n x = rng.randint(0, 2 * 10 ** 6)\n b = pow(a, x, m)\n r = discrete_log(a, b, m)\n> assert r != -1\nE assert -1 != -1\n\ntest_discrete_log.py:99: AssertionError\n=========================== short test summary info ============================\nFAILED test_discrete_log.py::test_simple_coprime - assert -1 == 3\nFAILED test_discrete_log.py::test_non_coprime_reachable_in_cycle - assert -1 ...\nFAILED test_discrete_log.py::test_reaches_zero_tail - assert -1 == 10\nFAILED test_discrete_log.py::test_inputs_reduced_mod_m - assert -1 == 3\nFAILED test_discrete_log.py::test_exhaustive_matches_brute_force - AssertionE...\nFAILED test_discrete_log.py::test_large_solvable_validity_and_minimal_bound\n6 failed, 7 passed in 0.31s"
},
{
"total": 13,
"answer": "```python\nimport math\n\ndef discrete_log(a: int, b: int, m: int) -> int:\n \"\"\"Return the smallest non-negative integer x such that a**x ≡ b (mod m), or -1 if no such x exists.\"\"\"\n if m == 1:\n return 0\n a = a % m\n b = b % m\n if b == 1:\n return 0\n if a == 0:\n return 1 if b == 0 else -1\n\n g = math.gcd(a, m)\n if b % g != 0:\n return -1\n\n # Factor out g from m\n m_prime = m // g\n a_prime = a // g\n b_prime = b // g\n\n # Now solve a_prime^x ≡ b_prime (mod m_prime) with gcd(a_prime, m_prime) = 1\n # Using baby-step giant-step\n n = int(math.isqrt(m_prime)) + 1\n table = {}\n current = 1\n for j in range(n):\n if current not in table:\n table[current] = j\n current = (current * a_prime) % m_prime\n\n a_prime_inv = pow(a_prime, -1, m_prime)\n current = b_prime\n for i in range(n):\n if current in table:\n j = table[current]\n x = i * n + j\n # Verify the solution in the original modulus\n if pow(a, x, m) == b:\n return x\n current = (current * a_prime_inv) % m_prime\n\n # If no solution found in the periodic part, check the tail\n # The tail length is the smallest k such that a^k ≡ 0 (mod m)\n # We need to find the smallest x where a^x ≡ b (mod m)\n # Since a and m are not coprime, we can have a^x ≡ 0 (mod m) for x >= k\n # So we need to check x from 0 to k-1\n # k is the smallest integer such that a^k ≡ 0 (mod m)\n # This is equivalent to the smallest k where a^k is divisible by m\n # Since a and m share a common factor g, we can write a = g * a', m = g * m'\n # Then a^k = g^k * a'^k, and we need g^k * a'^k to be divisible by g * m'\n # So k must be at least 1, and a'^k must be divisible by m'\n # Since a' and m' are coprime, a'^k is never divisible by m' unless m' = 1\n # So if m' > 1, then a^x is never 0 mod m for any x\n # But if m' = 1, then m = g, and a = g * a' = m * a'\n # Then a^x = m^x * a'^x, which is 0 mod m for x >= 1\n # So the tail length is 1 if m' = 1, otherwise there is no tail (a^x is never 0 mod m)\n if m_prime == 1:\n if b == 0:\n return 1\n else:\n return -1\n else:\n return -1\n```",
"passed": 7,
"test_error": ""
}
]