Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 77 additions & 0 deletions proxy/src/decompile.rs
Original file line number Diff line number Diff line change
Expand Up @@ -141,3 +141,80 @@ pub fn rewrite_jdt_locations(
}
rewritten
}

/// A jdt:// URI in embedded markdown/text terminates at whitespace or any of these
/// delimiters commonly used in markdown links and JSON strings. The URI itself only
/// contains URL-encoded forms of these characters, so scanning until we hit one of
/// them is safe.
fn jdt_uri_end(s: &str) -> usize {
s.find(|c: char| c.is_whitespace() || matches!(c, ')' | ']' | '"' | '>' | '`' | '\''))
.unwrap_or(s.len())
}

/// Extract all unique `jdt://` URIs appearing inside any string in `value`.
fn collect_jdt_uris(value: &Value, out: &mut Vec<String>) {
match value {
Value::String(s) => {
let mut rest = s.as_str();
while let Some(pos) = rest.find("jdt://") {
let tail = &rest[pos..];
let end = jdt_uri_end(tail);
let uri = tail[..end].to_string();
if !out.contains(&uri) {
out.push(uri);
}
rest = &tail[end..];
}
}
Value::Array(arr) => arr.iter().for_each(|v| collect_jdt_uris(v, out)),
Value::Object(obj) => obj.values().for_each(|v| collect_jdt_uris(v, out)),
_ => {}
}
}

/// Replace all occurrences of any key in `map` with its value, inside every string
/// contained in `value` (recursively).
fn replace_in_strings(value: &mut Value, map: &HashMap<String, String>) {
match value {
Value::String(s) => {
for (from, to) in map {
if s.contains(from.as_str()) {
*s = s.replace(from.as_str(), to);
}
}
}
Value::Array(arr) => arr.iter_mut().for_each(|v| replace_in_strings(v, map)),
Value::Object(obj) => obj.values_mut().for_each(|v| replace_in_strings(v, map)),
_ => {}
}
}

/// Scan a documentation response (hover, signatureHelp, completionItem/resolve, …)
/// for embedded `jdt://` URIs, resolve each one to a `file://` URI backed by a temp
/// file, and replace the URIs in-place in every string of `msg.result`.
pub fn rewrite_jdt_in_strings(
msg: &mut Value,
writer: &Arc<Mutex<impl Write>>,
pending: &Arc<Mutex<HashMap<Value, mpsc::Sender<Value>>>>,
next_id: &mut impl FnMut() -> Value,
) {
let Some(result) = msg.get_mut("result") else {
return;
};

let mut uris = Vec::new();
collect_jdt_uris(result, &mut uris);
if uris.is_empty() {
return;
}

let mut map = HashMap::new();
for uri in uris {
if let Some(file_uri) = resolve_jdt_uri(&uri, writer, pending, next_id()) {
map.insert(uri, file_uri);
}
}
if !map.is_empty() {
replace_in_strings(result, &map);
}
}
11 changes: 11 additions & 0 deletions proxy/src/lsp.rs
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,17 @@ pub fn parse_lsp_content(raw: &[u8]) -> Option<serde_json::Value> {
serde_json::from_slice(&raw[sep_pos + 4..]).ok()
}

/// Cheap check for the presence of an `"id"` key in the JSON body of a raw LSP
/// message. Used to skip full JSON parsing for notifications, which carry no
/// `id` and therefore cannot be responses or completion results.
pub fn raw_has_id(raw: &[u8]) -> bool {
let Some(sep_pos) = raw.windows(4).position(|w| w == HEADER_SEP) else {
return false;
};
let body = &raw[sep_pos + 4..];
body.windows(5).any(|w| w == b"\"id\":")
}

pub fn encode_lsp(value: &impl Serialize) -> String {
let json = serde_json::to_string(value).unwrap();
format!("{CONTENT_LENGTH}: {}\r\n\r\n{json}", json.len())
Expand Down
86 changes: 64 additions & 22 deletions proxy/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,13 @@ mod lsp;
mod platform;

use completions::{should_sort_completions, sort_completions_by_param_count};
use decompile::rewrite_jdt_locations;
use decompile::{rewrite_jdt_in_strings, rewrite_jdt_locations};
use http::handle_http;
use lsp::{parse_lsp_content, write_raw, write_to_stdout, LspReader};
use lsp::{parse_lsp_content, raw_has_id, write_raw, write_to_stdout, LspReader};
use platform::spawn_parent_monitor;
use serde_json::Value;
use std::{
collections::{HashMap, HashSet},
collections::HashMap,
env, fs,
io::{self, BufReader, Write},
net::TcpListener,
Expand All @@ -25,6 +25,12 @@ use std::{
thread,
};

#[derive(Clone, Copy)]
enum TrackedKind {
Definition,
Doc,
}

fn main() {
let args: Vec<String> = env::args().skip(1).collect();
if args.len() < 2 {
Expand Down Expand Up @@ -90,29 +96,41 @@ fn main() {

let id_counter = Arc::new(AtomicU64::new(1));

// Track definition/typeDefinition/implementation request IDs for jdt:// rewriting
let definition_ids: Arc<Mutex<HashSet<Value>>> = Arc::new(Mutex::new(HashSet::new()));
// Track definition/typeDefinition/implementation and documentation request IDs
// so their responses can be intercepted and rewritten.
let tracked_ids: Arc<Mutex<HashMap<Value, TrackedKind>>> = Arc::new(Mutex::new(HashMap::new()));

// --- Thread 1: Zed stdin -> JDTLS stdin (track definition requests) ---
let stdin_writer = Arc::clone(&child_stdin);
let alive_stdin = Arc::clone(&alive);
let def_ids_in = Arc::clone(&definition_ids);
let tracked_in = Arc::clone(&tracked_ids);
thread::spawn(move || {
let stdin = io::stdin().lock();
let mut reader = LspReader::new(stdin);
let mut reader = LspReader::new(BufReader::new(stdin));
while alive_stdin.load(Ordering::Relaxed) {
match reader.read_message() {
Ok(Some(raw)) => {
if let Some(msg) = parse_lsp_content(&raw) {
if let Some(method) = msg.get("method").and_then(|m| m.as_str()) {
if matches!(
method,
"textDocument/definition"
// Only requests (not notifications) carry an `id`; skip the
// JSON parse entirely for high-volume notifications like
// textDocument/didChange.
if raw_has_id(&raw) {
if let Some(msg) = parse_lsp_content(&raw) {
if let Some(method) = msg.get("method").and_then(|m| m.as_str()) {
let kind = match method {
"textDocument/definition"
| "textDocument/typeDefinition"
| "textDocument/implementation"
) {
if let Some(id) = msg.get("id").cloned() {
def_ids_in.lock().unwrap().insert(id);
| "textDocument/implementation" => {
Some(TrackedKind::Definition)
}
"textDocument/hover"
| "textDocument/signatureHelp"
| "completionItem/resolve" => Some(TrackedKind::Doc),
_ => None,
};
if let Some(kind) = kind {
if let Some(id) = msg.get("id").cloned() {
tracked_in.lock().unwrap().insert(id, kind);
}
}
}
}
Expand All @@ -131,7 +149,7 @@ fn main() {
// --- Thread 2: JDTLS stdout -> rewrite jdt:// URIs, modify completions -> Zed stdout / resolve pending ---
let pending_out = Arc::clone(&pending);
let alive_out = Arc::clone(&alive);
let def_ids_out = Arc::clone(&definition_ids);
let tracked_out = Arc::clone(&tracked_ids);
let decompile_writer = Arc::clone(&child_stdin);
let decompile_pending = Arc::clone(&pending);
let decompile_counter = Arc::clone(&id_counter);
Expand All @@ -141,6 +159,13 @@ fn main() {
while alive_out.load(Ordering::Relaxed) {
match reader.read_message() {
Ok(Some(raw)) => {
// Fast path: notifications (no `id`) can't be responses we
// need to intercept. Forward the raw bytes without parsing.
if !raw_has_id(&raw) {
write_raw(&mut io::stdout().lock(), &raw);
continue;
}

let Some(mut msg) = parse_lsp_content(&raw) else {
write_raw(&mut io::stdout().lock(), &raw);
continue;
Expand All @@ -154,11 +179,11 @@ fn main() {
}
}

// Rewrite jdt:// URIs in definition responses
// Spawns a thread so this loop stays unblocked and can
// route the java/classFileContents response back via `pending`.
// Rewrite jdt:// URIs in definition or documentation responses.
// Spawns a thread so this loop stays unblocked and can route
// the java/classFileContents response back via `pending`.
if let Some(id) = msg.get("id").cloned() {
if def_ids_out.lock().unwrap().remove(&id) {
if let Some(kind) = tracked_out.lock().unwrap().remove(&id) {
let writer = Arc::clone(&decompile_writer);
let pending = Arc::clone(&decompile_pending);
let pid = decompile_proxy_id.clone();
Expand All @@ -168,7 +193,24 @@ fn main() {
let seq = counter.fetch_add(1, Ordering::Relaxed);
Value::String(format!("{pid}-decompile-{seq}"))
};
rewrite_jdt_locations(&mut msg, &writer, &pending, &mut next_id);
match kind {
TrackedKind::Definition => {
rewrite_jdt_locations(
&mut msg,
&writer,
&pending,
&mut next_id,
);
}
TrackedKind::Doc => {
rewrite_jdt_in_strings(
&mut msg,
&writer,
&pending,
&mut next_id,
);
}
}
write_to_stdout(&msg);
});
continue;
Expand Down
Loading