#!/usr/bin/python3
#
# SPDX-License-Identifier: Apache-2.0
#

"""
This module provides functionality to connect and communicate with a remote host
using local domain socket.
"""

import re
import socket
from sys import stdout
import time
import streamexpect
from utils.install_log import LOG


def connect(hostname, port=10000, prefix=""):
    """
    Connect to local domain socket and return the socket object.

    Arguments:
    - Requires the hostname of target, e.g. controller-0
    - Requires TCP port if using Windows
    """

    if prefix:
        prefix = f"{prefix}_"

    socketname = f"/tmp/{prefix}{hostname}"
    if 'controller-0' in hostname:
        socketname += '_serial'

    LOG.info("Connecting to %s at %s", hostname, socketname)
    try:
        sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM, socket.IPPROTO_TCP)
        sock.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1)
        sock.connect(('localhost', port))

        # TODO (WEI): double check this # pylint: disable=fixme
        if sock:
            sock.setblocking(False)

    except Exception as exc:
        LOG.info("Failed sock connection")
        LOG.debug("Error:\n%s\n", repr(exc))
        if sock:
            sock.close()
            sock = None

    return sock


def disconnect(sock):
    """
    Disconnect a local doamin socket.

    Arguments:
    - Requires socket
    """

    # Shutdown connection and release resources
    LOG.info("Disconnecting from socket")
    sock.shutdown(socket.SHUT_RDWR)
    sock.close()


# pylint: disable=too-many-arguments, too-many-locals, too-many-branches
def get_output(stream, cmd, prompts=None, timeout=5, log=True, as_lines=True, flush=True):
    # pylint: disable=fixme
    # TODO: Not tested, will not work if kernel or other processes throw data on stdout or stderr
    """
    Execute a command and get its output. Make sure no other command is executing.
    And 'dmesg -D' was executed.
    """

    poll_period = 0.1
    max_read_buffer = 1024
    data = ""
    line_buf = ""
    lines = []
    if not prompts:
        prompts = [':~$ ', ':~# ', ':/home/wrsroot# ', '(keystone_.*)]$ ', '(keystone_.*)]# ']
    # Flush buffers
    if flush:
        try:
            trash = stream.poll(1)  # flush input buffers
            if trash:
                try:
                    LOG.debug("Buffer has bytes before cmd execution: %s",
                             trash.decode('utf-8'))
                except Exception as exc:
                    LOG.debug("Failed decoding buffer\nError: %s\n", repr(exc))
        except streamexpect.ExpectTimeout as exc:
            LOG.debug("Failed flushing buffer\nError: %s\n", repr(exc))

    # Send command
    stream.sendall(f"{cmd}\n".encode('utf-8'))

    # Get response
    patterns = []
    for prompt in prompts:
        patterns.append(re.compile(prompt))

    now = time.time()
    end_time = now + float(timeout)
    prev_timeout = stream.gettimeout()
    stream.settimeout(poll_period)
    incoming = None
    # pylint: disable=too-many-nested-blocks
    try:
        while (end_time - now) >= 0:
            try:
                incoming = stream.recv(max_read_buffer)
            except socket.timeout as exc:
                LOG.debug("Failed reading buffer\nError: %s\n", repr(exc))
            if incoming:
                data += incoming
                if log:
                    for char in incoming:
                        if char != '\n':
                            line_buf += char
                        else:
                            LOG.info(line_buf)
                            lines.append(line_buf)
                            line_buf = ""
                for pattern in patterns:
                    if pattern.search(data):
                        if as_lines:
                            return lines
                        return data
                now = time.time()
        raise streamexpect.ExpectTimeout()
    finally:
        stream.settimeout(prev_timeout)


def expect_bytes(stream, text, timeout=180, fail_ok=False, flush=True, log=True):
    """
    Wait for user specified text from stream.
    """

    time.sleep(1)

    if log:
        if timeout < 60:
            LOG.info("Expecting text within %s seconds: %s\n", timeout, text)
        else:
            LOG.info("Expecting text within %s minutes: %s\n", timeout / 60, text)

    try:
        stream.expect_bytes(f"{text}".encode('utf-8'), timeout=timeout)
    except streamexpect.ExpectTimeout:
        if fail_ok:
            return -1

        stdout.write('\n')
        LOG.error("Did not find expected text")
        # disconnect(stream)
        raise
    except Exception as exc:
        LOG.debug("Failed connection\nError: %s\n", repr(exc))
        raise

    stdout.write('\n')
    if log:
        LOG.debug("Found expected text: %s", text)

    time.sleep(1)
    if flush:
        try:
            incoming = stream.poll(1)  # flush input buffers
            if incoming:
                incoming += b'\n'
                try:
                    if log:
                        LOG.debug(">>> expect_bytes: Buffer has bytes!")
                    stdout.write(incoming.decode('utf-8'))  # streamexpect hardcodes it
                except Exception as exc:
                    LOG.debug("Failed decoding buffer\nError: %s\n", repr(exc))
        except streamexpect.ExpectTimeout as exc:
            LOG.debug("Failed flushing buffer\nError: %s\n", repr(exc))
    return 0


# pylint: disable=inconsistent-return-statements
def send_bytes(stream, command, fail_ok=False, expect_prompt=True,
               prompt=None, timeout=180, send=True, flush=True, log=True):
    """
    Send user specified command to stream.
    """

    time.sleep(1)
    if flush:
        try:
            incoming = stream.poll(1)  # flush input buffers
            if incoming:
                incoming += b'\n'
                try:
                    LOG.debug(">>> send_bytes: Buffer has bytes!")
                    stdout.write(incoming.decode('utf-8'))  # streamexpect hardcodes it
                except Exception as exc:
                    LOG.debug("Failed decoding buffer\nError: %s\n", repr(exc))
        except streamexpect.ExpectTimeout as exc:
            LOG.debug("Failed flushing buffer\nError: %s\n", repr(exc))

    if log:
        LOG.info("Sending command: %s", command)
    try:
        if send:
            stream.sendall(f"{command}\n".encode('utf-8'))
        else:
            stream.sendall(f"{command}".encode('utf-8'))
        if expect_prompt:
            time.sleep(1)
            if prompt:
                return expect_bytes(stream, prompt, timeout=timeout, fail_ok=fail_ok)

            return_code = expect_bytes(stream, "~$", timeout=timeout, fail_ok=True)
            if return_code != 0:
                send_bytes(stream, '\n', expect_prompt=False)
                expect_bytes(stream, 'keystone', timeout=timeout)
                return
    except streamexpect.ExpectTimeout:
        if fail_ok:
            return -1

        LOG.error("Failed to send command, logging out.")
        stream.sendall("exit".encode('utf-8'))
        raise
    except Exception as exc:
        LOG.error("Connection failed")
        LOG.debug("Failed flushing buffer\nError: %s\n", repr(exc))
        raise

    return 0