diff --git a/src/workerd/api/BUILD.bazel b/src/workerd/api/BUILD.bazel index 99fd4a0d8eb..b7b1cadfe87 100644 --- a/src/workerd/api/BUILD.bazel +++ b/src/workerd/api/BUILD.bazel @@ -85,6 +85,7 @@ filegroup( "unsafe.h", "url-standard.h", "web-socket.h", + "web-socket-data-message.h", "worker-rpc.h", ] + glob( [ diff --git a/src/workerd/api/actor-state.c++ b/src/workerd/api/actor-state.c++ index 3f4b406c363..5c4a4db9c67 100644 --- a/src/workerd/api/actor-state.c++ +++ b/src/workerd/api/actor-state.c++ @@ -1223,16 +1223,18 @@ void DurableObjectState::setWebSocketAutoResponse( auto reqResp = KJ_REQUIRE_NONNULL(kj::mv(maybeReqResp)); auto maxRequestOrResponseSize = 2048; - JSG_REQUIRE(reqResp->getRequest().size() <= maxRequestOrResponseSize, RangeError, + auto request = WebSocketDataMessage(reqResp->getRequest()); + auto response = WebSocketDataMessage(reqResp->getResponse()); + + JSG_REQUIRE(request.size() <= maxRequestOrResponseSize, RangeError, kj::str("Request cannot be larger than ", maxRequestOrResponseSize, " bytes. ", - "A request of size ", reqResp->getRequest().size(), " was provided.")); + "A request of size ", request.size(), " was provided.")); - JSG_REQUIRE(reqResp->getResponse().size() <= maxRequestOrResponseSize, RangeError, + JSG_REQUIRE(response.size() <= maxRequestOrResponseSize, RangeError, kj::str("Response cannot be larger than ", maxRequestOrResponseSize, " bytes. ", - "A response of size ", reqResp->getResponse().size(), " was provided.")); + "A response of size ", response.size(), " was provided.")); - maybeInitHibernationManager(a).setWebSocketAutoResponse( - reqResp->getRequest(), reqResp->getResponse()); + maybeInitHibernationManager(a).setWebSocketAutoResponse(kj::mv(request), kj::mv(response)); } kj::Maybe> DurableObjectState::getWebSocketAutoResponse( diff --git a/src/workerd/api/actor-state.h b/src/workerd/api/actor-state.h index a2cc488a708..588a7055e70 100644 --- a/src/workerd/api/actor-state.h +++ b/src/workerd/api/actor-state.h @@ -9,6 +9,7 @@ #include #include +#include #include #include #include @@ -537,20 +538,23 @@ class ActorState: public jsg::Object { class WebSocketRequestResponsePair: public jsg::Object { public: - WebSocketRequestResponsePair(kj::String request, kj::String response) + WebSocketRequestResponsePair(WebSocketDataMessage request, WebSocketDataMessage response) : request(kj::mv(request)), response(kj::mv(response)) {}; - static jsg::Ref constructor( - jsg::Lock& js, kj::String request, kj::String response) { - return js.alloc(kj::mv(request), kj::mv(response)); + static jsg::Ref constructor(jsg::Lock& js, + kj::OneOf, kj::String> request, + kj::OneOf, kj::String> response) { + auto req = WebSocketDataMessage(kj::mv(request)); + auto resp = WebSocketDataMessage(kj::mv(response)); + return js.alloc(kj::mv(req), kj::mv(resp)); }; - kj::StringPtr getRequest() { - return request.asPtr(); + kj::OneOf, kj::String> getRequest() { + return toJsgOneOf(request); } - kj::StringPtr getResponse() { - return response.asPtr(); + kj::OneOf, kj::String> getResponse() { + return toJsgOneOf(response); } JSG_RESOURCE_TYPE(WebSocketRequestResponsePair) { @@ -559,13 +563,25 @@ class WebSocketRequestResponsePair: public jsg::Object { } void visitForMemoryInfo(jsg::MemoryTracker& tracker) const { - tracker.trackField("request", request); - tracker.trackField("response", response); + tracker.trackFieldWithSize("request", request.size()); + tracker.trackFieldWithSize("response", response.size()); } private: - kj::String request; - kj::String response; + WebSocketDataMessage request; + WebSocketDataMessage response; + + static kj::OneOf, kj::String> toJsgOneOf(const WebSocketDataMessage& msg) { + KJ_SWITCH_ONEOF(msg.asOneOf()) { + KJ_CASE_ONEOF(text, kj::StringPtr) { + return kj::str(text); + } + KJ_CASE_ONEOF(bytes, kj::ArrayPtr) { + return kj::heapArray(bytes); + } + } + KJ_UNREACHABLE; + } }; // The type passed as the first parameter to durable object class's constructor. diff --git a/src/workerd/api/web-socket.c++ b/src/workerd/api/web-socket.c++ index ea58697e5b0..4882e85615c 100644 --- a/src/workerd/api/web-socket.c++ +++ b/src/workerd/api/web-socket.c++ @@ -862,11 +862,11 @@ void WebSocket::ensurePumping(jsg::Lock& js) { } } -kj::Promise WebSocket::sendAutoResponse(kj::String message, kj::WebSocket& ws) { +kj::Promise WebSocket::sendAutoResponse(WebSocketDataMessage message, kj::WebSocket& ws) { if (autoResponseStatus.isPumping) { autoResponseStatus.pendingAutoResponseDeque.push(kj::mv(message)); } else if (!autoResponseStatus.isClosed) { - auto p = ws.send(message).fork(); + auto p = message.sendVia(ws).fork(); KJ_IF_SOME(context, IoContext::tryCurrent()) { autoResponseStatus.ongoingAutoResponse.emplace(context.addObject(kj::heap(p.addBranch()))); } else { @@ -965,7 +965,7 @@ kj::Promise WebSocket::pump(IoContext& context, auto message = KJ_ASSERT_NONNULL(autoResponse.pendingAutoResponseDeque.pop()); gatedMessage.pendingAutoResponses--; autoResponse.queuedAutoResponses--; - co_await ws.send(message); + co_await message.sendVia(ws); } KJ_SWITCH_ONEOF(gatedMessage.message) { @@ -997,7 +997,7 @@ kj::Promise WebSocket::pump(IoContext& context, // We should also check if the last sent message was a close. Shouldn't happen. while (!autoResponse.pendingAutoResponseDeque.empty() && !autoResponse.isClosed) { auto message = KJ_ASSERT_NONNULL(autoResponse.pendingAutoResponseDeque.pop()); - co_await ws.send(message); + co_await message.sendVia(ws); } // While we were `co_await`ing the auto-response send, more messages could have been queued diff --git a/src/workerd/api/web-socket.h b/src/workerd/api/web-socket.h index f3fc3ad2d42..b5a735230ce 100644 --- a/src/workerd/api/web-socket.h +++ b/src/workerd/api/web-socket.h @@ -6,6 +6,7 @@ #include "basics.h" #include "events.h" +#include "web-socket-data-message.h" #include #include @@ -336,7 +337,7 @@ class WebSocket: public EventTarget { // These methods are c++ only and are not exposed to our js interface. kj::Maybe getAutoResponseTimestamp(); - kj::Promise sendAutoResponse(kj::String message, kj::WebSocket& ws); + kj::Promise sendAutoResponse(WebSocketDataMessage message, kj::WebSocket& ws); int getReadyState(); @@ -599,15 +600,16 @@ class WebSocket: public EventTarget { using OwnedAutoResponsePromise = kj::OneOf>, kj::Own>>; kj::Maybe ongoingAutoResponse; - workerd::util::Queue pendingAutoResponseDeque; + workerd::util::Queue pendingAutoResponseDeque; size_t queuedAutoResponses = 0; bool isPumping = false; bool isClosed = false; JSG_MEMORY_INFO(AutoResponse) { tracker.trackFieldWithSize("ongoingAutoResponse", sizeof(kj::Promise)); - pendingAutoResponseDeque.forEach( - [&](const kj::String& message) { tracker.trackField(nullptr, message); }); + pendingAutoResponseDeque.forEach([&](const WebSocketDataMessage& message) { + tracker.trackFieldWithSize(nullptr, message.size()); + }); } }; diff --git a/src/workerd/io/BUILD.bazel b/src/workerd/io/BUILD.bazel index 13c4220fa67..9a8f0df455e 100644 --- a/src/workerd/io/BUILD.bazel +++ b/src/workerd/io/BUILD.bazel @@ -101,6 +101,7 @@ wd_cc_library( "//src/workerd/api:fuzzilli", "//src/workerd/api:hibernation-event-params", "//src/workerd/api:url", + "//src/workerd/api:web-socket-data-message", "//src/workerd/jsg", "//src/workerd/jsg:inspector", "//src/workerd/jsg:script", diff --git a/src/workerd/io/hibernation-manager-test.c++ b/src/workerd/io/hibernation-manager-test.c++ index 67e5c364a8e..671b0bcbb65 100644 --- a/src/workerd/io/hibernation-manager-test.c++ +++ b/src/workerd/io/hibernation-manager-test.c++ @@ -136,7 +136,17 @@ kj::Own makeTestHm(TestFixture& fixture) { kj::Own makeTestHm( TestFixture& fixture, kj::StringPtr autoRequest, kj::StringPtr autoResponse) { auto hm = makeTestHm(fixture); - hm->setWebSocketAutoResponse(autoRequest, autoResponse); + hm->setWebSocketAutoResponse(api::WebSocketDataMessage(kj::str(autoRequest)), + api::WebSocketDataMessage(kj::str(autoResponse))); + return hm; +} + +kj::Own makeTestHm(TestFixture& fixture, + kj::ArrayPtr autoRequest, + kj::ArrayPtr autoResponse) { + auto hm = makeTestHm(fixture); + hm->setWebSocketAutoResponse(api::WebSocketDataMessage(kj::heapArray(autoRequest)), + api::WebSocketDataMessage(kj::heapArray(autoResponse))); return hm; } @@ -943,5 +953,291 @@ KJ_TEST("HibernationManager: auto-response (hibernated) bypasses the output gate fixture.drainAndDestroy(kj::mv(request)); } +KJ_TEST("HibernationManager: binary auto-response matches binary frame (active)") { + DispatchStats stats; + TestFixture fixture(stubLoopbackParams(stats, kj::str("bin-autoresp-active"))); + auto req = kj::heapArray({0x01, 0x02, 0x03}); + auto resp = kj::heapArray({0x04, 0x05, 0x06}); + auto hm = makeTestHm(fixture, req.asPtr(), resp.asPtr()); + auto request = fixture.newIncomingRequest(); + auto end1 = acceptNewWebSocket(fixture, *request, *hm); + + end1->send(req.asPtr()).wait(fixture.getWaitScope()); + + auto msg = end1->receive().wait(fixture.getWaitScope()); + KJ_ASSERT(msg.is>()); + auto& got = msg.get>(); + KJ_ASSERT(got.size() == 3); + KJ_ASSERT(got[0] == 0x04 && got[1] == 0x05 && got[2] == 0x06); + + fixture.pollEventLoop(); + KJ_ASSERT(stats.customEventCalls == 0, "binary auto-response should not dispatch to worker", + stats.customEventCalls); + + fixture.drainAndDestroy(kj::mv(request)); +} + +KJ_TEST("HibernationManager: binary auto-response matches binary frame (hibernated)") { + DispatchStats stats; + TestFixture fixture(stubLoopbackParams(stats, kj::str("bin-autoresp-hib"))); + auto req = kj::heapArray({0x01, 0x02, 0x03}); + auto resp = kj::heapArray({0x04, 0x05, 0x06}); + auto hm = makeTestHm(fixture, req.asPtr(), resp.asPtr()); + auto request = fixture.newIncomingRequest(); + auto end1 = acceptNewWebSocket(fixture, *request, *hm); + + fixture.enterWorkerLock([&](Worker::Lock& lock) { hm->hibernateWebSockets(lock); }); + + end1->send(req.asPtr()).wait(fixture.getWaitScope()); + auto msg = end1->receive().wait(fixture.getWaitScope()); + KJ_ASSERT(msg.is>()); + auto& got = msg.get>(); + KJ_ASSERT(got.size() == 3); + KJ_ASSERT(got[0] == 0x04 && got[1] == 0x05 && got[2] == 0x06); + + fixture.pollEventLoop(); + KJ_ASSERT(stats.customEventCalls == 0, "binary auto-response should not dispatch", + stats.customEventCalls); + + fixture.drainAndDestroy(kj::mv(request)); +} + +KJ_TEST("HibernationManager: binary auto-response does NOT match text frame") { + DispatchStats stats; + TestFixture fixture(stubLoopbackParams(stats, kj::str("bin-autoresp-no-text"))); + auto req = kj::heapArray({0x70, 0x69, 0x6e, 0x67}); // "ping" as bytes + auto resp = kj::heapArray({0x70, 0x6f, 0x6e, 0x67}); // "pong" as bytes + auto hm = makeTestHm(fixture, req.asPtr(), resp.asPtr()); + auto request = fixture.newIncomingRequest(); + auto end1 = acceptNewWebSocket(fixture, *request, *hm); + + // Send "ping" as a TEXT frame — should NOT match the binary auto-response. + end1->send("ping"_kj).wait(fixture.getWaitScope()); + + fixture.pollEventLoop(); + KJ_ASSERT(stats.customEventCalls >= 1, + "text frame should be dispatched despite matching binary content", stats.customEventCalls); + + fixture.drainAndDestroy(kj::mv(request)); +} + +KJ_TEST("HibernationManager: text auto-response does NOT match binary frame") { + DispatchStats stats; + TestFixture fixture(stubLoopbackParams(stats, kj::str("text-autoresp-no-bin"))); + auto hm = makeTestHm(fixture, "ping"_kj, "pong"_kj); + auto request = fixture.newIncomingRequest(); + auto end1 = acceptNewWebSocket(fixture, *request, *hm); + + // Send "ping" as a BINARY frame — should NOT match the text auto-response. + auto pingBytes = kj::heapArray({0x70, 0x69, 0x6e, 0x67}); + end1->send(pingBytes.asPtr()).wait(fixture.getWaitScope()); + + fixture.pollEventLoop(); + KJ_ASSERT(stats.customEventCalls >= 1, + "binary frame should be dispatched despite matching text content", stats.customEventCalls); + + fixture.drainAndDestroy(kj::mv(request)); +} + +KJ_TEST("HibernationManager: binary auto-response interleaved with DO sends (active)") { + DispatchStats stats; + TestFixture fixture(stubLoopbackParams(stats, kj::str("bin-autoresp-interleaved"))); + auto req = kj::heapArray({0x01, 0x02}); + auto resp = kj::heapArray({0x03, 0x04}); + auto hm = makeTestHm(fixture, req.asPtr(), resp.asPtr()); + auto request = fixture.newIncomingRequest(); + auto end1 = acceptNewWebSocket(fixture, *request, *hm); + + sendFromDo(fixture, *request, *hm, "before"_kj); + end1->send(req.asPtr()).wait(fixture.getWaitScope()); + sendFromDo(fixture, *request, *hm, "after"_kj); + + bool sawBefore = false, sawPong = false, sawAfter = false; + for (int i = 0; i < 3; ++i) { + auto msg = end1->receive().wait(fixture.getWaitScope()); + KJ_SWITCH_ONEOF(msg) { + KJ_CASE_ONEOF(text, kj::String) { + if (text == "before"_kj) + sawBefore = true; + else if (text == "after"_kj) + sawAfter = true; + else + KJ_FAIL_ASSERT("unexpected text message", text); + } + KJ_CASE_ONEOF(data, kj::Array) { + KJ_ASSERT(data.size() == 2 && data[0] == 0x03 && data[1] == 0x04); + sawPong = true; + } + KJ_CASE_ONEOF_DEFAULT { + KJ_FAIL_ASSERT("unexpected message type"); + } + } + } + KJ_ASSERT(sawBefore && sawPong && sawAfter); + + KJ_ASSERT(stats.customEventCalls == 0, "auto-response should not dispatch to worker", + stats.customEventCalls); + + fixture.drainAndDestroy(kj::mv(request)); +} + +KJ_TEST("HibernationManager: getWebSocketAutoResponse round-trips binary data") { + DispatchStats stats; + TestFixture fixture(stubLoopbackParams(stats, kj::str("bin-autoresp-roundtrip"))); + auto req = kj::heapArray({0xde, 0xad}); + auto resp = kj::heapArray({0xbe, 0xef}); + auto hm = makeTestHm(fixture, req.asPtr(), resp.asPtr()); + auto request = fixture.newIncomingRequest(); + + fixture.enterContext(*request, [&](const TestFixture::Environment& env) { + auto maybePair = hm->getWebSocketAutoResponse(env.js); + auto pair = KJ_ASSERT_NONNULL(kj::mv(maybePair)); + auto gotReq = pair->getRequest(); + auto gotResp = pair->getResponse(); + + KJ_ASSERT(gotReq.is>()); + auto& reqBytes = gotReq.get>(); + KJ_ASSERT(reqBytes.size() == 2 && reqBytes[0] == 0xde && reqBytes[1] == 0xad); + + KJ_ASSERT(gotResp.is>()); + auto& respBytes = gotResp.get>(); + KJ_ASSERT(respBytes.size() == 2 && respBytes[0] == 0xbe && respBytes[1] == 0xef); + }); + + fixture.drainAndDestroy(kj::mv(request)); +} + +KJ_TEST("HibernationManager: getWebSocketAutoResponse round-trips text data") { + DispatchStats stats; + TestFixture fixture(stubLoopbackParams(stats, kj::str("text-autoresp-roundtrip"))); + auto hm = makeTestHm(fixture, "hello"_kj, "world"_kj); + auto request = fixture.newIncomingRequest(); + + fixture.enterContext(*request, [&](const TestFixture::Environment& env) { + auto maybePair = hm->getWebSocketAutoResponse(env.js); + auto pair = KJ_ASSERT_NONNULL(kj::mv(maybePair)); + auto gotReq = pair->getRequest(); + auto gotResp = pair->getResponse(); + + KJ_ASSERT(gotReq.is()); + KJ_ASSERT(gotReq.get() == "hello"_kj); + + KJ_ASSERT(gotResp.is()); + KJ_ASSERT(gotResp.get() == "world"_kj); + }); + + fixture.drainAndDestroy(kj::mv(request)); +} + +KJ_TEST("HibernationManager: clear auto-response with kj::none after binary set") { + DispatchStats stats; + TestFixture fixture(stubLoopbackParams(stats, kj::str("bin-autoresp-clear"))); + auto req = kj::heapArray({0x01, 0x02}); + auto resp = kj::heapArray({0x03, 0x04}); + auto hm = makeTestHm(fixture, req.asPtr(), resp.asPtr()); + auto request = fixture.newIncomingRequest(); + auto end1 = acceptNewWebSocket(fixture, *request, *hm); + + hm->setWebSocketAutoResponse(kj::none, kj::none); + + end1->send(req.asPtr()).wait(fixture.getWaitScope()); + + fixture.pollEventLoop(); + KJ_ASSERT(stats.customEventCalls >= 1, "after clearing, binary frame should be dispatched", + stats.customEventCalls); + + fixture.drainAndDestroy(kj::mv(request)); +} + +KJ_TEST( + "HibernationManager: in-flight binary auto-response orphans BlockedSend during hibernation") { + KJ_EXPECT_LOG(ERROR, "another message send is already in progress"); + + DispatchStats stats; + TestFixture fixture(stubLoopbackParams(stats, kj::str("ew-10817-bin-autoresp"))); + auto req = kj::heapArray({0x01, 0x02}); + auto resp = kj::heapArray({0x03, 0x04}); + auto hm = makeTestHm(fixture, req.asPtr(), resp.asPtr()); + auto request = fixture.newIncomingRequest(); + auto end1 = acceptNewWebSocket(fixture, *request, *hm); + + end1->send(req.asPtr()).wait(fixture.getWaitScope()); + fixture.pollEventLoop(); + + fixture.enterWorkerLock([&](Worker::Lock& lock) { hm->hibernateWebSockets(lock); }); + + fixture.enterContext(*request, [&](const TestFixture::Environment& env) { + auto& js = env.js; + auto websockets = hm->getWebSockets(js, kj::none); + KJ_ASSERT(websockets.size() == 1); + websockets[0]->close(js, 1001, jsg::USVString(kj::str("stale"))); + }); + + fixture.pollEventLoop(); + end1->receive().wait(fixture.getWaitScope()); +} + +KJ_TEST( + "HibernationManager: in-flight binary auto-response orphans BlockedSend across actor eviction") { + KJ_EXPECT_LOG(ERROR, "another message send is already in progress"); + + DispatchStats stats; + TestFixture fixture(stubLoopbackParams(stats, kj::str("ew-10817-cross-bin-autoresp"))); + auto req = kj::heapArray({0x01, 0x02}); + auto resp = kj::heapArray({0x03, 0x04}); + auto hm = makeTestHm(fixture, req.asPtr(), resp.asPtr()); + + auto request1 = fixture.newIncomingRequest(); + auto end1 KJ_UNUSED = acceptNewWebSocket(fixture, *request1, *hm); + + end1->send(req.asPtr()).wait(fixture.getWaitScope()); + fixture.pollEventLoop(); + + fixture.enterWorkerLock([&](Worker::Lock& lock) { hm->hibernateWebSockets(lock); }); + request1 = nullptr; + fixture.resetActor(); + + auto request2 = fixture.newIncomingRequest(); + fixture.enterContext(*request2, [&](const TestFixture::Environment& env) { + auto& js = env.js; + auto websockets = hm->getWebSockets(js, kj::none); + KJ_ASSERT(websockets.size() == 1); + websockets[0]->close(js, 1001, jsg::USVString(kj::str("post-evict"))); + }); + fixture.pollEventLoop(); + + end1->receive().wait(fixture.getWaitScope()); +} + +KJ_TEST("HibernationManager: setWebSocketAutoResponse allows mismatched types") { + DispatchStats stats; + TestFixture fixture(stubLoopbackParams(stats, kj::str("mixed-types"))); + auto hm = makeTestHm(fixture); + + kj::byte binaryData[] = {0x01, 0x02, 0x03}; + + hm->setWebSocketAutoResponse(api::WebSocketDataMessage(kj::str("ping")), + api::WebSocketDataMessage(kj::heapArray(binaryData))); + + auto request = fixture.newIncomingRequest(); + + fixture.enterContext(*request, [&](const TestFixture::Environment& env) { + auto maybePair = hm->getWebSocketAutoResponse(env.js); + auto pair = KJ_ASSERT_NONNULL(kj::mv(maybePair)); + auto gotReq = pair->getRequest(); + auto gotResp = pair->getResponse(); + + KJ_ASSERT(gotReq.is()); + KJ_ASSERT(gotReq.get() == "ping"_kj); + + KJ_ASSERT(gotResp.is>()); + auto& respBytes = gotResp.get>(); + KJ_ASSERT(respBytes.size() == 3 && respBytes[0] == 0x01 && respBytes[1] == 0x02 && + respBytes[2] == 0x03); + }); + + fixture.drainAndDestroy(kj::mv(request)); +} + } // namespace } // namespace workerd diff --git a/src/workerd/io/hibernation-manager.c++ b/src/workerd/io/hibernation-manager.c++ index 0e1d5fb13d3..46a8679a261 100644 --- a/src/workerd/io/hibernation-manager.c++ +++ b/src/workerd/io/hibernation-manager.c++ @@ -189,25 +189,20 @@ kj::Vector> HibernationManagerImpl::getWebSockets( } void HibernationManagerImpl::setWebSocketAutoResponse( - kj::Maybe request, kj::Maybe response) { + kj::Maybe request, kj::Maybe response) { KJ_IF_SOME(req, request) { - // If we have a request, we must also have a response. If response is kj::none, we'll throw. - autoResponsePair->request = kj::str(req); - autoResponsePair->response = kj::str(KJ_REQUIRE_NONNULL(response)); + auto& resp = KJ_REQUIRE_NONNULL(response); + autoResponsePair->pair = AutoRequestResponsePair::Pair{kj::mv(req), kj::mv(resp)}; return; } - // If we don't have a request, we must unset both request and response. - autoResponsePair->request = kj::none; - autoResponsePair->response = kj::none; + autoResponsePair->pair = kj::none; } kj::Maybe> HibernationManagerImpl:: getWebSocketAutoResponse(jsg::Lock& js) { - KJ_IF_SOME(req, autoResponsePair->request) { - // When getting the currently set auto-response pair, if we have a request we must have a response - // set. If not, we'll throw. - return api::WebSocketRequestResponsePair::constructor( - js, kj::str(req), kj::str(KJ_REQUIRE_NONNULL(autoResponsePair->response))); + KJ_IF_SOME(pair, autoResponsePair->pair) { + return js.alloc( + pair.request.asPtr().toOwned(), pair.response.asPtr().toOwned()); } return kj::none; } @@ -293,55 +288,47 @@ kj::Promise HibernationManagerImpl::readLoop(HibernatableWebSocket& hib) { auto skip = false; - // If we have a request != kj::none, we can compare it the received message. This also implies - // that we have a response set in autoResponsePair. - KJ_IF_SOME(req, autoResponsePair->request) { - KJ_SWITCH_ONEOF(message) { - KJ_CASE_ONEOF(text, kj::String) { - if (text == req) { - // If the received message matches the one set for auto-response, we must - // short-circuit readLoop, store the current timestamp and and automatically respond - // with the expected response. - TimerChannel& timerChannel = KJ_REQUIRE_NONNULL(timer); - // This should count as a new IO event, hence we should call syncTime - // otherwise the autoResponseTimestamp wouldn't be accurate. - timerChannel.syncTime(); - // We should have set the timerChannel previously in the hibernation manager. - // If we haven't, we aren't able to get the current time. - hib.autoResponseTimestamp = timerChannel.now(); - // We'll store the current timestamp in the HibernatableWebSocket to assure it gets - // stored even if the WebSocket is currently hibernating. In that scenario, the timestamp - // value will be loaded into the WebSocket during unhibernation. - KJ_SWITCH_ONEOF(hib.activeOrPackage) { - KJ_CASE_ONEOF(apiWs, jsg::Ref) { - // If the actor is not hibernated/If the WebSocket is active, we need to update - // autoResponseTimestamp on the active websocket. - apiWs->setAutoResponseStatus(hib.autoResponseTimestamp, kj::READY_NOW); - // Since we had a request set, we must have and response that's sent back using the - // same websocket here. The sending of response is managed in web-socket to avoid - // possible racing problems with regular websocket messages. - co_await apiWs->sendAutoResponse( - kj::str(KJ_REQUIRE_NONNULL(autoResponsePair->response).asArray()), ws); - } - KJ_CASE_ONEOF(package, api::WebSocket::HibernationPackage) { - if (!package.closedOutgoingConnection) { - // We need to store the autoResponsePromise because we may instantiate an api::websocket - // If we do that, we have to provide it with the promise to avoid races. This can - // happen if we have a websocket hibernating, that unhibernates and sends a - // message while ws.send() for auto-response is also sending. - auto p = ws.send(KJ_REQUIRE_NONNULL(autoResponsePair->response).asArray()).fork(); - hib.autoResponsePromise = p.addBranch(); - co_await p; - hib.autoResponsePromise = kj::READY_NOW; - } - } + KJ_IF_SOME(pair, autoResponsePair->pair) { + bool matched = false; + if (pair.request.isText()) { + KJ_IF_SOME(text, message.tryGet()) { + matched = (text == pair.request); + } + } else if (pair.request.isBinary()) { + KJ_IF_SOME(data, message.tryGet>()) { + matched = (data == pair.request); + } + } + + if (matched) { + // Short-circuit readLoop: store the current timestamp and automatically respond + // with the expected response instead of dispatching to the actor. + TimerChannel& timerChannel = KJ_REQUIRE_NONNULL(timer); + // Count as a new IO event so autoResponseTimestamp is accurate. + timerChannel.syncTime(); + // Store the timestamp on the HibernatableWebSocket so it survives hibernation. + // During unhibernation, this value is loaded into the new api::WebSocket. + hib.autoResponseTimestamp = timerChannel.now(); + + KJ_SWITCH_ONEOF(hib.activeOrPackage) { + KJ_CASE_ONEOF(apiWs, jsg::Ref) { + apiWs->setAutoResponseStatus(hib.autoResponseTimestamp, kj::READY_NOW); + // Response sending is managed in web-socket to avoid racing with regular messages. + co_await apiWs->sendAutoResponse(pair.response.asPtr().toOwned(), ws); + } + KJ_CASE_ONEOF(package, api::WebSocket::HibernationPackage) { + if (!package.closedOutgoingConnection) { + // Store the autoResponsePromise so that if the WebSocket unhibernates mid-send, + // the new api::WebSocket can await it to avoid send races. + auto p = pair.response.sendVia(ws).fork(); + hib.autoResponsePromise = p.addBranch(); + co_await p; + hib.autoResponsePromise = kj::READY_NOW; } - // If we've sent an auto response message, we should not unhibernate or deliver the - // received message to the actor - skip = true; } } - KJ_CASE_ONEOF_DEFAULT {} + // Do not unhibernate or deliver the message to the actor. + skip = true; } } diff --git a/src/workerd/io/hibernation-manager.h b/src/workerd/io/hibernation-manager.h index a81f734bea8..77edb92fe60 100644 --- a/src/workerd/io/hibernation-manager.h +++ b/src/workerd/io/hibernation-manager.h @@ -6,6 +6,7 @@ #include #include +#include #include #include @@ -35,8 +36,8 @@ class HibernationManagerImpl final: public Worker::Actor::HibernationManager { // This converts our activeOrPackage from an api::WebSocket to a HibernationPackage. void hibernateWebSockets(Worker::Lock& lock) override; - void setWebSocketAutoResponse( - kj::Maybe request, kj::Maybe response) override; + void setWebSocketAutoResponse(kj::Maybe request, + kj::Maybe response) override; kj::Maybe> getWebSocketAutoResponse( jsg::Lock& js) override; void setTimerChannel(TimerChannel& timerChannel) override; @@ -165,13 +166,13 @@ class HibernationManagerImpl final: public Worker::Actor::HibernationManager { TagCollection(TagCollection&& other) = default; }; - // This structure will hold the request and corresponding response for hibernatable websockets - // auto-response feature. Although we store 2 kj::Maybe strings, if we don't have a request set - // we can't have a response, and vice versa. - // TODO(cleanup): Remove kj::Maybe from request and response strings. + // Holds the request/response pair for hibernatable websockets auto-response feature. struct AutoRequestResponsePair { - kj::Maybe request = kj::none; - kj::Maybe response = kj::none; + struct Pair { + api::WebSocketDataMessage request; + api::WebSocketDataMessage response; + }; + kj::Maybe pair; }; // A hashmap of tags to HibernatableWebSockets associated with the tag. diff --git a/src/workerd/io/worker.h b/src/workerd/io/worker.h index 8c3c6d1f376..63b997733f5 100644 --- a/src/workerd/io/worker.h +++ b/src/workerd/io/worker.h @@ -5,6 +5,7 @@ #pragma once // Classes to manage lifetime of workers, scripts, and isolates. +#include #include // because we can't forward-declare ActorCache::SharedLru. #include #include @@ -888,8 +889,8 @@ class Worker::Actor final: public kj::Refcounted { virtual kj::Vector> getWebSockets( jsg::Lock& js, kj::Maybe tag) = 0; virtual void hibernateWebSockets(Worker::Lock& lock) = 0; - virtual void setWebSocketAutoResponse( - kj::Maybe request, kj::Maybe response) = 0; + virtual void setWebSocketAutoResponse(kj::Maybe request, + kj::Maybe response) = 0; virtual kj::Maybe> getWebSocketAutoResponse( jsg::Lock& js) = 0; virtual void setTimerChannel(TimerChannel& timerChannel) = 0; diff --git a/types/generated-snapshot/experimental/index.d.ts b/types/generated-snapshot/experimental/index.d.ts index c0bd4661990..5dfcd7e4fee 100755 --- a/types/generated-snapshot/experimental/index.d.ts +++ b/types/generated-snapshot/experimental/index.d.ts @@ -823,9 +823,12 @@ interface DurableObjectSetAlarmOptions { allowUnconfirmed?: boolean; } declare class WebSocketRequestResponsePair { - constructor(request: string, response: string); - get request(): string; - get response(): string; + constructor( + request: (ArrayBuffer | ArrayBufferView) | string, + response: (ArrayBuffer | ArrayBufferView) | string, + ); + get request(): ArrayBuffer | string; + get response(): ArrayBuffer | string; } interface DurableObjectFacets { get( diff --git a/types/generated-snapshot/experimental/index.ts b/types/generated-snapshot/experimental/index.ts index b17203e0e78..b06336fb629 100755 --- a/types/generated-snapshot/experimental/index.ts +++ b/types/generated-snapshot/experimental/index.ts @@ -825,9 +825,12 @@ export interface DurableObjectSetAlarmOptions { allowUnconfirmed?: boolean; } export declare class WebSocketRequestResponsePair { - constructor(request: string, response: string); - get request(): string; - get response(): string; + constructor( + request: (ArrayBuffer | ArrayBufferView) | string, + response: (ArrayBuffer | ArrayBufferView) | string, + ); + get request(): ArrayBuffer | string; + get response(): ArrayBuffer | string; } export interface DurableObjectFacets { get( diff --git a/types/generated-snapshot/latest/index.d.ts b/types/generated-snapshot/latest/index.d.ts index 65623953506..9a2a10fe63c 100755 --- a/types/generated-snapshot/latest/index.d.ts +++ b/types/generated-snapshot/latest/index.d.ts @@ -775,9 +775,12 @@ interface DurableObjectSetAlarmOptions { allowUnconfirmed?: boolean; } declare class WebSocketRequestResponsePair { - constructor(request: string, response: string); - get request(): string; - get response(): string; + constructor( + request: (ArrayBuffer | ArrayBufferView) | string, + response: (ArrayBuffer | ArrayBufferView) | string, + ); + get request(): ArrayBuffer | string; + get response(): ArrayBuffer | string; } interface DurableObjectFacets { get( diff --git a/types/generated-snapshot/latest/index.ts b/types/generated-snapshot/latest/index.ts index ab8cb2fada2..b51fa3b2b30 100755 --- a/types/generated-snapshot/latest/index.ts +++ b/types/generated-snapshot/latest/index.ts @@ -777,9 +777,12 @@ export interface DurableObjectSetAlarmOptions { allowUnconfirmed?: boolean; } export declare class WebSocketRequestResponsePair { - constructor(request: string, response: string); - get request(): string; - get response(): string; + constructor( + request: (ArrayBuffer | ArrayBufferView) | string, + response: (ArrayBuffer | ArrayBufferView) | string, + ); + get request(): ArrayBuffer | string; + get response(): ArrayBuffer | string; } export interface DurableObjectFacets { get(