import CloudServerModel from './cloudservermodel.js';
import DGError from './dgerror.js';
const DEFAULT_CLOUD_SERVER = 'cs.degirum.com';
const DEFAULT_CLOUD_ZOO = '/degirum/public';
const DEFAULT_CLOUD_URL = `https://${DEFAULT_CLOUD_SERVER}${DEFAULT_CLOUD_ZOO}`;
const MODEL_INIT_TIMEOUT_MS = 20000;
const RESCAN_TIMEOUT_MS = 20000;
/**
* @class
* @classdesc This class is responsible for managing and loading models from a cloud-based AI model zoo, for inference in the cloud.
* It handles model discovery, fetching model information, and loading models for use.
* Use loadModel() to instantiate and configure a model for use.
* Various utility functions are available to interact with the cloud zoo.
*
* @example <caption>Usage:</caption>
* let zoo = dg.connect('cloud', 'https://cs.degirum.com/degirum/public', secretToken);
* let model = await zoo.loadModel('someModel', { max_q_len: 5, callback: someFunction });
*/
class CloudServerCloudZoo {
/**
* Do not call this constructor, instead use the `connect` function of the dg_sdk class to create a CloudServerCloudZoo instance.
* @param {string} [zooUrl=DEFAULT_CLOUD_URL] - Expected format: "https://<cloud server URL>[/<zoo URL>]"
* @param {string} accessToken - The access token for the cloud zoo.
* @param {boolean} [dummy=false] - Dummy zoo instance - only for system info call
*/
constructor(zooUrl = DEFAULT_CLOUD_URL, accessToken, dummy = false) {
// console.log('CloudServerCloudZoo: Constructor. fullUrl:', zooUrl, 'accessToken:', accessToken, 'dummy:', dummy);
if (dummy) accessToken = 'dummy';
// Validate accessToken:
if (!accessToken) throw new DGError('CloudServerCloudZoo: accessToken is required.', "INVALID_ARGS", { accessToken }, "Access token is required.", "Please provide a valid access token.");
// fullUrl: https://cs.foo.bar.com/org/name
this.fullUrl = zooUrl.endsWith('/') ? zooUrl.slice(0, -1) : zooUrl; // Remove trailing slash
this.accessToken = accessToken;
this.assets = {};
this.modelNames = []; // store model names in the cloud zoo (including ones that aren't supported on the system)
this.isRescanComplete = false;
// Create url (https://cs.foo.bar.com) and zoo_url (/org/name) from the full URL (https://cs.foo.bar.com/org/name)
let parsedUrl;
try { parsedUrl = new URL(this.fullUrl); } catch (err) {
throw new DGError('CloudServerCloudZoo: Invalid URL provided. Please provide a valid URL.', "INVALID_URL", { url: this.fullUrl }, "Invalid URL provided.");
}
const path = parsedUrl.pathname;
const cloudServer = parsedUrl.origin;
this.url = cloudServer; // e.g. https://cs.foo.bar.com
this._zoo_url = path; // e.g. /org/name
if (!dummy) this.rescanZoo();
}
/**
* Returns the full model name in the format 'organization/zooname/simpleModelName'.
* @private
* @param {string} simpleModelName - The simple name of the model, e.g. 'someModel'
* @returns {string} The full model name.
*/
getModelFullName(simpleModelName) {
if (!simpleModelName) throw new DGError('getModelFullName: simpleModelName is required.', "MISSING_ARGS", {}, "Model name is required.");
let parsedUrl = new URL(this.fullUrl);
let pathname = parsedUrl.pathname.startsWith('/') ? parsedUrl.pathname.substring(1) : parsedUrl.pathname;
// console.log('getModelFullName: returning', `${pathname}/${simpleModelName}`);
return `${pathname}/${simpleModelName}`;
}
/**
* Fetches data from the API with token by constructing a URL with https://foo.bar.com/zoo/v1/public/${apiUrl}
* @private
* @param {string} apiUrl - The API endpoint to fetch from.
* @param {boolean} [isOctetStream=false] - Whether to expect a blob response.
* @param {boolean} [noPrefix=false] - Whether to omit the '/zoo/v1/public' prefix.
* @returns {Promise<Object|Blob>} The fetched data.
*/
async fetchWithToken(apiUrl, isOctetStream = false, noPrefix = false) {
// console.log('Entered fetchWithToken, apiUrl:', apiUrl, 'and token:', this.accessToken);
const headers = {
'Accept': isOctetStream ? 'application/octet-stream' : 'application/json',
'token': this.accessToken,
};
// Parse this.fullUrl as a URL
let parsedUrl;
try {
parsedUrl = new URL(this.fullUrl);
} catch (error) {
throw new DGError('CloudServerCloudZoo: Invalid URL provided. Please provide a valid URL.', "INVALID_URL", { url: this.fullUrl }, "Invalid URL provided.");
}
if (!parsedUrl.origin || !parsedUrl.pathname) {
throw new DGError('CloudServerCloudZoo: Invalid URL provided. Please provide a valid URL.', "INVALID_URL", { url: this.fullUrl }, "Invalid URL provided.");
}
// Extract the first part of the URL
const firstPart = parsedUrl.origin;
// Construct the full URL by appending the cloudServerPathSuffix and then the apiUrl
const cloudServerPathSuffix = noPrefix ? '' : '/zoo/v1/public';
const fullUrl = `${firstPart}${cloudServerPathSuffix}${apiUrl}`;
// console.log(`fetchWithToken: initiating fetch with retry for ${fullUrl}`);
for (let attempt = 1; attempt <= 3; attempt++) {
try {
const response = await fetch(fullUrl, { headers });
// console.log('fetchWithToken: Got response:', response);
if (response.ok) {
// Check for redirect and update zooUrl accordingly
if (response.redirected && response.url.startsWith('https://')) {
this.fullUrl = response.url.substring(0, response.url.indexOf('/zoo/v1/public'));
}
return isOctetStream ? response.blob() : response.json();
} else {
// We need to throw custom message for a wrong token.
if (response.status === 401) {
console.log('Wrong token. Response:', response);
// TODO: Ensure that wrong token has custom error handling.
// We can throw a DG error with a custom error code specifically for wrong token.
const errorDetails = (await response.json()).detail || 'Invalid token value';
throw new DGError(`Unable to connect to server: ${errorDetails}`, "FETCH_WITH_TOKEN_FAILED", { status: response.status }, "Unable to connect to server.");
}
throw new DGError(`HTTP error! status: ${response.status}: ${response.statusText}`, "FETCH_WITH_TOKEN_FAILED", { status: response.status }, "HTTP error occurred.");
}
} catch (error) {
if (attempt === 3) throw new DGError(`All retries failed: ${error}`, "FETCH_WITH_TOKEN_FAILED", {}, "All retries failed.");
// Exponential backoff
await new Promise(resolve => setTimeout(resolve, Math.pow(2, attempt) * 100));
}
}
}
/**
* Rescans the cloud zoo to update the list of available models.
* Fetches from https://cs.degirum.com/zoo/v1/public/models/ORGANIZATION/ZOONAME
* @private
* @returns {Promise<void>}
*/
async rescanZoo() {
try {
const parsedUrl = new URL(this.fullUrl);
const path = parsedUrl.pathname;
const modelsInfo = await this.fetchWithToken('/models' + path);
let systemDeviceKeys = new Set();
if (!this.isDummy) {
try {
const systemDeviceTypes = await this.systemSupportedDeviceTypes();
systemDeviceKeys = new Set(systemDeviceTypes);
} catch (error) {
this.isRescanComplete = true;
// throw new DGError('Error fetching systemSupportedDeviceTypes for cloud while rescanning zoo:' + error, "FETCH_MODELS_FAILED", {}, "Error fetching sys info.");
// TODO: THIS BREAKS EVERYTHING if it's turned into a DGError.
console.error('Error fetching systemSupportedDeviceTypes for cloud while rescanning zoo:', error);
return;
}
}
if (!modelsInfo.error) {
this.assets = {};
for (const [modelName, modelDetails] of Object.entries(modelsInfo)) {
// Append modelName to this.modelNames
this.modelNames.push(modelName);
if (this.isDummy) {
this.assets[modelName] = modelDetails;
continue;
}
// Bypass audio models.
if (Array.isArray(modelDetails?.PRE_PROCESS) && modelDetails.PRE_PROCESS.length > 0 && modelDetails.PRE_PROCESS[0]?.InputType === 'Audio') {
console.log('Skipping audio model', modelName);
continue;
}
// Extract supported device types if available
let modelSupportedTypes;
try {
modelSupportedTypes = modelDetails.DEVICE[0].SupportedDeviceTypes;
if (modelSupportedTypes) {
modelSupportedTypes = modelSupportedTypes.split(',').map(type => type.trim());
} else {
// Fallback to default DEVICE type check
const modelRuntime = modelDetails.DEVICE[0].RuntimeAgent;
const modelDevice = modelDetails.DEVICE[0].DeviceType;
const capabilityKey = `${modelRuntime}/${modelDevice}`;
if (systemDeviceKeys.has(capabilityKey)) {
this.assets[modelName] = modelDetails;
} else {
console.log('Filtering out model', modelName, 'because it requires', capabilityKey, 'which is not available on the system.');
}
continue;
}
} catch (error) {
console.error('Error processing SupportedDeviceTypes for model', modelName, error);
continue;
}
// Perform intersection check with SupportedDeviceTypes if available
const matchingDevices = this.matchSupportedDevices(modelSupportedTypes, Array.from(systemDeviceKeys));
if (matchingDevices.length > 0) {
this.assets[modelName] = modelDetails;
} else {
console.log('Filtering out model', modelName, 'because none of the SupportedDeviceTypes match the system devices.');
}
}
this.isRescanComplete = true; // Mark rescan as complete
console.log('rescanZoo: Rescan complete. Model list:', this.assets);
} else {
this.isRescanComplete = true;
throw new DGError(modelsInfo.error, "RESCAN_ZOO_FAILED", {}, "Error occurred while rescanning zoo.");
}
} catch (error) {
this.isRescanComplete = true; // Still mark rescan as complete to avoid indefinite waiting
throw new DGError('Error rescanning zoo: ' + error, "RESCAN_ZOO_FAILED", { error: error.message }, "Error occurred while rescanning zoo.");
}
}
/**
* Determines if a system device type matches any of the model's supported device types, considering wildcards.
* @private
* @param {Array<string>} modelSupportedTypes - An array of strings representing the device types supported by the model.
* Example: ["OPENVINO/*", "TENSORRT/*", "ONNX/*"]
* @param {Array<string>} systemDeviceTypes - An array of strings representing the device types available on the system.
* Example: ["OPENVINO/CPU", "TENSORRT/GPU", "ONNX/CPU"]
*
* @returns {Array<string>} - An array of strings representing the intersection of modelSupportedTypes and systemDeviceTypes,
* with wildcards considered.
*/
matchSupportedDevices(modelSupportedTypes, systemDeviceTypes) {
const matchesWildcard = (pattern, type) => {
const [patternRuntime, patternDevice] = pattern.split('/');
const [typeRuntime, typeDevice] = type.split('/');
const runtimeMatches = patternRuntime === '*' || patternRuntime === typeRuntime;
const deviceMatches = patternDevice === '*' || patternDevice === typeDevice;
return runtimeMatches && deviceMatches;
};
return systemDeviceTypes.filter(systemType =>
modelSupportedTypes.some(modelType => matchesWildcard(modelType, systemType))
);
}
/**
* Fetches the labels for a specific model from the cloud zoo.
* queries https://cs.degirum.com/zoo/v1/public/models/degirum/public/example_model/dictionary
* @param {string} modelName - The name of the model.
* @returns {Promise<Object>} The model's label dictionary.
*/
async getModelLabels(modelName) {
try {
const fullModelName = this.getModelFullName(modelName);
const dictionaryPath = `/models/${fullModelName}/dictionary`;
const labels = await this.fetchWithToken(dictionaryPath);
return labels;
} catch (error) {
throw new DGError(`Error fetching labels for model ${modelName}: ${error}`, "GET_MODEL_LABELS_FAILED", {}, "Error occurred while fetching labels for model.");
}
}
/**
* Downloads a model from the cloud zoo.
* fetches from https://cs.degirum.com/zoo/v1/public/models/degirum/public/example_model
* @param {string} modelName - The name of the model to download.
* @returns {Promise<Blob>} The downloaded model as a Blob.
*/
async downloadModel(modelName) {
const fullModelName = this.getModelFullName(modelName);
const modelDownloadPath = `/models/${fullModelName}`;
try {
const modelBlob = await this.fetchWithToken(modelDownloadPath, true);
if (modelBlob && modelBlob.size > 0 && modelBlob.type) {
console.log(`Model ${modelName} downloaded successfully. Blob size: ${modelBlob.size} bytes, Type: ${modelBlob.type}`);
return modelBlob;
} else {
throw new DGError(
`Downloaded blob for model ${modelName} is invalid.`,
"DOWNLOAD_MODEL_FAILED",
{},
"Downloaded blob for model is invalid."
);
}
} catch (error) {
throw new DGError(
`Error downloading model ${modelName}: ${error.message || error.toString()}`,
"DOWNLOAD_MODEL_FAILED",
{ originalError: error },
"An unexpected error occurred while downloading the model."
);
}
}
/**
* Lists all available models in the cloud zoo.
* @example
* let models = await zoo.listModels();
* let modelNames = Object.keys(models);
* @returns {Promise<Object>} An object containing the available models and their params.
*/
async listModels() {
try {
await waitUntil(() => this.isRescanComplete, RESCAN_TIMEOUT_MS);
} catch (error) {
throw new DGError('cloudservercloudzoo.js: listModels() timed out.', "LIST_MODELS_FAILED", {}, "Timeout occurred while listing models.");
}
return this.assets;
}
/**
* Fetches the list of zoos for a given organization.
* @param {string} organization - The name of the organization.
* @returns {Promise<Object>} The list of zoos.
*/
async getZoos(organization) {
if (!organization) {
throw new DGError('Organization name is required to get zoos.', "GET_ZOOS_FAILED", {}, "Organization name is required to get zoos.");
}
try {
const zoos = await this.fetchWithToken(`/zoos/${organization}`);
return zoos;
} catch (error) {
throw new DGError(`Error fetching zoos for organization ${organization}. Error: ${error.message}`, "GET_ZOOS_FAILED", {}, "Error occurred while fetching zoos for organization.");
}
}
/**
* Loads a model from the cloud zoo and prepares it for inference.
* @example
* let model = await zoo.loadModel('someModel', { inputCropPercentage: 0.5, callback: someFunction } );
* @param {string} modelName - The name of the model to load.
* @param {Object} [options={}] - Additional options for loading the model.
* @param {number} [options.max_q_len=10] - The maximum length of the internal inference queue.
* @param {Function} [options.callback=null] - A callback function to handle inference results.
* @returns {Promise<CloudServerModel>} The loaded model instance.
*/
async loadModel(modelName, options = {}) {
let startTime = performance.now();
const {
max_q_len = 10,
callback = null,
...additionalParams
} = options;
if (!Number.isInteger(max_q_len) || max_q_len <= 0) {
throw new DGError("Invalid value for max_q_len: It should be a positive integer.", "INVALID_INPUT", { parameter: "max_q_len", value: max_q_len }, "Please ensure max_q_len is a positive integer.");
}
if (callback !== null && typeof callback !== "function") {
throw new DGError("Invalid value for callback: It should be a function.", "INVALID_INPUT", { parameter: "callback", value: callback }, "Please ensure callback is a function.");
}
try {
await waitUntil(() => this.isRescanComplete, RESCAN_TIMEOUT_MS);
} catch (error) {
throw new DGError('cloudservercloudzoo.js: rescan of models timed out within loadModel().', "RESCAN_MODELS_FAILED", {}, "Timeout occurred while waiting for models to be fetched.");
}
let modelParams = this.assets[modelName];
if (!modelParams) {
throw new DGError(`Model not found in the cloud zoo: ${modelName}. List of supported models: ${Object.keys(this.assets)}`, "MODEL_NOT_FOUND", { modelName: modelName }, "Model not found in the cloud zoo.");
}
const extendedModelName = this.getModelFullName(modelName);
console.log(`Loading model: ${modelName}, queue length: ${max_q_len}, callback specified:`, !(callback === null));
let labels = await this.getModelLabels(modelName);
const deepCopiedModelParams = JSON.parse(JSON.stringify(modelParams));
// additionalParams.cloudURL = this.fullUrl; // Uncomment to include cloudURL in the socketio init structure.
// additionalParams.cloudToken = this.accessToken;
// We also pass the list of supported device keys to the model.
let systemDeviceTypes = await this.systemSupportedDeviceTypes(); // array of 'RUNTIME/DEVICE' strings
// TODO
const modelCreationParams = {
modelName: extendedModelName,
modelParams: deepCopiedModelParams,
max_q_len: max_q_len,
callback: callback,
labels: labels,
serverUrl: this.url,
token: this.accessToken,
systemDeviceTypes: systemDeviceTypes
};
const model = new CloudServerModel(modelCreationParams, additionalParams);
try {
await waitUntil(() => model.initialized, MODEL_INIT_TIMEOUT_MS);
console.log('cloudservercloudzoo.js: loadModel() took', performance.now() - startTime, 'ms');
return model;
} catch (error) {
throw new DGError('cloudservercloudzoo: Timeout occurred while waiting for the model to initialize!', "LOAD_MODEL_FAILED", {}, "Timeout occurred while waiting for the model to initialize.", "Check the connection to the cloud server / internet.");
}
}
/**
* Fetches system information from the cloud server.
* @returns {Promise<Object>} The system information.
*/
async systemInfo() {
try {
// https://cs.degirum.com/devices/api/v1/public/system-info
const response = await this.fetchWithToken('/devices/api/v1/public/system-info', false, true);
return response;
} catch (error) {
// console.warn('Error fetching system information:', error);
throw new DGError('Error fetching system information: ' + error, "SYSTEM_INFO_FAILED", {}, "Error occurred while fetching system info.")
// We return all devices as available
return {
Devices: {
"N2X/CPU": {
"@Index": 0
},
"N2X/ORCA1": {
"@Index": 0
},
"ONNX/CPU": {
"@Index": 0
},
"OPENVINO/CPU": {
"@Index": 0
},
"OPENVINO/GPU": {
"@Index": 0
},
"RKNN/RK3588": {
"@Index": 0
},
"TENSORRT/DLA_FALLBACK": {
"@Index": 0
},
"TENSORRT/GPU": {
"@Index": 0
},
"TFLITE/CPU": {
"@Index": 0
},
"TFLITE/EDGETPU": {
"@Index": 0
}
}
};
}
}
/**
* Fetches the system supported device types from the AI server.
* @private
* @returns {Promise<Array<string>>} A promise that resolves to an array of strings in "RUNTIME/DEVICE" format.
*/
async systemSupportedDeviceTypes() {
// private function, to be called by dg_sdk() instead.
try {
if (!this.isDummy) {
let systemInfo;
try {
systemInfo = await this.systemInfo();
} catch (error) {
throw new DGError('Error fetching system information for cloud:' + error, "FETCH_MODELS_FAILED", {}, "Error fetching sys info.");
}
if (!systemInfo.Devices) {
throw new DGError('No devices found in system info.', "FETCH_MODELS_FAILED", {}, "No devices found in system info.");
}
return Object.keys(systemInfo.Devices);
}
} catch (error) {
throw new DGError('systemSupportedDeviceTypes failed: ' + error, "SYSTEM_INFO_FAILED", {}, "Error occurred while fetching system info.");
}
return [];
}
}
/**
* Waits until a specified condition is met or a timeout occurs.
* @private
* @param {Function} predicate - A function that returns a boolean indicating whether the condition is met.
* @param {number} [timeout=10000] - The maximum time to wait for the condition to be met, in milliseconds.
* @param {number} [interval=10] - The interval at which to check the condition, in milliseconds.
* @returns {Promise<void>} A promise that resolves when the condition is met or rejects if the timeout is reached.
*/
function waitUntil(predicate, timeout = 10000, interval = 10) {
const startTime = Date.now();
return new Promise((resolve, reject) => {
const checkCondition = () => {
if (predicate()) {
resolve();
} else if (Date.now() - startTime > timeout) {
reject(new Error('Timed out waiting for condition'));
} else {
setTimeout(checkCondition, interval);
}
};
checkCondition();
});
}
export default CloudServerCloudZoo;