import PreProcess from './preprocess.js';
import PostProcess from './postprocess.js';
import AsyncQueue from './asyncqueue.js';
import Mutex from './mutex.js';
import DGError from './dgerror.js';
// Load msgpack library from the adjacent file
import { loadMsgpack } from './msgpack.min.js';
if (typeof msgpack === 'undefined') {
loadMsgpack();
}
/**
* @class
* @classdesc A comprehensive class for handling AI model inference using an AIServer
* over WebSocket. Designed to provide a streamlined interface for sending data to the server for inference, receiving
* processed results, and displaying or further processing these results as needed. <br><br>
* Features: <br>
* - WebSocket Communication: Handles the full lifecycle of a WebSocket connection for real-time data streaming.<br>
* - Preprocessing & Postprocessing: Integrates with PreProcess and PostProcess classes to prepare data for the model and visualize results.<br>
* - Queue Management: Uses AsyncQueue instances to manage inbound and outbound data flow.<br>
* - Concurrency Control: Ensures thread-safe operations through mutex usage.<br>
* - Dynamic Configuration: Allows runtime modification of model and overlay parameters.<br>
* - Callback Integration: Supports custom callback functions for handling results outside the class.<br>
*
* @example <caption>Usage:</caption>
* - Create an instance with the required model details and server URL.
* let model = zoo.loadModel('some_model_name', {} );
* - Use the `predict` method for inference with individual data items or `predict_batch` for multiple items.
* let result = await model.predict(someImage);
* for await (let result of model.predict_batch(someDataGeneratorFn)) { ... }
* - Access processed results directly or set up a callback function for custom result handling.
* - You can display results to a canvas to view drawn overlays.
* await model.displayResultToCanvas(result, canvas);
*/
class AIServerModel {
/**
* Do not call the constructor directly. Use the `loadModel` method of an AIServerZoo instance to create an AIServerModel.
* @constructor
* @param {Object} options - Options for initializing the model.
* @param {string} options.modelName - The name of the model to load.
* @param {string} options.serverUrl - The URL of the server.
* @param {Object} options.modelParams - The default model parameters.
* @param {number} [options.max_q_len=10] - Maximum queue length.
* @param {Function} [options.callback=null] - Callback function for handling results.
* @param {Object} [options.labels=null] - Label dictionary for the model.
* @param {Array<string>} options.systemDeviceTypes - Array of 'RUNTIME/DEVICE' strings supported by the AIServer.
* @param {Object} [additionalParams] - Additional parameters for the model.
*/
constructor({ modelName, serverUrl, modelParams, max_q_len = 10, callback = null, labels = null, systemDeviceTypes }, additionalParams) {
this.debugLogsEnabled = false;
// The creation of a AIServerModel only happens when the user requested a new model on the serverUrl.
// Thus we need to initialize a new websocket with the first packet having the modelname,
// as well as the config if needed
this.modelName = modelName;
this.serverUrl = serverUrl;
this.labels = labels; // Label dictionary for this model
this.systemDeviceTypes = systemDeviceTypes; // Array of 'RUNTIME/DEVICE' strings supported by the AIServer.
if (!this.systemDeviceTypes || this.systemDeviceTypes.length === 0) {
throw new DGError("System Device Types are missing from Zoo upon initalization of AIServerModel class!", "MISSING_SYSTEM_DEVICE_TYPES", {}, "An error occurred during initialization.", "AIServerZoo should have sent these to the model in loadModel().");
}
this._deviceType = null;
// modelParams is the JSON of DEFAULT model parameters grabbed during zoo.loadModel()
// All modifications to model params will have to be sent as config messages to the server
this.modelParams = modelParams;
this.additionalParams = additionalParams;
this.dirty = false; // Dirty flag to signify if model params have changed (inference parameters)
this.modelConfig = {}; // Configuration for websocket to initialize non-default model parameters.
// Pre/Post processors are initialized upon first call to functions that need them
this.preProcessor = null;
this.postProcessor = null;
this.configParamsDirty = true; // Dirty flag specifically for display / input handling parameters
this.initialized = false; // true if acknowledgement packet received from server
this.infoQ = new AsyncQueue(max_q_len, 'infoQ');
this.resultQ = new AsyncQueue(max_q_len, 'resultQ');
this.callback = callback;
this.poison = false; // Poison flag to stop all predict calls
// Now, initialize some other member values using modelParams
this.finishedSettingAdditionalParams = false;
this.initMemberValues();
// Initializes websocket to load the model, initialize listeners
this.initializeSocket();
// class member variable to store the last processed message Promise
this.lastProcessedMessage = Promise.resolve();
this.mutex = new Mutex(); // Why not?
this.MAX_SOCKET_WAIT_MS = 10000; // Max Time to wait for the socket connection to be opened before error.
}
/**
* Logs messages to the console if debug logs are enabled.
* @private
* @param {...any} args - The messages to log.
*/
log(...args) {
if (this.debugLogsEnabled) {
console.log(...args);
}
}
/**
* Initializes member values from model parameters.
* @private
*/
initMemberValues() {
if (this.modelParams && this.modelParams["MODEL_PARAMETERS"] && this.modelParams["MODEL_PARAMETERS"].length > 0) {
const parameters = this.modelParams["MODEL_PARAMETERS"][0];
const preProcessParams = this.modelParams["PRE_PROCESS"][0];
this.modelPath = parameters.ModelPath;
// NCHW info is either under 'MODEL_PARAMETERS' or 'PRE_PROCESS'
if (parameters.ModelInputN) {
this.modelInputN = parameters.ModelInputN;
this.modelInputC = parameters.ModelInputC;
this.modelInputH = parameters.ModelInputH;
this.modelInputW = parameters.ModelInputW;
} else if (preProcessParams.InputN) {
this.modelInputN = preProcessParams.InputN;
this.modelInputC = preProcessParams.InputC;
this.modelInputH = preProcessParams.InputH;
this.modelInputW = preProcessParams.InputW;
} else {
throw new DGError("Model Parameters don't contain input height / width.", "MISSING_PARAMS", { parameters, preProcessParams }, "Ensure model parameters include input height and width.", "Check the model's documentation to provide the required input height and width parameters.");
}
}
// Set internal parameters to default, only if the additionalParams weren't handled yet.
if (!this.finishedSettingAdditionalParams) {
// Internal model pre/post processing and inference parameters, initially set to defaults.
this._labelWhitelist = null;
this._labelBlacklist = null;
// Display Parameters
this._overlayColor = [255, 0, 0];
this._overlayLineWidth = 2;
this._overlayShowLabels = true;
this._overlayShowProbabilities = false;
this._overlayAlpha = 0.75;
this._overlayFontScale = 1.0;
// Input Handling Parameters
this._inputLetterboxFillColor = [0, 0, 0];
this._inputPadMethod = 'letterbox';
this._saveModelImage = false;
this._inputCropPercentage = 1.0;
}
// Assign additional parameters, using our set/get functions. Try to overwrite them by these values, and
// warn the user if that parameter doesn't exist.
// Only do this ONCE: prior to websocket opening.
if (this.additionalParams !== null && this.additionalParams !== undefined && !this.finishedSettingAdditionalParams) {
for (const [key, value] of Object.entries(this.additionalParams)) {
this.log('initMemberValues(): Setting additional param', key, 'to', value);
// Check for the existence of a setter for 'key'
const descriptor = Object.getOwnPropertyDescriptor(Object.getPrototypeOf(this), key);
const hasSetter = descriptor && typeof descriptor.set === 'function';
if (hasSetter) {
try {
// console.log('Invoking setter to set', key, 'to', value);
this[key] = value; // invoke the setter for 'key'
// console.log('Now, the value for', key, 'is:', this[key]);
} catch (error) {
console.warn(`Error using setter for '${key}': ${error.message}`);
}
} else {
console.warn(`Setter for '${key}' does not exist or cannot be used.`);
}
}
}
if (!this.additionalParams?.deviceType) {
console.log("User did not specify deviceType in additionalParams. Attempting to infer deviceType from Model Parameters.");
// If the user did NOT specify deviceType inside the additional options for the model,
// we need to create it internally ourselves as the 'RUNTIME/DEVICE' pair.
// We do this by combining 'RuntimeAgent' and 'DeviceType' strings from getModelParameter()
try {
this._deviceType = this.getModelParameter('RuntimeAgent') + '/' + this.getModelParameter('DeviceType');
// this.deviceType = this.getModelParameter('RuntimeAgent') + '/' + this.getModelParameter('DeviceType'); // Don't use setter, it will create unnecessary modelConfig, dirty flag, etc.
console.log('Set deviceType from Model Parameters:', this.deviceType);
} catch (error) {
throw new DGError('Failed to infer device type from Model Parameters:' + error, "DEVICE_TYPE_INFERENCE_ERROR", { error }, "An error occurred during device type inference.", "Please check the model parameters to ensure they have RuntimeAgent and DeviceType set");
}
} else {
// User DID set the deviceType. The validity was checked inside the setter of deviceType.
console.log('User set deviceType using additionalParams:', this.deviceType);
}
this.finishedSettingAdditionalParams = true;
// Now set dirty to false. As some setters use setModelParameter(), this creates the this.modelConfig which gets passed
// to the websocket upon model initialization. Dirty is set to true when these are modified even if model isn't initialized
// yet, so we set it back to false, so we don't reinitialize the websocket connection for no reason.
this.dirty = false;
/*
imageBackend - maybe in future
inputNumpyColorspace - not applicable
inputResizeMethod - maybe in future, add nicer downscale functiosn
*/
// Control / Information Parameters
/*
// TODO:
frameQueueDepth // int 1.160 - setters and getters need to have some nice implementation
measureTime // bool -
nonBlockingBatchPredict // bool - ?
eagerBatchSize // int 1..80 - ?
*/
// Inference Parameters set / get through setModelParameter() and getModelParameter()
// Setters / getters exist to wrap them in input handling
}
/**
* Constructs a configuration for overlay settings which is passed to the pre/post processors. Will
* compile every internal parameter settable by the user.
*
* @private
* @returns {Object} The overlay configuration.
*/
constructOverlayConfig() {
return {
labels: this.labels,
overlayColor: this._overlayColor,
overlayLineWidth: this._overlayLineWidth,
overlayShowLabels: this._overlayShowLabels,
overlayShowProbabilities: this._overlayShowProbabilities,
overlayAlpha: this._overlayAlpha,
overlayFontScale: this._overlayFontScale,
inputLetterboxFillColor: this._inputLetterboxFillColor,
inputPadMethod: this._inputPadMethod,
saveModelImage: this._saveModelImage,
inputCropPercentage: this._inputCropPercentage
};
}
/**
* Predicts the result for a given image. <br>
*
* @example If callback is provided:
* The WebSocket onmessage will invoke the callback directly when the result arrives.
*
* @example If callback is not provided:
* The function waits for the resultQ to get a result, then returns it.
* let result = await model.predict(someImage);
*
* @async
* @param {Blob|File|string|HTMLImageElement|HTMLVideoElement|HTMLCanvasElement|ArrayBuffer|TypedArray|ImageBitmap} imageFile
* @param {string} [info=performance.now()] - Unique frame information provided by user (such as frame num). Used for matching results back to input images within callback.
* @param {boolean} [bypassPreprocessing=false] - Whether to bypass preprocessing. Used to send Blob data directly to the socket without any preprocessing.
* @returns {Promise<Object>} The prediction result.
*/
async predict(imageFile, info = '_DEFAULT_FRAME_INFO_', bypassPreprocessing = false) {
let unlockedAlready = true; // only allow one mutex unlock operation per function call.
// passthrough if the error flag is enabled.
if (this.poison) return;
// Outer try catch finally block for handling unexpected errors with cleanup code (e.g. release mutex)
try {
// Check if the model needs to be reloaded
if (this.dirty) {
this.log('predict(): dirty flag caught. modelParams object prior to reset:', this.modelParams);
await this.handleDirtyFlag();
}
// Generate unique info from the imageFile here if needed
if (info == '_DEFAULT_FRAME_INFO_') {
// TODO: performance.now() isn't fast enough! Instead we can just implement a frame counter
info = `frame_${performance.now()}`;
}
// Instantly push the frame info to infoQ
await this.infoQ.push({ info });
// Attempt to lock the mutex with a timeout
const mutexLockPromise = this.mutex.lock();
const mutexTimeout_ms = 10000;
const mutexTimeoutPromise = this.timeoutPromise(mutexTimeout_ms, () => this.mutex.cancelLock(mutexLockPromise));
// Wait for either the mutex to be acquired, or for the timeout to occur
await Promise.race([
mutexLockPromise.then(() => { mutexTimeoutPromise.cancel(); }), // Lock acquired
mutexTimeoutPromise // Timeout
]).catch(error => {
throw new DGError("Error during mutex lock / mutex timeout unlock", "MUTEX_LOCK_TIMEOUT_ERROR", {}, "An error occurred while acquiring the mutex lock.", error);
});
unlockedAlready = false; // Set to false, so that now we are allowed to call unlock
if (bypassPreprocessing) {
// Ensure that the imageFile is a Blob
if (!(imageFile instanceof Blob)) {
throw new DGError("predict(): Bypassed image must be a Blob.", "INVALID_BYPASSED_IMAGE", {}, "An error occurred during image preprocessing.", "Please check the image and try again.");
}
// Attach fake transformationDetails to infoQ.
let fakeTransformationDetails = {
scaleX: 1.0,
scaleY: 1.0,
offsetX: 0,
offsetY: 0
};
await this.infoQ.update(
item => item.info === info,
{ transformationDetails: fakeTransformationDetails, imageFrame: null }
);
// Directly send the image to the socket.
await this.waitForSocketConnection(); // Timeout error here will be caught in outer trycatch
this.log('predict(): Sending bypassed image to socket with info:', info);
// unpack blob into float32array
let blobData = await imageFile.arrayBuffer();
let floatArray2 = new Float32Array(blobData);
this.log('predict() sending: ', floatArray2);
this.socket.send(imageFile);
} else {
// Validate / send the image frame
await this.validateAndSendFrame(imageFile, info);
}
this.mutex.unlock();
unlockedAlready = true;
if (this.callback == null) {
return await this.resultQ.pop(); // does not instantly complete the try block. waits for the Promise to resolve before returning the resolved value
}
} catch (error) {
this.poison = true; // Set the error flag.
throw new DGError("An error occurred during predict:" + error, "PREDICT_ERROR", { error }, "An error occurred during predict.");
} finally {
// We enter here if there was an exception, OR after the await expression has
// been resolved and the function has returned.
// So, the mutex should be unlocked prior to the return call, but if we
// get here and it's still locked, it means there was an exception and we
// must unlock it.
if (!unlockedAlready) this.mutex.unlock();
}
}
/**
* Validates/Converts input frame, then sends it to the server.
* @async
* @private
* @param {Blob|File|string|HTMLImageElement|HTMLVideoElement|HTMLCanvasElement|ArrayBuffer} imageFile - The input image.
* @param {string} info - Frame information for matching results.
*/
async validateAndSendFrame(imageFile, info) {
// Input frame validation / conversion
let imageFrame = await this.validateAndConvertInputFrame(imageFile);
if (this.configParamsDirty || !this.preProcessor || !this.postProcessor) {
await this.initPrePostProcessors();
}
await this.preprocessAndSend(imageFrame, info);
}
/**
* Predicts results for a batch of data. Will yield results if a callback is not provided.
* @async
* @generator
*
* @example The function asynchronously processes results. If a callback is not provided, it will yield results.
* for await (let result of model.predict_batch(data_source)) { console.log(result); }
* @param {AsyncIterable} data_source - An async iterable data source.
* @param {boolean} [bypassPreprocessing=false] - Whether to bypass preprocessing.
* @yields {Object} The prediction result.
*/
async *predict_batch(data_source, bypassPreprocessing = false) {
// passthrough if the error flag is enabled.
if (this.poison) return;
try {
// Check if the model needs to be reloaded upon first predict
if (this.dirty) {
this.log('predict_batch(): dirty flag caught. modelParams object prior to reset:', this.modelParams);
await this.handleDirtyFlag();
}
// Iterates over asynchronous generator (data source) and waits for each item to be yielded
// Each iteration of the loop waits for the completion of its asynchronous tasks before moving
// on to the next one, which implicitly serializes the sequence of operations for each frame
for await (let [data, info] of data_source) {
// Instantly push the frame info to infoQ
await this.infoQ.push({ info });
if (bypassPreprocessing) {
if (!(data instanceof Blob)) {
throw new DGError("predict_batch(): Bypassed image must be a Blob.", "INVALID_BYPASSED_IMAGE", {}, "An error occurred during image preprocessing.", "Please check the image and try again.");
}
let fakeTransformationDetails = {
scaleX: 1.0,
scaleY: 1.0,
offsetX: 0,
offsetY: 0
};
await this.infoQ.update(
item => item.info === info,
{ transformationDetails: fakeTransformationDetails, imageFrame: null }
);
await this.waitForSocketConnection();
this.socket.send(data);
} else {
// Validate / send the image frame
await this.validateAndSendFrame(data, info);
}
if (this.callback == null) {
// asynchronously process results
if (!this.resultQ.empty()) {
yield await this.resultQ.pop();
}
}
}
// there might still be pending results that have not been processed
if (this.callback == null) {
while (!this.infoQ.empty() || !this.resultQ.empty()) {
yield await this.resultQ.pop();
}
}
} catch (error) {
this.poison = true; // Set the error flag.
throw new DGError('An error occurred during predict_batch:' + error, "PREDICT_BATCH_ERROR", { error }, "An error occurred during predict_batch.");
}
}
/**
* Preprocesses the input image prior to sending to websocket. For now, only real preprocessing is resizing the image to model specification. If it's a video, we don't attach it to the infoQ.
* @async
* @private
* @param {Blob} imageFrame - The input frame.
* @param {string} [info=null] - Frame information for matching results.
*/
async preprocessAndSend(imageFrame, info = null) {
try {
// Validate input
if (!imageFrame || !this.modelInputW || !this.modelInputH) {
throw new DGError("preprocessAndSend(): missing input parameters.", "MISSING_INPUT_PARAMETERS_ERROR", {}, "An error occurred during image preprocessing.", "Please check the input and try again.");
}
const startTime = performance.now(); // PERFORMANCE LOGGING
const { blob: resizedBlob, transformationDetails } = await this.preProcessor.resizeImage(imageFrame);
this.log('resizeImage() took:', performance.now() - startTime, 'ms.'); // PERFORMANCE LOGGING
// Check for scale factors
if (!transformationDetails.scaleX || !transformationDetails.scaleY) {
throw new DGError("Scale factors are missing after resizeImage()", "MISSING_SCALE_FACTORS_ERROR", {}, "An error occurred during image preprocessing.", "Please check the image and try again.");
}
// Check for offset values
if (typeof transformationDetails.offsetX === 'undefined' || typeof transformationDetails.offsetY === 'undefined') {
throw new DGError("Offset values are missing after resizeImage()", "MISSING_OFFSET_VALUES_ERROR", {}, "An error occurred during image preprocessing.", "Please check the image and try again.");
}
// Check for resizedBlob
if (!resizedBlob) {
throw new DGError("preprocessAndSend(): resizedBlob is null or undefined", "RESIZED_BLOB_NULL_ERROR", {}, "An error occurred during image preprocessing.", "Please check the image and try again.");
}
// update the infoQ with the original input frame and transformation details
// First we need to handle the case where the input is a video element
if (imageFrame instanceof HTMLVideoElement) {
// We don't need to attach the video element to the infoQ.
// Instead we just make it null.
imageFrame = null;
}
if (this._saveModelImage) {
// Update the existing info in the queue, along with the modelImage (resized image blob)
await this.infoQ.update(
item => item.info === info,
{ transformationDetails, imageFrame: imageFrame, modelImage: resizedBlob }
);
} else {
// Update the existing info in the queue
await this.infoQ.update(
item => item.info === info,
{ transformationDetails, imageFrame: imageFrame }
);
}
await this.waitForSocketConnection(); // Timeout error here will be caught in outer trycatch
// Directly send Blob data
this.log('preprocessAndSend(): Sending blob to socket with info:', info);
this.socket.send(resizedBlob);
} catch (error) {
this.poison = true; // Set the error flag.
throw new DGError("preprocessAndSend(): Failed to preprocess the image:" + error, "PREPROCESS_IMAGE_FAILED", { error }, "Failed to preprocess the image.");
// Now, all subsequent predict / predict_batch calls will be passthrough
}
}
/**
* Validates and converts the user's input frame to a compatible format for resizeImage().
* @async
* @private
* @param {Blob|File|string|HTMLImageElement|HTMLVideoElement|HTMLCanvasElement|ArrayBuffer|SVGImageElement|ImageBitmap|OffscreenCanvas|ImageData|Array} image - The input image.
* @returns {Promise<ImageBitmap|HTMLImageElement|HTMLVideoElement|HTMLCanvasElement>} The validated and converted input frame.
*/
async validateAndConvertInputFrame(image) {
if (!image) {
throw new DGError('validateAndConvertInputFrame(): Image must be provided.', "INVALID_IMAGE_INPUT", {}, "Image must be provided.");
}
// Directly passthrough for these image types, as they can be directly used with our resizeImage() implementation:
// HTMLImageElement
// SVGImageElement
// HTMLVideoElement
// HTMLCanvasElement
// ImageBitmap
// OffscreenCanvas
// TODO: Create tests for all of the above types ^^ (only ImageBitmap is tested......!)
if (image instanceof HTMLImageElement || image instanceof SVGImageElement || image instanceof HTMLVideoElement || image instanceof HTMLCanvasElement || image instanceof ImageBitmap || image instanceof OffscreenCanvas) {
return image;
}
// For Blob, ImageData, and File types, we use createImageBitmap()
if (image instanceof Blob || image instanceof ImageData) {
// Blob and ImageData are valid input types, but we need to convert them to an ImageBitmap
const imageBitmap = await createImageBitmap(image);
return imageBitmap;
}
if (image instanceof File) {
// Check if the file is an image
if (!image.type.startsWith('image/')) {
throw new DGError('validateAndConvertInputFrame(): input image is a File but is not an image.', "INVALID_IMAGE_INPUT", {}, "File is not an image.");
}
// Convert the File to an ImageBitmap
const imageBitmap = await createImageBitmap(image);
return imageBitmap;
}
// Handle Data URLs, image URLs, base64 strings, ArrayBuffers, and typed arrays
if (typeof image === 'string') {
// Data URL
if (image.startsWith('data:')) {
return await this.convertDataUrlToImageBitmap(image);
}
try {
// TODO: Need some more robust way to validate URLs...
new URL(image); // This will throw an error if `image` is not a valid URL
// if (image.startsWith('http'))
} catch (error) {
// If here, the string is neither a Data URL nor a valid URL, so it should be a base64 string
return await this.convertBase64ToImageBitmap(image);
}
// Fetching image from URL
return await this.convertImageURLToImageBitmap(image);
}
if (image instanceof ArrayBuffer) {
return await this.convertArrayBufferToImageBitmap(image);
}
if (image instanceof Uint8Array || image instanceof Uint16Array || image instanceof Float32Array) {
return await this.convertTypedArrayToImageBitmap(image);
}
throw new DGError('Invalid image input type, it is: ' + typeof (image), "INVALID_IMAGE_INPUT", {}, "Invalid image input type.");
}
/**
* Converts a data URL to an ImageBitmap.
* @private
* @async
* @param {string} dataUrl - The data URL to convert.
* @returns {Promise<ImageBitmap>} The converted ImageBitmap.
*/
async convertDataUrlToImageBitmap(dataUrl) {
if (!dataUrl.startsWith('data:')) {
throw new DGError('Invalid data URL: ' + dataUrl, "INVALID_DATA_URL", {}, "Invalid data URL.");
}
const response = await fetch(dataUrl);
const blob = await response.blob();
return await createImageBitmap(blob);
}
/**
* Converts an image URL to an ImageBitmap.
* @private
* @async
* @param {string} imageUrl - The image URL to convert.
* @returns {Promise<ImageBitmap>} The converted ImageBitmap.
*/
async convertImageURLToImageBitmap(imageUrl) {
try {
new URL(imageUrl); // Validates the URL
const response = await fetch(imageUrl);
const blob = await response.blob();
return await createImageBitmap(blob);
} catch (error) {
throw new DGError('Invalid image URL: ' + imageUrl + ' : ' + error, "INVALID_IMAGE_URL", { error }, "Invalid image URL.");
}
}
/**
* Converts a base64 string to an ImageBitmap.
* @private
* @async
* @param {string} base64 - The base64 string to convert.
* @returns {Promise<ImageBitmap>} The converted ImageBitmap.
*/
async convertBase64ToImageBitmap(base64) {
try {
const byteString = atob(base64);
const ab = new ArrayBuffer(byteString.length);
const ia = new Uint8Array(ab);
for (let i = 0; i < byteString.length; i++) {
ia[i] = byteString.charCodeAt(i);
}
let blob = new Blob([ab], { type: 'image/jpeg' });
return await createImageBitmap(blob);
} catch (error) {
throw new DGError('Invalid base64 string: ' + base64 + ' : ' + error, "INVALID_BASE64_STRING", { error }, "Invalid base64 string.");
}
}
/**
* Converts an ArrayBuffer to an ImageBitmap.
* @private
* @async
* @param {ArrayBuffer} arrayBuffer - The ArrayBuffer to convert.
* @returns {Promise<ImageBitmap>} The converted ImageBitmap.
*/
async convertArrayBufferToImageBitmap(arrayBuffer) {
if (!(arrayBuffer instanceof ArrayBuffer)) {
throw new DGError('Invalid ArrayBuffer input: ' + arrayBuffer, "INVALID_ARRAY_BUFFER", {}, "Invalid ArrayBuffer input.");
}
const blob = new Blob([arrayBuffer]);
return await createImageBitmap(blob);
}
/**
* Converts a typed array to an ImageBitmap.
* @private
* @async
* @param {string} imageUrl - The image URL to convert.
* @returns {Promise<ImageBitmap>} The converted ImageBitmap.
*/
async convertTypedArrayToImageBitmap(typedArray) {
if (!(typedArray instanceof Uint8Array)) {
throw new DGError('Invalid Uint8Array input: ' + typedArray, "INVALID_TYPED_ARRAY", {}, "Invalid Uint8Array input.");
}
const blob = new Blob([typedArray.buffer]);
return await createImageBitmap(blob);
}
/**
* Reset the socket to handle model parameter change. Waits for outstanding frames to be processed.Handles model parameter changes and resets the socket if necessary.
* @private
* @async
*/
async handleDirtyFlag() {
this.log('handleDirtyFlag(): dirty flag caught. modelParams object prior to reset:', this.modelParams);
if (this.infoQ.empty() && this.resultQ.empty()) {
this.resetSocket();
} else {
// Wait for this.infoQ.onPop listener, check if all empty.
// If all empty, then finally reset the socket
// This promise resolves when both queues are empty
await new Promise(resolve => {
const checkQueuesEmptyAndReset = () => {
if (this.infoQ.empty() && this.resultQ.empty()) {
if (this.infoQ.hasEventListener('onPop')) {
this.infoQ.removeEventListener('onPop', checkQueuesEmptyAndReset);
}
if (this.resultQ.hasEventListener('onPop')) {
this.resultQ.removeEventListener('onPop', checkQueuesEmptyAndReset);
}
}
this.resetSocket();
resolve();
};
// Add the listeners
this.infoQ.addEventListener('onPop', checkQueuesEmptyAndReset);
this.resultQ.addEventListener('onPop', checkQueuesEmptyAndReset);
});
}
// After resetting the socket, no need to wait for it to be back up again
// as this happens prior to sending frame to socket
}
/**
* Initializes onmessage and onerror listeners for the websocket.
* @private
*/
initSocketListeners() {
this.socket.onmessage = (event) => {
// Chaining the promises of message processing, making sure that a new message
// will only start processing after the previous one has finished. The
// this.lastProcessedMessage keeps track of the most recent message's processing
// status, ensuring they are processed sequentially.
// This way, even if onmessage is triggered multiple times rapidly, each message
// will be processed in the order they arrived, ensuring that the results won't get mismatched
this.lastProcessedMessage = this.lastProcessedMessage.then(async () => {
// Convert Blob to ArrayBuffer
const arrayBuffer = await event.data.arrayBuffer();
// Create a Uint8Array from the ArrayBuffer
const uint8Array = new Uint8Array(arrayBuffer);
// Decode MessagePack to JSON
// eslint-disable-next-line no-undef
let data = msgpack.decode(uint8Array);
if (!this.initialized && data.model_params) {
this.log('onmessage: Succesfully loaded model', this.modelName);
// this.log('onmessage: Setting model params to:', data.model_params);
this.modelParams = data.model_params;
this.initMemberValues();
this.initialized = true;
} else if (this.initialized) {
let info, transformationDetails, imageFrame, modelImage, combinedResult;
if (this.poison) return;
// Error check the result. If there is an error:
// - ignore all subsequent results
// - turn on poison flag to make predict calls passthrough
const errorMsg = this.errorCheck(data);
if (errorMsg) {
this.poison = true;
throw new DGError(`Error caught in result object: ${errorMsg}`, "RESULT_ERROR", { error: errorMsg }, "Error caught in result object.");
}
// Prepare the frame info for this result
// Grab modelImage if saveModelImage enabled
if (this._saveModelImage) {
({ info, transformationDetails, imageFrame, modelImage } = await this.infoQ.pop());
} else {
({ info, transformationDetails, imageFrame } = await this.infoQ.pop());
}
// console.log('onmessage received data for info ', info, ':', data);
// Log the result and info
this.log('onmessage: Result received for info:', info);
// Check for transformationDetails
if (!transformationDetails) {
// We need to be very descriptive.
// Log info, imageFrame, data, infoQ status
console.warn('onmessage: Result received but transformationDetails are missing for info:', info);
throw new DGError("Transformation details are missing from the infoQ", "MISSING_TRANSFORMATION_DETAILS_ERROR", {}, "An error occurred while processing the result.", "Please try again later.");
}
// Logic for filtering objects based on labelWhitelist and labelBlacklist
if (this._labelWhitelist || this._labelBlacklist) {
// dummy check: Does this model even have labels?
if (!this.labels) {
console.warn('labelWhitelist/labelBlacklist is set but this model does not have a label dictionary. Ignoring the labelWhitelist/labelBlacklist.');
} else {
// whitelist set and blacklist set
if (this._labelWhitelist && this._labelBlacklist) {
const filteredData = data.filter(item => this._labelWhitelist.includes(item.label) && !this._labelBlacklist.includes(item.label));
data = filteredData;
}
// whitelist set, blacklist not set
else if (this._labelWhitelist) {
const filteredData = data.filter(item => this._labelWhitelist.includes(item.label));
data = filteredData;
}
// blacklist set, whitelist not set
else {
const filteredData = data.filter(item => !this._labelBlacklist.includes(item.label));
data = filteredData;
}
}
}
// Attach scales / offsets to the result object
data.scaleX = transformationDetails.scaleX;
data.scaleY = transformationDetails.scaleY;
data.offsetX = transformationDetails.offsetX;
data.offsetY = transformationDetails.offsetY;
// Form the result array and the combinedResult object
const resultArray = [data, info];
if (this._saveModelImage) {
combinedResult = { result: resultArray, imageFrame, modelImage };
} else {
combinedResult = { result: resultArray, imageFrame };
}
if (this.callback == null) {
this.log('onmessage: Pushing to resultQ with info:', info);
this.resultQ.push(combinedResult);
} else {
this.callback(combinedResult, info);
}
} else {
this.log('onmessage: Message from server received but this.initialized == false!');
}
}).catch(error => {
// Reset lastProcessedMessage to avoid blocking future messages
this.lastProcessedMessage = Promise.resolve();
throw new DGError(`Error in onmessage: ${error}`, "ONMESSAGE_ERROR", { error }, "Error in onmessage.");
});
};
this.socket.onerror = (event) => {
// We temporarily turn these into just console errors
// This is because CROW errors are not emitted properly yet
console.error('AIServerModel: WebSocket error observed:', event);
// throw new DGError("AIServerModel: WebSocket error observed:", "WEBSOCKET_ERROR", { event }, "WebSocket error observed.");
};
}
/**
* Waits for outstanding frames to be processed.
* @private
* @returns {Promise<void>}
*/
awaitOutsandingFrames() {
return new Promise((resolve) => {
const checkInterval = setInterval(() => {
if (!this.infoQ.empty() || !this.resultQ.empty()) {
// Still waiting for the queues to be empty
} else {
clearInterval(checkInterval); // Clear the interval when the condition is met
resolve(); // Resolve the promise when both queues are empty
}
}, 10); // Check every 10ms
});
}
/**
* Initializes the WebSocket connection with the server using modelName and modelConfig packet in order to load the model.
* @private
* @async
*/
async initializeSocket() {
try {
await this.waitFor(() => this.finishedSettingAdditionalParams, 500);
} catch (error) {
console.warn('Setting additional parameters timed out. Some values will be set to default.');
}
// this._deviceType is always set when we get here.
if (!this._deviceType) {
throw new DGError("Device type is not set. Uh-oh!");
}
// Last check for device type / runtime agent compatibility inside this.systemDeviceTypes
if (!this.systemDeviceTypes.includes(this._deviceType)) {
// Device type is not supported by the AIServer (this.systemDeviceTypes)
// However, the model still could have other supported device types (this.supportedDeviceTypes)
throw new DGError(`Device type ${this._deviceType} is not supported by the AIServer. Please use one of the supported device types: ${this.supportedDeviceTypes}.`, "UNSUPPORTED_DEVICE_TYPE", {}, "An error occurred while setting the device type.", "Please check the device type and try again.");
}
// console.log('Finished waiting for additional params to be set. Opening socket with model:', this.modelName, 'and config:', this.modelConfig);
this.socket = new WebSocket(`${this.serverUrl}/v1/stream`);
// Load model by sending name + config packet
this.socket.onopen = () => {
this.log('AIServerModel initializeSocket(): WebSocket connection opened. Loading model:', this.modelName);
// this.log('AIServerModel initializeSocket(): sending modelConfig:', this.modelConfig);
this.socket.send(JSON.stringify({
name: this.modelName,
config: this.modelConfig
}));
};
this.initialized = false;
this.initSocketListeners();
}
/**
* Resets the WebSocket connection.
* @private
*/
resetSocket() {
if (this.socket) {
this.socket.close();
}
this.initializeSocket();
// Reset dirty flag
this.dirty = false;
}
/////////////////// Internal Parameter Setters / Getters ///////////////////
// Internal parameters can be set / get without explicit getter / setter calling:
// model.overlayShowLabels = false; // This actually calls the setter method
// console.log(model.overlayShowLabels); // This calls the getter method
// deviceType must either be a string 'RUNTIME/DEVICE' or an array of such strings ['RUNTIME1/DEVICE1', 'RUNTIME2/DEVICE2']
// Then, it's checked against the system's available devices and only one such device is selected (the first one in the array that passed)
set deviceType(value) {
console.log('Entered deviceType setter with value:', value);
if (!value || (typeof value !== 'string' && !Array.isArray(value))) {
throw new TypeError("deviceType should be a string or an array of strings. e.g. 'RUNTIME/DEVICE' or ['RUNTIME1/DEVICE1', 'RUNTIME2/DEVICE2'].");
}
let currentDevice = this.modelParams.DEVICE[0]['RuntimeAgent'] + '/' + this.modelParams.DEVICE[0]['DeviceType'];
if (currentDevice === value) {
console.warn('Device type is already set to:', value);
// Directly update _deviceType. This is so deviceType passed to constructor will be set properly even if it's identical to model params
if (!this._deviceType) this._deviceType = value;
return;
}
console.log('Got here with value:', value);
const checkDeviceType = (deviceType) => {
// console.log('Checking device type:', deviceType);
const agentDevice = deviceType.split('/');
if (agentDevice.length !== 2) {
throw new DGError("deviceType should be in the format 'RUNTIME/DEVICE'.", "INVALID_DEVICE_TYPE", {}, "An error occurred while setting the device type.", "Please check the device type and try again.");
}
if (this.supportedDeviceTypes.includes(deviceType)) {
return agentDevice;
}
return null;
};
const values = Array.isArray(value) ? value : [value];
let agentDevice = null;
for (const deviceType of values) {
agentDevice = checkDeviceType(deviceType);
if (agentDevice !== null) {
break; // take the first suitable device type
}
}
if (agentDevice === null) {
throw new Error(`None of the device types in the list ${values} are supported by the model ${this.modelName}. Supported device types are: ${this.supportedDeviceTypes}.`);
}
// Assign the selected runtime and device to the model parameters
// this.modelParams.RuntimeAgent = agentDevice[0];
// this.modelParams.DeviceType = agentDevice[1];
// this.modelParams.dirty = true; // Mark the model parameters as dirty
this.setModelParameter('RuntimeAgent', agentDevice[0]);
this.setModelParameter('DeviceType', agentDevice[1]);
// Directly update _deviceType to the one that was chosen.
this._deviceType = agentDevice.join('/');
this.log(`Device type set to ${this._deviceType}`);
}
get deviceType() {
return this._deviceType;
}
/**
* Determines if a system device type matches any of the model's supported device types, considering wildcards.
* @private
* @param {Array<string>} modelSupportedTypes - An array of strings representing the device types supported by the model.
* Example: ["OPENVINO/*", "TENSORRT/*", "ONNX/*"]
* @param {Array<string>} systemDeviceTypes - An array of strings representing the device types available on the system.
* Example: ["OPENVINO/CPU", "TENSORRT/GPU", "ONNX/CPU"]
*
* @returns {Array<string>} - An array of strings representing the intersection of modelSupportedTypes and systemDeviceTypes,
* with wildcards considered.
* Example: If modelSupportedTypes is ["OPENVINO/*", "TENSORRT/*"] and systemDeviceTypes is ["OPENVINO/CPU", "TENSORRT/GPU"],
* it returns ["OPENVINO/CPU", "TENSORRT/GPU"].
*/
matchSupportedDevices(modelSupportedTypes, systemDeviceTypes) {
const matchesWildcard = (pattern, type) => {
const [patternRuntime, patternDevice] = pattern.split('/');
const [typeRuntime, typeDevice] = type.split('/');
const runtimeMatches = patternRuntime === '*' || patternRuntime === typeRuntime;
const deviceMatches = patternDevice === '*' || patternDevice === typeDevice;
return runtimeMatches && deviceMatches;
};
return systemDeviceTypes.filter(systemType =>
modelSupportedTypes.some(modelType => matchesWildcard(modelType, systemType))
);
}
get supportedDeviceTypes() {
let modelSupportedTypes;
try {
modelSupportedTypes = this.getModelParameter('SupportedDeviceTypes'); // Returns string such as OPENVINO/*, TENSORRT/*, ONNX/*
modelSupportedTypes = modelSupportedTypes.split(',').map(type => type.trim());
} catch (err) {
// if model does not have SupportedDeviceTypes, use systemDeviceTypes
modelSupportedTypes = this.systemDeviceTypes;
}
return this.matchSupportedDevices(modelSupportedTypes, this.systemDeviceTypes);
}
/**
* Sets the label whitelist. Only labels in the whitelist will be shown in the overlay.
* @type {Array.<string>}
* @private
*/
set labelWhitelist(value) {
if (!Array.isArray(value)) {
throw new TypeError("labelWhitelist should be an array of strings. e.g. ['cat', 'dog'].");
}
for (const label of value) {
if (typeof label !== 'string') {
throw new TypeError("All items in labelWhitelist must be strings. e.g. ['cat', 'dog'].");
}
}
this._labelWhitelist = value;
}
/**
* Gets the label whitelist. Only labels in the whitelist will be shown in the overlay.
* @type {Array.<string>}
* @private
*/
get labelWhitelist() {
return this._labelWhitelist;
}
/**
* Gets the label blacklist. Labels in the blacklist will not be shown in the overlay.
* @type {Array.<string>}
* @private
*/
set labelBlacklist(value) {
if (!Array.isArray(value)) {
throw new TypeError("labelBlacklist should be an array of strings. e.g. ['cat', 'dog'].");
}
for (const label of value) {
if (typeof label !== 'string') {
throw new TypeError("All items in labelBlacklist must be strings. e.g. ['cat', 'dog'].");
}
}
this._labelBlacklist = value;
}
/**
* Sets the label blacklist. Labels in the blacklist will not be shown in the overlay.
* @type {Array.<string>}
* @private
*/
get labelBlacklist() {
return this._labelBlacklist;
}
/////////////////// Display Parameters ///////////////////
/**
* Sets the overlay color. The overlay color is used to draw bounding boxes and labels.
*
* @type {Array.<Array.<number>>}
* @private
*/
set overlayColor(value) {
if (!Array.isArray(value)) {
throw new TypeError("overlayColor should be an array.");
}
// Validate if it's a list of [R, G, B] triplets or just a single triplet
const isValidTriplet = (triplet) => {
return Array.isArray(triplet) &&
triplet.length === 3 &&
triplet.every(color => typeof color === 'number' && color >= 0 && color <= 255);
};
if (!isValidTriplet(value)) {
if (!value.every(isValidTriplet)) {
throw new TypeError("overlayColor should either be a single [R, G, B] triplet or a list of such triplets.");
}
}
this.configParamsDirty = true;
this._overlayColor = value;
}
/**
* Gets the overlay color. The overlay color is used to draw bounding boxes and labels.
* @type {Array.<Array.<number>>}
* @private
*/
get overlayColor() {
return this._overlayColor;
}
/**
* Sets the overlay line width. The overlay line width is used to draw bounding boxes and labels.
* @type {number}
* @private
*/
set overlayLineWidth(value) {
if (typeof value !== 'number' || value <= 0) {
throw new TypeError("overlayLineWidth should be a positive number.");
}
this.configParamsDirty = true;
this._overlayLineWidth = value;
}
/**
* Gets the overlay line width. The overlay line width is used to draw bounding boxes and labels.
* @type {number}
* @private
*/
get overlayLineWidth() {
return this._overlayLineWidth;
}
/**
* Determines whether to show labels in the overlay.
* @type {boolean}
* @private
*/
set overlayShowLabels(value) {
if (typeof value !== 'boolean') {
throw new TypeError("overlayShowLabels should be a boolean value.");
}
this.configParamsDirty = true;
this._overlayShowLabels = value;
}
/**
* Gets whether to show labels in the overlay.
* @type {boolean}
* @private
*/
get overlayShowLabels() {
return this._overlayShowLabels;
}
/**
* Sets whether to show probabilities in the overlay.
* @type {boolean}
* @private
*/
set overlayShowProbabilities(value) {
if (typeof value !== 'boolean') {
throw new TypeError("overlayShowProbabilities should be a boolean value.");
}
this.configParamsDirty = true;
this._overlayShowProbabilities = value;
}
/**
* Determines whether to show probabilities in the overlay.
* @type {boolean}
* @private
*/
get overlayShowProbabilities() {
return this._overlayShowProbabilities;
}
// overlayAlpha
/**
* Sets the transparency percentage of the overlay.
* @type {number}
* @private
*/
set overlayAlpha(value) {
if (typeof value !== 'number' || value < 0 || value > 1) {
throw new TypeError("overlayAlpha should be a number between 0 and 1.");
}
this.configParamsDirty = true;
this._overlayAlpha = value;
}
/**
* Gets the transparency percentage of the overlay.
* @type {number}
* @private
*/
get overlayAlpha() {
return this._overlayAlpha;
}
/**
* Sets the font scale for the overlay.
* @type {number}
* @private
*/
set overlayFontScale(value) {
if (typeof value !== 'number' || value <= 0) {
throw new TypeError("overlayFontScale should be a positive number.");
}
this.configParamsDirty = true;
this._overlayFontScale = value;
}
/**
* Gets the font scale for the overlay.
* @type {number}
* @private
*/
get overlayFontScale() {
return this._overlayFontScale;
}
/////////////////// Input Handling Parameters ///////////////////
/**
* Sets the fill color for letterboxing the input image.
* @type {Array.<number>}
* @private
*/
set inputLetterboxFillColor(value) {
// Validation for single [R, G, B] triplet
if (!Array.isArray(value) ||
value.length !== 3 ||
!value.every(color => typeof color === 'number' && color >= 0 && color <= 255)) {
throw new TypeError("inputLetterboxFillColor should be a single [R, G, B] triplet.");
}
this.configParamsDirty = true;
this._inputLetterboxFillColor = value;
}
/**
* Gets the fill color for letterboxing the input image.
* @type {Array.<number>}
* @private
*/
get inputLetterboxFillColor() {
return this._inputLetterboxFillColor;
}
/**
* Sets the method for padding the input image. Can be one of 'stretch', 'letterbox', 'crop-first', or 'crop-last'.
* @type {string}
* @private
*/
set inputPadMethod(value) {
if (typeof value !== 'string' ||
!["stretch", "letterbox", "crop-first", "crop-last"].includes(value)) {
throw new TypeError("inputPadMethod should be one of 'stretch', 'letterbox', 'crop-first', or 'crop-last'.");
}
this.configParamsDirty = true;
this._inputPadMethod = value;
}
/**
* Gets the method for padding the input image. Can be one of 'stretch', 'letterbox', 'crop-first', or 'crop-last'.
* @type {string}
* @private
*/
get inputPadMethod() {
return this._inputPadMethod;
}
/**
* Sets whether to save the model image in the result object.
* @type {boolean}
* @private
*/
set saveModelImage(value) {
if (typeof value !== 'boolean') {
throw new TypeError("saveModelImage should be a boolean value.");
}
this.configParamsDirty = true;
this._saveModelImage = value;
}
/**
* Gets whether to save the model image in the result object.
* @type {boolean}
* @private
*/
get saveModelImage() {
return this._saveModelImage;
}
/**
* Sets the percentage of the input image to crop. The value should be between 0 and 1.
* @type {number}
* @private
*/
set inputCropPercentage(value) {
if (typeof value !== 'number' || value < 0 || value > 1) {
throw new TypeError("inputCropPercentage should be a number between 0 and 1.");
}
this.configParamsDirty = true;
this._inputCropPercentage = value;
}
/**
* Gets the percentage of the input image to crop. The value should be between 0 and 1.
* @type {number}
* @private
*/
get inputCropPercentage() {
return this._inputCropPercentage;
}
/////////////////// Inference Parameters ///////////////////
// These just wrap setModelParameter() with input handling
/**
* Sets the cloud token. The value should be a string.
* @type {string}
* @private
*/
set cloudToken(value) {
if (typeof value !== 'string') {
throw new TypeError("cloudToken should be a string.");
}
this.setModelParameter('CloudToken', value);
}
/**
* Gets the cloud token. The value is a string.
* @type {string}
* @private
*/
get cloudToken() {
return this.getModelParameter('CloudToken');
}
/**
* Sets the cloud URL. The value should be a string.
* @type {string}
* @private
*/
set cloudURL(value) {
if (typeof value !== 'string') {
throw new TypeError("cloudURL should be a string.");
}
// this.setModelParameter('CloudURL', value);
// Parse the URL and reconstruct it without the path (patch for HttpServer not expecting a path)
try {
const urlObj = new URL(value);
const urlWithoutPath = urlObj.origin; // origin includes protocol and host
this.setModelParameter('CloudURL', urlWithoutPath);
} catch (e) {
throw new DGError("Invalid URL provided.", "INVALID_URL", {}, "Invalid URL provided.");
}
}
/**
* Gets the cloud URL. The value is a string.
* @type {string}
* @private
*/
get cloudURL() {
return this.getModelParameter('CloudURL');
}
/**
* Sets the output confidence threshold. The value should be a number between 0 and 1.
* @type {number}
* @private
*/
set outputConfidenceThreshold(value) {
if (typeof value !== 'number' || value < 0 || value > 1) {
throw new TypeError("outputConfidenceThreshold should be a number between 0 and 1.");
}
this.setModelParameter('OutputConfThreshold', value);
}
/**
* Gets the output confidence threshold. The value is a number between 0 and 1.
* @type {number}
* @private
*/
get outputConfidenceThreshold() {
return this.getModelParameter('OutputConfThreshold');
}
/**
* Sets the maximum number of detections. The value should be an integer.
* @type {number}
* @private
*/
set outputMaxDetections(value) {
if (typeof value !== 'number' || !Number.isInteger(value)) {
throw new TypeError("outputMaxDetections should be an integer.");
}
this.setModelParameter('MaxDetections', value);
}
/**
* Gets the maximum number of detections. The value should be an integer.
* @type {number}
* @private
*/
get outputMaxDetections() {
return this.getModelParameter('MaxDetections');
}
/**
* Sets the maximum number of detections per class. The value should be an integer.
* @type {number}
* @private
*/
set outputMaxDetectionsPerClass(value) {
if (typeof value !== 'number' || !Number.isInteger(value)) {
throw new TypeError("outputMaxDetectionsPerClass should be an integer.");
}
this.setModelParameter('MaxDetectionsPerClass', value);
}
/**
* Sets the maximum number of detections per class. The value should be an integer.
* @type {number}
* @private
*/
get outputMaxDetectionsPerClass() {
return this.getModelParameter('MaxDetectionsPerClass');
}
/**
* Sets the maximum number of classes per detection. The value should be an integer.
* @type {number}
* @private
*/
set outputMaxClassesPerDetection(value) {
if (typeof value !== 'number' || !Number.isInteger(value)) {
throw new TypeError("outputMaxClassesPerDetection should be an integer.");
}
this.setModelParameter('MaxClassesPerDetection', value);
}
/**
* Sets the maximum number of classes per detection. The value should be an integer.
* @type {number}
* @private
*/
get outputMaxClassesPerDetection() {
return this.getModelParameter('MaxClassesPerDetection');
}
/**
* Sets the non-maximum suppression threshold. The value should be a number between 0 and 1.
* @type {number}
* @private
*/
set outputNmsThreshold(value) {
if (typeof value !== 'number' || value < 0 || value > 1) {
throw new TypeError("outputNmsThreshold should be a number between 0 and 1.");
}
this.setModelParameter('OutputNMSThreshold', value);
}
/**
* Gets the non-maximum suppression threshold. The value should be a number between 0 and 1.
* @type {number}
* @private
*/
get outputNmsThreshold() {
return this.getModelParameter('OutputNMSThreshold');
}
/**
* Sets the output pose threshold. The value should be a number between 0 and 1.
* @type {number}
* @private
*/
set outputPoseThreshold(value) {
if (typeof value !== 'number' || value < 0 || value > 1) {
throw new TypeError("outputPoseThreshold should be a number between 0 and 1.");
}
this.setModelParameter('OutputConfThreshold', value); // set OutputConfThreshold with the value (not pose threshold)
}
/**
* Gets the output pose threshold. The value should be a number between 0 and 1.
* @type {number}
* @private
*/
get outputPoseThreshold() {
return this.getModelParameter('OutputConfThreshold');
}
/**
* Sets the output post-process type. The value should be one of the specified valid string values.
* @type {string}
* @private
*/
set outputPostprocessType(value) {
const validValues = ["Classification", "Detection", "DetectionYolo", "PoseDetection", "HandDetection", "FaceDetect", "Segmentation", "BodyPix", "Python", "None"];
if (typeof value !== 'string' || !validValues.includes(value)) {
throw new TypeError("outputPostprocessType should be one of the specified valid string values.");
}
this.setModelParameter('OutputPostprocessType', value);
}
/**
* Gets the output post-process type. The value should be one of the specified valid string values.
* @type {string}
* @private
*/
get outputPostprocessType() {
return this.getModelParameter('OutputPostprocessType');
}
/**
* Sets the output top K value. The value should be an integer.
* @type {number}
* @private
*/
set outputTopK(value) {
if (typeof value !== 'number' || !Number.isInteger(value)) {
throw new TypeError("outputTopK should be an integer.");
}
this.setModelParameter('OutputTopK', value);
}
/**
* Gets the output top K value. The value should be an integer.
* @type {number}
* @private
*/
get outputTopK() {
return this.getModelParameter('OutputTopK');
}
/**
* Sets whether to use regular non-maximum suppression. The value should be a boolean.
* @type {boolean}
* @private
*/
set outputUseRegularNms(value) {
if (typeof value !== 'boolean') {
throw new TypeError("outputUseRegularNms should be a boolean.");
}
this.setModelParameter('UseRegularNMS', value);
}
/**
* Gets whether to use regular non-maximum suppression. The value should be a boolean.
* @type {boolean}
* @private
*/
get outputUseRegularNms() {
return this.getModelParameter('UseRegularNMS');
}
/**
* Waits for a condition to be met.
* @private
* @param {Function} conditionFunction - The function representing the condition to wait for.
* @param {number} [timeout=1000] - The maximum time to wait.
* @param {number} [interval=10] - The interval between checks.
* @returns {Promise<void>}
*/
waitFor(conditionFunction, timeout = 1000, interval = 10) {
const poll = resolve => {
if (conditionFunction()) resolve();
else if (timeout > 0) setTimeout(() => poll(resolve), interval);
else throw new DGError("Timed out waiting.", "WAIT_TIMEOUT", {}, "Timed out waiting.");
};
return new Promise(poll);
}
/**
* Initializes the preprocessor.
* @async
* @private
*/
async initializePreProcessor() {
// First ensure that model instance is fully initialized with params
if (!this.initialized) await this.waitFor(() => this.initialized);
this.log('(re)setting preprocessor...');
this.preProcessor = null;
this.preProcessor = new PreProcess(this.modelParams, this.constructOverlayConfig());
}
/**
* Initializes the postprocessor.
* @async
* @private
*/
async initializePostProcessor() {
// First ensure that model instance is fully initialized with params
if (!this.initialized) await this.waitFor(() => this.initialized);
this.log('(re)setting postprocessor...');
this.postProcessor = null;
this.postProcessor = new PostProcess(this.modelParams, this.constructOverlayConfig());
}
/**
* Initializes both the preprocessor and postprocessor.
* @async
* @private
*/
async initPrePostProcessors() {
await this.initializePreProcessor();
await this.initializePostProcessor();
this.configParamsDirty = false;
}
/**
* Updates a key within modelParams and sets the dirty flag. Ultimately, designed only to modify leaf nodes with primitive values within the JSON used.
* @async
* @private
* @param {string} key - The key of the parameter to set.
* @param {any} value - The value of the parameter to set.
*/
async setModelParameter(key, value) {
this.log('setModelParameter(). Attempting to update:', key, 'to value:', value);
let updated = false;
try {
// Ensure modelParams exists
if (!this.modelParams) {
throw new DGError("Model parameters are not initialized!", "MODEL_PARAMETERS_NOT_INITIALIZED", {}, "Model parameters are not initialized. Please initialize the model parameters before updating.");
}
// Check for top-level key
if (Object.prototype.hasOwnProperty.call(this.modelParams, key)) {
this.log('Top-level key found! Updating key to:', value);
this.modelParams[key] = value;
if (!this.modelConfig) {
this.modelConfig = {};
}
this.modelConfig[key] = value;
this.log('Updated modelConfig for top-level key:', this.modelConfig);
this.dirty = true;
updated = true;
} else if (key === 'CloudToken' || key === 'CloudURL') { // TEMPORARY PATCH - Cloud doesn't return FULL model params like our websocket does, so we manually add cloudURL/token if missing
this.log('Cloud token bypass-patching in new key:', key);
// same as above block, creates the key/value pair...
if (!this.modelConfig) {
this.modelConfig = {};
}
// Add the key/value pair to modelParams and modelConfig
this.modelParams[key] = value;
this.modelConfig[key] = value;
// console.log('Updated modelParams and modelConfig for key:', key, 'with value:', value);
this.dirty = true;
updated = true;
} else {
// Try setting the value for each top-level key
for (const topLevelKey in this.modelParams) {
if (this.modelParams[topLevelKey] && this.modelParams[topLevelKey][0] && Object.prototype.hasOwnProperty.call(this.modelParams[topLevelKey][0], key)) {
this.log('Key found! Updating key to:', value);
// Updating local copy of model params is now done on confirmation message from websocket in initializeSocket()
// Need to update local copy anyway, even if it will be overwritten by next lazy reload upon predict()
// This is so querying the model params after user changes parameter without performing inference
// will yield expected new model params, not old unchanged params object
this.log('setModelParameter(): Updating local modelParams copy, setting', this.modelParams[topLevelKey][0][key], 'to', value);
this.modelParams[topLevelKey][0][key] = value;
if (!this.modelConfig[topLevelKey]) {
this.modelConfig[topLevelKey] = [];
this.modelConfig[topLevelKey].push({});
}
if (this.modelConfig[topLevelKey][0]) {
this.modelConfig[topLevelKey][0][key] = value;
} else {
const newEntry = {};
newEntry[key] = value;
this.modelConfig[topLevelKey].push(newEntry);
}
this.log('Updated modelConfig:', this.modelConfig);
this.dirty = true;
updated = true;
break;
}
}
}
} catch (error) {
throw new DGError(`Failed to set a parameter: ${error}`, "SET_PARAMETER_FAILED", {}, "Failed to set a parameter.");
}
// If not updated, log an error
if (!updated) {
throw new DGError(`Failed to update the parameter. Key "${key}" not found!`, "UPDATE_PARAMETER_FAILED", { key }, `Failed to update the parameter "${key}". Please make sure the key exists.`);
}
}
/**
* Retrieves a model parameter from modelParams JSON associated with this model instance.
* @private
* @param {string} key - The key of the parameter to retrieve.
* @returns {any} The value of the parameter.
*/
getModelParameter(key) {
this.log('Entered getModelParameter(). Querying value for key:', key);
if (!this.modelParams) {
throw new DGError("Model parameters are not initialized!", "MODEL_PARAMETERS_NOT_INITIALIZED", {}, "Model parameters are not initialized. Please initialize the model parameters before querying.");
}
// Check for top-level key
if (Object.prototype.hasOwnProperty.call(this.modelParams, key)) {
this.log('Top-level key found. Value:', this.modelParams[key]);
return this.modelParams[key];
} else {
// Check in nested structures
for (const topLevelKey in this.modelParams) {
if (this.modelParams[topLevelKey] && this.modelParams[topLevelKey][0] && Object.prototype.hasOwnProperty.call(this.modelParams[topLevelKey][0], key)) {
this.log('Key found in nested structure. Value:', this.modelParams[topLevelKey][0][key]);
return this.modelParams[topLevelKey][0][key];
}
}
}
throw new DGError(`Failed to get the parameter. Key "${key}" not found!`, "GET_PARAMETER_FAILED", { key }, `Failed to get the parameter "${key}". Please make sure the key exists.`);
}
/**
* Returns a read-only copy of the model parameters.
* @returns {Object} The model parameters.
*/
modelInfo() {
if (this.socket && this.modelParams) {
return JSON.parse(JSON.stringify(this.modelParams));
} else {
throw new DGError("Model parameters are not yet initialized for this model!", "MODEL_PARAMETERS_NOT_INITIALIZED", {}, "Model parameters are not yet initialized for this model!");
}
}
/**
* Returns the label dictionary for this AIServerModel instance.
* @returns {Object} The label dictionary.
*/
labelDictionary() {
return this.labels;
}
/**
* Overlay the result onto the image frame and display it on the canvas.
* @async
* @param {Object} combinedResult - The result object combined with the original image frame. This is directly received from `predict` or `predict_batch`
* @param {string|HTMLCanvasElement} outputCanvasName - The canvas to draw the image onto. Either the canvas element or the ID of the canvas element.
* @param {boolean} [justResults=false] - Whether to show only the result overlay without the image frame.
*/
async displayResultToCanvas(combinedResult, outputCanvasName, justResults = false) {
this.log('Entered displayResultToCanvas()');
// Handle incorrect / empty result object
if (!combinedResult || !combinedResult.result) {
throw new DGError('displayResultToCanvas(): Invalid or empty result object, returning', "INVALID_RESULT_OBJECT", {}, "Invalid or empty result object. Please make sure the result object is valid.");
}
// If !combinedResult.imageFrame then it means the input was a video element
// allow it, just set justResults to true
if (!combinedResult.imageFrame) {
justResults = true;
this.log('displayResultToCanvas(): No imageFrame found in combinedResult most likely due to video element inference. Setting justResults to true.');
}
const { result, imageFrame } = combinedResult; // Destructure to extract result and imageFrame
let canvas;
// Input validation for outputCanvasName
if (!outputCanvasName || typeof outputCanvasName !== 'string' || outputCanvasName.trim() === '') {
// also accept HTMLCanvasElement
if (!(outputCanvasName instanceof HTMLCanvasElement)) {
throw new DGError('Invalid outputCanvasName parameter', "INVALID_OUTPUT_CANVAS_NAME", {}, "Invalid outputCanvasName parameter. Please provide a valid outputCanvasName.");
} else {
canvas = outputCanvasName;
}
}
if (!canvas) {
canvas = document.getElementById(outputCanvasName);
}
try {
// Check result for errors
const errorMsg = this.errorCheck(result);
if (errorMsg) {
throw new DGError(`Error in result: ${errorMsg}`, "RESULT_ERROR", { errorMsg }, "Error in result. Please check the result for errors.");
}
// letterbox details attached to result already in onmessage
this.postProcessor.displayResultToCanvas(imageFrame, result, canvas, justResults);
} catch (error) {
throw new DGError("Error in parsing result: ", "PARSE_RESULT_ERROR", {}, "Error in parsing result.");
}
}
/**
* Processes the original image and draws the results on it, return png image with overlayed results.
* @async
* @param {Object} combinedResult - The result object combined with the original image frame.
* @returns {Promise<Blob>} The processed image file as a Blob of a PNG image.
*/
async processImageFile(combinedResult) {
this.log('Entered processImageFile()');
const { result, imageFrame } = combinedResult;
try {
// Check result for errors
const errorMsg = this.errorCheck(result);
if (errorMsg) {
throw new DGError(`Error in result: ${errorMsg}`, "RESULT_ERROR", { errorMsg }, "Error in result. Please check the result for errors.");
}
if (this.configParamsDirty) {
await this.initPrePostProcessors();
}
// letterbox details attached to result already in onmessage
return this.postProcessor.processImageFile(imageFrame, result);
} catch (error) {
throw new DGError("Error in processImageFile: ", "PROCESS_IMAGE_FILE_ERROR", {}, "Error in processImageFile.");
}
}
/**
* Helper function to display image in desired canvas while keeping the aspect ratio.
* @private
* @async
* @param {Blob|File|string|HTMLImageElement|HTMLVideoElement|HTMLCanvasElement|ArrayBuffer} imageFile - The input image.
* @param {string} outputCanvasName - The name of the canvas to display the image on.
*/
async showImg(imageFile, outputCanvasName) {
this.log('Entered showImg()');
// Input frame validation / conversion
let imageFrame = await this.validateAndConvertInputFrame(imageFile);
// Input validation for outputCanvasName
if (!outputCanvasName || typeof outputCanvasName !== 'string' || outputCanvasName.trim() === '') {
throw new DGError('Invalid outputCanvasName parameter', "INVALID_OUTPUT_CANVAS_NAME", {}, "Invalid outputCanvasName parameter. Please provide a valid outputCanvasName.");
}
try {
if (this.configParamsDirty) {
await this.initPrePostProcessors();
}
return this.postProcessor.showImg(imageFrame, outputCanvasName);
} catch (error) {
throw new DGError("Error in showImg: ", "SHOW_IMAGE_ERROR", {}, "Error in showImg.");
}
}
/**
* Checks the server response for errors, similar to c++ errorCheck function.
* @private
* @param {Object} response - The server response.
* @returns {string} The error message, if any.
*/
errorCheck(response) {
// console.log('Entered errorCheck with result:', JSON.stringify(response));
// let startTime = performance.now();
if (!response)
return "response JSON is null!";
// Check for the success flag
if (Object.prototype.hasOwnProperty.call(response, 'success')) {
if (!response.success) {
let msg = Object.prototype.hasOwnProperty.call(response, 'msg') ? response.msg : "unspecified error";
throw new DGError(msg, "RESPONSE_ERROR", { msg }, "Error in server response. Please check the server response for errors.");
}
}
// also add check for the string '[ERROR]' inside the first 25 characters of the stringified response
if (JSON.stringify(response).substring(0, 25).includes('[ERROR]')) {
// We have to parse the response to get the error message as well
return new DGError("Error in response: " + response, "RESPONSE_ERROR", { response }, "Error in response.");
}
// console.log('errorCheck took', performance.now() - startTime, 'ms.');
return ""; // no error
}
/**
* Waits for the WebSocket connection to be established using promises.
* @async
* @private
* @returns {Promise<void>}
*/
waitForSocketConnection() {
return new Promise((resolve, reject) => {
const maxWait = this.MAX_SOCKET_WAIT_MS;
const intervalTime = 50;
let elapsedWait = 0;
const checkConnection = () => {
if (this.socket && this.socket.readyState === WebSocket.OPEN) {
resolve();
} else if (elapsedWait >= maxWait) {
reject(new DGError('Cannot establish WebSocket connection.', "WEBSOCKET_CONNECTION_FAILED", {}, "Failed to establish WebSocket connection."));
} else {
elapsedWait += intervalTime;
setTimeout(checkConnection, intervalTime);
}
};
checkConnection();
});
}
/**
* Creates a promise with a timeout and cancellation handling. Used for mutex lock timeouts.
* @private
* @param {number} duration - The duration of the timeout.
* @param {Function} onCancel - The function to call on cancellation.
* @returns {Promise<void>} The timeout promise.
*/
timeoutPromise(duration, onCancel) {
let timeoutId;
const promise = new Promise((resolve, reject) => {
timeoutId = setTimeout(() => {
reject(new DGError('Mutex lock timeout exceeded', "MUTEX_LOCK_TIMEOUT", {}, "Mutex lock timeout exceeded."));
onCancel();
}, duration);
});
// Attach the cancel method
promise.cancel = () => {
clearTimeout(timeoutId);
};
return promise;
}
/**
* Cleans up resources and closes the WebSocket connection.
* Does so by following a destructor-like pattern which is manually called by the user.
* Makes sure to close the WebSocket connection, stop all inferences, remove the listeners, clear async queues, and nullify all references. <br>
* Call this whenever switching models or when the model instance is no longer needed.
* @async
*/
async cleanup() {
// Set poison flag to stop further inferences
this.poison = true;
// Remove WebSocket event listeners
if (this.socket) {
this.socket.onmessage = null;
this.socket.onerror = null;
this.socket.onopen = null;
this.socket.onclose = null;
}
// Close WebSocket connection
if (this.socket && this.socket.readyState === WebSocket.OPEN) {
this.socket.close();
}
// Clear Async Queues
await this.infoQ.clear();
await this.resultQ.clear();
// Nullify references
this.preProcessor = null;
this.postProcessor = null;
this.mutex = null;
this.infoQ = null;
this.resultQ = null;
// Check and resolve/reject outstanding promises
// if (this.lastProcessedMessage && this.lastProcessedMessage instanceof Promise) {
// this.lastProcessedMessage.then(() => {}, () => {});
// }
// Reset internal states and flags
this.initialized = false;
}
}
// Export the class for use in other files
export default AIServerModel;