Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/workerd/api/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ filegroup(
"unsafe.h",
"url-standard.h",
"web-socket.h",
"web-socket-data-message.h",
"worker-rpc.h",
] + glob(
[
Expand Down
14 changes: 8 additions & 6 deletions src/workerd/api/actor-state.c++
Original file line number Diff line number Diff line change
Expand Up @@ -1223,16 +1223,18 @@ void DurableObjectState::setWebSocketAutoResponse(
auto reqResp = KJ_REQUIRE_NONNULL(kj::mv(maybeReqResp));
auto maxRequestOrResponseSize = 2048;

JSG_REQUIRE(reqResp->getRequest().size() <= maxRequestOrResponseSize, RangeError,
auto request = WebSocketDataMessage(reqResp->getRequest());
auto response = WebSocketDataMessage(reqResp->getResponse());

JSG_REQUIRE(request.size() <= maxRequestOrResponseSize, RangeError,
kj::str("Request cannot be larger than ", maxRequestOrResponseSize, " bytes. ",
"A request of size ", reqResp->getRequest().size(), " was provided."));
"A request of size ", request.size(), " was provided."));

JSG_REQUIRE(reqResp->getResponse().size() <= maxRequestOrResponseSize, RangeError,
JSG_REQUIRE(response.size() <= maxRequestOrResponseSize, RangeError,
kj::str("Response cannot be larger than ", maxRequestOrResponseSize, " bytes. ",
"A response of size ", reqResp->getResponse().size(), " was provided."));
"A response of size ", response.size(), " was provided."));

maybeInitHibernationManager(a).setWebSocketAutoResponse(
reqResp->getRequest(), reqResp->getResponse());
maybeInitHibernationManager(a).setWebSocketAutoResponse(kj::mv(request), kj::mv(response));
}

kj::Maybe<jsg::Ref<api::WebSocketRequestResponsePair>> DurableObjectState::getWebSocketAutoResponse(
Expand Down
40 changes: 28 additions & 12 deletions src/workerd/api/actor-state.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

#include <workerd/api/actor.h>
#include <workerd/api/container.h>
#include <workerd/api/web-socket-data-message.h>
#include <workerd/io/actor-cache.h>
#include <workerd/io/actor-id.h>
#include <workerd/io/compatibility-date.capnp.h>
Expand Down Expand Up @@ -537,20 +538,23 @@ class ActorState: public jsg::Object {

class WebSocketRequestResponsePair: public jsg::Object {
public:
WebSocketRequestResponsePair(kj::String request, kj::String response)
WebSocketRequestResponsePair(WebSocketDataMessage request, WebSocketDataMessage response)
: request(kj::mv(request)),
response(kj::mv(response)) {};

static jsg::Ref<WebSocketRequestResponsePair> constructor(
jsg::Lock& js, kj::String request, kj::String response) {
return js.alloc<WebSocketRequestResponsePair>(kj::mv(request), kj::mv(response));
static jsg::Ref<WebSocketRequestResponsePair> constructor(jsg::Lock& js,
kj::OneOf<kj::Array<kj::byte>, kj::String> request,
kj::OneOf<kj::Array<kj::byte>, kj::String> response) {
auto req = WebSocketDataMessage(kj::mv(request));
auto resp = WebSocketDataMessage(kj::mv(response));
return js.alloc<WebSocketRequestResponsePair>(kj::mv(req), kj::mv(resp));
};

kj::StringPtr getRequest() {
return request.asPtr();
kj::OneOf<kj::Array<kj::byte>, kj::String> getRequest() {
return toJsgOneOf(request);
}
kj::StringPtr getResponse() {
return response.asPtr();
kj::OneOf<kj::Array<kj::byte>, kj::String> getResponse() {
return toJsgOneOf(response);
}

JSG_RESOURCE_TYPE(WebSocketRequestResponsePair) {
Expand All @@ -559,13 +563,25 @@ class WebSocketRequestResponsePair: public jsg::Object {
}

void visitForMemoryInfo(jsg::MemoryTracker& tracker) const {
tracker.trackField("request", request);
tracker.trackField("response", response);
tracker.trackFieldWithSize("request", request.size());
tracker.trackFieldWithSize("response", response.size());
Comment on lines +566 to +567
Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't know all of the nuances here, but I think this diminishes the usefulness of this tracking. Not sure what the proper fix is - maybe adding something like trackMemory(jsg::MemoryTracker&, const char*) to WebSocketDataMessage?

}

private:
kj::String request;
kj::String response;
WebSocketDataMessage request;
WebSocketDataMessage response;

static kj::OneOf<kj::Array<kj::byte>, kj::String> toJsgOneOf(const WebSocketDataMessage& msg) {
KJ_SWITCH_ONEOF(msg.asOneOf()) {
KJ_CASE_ONEOF(text, kj::StringPtr) {
return kj::str(text);
}
KJ_CASE_ONEOF(bytes, kj::ArrayPtr<const kj::byte>) {
return kj::heapArray(bytes);
}
}
KJ_UNREACHABLE;
}
};

// The type passed as the first parameter to durable object class's constructor.
Expand Down
8 changes: 4 additions & 4 deletions src/workerd/api/web-socket.c++
Original file line number Diff line number Diff line change
Expand Up @@ -862,11 +862,11 @@ void WebSocket::ensurePumping(jsg::Lock& js) {
}
}

kj::Promise<void> WebSocket::sendAutoResponse(kj::String message, kj::WebSocket& ws) {
kj::Promise<void> WebSocket::sendAutoResponse(WebSocketDataMessage message, kj::WebSocket& ws) {
if (autoResponseStatus.isPumping) {
autoResponseStatus.pendingAutoResponseDeque.push(kj::mv(message));
} else if (!autoResponseStatus.isClosed) {
auto p = ws.send(message).fork();
auto p = message.sendVia(ws).fork();
KJ_IF_SOME(context, IoContext::tryCurrent()) {
autoResponseStatus.ongoingAutoResponse.emplace(context.addObject(kj::heap(p.addBranch())));
} else {
Expand Down Expand Up @@ -965,7 +965,7 @@ kj::Promise<void> WebSocket::pump(IoContext& context,
auto message = KJ_ASSERT_NONNULL(autoResponse.pendingAutoResponseDeque.pop());
gatedMessage.pendingAutoResponses--;
autoResponse.queuedAutoResponses--;
co_await ws.send(message);
co_await message.sendVia(ws);
}

KJ_SWITCH_ONEOF(gatedMessage.message) {
Expand Down Expand Up @@ -997,7 +997,7 @@ kj::Promise<void> WebSocket::pump(IoContext& context,
// We should also check if the last sent message was a close. Shouldn't happen.
while (!autoResponse.pendingAutoResponseDeque.empty() && !autoResponse.isClosed) {
auto message = KJ_ASSERT_NONNULL(autoResponse.pendingAutoResponseDeque.pop());
co_await ws.send(message);
co_await message.sendVia(ws);
}

// While we were `co_await`ing the auto-response send, more messages could have been queued
Expand Down
10 changes: 6 additions & 4 deletions src/workerd/api/web-socket.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

#include "basics.h"
#include "events.h"
#include "web-socket-data-message.h"

#include <workerd/io/io-gate.h>
#include <workerd/io/observer.h>
Expand Down Expand Up @@ -336,7 +337,7 @@ class WebSocket: public EventTarget {
// These methods are c++ only and are not exposed to our js interface.
kj::Maybe<kj::Date> getAutoResponseTimestamp();

kj::Promise<void> sendAutoResponse(kj::String message, kj::WebSocket& ws);
kj::Promise<void> sendAutoResponse(WebSocketDataMessage message, kj::WebSocket& ws);

int getReadyState();

Expand Down Expand Up @@ -599,15 +600,16 @@ class WebSocket: public EventTarget {
using OwnedAutoResponsePromise =
kj::OneOf<IoOwn<kj::Promise<void>>, kj::Own<kj::Promise<void>>>;
kj::Maybe<OwnedAutoResponsePromise> ongoingAutoResponse;
workerd::util::Queue<kj::String> pendingAutoResponseDeque;
workerd::util::Queue<WebSocketDataMessage> pendingAutoResponseDeque;
size_t queuedAutoResponses = 0;
bool isPumping = false;
bool isClosed = false;

JSG_MEMORY_INFO(AutoResponse) {
tracker.trackFieldWithSize("ongoingAutoResponse", sizeof(kj::Promise<void>));
pendingAutoResponseDeque.forEach(
[&](const kj::String& message) { tracker.trackField(nullptr, message); });
pendingAutoResponseDeque.forEach([&](const WebSocketDataMessage& message) {
tracker.trackFieldWithSize(nullptr, message.size());
});
}
};

Expand Down
1 change: 1 addition & 0 deletions src/workerd/io/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,7 @@ wd_cc_library(
"//src/workerd/api:fuzzilli",
"//src/workerd/api:hibernation-event-params",
"//src/workerd/api:url",
"//src/workerd/api:web-socket-data-message",
"//src/workerd/jsg",
"//src/workerd/jsg:inspector",
"//src/workerd/jsg:script",
Expand Down
Loading
Loading