# As with Estimators & training jobs, we can instead attach to an existing Endpoint: ss_predictor = sagemaker.predictor.Predictor("XXX-2022-10-XX-XX-XX-XX-XX") # please fill your end point name here
设置序列化器和反序列化器
在 SageMaker SDK 中,Predictor 有一个关联的序列化器和反序列化器,它们控制数据如何转换为我们的 API 调用,并加载回 Python 结果对象。
import io import tempfile import mxnet as mx from sagemaker.amazon.record_pb2 import Record
class SSProtobufDeserializer(sagemaker.deserializers.BaseDeserializer): """Deserialize protobuf semantic segmentation response into a numpy array""" def __init__(self, accept="application/x-protobuf"): self.accept = accept
@property def ACCEPT(self): return (self.accept,)
def deserialize(self, stream, content_type): """Read a stream of bytes returned from an inference endpoint. Args: stream (botocore.response.StreamingBody): A stream of bytes. content_type (str): The MIME type of the data. Returns: mask: The numpy array of class confidences per pixel """ try: rec = Record()
# mxnet.recordio can only read from files, not in-memory file-like objects, so we buffer the # response stream to a file on disk and then read it back:
with tempfile.NamedTemporaryFile(mode="w+b") as ftemp: ftemp.write(stream.read()) ftemp.seek(0)
recordio = mx.recordio.MXRecordIO(ftemp.name, "r") protobuf = rec.ParseFromString(recordio.read()) values = list(rec.features["target"].float32_tensor.values) shape = list(rec.features["shape"].int32_tensor.values)
# We 'squeeze' away extra dimensions introduced by the fact that the model can operate on batches # of images at a time: shape = np.squeeze(shape) mask = np.reshape(np.array(values), shape) return np.squeeze(mask, axis=0) finally: stream.close()
# Red-shift our image to make the cyan highlights more obvious: imshifted = imarray.copy() imshifted[:, :, 1] *= 0.6 imshifted[:, :, 2] *= 0.5
# Construct a mask with alpha channel taken from the model result: hilitemask = np.tile(hilitecol[np.newaxis, np.newaxis, :], list(imarray.shape[:2]) + [1]) hilitemask[:, :, 3] = prob_mask[target_cls_id, :, :]
# Overlay the two images: fig, (ax0, ax1, ax2) = plt.subplots(1, 3, figsize=(16, 6)) ax0.imshow(imarray) ax0.axis("off") ax0.set_title("Original Image")
ax2.imshow(hilitemask) ax2.axis("off") ax2.set_title("Highlight people Mask")