// Copyright 2024 The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
import{ModelExecutionError,ModelLoader as ModelLoaderBase,ModelLoadError}from"../../core/on_device_model/types.js";import{signal}from"../../core/reactive/signal.js";import{assertExhaustive,assertExists,assertNotReached}from"../../core/utils/assert.js";import{chunkContentByWord}from"../../core/utils/utils.js";import{isCannedResponse,isInvalidFormatResponse,parseResponse,trimRepeatedBulletPoints}from"./on_device_model_utils.js";import{FormatFeature,LoadModelResult as MojoLoadModelResult,ModelStateMonitorReceiver,ModelStateType,OnDeviceModelRemote,SafetyFeature,SessionRemote,StreamingResponderCallbackRouter}from"./types.js";const MIN_TOKEN_LENGTH=200;const MAX_TS_MODEL_INPUT_WORD_LENGTH=700;const MODEL_REPETITION_CONFIG={maxLength:1e3,lcsScoreThreshold:.9};class OnDeviceModel{constructor(remote,pageRemote,modelInfo){this.remote=remote;this.pageRemote=pageRemote;this.modelInfo=modelInfo}async execute(content,language){const session=new SessionRemote;this.remote.startSession(session.$.bindNewPipeAndPassReceiver(),null);const result=await this.executeInRemoteSession(content,language,session);session.$.close();return result}async getInputTokenSize(text,session){const inputPieces={pieces:[{text:text}]};const{size:size}=await session.getSizeInTokens(inputPieces);return size}async executeRaw(text,session,expectedBulletPointCount,language){const inputPieces={pieces:[{text:text}]};const size=await this.getInputTokenSize(text,session);if(size<MIN_TOKEN_LENGTH){console.warn(`Skip GenAI model execution: too small token size: ${size}.`);return{kind:"error",error:ModelExecutionError.UNSUPPORTED_TRANSCRIPTION_IS_TOO_SHORT}}if(size>this.modelInfo.inputTokenLimit){console.warn(`Skip GenAI model execution: too large token size: ${size}.`);return{kind:"error",error:ModelExecutionError.UNSUPPORTED_TRANSCRIPTION_IS_TOO_LONG}}const responseRouter=new StreamingResponderCallbackRouter;const{promise:promise,resolve:resolve}=Promise.withResolvers();const response=[];const onResponseId=responseRouter.onResponse.addListener((chunk=>{response.push(chunk.text)}));const onCompleteId=responseRouter.onComplete.addListener((_=>{responseRouter.removeListener(onResponseId);responseRouter.removeListener(onCompleteId);responseRouter.$.close();resolve(response.join("").trimStart())}));session.append({maxTokens:0,input:inputPieces},null);session.generate({maxOutputTokens:0,constraint:null},responseRouter.$.bindNewPipeAndPassRemote());const result=await promise;if(isCannedResponse(result)){console.warn("Invalid GenAI result: canned response.");return{kind:"error",error:ModelExecutionError.UNSAFE}}const parsedResult=parseResponse(result);if(isInvalidFormatResponse(parsedResult,expectedBulletPointCount)){console.warn("Invalid GenAI result: invalid format.");return{kind:"error",error:ModelExecutionError.UNSAFE}}const finalBulletPoints=trimRepeatedBulletPoints(parsedResult,MODEL_REPETITION_CONFIG.maxLength,language,MODEL_REPETITION_CONFIG.lcsScoreThreshold);if(finalBulletPoints.length===0){console.warn("Invalid GenAI result: no valid bullet point.");return{kind:"error",error:ModelExecutionError.UNSAFE}}const finalResult=finalBulletPoints.join("\n");return{kind:"success",result:finalResult}}async contentIsUnsafe(content,safetyFeature,language){const contentChunks=chunkContentByWord(content,MAX_TS_MODEL_INPUT_WORD_LENGTH,language);for(const chunk of contentChunks){const{safetyInfo:safetyInfo}=await this.remote.classifyTextSafety(chunk);if(safetyInfo===null){continue}const{isSafe:isSafe}=await this.pageRemote.validateSafetyResult(safetyFeature,chunk,safetyInfo);if(!isSafe){return true}}return false}close(){this.remote.$.close()}async formatInput(feature,fields){const{result:result}=await this.pageRemote.formatModelInput(this.modelInfo.modelId,feature,fields);return result}async formatAndExecute(formatFeature,requestSafetyFeature,responseSafetyFeature,fields,session,language,expectedBulletPointCount){const prompt=await this.formatInput(formatFeature,fields);if(prompt===null){console.error("formatInput returns null, wrong model?");return{kind:"error",error:ModelExecutionError.GENERAL}}if(await this.contentIsUnsafe(prompt,requestSafetyFeature,language)){console.warn("Unsafe GenAI prompt.");return{kind:"error",error:ModelExecutionError.UNSAFE}}const response=await this.executeRaw(prompt,session,expectedBulletPointCount,language);if(response.kind==="error"){return response}if(await this.contentIsUnsafe(response.result,responseSafetyFeature,language)){console.warn("Unsafe GenAI result.");return{kind:"error",error:ModelExecutionError.UNSAFE}}return{kind:"success",result:response.result}}}export class SummaryModel extends OnDeviceModel{async executeInRemoteSession(content,language,session){const inputTokenSize=await this.getInputTokenSize(content,session);const safetyFeatureOnResponse=this.modelInfo.isLargeModel?SafetyFeature.kAudioSummaryResponseV2:SafetyFeature.kAudioSummaryResponse;const expectedBulletPointCount=this.getExpectedBulletPoints(inputTokenSize);const bulletPointsRequest=this.formatBulletPointRequest(expectedBulletPointCount);const resp=await this.formatAndExecute(FormatFeature.kAudioSummary,SafetyFeature.kAudioSummaryRequest,safetyFeatureOnResponse,{transcription:content,language:language,bullet_points_request:bulletPointsRequest},session,language,expectedBulletPointCount);if(resp.kind==="error"){return resp}return{kind:"success",result:resp.result}}getExpectedBulletPoints(inputTokenSize){if(!this.modelInfo.isLargeModel){return 3}if(inputTokenSize<250){return 1}else if(inputTokenSize<600){return 2}else if(inputTokenSize<4e3){return 3}else if(inputTokenSize<6600){return 4}else if(inputTokenSize<9300){return 5}else{return 6}}formatBulletPointRequest(request){if(request<=0){assertNotReached("Got non-positive bullet point request.")}return`${request} bullet point`+(request>1?"s":"")}}export class TitleSuggestionModel extends OnDeviceModel{async executeInRemoteSession(content,language,session){const safetyFeatureOnResponse=this.modelInfo.isLargeModel?SafetyFeature.kAudioTitleResponseV2:SafetyFeature.kAudioTitleResponse;const resp=await this.formatAndExecute(FormatFeature.kAudioTitle,SafetyFeature.kAudioTitleRequest,safetyFeatureOnResponse,{transcription:content,language:language},session,language,3);if(resp.kind==="error"){return resp}const lines=resp.result.split("\n");const titles=[];for(const line of lines){const lineStart="- ";if(line.startsWith(lineStart)){titles.push(line.substring(lineStart.length))}}return{kind:"success",result:titles.slice(0,3)}}}export function mojoModelStateToModelState(state){switch(state.type){case ModelStateType.kNotInstalled:return{kind:"notInstalled"};case ModelStateType.kInstalling:return{kind:"installing",progress:assertExists(state.progress)};case ModelStateType.kInstalled:return{kind:"installed"};case ModelStateType.kNeedsReboot:return{kind:"needsReboot"};case ModelStateType.kError:return{kind:"error"};case ModelStateType.kUnavailable:return{kind:"unavailable"};case ModelStateType.MIN_VALUE:case ModelStateType.MAX_VALUE:return assertNotReached(`Got MIN_VALUE or MAX_VALUE from mojo ModelStateType: ${state.type}`);default:assertExhaustive(state.type)}}function mojoLoadModelResultToModelLoadError(result){switch(result){case MojoLoadModelResult.kFailedToLoadLibrary:case MojoLoadModelResult.kGpuBlocked:return ModelLoadError.LOAD_FAILURE;case MojoLoadModelResult.kCrosNeedReboot:return ModelLoadError.NEEDS_REBOOT;case MojoLoadModelResult.kSuccess:return assertNotReached(`Try transforming success load result to error`);case MojoLoadModelResult.MIN_VALUE:case MojoLoadModelResult.MAX_VALUE:return assertNotReached(`Got MIN_VALUE or MAX_VALUE from mojo LoadModelResult: ${result}`);default:assertExhaustive(result)}}class ModelLoader extends ModelLoaderBase{constructor(remote,platformHandler){super();this.remote=remote;this.platformHandler=platformHandler;this.state=signal({kind:"unavailable"});this.modelInfoInternal=null}get modelInfo(){return assertExists(this.modelInfoInternal)}async init(){this.modelInfoInternal=(await this.remote.getModelInfo(this.featureType)).modelInfo;const update=state=>{this.state.value=mojoModelStateToModelState(state)};const monitor=new ModelStateMonitorReceiver({update:update});const{state:state}=await this.remote.addModelMonitor(this.modelInfo.modelId,monitor.$.bindNewPipeAndPassRemote());update(state)}async load(){const newModel=new OnDeviceModelRemote;const{result:result}=await this.remote.loadModel(this.modelInfo.modelId,newModel.$.bindNewPipeAndPassReceiver());if(result!==MojoLoadModelResult.kSuccess){console.error("Load model failed:",result);return{kind:"error",error:mojoLoadModelResultToModelLoadError(result)}}return{kind:"success",model:this.createModel(newModel)}}async loadAndExecute(content,language){if(!this.platformHandler.getLangPackInfo(language).isGenAiSupported){console.warn(`Skip GenAI model execution: unsupported language: ${language}.`);return{kind:"error",error:ModelExecutionError.UNSUPPORTED_LANGUAGE}}const loadResult=await this.load();if(loadResult.kind==="error"){return loadResult}try{return await loadResult.model.execute(content,language)}finally{loadResult.model.close()}}}export class SummaryModelLoader extends ModelLoader{constructor(){super(...arguments);this.featureType=FormatFeature.kAudioSummary}createModel(remote){return new SummaryModel(remote,this.remote,this.modelInfo)}}export class TitleSuggestionModelLoader extends ModelLoader{constructor(){super(...arguments);this.featureType=FormatFeature.kAudioTitle}createModel(remote){return new TitleSuggestionModel(remote,this.remote,this.modelInfo)}}