import AIServerModel from './aiservermodel.js';
import DGError from './dgerror.js';
const MODEL_INIT_TIMEOUT_MS = 5000;
/**
* @class
* @classdesc This class is responsible for loading models for inference on an AIServer.<br>
* Use loadModel() to instantiate and configure a model for use.<br>
* Use various utility functions to get AIServer information using REST API and manage models.<br>
*
* @example <caption>Usage:</caption>
* let zoo = dg.connect('ws://localhost:8779');
* let model = await zoo.loadModel('someModel', { inputCropPercentage: 0.5, saveModelImage: true, outputMaxDetectionsPerClass: 7 } );
*/
class AIServerZoo {
/**
* Note: do not call this constructor, instead use the `connect` function of the dg_sdk class to create an AIServerZoo instance.
* @constructs AIServerZoo
* @param {string} serverUrl - The URL of the AI server. Will always be 'ws:://something:port'
*/
constructor(serverUrl) {
this.serverUrl = serverUrl;
this.httpUrl = serverUrl.replace('ws://', 'http://');
}
/**
* Loads a model from the AIServer with specific options. The model instance can be created with a
* callback function, which will be called for every predict result instead of returning a promise. The
* max_q_len parameter can be set to be able to control the maximum length of the internal inference queue.
* @example
* let model = await zoo.loadModel('someModel', { inputCropPercentage: 0.5, callback: someFunction } );
* @param {string} modelName - The name of the model to load.
* @param {Object} [options={}] - Additional options to pass to the AIServerModel class. These options will be set on the model instance.
* @param {Function} [options.callback=null] - The callback function to call for each predict result.
* @param {number} [options.max_q_len=10] - The maximum length of internal inference queue for the AIServerModel class.
* @returns {Promise<AIServerModel>} The loaded model instance.
*/
async loadModel(modelName, options = {}) {
// console.log('AIServerZoo: Entered loadModel()');
// Default values
const {
max_q_len = 10,
callback = null,
...additionalParams
} = options;
// Validate max_q_len
if (!Number.isInteger(max_q_len) || max_q_len <= 0) {
throw new DGError("Invalid value for max_q_len: It should be a positive integer.", "INVALID_INPUT", { parameter: "max_q_len", value: max_q_len }, "Please ensure max_q_len is a positive integer.");
}
// Validate callback
if (callback !== null && typeof callback !== "function") {
throw new DGError("Invalid value for callback: It should be a function.", "INVALID_CALLBACK", {}, "Please provide a valid callback function.");
}
// Here we must ping the AIServer with a call to systemInfo() to ensure that the server is up and running
// If the server is not up and running, we should throw an error
try {
await this.systemInfo();
} catch (error) {
throw new DGError('No response from AIServer at ' + this.serverUrl + '. Please check the server URL and ensure that the AIServer is running.', "NETWORK_ERROR", {}, "AIServer is not responding.");
}
// Fetch the list of available models
const modelsList = await this.listModels();
if (!modelsList) {
throw new DGError('No models were found at the specified URL.', "NO_MODELS_FOUND", {}, "No models found at the specified URL.");
}
// Initialize variables to hold matching model parameters and name
let modelParams = null;
let matchingModelName = null;
// Search for an exact match or substring match
for (let key in modelsList) {
if (key === modelName || key.includes(modelName)) {
modelParams = modelsList[key];
matchingModelName = key; // Save the full model name
break;
}
}
// If no match is found
if (!modelParams) {
throw new DGError(`Model matching substring ${modelName} does not exist in the model list, list of models: ${Object.keys(modelsList)}`, "MODEL_NOT_FOUND", { modelName }, "The specified model was not found in the model list.");
}
console.log(`Loading model: ${matchingModelName}, queue length: ${max_q_len}, callback specified:`, !(callback === null));
let labels = await this.getModelLabels(matchingModelName);
// We also pass the list of supported device keys on this AIServer to the model
let systemDeviceTypes = await this.systemSupportedDeviceTypes(); // array of 'RUNTIME/DEVICE' strings
// pack all arguments into a single object
const modelCreationParams = {
modelName: matchingModelName,
serverUrl: this.serverUrl,
modelParams: modelParams,
max_q_len: max_q_len,
callback: callback,
labels: labels,
systemDeviceTypes: systemDeviceTypes
};
// If a match is found, create the new Model with max_q_len and callback if provided:
const model = new AIServerModel(modelCreationParams, additionalParams);
// Wait for model.initialized to be true, then return model
try {
await waitUntil(() => model.initialized, MODEL_INIT_TIMEOUT_MS);
} catch (error) {
throw new DGError('Timeout occurred while waiting for the model to initialize.', "INITIALIZATION_TIMEOUT", { modelName: matchingModelName }, "The model could not be initialized in time.");
}
return model;
}
/**
* Lists all available models on the AI server as a collection of objects (model names) whose values are the model parameters for that model.
* @example
* let models = await zoo.listModels();
* let modelNames = Object.keys(models);
* @returns {Promise<Object>} A promise that resolves to an object containing the model names as keys and their parameters as values.
*/
async listModels() {
const url = `${this.httpUrl}/v1/modelzoo`;
try {
const systemDeviceKeys = new Set(await this.systemSupportedDeviceTypes());
const response = await fetch(url); // Sends HTTP GET request
if (response.ok) {
const modelsInfo = await response.json(); // Get the response JSON
if (!modelsInfo) {
return {};
}
if (modelsInfo.error) {
throw new DGError(modelsInfo.error, "FETCH_MODELS_FAILED", {}, "Error occurred while fetching model list.");
}
let filteredModels = {};
for (const [modelName, modelDetails] of Object.entries(modelsInfo)) {
let modelSupportedTypes;
try {
// Get SupportedDeviceTypes if available
modelSupportedTypes = modelDetails.DEVICE[0].SupportedDeviceTypes;
console.log('aiserverzoo listModels: modelSupportedTypes:', modelSupportedTypes);
if (modelSupportedTypes) {
modelSupportedTypes = modelSupportedTypes.split(',').map(type => type.trim());
} else {
// Fallback to default DEVICE type check
const modelRuntime = modelDetails.DEVICE[0].RuntimeAgent;
const modelDevice = modelDetails.DEVICE[0].DeviceType;
const capabilityKey = `${modelRuntime}/${modelDevice}`;
if (systemDeviceKeys.has(capabilityKey)) {
filteredModels[modelName] = modelDetails;
} else {
console.log('Filtering out model', modelName, 'because it requires', capabilityKey, 'which is not available on the system.');
}
continue;
}
} catch (error) {
console.error('Error processing SupportedDeviceTypes for model', modelName, error);
continue;
}
// Perform intersection check with SupportedDeviceTypes if available
const matchingDevices = this.matchSupportedDevices(modelSupportedTypes, Array.from(systemDeviceKeys));
if (matchingDevices.length > 0) {
filteredModels[modelName] = modelDetails;
} else {
console.log('Filtering out model', modelName, 'because none of the SupportedDeviceTypes match the system devices.');
}
}
return filteredModels;
} else {
throw new DGError(`Failed to fetch models list. HTTP status: ${response.status}`, "FETCH_MODELS_FAILED", { status: response.status }, "Failed to fetch the list of available models.");
}
} catch (error) {
console.error(error);
throw new DGError(`No AIServer was found at ${this.httpUrl}. Please check the URL.`, "NETWORK_ERROR", { url: this.httpUrl }, "Please check your network connection and the URL.");
}
}
/**
* Determines if a system device type matches any of the model's supported device types, considering wildcards.
* @private
* @param {Array<string>} modelSupportedTypes - An array of strings representing the device types supported by the model.
* Example: ["OPENVINO/*", "TENSORRT/*", "ONNX/*"]
* @param {Array<string>} systemDeviceTypes - An array of strings representing the device types available on the system.
* Example: ["OPENVINO/CPU", "TENSORRT/GPU", "ONNX/CPU"]
*
* @returns {Array<string>} - An array of strings representing the intersection of modelSupportedTypes and systemDeviceTypes,
* with wildcards considered.
* Example: If modelSupportedTypes is ["OPENVINO/*", "TENSORRT/*"] and systemDeviceTypes is ["OPENVINO/CPU", "TENSORRT/GPU"],
* it returns ["OPENVINO/CPU", "TENSORRT/GPU"].
*/
matchSupportedDevices(modelSupportedTypes, systemDeviceTypes) {
const matchesWildcard = (pattern, type) => {
const [patternRuntime, patternDevice] = pattern.split('/');
const [typeRuntime, typeDevice] = type.split('/');
const runtimeMatches = patternRuntime === '*' || patternRuntime === typeRuntime;
const deviceMatches = patternDevice === '*' || patternDevice === typeDevice;
return runtimeMatches && deviceMatches;
};
return systemDeviceTypes.filter(systemType =>
modelSupportedTypes.some(modelType => matchesWildcard(modelType, systemType))
);
}
/**
* Gets the default model parameters for a specified model.
* @param {string} modelName - The name of the model to retrieve information for.
* @returns {Promise<Object|null>} A promise that resolves to the default model parameters if found, or null if not found.
*/
async getModelInfo(modelName) {
const modelsList = await this.listModels();
let modelInfo = null;
for (let model of modelsList) {
if (model[modelName]) {
modelInfo = model[modelName];
break;
}
}
return modelInfo;
}
/**
* Fetches the system information from the AI server.
* @returns {Promise<Object>} Contains devices object and software version string.
*/
async systemInfo() {
try {
const response = await fetch(`${this.httpUrl}/v1/system_info`);
// return response.json() if it's ok
if (response.ok) {
return response.json();
} else {
throw new DGError(`HTTP error! status: ${response.status}: ${response.statusText}`, "SYSTEM_INFO_FAILED", { status: response.status }, "HTTP error occurred.");
}
} catch (error) {
throw new DGError(`No AIServer was found at ${this.httpUrl}. Please check the URL.`, "NETWORK_ERROR", { url: this.httpUrl }, "Please check your network connection and the URL.", "Verify the AI Server IP / port and ensure that it is running.");
// throw new DGError('systemInfo fetch failed: ' + error, "SYSTEM_INFO_FAILED", {}, "Error occurred while fetching system info.");
}
}
/**
* Fetches the system supported device types from the AI server.
* @private
* @returns {Promise<Array<string>>} A promise that resolves to an array of strings in "RUNTIME/DEVICE" format.
*/
async systemSupportedDeviceTypes() {
// private function, to be called by dg_sdk() instead.
try {
let systemInfo;
try {
systemInfo = await this.systemInfo();
} catch (error) {
throw new DGError('Error fetching system information for AIServer at' + this.httpUrl + ' : ' + error, "FETCH_MODELS_FAILED", {}, "Error fetching sys info.");
}
if (!systemInfo.Devices) {
throw new DGError('No devices found in system info.', "FETCH_MODELS_FAILED", {}, "No devices found in system info.");
}
return Object.keys(systemInfo.Devices);
} catch (error) {
throw new DGError('systemSupportedDeviceTypes failed: ' + error, "SYSTEM_INFO_FAILED", {}, "Error occurred while fetching system info.");
}
}
/**
* Fetches the labels of a specific model by its name.
* @param {string} name - The name of the model to retrieve labels for.
* @returns {Promise<Object>} A promise that resolves to the model's label dictionary.
*/
async getModelLabels(name) {
// console.log('AIServerZoo: Entered getModelLabels()');
const url = `${this.httpUrl}/v1/label_dictionary/${name}`;
try {
const response = await fetch(url);
if (response.ok) {
return await response.json();
} else {
throw new DGError(`Failed to get model details. HTTP status: ${response.status}`, "FETCH_MODEL_DETAILS_FAILED", { status: response.status }, "Failed to fetch model details.");
}
} catch (error) {
throw new DGError(`Failed to fetch model details.`, "FETCH_MODEL_DETAILS_FAILED", {}, "Failed to fetch model details.");
}
}
/**
* Sends trace management data to the server.
* @param {Object} data - The trace management data in JSON format. MUST USE THIS FORMAT: https://degirum.atlassian.net/wiki/spaces/SD/pages/1586298881/AI+Server+Protocol+Description
* @returns {Promise<Object>} A promise that resolves to the server's response.
*/
async traceManage(data) {
// console.log('AIServerZoo: Entered traceManage()');
const url = `${this.httpUrl}/v1/trace_manage`;
try {
const response = await fetch(url, {
method: 'POST',
headers: {
'Content-Type': 'application/json'
},
body: JSON.stringify(data)
});
if (response.ok) {
return await response.json();
} else {
throw new DGError(`Failed to manage trace. HTTP status: ${response.status}`, "TRACE_MANAGEMENT_FAILED", { status: response.status }, "Failed to manage trace.");
}
} catch (error) {
throw new DGError(`Failed to manage trace.`, "TRACE_MANAGEMENT_FAILED", {}, "Failed to manage trace.");
}
}
/**
* Sends a model zoo management request to the server.
* Currently, it works with the 'rescan' string in the request JSON.
* @param {Object} data - The model zoo management data in JSON format.
* @returns {Promise<Object>} A promise that resolves to the server's response.
*/
async zooManage(data) {
// console.log('AIServerZoo: Entered zooManage()');
const url = `${this.httpUrl}/v1/zoo_manage`;
try {
const response = await fetch(url, {
method: 'POST',
headers: {
'Content-Type': 'application/json'
},
body: JSON.stringify(data)
});
if (response.ok) {
return await response.json();
} else {
throw new DGError(`Failed to manage model zoo. HTTP status: ${response.status}`, "ZOO_MANAGEMENT_FAILED", { status: response.status }, "Failed to manage model zoo.");
}
} catch (error) {
throw new DGError(`Failed to manage model zoo.`, "ZOO_MANAGEMENT_FAILED", {}, "Failed to manage model zoo.");
}
}
/**
* Sends a request to make the server sleep for a specified amount of time.
* Useful for pinging the server.
* @param {number} ms - The amount of time in milliseconds for the server to sleep.
* @returns {Promise<Response>} A promise that resolves to the server's response.
*/
async sleep(ms) {
// console.log('AIServerZoo: Entered sleep()');
const url = `${this.httpUrl}/v1/sleep/${ms}`;
try {
const response = await fetch(url, {
method: 'POST'
});
if (!response.ok) {
throw new DGError(`Failed to make the server sleep. HTTP status: ${response.status}`, "SLEEP_FAILED", { status: response.status }, "Failed to make the server sleep.");
}
return response;
} catch (error) {
throw new DGError(`Failed to make the server sleep.`, "SLEEP_FAILED", {}, "Failed to make the server sleep.");
}
}
/**
* Sends a request to shut down the server.
* @returns {Promise<void>} A promise that resolves when the server has been shut down.
*/
async shutdown() {
// console.log('AIServerZoo: Entered shutdown()');
const url = `${this.httpUrl}/v1/shutdown`;
try {
const response = await fetch(url, {
method: 'POST'
});
if (!response.ok) {
throw new DGError(`Failed to shut down the server. HTTP status: ${response.status}`, "SHUTDOWN_FAILED", { status: response.status }, "Failed to shut down the server.");
}
} catch (error) {
throw new DGError(`Failed to shut down the server.`, "SHUTDOWN_FAILED", {}, "Failed to shut down the server.");
}
}
}
/**
* Waits until a specified condition is met or a timeout occurs.
*
* @private
* @function
* @param {Function} condition - A function that returns a boolean indicating whether the condition is met.
* @param {number} timeout - The maximum time to wait for the condition to be met, in milliseconds.
* @param {number} [interval=100] - The interval at which to check the condition, in milliseconds.
* @returns {Promise<void>} A promise that resolves when the condition is met or rejects if the timeout is reached.
*
* @example
* // Wait until a variable is set to true
* await waitUntil(() => some.thing === true, 5000);
* expect(some.thing).toBe(true)
*/
function waitUntil(predicate, timeout = 1000, interval = 10) {
const startTime = Date.now()
return new Promise((resolve, reject) => {
const checkCondition = () => {
if (predicate()) {
//console.log('waitUntil took:', Date.now() - startTime)
resolve()
} else if (Date.now() - startTime > timeout) {
reject(new Error('Timed out waiting for condition'))
} else {
setTimeout(checkCondition, interval)
}
}
checkCondition()
})
}
// Export the class for use in other files
export default AIServerZoo;