1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
#!/usr/bin/env python3
 
# 1-2 Oblivious Transfer algorithm using files
# SI 486D Spring 2016
 
import sys
import getpass
import Crypto.Random
import Crypto.PublicKey.RSA as RSA
 
def main():
    if len(sys.argv) == 2 and sys.argv[1][0].lower() == 'a':
        m0 = getpass.getpass("First message: ")
        m1 = getpass.getpass("Second message: ")
        a = Alice(m0, m1)
        a.first()
        input("Initial data generated. Press ENTER after message is chosen.")
        a.second()
        print("Sending complete.")
    elif len(sys.argv) == 2 and sys.argv[1][0].lower() == 'b':
        choicestr = getpass.getpass("Which message do you want (0 or 1)? ")
        choice = -1
        if choicestr and choicestr[0] == '0': 
            choice = 0
        elif choicestr and choicestr[0] == '1': 
            choice = 1
        else:
            print("ERROR: you didn't choose 0 or 1")
            exit(1)
        b = Bob(choice)
        b.first()
        input("Response sent. Press ENTER after the response is received.");
        m = b.second()
        print("Your message is below.")
        print(m)
    else:
        print("ERROR: Please enter 'a' or 'b' on the command line.")
        exit(1)
 
def str2bytes(s, n):
    """Encodes string s into exactly n bytes."""
    res = s.encode('utf8')
    if len(res) > n:
        print("ERROR: message is too long to encode in", n, "bytes")
        exit(1)
    # pad with 0 bytes to make length exactly n
    return res +  b'\0' * (n - len(res))
 
def bytes2str(b):
    """Decodes bytes back to string, removing any padding null bytes."""
    return b.decode('utf8').rstrip('\0')
 
def xor(b1, b2):
    """Performs byte-wise XOR on two bytes objects"""
    return bytes(a^b for (a,b) in zip(b1, b2))
 
def load(fname):
    """Reads a file and returns bytes of its contents"""
    try:
        with open(fname, 'rb') as fin:
            return fin.read()
    except FileNotFoundError:
        print("Couldn't read file", fname)
        exit(1)
 
def store(fname, data):
    """Stores bytes data into a file with the given name."""
    with open(fname, 'wb') as fout:
        fout.write(data)
 
class Alice:
    def __init__(self, m0, m1, modlen=1024):
        self._modlen = modlen
        bytelen = self._modlen // 8
        self._m0b = str2bytes(m0, bytelen)
        self._m1b = str2bytes(m1, bytelen)
 
    def first(self):
        rgen = Crypto.Random.new()
 
        # Generate key-pair.
        # The top 2 bits of the RSA modulus must both equal 1.
        # This loop executes expected O(1) times to find such a modulus.
        while True:
            self._keys = RSA.generate(self._modlen, rgen.read)
            if self._keys.n >> (self._modlen - 2) == 3:
                break
 
        bytelen = self._modlen // 8
        self._x0 = rgen.read(bytelen)
        self._x1 = rgen.read(bytelen)
 
        store("x0", self._x0)
        store("x1", self._x1)
        store("Enc", self._keys.publickey().exportKey('DER'))
 
    def second(self):
        u = load("u")
 
        v0 = xor(self._keys.decrypt(xor(u, self._x0)), self._m0b)
        v1 = xor(self._keys.decrypt(xor(u, self._x1)), self._m1b)
 
        store("v0", v0)
        store("v1", v1)
 
class Bob:
    def __init__(self, b):
        self._b = b
 
    def first(self):
        xx = [load(f) for f in ("x0", "x1")]
        Enc = load("Enc")
        rgen = Crypto.Random.new()
 
        bytelen = len(xx[0])
        pubkey = RSA.importKey(Enc)
 
        # Generate random nonce y and encrypted response u.
        # Have to iterate until both decryptions on the other end are less
        # than the modulus, expected O(1) times.
        while True:
            self._y = rgen.read(bytelen)
            try:
                u = xor(pubkey.encrypt(self._y, None)[0], xx[self._b])
                # this encryption is just to check the bounds of y
                pubkey.encrypt(xor(u, xx[1-self._b]), None)
                break
            except ValueError:
                pass
 
        store('u', u)
 
    def second(self):
        vv = [load(f) for f in ("v0", "v1")]
        return bytes2str(xor(self._y, vv[self._b]))
 
if __name__ == '__main__':
    main()