Object Detection
This example demonstrates query inference using an object detection model with WherobotsAI Raster Inference to identify marine infrastructure (offshore wind farms and platforms) in satellite imagery. We will use a machine-learning model from Satlas 1 which was trained using imagery from the European Space Agency’s Sentinel-2 satellites.
Before you start¶
This is a read-only preview of this notebook.
To execute the cells in this Jupyter Notebook, do the following:
- Login to Wherobots Cloud.
- Start a GPU-Optimized runtime instance.
- Open a notebook.
- We recommend using a Tiny GPU-Optimized runtime.
- Open the
examples/python/wherobots-ai/gpu/object_detection.ipynb
notebook path.
For more information on starting and using notebooks, see the following Wherobots Documentation:
Access a GPU-Optimized runtime¶
This notebook requires a GPU-Optimized runtime. For more information on GPU-Optimized runtimes, see Runtime types.
To access this runtime category, do the following:
- Sign up for a paid Wherobots Organization Edition (Professional or Enterprise).
- Submit a Compute Request for a GPU-Optimized runtime.
Step 1: Set Up The WherobotsDB Context¶
Here we configure WherobotsDB to enable access to the necessary cloud object storage buckets with sample data.
import warnings
warnings.filterwarnings('ignore')
from wherobots.inference.data.io import read_raster_table
from sedona.spark import SedonaContext
from pyspark.sql.functions import expr
config = SedonaContext.builder().appName('object-detection-batch-inference')\
.getOrCreate()
sedona = SedonaContext.create(config)
# Uncomment the line that sets the `user_mlm_uri` variable and include the path to your MLM JSON to use your own model.
# Learn more about bringing your own model see [Bring your own model](https://docs.wherobots.com/latest/tutorials/wherobotsai/wherobots-inference/bring-your-own-model/) in the Wherobots Documentation.
# user_mlm_uri = [PATH-TO-MLM-JSON]
Step 2: Load Satellite Imagery¶
Next, we load the satellite imagery that we will be running inference over. These GeoTiff images are loaded as out-db rasters in WherobotsDB, where each row represents a different scene.
tif_folder_path = 's3://wherobots-benchmark-prod/data/ml/satlas-offshore-wind-scenes/'
files_df = read_raster_table(tif_folder_path, sedona, limit=500)
df_raster_input = files_df.withColumn(
"outdb_raster", expr("RS_FromPath(path)")
)
%%time
df_raster_input.cache().count()
df_raster_input.show(truncate=False)
Step 3: Run Predictions And Visualize Results¶
To run predictions we will specify the model we wish to use. Some models are pre-loaded and made available in Wherobots Cloud. We can also load our own models. Predictions can be run using Wherobot's Spatial SQL functions, in this case RS_DETECT_BBOXES
.
Here we generate 100 predictions using RS_DETECT_BBOXES
.
df_raster_input.createOrReplaceTempView("df_raster_input")
model_id = 'marine-satlas-sentinel2'
predictions_df = sedona.sql(f"""
SELECT
outdb_raster,
detect_result.*
FROM (
SELECT
outdb_raster,
RS_DETECT_BBOXES('{model_id}', outdb_raster) AS detect_result
FROM
df_raster_input
) AS detect_fields
""")
predictions_df.cache().count()
predictions_df.show()
predictions_df.createOrReplaceTempView("predictions")
You can specify your own model instead of using one of our hosted models via the model_id
variable. To do so, replace the model_id
variable with the s3 uri pointing to your Machine Learning Model Extension (MLM) metadata json. Then pass that as an argument to RS_DETECT_BBOXES
.
For example:
user_mlm_uri = 's3://wherobots-modelhub-prod/professional/object-detection/marine-satlas-sentinel2/model-metadata.json'
predictions_df = sedona.sql(f"SELECT name, outdb_raster, RS_DETECT_BBOXES('{user_mlm_uri}', outdb_raster) AS preds FROM df_raster_input")
Learn more about bringing your own model see Bring your own model in the Wherobots Documentation.
Since we ran inference across many country coastlines all over the world, many scenes don't contain wind farms and don't have positive detections. Now that we've generated predictions using our model over our satellite imagery, we can filter the geometries by confidence score with RS_FILTER_BOX_CONFIDENCE
and by the integer label representing offshore wind farms, 2
, to locate predicted offshore wind farms.
filtered_predictions = sedona.sql(f"""
SELECT
outdb_raster,
filtered.*
FROM (
SELECT
outdb_raster,
RS_FILTER_BOX_CONFIDENCE(bboxes_wkt, confidence_scores, labels, 0.65) AS filtered
FROM
predictions
) AS temp
WHERE size(filtered.max_confidence_bboxes) > 0
AND array_contains(filtered.max_confidence_labels, '2')
""")
filtered_predictions.createOrReplaceTempView("filtered_predictions")
filtered_predictions.cache().count()
filtered_predictions.show()
Our final step before plotting our prediction results is to convert our table from a format where each row represents a raster scene's predictions to a format where each row represents one predicted bounding box. To do this we combine our list columns with arrays_zip
and then use explode
to convert lists to rows. To convert our string column representing a geometry into a GeometryType
column, we use ST_GeomFromWKT
so we can plot it with SedonaKepler
.
exploded_df = sedona.sql("""
SELECT
outdb_raster,
exploded.*
FROM (
SELECT
outdb_raster,
explode(arrays_zip(max_confidence_bboxes, max_confidence_scores, max_confidence_labels)) AS exploded
FROM
filtered_predictions
) temp
""")
df_exploded = exploded_df.withColumn("geometry", expr("ST_GeomFromWkt(max_confidence_bboxes)")).drop("max_confidence_bboxes")
print(df_exploded.cache().count())
df_exploded.show()
Zoom into the coasts of China or the Netherlands to spot some detected wind farms!
from sedona.maps.SedonaKepler import SedonaKepler
config = {
'version': 'v1',
'config': {
'mapStyle': {
'styleType': 'dark',
'topLayerGroups': {},
'visibleLayerGroups': {},
'mapStyles': {}
}
}
}
map = SedonaKepler.create_map(config=config)
SedonaKepler.add_df(map, df=df_exploded.drop("outdb_raster"), name="Wind Farm Detections")
map
wherobots.inference Python API¶
If you prefer Python, wherobots.inference offers a module for registering the SQL inference functions as Python functions. Below we run the same inference as before with RS_DETECT_BBOXES
.
from wherobots.inference.engine.register import create_object_detection_udfs
from pyspark.sql.functions import col
rs_detect, rs_threshold_geoms = create_object_detection_udfs(batch_size = 10, sedona=sedona)
df = df_raster_input.withColumn("detect_result", rs_detect(model_id, col("outdb_raster"))).select(
"outdb_raster",
col("detect_result.bboxes_wkt").alias("bboxes_wkt"),
col("detect_result.confidence_scores").alias("confidence_scores"),
col("detect_result.labels").alias("labels")
)
df.show()