import datetime
import queue

import paho.mqtt.client as mqtt
import psycopg2
from psycopg2 import sql
from psycopg2.extras import execute_values
from tzlocal import get_localzone
import threading
from queue import Queue
import signal
import sys


exit_event = threading.Event()


def signal_handler(sig, frame):
    print("Ctrl+C received. Stopping threads...")
    exit_event.set()


class MqttListenerThread(threading.Thread):
    def __init__(self, messages_queue:Queue, exit_event:threading.Event, *args, **kwargs):
        super().__init__(*args, **kwargs)

        self.mqttclient = mqtt.Client(callback_api_version=mqtt.CallbackAPIVersion.VERSION2, client_id='logger',
                                        transport='websockets')
        self.mqttclient.on_connect = self.on_connect
        self.mqttclient.on_message = self.on_message

        self.mqttclient.connect("192.168.74.120", 80, 60)

        self.messages_queue = messages_queue
        self.exit_event = exit_event

        self.local_tz = get_localzone()

    def run(self):
        while not self.exit_event.is_set():
            self.mqttclient.loop()

    def on_connect(self, client, userdata, flags, reason_code, properties):
        print(f"Connected with result code {reason_code}")

        client.subscribe("/devices/+/controls/+")

    def on_message(self, client, userdata, msg):
        msg_value = msg.payload.decode('utf-8')

        try:
            n_value = float(msg_value)
        except ValueError:
            n_value = 0

        put_value = (datetime.datetime.now(tz=self.local_tz), msg.topic, n_value)

        self.messages_queue.put(put_value)


class SenderThread (threading.Thread):
    def __init__(self, messages_queue:Queue, exit_event:threading.Event, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.messages_queue = messages_queue
        self.exit_event = exit_event

        self.db_conn = psycopg2.connect(dbname='wb-log', user='wb',
                                        password='timescale', host='localhost', port=55432)

        self.db_conn.autocommit = True

        self.db_cursor = self.db_conn.cursor()

    def run(self):

        local_buffer = []
        while True:
            try:
                msg = self.messages_queue.get(timeout=0.5)
                local_buffer.append(msg)
                send_values = len(local_buffer) >= 20
            except queue.Empty:
                send_values = True

            if send_values:
                with self.db_conn:
                    insert_query = sql.SQL('INSERT INTO wb_values (time, topic, value) VALUES %s')
                    execute_values(self.db_cursor,insert_query,local_buffer)

                local_buffer = []

            if self.exit_event.is_set():
                self.db_conn.close()
                break


if __name__ == '__main__':
    signal.signal(signal.SIGINT, signal_handler)

    msg_queue = Queue()

    sender_thread = SenderThread(messages_queue=msg_queue, exit_event=exit_event)
    mqtt_thread = MqttListenerThread(messages_queue=msg_queue, exit_event=exit_event)

    sender_thread.start()
    mqtt_thread.start()

    sender_thread.join()
    mqtt_thread.join()

