Multithreaded Server in Java

In Java, listening over TCP can be simply done using ServerSocket. Example below binds our program to port 5656:

ServerSocket serverSocket = new ServerSocket(5656);

To access the client’s input stream, we first need to accept the connection and obtain the Socket object. This is done through the accept method:

ServerSocket serverSocket = new ServerSocket(5656);
Socket socket = serverSocket.accept();

The program above can only accept 1 client connection. We want to able to accept infinitely many connections, hence we can use an infinite loop to keep accepting new connections:

ServerSocket serverSocket = new ServerSocket(5656);
while(true) {
  Socket socket = serverSocket.accept();
  // process client connection..
}

However once a client is connected our code has to wait until he/she disconnects before we can accept the next one. This might be ok if the work duration is very short, but if it’s very long (eg: a chatting program) then this won’t do. To solve this we can use an ExecutorService which schedules work to be done in the future (in a separate thread) and allowing the calling thread to progress. In our case we will use a fixed thread-pool executor service, which effectively limit how many active connections we can have. We will soon write the ConnectionHandler class below to process the connection.

ExecutorService executorService = Executors.newFixedThreadPool(100);
ServerSocket serverSocket = new ServerSocket(5656);
while(true) {
  Socket socket = serverSocket.accept();
  executorService.execute(new ConnectionHandler(socket));
  // go back to start of infinite loop and listen for next incoming connection
}

The ConnectionHandler class above implements Runnable and does the work of listening & writing to client. Here’s an example of what simple ConnectionHandler that just echoes messages sent by client back:

public class ConnectionHandler implements Runnable {

  private Socket socket;

  public ConnectionHandler(Socket socket) {
    this.socket = socket;
  }

  public void run() {
    BufferedReader reader = null;
    PrintWriter writer = null;
    try {
      reader = new BufferedReader(new InputStreamReader(socket.getInputStream()));
      writer = new PrintWriter(socket.getOutputStream(), true);

      // The read loop. Code only exits this loop if connection is lost / client disconnects
      while(true) {
        String line = reader.readLine();
        if(line == null) break;
        writer.println("Echo: " + line);
      }
    } catch (IOException e) {
      throw new RuntimeException(e);
    } finally {
      try {
        if(reader != null) reader.close();
        if(writer != null) writer.close();
      } catch (IOException e) {
        throw new RuntimeException(e);
      }
    }
  }

}

Keep in mind the run method will run in a separate thread, not the main thread, so make sure you synchronize reads / writes into shared resources.

Socket Timeout

Ok so far so good, we can handle multiple connection simultaneously in separate threads. But our code is still vulnerable. Recall our thread pool setting is capped at 100 which means we can only accept 100 simultaneous connection. This leave us vulnerable to denial-of-service attack, a hacker could easily open 100 connections and leave it be — denying access to other clients. One way to mitigate this is to set socket timeout:

ExecutorService executorService = Executors.newFixedThreadPool(100);
ServerSocket serverSocket = new ServerSocket(5656);
while(true) {
  Socket socket = serverSocket.accept();
  socket.setSoTimeout(3000); // inputstream's read times out if no data came after 3 seconds
  executorService.execute(new ConnectionHandler(socket));
  // go back to start of infinite loop and listen for next incoming connection
}

By setting the socket timeout, the BufferedReader.readLine() method used by ConnectionHandler above will throw SocketTimeoutException if it hasn’t had any data in 3 seconds.

public void run() {
  BufferedReader reader = null;
  PrintWriter writer = null;
  try {
    reader = new BufferedReader(new InputStreamReader(socket.getInputStream()));
    writer = new PrintWriter(socket.getOutputStream(), true);
 
    while(true) {
      // SocketTimeoutException thrown here if nothing read after 3 seconds
      String line = reader.readLine();
      if(line == null) break;
      writer.println("Echo: " + line);
    }
  } catch (SocketTimeoutException e) {
    System.out.println("Connection timed out");
  } catch (IOException e) {
    throw new RuntimeException(e);
  } finally {
    try {
      if(reader != null) reader.close();
      if(writer != null) writer.close();
    } catch (IOException e) {
      throw new RuntimeException(e);
    }
  }
}

Leave a Reply