Skip to content

Commit 9776bfb

Browse files
committed
poly smaller than pixel fix
1 parent ad31d6b commit 9776bfb

File tree

1 file changed

+220
-17
lines changed

1 file changed

+220
-17
lines changed

geospatial_learn/shape.py

+220-17
Original file line numberDiff line numberDiff line change
@@ -557,7 +557,7 @@ def _bbox_to_pixel_offsets(rgt, geom):
557557
----------
558558
559559
rgt : array
560-
List of points defining polygon (?)
560+
List of points defining polygon
561561
562562
geom : shapely.geometry
563563
Structure defining geometry
@@ -613,22 +613,10 @@ def _bbox_to_pixel_offsets(rgt, geom):
613613
# Specify offset and rows and columns to read
614614
xoff = int((xmin - xOrigin)/pixelWidth)
615615
yoff = int((yOrigin - ymax)/pixelWidth)
616-
xcount = int((xmax - xmin)/pixelWidth)#+1
617-
ycount = int((ymax - ymin)/pixelWidth)#+1
618-
# originX = rgt[0]
619-
# originY = rgt[3]
620-
# pixel_width = rgt[1]
621-
# pixel_height = rgt[5]
622-
# x1 = int((bbox[0] - originX) / pixel_width)
623-
# x2 = int((bbox[1] - originX) / pixel_width) + 1
624-
#
625-
# y1 = int((bbox[3] - originY) / pixel_height)
626-
# y2 = int((bbox[2] - originY) / pixel_height) + 1
627-
#
628-
# xsize = x2 - x1
629-
# ysize = y2 - y1
630-
# return (x1, y1, xsize, ysize)
631-
return (xoff, yoff, xcount, ycount)
616+
xcount = int((xmax - xmin)/pixelWidth)
617+
ycount = int((ymax - ymin)/pixelWidth)
618+
619+
return [xoff, yoff, xcount, ycount]
632620

633621
def sqlfilter(inShp, sql):
634622

@@ -953,6 +941,13 @@ def zonal_stats(inShp, inRas, band, bandname, layer=None, stat = 'mean',
953941
# if outside the raster
954942
src_offset = _bbox_to_pixel_offsets(rgt, geom)
955943

944+
#x, y, xcount, ycount - so if counts are 0, the geom is smaller than
945+
# the pixel, so we can give it a shape of 1x1
946+
if src_offset[2] == 0:
947+
src_offset[2] = 1
948+
if src_offset[3] == 0:
949+
src_offset[3] = 1
950+
956951
# This does not seem to be fullproof
957952
# This is a hacky mess that needs fixed
958953
if poly.Contains(geom) == False:
@@ -1083,6 +1078,214 @@ def zonal_stats(inShp, inRas, band, bandname, layer=None, stat = 'mean',
10831078

10841079
if write_stat != None:
10851080
return frame, rejects
1081+
1082+
def zonal_frac(inShp, inRas, band, bandname, layer=None,
1083+
write_stat=True, nodata_value=0, all_touched=True,
1084+
expression=None):
1085+
1086+
"""
1087+
Return the unique classes and their counts per polygon
1088+
1089+
Parameters
1090+
----------
1091+
1092+
inShp: string
1093+
input shapefile
1094+
1095+
inRas: string
1096+
input raster
1097+
1098+
band: int
1099+
an integer val eg - 2
1100+
1101+
bandname: string
1102+
eg - blue
1103+
1104+
layer: string
1105+
if using a db type format with multi layers, specify the name of the
1106+
layer in question
1107+
1108+
write_stat: bool (optional)
1109+
If True, stat will be written to OGR file, if false, dataframe
1110+
only returned (bool)
1111+
1112+
nodata_value: numerical
1113+
If used the no data val of the raster
1114+
1115+
all_touched: bool
1116+
whether to use all touched when raterising the polygon
1117+
if the poly is smaller/comaparable to the pixel size,
1118+
True is perhaps the best option
1119+
expression: string
1120+
process a selection only eg expression e.g. "DN >= 168"
1121+
"""
1122+
# gdal/ogr-based zonal stats
1123+
1124+
if all_touched == True:
1125+
touch = "ALL_TOUCHED=TRUE"
1126+
else:
1127+
touch = "ALL_TOUCHED=FALSE"
1128+
1129+
rds = gdal.Open(inRas, gdal.GA_ReadOnly)
1130+
#assert(rds)
1131+
rb = rds.GetRasterBand(band)
1132+
rgt = rds.GetGeoTransform()
1133+
1134+
if nodata_value:
1135+
nodata_value = float(nodata_value)
1136+
rb.SetNoDataValue(nodata_value)
1137+
1138+
vds = ogr.Open(inShp, 1)
1139+
1140+
# if we are using a db of some sort gpkg etc where we have to choose
1141+
if layer !=None:
1142+
vlyr = vds.GetLayerByName(layer)
1143+
else:
1144+
vlyr = vds.GetLayer()
1145+
1146+
if expression != None:
1147+
vlyr.SetAttributeFilter(expression)
1148+
fcount = str(vlyr.GetFeatureCount())
1149+
print(expression+"\nresults in "+fcount+" features to process")
1150+
1151+
if write_stat != None:
1152+
# if the field exists leave it as ogr is a pain with dropping it
1153+
# plus can break the file
1154+
if _fieldexist(vlyr, bandname+'_cls') == False:
1155+
vlyr.CreateField(ogr.FieldDefn(bandname+'_cls', ogr.OFTIntegerList))
1156+
if _fieldexist(vlyr, bandname+'_cnt') == False:
1157+
vlyr.CreateField(ogr.FieldDefn(bandname+'_cnt', ogr.OFTIntegerList))
1158+
1159+
1160+
mem_drv = ogr.GetDriverByName('Memory')
1161+
driver = gdal.GetDriverByName('MEM')
1162+
1163+
# Loop through vectors
1164+
stats = []
1165+
feat = vlyr.GetNextFeature()
1166+
features = np.arange(vlyr.GetFeatureCount())
1167+
rejects = []
1168+
1169+
#create a poly of raster bbox to test for within raster
1170+
poly = rasterext2poly(inRas)
1171+
1172+
#TODO FAR too many if statements in this loop.
1173+
# This is FAR too slow
1174+
1175+
for label in tqdm(features):
1176+
1177+
if feat is None:
1178+
continue
1179+
# debug
1180+
# wkt=geom.ExportToWkt()
1181+
# poly1 = loads(wkt)
1182+
geom = feat.geometry()
1183+
1184+
src_offset = _bbox_to_pixel_offsets(rgt, geom)
1185+
#x, y, xcount, ycount - so if counts are 0, the geom is smaller than
1186+
# the pixel, so we can give it a shape of 1x1
1187+
if src_offset[2] == 0:
1188+
src_offset[2] = 1
1189+
if src_offset[3] == 0:
1190+
src_offset[3] = 1
1191+
1192+
# This does not seem to be fullproof
1193+
# This is a hacky mess that needs fixed
1194+
if poly.Contains(geom) == False:
1195+
#print(src_offset[0],src_offset[1])
1196+
#offs.append()
1197+
feat = vlyr.GetNextFeature()
1198+
continue
1199+
elif src_offset[0] > rds.RasterXSize:
1200+
feat = vlyr.GetNextFeature()
1201+
continue
1202+
elif src_offset[1] > rds.RasterYSize:
1203+
feat = vlyr.GetNextFeature()
1204+
continue
1205+
elif src_offset[0] < 0 or src_offset[1] < 0:
1206+
feat = vlyr.GetNextFeature()
1207+
continue
1208+
1209+
if src_offset[0] + src_offset[2] > rds.RasterXSize:
1210+
# needs to be the diff otherwise neg vals are possble
1211+
xx = abs(rds.RasterXSize - src_offset[0])
1212+
1213+
src_offset = (src_offset[0], src_offset[1], xx, src_offset[3])
1214+
1215+
if src_offset[1] + src_offset[3] > rds.RasterYSize:
1216+
yy = abs(rds.RasterYSize - src_offset[1])
1217+
src_offset = (src_offset[0], src_offset[1], src_offset[2], yy)
1218+
1219+
src_array = rb.ReadAsArray(src_offset[0], src_offset[1], src_offset[2],
1220+
src_offset[3])
1221+
if src_array is None:
1222+
src_array = rb.ReadAsArray(src_offset[0]-1, src_offset[1], src_offset[2],
1223+
src_offset[3])
1224+
if src_array is None:
1225+
rejects.append(feat.GetFID())
1226+
continue
1227+
1228+
# calculate new geotransform of the feature subset
1229+
new_gt = (
1230+
(rgt[0] + (src_offset[0] * rgt[1])),
1231+
rgt[1],
1232+
0.0,
1233+
(rgt[3] + (src_offset[1] * rgt[5])),
1234+
0.0,
1235+
rgt[5])
1236+
1237+
1238+
# Create a temporary vector layer in memory
1239+
mem_ds = mem_drv.CreateDataSource('out')
1240+
mem_layer = mem_ds.CreateLayer('poly', None, ogr.wkbPolygon)
1241+
mem_layer.CreateFeature(feat.Clone())
1242+
1243+
# Rasterize it
1244+
1245+
rvds = driver.Create('', src_offset[2], src_offset[3], 1, gdal.GDT_Byte)
1246+
1247+
rvds.SetGeoTransform(new_gt)
1248+
rvds.SetProjection(rds.GetProjectionRef())
1249+
rvds.SetGeoTransform(new_gt)
1250+
gdal.RasterizeLayer(rvds, [1], mem_layer, burn_values=[1], options=[touch])
1251+
rv_array = rvds.ReadAsArray()
1252+
1253+
# Mask the source data array with our current feature using np mask
1254+
1255+
#rejects.append(feat.GetField('DN'))
1256+
masked = np.ma.MaskedArray(
1257+
src_array,
1258+
mask=np.logical_or(
1259+
src_array == nodata_value,
1260+
np.logical_not(rv_array)
1261+
)
1262+
)
1263+
unique, count = np.unique(masked.data, return_counts=True)
1264+
1265+
stats.append([feat.GetFID(), unique, count])
1266+
1267+
if write_stat != None:
1268+
# may have to insert into gdf as array
1269+
# TypeError: Feature_SetFieldIntegerList expected 3 arguments, got 4
1270+
# But there are 3 args apart from the self
1271+
feat.SetFieldIntegerList(bandname+'_cls', 3, unique.tolist())
1272+
feat.SetFieldIntegerList(bandname+'_cnt', 3, count.tolist())
1273+
# A hack could be to write a string then use eval(string) to get a list back
1274+
#feat.SetField(bandname+'_cls', str(unique.tolist()))
1275+
#feat.SetField(bandname+'_cnt', str((count.tolist()))
1276+
vlyr.SetFeature(feat)
1277+
feat = vlyr.GetNextFeature()
1278+
1279+
if write_stat != None:
1280+
vlyr.SyncToDisk()
1281+
1282+
vds = None
1283+
rds = None
1284+
frame = DataFrame(data=stats, columns=['fid', bandname+'_cls', bandname+'_cnt'])
1285+
1286+
if write_stat is None:
1287+
return frame, rejects
1288+
10861289

10871290
def zonal_stats_all(inShp, inRas, bandnames,
10881291
statList = ['mean', 'min', 'max', 'median', 'std',

0 commit comments

Comments
 (0)