Skip to content

Commit df4e40c

Browse files
authored
Replace pickle with msgspec for IPC serialization (#8713)
Pickle serialization across the ZeroMQ IPC boundary is fragile because any change to the command types (renaming, adding/removing fields) can silently break deserialization between the host and kernel processes. `msgspec.msgpack` gives us a proper (binary) wire format with well-defined schema evolution rules tied to the existing `Command` struct definitions. Each Pull channel now constructs a `msgspec.msgpack.Decoder[T]` from its known message type at creation time. The receiver thread uses this decoder to deserialize incoming messages back into their correct types, leveraging the discriminated union tags already on `Command` subclasses.
1 parent b8c556b commit df4e40c

3 files changed

Lines changed: 187 additions & 30 deletions

File tree

marimo/_ipc/connection.py

Lines changed: 29 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -8,20 +8,22 @@
88
import sys
99
import typing
1010

11+
import msgspec
12+
1113
from marimo import _loggers
1214
from marimo._ipc.queue_proxy import PushQueue, start_receiver_thread
1315
from marimo._ipc.types import ConnectionInfo
16+
from marimo._runtime.commands import (
17+
BatchableCommand,
18+
CodeCompletionCommand,
19+
CommandMessage,
20+
)
1421
from marimo._session.queue import QueueType
1522

1623
if typing.TYPE_CHECKING:
1724
import zmq
1825

1926
from marimo._messaging.types import KernelMessage
20-
from marimo._runtime.commands import (
21-
BatchableCommand,
22-
CodeCompletionCommand,
23-
CommandMessage,
24-
)
2527

2628
LOGGER = _loggers.marimo_logger()
2729
ADDR = "tcp://127.0.0.1"
@@ -41,6 +43,7 @@ class Channel(typing.Generic[T]):
4143
kind: typing.Literal["push", "pull"]
4244
socket: zmq.Socket[bytes]
4345
queue: QueueType[T]
46+
decoder: msgspec.msgpack.Decoder[T] | None = None
4447

4548
@classmethod
4649
def Push(
@@ -61,12 +64,17 @@ def Push(
6164

6265
@classmethod
6366
def Pull(
64-
cls, context: zmq.Context[zmq.Socket[bytes]], *, maxsize: int = 0
67+
cls,
68+
context: zmq.Context[zmq.Socket[bytes]],
69+
*,
70+
msg_type: type[T],
71+
maxsize: int = 0,
6572
) -> Channel[T]:
6673
"""Create a pull (receive-only) channel.
6774
6875
Args:
6976
context: ZeroMQ context for creating sockets
77+
msg_type: The type to decode incoming messages as
7078
maxsize: Maximum queue size (0 = unlimited)
7179
"""
7280
import zmq
@@ -76,6 +84,7 @@ def Pull(
7684
kind="pull",
7785
socket=socket,
7886
queue=queue.Queue(maxsize=maxsize),
87+
decoder=msgspec.msgpack.Decoder(msg_type),
7988
)
8089

8190

@@ -95,22 +104,22 @@ class Connection:
95104

96105
def __post_init__(self) -> None:
97106
"""Start receiver threads for all pull channels."""
98-
receivers: dict[zmq.Socket[bytes], QueueType[typing.Any]] = {}
107+
pull_channels: list[Channel[typing.Any]] = []
99108
if self.control.kind == "pull":
100-
receivers[self.control.socket] = self.control.queue
109+
pull_channels.append(self.control)
101110
if self.ui_element.kind == "pull":
102-
receivers[self.ui_element.socket] = self.ui_element.queue
111+
pull_channels.append(self.ui_element)
103112
if self.completion.kind == "pull":
104-
receivers[self.completion.socket] = self.completion.queue
113+
pull_channels.append(self.completion)
105114
if self.win32_interrupt and self.win32_interrupt.kind == "pull":
106-
receivers[self.win32_interrupt.socket] = self.win32_interrupt.queue
115+
pull_channels.append(self.win32_interrupt)
107116
if self.input.kind == "pull":
108-
receivers[self.input.socket] = self.input.queue
117+
pull_channels.append(self.input)
109118
if self.stream.kind == "pull":
110-
receivers[self.stream.socket] = self.stream.queue
119+
pull_channels.append(self.stream)
111120

112121
self._stop_event, self._receiver_thread = start_receiver_thread(
113-
receivers
122+
pull_channels
114123
)
115124

116125
@classmethod
@@ -132,7 +141,7 @@ def create(cls) -> tuple[Connection, ConnectionInfo]:
132141
Channel.Push(context) if sys.platform == "win32" else None
133142
),
134143
input=Channel.Push(context, maxsize=1),
135-
stream=Channel.Pull(context),
144+
stream=Channel.Pull(context, msg_type=bytes),
136145
)
137146
info = ConnectionInfo(
138147
control=conn.control.socket.bind_to_random_port(ADDR),
@@ -164,13 +173,13 @@ def connect(cls, connection_info: ConnectionInfo) -> Connection:
164173

165174
conn = cls(
166175
context=context,
167-
control=Channel.Pull(context),
168-
ui_element=Channel.Pull(context),
169-
completion=Channel.Pull(context),
170-
win32_interrupt=Channel.Pull(context)
176+
control=Channel.Pull(context, msg_type=CommandMessage),
177+
ui_element=Channel.Pull(context, msg_type=BatchableCommand),
178+
completion=Channel.Pull(context, msg_type=CodeCompletionCommand),
179+
win32_interrupt=Channel.Pull(context, msg_type=bool)
171180
if connection_info.win32_interrupt
172181
else None,
173-
input=Channel.Pull(context, maxsize=1),
182+
input=Channel.Pull(context, msg_type=str, maxsize=1),
174183
stream=Channel.Push(context),
175184
)
176185

marimo/_ipc/queue_proxy.py

Lines changed: 17 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,11 @@
33

44
from __future__ import annotations
55

6-
import pickle
76
import threading
87
import typing
98

9+
import msgspec
10+
1011
from marimo import _loggers
1112
from marimo._session.queue import QueueType
1213

@@ -17,6 +18,8 @@
1718
if typing.TYPE_CHECKING:
1819
import zmq
1920

21+
from marimo._ipc.connection import Channel
22+
2023

2124
class PushQueue(QueueType[T]):
2225
"""Queue for pushing messages through ZeroMQ socket (sender side only).
@@ -40,7 +43,7 @@ def put(
4043
timeout: float | None = None, # noqa: ARG002
4144
) -> None:
4245
"""Put an item into the queue."""
43-
self.socket.send(pickle.dumps(obj))
46+
self.socket.send(msgspec.msgpack.encode(obj))
4447

4548
def put_nowait(self, obj: T) -> None:
4649
"""Put an item into the queue without blocking."""
@@ -63,29 +66,34 @@ def empty(self) -> bool:
6366

6467

6568
def start_receiver_thread(
66-
receivers: dict[zmq.Socket[bytes], QueueType[typing.Any]],
69+
channels: list[Channel[typing.Any]],
6770
) -> tuple[threading.Event, threading.Thread]:
6871
"""Start receiver thread."""
6972
import zmq
7073

7174
def receive_loop(
72-
receivers: dict[zmq.Socket[bytes], QueueType[typing.Any]],
75+
channels: list[Channel[typing.Any]],
7376
stop_event: threading.Event,
7477
) -> None:
7578
"""Receive messages from sockets and put them in queues using polling."""
7679
poller = zmq.Poller()
77-
for socket in receivers:
78-
poller.register(socket, zmq.POLLIN)
80+
socket_to_channel: dict[zmq.Socket[bytes], Channel[typing.Any]] = {}
81+
for channel in channels:
82+
poller.register(channel.socket, zmq.POLLIN)
83+
socket_to_channel[channel.socket] = channel
7984

8085
while not stop_event.is_set():
8186
try:
8287
# Poll with 100ms timeout
8388
socks = dict(poller.poll(100))
8489
for socket, event in socks.items():
8590
if event & zmq.POLLIN:
91+
ch = socket_to_channel[socket]
8692
msg = socket.recv(flags=zmq.NOBLOCK)
87-
obj = pickle.loads(msg)
88-
receivers[socket].put(obj)
93+
assert ch.decoder is not None, (
94+
"Pull channel must have a decoder"
95+
)
96+
ch.queue.put(ch.decoder.decode(msg))
8997
except zmq.Again:
9098
continue
9199
except zmq.ZMQError as e:
@@ -101,7 +109,7 @@ def receive_loop(
101109
stop_event = threading.Event()
102110
thread = threading.Thread(
103111
target=receive_loop,
104-
args=(receivers, stop_event),
112+
args=(channels, stop_event),
105113
daemon=True,
106114
)
107115
thread.start()

tests/_ipc/test_kernel_communication.py

Lines changed: 141 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import sys
99
import time
1010

11+
import msgspec
1112
import pytest
1213
from dirty_equals import IsFloat, IsList, IsUUID
1314

@@ -18,7 +19,14 @@
1819
from marimo._dependencies.dependencies import DependencyManager
1920
from marimo._runtime.commands import (
2021
AppMetadata,
22+
BatchableCommand,
23+
CodeCompletionCommand,
24+
CommandMessage,
2125
ExecuteCellsCommand,
26+
HTTPRequest,
27+
ModelCommand,
28+
ModelUpdateMessage,
29+
UpdateUIElementCommand,
2230
)
2331
from marimo._types.ids import CellId_t
2432

@@ -259,9 +267,141 @@ def test_queue_manager_connection():
259267
host_manager.control_queue.put(test_request)
260268
assert client_manager.control_queue.get(timeout=1) == test_request
261269

262-
kernel_message = ("test-op", b'{"data": "test"}')
270+
kernel_message = b'{"op": "test-op", "data": "test"}'
263271
client_manager.stream_queue.put(kernel_message)
264272
assert host_manager.stream_queue.get(timeout=1) == kernel_message
265273

266274
host_manager.close_queues()
267275
client_manager.close_queues()
276+
277+
278+
class TestMsgpackIPC:
279+
"""Test msgspec msgpack serialization for IPC channels.
280+
281+
Each IPC channel has a specific type that must survive encode/decode.
282+
These tests mirror the channel types in Connection.create/connect:
283+
284+
control: CommandMessage (discriminated union)
285+
ui_element: BatchableCommand (discriminated union)
286+
completion: CodeCompletionCommand
287+
input: str
288+
win32_interrupt: bool
289+
stream: bytes
290+
"""
291+
292+
def test_control_channel(self) -> None:
293+
"""CommandMessage union dispatches to the correct subtype."""
294+
cmd = ExecuteCellsCommand(cell_ids=["c1"], codes=["x=1"])
295+
encoded = msgspec.msgpack.encode(cmd)
296+
decoded = msgspec.msgpack.Decoder(CommandMessage).decode(encoded)
297+
assert decoded == cmd
298+
299+
def test_control_channel_with_http_request(self) -> None:
300+
"""HTTPRequest (a @dataclass) survives the round-trip as a nested field."""
301+
req = HTTPRequest(
302+
url={"path": "/run"},
303+
base_url={"path": "/"},
304+
headers={"content-type": "application/json"},
305+
query_params={"key": ["val1", "val2"]},
306+
path_params={},
307+
cookies={},
308+
meta={},
309+
user={},
310+
)
311+
cmd = ExecuteCellsCommand(cell_ids=["c1"], codes=["x=1"], request=req)
312+
encoded = msgspec.msgpack.encode(cmd)
313+
decoded = msgspec.msgpack.Decoder(CommandMessage).decode(encoded)
314+
assert decoded.request is not None
315+
assert decoded.request["url"]["path"] == "/run"
316+
assert decoded.request["query_params"]["key"] == ["val1", "val2"]
317+
318+
def test_ui_element_channel(self) -> None:
319+
"""BatchableCommand union round-trips both member types."""
320+
ui = UpdateUIElementCommand(
321+
object_ids=["e1"], values=[{"nested": [1, 2]}], token="tok"
322+
)
323+
encoded = msgspec.msgpack.encode(ui)
324+
assert msgspec.msgpack.Decoder(BatchableCommand).decode(encoded) == ui
325+
326+
msg = ModelUpdateMessage(state={"k": "v"}, buffer_paths=[])
327+
model = ModelCommand(
328+
model_id="m1", message=msg, buffers=[b"buf"], token="tok"
329+
)
330+
encoded = msgspec.msgpack.encode(model)
331+
assert (
332+
msgspec.msgpack.Decoder(BatchableCommand).decode(encoded) == model
333+
)
334+
335+
def test_completion_channel(self) -> None:
336+
cmd = CodeCompletionCommand(id="t", document="x = 1", cell_id="c1")
337+
encoded = msgspec.msgpack.encode(cmd)
338+
decoded = msgspec.msgpack.Decoder(CodeCompletionCommand).decode(
339+
encoded
340+
)
341+
assert decoded == cmd
342+
343+
def test_primitive_channels(self) -> None:
344+
"""str (input), bool (win32_interrupt), and bytes (stream) channels."""
345+
for value, typ in [
346+
("user_input", str),
347+
(True, bool),
348+
(b'{"op": "cell-op"}', bytes),
349+
]:
350+
encoded = msgspec.msgpack.encode(value)
351+
assert msgspec.msgpack.Decoder(typ).decode(encoded) == value
352+
353+
def test_unknown_fields_are_ignored(self) -> None:
354+
"""Decoder silently drops fields it doesn't recognize.
355+
356+
This matters when the sender is newer than the receiver (e.g.
357+
a field was added to a command). msgspec must not reject the
358+
message — it should decode the known fields and discard the rest.
359+
"""
360+
361+
# A "V2" struct with the same tag but an extra field
362+
class ExecuteCellsCommandV2(
363+
msgspec.Struct,
364+
rename="camel",
365+
tag_field="type",
366+
tag="execute-cells",
367+
):
368+
cell_ids: list[str]
369+
codes: list[str]
370+
new_field: str = "added_later"
371+
372+
v2 = ExecuteCellsCommandV2(
373+
cell_ids=["c1"], codes=["x=1"], new_field="extra"
374+
)
375+
encoded = msgspec.msgpack.encode(v2)
376+
377+
decoded = msgspec.msgpack.Decoder(CommandMessage).decode(encoded)
378+
assert type(decoded) is ExecuteCellsCommand
379+
assert decoded.cell_ids == ["c1"]
380+
assert decoded.codes == ["x=1"]
381+
382+
def test_missing_optional_fields_get_defaults(self) -> None:
383+
"""Decoder fills in defaults for fields the sender didn't include.
384+
385+
This matters when the receiver is newer than the sender (e.g.
386+
a field was added with a default, but the sender hasn't been
387+
updated yet).
388+
"""
389+
390+
class ExecuteCellsCommandV2(
391+
msgspec.Struct,
392+
rename="camel",
393+
tag_field="type",
394+
tag="execute-cells",
395+
):
396+
cell_ids: list[str]
397+
codes: list[str]
398+
new_field: str = "default_value"
399+
400+
old = ExecuteCellsCommand(cell_ids=["c1"], codes=["x=1"])
401+
encoded = msgspec.msgpack.encode(old)
402+
403+
decoded = msgspec.msgpack.Decoder(ExecuteCellsCommandV2).decode(
404+
encoded
405+
)
406+
assert decoded.cell_ids == ["c1"]
407+
assert decoded.new_field == "default_value"

0 commit comments

Comments
 (0)