use bytes::{Buf, BufMut, Bytes, BytesMut}; use tokio_util::codec::{Decoder, Encoder, LengthDelimitedCodec}; // Opcodes pub const OP_PUBLISH_REQ: u8 = 0x01; pub const OP_PUBLISH_RES: u8 = 0x81; pub const OP_SUBSCRIBE_REQ: u8 = 0x02; pub const OP_SUBSCRIBE_RES: u8 = 0x82; pub const OP_ACK_REQ: u8 = 0x03; pub const OP_ACK_RES: u8 = 0x83; pub const OP_SUBSCRIBE_END: u8 = 0x84; pub const OP_ERROR: u8 = 0xFE; /// A decoded frame: opcode + capnp payload bytes. pub struct Frame { pub opcode: u8, pub payload: Bytes, } /// Codec that wraps `LengthDelimitedCodec` and prepends a 1-byte opcode. /// /// Wire format: `[4-byte big-endian frame length][1-byte opcode][capnp payload]` pub struct SqCodec { inner: LengthDelimitedCodec, } impl SqCodec { pub fn new() -> Self { Self { inner: LengthDelimitedCodec::builder() .max_frame_length(16 * 1024 * 1024) // 16 MB .new_codec(), } } } impl Default for SqCodec { fn default() -> Self { Self::new() } } impl Decoder for SqCodec { type Item = Frame; type Error = std::io::Error; fn decode(&mut self, src: &mut BytesMut) -> Result, Self::Error> { match self.inner.decode(src)? { Some(mut buf) => { if buf.is_empty() { return Err(std::io::Error::new( std::io::ErrorKind::InvalidData, "empty frame", )); } let opcode = buf.get_u8(); let payload = buf.freeze(); Ok(Some(Frame { opcode, payload })) } None => Ok(None), } } } impl Encoder for SqCodec { type Error = std::io::Error; fn encode(&mut self, item: Frame, dst: &mut BytesMut) -> Result<(), Self::Error> { let mut buf = BytesMut::with_capacity(1 + item.payload.len()); buf.put_u8(item.opcode); buf.extend_from_slice(&item.payload); self.inner.encode(buf.freeze(), dst) } } /// Serialize a capnp message builder into bytes. pub fn serialize_capnp(builder: &capnp::message::Builder) -> Bytes { let mut buf = Vec::new(); capnp::serialize::write_message(&mut buf, builder).expect("capnp serialize failed"); Bytes::from(buf) } /// Build a Frame from an opcode and a capnp message builder. pub fn build_frame( opcode: u8, builder: &capnp::message::Builder, ) -> Frame { Frame { opcode, payload: serialize_capnp(builder), } } /// Build an error frame with a text message. pub fn error_frame(msg: &str) -> Frame { let mut builder = capnp::message::Builder::new_default(); { let mut err = builder.init_root::(); err.set_message(msg); } build_frame(OP_ERROR, &builder) } /// Deserialize a capnp message from a byte slice. pub fn read_capnp(payload: &[u8]) -> capnp::Result> { let mut cursor = std::io::Cursor::new(payload); capnp::serialize::read_message( &mut cursor, capnp::message::ReaderOptions::new(), ) } #[cfg(test)] mod tests { use super::*; use tokio_util::codec::{Decoder, Encoder}; #[test] fn roundtrip_frame() { let mut codec = SqCodec::new(); let original = Frame { opcode: OP_PUBLISH_REQ, payload: Bytes::from_static(b"hello"), }; let mut buf = BytesMut::new(); codec .encode( Frame { opcode: original.opcode, payload: original.payload.clone(), }, &mut buf, ) .unwrap(); let decoded = codec.decode(&mut buf).unwrap().unwrap(); assert_eq!(decoded.opcode, OP_PUBLISH_REQ); assert_eq!(decoded.payload, Bytes::from_static(b"hello")); } #[test] fn capnp_publish_roundtrip() { // Build a PublishRequest let mut builder = capnp::message::Builder::new_default(); { let mut req = builder.init_root::(); req.set_ack_mode(1); req.set_producer_id("test"); let mut msgs = req.init_messages(1); let mut msg = msgs.reborrow().get(0); msg.set_topic("orders"); msg.set_key(b"key1"); msg.set_value(b"value1"); } let frame = build_frame(OP_PUBLISH_REQ, &builder); assert_eq!(frame.opcode, OP_PUBLISH_REQ); // Decode let reader = read_capnp(&frame.payload).unwrap(); let req = reader .get_root::() .unwrap(); assert_eq!(req.get_ack_mode(), 1); assert_eq!(req.get_producer_id().unwrap(), "test"); let msgs = req.get_messages().unwrap(); assert_eq!(msgs.len(), 1); assert_eq!(msgs.get(0).get_topic().unwrap(), "orders"); assert_eq!(msgs.get(0).get_key().unwrap(), b"key1"); assert_eq!(msgs.get(0).get_value().unwrap(), b"value1"); } #[test] fn error_frame_roundtrip() { let frame = error_frame("something went wrong"); assert_eq!(frame.opcode, OP_ERROR); let reader = read_capnp(&frame.payload).unwrap(); let err = reader .get_root::() .unwrap(); assert_eq!(err.get_message().unwrap(), "something went wrong"); } }