#!/usr/bin/env python

from telnetlib2 import TelnetPopen4
from bitprocess import dataset, sorted_index, theta_ratios_detect, print_ratios
from bitprocess import remove_row_from_ratios, remove_col_from_ratios
from bitrotater import bitxormatrix
from time import sleep
from math import pi, degrees, radians
import sys
import os
from copy import copy
from threading import Thread

def int_to_str(i, max_len):
	"""	integer to 16-byte-length string.  padded with zeros.
	"""
	d = ''
	while i != 0:
		d += chr(i&0xff) 
		i >>= 8
	d += chr(0) * (max_len-len(d))
	return d

def str_to_int(d):
	i = 0L
	exp = 0
	for c in d:
		i += ord(c) << exp
		exp += 8
	return i

def bit_set(i, bit):
	return i | 1<<bit 

def bit_is_set(i, bit):
	return i & (1<<bit) != 0

def dump_data(d):

	for ch in d:
		if not ch.isalnum():
			ch = '.'
		sys.stdout.write(ch)
	for ch in d:
		print "%02x" % ord(ch),
	print

def make_data(run_len, max_len):

	dt = []

	from random import Random
	r = Random()

	for d in range(run_len):

		dt.append(d)
		continue

		data = ''
		for x in range(max_len):
			data += chr(r.randint(0, 255))
		dt.append(str_to_int(data))

	return dt

def pre_encrypt(cmd_name, run_len, data_bytes, key_bytes, dt, key):

	t = TelnetPopen4("./%s -k -e -i -o test -l %d" % (cmd_name, run_len))
	#t.set_debuglevel(1)

	from random import Random
	r = Random()

	baseline = []

	for d in range(run_len):

		print d
		t.write(int_to_str(dt[d], data_bytes))
		t.write(int_to_str(key, key_bytes))
		txt = t.read_until('read input\n')
		txt = t.read_until('encrypt\n')
		#print txt[max_len:]
		base = txt[:data_bytes]
		print "base: ",
		dump_data(base)
		baseline.append(str_to_int(base))

	txt = t.read_until('done\n')
	t.close()

	return baseline

def test_decrypt(cmd_name, run_len, data_bytes, key_bytes, k_prob, dt):

	t = TelnetPopen4("./%s -b -d -i -o test -l %d" % (cmd_name, run_len))
	#t.set_debuglevel(1)

	for p in k_prob:
		t.write("%g\n" % p)
		print t.read_lazy()

	baseline = []
	keys = []
	num_fails = []

	for d in range(run_len):

		print d
		t.write(int_to_str(dt[d], data_bytes))
		txt = t.read_until('key: ')
		#print txt[16:]
		txt = t.read_until('read input\n')
		key = txt[:key_bytes]
		print "key :",
		dump_data(key)
		txt = t.read_until('encrypt\n')
		#print txt[16:]
		base = txt[:data_bytes]
		print "base:",
		dump_data(base)
		txt = t.read_until('num_fails: ')
		txt = t.read_until('\n').strip()
		baseline.append(str_to_int(base))
		keys.append(str_to_int(key))
		num_fail = int(txt)
		num_fails.append(num_fail)
		print "num_fails:", num_fail

	txt = t.read_until('done\n')
	t.close()

	return baseline, keys, num_fails

def test_k(cmd_name, run_len, k_prob, dt, keys):

	t = TelnetPopen4("./%s -k -i -o test -l %d" % (cmd_name, run_len))
	#t.set_debuglevel(1)

	for p in k_prob:
		t.write("%g\n" % p)
		print t.read_lazy()

	from random import Random
	r = Random()

	for d in range(run_len):

		t.write(int_to_str(dt[d]))
		t.write(int_to_str(keys[d]))
		txt = t.read_until('read input\n')
		#txt = t.read_until('key: ')
		#txt = t.read_some()
		#key = txt[:16]
		#print "key :",
		#dump_data(key)
		print txt

	txt = t.read_until('done\n')
	t.close()

def transfer_analyse_input(hostname, key_block_pairs, key_bytes, data_bytes):

	t = TelnetPopen4("ssh -e none -T %s 'cat > src/crypto/key_as_data/%s.input.txt'" % \
			(hostname, hostname))
	#t.set_debuglevel(1)

	for (block, key) in key_block_pairs:
		t.write(int_to_str(key, key_bytes))
		t.write(int_to_str(block, data_bytes))

	t.close()

def transfer_analyse_vectors(hostname, vectors):

	t = open("/tmp/%s.vectors.txt" % hostname, "w")
	for (k1, d1, k2, d2) in vectors:
		t.write("%d\n%d\n%d\n%d\n" % (k1, d1, k2, d2))
	t.close()

	t = TelnetPopen4("scp -v /tmp/%s.vectors.txt %s:src/crypto/key_as_data/%s.vectors.txt" % \
			(hostname, hostname, hostname))
	#t.set_debuglevel(1)
	t.read_until("100%")
	t.close()

def open_analyse(hostname, key_block_pairs, key_bytes, data_bytes):

	#t = TelnetPopen4("ssh -e none -T %s src/crypto/key_as_data/analyse -v ~/src/crypto/key_as_data/%s.vectors.txt -l %d" % \
	#		(hostname, hostname, len(key_block_pairs)))
	t = TelnetPopen4("ssh -e none -T %s src/crypto/key_as_data/analyse -l %d" % \
			(hostname, len(key_block_pairs)))
	#t.set_debuglevel(1)

	txt = t.read_until('num keyblocks: %d\n' % len(key_block_pairs))

	for (block, key) in key_block_pairs:
		t.write(int_to_str(key, key_bytes))
		t.write(int_to_str(block, data_bytes))
		txt = t.read_until('read input\n')

	txt = t.read_until('read pairs\n')

	return t

def read_key_bytes(t, key_bytes):

	kp = {}
	for l in range(key_bytes * 8):
		txt = t.read_until("\n")
		txt.strip()
		kp[l] = int(txt)
	return kp

def test_analyse_by_vectorfile(t, key_bytes, data_bytes, vectors):

	t.write("%d\n" % len(vectors))
	txt = t.read_until('num pairs: %d\n' % len(vectors))

	for (k1, d1, k2, d2) in vectors:
		txt = t.read_until('read bitpairs\n')

	return read_key_bytes(t, key_bytes)

def test_analyse(t, key_bytes, data_bytes, vectors):

	t.write("%d\n" % len(vectors))
	txt = t.read_until('num pairs: %d\n' % len(vectors))

	for (k1, d1, k2, d2) in vectors:
		t.write("%d\n%d\n%d\n%d\n" % (k1, d1, k2, d2))
	for (k1, d1, k2, d2) in vectors:
		txt = t.read_until('read bitpairs\n')

	return read_key_bytes(t, key_bytes)

def close_analyse(t):

	t.write("0\n");
	t.close()


def process_data(ds, row_num, col_num, data_block_size, key_block_size,
			rotate_x, rotate_y):

	d = ds.rotate(rotate_x, rotate_y)
	#d = ds.remap(rotate_x, rotate_y)
	d = d.average(data_block_size, key_block_size)
	#da = ds.sub_average(d)
	da = ds.std_dev()
	d = d.signed_std_dev()

	if col_num is not None:
		c0 = d.select_col(col_num)
		i0 = sorted_index(c0)
		d = d.reorder_rows(i0)
		#da = da.reorder_cols(i0)
	
	if row_num is not None:
		r0 = d.select_row(row_num)
		i0 = sorted_index(r0)
		d = d.reorder_cols(i0)
		#da = da.reorder_rows(i0)

	#xo = d.col_order
	#yo = d.row_order
	#yo = reverse_order(d.col_order)
	#xo = reverse_order(d.row_order)
	#yo = reverse_order(d.row_order)
	#xo = reverse_order(d.col_order)

	#print yo
	#print xo

	#sys.exit(0)

	allow_err = pi/2/90
	exclude = {}
	#if row_num is not None:
	#	exclude[0.0] = (None, row_num)
	#if col_num is not None:
	#	exclude[pi/2] = (col_num, None)

	r = d.z_ratio_detect(allow_err, exclude)
	r = theta_ratios_detect(r, allow_err)
	r = d.find_lines(r)

	return d, da, r

def reverse_order(order):
	"""	the col order and row order represent where the data has been
		moved to, not how to index it.  so we have to reverse the index
		oops.  well, it made life a bit easier to shuffle stuff around...
	"""

	keys = []
	data = {}
	for (k, d) in order:
		keys.append(k)
		data[k] = d
	kcp = copy(keys)
	kcp.sort()
	ro = {}
	for krev, ksorted in zip(kcp, keys):
		ro[ksorted] = data[krev]
	return ro

def process_ratios(hostname, name, key_size, data_size,
	kweight, kk, da, d, r, dt, blkz, min_vector_len):

	key_bytes = key_size / 8
	data_bytes = data_size / 8

	#transfer_analyse_input(hostname, blkz, key_bytes, data_bytes)

	t = open_analyse(hostname, blkz, key_bytes, data_bytes)

	xo = reverse_order(d.col_order)
	yo = reverse_order(d.row_order)
	#xo = reverse_order(d.col_order)
	#yo = reverse_order(d.row_order)

	kp = {}
	for i in range(key_size):
		kp[i] = 0.0

	v = []
	for ((ang, theta), vectors) in r.items():
		if len(vectors) < min_vector_len:
			continue
		print vectors
		vector_len = len(vectors)-1
		for i in range(vector_len):
			x, y = vectors[i]
			x2, y2 = vectors[i+1]

			avgs_x = xo[x]
			avgs_y = yo[y]

			avgs_x2 = xo[x2]
			avgs_y2 = yo[y2]

			v = []

			for ix in range(len(avgs_x)):
				for iy in range(len(avgs_y)):
					ax = avgs_x[ix]
					ay = avgs_y[iy]
					ax2 = avgs_x2[ix]
					ay2 = avgs_y2[iy]

					#v.append((ax, ay, ax2, ay2))
					v.append((ay, ax, ay2, ax2))

			ka = test_analyse(t, key_bytes, data_bytes, v)
			for j in range(key_size):
				kp[j] += ka[j] * 1e-5

			continue

			if len(v) > 40000 or i == vector_len-1:
				transfer_analyse_vectors(hostname, v)
				ka = test_analyse_by_vectorfile(t, key_bytes, data_bytes, v)

				for j in range(key_size):
					kp[j] += ka[j] * 1e-5

				v = []

		print_k_prob(name, kk, kp)

	k = []
	for i in range(key_size):
		k.append(kp[i])

	close_analyse(t)

	return k

def process_ratios_old(name, key_size, data_size,
	kweight, kk, da, d, r, dt, blkz, min_vector_len):

	kp = {}
	for i in range(key_size):
		kp[i] = 0.0

	xo = reverse_order(d.col_order)
	yo = reverse_order(d.row_order)

	for ((ang, theta), vectors) in r.items():
		if len(vectors) < min_vector_len:
			continue
		print vectors
		for i in range(len(vectors)-1):
			x, y = vectors[i]
			x2, y2 = vectors[i+1]

			avgs_x = xo[x]
			avgs_y = yo[y]

			avgs_x2 = xo[x2]
			avgs_y2 = yo[y2]

			#weight = d[x, y]
			#weight2 = d[x2, y2]

			for ix in range(len(avgs_x)):
				for iy in range(len(avgs_y)):
					ax = avgs_x[ix]
					ay = avgs_y[iy]
					ax2 = avgs_x2[ix]
					ay2 = avgs_y2[iy]

					weight = da[ax, ay]
					weight2 = da[ax2, ay2]

					for (base, k) in blkz:
						if (bit_is_set(k, ay) ^ bit_is_set(base, ax)) == \
						   (bit_is_set(k, ay2) ^ bit_is_set(base, ax2)):
							
							k_shift = 1e-5

						else:

							k_shift = -1e-5

					kp[ay] += k_shift * weight
					kp[ay2] += k_shift * weight2

		print_k_prob(name, kk, kp)

	k = []
	for i in range(key_size):
		k.append(kp[i])
	return k

def print_k_prob(name, key, kp):

	f = open("%s.txt" % name, "w")
	tot = 0
	for n in range(len(kp)):
		kpn = kp[n]
		bk = bit_is_set(key, n)
		if bk:
			b = kpn > 0
		else:
			b = kpn < 0
		f.write("%d %d %d %.3g   " % (n, b, bk, kpn))
		if n % 4 == 3:
			f.write("\n")
		if b:
			tot += 1
	f.write("%d\n" % tot)
	f.close()

def print_last_kprob(name, key, kp):

	f = open("%s.txt" % name, "w")
	tot = 0
	for (n, kpn) in zip(range(len(kp)), kp):
		bk = bit_is_set(key, n)
		if bk:
			b = kpn > 0
		else:
			b = kpn < 0
		f.write("%d %f (%d %d)\n" % (n, kpn, b, bk))
		if b:
			tot += 1
	f.write("%d\n" % tot)
	f.close()

class fork_processor:

	def __init__(self, hosts):
		self.child_pids = []
		self.max_children = len(hosts)
		self.available_hosts = copy(hosts)
		self.in_use = {}

	def do_fork(self)
	            k_prob, key, da, d, r, dt, blkz, min_vector_len):

		self.collect_children()
		hostname = self.available_hosts[0]
		pid = os.fork()
		if pid:
			self.child_pids.append(pid)
			self.in_use[pid] = hostname
			self.available_hosts.remove(hostname)
		return pid

	def fork_makegraph(self, name, key_size, data_size,
				key_bit, data_bit)
	            key, dt, blkz, min_vector_len):

		if self.do_fork():
			return 
		kp = process_ratios(hostname, name, key_size, data_size,
		             k_prob, key, da, d, r, dt, blkz, min_vector_len)
		print "child exiting", pid
		os._exit(0)

	def fork_pr(self, name, key_size, data_size,
	            k_prob, key, da, d, r, dt, blkz, min_vector_len):

		if self.do_fork():
			return 
		kp = process_ratios(hostname, name, key_size, data_size,
		             k_prob, key, da, d, r, dt, blkz, min_vector_len)
		print_last_kprob(name, key, kp)
		print "child exiting", pid
		os._exit(0)

	def collect_children(self):

		print "entering collect_children pids", self.child_pids
		while len(self.child_pids) > 0:
			if len(self.child_pids) <self.max_children:
				options = os.WNOHANG
			else:
				# If the maximum number of children are already
				# running, block while waiting for a child to exit
				options = 0
			try:
				pid, status = os.waitpid(-1, options)
			except os.error:
				print "os.error!"
				pid = None
			if not pid: break
			print "child pid exiting", self.child_pids, pid, 
			if pid in self.child_pids:
				hostname = self.in_use.pop(pid)
				print hostname,
				self.child_pids.remove(pid)
				self.available_hosts.append(hostname)
			print

def test_make_doublegraphs():

	hosts = ['127.0.0.1', 'localhost']

	run_len = 512
	#num_blocks = 16
	#key_size = 56
	#data_size = 64
	#cmd_name = "des_xor"
	key_size = 128
	data_size = 128
	cmd_name = "rjd_xor"

	key_block_size = 16 # key_size / num_blocks
	data_block_size = 16 # data_size / num_blocks
	key_bytes = key_size / 8
	data_bytes = data_size / 8

	dt = make_data(run_len, data_bytes)

	# encrypt data with pre-arranged key, which we're going to then
	# try and 'find'.
	key = str_to_int('0123456789abcdef'[:key_bytes])
	#key = str_to_int('                '[:key_bytes])
	#key = 0

	for key_bit in range(key_size):
		for data_bit in range(key_size):
			pf.fork_makegraph(self, name, key_size, data_size,
				key_bit, data_bit)
	            key, dt, min_vector_len):

			print "parent's child pids", pf.child_pids


def test_analyse():

	hosts = ['127.0.0.1', 'localhost', '192.168.16.77', 'cold',
	         '10.8.0.1', '10.8.0.1']

	run_len = 512
	#num_blocks = 16
	#key_size = 56
	#data_size = 64
	#cmd_name = "des_xor"
	key_size = 128
	data_size = 128
	cmd_name = "rjd_xor"

	key_block_size = 16 # key_size / num_blocks
	data_block_size = 16 # data_size / num_blocks
	key_bytes = key_size / 8
	data_bytes = data_size / 8

	dt = make_data(run_len, data_bytes)

	# encrypt data with pre-arranged key, which we're going to then
	# try and 'find'.
	#key = str_to_int('0123456789abcdef'[:key_bytes])
	key = str_to_int('                '[:key_bytes])
	key = 0
	dt_e = pre_encrypt(cmd_name, run_len, data_bytes, key_bytes, dt, key)

	# start off with probabilities 0.5 for all key bits - i.e. we ain't
	# got a clue what we're doing :)
	k_prob = []
	#if i < 64:
	#	guess = 1.0
	#else:
	#	guess = 0.5
	guess = 0.5
	for i in range(key_size):
		if bit_is_set(key, i):
			k_prob.append(guess)
		else:
			k_prob.append(1.0 - guess)

	# ok - hand the encrypted data, and the probabilities, to the
	# stats-analyser.  receive back each weighted-probability-generated
	# key and the decrypted data block.
	baseline, keys, num_fails = test_decrypt(cmd_name, run_len,
						data_bytes, key_bytes, k_prob, dt_e)

	blkz = zip(baseline, keys)

	#blkz = []
	#for base, k in zip(baseline, keys):
	#	blkz.append(bitxormatrix(k, base, 128, 128))

	# now perform analysis
	kp_sum = [0.0] * key_size
	ds = dataset()
	ds.read_file("graphs/test.%d.txt" % run_len)

	pf = fork_processor(hosts)

	#for rotate_x in [0,4,8,12]:
	#	for rotate_y in [0,4,8,12]:
	for rotate_x in [0]:
		for rotate_y in [0]:
			for i in [None] + range(0, data_size, data_block_size):
				for j in [None] + range(0, key_size, key_block_size):
					d, da, r = process_data(ds, j, i, data_block_size, key_block_size,
										rotate_x, rotate_y)
				#remove_row_from_ratios(r, i)
					pf.fork_pr("results/row.%s.col.%s.%dx%d" % \
							(str(i), str(j), rotate_x, rotate_y),
							key_size, data_size,
						   k_prob, key, da, d, r, dt, blkz, 10)
				print "parent's child pids", pf.child_pids

				#d, da, r = process_data(ds, i, None, data_block_size, key_block_size,
				#						rotate_x, rotate_y)
				#remove_col_from_ratios(r, i)
				#pf.fork_pr("col.%d.%dx%d" % (i, rotate_x, rotate_y),
				#			key_size, data_size,
				#		   k_prob, key, da, d, r, dt, blkz, 10)
				#print "parent's child pids", pf.child_pids

				##kp_sum = map(lambda (x,y): x+y, zip(kp_sum, rt.kp))
				##kp_sum = map(lambda (x,y): x+y, zip(kp_sum, ct.kp))

				##print_last_kprob("results", key, kp_sum)


def test_str_convert():
	d = ' ' * data_bytes
	i = str_to_int(d)
	print hex(i)
	i = bit_set(i, 55)
	print hex(i)
	d = int_to_str(i)
	print d

def test_analyser():

	key_size = 128
	data_size = 128
	key_bytes = key_size / 8
	data_bytes = data_size / 8

	dt = make_data(2, data_bytes)
	k = make_data(2, key_bytes)
	kd = zip(dt, k)
	v = [(0,0,1,1), (0,0,2,2)]
	t = open_analyse("localhost", kd, key_bytes, data_bytes)
	test_analyse(t, key_bytes, data_bytes, v)

if __name__ == '__main__':
	test_analyse()

