package org.mockejb.jms;

import javax.jms.*;
import java.io.*;

/**
 * <code>BytesMessage</code> implementation.
 * @author Dimitar Gospodinov
 * @see javax.jms.BytesMessage
 */
public class BytesMessageImpl extends MessageImpl implements BytesMessage {

    private final ByteArrayOutputStream bytesOut = new ByteArrayOutputStream();
    private ByteArrayInputStream bytesIn;
    private DataInputStream dataIn;
    private DataOutputStream dataOut;

    /**
     * Creates empty <code>BytesMessage</code>
     */
    public BytesMessageImpl() throws JMSException {
        clearBody();
    }

    /**
     * Creates <code>BytesMessage</code> and copies its header, properties and
     * body from <code>msg<code>
     * The state of <code>msg</code> is not changed.
     * @param msg message to copy from
     * @throws JMSException
     */
    public BytesMessageImpl(BytesMessage msg) throws JMSException {
        super(msg);
        clearBody();
        setBody(msg);
    }

    /**
     * Extracts body data from <code>msg</code> and sets it to this message.
     * The state of <code>msg</code> is not changed.
     * @param msg
     * @throws JMSException
     */
    private void setBody(BytesMessage msg) throws JMSException {
        // Number of bytes remaining until the end of the byte stream
        int bytesRemaining = 0;
        byte[] sourceBytes = null;

        while (true) {
            /*
             * Try to read from the message. If the read succeeds then read until EOF.
             * This will give us current position in the byte stream.
             * We have to restore to this position before the method returns.
             * 
             * If the read fails with MessageNotReadableException then proceed to extracting
             * the body.
             */
            try {
                msg.readByte();
                bytesRemaining++;
            } catch (MessageEOFException ex) {
                /*
                 * Reached EOF.
                 * Now <code>bytesRemaining</code> contains number of bytes
                 * remaining in the byte stream of <code>msg</code>
                 * Next - reset <code>msg</code> and "extract" all bytes into
                 * <code>sourceBytes</code>
                 */
                sourceBytes = extractBytes(msg);
                /* 
                 * Reset the message and move to the original position in the byte stream
                 * by reading the required number of bytes.
                 */
                msg.reset();
                msg.readBytes(new byte[sourceBytes.length - bytesRemaining]);
                break;
            } catch (MessageNotReadableException ex) {
                // Message is in Write-Only mode
                sourceBytes = extractBytes(msg);
                /*
                 *  At this point <code>msg</code> is in Read-Only mode
                 *  Switch to Write-Only mode and restore the original byte stream.
                 */
                msg.clearBody();
                msg.writeBytes(sourceBytes);
                break;
            }
        }
        writeBytes(sourceBytes);
    }

    /**
     * Calls <code>msg.reset</code> and counts number of bytes in the message.
     * The implementation is trivial and efficiency is ignored.
     * Creates byte array, populates it with all bytes from <code>msg</code> and
     * returns it.
     * @param msg
     * @return <code>byte[]</code> that represents the body of <code>msg</code>
     * @throws JMSException
     */
    private byte[] extractBytes(BytesMessage msg) throws JMSException {
        msg.reset();
        int numberOfBytes = 0;
        while (true) {
            try {
                msg.readByte();
                numberOfBytes++;
            } catch (MessageEOFException ex) {
                break;
            }
        }
        byte[] result = new byte[numberOfBytes];
        msg.reset();
        msg.readBytes(result);
        return result;
    }

    /**
     * @see javax.jms.BytesMessage#reset()
     */
    public void reset() {
        setBodyReadOnly();
        if (dataOut != null) {
            dataOut = null;
        }
        bytesIn = new ByteArrayInputStream(bytesOut.toByteArray());
        dataIn = new DataInputStream(bytesIn);
    }

    /**
     * @see javax.jms.BytesMessage#clearBody()
     */
    public void clearBody() throws JMSException {
        super.clearBody();
        if (dataIn != null) {
            dataIn = null;
            bytesIn = null;
        }
        bytesOut.reset();
        dataOut = new DataOutputStream(bytesOut);
    }

    /**
     * Returns body length. Method is part of <code>BytesMessage</code>
     * interface, version 1.1
     * @return body length.
     */
    public long getBodyLength() throws JMSException {
        checkBodyReadable();
        return bytesOut.size();
    }

    /**
     * @see javax.jms.BytesMessage#readBoolean()
     */
    public boolean readBoolean() throws JMSException {
        checkBodyReadable();
        try {
            return dataIn.readBoolean();
        } catch (EOFException ex) {
            throw new MessageEOFException(ex.getMessage());
        } catch (IOException ex) {
            throw new JMSException(ex.getMessage());
        }
    }

    /**
     * @see javax.jms.BytesMessage#readByte()
     */
    public byte readByte() throws JMSException {
        checkBodyReadable();
        try {
            return dataIn.readByte();
        } catch (EOFException ex) {
            throw new MessageEOFException(ex.getMessage());
        } catch (IOException ex) {
            throw new JMSException(ex.getMessage());
        }
    }

    /**
     * @see javax.jms.BytesMessage#readUnsignedByte()
     */
    public int readUnsignedByte() throws JMSException {
        checkBodyReadable();
        try {
            return dataIn.readUnsignedByte();
        } catch (EOFException ex) {
            throw new MessageEOFException(ex.getMessage());
        } catch (IOException ex) {
            throw new JMSException(ex.getMessage());
        }
    }

    /**
     * @see javax.jms.BytesMessage#readShort()
     */
    public short readShort() throws JMSException {
        checkBodyReadable();
        try {
            return dataIn.readShort();
        } catch (EOFException ex) {
            throw new MessageEOFException(ex.getMessage());
        } catch (IOException ex) {
            throw new JMSException(ex.getMessage());
        }
    }

    /**
     * @see javax.jms.BytesMessage#readUnsignedShort()
     */
    public int readUnsignedShort() throws JMSException {
        checkBodyReadable();
        try {
            return dataIn.readUnsignedShort();
        } catch (EOFException ex) {
            throw new MessageEOFException(ex.getMessage());
        } catch (IOException ex) {
            throw new JMSException(ex.getMessage());
        }
    }

    /**
     * @see javax.jms.BytesMessage#readChar()
     */
    public char readChar() throws JMSException {
        checkBodyReadable();
        try {
            return dataIn.readChar();
        } catch (EOFException ex) {
            throw new MessageEOFException(ex.getMessage());
        } catch (IOException ex) {
            throw new JMSException(ex.getMessage());
        }
    }

    /**
     * @see javax.jms.BytesMessage#readInt()
     */
    public int readInt() throws JMSException {
        checkBodyReadable();
        try {
            return dataIn.readInt();
        } catch (EOFException ex) {
            throw new MessageEOFException(ex.getMessage());
        } catch (IOException ex) {
            throw new JMSException(ex.getMessage());
        }
    }

    /**
     * @see javax.jms.BytesMessage#readLong()
     */
    public long readLong() throws JMSException {
        checkBodyReadable();
        try {
            return dataIn.readLong();
        } catch (EOFException ex) {
            throw new MessageEOFException(ex.getMessage());
        } catch (IOException ex) {
            throw new JMSException(ex.getMessage());
        }
    }

    /**
     * @see javax.jms.BytesMessage#readFloat()
     */
    public float readFloat() throws JMSException {
        checkBodyReadable();
        try {
            return dataIn.readFloat();
        } catch (EOFException ex) {
            throw new MessageEOFException(ex.getMessage());
        } catch (IOException ex) {
            throw new JMSException(ex.getMessage());
        }
    }

    /**
     * @see javax.jms.BytesMessage#readDouble()
     */
    public double readDouble() throws JMSException {
        checkBodyReadable();
        try {
            return dataIn.readDouble();
        } catch (EOFException ex) {
            throw new MessageEOFException(ex.getMessage());
        } catch (IOException ex) {
            throw new JMSException(ex.getMessage());
        }
    }

    /**
     * @see javax.jms.BytesMessage#readUTF()
     */
    public String readUTF() throws JMSException {
        checkBodyReadable();
        String result;
        try {
            try {
                dataIn.mark(dataIn.available());
                result = dataIn.readUTF();
                dataIn.mark(0);
            } catch (UTFDataFormatException ex) {
                throw new MessageFormatException(ex.getMessage());
            } catch (EOFException ex) {
                throw new MessageEOFException(ex.getMessage());
            } finally {
                dataIn.reset();
                dataIn.mark(0);
            }
        } catch (IOException ex) {
            throw new JMSException(ex.getMessage());
        }
        return result;
    }

    /**
     * @see javax.jms.BytesMessage#readBytes(byte[])
     */
    public int readBytes(byte[] bytes) throws JMSException {
        checkBodyReadable();
        try {
            return dataIn.read(bytes);
        } catch (IOException ex) {
            throw new JMSException(ex.getMessage());
        }
    }

    /**
     * @see javax.jms.BytesMessage#readBytes(byte[], int)
     */
    public int readBytes(byte[] bytes, int length) throws JMSException {
        checkBodyReadable();
        try {
            return dataIn.read(bytes, 0, length);
        } catch (IOException ex) {
            throw new JMSException(ex.getMessage());
        }
    }

    /**
     * @see javax.jms.BytesMessage#writeBoolean(boolean)
     */
    public void writeBoolean(boolean value) throws JMSException {
        checkBodyWriteable();
        try {
            dataOut.writeBoolean(value);
        } catch (IOException ex) {
            throw new JMSException(ex.getMessage());
        }
    }

    /**
     * @see javax.jms.BytesMessage#writeByte(byte)
     */
    public void writeByte(byte value) throws JMSException {
        checkBodyWriteable();
        try {
            dataOut.writeByte(value);
        } catch (IOException ex) {
            throw new JMSException(ex.getMessage());
        }
    }

    /**
     * @see javax.jms.BytesMessage#writeShort(short)
     */
    public void writeShort(short value) throws JMSException {
        checkBodyWriteable();
        try {
            dataOut.writeShort(value);
        } catch (IOException ex) {
            throw new JMSException(ex.getMessage());
        }
    }

    /**
     * @see javax.jms.BytesMessage#writeChar(char)
     */
    public void writeChar(char value) throws JMSException {
        checkBodyWriteable();
        try {
            dataOut.writeChar(value);
        } catch (IOException ex) {
            throw new JMSException(ex.getMessage());
        }
    }

    /**
     * @see javax.jms.BytesMessage#writeInt(int)
     */
    public void writeInt(int value) throws JMSException {
        checkBodyWriteable();
        try {
            dataOut.writeInt(value);
        } catch (IOException ex) {
            throw new JMSException(ex.getMessage());
        }
    }

    /**
     * @see javax.jms.BytesMessage#writeLong(long)
     */
    public void writeLong(long value) throws JMSException {
        checkBodyWriteable();
        try {
            dataOut.writeLong(value);
        } catch (IOException ex) {
            throw new JMSException(ex.getMessage());
        }
    }

    /**
     * @see javax.jms.BytesMessage#writeFloat(float)
     */
    public void writeFloat(float value) throws JMSException {
        checkBodyWriteable();
        try {
            dataOut.writeFloat(value);
        } catch (IOException ex) {
            throw new JMSException(ex.getMessage());
        }
    }

    /**
     * @see javax.jms.BytesMessage#writeDouble(double)
     */
    public void writeDouble(double value) throws JMSException {
        checkBodyWriteable();
        try {
            dataOut.writeDouble(value);
        } catch (IOException ex) {
            throw new JMSException(ex.getMessage());
        }
    }

    /**
     * @see javax.jms.BytesMessage#writeUTF(java.lang.String)
     */
    public void writeUTF(String value) throws JMSException {
        checkBodyWriteable();
        try {
            dataOut.writeUTF(value);
        } catch (IOException ex) {
            throw new JMSException(ex.getMessage());
        }
    }

    /**
     * @see javax.jms.BytesMessage#writeBytes(byte[])
     */
    public void writeBytes(byte[] bytes) throws JMSException {
        checkBodyWriteable();
        try {
            dataOut.write(bytes);
        } catch (IOException ex) {
            throw new JMSException(ex.getMessage());
        }
    }

    /**
     * @see javax.jms.BytesMessage#writeBytes(byte[], int, int)
     */
    public void writeBytes(byte[] bytes, int offset, int length)
        throws JMSException {

        checkBodyWriteable();
        try {
            dataOut.write(bytes, offset, length);
        } catch (IOException ex) {
            throw new JMSException(ex.getMessage());
        }
    }

    /**
     * @see javax.jms.BytesMessage#writeObject(java.lang.Object)
     */
    public void writeObject(Object value) throws JMSException {
        checkBodyWriteable();

        if (value == null) {
            throw new NullPointerException();
        }

        if (value instanceof Boolean) {
            writeBoolean(((Boolean) value).booleanValue());
        }
        if (value instanceof Byte) {
            writeByte(((Byte) value).byteValue());
        }
        if (value instanceof Short) {
            writeShort(((Short) value).shortValue());
        }
        if (value instanceof Integer) {
            writeInt(((Integer) value).intValue());
        }
        if (value instanceof Long) {
            writeLong(((Long) value).longValue());
        }
        if (value instanceof Float) {
            writeFloat(((Float) value).floatValue());
        }
        if (value instanceof Double) {
            writeDouble(((Double) value).doubleValue());
        }
        if (value instanceof String) {
            writeUTF((String) value);
        }
        if (value instanceof byte[]) {
            writeBytes((byte[]) value);
        }

        throw new MessageFormatException("Incorrect object type!");
    }

    // Non-standard methods

    /**
        * Sets message body in read-only mode.
        * @throws JMSException
        */
    void resetBody() throws JMSException {
        reset();
    }

}