#!/usr/bin/env python3

# Based on Commander Keen compresssion algorithm (a variant of LZSS)

import sys
import argparse

def get_symbol(data, window):
	it = iter(data)
	try:
		needle = bytearray([next(it)])
	except:
		return (None, None)
	try:
		while window.find(needle) >= 0:
			needle.append(next(it))
	except StopIteration:
		pass

	matchlen = len(needle) - 1
	if matchlen <= 1:
		return (1, (False, needle[0]))
	else:
		matchpos = len(window) - 1 - window.find(needle[:-1])
		return (matchlen, (True, matchpos, matchlen - 1))

def iter_symbols(data, windowsize):
	data = data[:]
	window = bytearray()

	while True:
		consumed, symbol = get_symbol(data, window)
		if consumed is None:
			return
		yield symbol
		del window[:max(0, len(window) + consumed - windowsize)]
		window.extend(data[:consumed])
		del data[:consumed]

def bitmap(bits):
	bm = 0
	for i, bit in enumerate(bits):
		bm = bm | (bool(bit) << i)
	return bm

def iter_bytes(data, windowsize):
	syms = iter(iter_symbols(data, windowsize))
	holding = []
	eof = False
	while not eof:
		try:
			holding.append(next(syms))
		except StopIteration:
			eof = True
		if eof and len(holding) or len(holding) >= 8:
			yield bitmap(sym[0] for sym in holding)
			for sym in holding:
				yield from sym[1:]
			holding.clear()

def compress(data, windowsize):
	return bytearray(iter_bytes(data, windowsize))

def decompress(cdata, windowsize):
	out = bytearray()
	window = bytearray()
	symbolcount = 0
	it = iter(cdata)
	while True:
		if symbolcount == 0:
			try:
				bitmap = next(it)
			except StopIteration:
				break
			symbolcount = 7
		else:
			symbolcount -= 1
		if bitmap & 1:
			start = len(window) - 1 - next(it)
			count = next(it) + 1
			new = window[start:start + count]
			out.extend(new)
			window.extend(new)
		else:
			try:
				lit = next(it)
			except StopIteration:
				break
			window.append(lit)
			out.append(lit)
		del window[:max(0, len(window) - windowsize)]
		bitmap = bitmap >> 1

	return out


if __name__ == "__main__":
	parser = argparse.ArgumentParser(formatter_class=argparse.RawDescriptionHelpFormatter)
	parser.add_argument("input", help="Input file name")
	parser.add_argument("output", help="Output file name")
	parser.add_argument("-d", "--decompress", action="store_true", help="Decompress input (default is to compress)")
	parser.add_argument("-w", "--window", help="Compression window size")
	args = parser.parse_args()
	if args.input == "-":
		ifile = sys.stdin
	else:
		ifile = open(args.input, "rb")
	if args.output == "-":
		ofile = sys.stdout
	else:
		ofile = open(args.output, "wb")
	if args.window is None:
		args.window = 256

	ibuf = bytearray(ifile.read())
	ifile.close()

	if args.decompress:
		obuf = decompress(ibuf, args.window)
	else:
		obuf = compress(ibuf, args.window)

	ofile.write(obuf)
	ofile.close()
