Eoin
Eoin

Reputation: 350

How can I improve this search algorithms runtime?

I'm trying to solve an interview problem I was given a few years ago in preparation for upcoming interviews. The problem is outlined in a pdf here. I wrote a simple solution using DFS that works fine for the example outlined in the document, but I haven't been able to get the program to meet the criteria of

Your code should produce correct answers in under a second for a 10,000 x 10,000 Geo GeoBlock containing 10,000 occupied Geos.

To test this I generated a CSV file with 10000 random entries and when I run the code against it, it averages just over 2 seconds to find the largest geo block in it. I'm not sure what improvements could be made to my approach to cut the runtime by over half, other than running it on a faster laptop. From my investigations it appears the search itself seems to only take about 8ms, so perhaps the way I load the data into memory is the inefficient part?

I'd greatly appreciate an advice on how this could be improved. See code below:

GeoBlockAnalyzer

package analyzer.block.geo.main;

import analyzer.block.geo.model.Geo;
import analyzer.block.geo.result.GeoResult;

import java.awt.*;
import java.io.BufferedReader;
import java.io.FileNotFoundException;
import java.io.IOException;
import java.nio.file.Files;
import java.nio.file.Paths;
import java.time.LocalDate;
import java.time.format.DateTimeFormatter;
import java.time.format.DateTimeParseException;
import java.util.List;
import java.util.*;

public class GeoBlockAnalyzer {

  private static final DateTimeFormatter formatter = DateTimeFormatter.ofPattern("yyyy-MM-dd");
  private final int width;
  private final int height;
  private final String csvFilePath;
  private GeoResult result = new GeoResult();

  // Map of the geo id and respective geo object
  private final Map<Integer, Geo> geoMap = new HashMap<>();
  // Map of coordinates to each geo in the grid
  private final Map<Point, Geo> coordMap = new HashMap<>();

  /**
   * Constructs a geo grid of the given width and height, populated with the geo data provided in
   * the csv file
   *
   * @param width the width of the grid
   * @param height the height of the grid
   * @param csvFilePath the csv file containing the geo data
   * @throws IOException
   */
  public GeoBlockAnalyzer(final int width, final int height, final String csvFilePath)
      throws IOException {

    if (!Files.exists(Paths.get(csvFilePath)) || Files.isDirectory(Paths.get(csvFilePath))) {
      throw new FileNotFoundException(csvFilePath);
    }

    if (width <= 0 || height <= 0) {
      throw new IllegalArgumentException("Input height or width is 0 or smaller");
    }

    this.width = width;
    this.height = height;
    this.csvFilePath = csvFilePath;

    populateGeoGrid();
    populateCoordinatesMap();
    calculateGeoNeighbours();
    // printNeighbours();
  }

  /** @return the largest geo block in the input grid */
  public GeoResult getLargestGeoBlock() {
    for (final Geo geo : this.geoMap.values()) {
      final List<Geo> visited = new ArrayList<>();
      search(geo, visited);
    }
    return this.result;
  }

  /**
   * Iterative DFS implementation to find largest geo block.
   *
   * @param geo the geo to be evaluated
   * @param visited list of visited geos
   */
  private void search(Geo geo, final List<Geo> visited) {
    final Deque<Geo> stack = new LinkedList<>();
    stack.push(geo);
    while (!stack.isEmpty()) {
      geo = stack.pop();
      if (visited.contains(geo)) {
        continue;
      }
      visited.add(geo);

      final List<Geo> neighbours = geo.getNeighbours();
      for (int i = neighbours.size() - 1; i >= 0; i--) {
        final Geo g = neighbours.get(i);
        if (!visited.contains(g)) {
          stack.push(g);
        }
      }
    }
    if (this.result.getSize() < visited.size()) {
      this.result = new GeoResult(visited);
    }
  }

  /**
   * Creates a map of the geo grid from the csv file data
   *
   * @throws IOException
   */
  private void populateGeoGrid() throws IOException {
    try (final BufferedReader br = Files.newBufferedReader(Paths.get(this.csvFilePath))) {
      int lineNumber = 0;
      String line = "";
      while ((line = br.readLine()) != null) {
        lineNumber++;
        final String[] geoData = line.split(",");
        LocalDate dateOccupied = null;

        // Handle for empty csv cells
        for (int i = 0; i < geoData.length; i++) {
          // Remove leading and trailing whitespace
          geoData[i] = geoData[i].replace(" ", "");

          if (geoData[i].isEmpty() || geoData.length > 3) {
            throw new IllegalArgumentException(
                "There is missing data in the csv file at line: " + lineNumber);
          }
        }
        try {
          dateOccupied = LocalDate.parse(geoData[2], formatter);
        } catch (final DateTimeParseException e) {
          throw new IllegalArgumentException("There input date is invalid on line: " + lineNumber);
        }
        this.geoMap.put(
            Integer.parseInt(geoData[0]),
            new Geo(Integer.parseInt(geoData[0]), geoData[1], dateOccupied));
      }
    }
  }

  /** Create a map of each coordinate in the grid to its respective geo */
  private void populateCoordinatesMap() {
    // Using the geo id, calculate its point on the grid
    for (int i = this.height - 1; i >= 0; i--) {
      int blockId = (i * this.width);
      for (int j = 0; j < this.width; j++) {
        if (this.geoMap.containsKey(blockId)) {
          final Geo geo = this.geoMap.get(blockId);
          geo.setCoordinates(i, j);
          this.coordMap.put(geo.getCoordinates(), geo);
        }
        blockId++;
      }
    }
  }

  private void calculateGeoNeighbours() {
    for (final Geo geo : this.geoMap.values()) {
      addNeighboursToGeo(geo);
    }
  }

  private void addNeighboursToGeo(final Geo geo) {
    final int x = geo.getCoordinates().x;
    final int y = geo.getCoordinates().y;

    final Point[] possibleNeighbours = {
      new Point(x, y + 1), new Point(x - 1, y), new Point(x + 1, y), new Point(x, y - 1)
    };

    Geo g;
    for (final Point p : possibleNeighbours) {
      if (this.coordMap.containsKey(p)) {
        g = this.coordMap.get(p);
        if (g != null) {
          geo.getNeighbours().add(g);
        }
      }
    }
  }

  private void printNeighbours() {
    for (final Geo geo : this.geoMap.values()) {
      System.out.println("Geo " + geo.getId() + " has the following neighbours: ");
      for (final Geo g : geo.getNeighbours()) {
        System.out.println(g.getId());
      }
    }
  }
}

GeoResult

package analyzer.block.geo.result;

import analyzer.block.geo.model.Geo;

import java.util.ArrayList;
import java.util.Comparator;
import java.util.List;

public class GeoResult {

    private final List<Geo> geosInBlock = new ArrayList<>();

    public GeoResult() {
    }

    public GeoResult(final List<Geo> geosInBlock) {
        this.geosInBlock.addAll(geosInBlock);
    }

    public List<Geo> getGeosInBlock() {
        this.geosInBlock.sort(Comparator.comparingInt(Geo::getId));
        return this.geosInBlock;
    }

    public int getSize() {
        return this.geosInBlock.size();
    }

    @Override
    public String toString() {
        final StringBuilder sb = new StringBuilder();
        sb.append("The geos in the largest cluster of occupied Geos for this GeoBlock are: \n");
        for(final Geo geo : this.geosInBlock) {
            sb.append(geo.toString()).append("\n");
        }
        return sb.toString();
    }
}

Geo

package analyzer.block.geo.model;

import java.awt.Point;
import java.time.LocalDate;
import java.util.ArrayList;
import java.util.List;
import java.util.Objects;

public class Geo {

    private final int id;
    private final String name;
    private final LocalDate dateOccupied;
    private final Point coordinate;
    private final List<Geo> neighbours = new ArrayList<>();

    public Geo (final int id, final String name, final LocalDate dateOccupied) {
        this.id = id;
        this.name = name;
        this.dateOccupied = dateOccupied;
        this.coordinate = new Point();
    }

    public int getId() {
        return this.id;
    }

    public String getName() {
        return this.name;
    }

    public LocalDate getDateOccupied() {
        return this.dateOccupied;
    }

    public void setCoordinates(final int x, final int y) {
        this.coordinate.setLocation(x, y);
    }

    public Point getCoordinates() {
        return this.coordinate;
    }

    public String toString() {
        return this.id + ", " + this.name + ", " + this.dateOccupied;
    }

    public List<Geo> getNeighbours() {
        return this.neighbours;
    }

    @Override
    public int hashCode() {
        return Objects.hash(this.id, this.name, this.dateOccupied);
    }

    @Override
    public boolean equals(final Object obj) {
        if(this == obj) {
            return true;
        }

        if(obj == null || this.getClass() != obj.getClass()) {
           return false;
        }

        final Geo geo = (Geo) obj;
        return this.id == geo.getId() &&
                this.name.equals(geo.getName()) &&
                this.dateOccupied == geo.getDateOccupied();
    }
}

Upvotes: 5

Views: 272

Answers (2)

ldog
ldog

Reputation: 12161

The major optimization available here is a conceptual one. Unfortunately, this type of optimization is not easy to teach, nor look up in a reference somewhere. The principle being used here is:

It's (almost always) cheaper to use an analytic formula to compute a known result than to (pre)compute it. [1]

It's clear from your code & the definition of your problem that you are not taking advantage of this principle and the problem specification. In particular, one of the key points taken directly from the problem specification is this:

Your code should produce correct answers in under a second for a 10,000 x 10,000 Geo GeoBlock containing 10,000 occupied Geos.

When you read this statement a few things should be going through your mind (when thinking about runtime efficiency):

  • 10,000^2 is a much larger number than 10,000 (exactly 10,000 times larger!) There is a clear efficiency gain if you can maintain an algorithm that is O(n) as opposed to O(n^2) (in the expected case because of the use of hashing.)
  • touching (i.e. computing any O(1) operation) for the entire grid is going to immediately yield a O(n^2) algorithm; clearly, this is something that must be avoided if possible
  • from the problem statement, we should never expect O(n^2) geo's that need to be touched. This should be a major hint as to what the person who wrote the problem is looking for. BFS or DFS is an O(N+M) algorithm where N,M are the number of nodes and edges touched. Thus, we should be expecting an O(n) search.
  • based on the above points, it is clear that the solution being looked for here should be O(10,000) for a problem input with grid size 10,000 x 10,000 and 10,000 geos

The solution you provided is O(n^2) because,

  1. You use visited.contains where visited is a List. This is not showing up in your testing as a problem area because I suspect you are using small geo clusters. Try using a large geo cluster (one with 10,000 geos.) You should see a major slow down as compared to say the largest cluster having 3 geos. The solution here is to use an efficient data structure for visited, some that come to mind are a bit set (unknown to me if Java has any available, but any decent language should) or a hash set (clearly Java has some available.) Because you did not notice this in testing, this suggests to me you are not vetting/testing your code well enough with enough varied examples of the corner cases you expect. This should of come up immediately in any thorough testing/profiling of your code. As per my comment, I would of liked to have seen this type of groundwork/profiling done before the question was posted.
  2. You touch the entire 10,000 x 10,000 grid in the function/member populateCoordinatesMap. This is clearly already O(n^2) where n=10,000. Notice, that the only location where coordMap is used outside of populateCoordinatesMap is in addNeighboursToGeo. This is a major bottleneck, and for no reason, addNeighboursToGeo can be computed in O(1) time without the need for a coordMap. However, we can still use your code as is with a minor modification given below.

I hope it is obvious how to fix (1). To fix (2), replace populateCoordinatesMap

  /** Create a map of each coordinate in the grid to its respective geo */
  private void populateCoordinatesMap() {
   for (Map.Entry<int,Geo> entry : geoMap.entrySet()) {
     int key = entry.getKey();
     Geo value = entry.getValue();
     int x = key % this.width;
     int y = key / this.width;  
     value.setCoordinates(x, y);
     this.coordMap.put(geo.getCoordinates(), geo); 
   }
  }

Notice the principle being put to use here. Instead of iterating over the entire grid as you were doing before (O(n^2) immediately), this iterates only over the occupied Geos, and uses the analytic formula for indexing a 2D array (as opposed to doing copious computation to compute the same thing.) Effectively, this change improves populateCoordinatesMap from being O(n^2) to being O(n).

Some general & opinionated comments below:

  • Overall, I strongly disagree with using an object oriented approach over a procedural one for this problem. I think the OO approach is completely unjustified for how simple this code should be, but I understand that the interviewer wanted to see it.
  • This is a very simple problem you are trying to solve, and I think the object orientated approach you took here confounds it so much so you could not see the forest for the trees (or perhaps the trees for the forest.) A much simpler approach could of been taken in how this algorithm was implemented, even using an object oriented approach.
  • It's clear from the points above, you could benefit from knowing the available tools in the language you are working in. By this I mean you should know what containers are readily available and what the trade offs are for using each operation on each container. You should also know at least one decent profiling tool for the language you are working with if you are going to be looking into optimizing code. Given that you failed to post a profiling summary, even after I asked for it, it suggests to me you do not know of such a tool with Java. Learn one.

[1] I provide no reference for this principle because it is a first principle, and can be explained by the fact that running fewer constant time operations is cheaper than running many. The assumption here is that the known analytic form requires less computation. There are occasional exceptions to this rule. But it should be stressed that such exceptions are almost always because of hardware limitations or advantages. For example, when computing the hamming distance it is cheaper to use a precomputed LUT for computing the population count on a hardware architecture without access to SSE registers/operations.

Upvotes: 1

גלעד ברקן
גלעד ברקן

Reputation: 23955

Without testing, it seems to me that the main block here is the literal creation of the map, which could be up to 100,000,000 cells. There would be no need for that if instead we labeled each CSV entry and had a function getNeighbours(id, width, height) that returned the list of possible neighbour IDs (think modular arithmetic). As we iterate over each CSV entry in turn, if (1) neighbour IDs were already seen that all had the same label, we'd label the new ID with that label; if (2) no neighbours were seen, we'd use a new label for the new ID; and if (3) two or more different labels existed between seen neighbour IDs, we'd combine them to one label (say the minimal label), by having a hash that mapped a label to its "final" label. Also store the sum and size for each label. Your current solution is O(n), where n is width x height. The idea here would be O(n), where n is the number of occupied Geos.

Here's something really crude in Python that I wouldn't expect to have all scenarios handled but could hopefully give you an idea (sorry, I don't know Java):

def get_neighbours(id, width, height):
  neighbours = []

  if id % width != 0:
    neighbours.append(id - 1)
  if (id + 1) % width != 0:
    neighbours.append(id + 1)
  if id - width >= 0:
    neighbours.append(id - width)
  if id + width < width * height:
    neighbours.append(id + width)

  return neighbours

def f(data, width, height):
  ids = {}
  labels = {}
  current_label = 0
        
  for line in data:
    [idx, name, dt] = line.split(",")
    idx = int(idx)
    this_label = None
    neighbours = get_neighbours(idx, width, height)
    no_neighbour_was_seen = True

    for n in neighbours:
      # A neighbour was seen
      if n in ids:
        no_neighbour_was_seen = False

        # We have yet to assign a label to this ID
        if not this_label:
          this_label = ids[n]["label"]
          ids[idx] = {"label": this_label, "data": name + " " + dt}
          final_label = labels[this_label]["label"]
          labels[final_label]["size"] += 1
          labels[final_label]["sum"] += idx
          labels[final_label]["IDs"] += [idx]

        # This neighbour has yet to be connected
        elif ids[n]["label"] != this_label:
          old_label = ids[n]["label"]
          old_obj = labels[old_label]
          final_label = labels[this_label]["label"]
          ids[n]["label"] = final_label
          labels[final_label]["size"] += old_obj["size"]
          labels[final_label]["sum"] += old_obj["sum"]
          labels[final_label]["IDs"] += old_obj["IDs"]
          del labels[old_label]

    if no_neighbour_was_seen:
      this_label = current_label
      current_label += 1
      ids[idx] = {"label": this_label, "data": name + " " + dt}
      labels[this_label] = {"label": this_label, "size": 1, "sum": idx, "IDs": [idx]}

  for i in ids:
    print i, ids[i]["label"], ids[i]["data"]
  print ""
  for i in labels:
    print i
    print labels[i]

  return labels, ids
  
          
data = [
  "4, Tom, 2010-10-10",
  "5, Katie, 2010-08-24",
  "6, Nicole, 2011-01-09",
  "11, Mel, 2011-01-01",
  "13, Matt, 2010-10-14",
  "15, Mel, 2011-01-01",
  "17, Patrick, 2011-03-10",
  "21, Catherine, 2011-02-25",
  "22, Michael, 2011-02-25"
]

f(data, 4, 7)
print ""
f(data, 7, 4)

Output:

"""
4 0  Tom  2010-10-10
5 0  Katie  2010-08-24
6 0  Nicole  2011-01-09
11 1  Mel  2011-01-01
13 2  Matt  2010-10-14
15 1  Mel  2011-01-01
17 2  Patrick  2011-03-10
21 2  Catherine  2011-02-25
22 2  Michael  2011-02-25

0
{'sum': 15, 'size': 3, 'IDs': [4, 5, 6], 'label': 0}
1
{'sum': 26, 'size': 2, 'IDs': [11, 15], 'label': 1}
2
{'sum': 73, 'size': 4, 'IDs': [13, 17, 21, 22], 'label': 2}

---

4 0  Tom  2010-10-10
5 0  Katie  2010-08-24
6 0  Nicole  2011-01-09
11 0  Mel  2011-01-01
13 0  Matt  2010-10-14
15 3  Mel  2011-01-01
17 2  Patrick  2011-03-10
21 3  Catherine  2011-02-25
22 3  Michael  2011-02-25

0
{'sum': 39, 'size': 5, 'IDs': [4, 5, 6, 11, 13], 'label': 0}
2
{'sum': 17, 'size': 1, 'IDs': [17], 'label': 2}
3
{'sum': 58, 'size': 3, 'IDs': [21, 22, 15], 'label': 3}
"""

Upvotes: 1

Related Questions