diff --git a/.gitignore b/.gitignore index 1f9231f..c4d762b 100644 --- a/.gitignore +++ b/.gitignore @@ -8,3 +8,8 @@ fileserver/ gradlew.bat gradle/ gradle/ +/target/ +.classpath +.factorypath +*.prefs +.project diff --git a/README.md b/README.md index cbcce1e..27f4b6a 100644 --- a/README.md +++ b/README.md @@ -94,6 +94,19 @@ There is a [simple file server](https://github.com/robaho/httpserver/blob/727759 gradle runSimpleFileServer ``` +## websockets + +For websocket usage, see the examples in the [websocket testing folder](https://github.com/robaho/httpserver/tree/main/src/test/java/robaho/net/httpserver/websockets). + +In general, create a handler that extends WebSocketHandler, and add an endpoint for the handler: + +``` + HttpHandler h = new EchoWebSocketHandler(); + HttpContext c = server.createContext("/ws", h); +``` + +The low-level websocket api is [nanohttpd](https://github.com/NanoHttpd/nanohttpd) so there are many examples on the web. + ## logging All logging is performed using the [Java System Logger](https://docs.oracle.com/en/java/javase/19/docs/api/java.base/java/lang/System.Logger.html) diff --git a/build.gradle b/build.gradle index 1f39938..e47b7dc 100644 --- a/build.gradle +++ b/build.gradle @@ -1,7 +1,8 @@ plugins { - id 'java-library' id 'maven-publish' id 'signing' + id 'java-library' + id 'tech.yanand.maven-central-publish' version '1.3.0' } repositories { @@ -29,7 +30,7 @@ tasks.withType(Test) { systemProperty("robaho.net.httpserver.http2OverNonSSL","true") // systemProperty("robaho.net.httpserver.http2MaxConcurrentStreams","5000") // systemProperty("robaho.net.httpserver.http2DisableFlushDelay","true") - systemProperty("robaho.net.httpserver.http2OverSSL","true") + // systemProperty("robaho.net.httpserver.http2OverSSL","true") systemProperty("robaho.net.httpserver.http2OverNonSSL","true") // systemProperty("javax.net.debug","ssl:handshake:verbose:keymanager:trustmanager") } @@ -38,7 +39,7 @@ tasks.withType(JavaExec) { jvmArgs += "--enable-preview" systemProperty("java.util.logging.config.file","logging.properties") systemProperty("com.sun.net.httpserver.HttpServerProvider","robaho.net.httpserver.DefaultHttpServerProvider") - systemProperty("robaho.net.httpserver.http2OverSSL","true") + // systemProperty("robaho.net.httpserver.http2OverSSL","true") systemProperty("robaho.net.httpserver.http2OverNonSSL","true") systemProperty("robaho.net.httpserver.http2InitialWindowSize","1024000") systemProperty("robaho.net.httpserver.http2ConnectionWindowSize","1024000000") @@ -72,14 +73,14 @@ sourceSets { test { java { srcDirs = [ - 'src/test/extras', - 'src/test/java', - 'src/test/java_default/bugs', - 'src/test/java_default/HttpExchange' + 'src/test/extras', + 'src/test/java', + 'src/test/java_default/bugs', + 'src/test/java_default/HttpExchange' ] } } - testMains { + create('testMains') { java { srcDirs = ['src/test/test_mains'] compileClasspath = test.output + main.output + configurations.testMainsCompile @@ -91,22 +92,15 @@ sourceSets { } } -def getGitVersion () { - def output = new ByteArrayOutputStream() - exec { - commandLine 'git', 'rev-list', '--tags', '--max-count=1' - standardOutput = output - } - def revision = output.toString().trim() - output.reset() - exec { - commandLine 'git', 'describe', '--tags', revision - standardOutput = output - } - return output.toString().trim() -} +// Use a lazy Provider to get the git version. This is the modern, configuration-cache-friendly approach. +def gitVersionProvider = project.providers.exec { + // 1. Describe the latest revision with a tag + commandLine = ['git', 'describe', '--tags', '--always'] + ignoreExitValue = true // Don't fail the build if git fails (e.g., no tags exist) +}.standardOutput.asText.map { it.trim() } -version = getGitVersion() +// Apply the git version to your project +version = gitVersionProvider.get() task showGitVersion { doLast { @@ -115,9 +109,6 @@ task showGitVersion { } build { - doFirst { - getGitVersion - } } jar { @@ -132,7 +123,7 @@ task runSingleUnitTest(type: Test) { outputs.upToDateWhen { false } dependsOn testClasses filter { - includeTestsMatching 'InputNotRead' + includeTestsMatching 'PipeliningStallTest' } useTestNG() } @@ -180,15 +171,17 @@ task runSimpleFileServer(type: JavaExec) { } dependsOn testClasses classpath sourceSets.test.runtimeClasspath - main "SimpleFileServer" - args = ['fileserver','8080','fileserver/logfile.txt'] + + // FIX 1: Use 'mainClass' instead of 'main' + // FIX 2: Replace "SimpleFileServer" with the FULLY QUALIFIED class name + // (e.g., if it's in a package named com.example) + mainClass = "com.example.SimpleFileServer" + + args = ['fileserver','443','fileserver/logfile.txt'] + javaLauncher = javaToolchains.launcherFor { languageVersion = JavaLanguageVersion.of(23) } - // debugOptions { - // enabled = true - // suspend = true - // } } task testJar(type: Jar) { @@ -254,14 +247,17 @@ publishing { } } } - repositories { - maven { - name = "OSSRH" - url = "https://s01.oss.sonatype.org/service/local/staging/deploy/maven2/" - credentials { - username = "$maven_user" - password = "$maven_password" - } - } - } +} + +mavenCentral { + def tokenString = "${maven_user}:${maven_password}" + def token = tokenString.bytes.encodeBase64().toString() + authToken = token + // Whether the upload should be automatically published or not. Use 'USER_MANAGED' if you wish to do this manually. + // This property is optional and defaults to 'AUTOMATIC'. + publishingType = 'AUTOMATIC' + // Max wait time for status API to get 'PUBLISHING' or 'PUBLISHED' status when the publishing type is 'AUTOMATIC', + // or additionally 'VALIDATED' when the publishing type is 'USER_MANAGED'. + // This property is optional and defaults to 60 seconds. + maxWait = 60 } diff --git a/src/main/java/robaho/net/httpserver/Code.java b/src/main/java/robaho/net/httpserver/Code.java index a32a5cc..47c46ec 100644 --- a/src/main/java/robaho/net/httpserver/Code.java +++ b/src/main/java/robaho/net/httpserver/Code.java @@ -28,6 +28,7 @@ public class Code { public static final int HTTP_CONTINUE = 100; + public static final int HTTP_SWITCHING_PROTOCOLS = 101; public static final int HTTP_OK = 200; public static final int HTTP_CREATED = 201; public static final int HTTP_ACCEPTED = 202; @@ -71,6 +72,8 @@ static String msg(int code) { return " OK"; case HTTP_CONTINUE: return " Continue"; + case HTTP_SWITCHING_PROTOCOLS: + return " Switching Protocols"; case HTTP_CREATED: return " Created"; case HTTP_ACCEPTED: diff --git a/src/main/java/robaho/net/httpserver/ExchangeImpl.java b/src/main/java/robaho/net/httpserver/ExchangeImpl.java index 66e366d..7da9ca0 100644 --- a/src/main/java/robaho/net/httpserver/ExchangeImpl.java +++ b/src/main/java/robaho/net/httpserver/ExchangeImpl.java @@ -40,8 +40,6 @@ import com.sun.net.httpserver.*; -import robaho.net.httpserver.websockets.WebSocketHandler; - class ExchangeImpl { Headers reqHdrs, rspHdrs; @@ -69,7 +67,8 @@ class ExchangeImpl { private static final String HEAD = "HEAD"; private static final String CONNECT = "CONNECT"; - + private static final String HEADER_CONNECTION = "Connection"; + private static final String HEADER_CONNECTION_UPGRADE = "Upgrade"; /* * streams which take care of the HTTP protocol framing * and are passed up to higher layers @@ -85,7 +84,7 @@ class ExchangeImpl { Map attributes; int rcode = -1; HttpPrincipal principal; - final boolean websocket; + boolean connectionUpgraded = false; ExchangeImpl( String m, URI u, Request req, long len, HttpConnection connection) throws IOException { @@ -97,11 +96,6 @@ class ExchangeImpl { this.method = m; this.uri = u; this.connection = connection; - this.websocket = WebSocketHandler.isWebsocketRequested(this.reqHdrs); - if (this.websocket) { - // length is indeterminate - len = -1; - } this.reqContentLen = len; /* ros only used for headers, body written directly to stream */ this.ros = req.outputStream(); @@ -135,6 +129,9 @@ private boolean isHeadRequest() { private boolean isConnectRequest() { return CONNECT.equals(getRequestMethod()); } + private boolean isUpgradeRequest() { + return HEADER_CONNECTION_UPGRADE.equalsIgnoreCase(reqHdrs.getFirst(HEADER_CONNECTION)); + } public void close() { if (closed) { @@ -170,7 +167,7 @@ public InputStream getRequestBody() { if (uis != null) { return uis; } - if (websocket || isConnectRequest()) { + if (connectionUpgraded || isConnectRequest() || isUpgradeRequest()) { // connection cannot be re-used uis = ris; } else if (reqContentLen == -1L) { @@ -232,7 +229,6 @@ public void sendResponseHeaders(int rCode, long contentLen) ros.write(statusLine.getBytes(ISO_CHARSET)); boolean noContentToSend = false; // assume there is content boolean noContentLengthHeader = false; // must not send Content-length is set - rspHdrs.set("Date", ActivityTimer.dateAndTime()); Integer bufferSize = (Integer)this.getAttribute(Attributes.SOCKET_WRITE_BUFFER); if(bufferSize!=null) { @@ -242,19 +238,21 @@ public void sendResponseHeaders(int rCode, long contentLen) boolean flush = false; /* check for response type that is not allowed to send a body */ - if (rCode == 101) { - logger.log(Level.DEBUG, () -> "switching protocols"); - - if (contentLen != 0) { - String msg = "sendResponseHeaders: rCode = " + rCode - + ": forcing contentLen = 0"; - logger.log(Level.WARNING, msg); - } - contentLen = 0; - flush = true; - - } else if ((rCode >= 100 && rCode < 200) /* informational */ - || (rCode == 204) /* no content */ + var informational = rCode >= 100 && rCode < 200; + + if (informational) { + if (rCode == 101) { + logger.log(Level.DEBUG, () -> "switching protocols"); + if (contentLen != 0) { + String msg = "sendResponseHeaders: rCode = " + rCode + + ": forcing contentLen = 0"; + logger.log(Level.WARNING, msg); + contentLen = 0; + } + connectionUpgraded = true; + } + noContentLengthHeader = true; // the Content-length header must not be set for interim responses as they cannot have a body + } else if ((rCode == 204) /* no content */ || (rCode == 304)) /* not modified */ { if (contentLen != -1) { @@ -266,6 +264,10 @@ public void sendResponseHeaders(int rCode, long contentLen) noContentLengthHeader = (rCode != 304); } + if(!informational) { + rspHdrs.set("Date", ActivityTimer.dateAndTime()); + } + if (isHeadRequest() || rCode == 304) { /* * HEAD requests or 304 responses should not set a content length by passing it @@ -278,14 +280,16 @@ public void sendResponseHeaders(int rCode, long contentLen) noContentToSend = true; contentLen = 0; o.setWrappedStream(new FixedLengthOutputStream(this, ros, contentLen)); - } else { /* not a HEAD request or 304 response */ + } else if(informational && !connectionUpgraded) { + // don't want to set the stream for 1xx responses, except 101, the handler must call sendResponseHeaders again with the final code + flush = true; + } else if(connectionUpgraded || isConnectRequest()) { + o.setWrappedStream(ros); + close = true; + flush = true; + } else { /* standard response with possible response data */ if (contentLen == 0) { - if (websocket || isConnectRequest()) { - o.setWrappedStream(ros); - close = true; - flush = true; - } - else if (http10) { + if (http10) { o.setWrappedStream(new UndefLengthOutputStream(this, ros)); close = true; } else { @@ -323,9 +327,9 @@ else if (http10) { writeHeaders(rspHdrs, ros); this.rspContentLen = contentLen; - sentHeaders = true; + sentHeaders = !informational; if(logger.isLoggable(Level.TRACE)) { - logger.log(Level.TRACE, "Sent headers: noContentToSend=" + noContentToSend); + logger.log(Level.TRACE, "sendResponseHeaders(), code="+rCode+", noContentToSend=" + noContentToSend + ", contentLen=" + contentLen); } if(flush) { ros.flush(); diff --git a/src/main/java/robaho/net/httpserver/NoSyncBufferedInputStream.java b/src/main/java/robaho/net/httpserver/NoSyncBufferedInputStream.java index 33c1da1..b99e5aa 100644 --- a/src/main/java/robaho/net/httpserver/NoSyncBufferedInputStream.java +++ b/src/main/java/robaho/net/httpserver/NoSyncBufferedInputStream.java @@ -83,7 +83,7 @@ public NoSyncBufferedInputStream(InputStream in) { private void fill() throws IOException { pos = 0; count = 0; - int n = getInIfOpen().read(buf); + int n = getInIfOpen().read(getBufIfOpen()); if (n > 0) count = n; } diff --git a/src/main/java/robaho/net/httpserver/ServerImpl.java b/src/main/java/robaho/net/httpserver/ServerImpl.java index 458e65f..214ad23 100644 --- a/src/main/java/robaho/net/httpserver/ServerImpl.java +++ b/src/main/java/robaho/net/httpserver/ServerImpl.java @@ -50,6 +50,7 @@ import java.util.concurrent.Executor; import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; +import java.util.concurrent.RejectedExecutionException; import java.util.logging.LogRecord; import javax.net.ssl.SSLSocket; @@ -342,82 +343,92 @@ public void run() { while (true) { try { Socket s = socket.accept(); - if(logger.isLoggable(Level.TRACE)) { - logger.log(Level.TRACE, "accepted connection: " + s.toString()); - } - stats.connectionCount.incrementAndGet(); - if (MAX_CONNECTIONS > 0 && allConnections.size() >= MAX_CONNECTIONS) { - // we've hit max limit of current open connections, so we go - // ahead and close this connection without processing it + try { + executor.execute(() -> { try { - stats.maxConnectionsExceededCount.incrementAndGet(); - logger.log(Level.WARNING, "closing accepted connection due to too many connections"); - s.close(); - } catch (IOException ignore) { + acceptConnection(s); + } catch (IOException t) { + logger.log(Level.WARNING, "unable to accept connection", t); + try { + s.close(); + } catch (IOException ex) { + } } - continue; + }); + } catch (RejectedExecutionException e) { + s.close(); } - - if (ServerConfig.noDelay()) { - s.setTcpNoDelay(true); + } catch (IOException e) { + if (!isFinishing()) { + logger.log(Level.ERROR, "socket accept failed", e); } + return; + } + } + } + private void acceptConnection(Socket s) throws IOException { + if(logger.isLoggable(Level.TRACE)) { + logger.log(Level.TRACE, "accepted connection: " + s.toString()); + } + stats.connectionCount.incrementAndGet(); + if (MAX_CONNECTIONS > 0 && allConnections.size() >= MAX_CONNECTIONS) { + // we've hit max limit of current open connections, so we go + // ahead and close this connection without processing it + try { + stats.maxConnectionsExceededCount.incrementAndGet(); + logger.log(Level.WARNING, "closing accepted connection due to too many connections"); + s.close(); + } catch (IOException ignore) { + } + return; + } - boolean http2 = false; + if (ServerConfig.noDelay()) { + s.setTcpNoDelay(true); + } - if (https) { - // for some reason, creating an SSLServerSocket and setting the default parameters would - // not work, so upgrade to a SSLSocket after connection - SSLSocketFactory ssf = httpsConfig.getSSLContext().getSocketFactory(); - SSLSocket sslSocket = (SSLSocket) ssf.createSocket(s, null, false); - SSLConfigurator.configure(sslSocket,httpsConfig); + boolean http2 = false; - sslSocket.setHandshakeApplicationProtocolSelector((_sslSocket, protocols) -> { - if (protocols.contains("h2") && ServerConfig.http2OverSSL()) { - return "h2"; - } else { - return "http/1.1"; - } - }); - // the following forces the SSL handshake to complete in order to determine the negotiated protocol - var session = sslSocket.getSession(); - if ("h2".equals(sslSocket.getApplicationProtocol())) { - logger.log(Level.DEBUG, () -> "http2 connection "+sslSocket.toString()); - http2 = true; - } else { - logger.log(Level.DEBUG, () -> "http/1.1 connection "+sslSocket.toString()); - } - s = sslSocket; + if (https) { + // for some reason, creating an SSLServerSocket and setting the default parameters would + // not work, so upgrade to a SSLSocket after connection + SSLSocketFactory ssf = httpsConfig.getSSLContext().getSocketFactory(); + SSLSocket sslSocket = (SSLSocket) ssf.createSocket(s, null, false); + SSLConfigurator.configure(sslSocket,httpsConfig); + + sslSocket.setHandshakeApplicationProtocolSelector((_sslSocket, protocols) -> { + if (protocols.contains("h2") && ServerConfig.http2OverSSL()) { + return "h2"; + } else { + return "http/1.1"; } + }); + // the following forces the SSL handshake to complete in order to determine the negotiated protocol + var session = sslSocket.getSession(); + if ("h2".equals(sslSocket.getApplicationProtocol())) { + logger.log(Level.DEBUG, () -> "http2 connection "+sslSocket.toString()); + http2 = true; + } else { + logger.log(Level.DEBUG, () -> "http/1.1 connection "+sslSocket.toString()); + } + s = sslSocket; + } - HttpConnection c; - try { - c = new HttpConnection(s); - } catch (IOException e) { - logger.log(Level.WARNING, "Failed to create HttpConnection", e); - continue; - } - try { - allConnections.add(c); - - if (http2) { - Http2Exchange t = new Http2Exchange(protocol, c); - executor.execute(t); - } else { - Exchange t = new Exchange(protocol, c); - executor.execute(t); - } + HttpConnection c = new HttpConnection(s); + try { + allConnections.add(c); - } catch (Exception e) { - logger.log(Level.TRACE, "Dispatcher Exception", e); - stats.handleExceptionCount.incrementAndGet(); - closeConnection(c); - } - } catch (IOException e) { - if (!isFinishing()) { - logger.log(Level.ERROR, "Dispatcher Exception, terminating", e); - } - return; + if (http2) { + Http2Exchange t = new Http2Exchange(protocol, c); + t.run(); + } else { + Exchange t = new Exchange(protocol, c); + t.run(); } + } catch (Throwable t) { + logger.log(Level.WARNING, "Dispatcher Exception", t); + stats.handleExceptionCount.incrementAndGet(); + closeConnection(c); } } } @@ -621,13 +632,13 @@ public void run() { logger.log(Level.TRACE, () -> "exchange started "+connection.toString()); - while (true) { + while (!connection.closed) { try { runPerRequest(); if (connection.closed) { break; } - } catch (SocketException e) { + } catch (IOException e) { // these are common with clients breaking connections etc logger.log(Level.TRACE, "ServerImpl IOException", e); stats.socketExceptionCount.incrementAndGet(); @@ -851,12 +862,17 @@ void sendReply( builder.append("HTTP/1.1 ") .append(code).append(Code.msg(code)).append("\r\n"); + var informational = (code >= 100 && code < 200); + if (text != null && text.length() != 0) { builder.append("Content-length: ") .append(text.length()).append("\r\n") .append("Content-type: text/html\r\n"); } else { - builder.append("Content-length: 0\r\n"); + if (!informational) { + // no body for 1xx responses + builder.append("Content-length: 0\r\n"); + } text = ""; } if (closeNow) { @@ -887,7 +903,7 @@ void logReply(int code, String requestStr, String text) { } else { r = requestStr; } - logger.log(Level.DEBUG, () -> "reply "+ r + " [" + code + " " + Code.msg(code) + "] (" + (text!=null ? text : "") + ")"); + logger.log(Level.DEBUG, () -> "reply "+ r + " [" + code + Code.msg(code) + "] (" + (text!=null ? text : "") + ")"); } void delay() { diff --git a/src/main/java/robaho/net/httpserver/extras/MultipartFormParser.java b/src/main/java/robaho/net/httpserver/extras/MultipartFormParser.java index 8d6f169..def4b3b 100644 --- a/src/main/java/robaho/net/httpserver/extras/MultipartFormParser.java +++ b/src/main/java/robaho/net/httpserver/extras/MultipartFormParser.java @@ -14,6 +14,7 @@ import java.util.LinkedList; import java.util.List; import java.util.Map; +import java.util.logging.Logger; import java.util.regex.Matcher; import java.util.regex.Pattern; @@ -23,6 +24,8 @@ * parse multipart form data */ public class MultipartFormParser { + static final Logger logger = Logger.getLogger("robaho.net.httpserver.MultipartFormParser"); + /** * a multipart part. * @@ -38,7 +41,7 @@ public record Part(String contentType, String filename, String data, File file) } - private record PartMetadata(String name, String filename) { + private record PartMetadata(String contentType, String name, String filename) { } @@ -66,7 +69,8 @@ public static Map> parse(String encoding, String content_type List headers = new LinkedList<>(); - System.out.println("reading until start of part"); + logger.finer(() -> "reading multipart form data with boundary '%s'".formatted(boundary)); + // read until boundary found int matchCount = 2; // starting at 2 allows matching non-compliant senders. rfc says CRLF is part of // boundary marker @@ -78,7 +82,6 @@ public static Map> parse(String encoding, String content_type if (c == boundaryCheck[matchCount]) { matchCount++; if (matchCount == boundaryCheck.length - 2) { - System.out.println("found boundary marker"); break; } } else { @@ -99,7 +102,7 @@ public static Map> parse(String encoding, String content_type while (true) { // read part headers until blank line - System.out.println("reading part headers"); + while (true) { s = readLine(charset, is); if (s == null) { @@ -111,7 +114,6 @@ public static Map> parse(String encoding, String content_type headers.add(s); } - System.out.println("reading part data"); // read part data - need to detect end of part PartMetadata meta = parseHeaders(headers); @@ -120,12 +122,12 @@ public static Map> parse(String encoding, String content_type if (meta.filename == null) { var bos = new ByteArrayOutputStream(); os = bos; - addToResults = () -> results.computeIfAbsent(meta.name, k -> new LinkedList()).add(new Part(null, null, bos.toString(charset), null)); + addToResults = () -> results.computeIfAbsent(meta.name, k -> new LinkedList()).add(new Part(meta.contentType, null, bos.toString(charset), null)); } else { File file = Path.of(storage.toString(), meta.filename).toFile(); file.deleteOnExit(); os = new NoSyncBufferedOutputStream(new FileOutputStream(file)); - addToResults = () -> results.computeIfAbsent(meta.name, k -> new LinkedList()).add(new Part(null, meta.filename, null, file)); + addToResults = () -> results.computeIfAbsent(meta.name, k -> new LinkedList()).add(new Part(meta.contentType, meta.filename, null, file)); } try (os) { @@ -138,7 +140,6 @@ public static Map> parse(String encoding, String content_type if (c == boundaryCheck[matchCount]) { matchCount++; if (matchCount == boundaryCheck.length) { - System.out.println("found boundary marker"); break; } } else { @@ -170,6 +171,7 @@ public static Map> parse(String encoding, String content_type private static PartMetadata parseHeaders(List headers) { String name = null; String filename = null; + String contentType = null; for (var header : headers) { String[] parts = header.split(":", 2); if ("content-disposition".equalsIgnoreCase(parts[0])) { @@ -188,9 +190,11 @@ private static PartMetadata parseHeaders(List headers) { } } + } else if ("content-type".equalsIgnoreCase(parts[0])) { + contentType = parts[1].trim(); } } - return new PartMetadata(name, filename); + return new PartMetadata(contentType, name, filename); } private static String readLine(Charset charset, InputStream is) throws IOException { diff --git a/src/test/java/InputRead100Test.java b/src/test/java/InputRead100Test.java new file mode 100644 index 0000000..34d142f --- /dev/null +++ b/src/test/java/InputRead100Test.java @@ -0,0 +1,151 @@ +/** + * @test id=default + * @bug 8349670 + * @summary Test 100 continue response handling + * @run junit/othervm InputRead100Test + */ +/** + * @test id=preferIPv6 + * @bug 8349670 + * @summary Test 100 continue response handling ipv6 + * @run junit/othervm -Djava.net.preferIPv6Addresses=true InputRead100Test + */ +import java.io.BufferedReader; +import java.io.IOException; +import java.io.InputStreamReader; +import java.io.OutputStreamWriter; +import java.io.PrintWriter; +import java.net.InetAddress; +import java.net.InetSocketAddress; +import java.net.Socket; +import java.util.logging.Level; +import java.util.logging.Logger; + +import com.sun.net.httpserver.HttpServer; + +import org.testng.annotations.Test; + +import static java.nio.charset.StandardCharsets.*; + +public class InputRead100Test { + private static final String someContext = "/context"; + + static { + Logger.getLogger("").setLevel(Level.ALL); + Logger.getLogger("").getHandlers()[0].setLevel(Level.ALL); + } + + @Test + public static void testContinue() throws Exception { + System.out.println("testContinue()"); + InetAddress loopback = InetAddress.getLoopbackAddress(); + HttpServer server = HttpServer.create(new InetSocketAddress(loopback, 0), 0); + try { + server.createContext( + someContext, + msg -> { + System.err.println("Handling request: " + msg.getRequestURI()); + byte[] reply = "Here is my reply!".getBytes(UTF_8); + try { + msg.getRequestBody().readAllBytes(); + msg.sendResponseHeaders(200, reply.length); + msg.getResponseBody().write(reply); + msg.getResponseBody().close(); + } finally { + System.err.println("Request handled: " + msg.getRequestURI()); + } + }); + server.start(); + System.out.println("Server started at port " + server.getAddress().getPort()); + + runRawSocketHttpClient(loopback, server.getAddress().getPort(), 0); + } finally { + System.out.println("shutting server down"); + server.stop(0); + } + System.out.println("Server finished."); + } + + static void runRawSocketHttpClient(InetAddress address, int port, int contentLength) + throws Exception { + Socket socket = null; + PrintWriter writer = null; + BufferedReader reader = null; + + boolean foundContinue = false; + + final String CRLF = "\r\n"; + try { + socket = new Socket(address, port); + writer = new PrintWriter(new OutputStreamWriter(socket.getOutputStream())); + System.out.println("Client connected by socket: " + socket); + String body = "I will send all the data."; + if (contentLength <= 0) contentLength = body.getBytes(UTF_8).length; + + writer.print("GET " + someContext + "/ HTTP/1.1" + CRLF); + writer.print("User-Agent: Java/" + System.getProperty("java.version") + CRLF); + writer.print("Host: " + address.getHostName() + CRLF); + writer.print("Accept: */*" + CRLF); + writer.print("Content-Length: " + contentLength + CRLF); + writer.print("Connection: keep-alive" + CRLF); + writer.print("Expect: 100-continue" + CRLF); + writer.print(CRLF); // Important, else the server will expect that + // there's more into the request. + writer.flush(); + System.out.println("Client wrote request to socket: " + socket); + System.out.println("Client read 100 Continue response from server and headers"); + reader = new BufferedReader(new InputStreamReader(socket.getInputStream())); + String line = reader.readLine(); + for (; line != null; line = reader.readLine()) { + if (line.isEmpty()) { + break; + } + System.out.println("interim response \"" + line + "\""); + if (line.startsWith("HTTP/1.1 100")) { + foundContinue = true; + } + } + if (!foundContinue) { + throw new IOException("Did not receive 100 continue from server"); + } + writer.print(body); + writer.flush(); + System.out.println("Client wrote body to socket: " + socket); + + System.out.println("Client start reading from server:"); + line = reader.readLine(); + for (; line != null; line = reader.readLine()) { + if (line.isEmpty()) { + break; + } + System.out.println("final response \"" + line + "\""); + } + System.out.println("Client finished reading from server"); + } finally { + // give time to the server to try & drain its input stream + Thread.sleep(500); + // closes the client outputstream while the server is draining + // it + if (writer != null) { + writer.close(); + } + // give time to the server to trigger its assertion + // error before closing the connection + Thread.sleep(500); + if (reader != null) + try { + reader.close(); + } catch (IOException logOrIgnore) { + logOrIgnore.printStackTrace(); + } + if (socket != null) { + try { + socket.close(); + } catch (IOException logOrIgnore) { + logOrIgnore.printStackTrace(); + } + } + } + System.out.println("Client finished."); + } +} diff --git a/src/test/java/robaho/net/httpserver/extras/MultipartFormParserTest.java b/src/test/java/robaho/net/httpserver/extras/MultipartFormParserTest.java index 0a399e8..12b065d 100644 --- a/src/test/java/robaho/net/httpserver/extras/MultipartFormParserTest.java +++ b/src/test/java/robaho/net/httpserver/extras/MultipartFormParserTest.java @@ -52,13 +52,16 @@ public void testFiles() throws UnsupportedEncodingException, IOException { s += "111Y\r\n"; s += "111Z\rCCCC\nCCCC\r\nCCCCC@\r\n"; + Assert.assertEquals(values.get(0).contentType(),"text/plain"); Assert.assertEquals(s.getBytes("UTF-8"), Files.readAllBytes((values.get(0).file()).toPath()), "file1 failed"); + s = "\r\n"; s += "@22X"; s += "222Y\r\n"; s += "222Z\r222W\n2220\r\n666@"; + Assert.assertEquals(values.get(1).contentType(),"text/plain"); Assert.assertEquals(s.getBytes("UTF-8"), Files.readAllBytes((values.get(1).file()).toPath()), "file2 failed"); } @@ -118,6 +121,7 @@ public void testFormSample() throws IOException { Assert.assertEquals(results.size(), 1); List values = results.get("myfile"); + Assert.assertEquals(values.get(0).contentType(), "text/plain"); Assert.assertEquals(values.size(), 1); } @@ -133,6 +137,12 @@ public void testMultiFileFormSample() throws IOException { Assert.assertEquals(results.size(), 2); List values = results.get("myfile"); + Assert.assertEquals(values.get(0).contentType(), "text/plain"); + Assert.assertEquals(values.size(), 1); + + values = results.get("myfile2"); + Assert.assertEquals(values.get(0).contentType(), "image/png"); Assert.assertEquals(values.size(), 1); + } }