Mosaic-glasses commited on
Commit
e11d3d5
·
verified ·
1 Parent(s): 04f46f9

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +6 -1
README.md CHANGED
@@ -42,8 +42,8 @@ pip install transformers==4.51.0
42
  ```python
43
  from transformers import AutoModelForCausalLM, AutoTokenizer
44
 
 
45
  model_name = "infly/inf-query-aligner"
46
-
47
  model = AutoModelForCausalLM.from_pretrained(
48
  model_name,
49
  torch_dtype="auto",
@@ -51,12 +51,14 @@ model = AutoModelForCausalLM.from_pretrained(
51
  )
52
  tokenizer = AutoTokenizer.from_pretrained(model_name)
53
 
 
54
  prompt = "Give me a short introduction to large language model."
55
  messages = [
56
  {"role": "system", "content": "You are Qwen, created by Alibaba Cloud. You are a helpful assistant."},
57
  {"role": "user", "content": prompt}
58
  ]
59
 
 
60
  text = tokenizer.apply_chat_template(
61
  messages,
62
  tokenize=False,
@@ -64,6 +66,7 @@ text = tokenizer.apply_chat_template(
64
  )
65
  model_inputs = tokenizer([text], return_tensors="pt").to(model.device)
66
 
 
67
  generated_ids = model.generate(
68
  **model_inputs,
69
  max_new_tokens=512
@@ -73,6 +76,8 @@ generated_ids = [
73
  ]
74
 
75
  response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
 
 
76
  ```
77
 
78
  ---
 
42
  ```python
43
  from transformers import AutoModelForCausalLM, AutoTokenizer
44
 
45
+ # Load model and tokenizer
46
  model_name = "infly/inf-query-aligner"
 
47
  model = AutoModelForCausalLM.from_pretrained(
48
  model_name,
49
  torch_dtype="auto",
 
51
  )
52
  tokenizer = AutoTokenizer.from_pretrained(model_name)
53
 
54
+ # Define input query
55
  prompt = "Give me a short introduction to large language model."
56
  messages = [
57
  {"role": "system", "content": "You are Qwen, created by Alibaba Cloud. You are a helpful assistant."},
58
  {"role": "user", "content": prompt}
59
  ]
60
 
61
+ # Apply chat template
62
  text = tokenizer.apply_chat_template(
63
  messages,
64
  tokenize=False,
 
66
  )
67
  model_inputs = tokenizer([text], return_tensors="pt").to(model.device)
68
 
69
+ # Generate rewritten query
70
  generated_ids = model.generate(
71
  **model_inputs,
72
  max_new_tokens=512
 
76
  ]
77
 
78
  response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
79
+
80
+ print(response)
81
  ```
82
 
83
  ---