Das Python Skript zu dem Video:
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, TextStreamer, StoppingCriteria
class MyStoppingCriteria(StoppingCriteria):
def __init__(self, target_sequence, prompt):
self.target_sequence = target_sequence
self.prompt=prompt
def __call__(self, input_ids, scores, **kwargs):
# Get the generated text as a string
generated_text = tokenizer.decode(input_ids[0])
generated_text = generated_text.replace(self.prompt,'')
# Check if the target sequence appears in the generated text
if self.target_sequence in generated_text:
return True # Stop generation
return False # Continue generation
def __len__(self):
return 1
def __iter__(self):
yield self
model_id = "mistralai/Mixtral-8x7B-Instruct-v0.1"
device = "cuda" if torch.cuda.is_available() else "cpu"
print(torch.cuda.get_device_name(0))
print('Target Device: ', device)
print('Memory Usage: ')
print('Allocated: ', round(torch.cuda.memory_allocated(0)/1024**3,1), 'GB')
print('Cached: ', round(torch.cuda.memory_reserved(0)/1024**3,1), 'GB')
print('Model: ', model_id)
tokenizer = AutoTokenizer.from_pretrained(model_id)
nf4_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_use_double_quant=True,
bnb_4bit_compute_dtype=torch.bfloat16
)
model = AutoModelForCausalLM.from_pretrained(model_id, quantization_config=nf4_config)
print (model)
prompt = """
Create a professional and comprehensive JavaDoc documentation for the given Java method. Add method descriptions, parameter explanations, and return value details.
#####
public static double tanh(double x) {
boolean negate = false;
if (Double.isNaN(x)) {
return x;
}
// tanh[z] = sinh[z] / cosh[z]
// = (exp(z) - exp(-z)) / (exp(z) + exp(-z))
// = (exp(2x) - 1) / (exp(2x) + 1)
// for magnitude > 20, sinh[z] == cosh[z] in double precision
if (x > 20.0) {
return 1.0;
}
if (x < -20) {
return -1.0;
}
if (x == 0) {
return x;
}
if (x < 0.0) {
x = -x;
negate = true;
}
double result;
if (x >= 0.5) {
double[] hiPrec = new double[2];
// tanh(x) = (exp(2x) - 1) / (exp(2x) + 1)
exp(x * 2.0, 0.0, hiPrec);
double ya = hiPrec[0] + hiPrec[1];
double yb = -(ya - hiPrec[0] - hiPrec[1]);
/* Numerator */
double na = -1.0 + ya;
double nb = -(na + 1.0 - ya);
double temp = na + yb;
nb += -(temp - na - yb);
na = temp;
/* Denominator */
double da = 1.0 + ya;
double db = -(da - 1.0 - ya);
temp = da + yb;
db += -(temp - da - yb);
da = temp;
temp = da * HEX_40000000;
double daa = da + temp - temp;
double dab = da - daa;
// ratio = na/da
double ratio = na / da;
temp = ratio * HEX_40000000;
double ratioa = ratio + temp - temp;
double ratiob = ratio - ratioa;
// Correct for rounding in division
ratiob += (na - daa * ratioa - daa * ratiob - dab * ratioa - dab * ratiob) / da;
// Account for nb
ratiob += nb / da;
// Account for db
ratiob += -db * na / da / da;
result = ratioa + ratiob;
} else {
double[] hiPrec = new double[2];
// tanh(x) = expm1(2x) / (expm1(2x) + 2)
expm1(x * 2.0, hiPrec);
double ya = hiPrec[0] + hiPrec[1];
double yb = -(ya - hiPrec[0] - hiPrec[1]);
/* Numerator */
double na = ya;
double nb = yb;
/* Denominator */
double da = 2.0 + ya;
double db = -(da - 2.0 - ya);
double temp = da + yb;
db += -(temp - da - yb);
da = temp;
temp = da * HEX_40000000;
double daa = da + temp - temp;
double dab = da - daa;
// ratio = na/da
double ratio = na / da;
temp = ratio * HEX_40000000;
double ratioa = ratio + temp - temp;
double ratiob = ratio - ratioa;
// Correct for rounding in division
ratiob += (na - daa * ratioa - daa * ratiob - dab * ratioa - dab * ratiob) / da;
// Account for nb
ratiob += nb / da;
// Account for db
ratiob += -db * na / da / da;
result = ratioa + ratiob;
}
if (negate) {
result = -result;
}
return result;
}
"""
inputs = tokenizer(prompt, return_tensors="pt").to(device)
streamer = TextStreamer(tokenizer=tokenizer, skip_prompt=True)
stopping_criteria = MyStoppingCriteria(" */", prompt)
outputs = model.generate(
**inputs,
streamer=streamer,
do_sample=True,
pad_token_id=tokenizer.eos_token_id,
max_new_tokens=1024,
temperature=0.01,
# top_p=0.95,
# top_k=50,
stopping_criteria=stopping_criteria
)