import AIServerModel from './aiservermodel.js';
import DGError from './dgerror.js';
const DEFAULT_CLOUD_SERVER = 'cs.degirum.com';
const DEFAULT_CLOUD_ZOO = '/degirum/public';
const DEFAULT_CLOUD_URL = `https://${DEFAULT_CLOUD_SERVER}${DEFAULT_CLOUD_ZOO}`;
const MODEL_INIT_TIMEOUT_MS = 20000;
const RESCAN_TIMEOUT_MS = 20000;
/**
* @class
* @classdesc This class is responsible for managing and loading models from a cloud-based AI model zoo for inference on an AIServer.
* It handles model discovery, fetching model information, and loading models for use.
* Use loadModel() to instantiate and configure a model for use.
* Various utility functions are available to interact with the cloud zoo and AIServer.
*
* @example <caption>Usage:</caption>
* let zoo = dg.connect('ws://localhost:8779', 'https://cs.degirum.com/degirum/public', secretToken);
* let model = await zoo.loadModel('someModel', { max_q_len: 5, callback: someFunction });
*/
class AIServerCloudZoo {
/**
* Do not call this constructor, instead use the `connect` function of the dg_sdk class to create an AIServerCloudZoo instance.
* @param {string} serverUrl - The URL of the AI server (WebSocket URL).
* @param {string} [zooUrl='https://cs.degirum.com/degirum/public'] - The URL of the cloud model zoo.
* @param {string} accessToken - The access token for the cloud zoo.
* @param {boolean} [isDummy=false] - Whether this is a dummy instance (if we only need to have model listing functionality, without connecting to an AIServer).
*/
constructor(serverUrl, zooUrl = DEFAULT_CLOUD_URL, accessToken, isDummy = false) {
console.log('AIServerCloudZoo constructor: serverUrl:', serverUrl, 'zooUrl:', zooUrl, 'accessToken:', accessToken);
this.isDummy = isDummy;
// Cloud Model Zoo URL for model management
this.zooUrl = zooUrl.endsWith('/') ? zooUrl.slice(0, -1) : zooUrl; // Remove trailing slash if present
let parsedUrl = new URL(this.zooUrl);
let path = parsedUrl.pathname;
// If the zooUrl doesn't have a path, we append the 'DEFAULT_CLOUD_ZOO' to the zooUrl variable
// We do this because construction of AIServerCloudZoo is never expected to perform local inference
if (path === '/' || path === '') {
this.zooUrl += DEFAULT_CLOUD_ZOO;
}
this.accessToken = accessToken;
this.assets = {}; // model assets
this.modelNames = []; // store model names in the cloud zoo (including ones that aren't supported on the system)
// AI Server URL for model inference
this.serverUrl = serverUrl;
// HTTP URL for AIServer REST API calls, e.g. systemInfo()
this.httpUrl = serverUrl.replace('ws://', 'http://');
// Initialize the assets list
this.isRescanComplete = false;
this.rescanZoo();
}
/**
* Returns the full model name in the format 'organization/zooname/simpleModelName'.
* @private
* @param {string} simpleModelName - The simple name of the model.
* @returns {string} The full model name.
*/
getModelFullName(simpleModelName) {
let parsedUrl = new URL(this.zooUrl); // Guaranteed to not have a trailing slash due to constructor logic
// Remove the starting slash if present
let pathname = parsedUrl.pathname.startsWith('/') ? parsedUrl.pathname.substring(1) : parsedUrl.pathname;
// console.log('getModelFullName: returning ', `${pathname}/${simpleModelName}`);
return `${pathname}/${simpleModelName}`;
}
/**
* Fetches data from the API using the provided token.
* @private
* @param {string} apiUrl - The API endpoint to fetch from.
* @param {boolean} [isOctetStream=false] - Whether to expect a blob response.
* @returns {Promise<Object|Blob>} The fetched data.
*/
async fetchWithToken(apiUrl, isOctetStream = false) {
// console.log('Entered fetchWithToken, apiUrl:', apiUrl);
const headers = {
'Accept': isOctetStream ? 'application/octet-stream' : 'application/json',
'token': this.accessToken,
};
// Parse this.zooUrl as a URL
const parsedUrl = new URL(this.zooUrl);
// Extract the first part of the URL
const firstPart = parsedUrl.origin;
// Construct the full URL by appending the cloudServerPathSuffix and then the apiUrl
const cloudServerPathSuffix = '/zoo/v1/public';
const fullUrl = `${firstPart}${cloudServerPathSuffix}${apiUrl}`;
// console.log(`fetchWithToken: initiating fetch with retry for ${fullUrl}`);
for (let attempt = 1; attempt <= 3; attempt++) {
try {
const response = await fetch(fullUrl, { headers });
// console.log('fetchWithToken: Got response:', response);
if (response.ok) {
// Check for redirect and update zooUrl accordingly
if (response.redirected && response.url.startsWith('https://')) {
this.zooUrl = response.url.substring(0, response.url.indexOf('/zoo/v1/public'));
}
return isOctetStream ? response.blob() : response.json();
} else {
// We need to throw custom message for a wrong token.
if (response.status === 401) {
console.log('Wrong token. Response:', response);
// TODO: Ensure that wrong token has custom error handling.
// We can throw a DG error with a custom error code specifically for wrong token.
const errorDetails = (await response.json()).detail || 'Invalid token value';
throw new DGError(`Unable to connect to server: ${errorDetails}`, "FETCH_WITH_TOKEN_FAILED", { status: response.status }, "Unable to connect to server.");
}
throw new DGError(`HTTP error! status: ${response.status}: ${response.statusText}`, "FETCH_WITH_TOKEN_FAILED", { status: response.status }, "HTTP error occurred.");
}
} catch (error) {
if (attempt === 3) throw new DGError(`All retries failed: ${error}`, "FETCH_WITH_TOKEN_FAILED", {}, "All retries failed.");
// Exponential backoff
await new Promise(resolve => setTimeout(resolve, Math.pow(2, attempt) * 100));
}
}
}
/**
* Rescans the cloud zoo to update the list of available models.
* Fetches from https://cs.degirum.com/zoo/v1/public/models/ORGANIZATION/ZOONAME
* @private
* @returns {Promise<void>}
*/
async rescanZoo() {
try {
const parsedUrl = new URL(this.zooUrl);
const path = parsedUrl.pathname;
const modelsInfo = await this.fetchWithToken('/models' + path);
let systemDeviceKeys = new Set();
if (!this.isDummy) {
try {
const systemDeviceTypes = await this.systemSupportedDeviceTypes();
systemDeviceKeys = new Set(systemDeviceTypes);
} catch (error) {
console.error('Error fetching system information for AIServer at', this.serverUrl, error);
// TODO: THIS BREAKS EVERYTHING if it's turned into a DGError.
this.isRescanComplete = true;
return;
}
}
if (!modelsInfo.error) {
this.assets = {};
this.modelNames = [];
for (const [modelName, modelDetails] of Object.entries(modelsInfo)) {
this.modelNames.push(modelName);
if (this.isDummy) {
this.assets[modelName] = modelDetails;
continue;
}
// Bypass audio models.
if (Array.isArray(modelDetails?.PRE_PROCESS) && modelDetails.PRE_PROCESS.length > 0 && modelDetails.PRE_PROCESS[0]?.InputType === 'Audio') {
console.log('Skipping audio model', modelName);
continue;
}
let modelSupportedTypes;
try {
// Get SupportedDeviceTypes if available
modelSupportedTypes = modelDetails.DEVICE[0].SupportedDeviceTypes;
if (modelSupportedTypes) {
modelSupportedTypes = modelSupportedTypes.split(',').map(type => type.trim());
} else {
// Fallback to default DEVICE type check
const modelRuntime = modelDetails.DEVICE[0].RuntimeAgent;
const modelDevice = modelDetails.DEVICE[0].DeviceType;
const capabilityKey = `${modelRuntime}/${modelDevice}`;
if (systemDeviceKeys.has(capabilityKey)) {
this.assets[modelName] = modelDetails;
} else {
console.log('Skipping model', modelName, 'because it requires', capabilityKey, 'which is not available on the system.');
}
continue;
}
} catch (error) {
console.error('Error processing SupportedDeviceTypes for model', modelName, error);
continue;
}
// Perform intersection check with SupportedDeviceTypes if available
const matchingDevices = this.matchSupportedDevices(modelSupportedTypes, Array.from(systemDeviceKeys));
if (matchingDevices.length > 0) {
this.assets[modelName] = modelDetails;
} else {
console.log('Skipping model', modelName, 'because none of the SupportedDeviceTypes match the system devices.');
}
}
this.isRescanComplete = true; // Mark rescan as complete
console.log('rescanZoo: Rescan complete. Model list:', this.assets);
} else {
this.isRescanComplete = true;
throw new DGError(modelsInfo.error, "RESCAN_ZOO_FAILED", {}, "Error occurred while rescanning zoo.");
}
} catch (error) {
this.isRescanComplete = true; // Still mark rescan as complete to avoid indefinite waiting
throw new DGError('Error rescanning zoo: ' + error, "RESCAN_ZOO_FAILED", { error: error.message }, "Error occurred while rescanning zoo.");
}
}
/**
* Determines if a system device type matches any of the model's supported device types, considering wildcards.
* @private
* @param {Array<string>} modelSupportedTypes - An array of strings representing the device types supported by the model.
* Example: ["OPENVINO/*", "TENSORRT/*", "ONNX/*"]
* @param {Array<string>} systemDeviceTypes - An array of strings representing the device types available on the system.
* Example: ["OPENVINO/CPU", "TENSORRT/GPU", "ONNX/CPU"]
*
* @returns {Array<string>} - An array of strings representing the intersection of modelSupportedTypes and systemDeviceTypes,
* with wildcards considered.
* Example: If modelSupportedTypes is ["OPENVINO/*", "TENSORRT/*"] and systemDeviceTypes is ["OPENVINO/CPU", "TENSORRT/GPU"],
* it returns ["OPENVINO/CPU", "TENSORRT/GPU"].
*/
matchSupportedDevices(modelSupportedTypes, systemDeviceTypes) {
const matchesWildcard = (pattern, type) => {
const [patternRuntime, patternDevice] = pattern.split('/');
const [typeRuntime, typeDevice] = type.split('/');
const runtimeMatches = patternRuntime === '*' || patternRuntime === typeRuntime;
const deviceMatches = patternDevice === '*' || patternDevice === typeDevice;
return runtimeMatches && deviceMatches;
};
return systemDeviceTypes.filter(systemType =>
modelSupportedTypes.some(modelType => matchesWildcard(modelType, systemType))
);
}
/**
* Fetches the list of zoos for a given organization.
* @param {string} organization - The name of the organization.
* @returns {Promise<Object>} The list of zoos.
*/
async getZoos(organization) {
if (!organization) {
throw new DGError('Organization name is required to get zoos.', "GET_ZOOS_FAILED", {}, "Organization name is required to get zoos.");
}
try {
// Fetch the list of zoos using the organization name
const zoos = await this.fetchWithToken(`/zoos/${organization}`);
return zoos;
} catch (error) {
// Log and rethrow the error.
throw new DGError(`Error fetching zoos for organization ${organization}. Error: ${error.message}`, "GET_ZOOS_FAILED", {}, "Error occurred while fetching zoos for organization.");
}
}
/**
* Fetches the labels for a specific model from the cloud zoo.
* queries https://cs.degirum.com/zoo/v1/public/models/degirum/public/example_model/dictionary
* @param {string} modelName - The name of the model.
* @returns {Promise<Object>} The model's label dictionary.
*/
async getModelLabels(modelName) {
// console.log('Entered getModelLabels');
try {
// Use getModelFullName to construct the full path for the model dictionary.
const fullModelName = this.getModelFullName(modelName);
const dictionaryPath = `/models/${fullModelName}/dictionary`;
// Fetch the labels using the full path.
const labels = await this.fetchWithToken(dictionaryPath);
// console.log('getModelLabels: Got labels:', labels);
return labels;
} catch (error) {
// Log and rethrow the error.
throw new DGError(`Error fetching labels for model ${modelName}:`, "GET_MODEL_LABELS_FAILED", {}, "Error occurred while fetching labels for model.");
}
}
/**
* Fetches system information from the AI server.
* @returns {Promise<Object>} The system information.
*/
async systemInfo() {
try {
const response = await fetch(`${this.httpUrl}/v1/system_info`);
// return response.json() if it's ok
if (response.ok) {
return response.json();
} else {
throw new DGError(`HTTP error! status: ${response.status}: ${response.statusText}`, "SYSTEM_INFO_FAILED", { status: response.status }, "HTTP error occurred.");
}
} catch (error) {
throw new DGError(`No AIServer was found at ${this.httpUrl}. Please check the URL.`, "NETWORK_ERROR", { url: this.httpUrl }, "Please check your network connection and the URL.", "Verify the AI Server IP / port and ensure that it is running.");
// throw new DGError('systemInfo fetch failed: ' + error, "SYSTEM_INFO_FAILED", {}, "Error occurred while fetching system info.");
}
}
/**
* Downloads a model from the cloud zoo.
* fetches from https://cs.degirum.com/zoo/v1/public/models/degirum/public/example_model
* @param {string} modelName - The name of the model to download.
* @returns {Promise<Blob>} The downloaded model as a Blob.
*/
async downloadModel(modelName) {
const fullModelName = this.getModelFullName(modelName);
const modelDownloadPath = `/models/${fullModelName}`;
try {
const modelBlob = await this.fetchWithToken(modelDownloadPath, true); // true to accept blob response
// await saveBlobToFile(modelBlob, destRootPath);
// Check if the blob is valid
if (modelBlob && modelBlob.size > 0 && modelBlob.type) {
console.log(`Model ${modelName} downloaded successfully. Blob size: ${modelBlob.size} bytes, Type: ${modelBlob.type}`);
return modelBlob;
} else {
throw new DGError(`Downloaded blob for model ${modelName} is invalid.`, "DOWNLOAD_MODEL_FAILED", {}, "Downloaded blob for model is invalid.");
}
} catch (error) {
throw new DGError(`Error downloading model ${modelName}:`, "DOWNLOAD_MODEL_FAILED", {}, "Error occurred while downloading model.");
}
}
/**
* Lists all available models in the cloud zoo.
* @example
* let models = await zoo.listModels();
* let modelNames = Object.keys(models);
* @returns {Promise<Object>} An object containing the available models and their params.
*/
async listModels() {
try {
await waitUntil(() => this.isRescanComplete, RESCAN_TIMEOUT_MS);
} catch (error) {
throw new DGError('aiservercloudzoo.js: listModels() timed out.', "LIST_MODELS_FAILED", {}, "Timeout occurred while listing models.");
}
return this.assets;
}
/**
* Fetches the system supported device types from the AI server.
* @private
* @returns {Promise<Array<string>>} A promise that resolves to an array of strings in "RUNTIME/DEVICE" format.
*/
async systemSupportedDeviceTypes() {
// private function, to be called by dg_sdk() instead.
try {
let systemInfo;
try {
systemInfo = await this.systemInfo();
} catch (error) {
throw new DGError('Error fetching system information for AIServer at' + this.httpUrl + ' : ' + error, "FETCH_MODELS_FAILED", {}, "Error fetching sys info.");
}
if (!systemInfo.Devices) {
throw new DGError('No devices found in system info.', "FETCH_MODELS_FAILED", {}, "No devices found in system info.");
}
return Object.keys(systemInfo.Devices);
} catch (error) {
throw new DGError('systemSupportedDeviceTypes failed: ' + error, "SYSTEM_INFO_FAILED", {}, "Error occurred while fetching system info.");
}
}
/**
* Loads a model from the cloud zoo and prepares it for inference.
* @example
* let model = await zoo.loadModel('someModel', { inputCropPercentage: 0.5, callback: someFunction } );
* @param {string} modelName - The name of the model to load.
* @param {Object} [options={}] - Additional options for loading the model.
* @param {number} [options.max_q_len=10] - The maximum length of the internal inference queue.
* @param {Function} [options.callback=null] - A callback function to handle inference results.
* @returns {Promise<AIServerModel>} The loaded model instance.
*/
async loadModel(modelName, options = {}) {
// console.log('AIServerCloudZoo: Entered loadModel()');
if (this.isDummy) {
throw new DGError("Model loading is not supported on dummy instances.", "DUMMY_INSTANCE", {}, "Model loading is not supported on dummy instances.");
}
let startTime = performance.now();
// Default values and destructuring of options
const {
max_q_len = 10,
callback = null,
...additionalParams
} = options;
// Validate max_q_len
if (!Number.isInteger(max_q_len) || max_q_len <= 0) {
throw new DGError("Invalid value for max_q_len: It should be a positive integer.", "INVALID_INPUT", { parameter: "max_q_len", value: max_q_len }, "Please ensure max_q_len is a positive integer.");
}
// Validate callback
if (callback !== null && typeof callback !== "function") {
throw new DGError("Invalid value for callback: It should be a function.", "INVALID_INPUT", { parameter: "callback", value: callback }, "Please ensure callback is a function.");
}
try {
await waitUntil(() => this.isRescanComplete, RESCAN_TIMEOUT_MS);
} catch (error) {
throw new DGError('aiservercloudzoo.js: rescan of models timed out within loadModel().', "RESCAN_MODELS_FAILED", {}, "Timeout occurred while waiting for models to be fetched.");
}
// Here we must ping the AIServer with a call to systemInfo() to ensure that the server is up and running
// If the server is not up and running, we should throw an error
try {
await this.systemInfo();
} catch (error) {
throw new DGError('No response from AIServer at ' + this.serverUrl + '. Please check the server URL and ensure that the AIServer is running.', "NETWORK_ERROR", {}, "AIServer is not responding.");
}
// console.log('Using cached assets to find modelParams:', this.assets);
// Use the cached assets from rescanZoo()
let modelParams;
let matchingModelName;
// Find the model in the assets, which we originally got from calling rescanZoo()
for (let key in this.assets) {
if (key === modelName || key.includes(modelName)) {
modelParams = this.assets[key];
matchingModelName = key; // Save the full model name
break;
}
}
if (!modelParams) {
// First, we check to see if this.modelNames has the modelName
if (!this.modelNames.includes(modelName)) {
throw new DGError(`Model not found in the cloud zoo: ${modelName}. List of supported models: ${Object.keys(this.assets)}`, "MODEL_NOT_FOUND", { modelName: modelName }, "Model not found in the cloud zoo.");
} else {
// Model found, but not supported on the system
throw new DGError(`Model exists in the zoo, but is not supported on the system: ${modelName}`, "MODEL_NOT_SUPPORTED", { modelName: modelName }, "Model exists in the zoo, but is not supported on the system.");
}
} else {
// console.log('Model found:', matchingModelName, modelParams);
}
// construct the extended model name using the cloud zoo path and the simple model name...
const extendedModelName = this.getModelFullName(matchingModelName);
console.log(`Loading model: ${matchingModelName}, queue length: ${max_q_len}, callback specified:`, !(callback === null));
// Download Label Dictionary
let labels = await this.getModelLabels(matchingModelName);
// make a deep copy of the modelParams first...
const deepCopiedModelParams = JSON.parse(JSON.stringify(modelParams));
// We attach the cloudURL and cloudToken to additionalParams, not to model params
additionalParams.cloudURL = this.zooUrl;
additionalParams.cloudToken = this.accessToken;
let systemDeviceTypes = await this.systemSupportedDeviceTypes(); // array of 'RUNTIME/DEVICE' strings
// pack all arguments into a single object for AIServerModel
const modelCreationParams = {
modelName: extendedModelName,
serverUrl: this.serverUrl,
modelParams: deepCopiedModelParams,
max_q_len: max_q_len,
callback: callback,
labels: labels,
systemDeviceTypes: systemDeviceTypes
};
// Create the new AIServerModel with the parameters
const model = new AIServerModel(modelCreationParams, additionalParams);
// Wait for model.initialized to be true, then return model
try {
await waitUntil(() => model.initialized, MODEL_INIT_TIMEOUT_MS);
console.log('aiservercloudzoo.js: loadModel() took', performance.now() - startTime, 'ms');
return model;
} catch (error) {
throw new DGError('aiservercloudzoo: Timeout occurred while waiting for the model to initialize!', "LOAD_MODEL_FAILED", {}, "Timeout occurred while waiting for the model to initialize.", "Check the connection to the AIServer / internet.");
}
}
}
/**
* Waits until a specified condition is met or a timeout occurs.
* @private
* @param {Function} predicate - A function that returns a boolean indicating whether the condition is met.
* @param {number} [timeout=10000] - The maximum time to wait for the condition to be met, in milliseconds.
* @param {number} [interval=10] - The interval at which to check the condition, in milliseconds.
* @returns {Promise<void>} A promise that resolves when the condition is met or rejects if the timeout is reached.
*/
function waitUntil(predicate, timeout = 10000, interval = 10) {
const startTime = Date.now();
return new Promise((resolve, reject) => {
const checkCondition = () => {
if (predicate()) {
resolve();
} else if (Date.now() - startTime > timeout) {
reject(new Error('Timed out waiting for condition'));
} else {
setTimeout(checkCondition, interval);
}
};
checkCondition();
});
}
export default AIServerCloudZoo;