import sys
import os

# Bytes per 8-line block in each mode
BLOCK_SIZES = {
    0: 640,   # MODE 0
    1: 320,   # MODE 1
    2: 640,   # MODE 2
}

# Total screen sizes
SCREEN_SIZES = {
    0: 20480,
    1: 10240,
    2: 20480,
}

def trim_bbc_image(filename, mode):
    if mode not in BLOCK_SIZES:
        raise ValueError("Mode must be 0, 1, or 2")

    block_size = BLOCK_SIZES[mode]
    max_size = SCREEN_SIZES[mode]

    with open(filename, "rb") as f:
        data = f.read()

    if len(data) < max_size:
        # Pad if the file is shorter than expected
        data = data.ljust(max_size, b"\x00")

    # Find last non-zero byte
    last_used = 0
    for i in range(len(data)):
        if data[i] != 0:
            last_used = i

    # Round up to end of the containing block
    trimmed_size = ((last_used // block_size) + 1) * block_size

    # Clamp to full screen size
    trimmed_size = min(trimmed_size, max_size)

    trimmed = data[:trimmed_size]

    outname = filename.replace(".bbc", "_trimmed.bbc")
    with open(outname, "wb") as f:
        f.write(trimmed)

    print(f"Original size: {len(data)} bytes")
    print(f"Trimmed size:  {trimmed_size} bytes")
    print(f"Saved as:      {outname}")

if __name__ == "__main__":
    if len(sys.argv) < 3:
        print("Usage: python trim_bbc_image.py <filename><mode>