Files
sq/crates/sq-capnp-interface/src/codec.rs
2026-02-27 12:15:43 +01:00

186 lines
5.6 KiB
Rust

use bytes::{Buf, BufMut, Bytes, BytesMut};
use tokio_util::codec::{Decoder, Encoder, LengthDelimitedCodec};
// Opcodes
pub const OP_PUBLISH_REQ: u8 = 0x01;
pub const OP_PUBLISH_RES: u8 = 0x81;
pub const OP_SUBSCRIBE_REQ: u8 = 0x02;
pub const OP_SUBSCRIBE_RES: u8 = 0x82;
pub const OP_ACK_REQ: u8 = 0x03;
pub const OP_ACK_RES: u8 = 0x83;
pub const OP_SUBSCRIBE_END: u8 = 0x84;
pub const OP_ERROR: u8 = 0xFE;
/// A decoded frame: opcode + capnp payload bytes.
pub struct Frame {
pub opcode: u8,
pub payload: Bytes,
}
/// Codec that wraps `LengthDelimitedCodec` and prepends a 1-byte opcode.
///
/// Wire format: `[4-byte big-endian frame length][1-byte opcode][capnp payload]`
pub struct SqCodec {
inner: LengthDelimitedCodec,
}
impl SqCodec {
pub fn new() -> Self {
Self {
inner: LengthDelimitedCodec::builder()
.max_frame_length(16 * 1024 * 1024) // 16 MB
.new_codec(),
}
}
}
impl Default for SqCodec {
fn default() -> Self {
Self::new()
}
}
impl Decoder for SqCodec {
type Item = Frame;
type Error = std::io::Error;
fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
match self.inner.decode(src)? {
Some(mut buf) => {
if buf.is_empty() {
return Err(std::io::Error::new(
std::io::ErrorKind::InvalidData,
"empty frame",
));
}
let opcode = buf.get_u8();
let payload = buf.freeze();
Ok(Some(Frame { opcode, payload }))
}
None => Ok(None),
}
}
}
impl Encoder<Frame> for SqCodec {
type Error = std::io::Error;
fn encode(&mut self, item: Frame, dst: &mut BytesMut) -> Result<(), Self::Error> {
let mut buf = BytesMut::with_capacity(1 + item.payload.len());
buf.put_u8(item.opcode);
buf.extend_from_slice(&item.payload);
self.inner.encode(buf.freeze(), dst)
}
}
/// Serialize a capnp message builder into bytes.
pub fn serialize_capnp(builder: &capnp::message::Builder<capnp::message::HeapAllocator>) -> Bytes {
let mut buf = Vec::new();
capnp::serialize::write_message(&mut buf, builder).expect("capnp serialize failed");
Bytes::from(buf)
}
/// Build a Frame from an opcode and a capnp message builder.
pub fn build_frame(
opcode: u8,
builder: &capnp::message::Builder<capnp::message::HeapAllocator>,
) -> Frame {
Frame {
opcode,
payload: serialize_capnp(builder),
}
}
/// Build an error frame with a text message.
pub fn error_frame(msg: &str) -> Frame {
let mut builder = capnp::message::Builder::new_default();
{
let mut err = builder.init_root::<crate::data_plane_capnp::error_response::Builder>();
err.set_message(msg);
}
build_frame(OP_ERROR, &builder)
}
/// Deserialize a capnp message from a byte slice.
pub fn read_capnp(payload: &[u8]) -> capnp::Result<capnp::message::Reader<capnp::serialize::OwnedSegments>> {
let mut cursor = std::io::Cursor::new(payload);
capnp::serialize::read_message(
&mut cursor,
capnp::message::ReaderOptions::new(),
)
}
#[cfg(test)]
mod tests {
use super::*;
use tokio_util::codec::{Decoder, Encoder};
#[test]
fn roundtrip_frame() {
let mut codec = SqCodec::new();
let original = Frame {
opcode: OP_PUBLISH_REQ,
payload: Bytes::from_static(b"hello"),
};
let mut buf = BytesMut::new();
codec
.encode(
Frame {
opcode: original.opcode,
payload: original.payload.clone(),
},
&mut buf,
)
.unwrap();
let decoded = codec.decode(&mut buf).unwrap().unwrap();
assert_eq!(decoded.opcode, OP_PUBLISH_REQ);
assert_eq!(decoded.payload, Bytes::from_static(b"hello"));
}
#[test]
fn capnp_publish_roundtrip() {
// Build a PublishRequest
let mut builder = capnp::message::Builder::new_default();
{
let mut req = builder.init_root::<crate::data_plane_capnp::publish_request::Builder>();
req.set_ack_mode(1);
req.set_producer_id("test");
let mut msgs = req.init_messages(1);
let mut msg = msgs.reborrow().get(0);
msg.set_topic("orders");
msg.set_key(b"key1");
msg.set_value(b"value1");
}
let frame = build_frame(OP_PUBLISH_REQ, &builder);
assert_eq!(frame.opcode, OP_PUBLISH_REQ);
// Decode
let reader = read_capnp(&frame.payload).unwrap();
let req = reader
.get_root::<crate::data_plane_capnp::publish_request::Reader>()
.unwrap();
assert_eq!(req.get_ack_mode(), 1);
assert_eq!(req.get_producer_id().unwrap(), "test");
let msgs = req.get_messages().unwrap();
assert_eq!(msgs.len(), 1);
assert_eq!(msgs.get(0).get_topic().unwrap(), "orders");
assert_eq!(msgs.get(0).get_key().unwrap(), b"key1");
assert_eq!(msgs.get(0).get_value().unwrap(), b"value1");
}
#[test]
fn error_frame_roundtrip() {
let frame = error_frame("something went wrong");
assert_eq!(frame.opcode, OP_ERROR);
let reader = read_capnp(&frame.payload).unwrap();
let err = reader
.get_root::<crate::data_plane_capnp::error_response::Reader>()
.unwrap();
assert_eq!(err.get_message().unwrap(), "something went wrong");
}
}