Source code for omnigenbench.cli.omnigenome_cli

# -*- coding: utf-8 -*-
# file: omnigenbench_cli.py
# time: 12:51 05/02/2025
# author: YANG, HENG <hy345@exeter.ac.uk> (杨恒)
# Homepage: https://yangheng95.github.io
# github: https://github.com/yangheng95
# huggingface: https://huggingface.co/yangheng
# google scholar: https://scholar.google.com/citations?user=NPq5a_0AAAAJ&hl=en
# Copyright (C) 2019-2025. All Rights Reserved.
import argparse
import json
import warnings

# Suppress warnings for cleaner CLI output
warnings.filterwarnings("ignore")


[docs] def main(): """ The main entry point for the OmniGenome command-line interface. This function sets up the command-line argument parser and handles the execution of different subcommands. Supports RNA design and model inference functionality. Example: >>> # Design RNA sequences from command line >>> python -m omnigenbench.cli.omnigenome_cli rna_design --structure "(((...)))" >>> # Run model inference >>> python -m omnigenbench.cli.omnigenome_cli autoinfer --model "yangheng/ogb_tfb_finetuned" --sequence "ATCGATCGATCG" """ parser = argparse.ArgumentParser(description="OmniGenome CLI") subparsers = parser.add_subparsers( dest="command", required=True, help="Sub-command help" ) # Design command design_parser = subparsers.add_parser( "rna_design", help="Design RNA sequences for a given secondary structure" ) design_parser.add_argument( "--structure", type=str, required=True, help='Target RNA structure in dot-bracket notation (e.g., "(((...)))")', ) design_parser.add_argument( "--model", type=str, default="yangheng/OmniGenome-186M", help="Path to the pre-trained model (default: yangheng/OmniGenome-186M)", ) design_parser.add_argument( "--mutation-ratio", type=float, default=0.5, help="Mutation ratio for genetic algorithm (0.0-1.0, default: 0.5)", ) design_parser.add_argument( "--num-population", type=int, default=100, help="Number of individuals in population (default: 100)", ) design_parser.add_argument( "--num-generation", type=int, default=100, help="Number of generations to evolve (default: 100)", ) design_parser.add_argument( "--output-file", type=str, help="Output JSON file to save results" ) # AutoInfer command infer_parser = subparsers.add_parser( "autoinfer", help="Run inference with a fine-tuned model" ) infer_parser.add_argument( "--model", type=str, required=True, help="Path or name of the fine-tuned model (e.g., yangheng/ogb_tfb_finetuned)", ) infer_parser.add_argument( "--sequence", type=str, help="Input sequence(s). Can be a single sequence string or path to a file", ) infer_parser.add_argument( "--input-file", type=str, help="Path to JSON/CSV file with input data. JSON format: {'sequences': [...]} or {'data': [{'sequence': ..., ...}]}", ) infer_parser.add_argument( "--output-file", type=str, default="inference_results.json", help="Output file to save predictions (default: inference_results.json)", ) infer_parser.add_argument( "--batch-size", type=int, default=32, help="Batch size for inference (default: 32)", ) infer_parser.add_argument( "--device", type=str, default=None, help="Device to run inference on (e.g., 'cuda:0', 'cpu'). Auto-detected if not specified", ) args = parser.parse_args() if args.command == "rna_design": from omnigenbench import OmniModelForRNADesign # Validate parameters if not 0 <= args.mutation_ratio <= 1: raise ValueError("--mutation-ratio must be between 0.0 and 1.0") if args.num_population <= 0 or args.num_generation <= 0: raise ValueError( "Population and generation numbers must be positive integers" ) # Run RNA design model = OmniModelForRNADesign(model=args.model) best_sequences = model.design( structure=args.structure, mutation_ratio=args.mutation_ratio, num_population=args.num_population, num_generation=args.num_generation, ) # Output results print(f"Best RNA sequences for {args.structure}:") if isinstance(best_sequences, list): for seq in best_sequences: print(f"- {seq}") else: # Fallback for single sequence (shouldn't happen with updated model) print(f"- {best_sequences}") best_sequences = [best_sequences] # Save to file if specified if args.output_file: with open(args.output_file, "w") as f: json.dump( { "structure": args.structure, "parameters": { "model": args.model, "mutation_ratio": args.mutation_ratio, "num_population": args.num_population, "num_generation": args.num_generation, }, "best_sequences": best_sequences, }, f, indent=2, ) print(f"\nResults saved to {args.output_file}") elif args.command == "autoinfer": import pandas as pd from omnigenbench import ModelHub # Validate that at least one input source is provided if not args.sequence and not args.input_file: raise ValueError( "Either --sequence or --input-file must be provided for inference" ) # Load the model print(f"🔄 Loading model from: {args.model}") model = ModelHub.load(args.model, device=args.device) model.eval() print(f"✅ Model loaded successfully on device: {model.device}") # Prepare input sequences sequences = [] metadata = [] if args.sequence: # Single sequence or comma-separated sequences if args.sequence.endswith(".txt"): # Read from text file (one sequence per line) with open(args.sequence, "r") as f: sequences = [line.strip() for line in f if line.strip()] else: # Direct sequence input (support comma-separated) sequences = [s.strip() for s in args.sequence.split(",") if s.strip()] metadata = [{"index": i} for i in range(len(sequences))] elif args.input_file: # Load from JSON or CSV file if args.input_file.endswith(".json"): with open(args.input_file, "r") as f: data = json.load(f) if isinstance(data, dict): if "sequences" in data: sequences = data["sequences"] metadata = [{"index": i} for i in range(len(sequences))] elif "data" in data: # Complex format with metadata for item in data["data"]: sequences.append(item["sequence"]) meta = {k: v for k, v in item.items() if k != "sequence"} metadata.append(meta) else: raise ValueError( "JSON file must contain 'sequences' or 'data' key" ) elif isinstance(data, list): sequences = data metadata = [{"index": i} for i in range(len(sequences))] elif args.input_file.endswith(".csv"): df = pd.read_csv(args.input_file) if "sequence" not in df.columns: raise ValueError("CSV file must have a 'sequence' column") sequences = df["sequence"].tolist() metadata = df.drop(columns=["sequence"]).to_dict("records") else: raise ValueError("Input file must be .json, .csv, or .txt format") print(f"📊 Processing {len(sequences)} sequence(s)...") # Run inference results = [] for i in range(0, len(sequences), args.batch_size): batch_sequences = sequences[i : i + args.batch_size] batch_meta = metadata[i : i + args.batch_size] print( f"🔄 Inferring batch {i // args.batch_size + 1}/{(len(sequences) + args.batch_size - 1) // args.batch_size}..." ) for seq, meta in zip(batch_sequences, batch_meta): try: output = model.inference(seq) # Format output based on model type result = { "sequence": seq, "metadata": meta, } # Add predictions based on output structure if isinstance(output, dict): # Model returns dictionary with predictions/probabilities if "predictions" in output: result["predictions"] = ( output["predictions"].tolist() if hasattr(output["predictions"], "tolist") else output["predictions"] ) if "probabilities" in output: result["probabilities"] = ( output["probabilities"].tolist() if hasattr(output["probabilities"], "tolist") else output["probabilities"] ) if "logits" in output: result["logits"] = ( output["logits"].tolist() if hasattr(output["logits"], "tolist") else output["logits"] ) # Include any other keys from the output for key, value in output.items(): if key not in ["predictions", "probabilities", "logits"]: result[key] = ( value.tolist() if hasattr(value, "tolist") else value ) else: # Model returns raw tensor/array result["output"] = ( output.tolist() if hasattr(output, "tolist") else output ) results.append(result) except Exception as e: print(f"⚠️ Error processing sequence {meta.get('index', i)}: {e}") results.append( { "sequence": seq, "metadata": meta, "error": str(e), } ) # Save results output_data = { "model": args.model, "total_sequences": len(sequences), "results": results, } with open(args.output_file, "w") as f: json.dump(output_data, f, indent=2) print(f"✅ Inference completed!") print(f"📁 Results saved to: {args.output_file}") print( f"📊 Successfully processed: {len([r for r in results if 'error' not in r])}/{len(sequences)} sequences" )
if __name__ == "__main__": main()