#!/usr/bin/python3

# (C) 2019 Simon Funk simon@sifter.org

from datetime import datetime
import struct
from   time import time

from Btle import BtleDevice, UUID
from Units import convert_units

class AirthingsWavePlusBtle(BtleDevice):

    # Services:
    data_service_uuid   = UUID('b42e1c08-ade7-11e4-89d3-123b93f75cba')

    # Characteristics:
    current_values_uuid = UUID('b42e2a68-ade7-11e4-89d3-123b93f75cba')
    command_uuid        = UUID('b42e2d06-ade7-11e4-89d3-123b93f75cba')    # "Access Control Point" Characteristic
    sensor_record_uuid  = UUID('b42e2fc2-ade7-11e4-89d3-123b93f75cba')

    @staticmethod
    def poll(serial_number, mac_addy, hours=1):
        """This polls the specified device (serial number and mac addy
            must match, as a safety check) for everything of potential
            interest we know how to read and returns the result as a
            nested dictionary.

        Hours specifies how many hours of history to fetch.
        """
        d = {
            'serial_number': serial_number,
                 'mac_addy': mac_addy,
                     'time': time(),
        }
        AirthingsWavePlusBtle.add_datetimes(True, d, ['time'])

        wave = AirthingsWavePlusBtle(serial_number, mac_addy)

        try:
            d['current'] = wave.read_current_values()
            d[     'q2'] = wave.query2()    # Query1 is boring, so we'll skip right to this.
            d[     'q3'] = wave.query3()
            d[   'hist'] = wave.get_history(hours)
        except Exception as e:
            d[  'error'] = str(e)
        finally:
            wave.disconnect()

        return d

    def __init__(self, serial_number, mac_addy):
        BtleDevice.__init__(self, mac_addy)

        self.serial_number = serial_number

        self.enabled = False    # Have indications and notifications been enabled yet?

    def connect(self):
        """Connects to the device, resets any caches of things that might have changed while disconnected,
            and makes sure we're actually talking to the device we think we are.

        You probably *don't* need to call this manually unless you want to get initialization delay
            out of the way before you do your first query or issue your first command.
        """
        if BtleDevice.connect(self):    # This returns True only when we have actually just connected; False if we were already.

            self.enabled = False

            # Check the device's serial number matches so we know we're still talking to the right device:
            devices_sn  = self[self.serial_number_uuid].decode()
            expected_sn = str(self.serial_number)[-6:]
            if devices_sn != expected_sn:
                raise Exception("The device's serial number isn't what we expect! (%r!=%r)"%(devices_sn, expected_sn))

    @staticmethod
    def scan(serial_number, duration=5, dbg=False):
        """This looks through any advertising btle devices for a Wave matching the provided serial number...
        Must be run as root.
        Returns a matching ScanEntry object if found.  The .addr field of that is the MAC address.
        """
        def match(props):
            if dbg:
                print("# Scan finds %s"%(props,))
            return AirthingsWavePlusBtle.parse_serial_number(props.get("Manufacturer")) == serial_number

        return BtleDevice.scan(match, duration, True)

    def firmware_version(self):
        return self[self.firmware_revision_uuid].decode()    # This one line of code does actually poll the device...

    def read_current_values(self, metric=False, units=False, datetimes=True):
        """This queries the device for the most recent sample values of
            all of its sensors, and returns the result as a dictionary
            mapping each sensor name to its value.

        If units is True, then returns (value, units) tuples instead of just values.

        The current time (as from time()) is also provided for convenience (note the
            sensor values may be out of date up to five minutes with respect to this
            time), and if datetimes is True, also as a naive datetime object in the
            local tz.
        """
        raw  = self[self.current_values_uuid]
        now  = time()
        data = struct.unpack('<BBBBHHHHHHHH', raw)

        if data[0] != 1:
            raise Exception("Unknown data format version (%s)"%(data[0],))

        if metric:
            radon_units = 'Bq/m3'
            temp_units  = 'C'
        else:
            radon_units = 'pCi/L'
            temp_units  = 'F'

        d = { 'time':now, 'raw':raw }

        self.import_value(d,    'humidity', data[1],   0xff, 0.5 ,   'pct',       'pct', units)
        self.import_value(d,'ambientlight', data[2],   0xff, 1   ,     '?',         '?', units)
        self.import_value(d,       'radon', data[4], 0xffff, 1   , 'Bq/m3', radon_units, units)
        self.import_value(d,    'radon-lt', data[5], 0xffff, 1   , 'Bq/m3', radon_units, units)
        self.import_value(d, 'temperature', data[6], 0xffff, 0.01,     'C',  temp_units, units)
        self.import_value(d,    'pressure', data[7], 0xffff, 0.02,  'mBar',      'mBar', units)
        self.import_value(d,         'co2', data[8], 0xffff, 1   ,   'ppm',       'ppm', units)
        self.import_value(d,         'voc', data[9], 0xffff, 1   ,   'ppb',       'ppb', units)

        # waves and mode seem to update immediately (not every 5 minutes like the rest).
        if units:
            d[  'waves'] = (data[ 3]     , '?'),
            d[   'mode'] = (data[10]     , '?'),
            d[     'x3'] = (data[11]     , '?'),
        else:
            d[  'waves'] = data[ 3]      # Seems to count recent waves.
            d[   'mode'] = data[10]      # Usually 0; 1 = recent waves? 2 = pairing taps?
            d[     'x3'] = data[11]      # Maybe signal strength?  Free memory?  Or...???

        self.add_datetimes(datetimes, d, ['time'])

        return d

    def flash_light(self):
        """This tells the Wave+ to pulse its white ring light once.
        The pulse seems to begin, in practice, around when this
            method returns, and continues for a second or two.
        """
        self.exec_command(struct.pack('<BB', 0x67, 0x08))

    def query1(self):
        """This is one of the first things the app does.  Unclear
            what all of the returned information is.  Just guessing
            here at the struct packing.

        Except for the first 4 bytes this seems to return the same
            value from day to day and across multiple devices (not
            specific to device's particular serial number, etc).

        Currently, for firmware version G-BLE-1.2.4, I see:

        (0, 0, 0, 18, 17980, 6430, 100, 150, 250, 2000, 800, 1000)

        [Perhaps is it PPCP (Peripheral Preferred Connection Parameters)?
        Or do those always transact at a lower level in the bluetooth stack?]
        """
        # 0x10 <= 66 00
        data = self.exec_command(struct.pack('<BB', 0x66, 0x00))
        if data is None:
            raise Exception("Query1 Failed")

        # -> [66:00]01:00:00:12:00:3c:46:1e:19:64:00:96:00:fa:00:d0:07:20:03:e8:03
        if len(data) != 21:
            raise Exception("Unexpected reply to query1 (len=%d): %s"%(len(data), data.hex()))
        return struct.unpack('<bBB9H', data)

    def query2(self, datetimes=True):
        """This is the second major mystery query the app does.

        Returns a new dictionary with:

        cycle_start: This is the (approximate within a few seconds or less) absolute
            time (as returned by time()) when the current sampling cycle began.
            I have no idea when or why a new one starts, but they seem to last days
            if not weeks.  The main use of this is that the sample values all update
            relative to this time -- every 5 minutes for most, and every hour for
            radon.  Note this can be out of sync with the series_start from query3
            and history.

        time_elapsed: This is the exact time in seconds elapsed since cycle_start
            (the real cycle_start, not the approximate one we infer above), as
            reported by the device.

        raw: This is the raw bytes object the device returned, in case you want
            to try to figure out what the rest of it means...

        time: This is when we received the raw reply.

        If datetimes is True, then cycle_start_ and time_ are datetime versions
            of same.
        """
        return self.query2_parse(self.query2_raw(), datetimes)

    def query2_raw(self):
        """Returns a dict with the raw return from query2 and the time the
            reply was received.

        Useful if you want to archive the raw returns and parse them later.
        """
        # 0x10 <= 6d
        data = self.exec_command(struct.pack('<B', 0x6d))
        now  = time()
        if data is None:
            raise Exception("Query2 Failed")

        # -> [6d:00]0a:96:0b:00:02:26:88:41:01:00:00:00:00:00:40:4b:1c:00:b8:34:1e:00:c5:00:0a:0b:09:00
        if len(data) != 28:
            raise Exception("Unexpected reply to query2 (len=%d): %s"%(len(data), data.hex()))

        return { 'raw': data, 'time': now }

    @staticmethod
    def query2_parse(q2, datetimes=False):
        """Expands the result of query2_raw with more items broken down.  See query2()
            docs for details.

        If datetimes is True, this translates times into additional datetime objects.
            (This is less handy than you would think, since mostly we want
            the times in order to sync up with the device's cycle times,
            which means wanting to count seconds vs. time())

        ** NOTE this modifies q2 (a dict) and returns it. **
        """
        d = struct.unpack('<L24B', q2['raw'])

        q2['time_elapsed'] = d[0]
        q2[ 'cycle_start'] = q2['time']-d[0]

        q2['ambientlight'] = d[2]

        AirthingsWavePlusBtle.add_datetimes(datetimes, q2, ['time', 'cycle_start'])

        return q2

    def send_time(self):
        """This sends the current time() to the device, which it
            presumably uses to adjust its internal clock.

        Use at your own risk...  I have no idea what the bad side
            effects of this may be.

        The app only does this after doing a query2, and then does
            another query2 afterward.

        This returns a 2-tuple of the time we sent, and the time
            it replied with.  I have no idea what the latter is
            but it's usually the same as the time we sent.  (Is
            it the time the device used to have, or is it just
            a confirmation of the time received, or..?)
        """
        now = int(time())

        # 0x10 <= 71:52:01:00:00:00:3b:96:e2:5d
        data = self.exec_command(struct.pack('<BBLL',0x71,0x52,1,now))
        if data is None:
            raise Exception("send_time failed")

        # -> [71:00]52:01:00:00:00:3b:96:e2:5d
        if len(data) != 9:
            raise Exception("Unexpected reply to send_time (len=%d): %s"%(len(data), data.hex()))
        d = struct.unpack('<BLL', data)
        if d[0] != 0x52:
            raise Exception("Unexpected reply to send_time (%02x != 0x52): %s"%(d[0], data.hex()))
        if d[1] != 1:
            print("# WARNING: send_time returns %d where expecting 1"%(d[1],))

        return (now, d[2])

    def query3(self, datetimes=True):
        """This is the third query the app does, which is typically done
            after send_time and before a repeat of query2.

        Like query2, this also returns a start time ('series_start'), but this
            one is the long term start time of the entire sampling series.  I'm
            not sure if this is the last time the device was reset, or if it's
            the last time its location was changed, or what...  I suspect
            it's the last time it started its 7-day calibration window.
            Critically, the historical records are stored relative to this time.

        It also returns the number of records since that time ('num_records'),
            or possibly the number of the record it's currently working on; not
            sure...
        """
        return self.query3_parse(self.query3_raw(), datetimes)
 
    def query3_raw(self):
        "See query2_raw"

        # 0x10 <= 71:54:00
        data = self.exec_command(struct.pack('<BBB', 0x71, 0x54, 0x00))
        now  = time()
        if data is None:
            raise Exception("query3 failed")

        # -> [71:00]54:a9:8e:0d:5d:00:01:29:df:26:0f:ff:ff
        if len(data) != 13:
            raise Exception("Unexpected reply to query3 (len=%d): %s"%(len(data), data.hex()))
        if data[0] != 0x54:
            raise Exception("Unexpected reply to query3 (%02x != 0x54): %s"%(d[0], data.hex()))

        return { 'raw': data, 'time': now }

    @staticmethod
    def query3_parse(q3, datetimes=False):
        "See query2_parse"

        d = struct.unpack('<BL4H', q3['raw'])

        q3['series_start'] = d[1]
        q3[ 'num_records'] = d[4]

        AirthingsWavePlusBtle.add_datetimes(datetimes, q3, ['time', 'series_start'])

        return q3

    def get_history(self, hours=1, datetimes=True):
        """Fetches the most recent specified number of hours of historic data.

        Returns a list of dicts.

        May be less than hours long if an error was encountered mid way.
        """
        return [self.parse_hour_block(hour, datetimes) for hour in self.get_history_raw(hours)]

    def get_history_raw(self, hours=1):
        """Fetches the last this-many hours of historic data.

        Returns a list of raw blocks, oldest first.

        If an error occurs mid process, the returned list may be shorter than requested.

        In that case, beware the ongoing state may be defunct (the connection may have
            been lost or the state may be out of sync -- expect subsequent activity to fail...).
        """
        rec_handle = self.get_handle(self.sensor_record_uuid)   # Cached for later in case we have to recover from a major meltdown but still want to keep what we got so far...
        oob        = []                                         # The returned (handle, data_block) tuples get stuffed in here.

        try:
            # 0x10 <= 01:02:00:00:00:0d:00:00:00
            data = self.exec_command(struct.pack('<BHHHH', 0x01, 2, 0, hours, 0), oob=oob)
            if data is None:
                raise Exception("get_history: exec_command returned None.")

            # -> [01:00]0d:00:00:00
            if len(data) != 4:
                raise Exception("Unexpected reply to get_history(len=%d): %s"%(len(data), data.hex()))

            d = struct.unpack('<L', data)
            if d[0] > hours:
                raise Exception("Unexpected reply to get_history (%d>%d): %s"%(d[0], hours, data.hex()))

        except Exception as e:  # Catching our own exceptions but also any thrown by exec_command...

            print("# WARNING: Error encountered during history fetch.  Got %d/%d hours."%(len(oob)/hours))  # Technically misleading if any from different handle, but unlikely.

        hist = [raw for handle, raw in oob if handle == rec_handle] # Should be all of them, but just in case...

        if d[0] != len(hist):
            print("# WARNING: Got %d hours but command reply implies %d."%(len(hist), d[0]))

        return hist

    @staticmethod
    def parse_hour_block(raw, datetimes=False):
        """This parses a single hour of raw history as returned by get_history.
        """
        parts = struct.unpack('<8H HH 12H 12B 12H 12H', raw[:104])

        unused = [parts[0: 8]]     # Stuff we didn't parse, so we can examine it to try to figure out what it is...
        radon  = parts[ 8:10]
        temp   = parts[10:22]
        hum    = parts[22:34]
        pres   = parts[34:46]
        co2    = parts[46:58]

        parts = [struct.unpack('<2HL', raw[104+i*8:112+i*8]) for i in range(12)]

        x4  = [p[0] for p in parts]
        voc = [p[1] for p in parts]
        x3  = [p[2] for p in parts]

        parts = struct.unpack('< 12B 6B L 4H', raw[200:230])

        light = parts[0:12]
        unused.append(parts[12:18])
        tim   = parts[18]
        unused.append(parts[19:22])
        recno = parts[22]

        d = {
                'radon'       : radon,
                'temperature' : [(t-27315)/100 for t in temp],
                'humidity'    : [h/2           for h in hum],
                'pressure'    : [p/50          for p in pres],
                'co2'         : co2,
                'voc'         : voc,
                'ambientlight': light,

                'x3'          : [x/256         for x in x3], # Not really sure if that low order byte belongs or not
                'x4'          : x4,

                'recno'       : recno,
                'series_start': tim,                        # When the current major history record began.
                'start_time'  : tim + recno*3600,           # The start time of this hour.

                'unused'      : unused,
            }

        AirthingsWavePlusBtle.add_datetimes(datetimes, d, ['series_start', 'start_time'])

        return d

    #==== Unlikely you should call anything below here... ====

    def import_value(self, dest, key, raw, bad_value, scale, raw_units, out_units, show_units):
        """This essentially implements "dest[key] = raw*scale" except:

            If raw == bad_value, then nothing is done.

            If out_units is different from raw_units, then raw*scale (in raw_units) is converted to out_units.

            If show_units is True, then dest[key] gets (value, out_units) tuple instead of just value.
        """
        if raw == bad_value:
            return
        value = raw*scale
        if out_units != raw_units:
            value = convert_units(value, raw_units, out_units)
        if show_units:
            dest[key] = (value, out_units)
        else:
            dest[key] = value

    def exec_command(self, cmd, timeout=10, oob=None):
        """This issues a command by sending the specified byte array (cmd)
            to the "Access Control Point" port, and then waits for a reply Indication
            from that same port (which is presumed to start with the same first byte as
            the command).

        Returns the payload portion of the reply (byte array), or None on timeout.

        The timeout is from the time we start waiting for the reply and doesn't include
            the command issuing time since that could vary.

        oob can be a list, in which case any out-of-band notifications will be
            appended to it, in order, as (handle, data) pairs.

        May throw an Exception if the response indicates any sort of error.
        """
        self.enable_command_port()

        self[self.command_uuid] = cmd

        cmd_handle = self.get_handle(self.command_uuid)

        start = time()
        while True:
            timeleft = timeout - (time()-start)
            if timeleft <= 0:
                return None
            handledata = self.wait_for_notification(timeleft)
            if handledata is None:
                return None
            handle, data = handledata

            if handle == cmd_handle:
                if data and data[0] == cmd[0]:
                    if len(data) < 2 or data[1]:
                        raise Exception("Exec(%s) -> %s"%(cmd.hex(), data.hex()))
                    return data[2:]
                else:
                    print("# UNEXPECTED NOTIFICATION: %s (%s)"%(data.hex(), data))
            else:
                if oob is None:
                    print("# NON-COMMAND NOTIFICATION: handle=%s data=%s"%(handle, data.hex()))
                else:
                    oob.append(handledata)

    def enable_command_port(self):
        """This enables notifications and indications, but only if it hasn't already been
            done already this connection.

        Generally you would not need to call this directly.  It is used by exec_command().
        """
        self.connect()
        if not self.enabled:
            self.enable(self.command_uuid)
            self.enable(self.sensor_record_uuid)
            self.enabled = True

    @staticmethod
    def parse_serial_number(man_data):
        """Parses the serial number from the manufacturer specific data for an Airthings Wave Plus...
        """
        if not man_data.startswith('3403') or len(man_data) < 12:
            return None
        try:
            return struct.unpack('<L', bytearray.fromhex(man_data)[2:6])[0]
        except:
            print("# Warning: Couldn't parse Airthings manufacturer's data. (Debug this?)")
            return None

    def encode_serial_number(self):
        """This returns our serial number as it would appear in manufacturer's data.
        """
        return "3403" + "%x"%struct.unpack('<L',struct.pack('>L',self.serial_number))[0]

    @staticmethod
    def add_datetimes(doit, d, keys):
        if doit:
            for f in keys:
                d[f+"_"] = datetime.fromtimestamp(d[f])


if __name__ == "__main__":

    from sys import argv
    from pprint import pprint

    if len(argv) == 3:
        pprint(AirthingsWavePlusBtle.poll(int(argv[1]), argv[2]))
    else:
        print("Use: %s <serial_number> <mac_addy>"%(argv[0],))


