Zum Inhalt springen →

Run mistralai/Mixtral-8x7B-Instruct-v0.1 model on Jetson Orin 64GB in 4bit

Das Python Skript zu dem Video:

Run mistralai/Mixtral-8x7B-Instruct-v0.1 model on Jetson Orin 64GB in 4bit
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, TextStreamer, StoppingCriteria

class MyStoppingCriteria(StoppingCriteria):
    def __init__(self, target_sequence, prompt):
        self.target_sequence = target_sequence
        self.prompt=prompt

    def __call__(self, input_ids, scores, **kwargs):
        # Get the generated text as a string
        generated_text = tokenizer.decode(input_ids[0])
        generated_text = generated_text.replace(self.prompt,'')
        # Check if the target sequence appears in the generated text
        if self.target_sequence in generated_text:
            return True  # Stop generation

        return False  # Continue generation

    def __len__(self):
        return 1

    def __iter__(self):
        yield self


model_id = "mistralai/Mixtral-8x7B-Instruct-v0.1"
device = "cuda" if torch.cuda.is_available() else "cpu"

print(torch.cuda.get_device_name(0))
print('Target Device: ', device)
print('Memory Usage:  ')
print('Allocated:     ', round(torch.cuda.memory_allocated(0)/1024**3,1), 'GB')
print('Cached:        ', round(torch.cuda.memory_reserved(0)/1024**3,1), 'GB')
print('Model:         ', model_id)

tokenizer = AutoTokenizer.from_pretrained(model_id)

nf4_config = BitsAndBytesConfig(
   load_in_4bit=True,
   bnb_4bit_quant_type="nf4",
   bnb_4bit_use_double_quant=True,
   bnb_4bit_compute_dtype=torch.bfloat16
)

model = AutoModelForCausalLM.from_pretrained(model_id, quantization_config=nf4_config)
print (model)

prompt = """
Create a professional and comprehensive JavaDoc documentation for the given Java method. Add method descriptions, parameter explanations, and return value details.
#####
    public static double tanh(double x) {
        boolean negate = false;

        if (Double.isNaN(x)) {
            return x;
        }

        // tanh[z] = sinh[z] / cosh[z]
        // = (exp(z) - exp(-z)) / (exp(z) + exp(-z))
        // = (exp(2x) - 1) / (exp(2x) + 1)

        // for magnitude > 20, sinh[z] == cosh[z] in double precision

        if (x > 20.0) {
            return 1.0;
        }

        if (x < -20) {
            return -1.0;
        }

        if (x == 0) {
            return x;
        }

        if (x < 0.0) {
            x = -x;
            negate = true;
        }

        double result;
        if (x >= 0.5) {
            double[] hiPrec = new double[2];
            // tanh(x) = (exp(2x) - 1) / (exp(2x) + 1)
            exp(x * 2.0, 0.0, hiPrec);

            double ya = hiPrec[0] + hiPrec[1];
            double yb = -(ya - hiPrec[0] - hiPrec[1]);

            /* Numerator */
            double na = -1.0 + ya;
            double nb = -(na + 1.0 - ya);
            double temp = na + yb;
            nb += -(temp - na - yb);
            na = temp;

            /* Denominator */
            double da = 1.0 + ya;
            double db = -(da - 1.0 - ya);
            temp = da + yb;
            db += -(temp - da - yb);
            da = temp;

            temp = da * HEX_40000000;
            double daa = da + temp - temp;
            double dab = da - daa;

            // ratio = na/da
            double ratio = na / da;
            temp = ratio * HEX_40000000;
            double ratioa = ratio + temp - temp;
            double ratiob = ratio - ratioa;

            // Correct for rounding in division
            ratiob += (na - daa * ratioa - daa * ratiob - dab * ratioa - dab * ratiob) / da;

            // Account for nb
            ratiob += nb / da;
            // Account for db
            ratiob += -db * na / da / da;

            result = ratioa + ratiob;
        } else {
            double[] hiPrec = new double[2];
            // tanh(x) = expm1(2x) / (expm1(2x) + 2)
            expm1(x * 2.0, hiPrec);

            double ya = hiPrec[0] + hiPrec[1];
            double yb = -(ya - hiPrec[0] - hiPrec[1]);

            /* Numerator */
            double na = ya;
            double nb = yb;

            /* Denominator */
            double da = 2.0 + ya;
            double db = -(da - 2.0 - ya);
            double temp = da + yb;
            db += -(temp - da - yb);
            da = temp;

            temp = da * HEX_40000000;
            double daa = da + temp - temp;
            double dab = da - daa;

            // ratio = na/da
            double ratio = na / da;
            temp = ratio * HEX_40000000;
            double ratioa = ratio + temp - temp;
            double ratiob = ratio - ratioa;

            // Correct for rounding in division
            ratiob += (na - daa * ratioa - daa * ratiob - dab * ratioa - dab * ratiob) / da;

            // Account for nb
            ratiob += nb / da;
            // Account for db
            ratiob += -db * na / da / da;

            result = ratioa + ratiob;
        }

        if (negate) {
            result = -result;
        }

        return result;
    }
"""

inputs = tokenizer(prompt, return_tensors="pt").to(device)
streamer = TextStreamer(tokenizer=tokenizer, skip_prompt=True)
stopping_criteria = MyStoppingCriteria(" */", prompt)

outputs = model.generate(
        **inputs,
        streamer=streamer,
        do_sample=True,
        pad_token_id=tokenizer.eos_token_id,
        max_new_tokens=1024,
        temperature=0.01,
        # top_p=0.95,
        # top_k=50,
        stopping_criteria=stopping_criteria
        )