diff --git a/packages/myid/src/http.zig b/packages/myid/src/http.zig index 482f3ab..6feb10a 100644 --- a/packages/myid/src/http.zig +++ b/packages/myid/src/http.zig @@ -1,6 +1,8 @@ const std = @import("std"); const main = @import("main.zig"); +const Parser = @import("http/Parser.zig"); + threadlocal var read_buffer: [2 * 1024 * 1024]u8 = undefined; threadlocal var write_buffer: [2 * 1024 * 1024]u8 = undefined; @@ -87,6 +89,17 @@ fn makeResponseEmpty(options: ResponseEmptyOptions) ![]const u8 { return fbs.getWritten(); } +fn makeResponseClose(options: ResponseEmptyOptions) ![]const u8 { + var fbs = std.io.fixedBufferStream(&write_buffer); + const writer = fbs.writer(); + + try writer.print("{s}", .{options.status_text}); + try writer.print("Connection: close\r\n", .{}); + try writer.print("\r\n", .{}); + + return fbs.getWritten(); +} + fn makeResponse(options: ResponseOptions) ![]const u8 { var fbs = std.io.fixedBufferStream(&write_buffer); const writer = fbs.writer(); @@ -103,41 +116,73 @@ fn makeResponse(options: ResponseOptions) ![]const u8 { pub fn process(conn: std.net.Server.Connection) !void { defer conn.stream.close(); - var arena = std.heap.ArenaAllocator.init(main.allocator); - defer arena.deinit(); + var leftover_bytes: usize = 0; - const allocator = arena.allocator(); - _ = allocator; + while (true) { + var parser = Parser.init(); + var total_bytes_read: usize = 0; - var running = true; - while (running) { - var header_end: ?usize = null; - var request_len: usize = 0; + while (true) { + var bytes_read: usize = undefined; + var chars: []const u8 = undefined; - const request_head = blk: while (conn.stream.read(read_buffer[request_len..])) |len| { - if (len == 0) { - running = false; - return; + if (leftover_bytes > 0) { + bytes_read = leftover_bytes; + chars = read_buffer[0..leftover_bytes]; + } else { + bytes_read = try conn.stream.read(read_buffer[total_bytes_read..]); + chars = read_buffer[total_bytes_read .. total_bytes_read + bytes_read]; } - header_end = std.mem.indexOfPos(u8, &read_buffer, request_len, "\r\n\r\n"); - request_len += len; + total_bytes_read += bytes_read; - if (header_end) |end| { - break :blk read_buffer[0..end]; + const res = parser.consume(chars) catch |err| switch (err) { + error.MethodNotSupported => { + const response = try makeResponseClose(.{ .status_text = status.method_not_allowed }); + try conn.stream.writeAll(response); + return; + }, + error.HttpVersionNotSupported => { + const response = try makeResponseClose(.{ .status_text = status.http_version_not_supported }); + try conn.stream.writeAll(response); + return; + }, + error.MissingLineFeed => { + const response = try makeResponseClose(.{ .status_text = status.bad_request }); + try conn.stream.writeAll(response); + return; + }, + error.InvalidContentLength => { + const response = try makeResponseClose(.{ .status_text = status.bad_request }); + try conn.stream.writeAll(response); + return; + }, + }; + + if (total_bytes_read >= read_buffer.len and !res.done) { + if (parser.state == .body) { + const response = try makeResponseClose(.{ .status_text = status.content_too_large }); + try conn.stream.writeAll(response); + return; + } else { + const response = try makeResponseClose(.{ .status_text = status.request_header_fields_too_large }); + try conn.stream.writeAll(response); + return; + } } - } else |err| { - return err; - }; - const response = try makeResponse(.{ - .response_body = "PONG\n", - }); + if (res.done) { + leftover_bytes = bytes_read - res.consumed; + break; + } + } + + const response = try makeResponse(.{ .response_body = "PONG\n" }); try conn.stream.writeAll(response); - if (request_len > request_head.len + 4) { - @memmove(&read_buffer, read_buffer[request_head.len + 4 .. request_len]); + if (leftover_bytes > 0) { + @memmove(&read_buffer, read_buffer[total_bytes_read - leftover_bytes .. total_bytes_read]); } } } diff --git a/packages/myid/src/http/Parser.zig b/packages/myid/src/http/Parser.zig new file mode 100644 index 0000000..dc26d56 --- /dev/null +++ b/packages/myid/src/http/Parser.zig @@ -0,0 +1,297 @@ +const std = @import("std"); + +const Parser = @This(); + +const log = std.log.scoped(.http_parser); + +const Error = error{ + MethodNotSupported, + HttpVersionNotSupported, + MissingLineFeed, + InvalidContentLength, +}; + +const State = union(enum) { + init: void, + method_d: void, + method_g: void, + method_h: void, + method_p: void, + method_de: void, + method_ge: void, + method_he: void, + method_pa: void, + method_po: void, + method_pu: void, + method_del: void, + method_hea: void, + method_pat: void, + method_pos: void, + method_dele: void, + method_patc: void, + method_delet: void, + method_complete: void, + pathname: []const u8, + pathname_complete: void, + version_h: void, + version_ht: void, + version_htt: void, + version_http: void, + @"version_http/@": void, + @"version_http/1@": void, + @"version_http/1.@": void, + version_complete: void, + start_line_end: void, + header_name_start: void, + header_name: []const u8, + header_value: []const u8, + header_line_end: void, + headers_end: void, + body: []const u8, +}; + +pub const ConsumeResult = struct { + consumed: usize, + done: bool, +}; + +pub const ConsumeCharResult = enum { + not_done, + done, +}; + +const Method = enum { + DELETE, + GET, + HEAD, + PATCH, + POST, + PUT, +}; + +state: State = .init, +current_header_is_content_length: bool = false, +content_length: usize = 0, + +pub fn init() Parser { + return .{}; +} + +pub fn consume(self: *Parser, chars: []const u8) Error!ConsumeResult { + for (chars, 1..) |*c_ptr, len| { + const res = try self.consumeChar(c_ptr); + if (res == .done) return .{ + .consumed = len, + .done = true, + }; + } + + return .{ + .consumed = chars.len, + .done = false, + }; +} + +pub fn consumeChar(self: *Parser, c_ptr: *const u8) Error!ConsumeCharResult { + const c = c_ptr.*; + const c_slice = @as([*]const u8, @ptrCast(c_ptr))[0..1]; + switch (self.state) { + .init => switch (c) { + 'D' => self.state = .method_d, + 'G' => self.state = .method_g, + 'H' => self.state = .method_h, + 'P' => self.state = .method_p, + else => return error.MethodNotSupported, + }, + .method_d => switch (c) { + 'E' => self.state = .method_de, + else => return error.MethodNotSupported, + }, + .method_g => switch (c) { + 'E' => self.state = .method_ge, + else => return error.MethodNotSupported, + }, + .method_h => switch (c) { + 'E' => self.state = .method_he, + else => return error.MethodNotSupported, + }, + .method_p => switch (c) { + 'A' => self.state = .method_pa, + 'O' => self.state = .method_po, + 'U' => self.state = .method_pu, + else => return error.MethodNotSupported, + }, + .method_de => switch (c) { + 'L' => self.state = .method_del, + else => return error.MethodNotSupported, + }, + .method_ge => switch (c) { + 'T' => { + self.state = .method_complete; + log.debug("Method: GET", .{}); + }, + else => return error.MethodNotSupported, + }, + .method_he => switch (c) { + 'A' => self.state = .method_hea, + else => return error.MethodNotSupported, + }, + .method_pa => switch (c) { + 'T' => self.state = .method_pat, + else => return error.MethodNotSupported, + }, + .method_po => switch (c) { + 'S' => self.state = .method_pos, + else => return error.MethodNotSupported, + }, + .method_pu => switch (c) { + 'T' => { + self.state = .method_complete; + log.debug("Method: PUT", .{}); + }, + else => return error.MethodNotSupported, + }, + .method_del => switch (c) { + 'E' => self.state = .method_dele, + else => return error.MethodNotSupported, + }, + .method_hea => switch (c) { + 'D' => { + self.state = .method_complete; + log.debug("Method: HEAD", .{}); + }, + else => return error.MethodNotSupported, + }, + .method_pat => switch (c) { + 'C' => self.state = .method_patc, + else => return error.MethodNotSupported, + }, + .method_pos => switch (c) { + 'T' => { + self.state = .method_complete; + log.debug("Method: POST", .{}); + }, + else => return error.MethodNotSupported, + }, + .method_dele => switch (c) { + 'T' => self.state = .method_delet, + else => return error.MethodNotSupported, + }, + .method_patc => switch (c) { + 'H' => { + self.state = .method_complete; + log.debug("Method: PATCH", .{}); + }, + else => return error.MethodNotSupported, + }, + .method_delet => switch (c) { + 'E' => { + self.state = .method_complete; + log.debug("Method: DELETE", .{}); + }, + else => return error.MethodNotSupported, + }, + .method_complete => switch (c) { + ' ' => self.state = .{ .pathname = @as([*]const u8, @ptrCast(c_ptr))[1..1] }, + else => return error.MethodNotSupported, + }, + .pathname => |pathname| switch (c) { + ' ' => { + self.state = .pathname_complete; + log.debug("Pathname [{}]: {s}", .{ pathname.len, pathname }); + }, + else => self.state = .{ .pathname = pathname.ptr[0 .. pathname.len + 1] }, + }, + .pathname_complete => switch (c) { + 'H' => self.state = .version_h, + else => return error.HttpVersionNotSupported, + }, + .version_h => switch (c) { + 'T' => self.state = .version_ht, + else => return error.HttpVersionNotSupported, + }, + .version_ht => switch (c) { + 'T' => self.state = .version_htt, + else => return error.HttpVersionNotSupported, + }, + .version_htt => switch (c) { + 'P' => self.state = .version_http, + else => return error.HttpVersionNotSupported, + }, + .version_http => switch (c) { + '/' => self.state = .@"version_http/@", + else => return error.HttpVersionNotSupported, + }, + .@"version_http/@" => switch (c) { + '1' => self.state = .@"version_http/1@", + else => return error.HttpVersionNotSupported, + }, + .@"version_http/1@" => switch (c) { + '.' => self.state = .@"version_http/1.@", + else => return error.HttpVersionNotSupported, + }, + .@"version_http/1.@" => switch (c) { + '1' => { + self.state = .version_complete; + log.debug("Version: HTTP/1.1", .{}); + }, + else => return error.HttpVersionNotSupported, + }, + .version_complete => switch (c) { + '\r' => self.state = .start_line_end, + else => return error.HttpVersionNotSupported, + }, + .start_line_end => switch (c) { + '\n' => self.state = .header_name_start, + else => return error.MissingLineFeed, + }, + .header_name_start => switch (c) { + '\r' => self.state = .headers_end, + else => self.state = .{ .header_name = c_slice }, + }, + .header_name => |name| switch (c) { + ':' => { + self.state = .{ .header_value = @as([*]const u8, @ptrCast(c_ptr))[1..1] }; + self.current_header_is_content_length = std.ascii.eqlIgnoreCase(name, "Content-Length"); + log.debug("Header name [{}]: {s}", .{ name.len, name }); + }, + else => self.state = .{ .header_name = name.ptr[0 .. name.len + 1] }, + }, + .header_value => |value| switch (c) { + '\r' => { + self.state = .header_line_end; + const value_trimmed = std.mem.trim(u8, value, " \t"); + log.debug("Header value [{}]: {s}", .{ value_trimmed.len, value_trimmed }); + if (self.current_header_is_content_length) { + self.content_length = std.fmt.parseInt(usize, value_trimmed, 10) catch return error.InvalidContentLength; + self.current_header_is_content_length = false; + } + }, + else => self.state = .{ .header_value = value.ptr[0 .. value.len + 1] }, + }, + .header_line_end => switch (c) { + '\n' => self.state = .header_name_start, + else => return error.MissingLineFeed, + }, + .headers_end => switch (c) { + '\n' => { + if (self.content_length == 0) { + log.debug("End of request (no body)", .{}); + return .done; + } + self.state = .{ .body = @as([*]const u8, @ptrCast(c_ptr))[1..1] }; + }, + else => return error.MissingLineFeed, + }, + .body => |body| { + const new_body = body.ptr[0 .. body.len + 1]; + self.state = .{ .body = new_body }; + if (new_body.len >= self.content_length) { + log.debug("End of request ({} body bytes)", .{new_body.len}); + return .done; + } + }, + } + + return .not_done; +}