Gautam Mathur
Gautam Mathur

Reputation: 11

Stitching Mask R CNN predicted geospatial polygons separated by image chips

My team and I recently used a Mask-R-CNN model to use satellite imagery to predict the spatial extent of agricultural fields. The process involved splitting the satellite imagery into 512 X 512 pixel chips, predicting the fields from there, and then joining back the polygons for each image. We found that the model didn’t make predictions at the edge of the 512X512 chips, and this left areas of no data at the edges of the image chips:

The merged dataset ended up looking like this:

I wanted to come up with the polygons that have been separated to be manipulated to be merged, came up with the following fix:

The function is successful in most cases. In quite a few cases, it only merges a portion of the polygons

This is acceptable because it still merges into one polygon and represents most of the area of the polygon, but is not ideal. In a few cases, some of these polygons still remain unmerged

I was wondering if anyone on here has an idea on how I should amend my code to catch some of these mixed polygons and merge them as well, and even fix the partially merged polygons. Please let me know your thoughts. Thank you so much! Here is the code:

import os
import pandas as pd
import geopandas as gpd
import rasterio
from shapely.geometry import Polygon
from collections import Counter
from pathlib import Path


def stitchshp(img, polys, endpath):
    img = rasterio.open(img)
    polys = gpd.read_file(polys)
   #Making Fishnets from the image
    firsteast = img.bounds.left
    firstnorth = img.bounds.top
    easting = [img.bounds.left]
    northing = [img.bounds.top]

    def funeast(image):
        while image < img.bounds.right:
            image = image +(0.5*512)
            easting.append(image)
            return funeast(image)
        
    def funnorth(image):
        while image > img.bounds.bottom:
            image = image -(0.5*512)
            northing.append(image)
            return funnorth(image)


    funeast(firsteast)
    funnorth(firstnorth)

    geom = []

    for i in range(len(northing)):
        for j in range(len(easting)):
            coords = [(easting[j], northing[i]), (easting[j]+(0.5*512), northing[i]), (easting[j]+(0.5*512), northing[i]-(0.5*512)), (easting[j], northing[i]-(0.5*512))]
            geom.append(Polygon(coords))

    table = gpd.GeoDataFrame({"geometry": geom}, crs ={'init' :'epsg:32644'} )


    #using fishnets to subset manipulate, and stitch polygons
    allpolyslist = []
    for fish in range(len(table.geometry)):
        net = table.geometry[fish]
        east = max([i[0] for i in [*net.exterior.coords]])
        west = min([i[0] for i in [*net.exterior.coords]])
        north = max([i[1] for i in [*net.exterior.coords]])
        south = min([i[1] for i in [*net.exterior.coords]])
        subpolys = polys[polys.intersects(net)]
        net = gpd.GeoSeries(net)

        maxeast = []
        mineast = []
        maxnorth = []
        minnorth = []
        for sequ in [[*i.exterior.coords] for i in subpolys.geometry]:
            maxeast.append(max(i[0] for i in sequ))
            mineast.append(min(i[0] for i in sequ))
            maxnorth.append(max(i[1] for i in sequ))
            minnorth.append(min(i[1] for i in sequ))
        
        eastdict = Counter(maxeast)
        eastchange = [key for key, value in eastdict.items() if key >east-2 and key <east+2 and value>1]
        westdict = Counter(mineast)
        westchange = [key for key, value in westdict.items() if key <west+2 and key >west-2 and value>1]
        northdict = Counter(maxnorth)
        northchange = [key for key, value in northdict.items() if key >north-2 and key <north+2 and value>1]
        southdict = Counter(minnorth)
        southchange = [key for key, value in southdict.items() if key <south+2 and key >south-2 and value>1]

        polylists = [[*i.exterior.coords] for i  in subpolys.geometry]

        properpolys = []
        for poly in polylists:
            if len(poly)>6:
                properpolys.append(poly)


        for poly in properpolys:
            for j in range(len(poly)):
                if poly[j][0] in eastchange:
                    poly[j] = (east, poly[j][1])
                if poly[j][0] in westchange:
                    poly[j] = (west, poly[j][1])
                if poly[j][1] in northchange:
                    poly[j] = (poly[j][0], north)
                if poly[j][1] in southchange:
                    poly[j] = (poly[j][0], south)
        

        
        allpolyslist = allpolyslist + properpolys
    print("ok, iterated thru!")
    #gpd.GeoSeries([Polygon(i) for i in allpolyslist]).plot()   

    allpolylist = [Polygon(i) for i in allpolyslist]
    print("converted to polygons!")
    allpolydict = gpd.GeoDataFrame({"geometry": allpolylist}, crs = {'init' :'epsg:32644'})
    merge = allpolydict.unary_union
    merge = gpd.GeoSeries(merge)
    merge = merge.explode()
    print("exploded polys!")
    united = gpd.GeoDataFrame(geometry = merge)
    print("Made final df!")
    united.to_file(endpath, crs = allpolydict.crs)

Upvotes: 1

Views: 153

Answers (1)

Pieter
Pieter

Reputation: 1504

To avoid the worse detection quality on the edges of images you can run the detection on slightly larger, overlapping chips: e.g. add 32 pixels in the 4 directions.

After the detection, you can make the extra border pixels black so they are ignored when polygonizing the detection. This ensures consistent detection quality and avoids the issue you encounter.

This code snippet shows how you can ignore the extra border pixels:

# Make the pixels at the borders of the prediction black so they are ignored
if border_pixels_to_ignore and border_pixels_to_ignore > 0:
    mask_arr[0:border_pixels_to_ignore, :] = 0  # Left border
    mask_arr[-border_pixels_to_ignore:, :] = 0  # Right border
    mask_arr[:, 0:border_pixels_to_ignore] = 0  # Top border
    mask_arr[:, -border_pixels_to_ignore:] = 0  # Bottom border

This "trick" is also used in orthoseg, some software I developed to make it easier to detect stuff on ortho images. The full relevant code where the code snippet above was taken from can be found here.

Upvotes: -1

Related Questions