#!/usr/bin/env python3

from PIL import Image
import argparse
import os
import struct
import sys

FORMATS = [
	"argb1555", "rgab5515", "bgar5515", "rgb565", "argb1232", "ragb2132", "rgb332",
	"r8", "r6", "r2", "r1", "p8", "p4", "p2", "p1"
]

def bytes_from_bitstream_le(bitstream):
	accum = 0
	accum_size = 0
	while True:
		while accum_size < 8:
			try:
				nbits, newdata = next(bitstream)
			except StopIteration:
				return
			accum = accum | (newdata << accum_size)
			accum_size += nbits
		while accum_size >= 8:
			yield accum & 0xff
			accum = accum >> 8
			accum_size -= 8

class BinHeader:
	def __init__(self, filename, arrayname=None):
		if arrayname is None:
			arrayname = filename.split(".")[0]
		self.f = open(filename, "w")
		self.out_count = 0
		self.f.write(
			"#ifndef _IMG_ASSET_SECTION\n" \
			"#define _IMG_ASSET_SECTION \".data\"\n" \
			"#endif\n\n" \
			f"static const char __attribute__((aligned(4), section(_IMG_ASSET_SECTION \".{arrayname}\"))) {arrayname}[] = {{\n\t"
		)

	def write(self, bs):
		for b in bs:
			self.f.write("0x{:02x}".format(b) + (",\n\t" if self.out_count % 16 == 15 else ", "))
			self.out_count += 1

	def close(self):
		self.f.write("\n};\n")
		self.f.close()

# Fixed dither -- note every number 0...15 appears once (thanks Graham)
dither_pattern_4x4 = [
	[0  , 8  , 2  , 10],
	[12 , 4  , 14 , 6 ],
	[3  , 11 , 1  , 9 ],
	[15 , 7  , 13 , 5 ],
]

def format_channel(data, msb, lsb, dither=False, dithercoord=None):
	# Assume data to be 8 bits
	out_width = msb - lsb + 1
	assert(out_width <= 8)
	if dither:
		ditherval = dither_pattern_4x4[dithercoord[1] % 4][dithercoord[0] % 4]
		shamt = (8 - out_width) - 4
		if shamt >= 0:
			data += ditherval << shamt
		else:
			data += ditherval >> -shamt
		data = min(data, 0xff)
	return (data >> (8 - out_width)) << lsb

def format_rgb_pixel(pix, fmt, dither=False, dithercoord=None):
	accum = 0
	for p, f in zip(pix, fmt):
		accum |= format_channel(p, f[0], f[1], dither, dithercoord)
	if len(pix) == len(fmt) - 1:
		accum |= format_channel(0xff, fmt[-1][0], fmt[-1][1])
	return accum

# TODO would be kind of nice to generate these based on format string but I don't need that yet
rgb_formats = {
	"argb1555": (16, ((14, 10), (9, 5), (4, 0), (15, 15))),
	"rgab5515": (16, ((15, 11), (10, 6), (4, 0), (5, 5))),
	"bgar5515": (16, ((4, 0), (10, 6), (15, 11), (5, 5))),
	"rgb565"  : (16, ((15, 11), (10, 5), (4, 0))),
	"argb1232": (8, ((6, 5), (4, 2), (1, 0), (7, 7))),
	"ragb2132": (8, ((7, 6), (4, 2), (1, 0), (5, 5))),
	"rgb332"  : (8, ((7, 5), (4, 2), (1, 0))),
	"r8"      : (8, ((7, 0), (-1, 0), (-1, 0))),
	"r6"      : (8, ((5, 0), (-1, 0), (-1, 0))),
	"r2"      : (2, ((1, 0), (-1, 0), (-1, 0))),
	"r1"      : (1, ((0, 0), (-1, 0), (-1, 0))),
}

def format_pixel(format, src_has_transparency, pixel, dither=False, dithercoord=None):
	assert(format in FORMATS)
	if format in rgb_formats:
		return (rgb_formats[format][0], format_rgb_pixel(pixel, rgb_formats[format][1], dither, dithercoord))
	elif format in ["p8", "p4", "p2", "p1"]:
		size = int(format[1:])
		return (size, (pixel + src_has_transparency) & ((1 << size) - 1))
	else:
		raise Exception()

if __name__ == "__main__":
	parser = argparse.ArgumentParser()
	parser.add_argument("input", help="Input file name")
	parser.add_argument("output", help="Output file name")
	parser.add_argument("--tilesize", "-t", help="Tile size (pixels), default 8",
		default="8", choices=[str(2 ** i) for i in range(3, 11)])
	parser.add_argument("--single", "-s", action="store_true",
		help="The input consists of a single image of arbitrary width/height, rather than a tileset")
	parser.add_argument("--format", "-f", help="Output pixel format, default argb1555",
		default="argb1555", choices=FORMATS)
	parser.add_argument("--dither", "-d", action="store_true",
		help="Apply a simple fixed dither pattern when packing RGB files")
	parser.add_argument("--metadata", "-m", action="store_true",
		help="Write out opacity metadata at end of file for faster alpha blit (must be used with --single)")
	args = parser.parse_args()
	img = Image.open(args.input)
	if args.single:
		tsize_x = img.width
		tsize_y = img.height
	else:
		tsize_x = int(args.tilesize)
		tsize_y = tsize_x
	if args.metadata and not args.single:
		sys.exit("--metadata must be used with --single")

	format_is_paletted = args.format.startswith("p")
	image_is_transparent = img.mode == "RGBA" and img.getextrema()[3][0] < 255
	if args.metadata and not image_is_transparent:
		sys.exit("Can't write opacity metadata for a non-transparent image")

	if format_is_paletted:
		ncolours_max = 1 << int(args.format[1:])
		ncolours_actual = min(ncolours_max, len(img.getcolors()))
		pimg = img.quantize(ncolours_max)
		palette = pimg.getpalette()
		# TODO haven't found a sane way to make PIL map transparency to palette
		if image_is_transparent:
			for x in range(img.width):
				for y in range(img.height):
					if not (img.getpixel((x, y))[3] & 0x80):
						pimg.putpixel((x, y), 255)

		with open(args.output + ".pal", "wb") as pfile:
			if image_is_transparent:
				pfile.write(bytes(2))
			pfile.write(bytes(bytes_from_bitstream_le(
				format_pixel("argb1555", False, palette[i:i+3]) for i in range(0, ncolours_actual * 3, 3)
			)))
			if ncolours_actual < ncolours_max:
				pfile.write(bytes(2 * (ncolours_max - ncolours_actual)))
		img = pimg

	if args.output.endswith(".h"):
		ofile = BinHeader(args.output, arrayname = os.path.basename(args.input).split(".")[0])
	else:
		ofile = open(args.output, "wb")

	for y in range(0, img.height - (tsize_y - 1), tsize_y):
		for x in range(0, img.width - (tsize_x - 1), tsize_x):
			tile = img.crop((x, y, x + tsize_x, y + tsize_y))
			ofile.write(bytes(bytes_from_bitstream_le(
				format_pixel(args.format, image_is_transparent, tile.getpixel((i, j)), args.dither, dithercoord=(i, j)) for j in range(tsize_y) for i in range(tsize_x)
			)))
	if args.metadata:
		assert(tsize_x * tsize_y % 4 == 0)
		for y in range(0, tsize_y):
			opacity = list(img.getpixel((x, y))[3] >= 128 for x in range(tsize_x))
			try:
				first_transparent = opacity.index(True)
				last_transparent = tsize_x - 1 - list(reversed(opacity)).index(True)
				continuous_span = all(opacity[first_transparent:last_transparent + 1])
				ofile.write(struct.pack("<L", (last_transparent & 0xffff) | ((first_transparent & 0x7fff) << 16) | ((continuous_span & 1) << 31)))
			except ValueError:
				# Completely transparent row
				ofile.write(struct.pack("<L", 0))

	ofile.close()
