# coding: utf-8
# vim: ts=4 sw=4 et ai:
from __future__ import print_function, unicode_literals

"a"


import logging
import os
import random
import time

from .TftpPacketTypes import *
from .TftpShared import *

if TYPE_CHECKING:
    from .TftpContexts import TftpContext


log = logging.getLogger("partftpy.TftpStates")



class TftpState(object):
    "a"

    def __init__(self, context):
        "a"
        self.context = context

    def handle(self, pkt, raddress, rport):
        "a"
        raise NotImplementedError("Abstract method")

    def handleOACK(self, pkt):
        "a"
        if len(pkt.options.keys()) > 0:
            if pkt.match_options(self.context.options):
                log.info("Successful negotiation of options")

                self.context.options = pkt.options
                tsize = pkt.options.get("tsize")
                if tsize:
                    self.context.metrics.tsize = tsize
                for k, v in self.context.options.items():
                    log.info("    %s = %s", k, v)
            else:
                log.error("Failed to negotiate options")
                raise TftpException("Failed to negotiate options")
        else:
            raise TftpException("No options found in OACK")

    def returnSupportedOptions(self, options):
        "a"

        accepted_options = {}
        for option, optval in options.items():
            if optval == "":
                log.info("Option ignored due to blank value: '%s'", option)
            elif option == "blksize":

                if int(options[option]) > MAX_BLKSIZE:
                    log.info(
                        "Client requested blksize greater than %d; setting to maximum",
                        MAX_BLKSIZE,
                    )
                    accepted_options[option] = MAX_BLKSIZE
                elif int(options[option]) < MIN_BLKSIZE:
                    log.info(
                        "Client requested blksize less than %d; setting to minimum",
                        MIN_BLKSIZE,
                    )
                    accepted_options[option] = MIN_BLKSIZE
                else:
                    accepted_options[option] = options[option]
            elif option == "tsize":
                log.debug("tsize option is set")
                accepted_options["tsize"] = 0
            else:
                log.info("Dropping unsupported option '%s'", option)
        log.debug("Returning these accepted options: %s", accepted_options)
        return accepted_options

    def sendDAT(self):
        "a"
        finished = False
        blocknumber = self.context.next_block

        if DELAY_BLOCK and DELAY_BLOCK == blocknumber:
            log.debug("Deliberately delaying 10 seconds...")
            time.sleep(10)
        dat = None
        blksize = self.context.options["blksize"]
        buffer = self.context.fileobj.read(blksize)
        log.debug("Read %d bytes into buffer", len(buffer))
        if len(buffer) < blksize:
            log.info("Reached EOF on file %s", self.context.file_to_transfer)
            finished = True
        dat = TftpPacketDAT()
        dat.data = buffer
        dat.blocknumber = blocknumber
        self.context.metrics.bytes += len(dat.data)
        self.context.metrics.packets += 1

        if NETWORK_UNRELIABILITY > 0 and random.randrange(NETWORK_UNRELIABILITY) == 0:
            log.warning("Skipping DAT packet %d for testing", dat.blocknumber)
        else:
            log.debug("Sending DAT packet %d", dat.blocknumber)
            self.context.sock.sendto(
                dat.encode().buffer, (self.context.host, self.context.tidport)
            )
            self.context.metrics.last_dat_time = time.time()
        if self.context.packethook:
            self.context.packethook(dat, self.context)
        self.context.last_pkt = dat
        return finished

    def sendACK(self, blocknumber=None):
        "a"
        log.debug("In sendACK, passed blocknumber is %s", blocknumber)
        if blocknumber is None:
            blocknumber = self.context.next_block
        log.debug("Sending ack to block %d", blocknumber)
        ackpkt = TftpPacketACK()
        ackpkt.blocknumber = blocknumber

        if NETWORK_UNRELIABILITY > 0 and random.randrange(NETWORK_UNRELIABILITY) == 0:
            log.warning("Skipping ACK packet %d for testing", ackpkt.blocknumber)
        else:
            self.context.sock.sendto(
                ackpkt.encode().buffer, (self.context.host, self.context.tidport)
            )
        self.context.last_pkt = ackpkt

    def sendError(self, errorcode):
        "a"
        log.debug("In sendError, being asked to send error %d", errorcode)
        errpkt = TftpPacketERR()
        errpkt.errorcode = errorcode
        if self.context.tidport is None:
            log.debug("Error packet received outside session. Discarding")
        else:
            self.context.sock.sendto(
                errpkt.encode().buffer, (self.context.host, self.context.tidport)
            )
        self.context.last_pkt = errpkt

    def sendOACK(self):
        "a"
        log.debug("In sendOACK with options %s", self.context.options)
        pkt = TftpPacketOACK()
        pkt.options = self.context.options
        self.context.sock.sendto(
            pkt.encode().buffer, (self.context.host, self.context.tidport)
        )
        self.context.last_pkt = pkt

    def resendLast(self):
        "a"
        assert self.context.last_pkt is not None
        log.warning("Resending packet %s on sessions %s", self.context.last_pkt, self)
        self.context.metrics.resent_bytes += len(self.context.last_pkt.buffer)
        self.context.metrics.resent_packets += 1
        self.context.metrics.add_dup(self.context.last_pkt)
        sendto_port = self.context.tidport
        if not sendto_port:

            sendto_port = self.context.port
        self.context.sock.sendto(
            self.context.last_pkt.encode().buffer, (self.context.host, sendto_port)
        )
        if self.context.packethook:
            self.context.packethook(self.context.last_pkt, self.context)

    def handleDat(self, pkt):

        "a"
        log.debug("Handling DAT packet - block %d", pkt.blocknumber)
        log.debug("Expecting block %s", self.context.next_block)
        if pkt.blocknumber == self.context.next_block:
            log.debug("Good, received block %d in sequence", pkt.blocknumber)

            self.sendACK()
            self.context.next_block += 1

            log.debug("Writing %d bytes to output file", len(pkt.data))
            self.context.fileobj.write(pkt.data)
            self.context.metrics.bytes += len(pkt.data)
            self.context.metrics.packets += 1

            if len(pkt.data) < self.context.options["blksize"]:
                log.info("End of file detected")
                return None

        elif pkt.blocknumber < self.context.next_block:
            if pkt.blocknumber == 0:
                log.warning("There is no block zero!")
                self.sendError(TftpErrors.IllegalTftpOp)
                raise TftpException("There is no block zero!")
            log.warning("Dropping duplicate block %d", pkt.blocknumber)
            self.context.metrics.add_dup(pkt)
            log.debug("ACKing block %d again, just in case", pkt.blocknumber)
            self.sendACK(pkt.blocknumber)

        else:

            msg = "Whoa! Received future block %d but expected %d" % (
                pkt.blocknumber,
                self.context.next_block,
            )
            log.error(msg)
            raise TftpException(msg)

        return TftpStateExpectDAT(self.context)


class TftpServerState(TftpState):
    "a"

    def __init__(self, context):
        TftpState.__init__(self, context)

        self.full_path = None

    def serverInitial(self, pkt, raddress, rport):
        "a"
        options = pkt.options
        sendoack = False
        if not self.context.tidport:
            self.context.tidport = rport
            log.debug("Setting tidport to %s", rport)

        log.debug("Setting default options, blksize")
        self.context.options = {"blksize": DEF_BLKSIZE}

        if options:
            log.debug("Options requested: %s", options)
            supported_options = self.returnSupportedOptions(options)
            self.context.options.update(supported_options)
            sendoack = True

        if pkt.mode != "octet":

            log.warning("Received non-octet mode request. I'll reply with binary data.")

        if self.context.host != raddress or self.context.port != rport:
            self.sendError(TftpErrors.UnknownTID)
            log.error(
                "Expected traffic from %s:%s but received it from %s:%s instead.",
                self.context.host,
                self.context.port,
                raddress,
                rport,
            )

            return self

        log.debug("Requested filename is %s", pkt.filename)

        if pkt.filename.startswith(self.context.root):
            full_path = pkt.filename
        else:
            full_path = os.path.join(self.context.root, pkt.filename.lstrip("/"))

        self.full_path = os.path.abspath(full_path)
        log.debug("full_path is %s", full_path)
        if self.full_path.startswith(os.path.normpath(self.context.root) + os.sep):
            log.debug("requested file is in the server root - good")
        else:
            log.warning("requested file is not within the server root - bad")
            self.sendError(TftpErrors.IllegalTftpOp)
            raise TftpException("bad file path")

        self.context.file_to_transfer = pkt.filename

        return sendoack


class TftpStateServerRecvRRQ(TftpServerState):
    "a"

    def handle(self, pkt, raddress, rport):
        "a"
        log.debug("In TftpStateServerRecvRRQ.handle")
        sendoack = self.serverInitial(pkt, raddress, rport)
        path = self.full_path
        log.info("Opening file %s for reading", path)
        if os.path.exists(path):

            self.context.fileobj = open(path, "rb")
        elif self.context.dyn_file_func:
            log.debug("No such file %s but using dyn_file_func", path)
            self.context.fileobj = self.context.dyn_file_func(
                self.context.file_to_transfer, raddress=raddress, rport=rport
            )

            if self.context.fileobj is None:
                log.debug("dyn_file_func returned 'None', treating as " "FileNotFound")
                self.sendError(TftpErrors.FileNotFound)
                raise TftpException("File not found: %s" % path)
        else:
            log.warning("File not found: %s", path)
            self.sendError(TftpErrors.FileNotFound)
            raise TftpException("File not found: %s" % (path,))

        if sendoack and "tsize" in self.context.options:

            self.context.fileobj.seek(0, os.SEEK_END)
            tsize = str(self.context.fileobj.tell())
            self.context.fileobj.seek(0, 0)
            self.context.options["tsize"] = tsize
            self.context.metrics.tsize = tsize

        if sendoack:

            self.sendOACK()

        else:
            self.context.next_block = 1
            log.debug("No requested options, starting send...")
            self.context.pending_complete = self.sendDAT()

        return TftpStateExpectACK(self.context)



class TftpStateServerRecvWRQ(TftpServerState):
    "a"

    def make_subdirs(self):
        "a"

        subpath = self.full_path[len(self.context.root) :]
        log.debug("make_subdirs: subpath is %s", subpath)

        dirs = subpath.split(os.sep)[:-1]
        log.debug("dirs is %s", dirs)
        current = self.context.root
        for dir in dirs:
            if dir:
                current = os.path.join(current, dir)
                if os.path.isdir(current):
                    log.debug("%s is already an existing directory", current)
                else:
                    os.mkdir(current, 0o700)

    def handle(self, pkt, raddress, rport):
        "a"
        log.debug("In TftpStateServerRecvWRQ.handle")
        sendoack = self.serverInitial(pkt, raddress, rport)
        path = self.full_path
        if self.context.upload_open:
            f = self.context.upload_open(path, self.context)
            if f is None:
                self.sendError(TftpErrors.AccessViolation)
                raise TftpException("Dynamic path %s not permitted" % path)
            else:
                self.context.fileobj = f
        else:
            log.info("Opening file %s for writing", path)
            if os.path.exists(path):

                log.warning(
                    "File %s exists already, overwriting...",
                    self.context.file_to_transfer,
                )

            self.make_subdirs()
            self.context.fileobj = open(path, "wb")

        if sendoack:
            log.debug("Sending OACK to client")
            self.sendOACK()
        else:
            log.debug("No requested options, expecting transfer to begin...")
            self.sendACK()

        self.context.next_block = 1

        return TftpStateExpectDAT(self.context)



class TftpStateServerStart(TftpState):
    "a"

    def handle(self, pkt, raddress, rport):
        "a"
        log.debug("In TftpStateServerStart.handle")
        if isinstance(pkt, TftpPacketRRQ):
            log.debug("Handling an RRQ packet")
            return TftpStateServerRecvRRQ(self.context).handle(pkt, raddress, rport)
        elif isinstance(pkt, TftpPacketWRQ):
            log.debug("Handling a WRQ packet")
            return TftpStateServerRecvWRQ(self.context).handle(pkt, raddress, rport)
        else:
            self.sendError(TftpErrors.IllegalTftpOp)
            raise TftpException("Invalid packet to begin up/download: %s" % pkt)


class TftpStateExpectACK(TftpState):
    "a"

    def handle(self, pkt, raddress, rport):
        "a"
        if isinstance(pkt, TftpPacketACK):
            log.debug("Received ACK for packet %d", pkt.blocknumber)

            if self.context.next_block == pkt.blocknumber:
                if self.context.pending_complete:
                    log.info("Received ACK to final DAT, we're done.")
                    return None
                else:
                    log.debug("Good ACK, sending next DAT")
                    self.context.next_block += 1
                    log.debug("Incremented next_block to %d", self.context.next_block)
                    self.context.pending_complete = self.sendDAT()

            elif pkt.blocknumber < self.context.next_block:
                log.warning("Received duplicate ACK for block %d", pkt.blocknumber)
                self.context.metrics.add_dup(pkt)
                if self.context.metrics.last_dat_time > 0:
                    if (
                        time.time() - self.context.metrics.last_dat_time
                        > self.context.timeout
                    ):
                        raise TftpTimeoutExpectACK(
                            "Timeout waiting for ACK for block %d"
                            % self.context.next_block
                        )

            else:
                log.warning(
                    "Oooh, time warp. Received ACK to packet we "
                    "didn't send yet. Discarding."
                )
                self.context.metrics.errors += 1
            return self
        elif isinstance(pkt, TftpPacketERR):
            log.error("Received ERR packet from peer: %s", str(pkt))
            raise TftpException("Received ERR packet from peer: %s" % str(pkt))
        else:
            log.warning("Discarding unsupported packet: %s", str(pkt))
            return self


class TftpStateExpectDAT(TftpState):
    "a"

    def handle(self, pkt, raddress, rport):
        "a"
        if isinstance(pkt, TftpPacketDAT):
            return self.handleDat(pkt)

        elif isinstance(pkt, TftpPacketACK):

            self.sendError(TftpErrors.IllegalTftpOp)
            raise TftpException("Received ACK from peer when expecting DAT")

        elif isinstance(pkt, TftpPacketWRQ):
            self.sendError(TftpErrors.IllegalTftpOp)
            raise TftpException("Received WRQ from peer when expecting DAT")

        elif isinstance(pkt, TftpPacketERR):
            self.sendError(TftpErrors.IllegalTftpOp)
            raise TftpException("Received ERR from peer: " + str(pkt))

        elif isinstance(pkt, TftpPacketACK) or isinstance(pkt, TftpPacketOACK):
            log.warning("Discarding unexpected packet type (retransmission?): %s", pkt)
            return self

        else:
            self.sendError(TftpErrors.IllegalTftpOp)
            raise TftpException("Received unknown packet type from peer: " + str(pkt))


class TftpStateSentWRQ(TftpState):
    "a"

    def handle(self, pkt, raddress, rport):
        "a"
        if not self.context.tidport:
            self.context.tidport = rport
            log.debug("Set remote port for session to %s", rport)

        if isinstance(pkt, TftpPacketOACK):
            log.info("Received OACK from server")
            try:
                self.handleOACK(pkt)
            except TftpException:
                log.error("Failed to negotiate options")
                self.sendError(TftpErrors.FailedNegotiation)
                raise
            else:
                log.debug("Sending first DAT packet")
                self.context.pending_complete = self.sendDAT()
                log.debug("Changing state to TftpStateExpectACK")
                return TftpStateExpectACK(self.context)

        elif isinstance(pkt, TftpPacketACK):
            log.info("Received ACK from server")
            log.debug("Apparently the server ignored our options")

            if pkt.blocknumber == 0:
                log.debug("Ack blocknumber is zero as expected")
                log.debug("Sending first DAT packet")
                self.context.pending_complete = self.sendDAT()
                log.debug("Changing state to TftpStateExpectACK")
                return TftpStateExpectACK(self.context)
            else:
                log.warning("Discarding ACK to block %s", pkt.blocknumber)
                log.debug("Still waiting for valid response from server")
                return self

        elif isinstance(pkt, TftpPacketERR):
            raise TftpException("Received ERR from server: %s" % pkt)

        elif isinstance(pkt, TftpPacketRRQ):
            self.sendError(TftpErrors.IllegalTftpOp)
            raise TftpException("Received RRQ from server while in upload")

        elif isinstance(pkt, TftpPacketDAT):
            self.sendError(TftpErrors.IllegalTftpOp)
            raise TftpException("Received DAT from server while in upload")

        else:
            self.sendError(TftpErrors.IllegalTftpOp)
            raise TftpException("Received unknown packet type from server: %s" % pkt)

        return self


class TftpStateSentRRQ(TftpState):
    "a"

    def handle(self, pkt, raddress, rport):
        "a"
        if not self.context.tidport:
            self.context.tidport = rport
            log.info("Set remote port for session to %s", rport)

        if isinstance(pkt, TftpPacketOACK):
            log.info("Received OACK from server")
            try:
                self.handleOACK(pkt)
            except TftpException as err:
                log.error("Failed to negotiate options: %s", str(err))
                self.sendError(TftpErrors.FailedNegotiation)
                raise
            else:
                log.debug("Sending ACK to OACK")

                self.sendACK(blocknumber=0)

                log.debug("Changing state to TftpStateExpectDAT")
                return TftpStateExpectDAT(self.context)

        elif isinstance(pkt, TftpPacketDAT):

            log.info("Received DAT from server")
            if self.context.options:
                log.info("Server ignored options, falling back to defaults")
                self.context.options = {"blksize": DEF_BLKSIZE}
            return self.handleDat(pkt)

        elif isinstance(pkt, TftpPacketACK):

            self.sendError(TftpErrors.IllegalTftpOp)
            raise TftpException("Received ACK from server while in download")

        elif isinstance(pkt, TftpPacketWRQ):
            self.sendError(TftpErrors.IllegalTftpOp)
            raise TftpException("Received WRQ from server while in download")

        elif isinstance(pkt, TftpPacketERR):
            log.debug("Received ERR packet: %s", pkt)
            if pkt.errorcode == TftpErrors.FileNotFound:
                raise TftpFileNotFoundError("File not found")
            else:
                raise TftpException("Received ERR from server: %s" % (pkt,))

        else:
            self.sendError(TftpErrors.IllegalTftpOp)
            raise TftpException("Received unknown packet type from server: %s" % pkt)

        return self
