Skip to content

Commit 3d5d1a8

Browse files
authored
Merge branch 'coder:main' into fix/listed-as-buffer
2 parents 7280ed4 + aa9a5ce commit 3d5d1a8

11 files changed

Lines changed: 376 additions & 53 deletions

File tree

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -280,7 +280,7 @@ For deep technical details, see [ARCHITECTURE.md](./ARCHITECTURE.md).
280280
auto_close_on_accept = true,
281281
vertical_split = true,
282282
open_in_current_tab = true,
283-
keep_terminal_focus = false, -- If true, moves focus back to terminal after diff opens
283+
keep_terminal_focus = false, -- If true, moves focus back to terminal after diff opens (including floating terminals)
284284
},
285285
},
286286
keys = {

lua/claudecode/config.lua

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ M.defaults = {
2323
diff_opts = {
2424
layout = "vertical",
2525
open_in_new_tab = false, -- Open diff in a new tab (false = use current tab)
26-
keep_terminal_focus = false, -- If true, moves focus back to terminal after diff opens
26+
keep_terminal_focus = false, -- If true, moves focus back to terminal after diff opens (including floating terminals)
2727
hide_terminal_in_new_tab = false, -- If true and opening in a new tab, do not show Claude terminal there
2828
on_new_file_reject = "keep_empty", -- "keep_empty" leaves an empty buffer; "close_window" closes the placeholder split
2929
},

lua/claudecode/diff.lua

Lines changed: 45 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -101,17 +101,24 @@ local function find_claudecode_terminal_window()
101101
return nil
102102
end
103103

104-
-- Find the window containing this buffer
104+
-- Find the window containing this buffer.
105+
-- Prefer a normal split window, but fall back to a floating terminal window (e.g. Snacks position="float").
106+
local floating_fallback = nil
107+
105108
for _, win in ipairs(vim.api.nvim_list_wins()) do
106109
if vim.api.nvim_win_get_buf(win) == terminal_bufnr then
107110
local win_config = vim.api.nvim_win_get_config(win)
108-
if not (win_config.relative and win_config.relative ~= "") then
111+
local is_floating = win_config.relative and win_config.relative ~= ""
112+
113+
if is_floating then
114+
floating_fallback = floating_fallback or win
115+
else
109116
return win
110117
end
111118
end
112119
end
113120

114-
return nil
121+
return floating_fallback
115122
end
116123

117124
---Create a split based on configured layout
@@ -619,11 +626,17 @@ local function setup_new_buffer(
619626
term_tab = vim.api.nvim_win_get_tabpage(terminal_win)
620627
end)
621628
if term_tab == current_tab then
622-
local terminal_config = config.terminal or {}
623-
local split_width = terminal_config.split_width_percentage or 0.30
624-
local total_width = vim.o.columns
625-
local terminal_width = math.floor(total_width * split_width)
626-
pcall(vim.api.nvim_win_set_width, terminal_win, terminal_width)
629+
local win_config = vim.api.nvim_win_get_config(terminal_win)
630+
local is_floating = win_config.relative and win_config.relative ~= ""
631+
632+
-- Only resize split terminals. Floating terminals control their own sizing.
633+
if not is_floating then
634+
local terminal_config = config.terminal or {}
635+
local split_width = terminal_config.split_width_percentage or 0.30
636+
local total_width = vim.o.columns
637+
local terminal_width = math.floor(total_width * split_width)
638+
pcall(vim.api.nvim_win_set_width, terminal_win, terminal_width)
639+
end
627640
end
628641
end
629642
end
@@ -1015,14 +1028,20 @@ function M._cleanup_diff_state(tab_name, reason)
10151028
local terminal_ok, terminal_module = pcall(require, "claudecode.terminal")
10161029
if terminal_ok and diff_data.had_terminal_in_original then
10171030
pcall(terminal_module.ensure_visible)
1018-
-- And restore its configured width if it is visible
1031+
-- And restore its configured width if it is visible.
1032+
-- (We intentionally do not resize floating terminals.)
10191033
local terminal_win = find_claudecode_terminal_window()
10201034
if terminal_win and vim.api.nvim_win_is_valid(terminal_win) then
1021-
local terminal_config = config.terminal or {}
1022-
local split_width = terminal_config.split_width_percentage or 0.30
1023-
local total_width = vim.o.columns
1024-
local terminal_width = math.floor(total_width * split_width)
1025-
pcall(vim.api.nvim_win_set_width, terminal_win, terminal_width)
1035+
local win_config = vim.api.nvim_win_get_config(terminal_win)
1036+
local is_floating = win_config.relative and win_config.relative ~= ""
1037+
1038+
if not is_floating then
1039+
local terminal_config = config.terminal or {}
1040+
local split_width = terminal_config.split_width_percentage or 0.30
1041+
local total_width = vim.o.columns
1042+
local terminal_width = math.floor(total_width * split_width)
1043+
pcall(vim.api.nvim_win_set_width, terminal_win, terminal_width)
1044+
end
10261045
end
10271046
end
10281047
else
@@ -1038,14 +1057,20 @@ function M._cleanup_diff_state(tab_name, reason)
10381057
end)
10391058
end
10401059

1041-
-- After closing the diff in the same tab, restore terminal width if visible
1060+
-- After closing the diff in the same tab, restore terminal width if visible.
1061+
-- (We intentionally do not resize floating terminals.)
10421062
local terminal_win = find_claudecode_terminal_window()
10431063
if terminal_win and vim.api.nvim_win_is_valid(terminal_win) then
1044-
local terminal_config = config.terminal or {}
1045-
local split_width = terminal_config.split_width_percentage or 0.30
1046-
local total_width = vim.o.columns
1047-
local terminal_width = math.floor(total_width * split_width)
1048-
pcall(vim.api.nvim_win_set_width, terminal_win, terminal_width)
1064+
local win_config = vim.api.nvim_win_get_config(terminal_win)
1065+
local is_floating = win_config.relative and win_config.relative ~= ""
1066+
1067+
if not is_floating then
1068+
local terminal_config = config.terminal or {}
1069+
local split_width = terminal_config.split_width_percentage or 0.30
1070+
local total_width = vim.o.columns
1071+
local terminal_width = math.floor(total_width * split_width)
1072+
pcall(vim.api.nvim_win_set_width, terminal_win, terminal_width)
1073+
end
10491074
end
10501075
end
10511076

lua/claudecode/server/init.lua

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -12,14 +12,12 @@ local M = {}
1212
---@field server table|nil The TCP server instance
1313
---@field port number|nil The port server is running on
1414
---@field auth_token string|nil The authentication token for validating connections
15-
---@field clients table<string, WebSocketClient> A list of connected clients
1615
---@field handlers table Message handlers by method name
1716
---@field ping_timer table|nil Timer for sending pings
1817
M.state = {
1918
server = nil,
2019
port = nil,
2120
auth_token = nil,
22-
clients = {},
2321
handlers = {},
2422
ping_timer = nil,
2523
}
@@ -53,8 +51,6 @@ function M.start(config, auth_token)
5351
M._handle_message(client, message)
5452
end,
5553
on_connect = function(client)
56-
M.state.clients[client.id] = client
57-
5854
-- Log connection with auth status
5955
if M.state.auth_token then
6056
logger.debug("server", "Authenticated WebSocket client connected:", client.id)
@@ -71,7 +67,6 @@ function M.start(config, auth_token)
7167
end
7268
end,
7369
on_disconnect = function(client, code, reason)
74-
M.state.clients[client.id] = nil
7570
logger.debug(
7671
"server",
7772
"WebSocket client disconnected:",
@@ -124,8 +119,6 @@ function M.stop()
124119
M.state.server = nil
125120
M.state.port = nil
126121
M.state.auth_token = nil
127-
M.state.clients = {}
128-
129122
return true
130123
end
131124

@@ -213,8 +206,6 @@ end
213206
local module_instance_id = math.random(10000, 99999)
214207
logger.debug("server", "Server module loaded with instance ID:", module_instance_id)
215208

216-
-- Note: debug_deferred_table function removed as deferred_responses table is no longer used
217-
218209
function M._setup_deferred_response(deferred_info)
219210
local co = deferred_info.coroutine
220211

lua/claudecode/server/mock.lua

Lines changed: 34 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@ local tools = require("claudecode.tools.init")
1212
M.state = {
1313
server = nil,
1414
port = nil,
15-
clients = {},
1615
handlers = {},
1716
messages = {}, -- Store messages for testing
1817
}
@@ -74,7 +73,6 @@ function M.stop()
7473
-- Reset state
7574
M.state.server = nil
7675
M.state.port = nil
77-
M.state.clients = {}
7876
M.state.messages = {}
7977

8078
return true
@@ -101,29 +99,36 @@ end
10199
---@param client_id string A unique client identifier
102100
---@return table client The client object
103101
function M.add_client(client_id)
102+
assert(type(client_id) == "string", "Expected client_id to be a string")
104103
if not M.state.server then
105104
error("Server not running")
106105
end
106+
assert(type(M.state.server.clients) == "table", "Expected mock server.clients to be a table")
107107

108108
local client = {
109109
id = client_id,
110110
connected = true,
111111
messages = {},
112112
}
113113

114-
M.state.clients[client_id] = client
114+
M.state.server.clients[client_id] = client
115115
return client
116116
end
117117

118118
---Remove a client from the server
119119
---@param client_id string The client identifier
120120
---@return boolean success Whether removal was successful
121121
function M.remove_client(client_id)
122-
if not M.state.server or not M.state.clients[client_id] then
122+
assert(type(client_id) == "string", "Expected client_id to be a string")
123+
if not M.state.server or type(M.state.server.clients) ~= "table" then
123124
return false
124125
end
125126

126-
M.state.clients[client_id] = nil
127+
if not M.state.server.clients[client_id] then
128+
return false
129+
end
130+
131+
M.state.server.clients[client_id] = nil
127132
return true
128133
end
129134

@@ -136,7 +141,10 @@ function M.send(client, method, params)
136141
local client_obj
137142

138143
if type(client) == "string" then
139-
client_obj = M.state.clients[client]
144+
if not M.state.server or type(M.state.server.clients) ~= "table" then
145+
return false
146+
end
147+
client_obj = M.state.server.clients[client]
140148
else
141149
client_obj = client
142150
end
@@ -172,7 +180,10 @@ function M.send_response(client, id, result, error)
172180
local client_obj
173181

174182
if type(client) == "string" then
175-
client_obj = M.state.clients[client]
183+
if not M.state.server or type(M.state.server.clients) ~= "table" then
184+
return false
185+
end
186+
client_obj = M.state.server.clients[client]
176187
else
177188
client_obj = client
178189
end
@@ -208,9 +219,13 @@ end
208219
---@param params table The parameters to send
209220
---@return boolean success Whether broadcasting was successful
210221
function M.broadcast(method, params)
222+
if not M.state.server or type(M.state.server.clients) ~= "table" then
223+
return false
224+
end
225+
211226
local success = true
212227

213-
for client_id, _ in pairs(M.state.clients) do
228+
for client_id, _ in pairs(M.state.server.clients) do
214229
local send_success = M.send(client_id, method, params)
215230
success = success and send_success
216231
end
@@ -223,7 +238,12 @@ end
223238
---@param message table The message to process
224239
---@return table|nil response The response if any
225240
function M.simulate_message(client_id, message)
226-
local client = M.state.clients[client_id]
241+
assert(type(client_id) == "string", "Expected client_id to be a string")
242+
if not M.state.server or type(M.state.server.clients) ~= "table" then
243+
return nil
244+
end
245+
246+
local client = M.state.server.clients[client_id]
227247

228248
if not client then
229249
return nil
@@ -255,7 +275,11 @@ end
255275
function M.clear_messages()
256276
M.state.messages = {}
257277

258-
for _, client in pairs(M.state.clients) do
278+
if not M.state.server or type(M.state.server.clients) ~= "table" then
279+
return
280+
end
281+
282+
for _, client in pairs(M.state.server.clients) do
259283
client.messages = {}
260284
end
261285
end

lua/claudecode/server/tcp.lua

Lines changed: 42 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -124,33 +124,68 @@ function M._handle_new_connection(server)
124124
-- Set up data handler
125125
client_tcp:read_start(function(err, data)
126126
if err then
127-
server.on_error("Client read error: " .. err)
128-
M._remove_client(server, client)
127+
local error_msg = "Client read error: " .. err
128+
server.on_error(error_msg)
129+
M._disconnect_client(server, client, 1006, error_msg)
129130
return
130131
end
131132

132133
if not data then
133134
-- EOF - client disconnected
134-
M._remove_client(server, client)
135+
M._disconnect_client(server, client, 1006, "EOF")
135136
return
136137
end
137138

138139
-- Process incoming data
139140
client_manager.process_data(client, data, function(cl, message)
140141
server.on_message(cl, message)
141142
end, function(cl, code, reason)
142-
server.on_disconnect(cl, code, reason)
143-
M._remove_client(server, cl)
143+
M._disconnect_client(server, cl, code, reason)
144144
end, function(cl, error_msg)
145145
server.on_error("Client " .. cl.id .. " error: " .. error_msg)
146-
M._remove_client(server, cl)
146+
M._disconnect_client(server, cl, 1006, "Client error: " .. error_msg)
147147
end, server.auth_token)
148148
end)
149149

150150
-- Notify about new connection
151151
server.on_connect(client)
152152
end
153153

154+
---Disconnect a client and remove it from the server.
155+
---This ensures `server.on_disconnect` is invoked for every disconnect path
156+
---(EOF, read errors, protocol errors, timeouts), and only once per client.
157+
---@param server TCPServer The server object
158+
---@param client WebSocketClient The client to disconnect
159+
---@param code number|nil WebSocket close code
160+
---@param reason string|nil WebSocket close reason
161+
function M._disconnect_client(server, client, code, reason)
162+
assert(type(server) == "table", "Expected server to be a table")
163+
local on_disconnect_type = type(server.on_disconnect)
164+
local on_disconnect_mt = on_disconnect_type == "table" and getmetatable(server.on_disconnect) or nil
165+
assert(
166+
on_disconnect_type == "function" or (on_disconnect_mt ~= nil and type(on_disconnect_mt.__call) == "function"),
167+
"Expected server.on_disconnect to be callable"
168+
)
169+
assert(type(server.clients) == "table", "Expected server.clients to be a table")
170+
assert(type(client) == "table", "Expected client to be a table")
171+
assert(type(client.id) == "string", "Expected client.id to be a string")
172+
if code ~= nil then
173+
assert(type(code) == "number", "Expected code to be a number")
174+
end
175+
if reason ~= nil then
176+
assert(type(reason) == "string", "Expected reason to be a string")
177+
end
178+
179+
-- Idempotency: a client can hit multiple disconnect paths (e.g. CLOSE frame
180+
-- followed by a TCP EOF). Only notify/remove once.
181+
if not server.clients[client.id] then
182+
return
183+
end
184+
185+
server.on_disconnect(client, code, reason)
186+
M._remove_client(server, client)
187+
end
188+
154189
---Remove a client from the server
155190
---@param server TCPServer The server object
156191
---@param client WebSocketClient The client to remove
@@ -293,7 +328,7 @@ function M.start_ping_timer(server, interval)
293328
string.format("Client %s keepalive timeout (%ds idle), closing connection", client.id, time_since_pong)
294329
)
295330
client_manager.close_client(client, 1006, "Connection timeout")
296-
M._remove_client(server, client)
331+
M._disconnect_client(server, client, 1006, "Connection timeout")
297332
end
298333
end
299334
end

tests/mocks/vim.lua

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -881,6 +881,7 @@ local vim = {
881881
return true
882882
end,
883883
read_start = function(self, callback)
884+
self._read_cb = callback
884885
return true
885886
end,
886887
write = function(self, data, callback)

0 commit comments

Comments
 (0)