import sys
import os
import json
from PyQt6.QtWidgets import QApplication, QWidget, QLabel, QLineEdit, QPushButton, QVBoxLayout, QHBoxLayout
from PyQt6.QtGui import QPixmap, QMovie
from PyQt6.QtCore import Qt

# ==== CONFIG ====
IMAGE_FOLDER = "image"
OUTPUT_FILE = "labels.json"

# Load images
images = [f for f in os.listdir(IMAGE_FOLDER) if f.lower().endswith(('.png', '.jpg', '.jpeg', '.gif'))]
images.sort()

# Load existing labels
if os.path.exists(OUTPUT_FILE):
    with open(OUTPUT_FILE, 'r') as f:
        labels_data = json.load(f)
else:
    labels_data = {}

class MemeLabeler(QWidget):
    def __init__(self):
        super().__init__()
        self.setWindowTitle("Meme Labeler")
        self.index = 0
        self.current_movie = None  # for GIF playback

        # UI elements
        self.filename_label = QLabel()
        self.filename_label.setAlignment(Qt.AlignmentFlag.AlignCenter)
        self.filename_label.setStyleSheet("font-weight: bold; font-size: 14px;")

        self.image_label = QLabel()
        self.image_label.setAlignment(Qt.AlignmentFlag.AlignCenter)
        self.text_input = QLineEdit()
        self.text_input.returnPressed.connect(self.save_next)

        self.prev_button = QPushButton("Back")
        self.prev_button.clicked.connect(self.go_back)
        self.next_button = QPushButton("Next")
        self.next_button.clicked.connect(self.save_next)

        btn_layout = QHBoxLayout()
        btn_layout.addWidget(self.prev_button)
        btn_layout.addWidget(self.next_button)

        layout = QVBoxLayout()
        layout.addWidget(self.filename_label)  # show filename
        layout.addWidget(self.image_label)
        layout.addWidget(self.text_input)
        layout.addLayout(btn_layout)
        self.setLayout(layout)

        self.show_image()

    def show_image(self):
        if self.index < 0:
            self.index = 0
        if self.index >= len(images):
            self.filename_label.setText("")
            self.image_label.setText("All memes labeled!")
            self.text_input.setDisabled(True)
            self.prev_button.setDisabled(True)
            self.next_button.setDisabled(True)
            return

        current_image = images[self.index]
        self.filename_label.setText(current_image)  # display filename

        image_path = os.path.join(IMAGE_FOLDER, current_image)
        ext = os.path.splitext(image_path)[1].lower()

        # Stop previous GIF if any
        if self.current_movie:
            self.current_movie.stop()
            self.current_movie = None

        if ext == ".gif":
            self.current_movie = QMovie(image_path)
            self.image_label.setMovie(self.current_movie)
            self.current_movie.start()
        else:
            pixmap = QPixmap(image_path)
            pixmap = pixmap.scaled(600, 600, Qt.AspectRatioMode.KeepAspectRatio, Qt.TransformationMode.SmoothTransformation)
            self.image_label.setPixmap(pixmap)

        # Fill text input if already labeled
        self.text_input.setText(labels_data.get(current_image, ""))

    def save_next(self):
        current_image = images[self.index]
        labels_data[current_image] = self.text_input.text().strip()
        with open(OUTPUT_FILE, 'w') as f:
            json.dump(labels_data, f, indent=2)

        self.index += 1
        self.show_image()

    def go_back(self):
        current_image = images[self.index]
        labels_data[current_image] = self.text_input.text().strip()
        with open(OUTPUT_FILE, 'w') as f:
            json.dump(labels_data, f, indent=2)

        self.index -= 1
        self.show_image()

if __name__ == "__main__":
    app = QApplication(sys.argv)
    window = MemeLabeler()
    window.resize(800, 800)
    window.show()
    sys.exit(app.exec())
